Skip to content
Snippets Groups Projects
Commit d5e82429 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Update MSE to take a tensor as an input.

parent 4daac675
No related branches found
No related tags found
2 merge requests!6version 0.1.1,!5Update how loss function work
...@@ -20,10 +20,19 @@ ...@@ -20,10 +20,19 @@
namespace Aidge { namespace Aidge {
namespace loss { namespace loss {
Tensor MSE(std::shared_ptr<GraphView> graph, /**
* @brief Compute the Mean Square Error loss.
* This function returns the loss and set the ``grad()`` of the prediction
* input.
* @param prediction Tensor returned by the Aidge Graph, it is important that
* this tensor is not a copy as oterhwise the backward function will not have a
* gradient to start.
* @param target Tensor representing the ground truth, it must be one hot encoded.
*/
Tensor MSE(std::shared_ptr<Tensor>& prediction,
const std::shared_ptr<Tensor>& target); const std::shared_ptr<Tensor>& target);
} // loss } // namespace loss
} // namespace Aidge } // namespace Aidge
#endif /* AIDGE_CORE_LOSS_LOSSLIST_H_ */ #endif /* AIDGE_CORE_LOSS_LOSSLIST_H_ */
...@@ -26,23 +26,32 @@ ...@@ -26,23 +26,32 @@
#include "aidge/recipes/GraphViewHelper.hpp" #include "aidge/recipes/GraphViewHelper.hpp"
#include "aidge/scheduler/Scheduler.hpp" #include "aidge/scheduler/Scheduler.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp" #include "aidge/scheduler/SequentialScheduler.hpp"
int CPT = 0;
Aidge::Tensor Aidge::loss::MSE(std::shared_ptr<GraphView> graph, Aidge::Tensor Aidge::loss::MSE(std::shared_ptr<Tensor>& prediction,
const std::shared_ptr<Tensor>& target) { const std::shared_ptr<Tensor>& target) {
compile_gradient(graph); // Warning compile gradient here, without /*
// it, grad is nullptr. Maybe we can find a better Implementation note:
// place to do so ? MSE is computed using a graph in order to not be backend dependant.
The graph used is the following:
pred->Sub
label->Sub
Sub->Pow
(2)->Pow->ReduceMean->Loss
Sub->Mul
(2/NbBatch)->Mul->Gradient
*/
prediction->initGrad(); // Enable gradient for output
// compile_gradient(graph); // Warning compile gradient here, without
// // it, grad is nullptr. Maybe we can find a better
// // place to do so ?
AIDGE_ASSERT(target->dims().size() == 2, AIDGE_ASSERT(target->dims().size() == 2,
"Label must have two dims: [BatchSize, NbChannel]"); "Label must have two dims: [BatchSize, NbChannel]");
AIDGE_ASSERT(
graph->outputNodes().size() == 1,
"MSE can only be computed on graph with one output, {} were found.",
graph->outputs().size());
std::shared_ptr<Node> lastNode = *(graph->outputNodes().begin());
const std::shared_ptr<Tensor>& prediction =
std::dynamic_pointer_cast<OperatorTensor>(lastNode->getOperator())
->getOutput(0);
std::shared_ptr<Tensor> outputGrad = prediction->grad(); std::shared_ptr<Tensor> outputGrad = prediction->grad();
AIDGE_ASSERT(prediction->backend() == target->backend(), AIDGE_ASSERT(prediction->backend() == target->backend(),
...@@ -73,6 +82,8 @@ Aidge::Tensor Aidge::loss::MSE(std::shared_ptr<GraphView> graph, ...@@ -73,6 +82,8 @@ Aidge::Tensor Aidge::loss::MSE(std::shared_ptr<GraphView> graph,
Producer(target, "label")->addChild(sub_node, 0, 1); Producer(target, "label")->addChild(sub_node, 0, 1);
const std::shared_ptr<Node> mul_node = Mul("gradient"); const std::shared_ptr<Node> mul_node = Mul("gradient");
// Note: this assume target is [nbBatch, nbChan]
Producer(std::make_shared<Tensor>( Producer(std::make_shared<Tensor>(
Array1D<float, 1>{{2 / float(target->dims()[0])}})) Array1D<float, 1>{{2 / float(target->dims()[0])}}))
->addChild(mul_node, 0, 1); ->addChild(mul_node, 0, 1);
......
...@@ -35,9 +35,8 @@ TEST_CASE("[loss/regression] MSE", "[loss][regression][MSE]") { ...@@ -35,9 +35,8 @@ TEST_CASE("[loss/regression] MSE", "[loss][regression][MSE]") {
std::uniform_real_distribution<float> valueDist(0.0f, 1.0f); std::uniform_real_distribution<float> valueDist(0.0f, 1.0f);
for (std::uint16_t trial = 0; trial < NBTRIALS; ++trial) { for (std::uint16_t trial = 0; trial < NBTRIALS; ++trial) {
// Create a random number generator const std::size_t nb_dims = 2; // For MSE test, nb_dims is fixed as 2: NbBatch, NbChan
const std::size_t nb_dims = nbDimsDist(gen); std::vector<std::size_t> dims(2);
std::vector<std::size_t> dims(nb_dims);
for (std::size_t i = 0; i < nb_dims; ++i) { dims[i] = dimsDist(gen); } for (std::size_t i = 0; i < nb_dims; ++i) { dims[i] = dimsDist(gen); }
const std::size_t nb_elements = std::accumulate(dims.cbegin(), dims.cend(), std::size_t(1), std::multiplies<std::size_t>()); const std::size_t nb_elements = std::accumulate(dims.cbegin(), dims.cend(), std::size_t(1), std::multiplies<std::size_t>());
...@@ -78,11 +77,11 @@ TEST_CASE("[loss/regression] MSE", "[loss][regression][MSE]") { ...@@ -78,11 +77,11 @@ TEST_CASE("[loss/regression] MSE", "[loss][regression][MSE]") {
targ_tensor->setBackend("cpu"); targ_tensor->setBackend("cpu");
targ_tensor->getImpl()->setRawPtr(targ.get(), nb_elements); targ_tensor->getImpl()->setRawPtr(targ.get(), nb_elements);
targ_tensor->print(); targ_tensor->print();
const Tensor res_function = loss::MSE(pred_tensor, targ_tensor); const Tensor res_function = loss::MSE(pred_tensor, targ_tensor);
// compare results // compare results
Tensor res_manual_tensor = Tensor(res_manual); Tensor res_manual_tensor = Tensor(res_manual);
REQUIRE(approxEq<float>(res_manual, res_function)); REQUIRE(approxEq<float>(res_manual, res_function));
} }
} }
} // namespace Aidge } // namespace Aidge
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment