Skip to content
Snippets Groups Projects
Commit d2e420ea authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Merge branch 'Fix_loss_derivative' into 'dev'

Fix loss derivative

See merge request !43
parents 9c7303e5 608963e9
No related branches found
No related tags found
2 merge requests!44Update 0.2.3 -> 0.3.0,!43Fix loss derivative
Pipeline #71051 passed
......@@ -123,7 +123,7 @@ Aidge::Tensor Aidge::loss::BCE(std::shared_ptr<Tensor>& prediction,
// Define node: gradient
const std::shared_ptr<Node> gradient_node = Mul("gradient");
div1_node->addChild(gradient_node, 0, 0);
Producer(std::make_shared<Tensor>(Array1D<float, 1>{{-1.0f/float(target->dims()[0])}}))
Producer(std::make_shared<Tensor>(Array1D<float, 1>{{-1.0f/float(target->size())}}))
->addChild(gradient_node, 0, 1);
// Create GraphView
......
......@@ -83,7 +83,7 @@ Aidge::Tensor Aidge::loss::MSE(std::shared_ptr<Tensor>& prediction,
// Note: this assume target is [nbBatch, nbChan]
Producer(std::make_shared<Tensor>(
Array1D<float, 1>{{2 / float(target->dims()[0])}}))
Array1D<float, 1>{{2 / float(target->size())}}))
->addChild(mul_node, 0, 1);
sub_node->addChild(mul_node, 0, 0); // Error computation branch !
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment