diff --git a/include/aidge/loss/LossList.hpp b/include/aidge/loss/LossList.hpp index 61763bc8465c7caf70924de846c4cbe0782149c6..26b6e9c2838ef076c020c531e6af290aff1223c9 100644 --- a/include/aidge/loss/LossList.hpp +++ b/include/aidge/loss/LossList.hpp @@ -20,10 +20,19 @@ namespace Aidge { namespace loss { -Tensor MSE(std::shared_ptr<GraphView> graph, +/** + * @brief Compute the Mean Square Error loss. + * This function returns the loss and set the ``grad()`` of the prediction + * input. + * @param prediction Tensor returned by the Aidge Graph, it is important that + * this tensor is not a copy as oterhwise the backward function will not have a + * gradient to start. + * @param target Tensor representing the ground truth, it must be one hot encoded. + */ +Tensor MSE(std::shared_ptr<Tensor>& prediction, const std::shared_ptr<Tensor>& target); -} // loss -} // namespace Aidge +} // namespace loss +} // namespace Aidge #endif /* AIDGE_CORE_LOSS_LOSSLIST_H_ */ diff --git a/src/loss/regression/MSE.cpp b/src/loss/regression/MSE.cpp index 8c5ceeff4a57bea49ed667d9ed58713ff7cd0fee..87f685a0f550a1cb60563503447407f70868ce9a 100644 --- a/src/loss/regression/MSE.cpp +++ b/src/loss/regression/MSE.cpp @@ -26,23 +26,32 @@ #include "aidge/recipes/GraphViewHelper.hpp" #include "aidge/scheduler/Scheduler.hpp" #include "aidge/scheduler/SequentialScheduler.hpp" -int CPT = 0; -Aidge::Tensor Aidge::loss::MSE(std::shared_ptr<GraphView> graph, + +Aidge::Tensor Aidge::loss::MSE(std::shared_ptr<Tensor>& prediction, const std::shared_ptr<Tensor>& target) { - compile_gradient(graph); // Warning compile gradient here, without - // it, grad is nullptr. Maybe we can find a better - // place to do so ? + /* + Implementation note: + MSE is computed using a graph in order to not be backend dependant. + + The graph used is the following: + + pred->Sub + label->Sub + Sub->Pow + (2)->Pow->ReduceMean->Loss + Sub->Mul + (2/NbBatch)->Mul->Gradient + */ + + prediction->initGrad(); // Enable gradient for output + + // compile_gradient(graph); // Warning compile gradient here, without + // // it, grad is nullptr. Maybe we can find a better + // // place to do so ? AIDGE_ASSERT(target->dims().size() == 2, "Label must have two dims: [BatchSize, NbChannel]"); - AIDGE_ASSERT( - graph->outputNodes().size() == 1, - "MSE can only be computed on graph with one output, {} were found.", - graph->outputs().size()); - std::shared_ptr<Node> lastNode = *(graph->outputNodes().begin()); - const std::shared_ptr<Tensor>& prediction = - std::dynamic_pointer_cast<OperatorTensor>(lastNode->getOperator()) - ->getOutput(0); + std::shared_ptr<Tensor> outputGrad = prediction->grad(); AIDGE_ASSERT(prediction->backend() == target->backend(), @@ -73,6 +82,8 @@ Aidge::Tensor Aidge::loss::MSE(std::shared_ptr<GraphView> graph, Producer(target, "label")->addChild(sub_node, 0, 1); const std::shared_ptr<Node> mul_node = Mul("gradient"); + + // Note: this assume target is [nbBatch, nbChan] Producer(std::make_shared<Tensor>( Array1D<float, 1>{{2 / float(target->dims()[0])}})) ->addChild(mul_node, 0, 1); diff --git a/unit_tests/loss/regression/Test_MSE.cpp b/unit_tests/loss/regression/Test_MSE.cpp index 3899470b5f0141fc747f6a2a52cc35b41a590d49..2b0e6d1edfaa1d452c714a08c4998725331df2c3 100644 --- a/unit_tests/loss/regression/Test_MSE.cpp +++ b/unit_tests/loss/regression/Test_MSE.cpp @@ -35,9 +35,8 @@ TEST_CASE("[loss/regression] MSE", "[loss][regression][MSE]") { std::uniform_real_distribution<float> valueDist(0.0f, 1.0f); for (std::uint16_t trial = 0; trial < NBTRIALS; ++trial) { - // Create a random number generator - const std::size_t nb_dims = nbDimsDist(gen); - std::vector<std::size_t> dims(nb_dims); + const std::size_t nb_dims = 2; // For MSE test, nb_dims is fixed as 2: NbBatch, NbChan + std::vector<std::size_t> dims(2); for (std::size_t i = 0; i < nb_dims; ++i) { dims[i] = dimsDist(gen); } const std::size_t nb_elements = std::accumulate(dims.cbegin(), dims.cend(), std::size_t(1), std::multiplies<std::size_t>()); @@ -78,11 +77,11 @@ TEST_CASE("[loss/regression] MSE", "[loss][regression][MSE]") { targ_tensor->setBackend("cpu"); targ_tensor->getImpl()->setRawPtr(targ.get(), nb_elements); targ_tensor->print(); - const Tensor res_function = loss::MSE(pred_tensor, targ_tensor); + const Tensor res_function = loss::MSE(pred_tensor, targ_tensor); // compare results Tensor res_manual_tensor = Tensor(res_manual); REQUIRE(approxEq<float>(res_manual, res_function)); } } -} // namespace Aidge \ No newline at end of file +} // namespace Aidge