diff --git a/include/aidge/loss/LossList.hpp b/include/aidge/loss/LossList.hpp
index 61763bc8465c7caf70924de846c4cbe0782149c6..26b6e9c2838ef076c020c531e6af290aff1223c9 100644
--- a/include/aidge/loss/LossList.hpp
+++ b/include/aidge/loss/LossList.hpp
@@ -20,10 +20,19 @@
 namespace Aidge {
 namespace loss {
 
-Tensor MSE(std::shared_ptr<GraphView> graph,
+/**
+ * @brief Compute the Mean Square Error loss.
+ * This function returns the loss and set the ``grad()`` of the prediction
+ * input.
+ * @param prediction Tensor returned by the Aidge Graph, it is important that
+ * this tensor is not a copy as oterhwise the backward function will not have a
+ * gradient to start.
+ * @param target Tensor representing the ground truth, it must be one hot encoded.
+ */
+Tensor MSE(std::shared_ptr<Tensor>& prediction,
            const std::shared_ptr<Tensor>& target);
 
-} // loss
-} // namespace Aidge
+}  // namespace loss
+}  // namespace Aidge
 
 #endif /* AIDGE_CORE_LOSS_LOSSLIST_H_ */
diff --git a/src/loss/regression/MSE.cpp b/src/loss/regression/MSE.cpp
index 8c5ceeff4a57bea49ed667d9ed58713ff7cd0fee..87f685a0f550a1cb60563503447407f70868ce9a 100644
--- a/src/loss/regression/MSE.cpp
+++ b/src/loss/regression/MSE.cpp
@@ -26,23 +26,32 @@
 #include "aidge/recipes/GraphViewHelper.hpp"
 #include "aidge/scheduler/Scheduler.hpp"
 #include "aidge/scheduler/SequentialScheduler.hpp"
-int CPT = 0;
-Aidge::Tensor Aidge::loss::MSE(std::shared_ptr<GraphView> graph,
+
+Aidge::Tensor Aidge::loss::MSE(std::shared_ptr<Tensor>& prediction,
                                const std::shared_ptr<Tensor>& target) {
-    compile_gradient(graph);  // Warning compile gradient here, without
-                              // it, grad is nullptr. Maybe we can find a better
-                              // place to do so ?
+    /*
+    Implementation note:
+    MSE is computed using a graph in order to not be backend dependant.
+
+    The graph used is the following:
+
+    pred->Sub
+    label->Sub
+    Sub->Pow
+    (2)->Pow->ReduceMean->Loss
+    Sub->Mul
+    (2/NbBatch)->Mul->Gradient
+    */
+
+    prediction->initGrad(); // Enable gradient for output
+
+    // compile_gradient(graph);  // Warning compile gradient here, without
+    //                           // it, grad is nullptr. Maybe we can find a better
+    //                           // place to do so ?
 
     AIDGE_ASSERT(target->dims().size() == 2,
                  "Label must have two dims: [BatchSize, NbChannel]");
-    AIDGE_ASSERT(
-        graph->outputNodes().size() == 1,
-        "MSE can only be computed on graph with one output, {} were found.",
-        graph->outputs().size());
-    std::shared_ptr<Node> lastNode = *(graph->outputNodes().begin());
-    const std::shared_ptr<Tensor>& prediction =
-        std::dynamic_pointer_cast<OperatorTensor>(lastNode->getOperator())
-            ->getOutput(0);
+
     std::shared_ptr<Tensor> outputGrad = prediction->grad();
 
     AIDGE_ASSERT(prediction->backend() == target->backend(),
@@ -73,6 +82,8 @@ Aidge::Tensor Aidge::loss::MSE(std::shared_ptr<GraphView> graph,
     Producer(target, "label")->addChild(sub_node, 0, 1);
 
     const std::shared_ptr<Node> mul_node = Mul("gradient");
+
+    // Note: this assume target is [nbBatch, nbChan]
     Producer(std::make_shared<Tensor>(
                  Array1D<float, 1>{{2 / float(target->dims()[0])}}))
         ->addChild(mul_node, 0, 1);
diff --git a/unit_tests/loss/regression/Test_MSE.cpp b/unit_tests/loss/regression/Test_MSE.cpp
index 3899470b5f0141fc747f6a2a52cc35b41a590d49..2b0e6d1edfaa1d452c714a08c4998725331df2c3 100644
--- a/unit_tests/loss/regression/Test_MSE.cpp
+++ b/unit_tests/loss/regression/Test_MSE.cpp
@@ -35,9 +35,8 @@ TEST_CASE("[loss/regression] MSE", "[loss][regression][MSE]") {
     std::uniform_real_distribution<float> valueDist(0.0f, 1.0f);
 
     for (std::uint16_t trial = 0; trial < NBTRIALS; ++trial) {
-        // Create a random number generator
-        const std::size_t nb_dims = nbDimsDist(gen);
-        std::vector<std::size_t> dims(nb_dims);
+        const std::size_t nb_dims = 2; // For MSE test, nb_dims is fixed as 2: NbBatch, NbChan
+        std::vector<std::size_t> dims(2);
 
         for (std::size_t i = 0; i < nb_dims; ++i) { dims[i] = dimsDist(gen); }
         const std::size_t nb_elements = std::accumulate(dims.cbegin(), dims.cend(), std::size_t(1), std::multiplies<std::size_t>());
@@ -78,11 +77,11 @@ TEST_CASE("[loss/regression] MSE", "[loss][regression][MSE]") {
         targ_tensor->setBackend("cpu");
         targ_tensor->getImpl()->setRawPtr(targ.get(), nb_elements);
         targ_tensor->print();
-        const Tensor res_function = loss::MSE(pred_tensor, targ_tensor);
+            const Tensor res_function = loss::MSE(pred_tensor, targ_tensor);
 
         // compare results
         Tensor res_manual_tensor = Tensor(res_manual);
         REQUIRE(approxEq<float>(res_manual, res_function));
     }
 }
-}  // namespace Aidge
\ No newline at end of file
+}  // namespace Aidge