From 5678a9645b229c723e4f37c0d0c9d6082e435439 Mon Sep 17 00:00:00 2001
From: bhalimi <benjamin.halimi@cea.fr>
Date: Tue, 10 Jun 2025 14:35:55 +0000
Subject: [PATCH 1/7] use the device index (CELoss + SGD)

---
 include/aidge/learning/optimizer/SGD.hpp | 30 ++++++++++--
 src/loss/classification/CELoss.cpp       | 59 +++++++++++++++++++-----
 2 files changed, 73 insertions(+), 16 deletions(-)

diff --git a/include/aidge/learning/optimizer/SGD.hpp b/include/aidge/learning/optimizer/SGD.hpp
index da029b3..d388684 100644
--- a/include/aidge/learning/optimizer/SGD.hpp
+++ b/include/aidge/learning/optimizer/SGD.hpp
@@ -22,6 +22,9 @@
 #include "aidge/utils/StaticAttributes.hpp"
 #include "aidge/utils/Registrar.hpp"
 #include "aidge/utils/TensorUtils.hpp"
+#include "aidge/operator/Mul.hpp"
+#include "aidge/operator/Sub.hpp"
+
 
 namespace Aidge {
 
@@ -56,16 +59,33 @@ public:
     }
 
     void update() override final {
+
+        auto backend  = mParameters[0]->backend();
+        auto device   = mParameters[0]->device();
+        auto dataType = mParameters[0]->dataType();
+
         mLR = Tensor(learningRate());
-        mLR.setBackend(mParameters[0]->getImpl()->backend());
-        mLR.setDataType(mParameters[0]->dataType());
-        mWeightDecay.setBackend(mParameters[0]->getImpl()->backend());
-        mWeightDecay.setDataType(mParameters[0]->dataType());
+
+        // Set backends / devices
+
+        mLR.setDataType(dataType);
+        mLR.setBackend(backend, device);
+
+        mWeightDecay.setDataType(dataType);
+        mWeightDecay.setBackend(backend, device);
+
+        mReversedDampening.setDataType(dataType);
+        mReversedDampening.setBackend(backend, device);
+
+        mMomentum.setDataType(dataType);
+        mMomentum.setBackend(backend, device);
+
+        // update loop
 
         if (mLRScheduler.step() == 0) {
             for (std::size_t i = 0; i < mParameters.size(); ++i) {
                 mGradientInertia[i] = mParameters[i]->grad()->clone();
-                *mParameters[i] -= mLR*mGradientInertia[i];
+                *mParameters[i] -= mLR * mGradientInertia[i];
             }
         } else {
             for (std::size_t i = 0; i < mParameters.size(); ++i) {
diff --git a/src/loss/classification/CELoss.cpp b/src/loss/classification/CELoss.cpp
index 0a2ba76..96f05d0 100644
--- a/src/loss/classification/CELoss.cpp
+++ b/src/loss/classification/CELoss.cpp
@@ -36,14 +36,22 @@ Aidge::Tensor Aidge::loss::CELoss(std::shared_ptr<Tensor>& prediction,
 {
     AIDGE_ASSERT(prediction->nbDims() == 2, 
                  "Label must have two dims: [BatchSize, NbChannel]");
+
     AIDGE_ASSERT(prediction->backend() == target->backend(),
                  "'prediction' and 'target' Tensors must be on the "
                  "same backend. Found {} and {}.\n",
                  prediction->backend(), target->backend());
+
+    AIDGE_ASSERT(prediction->device() == target->device(),
+                 "'prediction' and 'target' Tensors must be on the "
+                 "same device. Found {} and {}.\n",
+                 prediction->device(), target->device());
+
     AIDGE_ASSERT(prediction->dims() == target->dims(),
                  "'prediction' (shape {}) and 'target' (shape {}) Tensors must "
                  "have the same dimensions.\n",
                  prediction->dims(), target->dims());
+
     AIDGE_ASSERT(prediction->dataType() == target->dataType(),
                  "'prediction' (data type {}) and 'target' (data type {}) "
                  "Tensors must have the same data type.\n",
@@ -51,17 +59,21 @@ Aidge::Tensor Aidge::loss::CELoss(std::shared_ptr<Tensor>& prediction,
 
     auto backend  = prediction->backend();
     auto dataType = prediction->dataType();
+    auto device   = prediction->device();
 
     // Compute the predicition SoftMax
 
     auto softmaxOp = Softmax_Op(1);
     softmaxOp.setDataType(dataType);
-    softmaxOp.setBackend(backend);
+    softmaxOp.setBackend(backend, device);
 
     softmaxOp.associateInput(0, prediction);
     softmaxOp.forward();
     auto softmax = softmaxOp.getOutput(0);
 
+    // Log::notice(" softmax {}", softmax->device());
+    // Log::notice(" target  {}", target->device());
+
     // Compute the loss value using a GraphView
 
     auto targetNode  = Producer(target);
@@ -79,19 +91,23 @@ Aidge::Tensor Aidge::loss::CELoss(std::shared_ptr<Tensor>& prediction,
 
     std::shared_ptr<GraphView> lossGraphView = std::make_shared<GraphView>("CELoss");
     lossGraphView->add({targetNode, softmaxNode, logNode, mulNode, sumNode, meanNode});
-    lossGraphView->compile(backend, dataType);
+    lossGraphView->compile(backend, dataType, device);
 
     SequentialScheduler scheduler(lossGraphView);
     scheduler.forward(true);
 
     auto meanOp = std::static_pointer_cast<OperatorTensor>(meanNode->getOperator());
     auto lossTensor = meanOp->getOutput(0);
+    
+    // Log::notice(" lossTensor {}", lossTensor->device());
+
+    auto scalar = std::make_shared<Tensor>(Tensor(-1.0f));
+    scalar->setBackend(backend, device);
+    scalar->setDataType(dataType);
 
-    auto scalar = Tensor(-1.0f);
-    scalar.setBackend(backend);
-    scalar.setDataType(dataType);
+    // Log::notice(" scalar {}", scalar->device());
 
-    (*lossTensor) = (*lossTensor) * scalar;
+    (*lossTensor) = (*lossTensor) * (*scalar);
 
     lossTensor->setBackend("cpu");
 
@@ -99,7 +115,7 @@ Aidge::Tensor Aidge::loss::CELoss(std::shared_ptr<Tensor>& prediction,
 
     auto subOp = Sub_Op();
     subOp.setDataType(dataType);
-    subOp.setBackend(backend);
+    subOp.setBackend(backend, device);
 
     subOp.associateInput(0, softmax);
     subOp.associateInput(1, target);
@@ -109,14 +125,35 @@ Aidge::Tensor Aidge::loss::CELoss(std::shared_ptr<Tensor>& prediction,
 
     const float batchSize = static_cast<float>((target->dims())[0]);
 
-    scalar = Tensor(1.0f / batchSize);
-    scalar.setBackend(backend);
-    scalar.setDataType(dataType);
+    // Compute the rescaled error
 
-    (*err) = (*err) * scalar;
+    scalar = std::make_shared<Tensor>(Tensor(1.0f / batchSize));
+    scalar->setBackend(backend, device);
+    scalar->setDataType(dataType);
+
+    //scalar = Tensor(1.0f / batchSize);
+    //scalar.setBackend(backend, device);
+    //scalar.setDataType(dataType);
+
+    // XXX (*err) = (*err) * (*scalar);   
+
+    auto mulOp = Mul_Op();
+    mulOp.setDataType(dataType);
+    mulOp.setBackend(backend, device);
+
+    mulOp.associateInput(0, err);
+    mulOp.associateInput(1, scalar);
+    mulOp.forward();
+    err = mulOp.getOutput(0);    
+
+    // Log::notice(" err {}", err->device());
+
+    // Set the error signal
 
     prediction->setGrad(err);
 
+    // Log::notice(" prediction {} {} " , prediction->device(), prediction->grad()->device());
+
     // Return the loss value
 
     return (*lossTensor); 
-- 
GitLab


From 4a03ee148b688ba22f30e6b47da9bc95f94733c0 Mon Sep 17 00:00:00 2001
From: bhalimi <benjamin.halimi@cea.fr>
Date: Tue, 10 Jun 2025 14:47:45 +0000
Subject: [PATCH 2/7] minor changes

---
 include/aidge/learning/optimizer/SGD.hpp |  3 ---
 src/loss/classification/CELoss.cpp       | 19 +++----------------
 2 files changed, 3 insertions(+), 19 deletions(-)

diff --git a/include/aidge/learning/optimizer/SGD.hpp b/include/aidge/learning/optimizer/SGD.hpp
index d388684..b03054e 100644
--- a/include/aidge/learning/optimizer/SGD.hpp
+++ b/include/aidge/learning/optimizer/SGD.hpp
@@ -22,9 +22,6 @@
 #include "aidge/utils/StaticAttributes.hpp"
 #include "aidge/utils/Registrar.hpp"
 #include "aidge/utils/TensorUtils.hpp"
-#include "aidge/operator/Mul.hpp"
-#include "aidge/operator/Sub.hpp"
-
 
 namespace Aidge {
 
diff --git a/src/loss/classification/CELoss.cpp b/src/loss/classification/CELoss.cpp
index 96f05d0..b3c3ee8 100644
--- a/src/loss/classification/CELoss.cpp
+++ b/src/loss/classification/CELoss.cpp
@@ -71,9 +71,6 @@ Aidge::Tensor Aidge::loss::CELoss(std::shared_ptr<Tensor>& prediction,
     softmaxOp.forward();
     auto softmax = softmaxOp.getOutput(0);
 
-    // Log::notice(" softmax {}", softmax->device());
-    // Log::notice(" target  {}", target->device());
-
     // Compute the loss value using a GraphView
 
     auto targetNode  = Producer(target);
@@ -99,14 +96,10 @@ Aidge::Tensor Aidge::loss::CELoss(std::shared_ptr<Tensor>& prediction,
     auto meanOp = std::static_pointer_cast<OperatorTensor>(meanNode->getOperator());
     auto lossTensor = meanOp->getOutput(0);
     
-    // Log::notice(" lossTensor {}", lossTensor->device());
-
     auto scalar = std::make_shared<Tensor>(Tensor(-1.0f));
     scalar->setBackend(backend, device);
     scalar->setDataType(dataType);
 
-    // Log::notice(" scalar {}", scalar->device());
-
     (*lossTensor) = (*lossTensor) * (*scalar);
 
     lossTensor->setBackend("cpu");
@@ -127,15 +120,13 @@ Aidge::Tensor Aidge::loss::CELoss(std::shared_ptr<Tensor>& prediction,
 
     // Compute the rescaled error
 
-    scalar = std::make_shared<Tensor>(Tensor(1.0f / batchSize));
-    scalar->setBackend(backend, device);
-    scalar->setDataType(dataType);
-
     //scalar = Tensor(1.0f / batchSize);
     //scalar.setBackend(backend, device);
     //scalar.setDataType(dataType);
 
-    // XXX (*err) = (*err) * (*scalar);   
+    scalar = std::make_shared<Tensor>(Tensor(1.0f / batchSize));
+    scalar->setBackend(backend, device);
+    scalar->setDataType(dataType);
 
     auto mulOp = Mul_Op();
     mulOp.setDataType(dataType);
@@ -146,14 +137,10 @@ Aidge::Tensor Aidge::loss::CELoss(std::shared_ptr<Tensor>& prediction,
     mulOp.forward();
     err = mulOp.getOutput(0);    
 
-    // Log::notice(" err {}", err->device());
-
     // Set the error signal
 
     prediction->setGrad(err);
 
-    // Log::notice(" prediction {} {} " , prediction->device(), prediction->grad()->device());
-
     // Return the loss value
 
     return (*lossTensor); 
-- 
GitLab


From 6585f6c06d40a2f763fe58b403976defae365146 Mon Sep 17 00:00:00 2001
From: bhalimi <benjamin.halimi@cea.fr>
Date: Tue, 10 Jun 2025 15:08:05 +0000
Subject: [PATCH 3/7] simplify the CELoss code

---
 src/loss/classification/CELoss.cpp | 31 +++++++++++-------------------
 1 file changed, 11 insertions(+), 20 deletions(-)

diff --git a/src/loss/classification/CELoss.cpp b/src/loss/classification/CELoss.cpp
index b3c3ee8..657074f 100644
--- a/src/loss/classification/CELoss.cpp
+++ b/src/loss/classification/CELoss.cpp
@@ -95,12 +95,12 @@ Aidge::Tensor Aidge::loss::CELoss(std::shared_ptr<Tensor>& prediction,
 
     auto meanOp = std::static_pointer_cast<OperatorTensor>(meanNode->getOperator());
     auto lossTensor = meanOp->getOutput(0);
-    
-    auto scalar = std::make_shared<Tensor>(Tensor(-1.0f));
-    scalar->setBackend(backend, device);
-    scalar->setDataType(dataType);
 
-    (*lossTensor) = (*lossTensor) * (*scalar);
+    auto scalar = Tensor(-1.0f);
+    scalar.setBackend(backend, device);
+    scalar.setDataType(dataType);
+
+    (*lossTensor) = (*lossTensor) * scalar;
 
     lossTensor->setBackend("cpu");
 
@@ -120,25 +120,16 @@ Aidge::Tensor Aidge::loss::CELoss(std::shared_ptr<Tensor>& prediction,
 
     // Compute the rescaled error
 
-    //scalar = Tensor(1.0f / batchSize);
-    //scalar.setBackend(backend, device);
-    //scalar.setDataType(dataType);
-
-    scalar = std::make_shared<Tensor>(Tensor(1.0f / batchSize));
-    scalar->setBackend(backend, device);
-    scalar->setDataType(dataType);
+    scalar = Tensor(1.0f / batchSize);
+    scalar.setBackend(backend, device);
+    scalar.setDataType(dataType);
 
-    auto mulOp = Mul_Op();
-    mulOp.setDataType(dataType);
-    mulOp.setBackend(backend, device);
-
-    mulOp.associateInput(0, err);
-    mulOp.associateInput(1, scalar);
-    mulOp.forward();
-    err = mulOp.getOutput(0);    
+    (*err) = (*err) * scalar;
 
     // Set the error signal
 
+    // Log::notice(" backend : {} {} ", err->backend(), err->device());
+
     prediction->setGrad(err);
 
     // Return the loss value
-- 
GitLab


From a20f8d6ca5dfb2d8631fc779d84700914824c028 Mon Sep 17 00:00:00 2001
From: bhalimi <benjamin.halimi@cea.fr>
Date: Mon, 16 Jun 2025 12:07:34 +0000
Subject: [PATCH 4/7] modify the accuracy routine

---
 include/aidge/learning/metrics/Accuracy.hpp |  4 +-
 src/metrics/Accuracy.cpp                    | 59 +++++++++++----------
 2 files changed, 34 insertions(+), 29 deletions(-)

diff --git a/include/aidge/learning/metrics/Accuracy.hpp b/include/aidge/learning/metrics/Accuracy.hpp
index 34e9b6e..9ea9cc2 100644
--- a/include/aidge/learning/metrics/Accuracy.hpp
+++ b/include/aidge/learning/metrics/Accuracy.hpp
@@ -30,8 +30,8 @@ namespace metrics {
  * @param target Tensor representing the ground truth, it must be one hot encoded.
  * @param axis The classes axis.
  */
-Tensor Accuracy(std::shared_ptr<Tensor>& prediction,
-                const std::shared_ptr<Tensor>& target,
+Tensor Accuracy(std::shared_ptr<Tensor> prediction,
+                const std::shared_ptr<Tensor> target,
                 std::int32_t axis);
 }  // namespace Metrics
 }  // namespace Aidge
diff --git a/src/metrics/Accuracy.cpp b/src/metrics/Accuracy.cpp
index 318d838..c7f8dd3 100644
--- a/src/metrics/Accuracy.cpp
+++ b/src/metrics/Accuracy.cpp
@@ -24,68 +24,73 @@
 #include "aidge/scheduler/Scheduler.hpp"
 #include "aidge/scheduler/SequentialScheduler.hpp"
 
-Aidge::Tensor Aidge::metrics::Accuracy(std::shared_ptr<Tensor>& prediction,
-                                        const std::shared_ptr<Tensor>& target,
+Aidge::Tensor Aidge::metrics::Accuracy(std::shared_ptr<Tensor> prediction,
+                                        const std::shared_ptr<Tensor> target,
                                         std::int32_t axis) {
     /*
-    Implementation note:
-    Accuracy is computed using a graph in order to not be backend dependant.
+        Implementation note:
+        Accuracy is computed using a graph in order to not be backend dependant.
+        The graph used is the following:
 
-    The graph used is the following:
-
-    pred->ArgMax
-                  ->Equal->ReduceSum
-    label->ArgMax
+            pred->ArgMax  \
+                           ->Equal->ReduceSum
+            label->ArgMax /
     */
 
     AIDGE_ASSERT(target->dims().size() == 2,
                  "Label must have two dims: [BatchSize, NbChannel]");
 
-    std::shared_ptr<Tensor> outputGrad = prediction->grad();
-
     AIDGE_ASSERT(prediction->backend() == target->backend(),
                  "'prediction' and 'target' Tensors must be on the "
                  "same backend. Found {} and {}.\n",
                  prediction->backend(), target->backend());
+
+    AIDGE_ASSERT(prediction->device() == target->device(),
+                 "'prediction' and 'target' Tensors must be on the "
+                 "same device. Found {} and {}.\n",
+                 prediction->device(), target->device());
+
     AIDGE_ASSERT(prediction->dims() == target->dims(),
                  "'prediction' (shape {}) and 'target' (shape {}) Tensors must "
                  "have the same dimensions.\n",
                  prediction->dims(), target->dims());
+
     AIDGE_ASSERT(prediction->dataType() == target->dataType(),
                  "'prediction' (data type {}) and 'target' (data type {}) "
                  "Tensors must have the same data type.\n",
                  prediction->dataType(), target->dataType());
 
     // Create graph nodes and connections
-    const std::shared_ptr<Node> argmax_perd_node = ArgMax(axis);
+
+    const std::shared_ptr<Node> argmax_pred_node = ArgMax(axis);
     const std::shared_ptr<Node> argmax_target_node = ArgMax(axis);
 
     const std::shared_ptr<Node> equal_node = Equal();
-    const std::shared_ptr<Node> rs_node = ReduceSum();
+    const std::shared_ptr<Node> reduce_node = ReduceSum();
 
     const std::shared_ptr<Node> pred_node = Producer(prediction, "pred");
-    pred_node->addChild(argmax_perd_node);
+    pred_node->addChild(argmax_pred_node);
+
     const std::shared_ptr<Node> label_node = Producer(target, "label");
     label_node->addChild(argmax_target_node);
 
-    argmax_perd_node->addChild(equal_node,0,0);
-    argmax_target_node->addChild(equal_node,0,1);
+    argmax_pred_node->addChild(equal_node, 0, 0);
+    argmax_target_node->addChild(equal_node, 0, 1);
+
+    std::shared_ptr<GraphView> graphView = Sequential({equal_node, reduce_node});
+
+    graphView->add({pred_node, argmax_pred_node, label_node, argmax_target_node});
+    graphView->compile(prediction->backend(), prediction->dataType(), prediction->device());
 
-    // equal_node->addChild(rs_node,0,0);
+    // Execute the graph and retreive the result
 
-    // Create the graph
-    std::shared_ptr<GraphView> gv_local =
-        Sequential({ equal_node, rs_node});
+    SequentialScheduler scheduler{graphView};
+    scheduler.forward(false);
 
-    gv_local->add({pred_node,argmax_perd_node, label_node,argmax_target_node});
-    gv_local->compile(prediction->getImpl()->backend(), prediction->dataType());
+    // TODO : way too complicated to access
 
-    SequentialScheduler ss_local{gv_local};
-    ss_local.forward(false);
+    const std::shared_ptr<OperatorTensor> res = std::dynamic_pointer_cast<OperatorTensor>(reduce_node->getOperator());
 
-    // TODO: way too complicated to access
-    const std::shared_ptr<OperatorTensor> res =
-        std::dynamic_pointer_cast<OperatorTensor>(rs_node->getOperator());
     std::shared_ptr<Tensor> fallback;
 
     return res->getOutput(0)->refFrom(fallback, "cpu");
-- 
GitLab


From a5d731b8e429f440beab7576248771b5c67ded23 Mon Sep 17 00:00:00 2001
From: bhalimi <benjamin.halimi@cea.fr>
Date: Tue, 17 Jun 2025 09:38:49 +0000
Subject: [PATCH 5/7] use the device index

---
 include/aidge/learning/metrics/Accuracy.hpp |  4 +-
 include/aidge/learning/optimizer/Adam.hpp   | 50 ++++++++++++++++-----
 include/aidge/learning/optimizer/SGD.hpp    |  6 +--
 src/loss/classification/BCE.cpp             | 10 ++++-
 src/loss/distillation/KD.cpp                |  9 +++-
 src/loss/regression/MSE.cpp                 |  9 +++-
 src/metrics/Accuracy.cpp                    |  4 +-
 7 files changed, 71 insertions(+), 21 deletions(-)

diff --git a/include/aidge/learning/metrics/Accuracy.hpp b/include/aidge/learning/metrics/Accuracy.hpp
index 9ea9cc2..34e9b6e 100644
--- a/include/aidge/learning/metrics/Accuracy.hpp
+++ b/include/aidge/learning/metrics/Accuracy.hpp
@@ -30,8 +30,8 @@ namespace metrics {
  * @param target Tensor representing the ground truth, it must be one hot encoded.
  * @param axis The classes axis.
  */
-Tensor Accuracy(std::shared_ptr<Tensor> prediction,
-                const std::shared_ptr<Tensor> target,
+Tensor Accuracy(std::shared_ptr<Tensor>& prediction,
+                const std::shared_ptr<Tensor>& target,
                 std::int32_t axis);
 }  // namespace Metrics
 }  // namespace Aidge
diff --git a/include/aidge/learning/optimizer/Adam.hpp b/include/aidge/learning/optimizer/Adam.hpp
index 8c89e53..da21469 100644
--- a/include/aidge/learning/optimizer/Adam.hpp
+++ b/include/aidge/learning/optimizer/Adam.hpp
@@ -58,23 +58,43 @@ public:
     }
 
     void update() override final {
+
+        auto backend = mParameters[0]->backend();
+        auto device = mParameters[0]->device();
+        auto dataType = mParameters[0]->dataType();
+
         float mBeta1Power = std::pow(this->getAttr<AdamAttr::Beta1>(), static_cast<float>(mLRScheduler.step() + 1));
         float mBeta2Power = std::pow(this->getAttr<AdamAttr::Beta2>(), static_cast<float>(mLRScheduler.step() + 1));
+        
         float mReversedBeta1Power = 1.0f - mBeta1Power;
         float mSqrtReversedBeta2Power = std::sqrt(1.0f - mBeta2Power);
 
         Tensor alpha = Tensor(learningRate() * mSqrtReversedBeta2Power / mReversedBeta1Power);
-        alpha.setBackend(mParameters[0]->getImpl()->backend());
-        alpha.setDataType(mParameters[0]->dataType());
+        alpha.setBackend(backend, device);
+        alpha.setDataType(dataType);
 
         Tensor epsilon_hat = Tensor(this->getAttr<AdamAttr::Epsilon>() * mSqrtReversedBeta2Power);
-        epsilon_hat.setBackend(mParameters[0]->getImpl()->backend());
-        epsilon_hat.setDataType(mParameters[0]->dataType());
+        epsilon_hat.setBackend(backend, device);
+        epsilon_hat.setDataType(dataType);
+
+        mBeta1.setBackend(backend, device);
+        mBeta1.setDataType(dataType);
+        mReversedBeta1.setBackend(backend, device);
+        mReversedBeta1.setDataType(dataType);
+
+        mBeta2.setBackend(backend, device);
+        mBeta2.setDataType(dataType);
+        mReversedBeta2.setBackend(backend, device);
+        mReversedBeta2.setDataType(dataType);
 
         if (mLRScheduler.step() == 0) {
             for (std::size_t i = 0; i < mParameters.size(); ++i) {
                 mMomentum1[i].zeros();
-                mMomentum2[i].zeros();
+                mMomentum1[i].setBackend(backend, device);
+                mMomentum1[i].setDataType(dataType);
+                mMomentum2[i].zeros(); 
+                mMomentum2[i].setBackend(backend, device);
+                mMomentum2[i].setDataType(dataType);
             }
         }
 
@@ -88,25 +108,33 @@ public:
     }
 
     void setParameters(const std::vector<std::shared_ptr<Tensor>>& parameters) override final {
+
         Optimizer::setParameters(parameters);
         mMomentum1 = std::vector<Tensor>(parameters.size());
         mMomentum2 = std::vector<Tensor>(parameters.size());
+
         for (std::size_t i = 0; i < parameters.size(); ++i) {
+
             mMomentum1[i] = Tensor(parameters[i]->dims());
-            mMomentum1[i].setBackend(parameters[i]->getImpl()->backend());
+            mMomentum1[i].setBackend(parameters[i]->backend(), parameters[i]->device());
             mMomentum1[i].setDataType(parameters[i]->dataType());
+
             mMomentum2[i] = Tensor(parameters[i]->dims());
-            mMomentum2[i].setBackend(parameters[i]->getImpl()->backend());
+            mMomentum2[i].setBackend(parameters[i]->backend(), parameters[i]->device());
             mMomentum2[i].setDataType(parameters[i]->dataType());
         }
         if (parameters.size() > 0) {
-            mBeta1.setBackend(mParameters[0]->getImpl()->backend());
+
+            mBeta1.setBackend(mParameters[0]->backend(), mParameters[0]->device());
             mBeta1.setDataType(parameters[0]->dataType());
-            mReversedBeta1.setBackend(mParameters[0]->getImpl()->backend());
+
+            mReversedBeta1.setBackend(mParameters[0]->backend(), mParameters[0]->device());
             mReversedBeta1.setDataType(parameters[0]->dataType());
-            mBeta2.setBackend(mParameters[0]->getImpl()->backend());
+
+            mBeta2.setBackend(mParameters[0]->backend(), mParameters[0]->device());
             mBeta2.setDataType(parameters[0]->dataType());
-            mReversedBeta2.setBackend(mParameters[0]->getImpl()->backend());
+
+            mReversedBeta2.setBackend(mParameters[0]->backend(), mParameters[0]->device());
             mReversedBeta2.setDataType(parameters[0]->dataType());
         }
     }
diff --git a/include/aidge/learning/optimizer/SGD.hpp b/include/aidge/learning/optimizer/SGD.hpp
index b03054e..1c5f1a6 100644
--- a/include/aidge/learning/optimizer/SGD.hpp
+++ b/include/aidge/learning/optimizer/SGD.hpp
@@ -99,13 +99,13 @@ public:
         mGradientInertia = std::vector<Tensor>(parameters.size());
         for (std::size_t i = 0; i < parameters.size(); ++i) {
             mGradientInertia[i] = Tensor(parameters[i]->dims());
-            mGradientInertia[i].setBackend(parameters[i]->backend());
+            mGradientInertia[i].setBackend(parameters[i]->backend(), parameters[i]->device());
             mGradientInertia[i].setDataType(parameters[i]->dataType());
         }
         if (parameters.size() > 0) {
-            mReversedDampening.setBackend(mParameters[0]->getImpl()->backend());
+            mReversedDampening.setBackend(mParameters[0]->backend(), mParameters[0]->device());
             mReversedDampening.setDataType(parameters[0]->dataType());
-            mMomentum.setBackend(mParameters[0]->getImpl()->backend());
+            mMomentum.setBackend(mParameters[0]->backend(), mParameters[0]->device());
             mMomentum.setDataType(parameters[0]->dataType());
         }
     }
diff --git a/src/loss/classification/BCE.cpp b/src/loss/classification/BCE.cpp
index 722184e..dfaafa7 100644
--- a/src/loss/classification/BCE.cpp
+++ b/src/loss/classification/BCE.cpp
@@ -43,14 +43,22 @@ Aidge::Tensor Aidge::loss::BCE(std::shared_ptr<Tensor>& prediction,
 
     AIDGE_ASSERT(target->dims().size() == 2,
                  "Label must have two dims: [BatchSize, NbChannel]");
+
     AIDGE_ASSERT(prediction->backend() == target->backend(),
                  "'prediction' and 'target' Tensors must be on the "
                  "same backend. Found {} and {}.\n",
                  prediction->backend(), target->backend());
+
+    AIDGE_ASSERT(prediction->device() == target->device(),
+                 "'prediction' and 'target' Tensors must be on the "
+                 "same device. Found {} and {}.\n",
+                 prediction->device(), target->device());
+
     AIDGE_ASSERT(prediction->dims() == target->dims(),
                  "'prediction' (shape {}) and 'target' (shape {}) Tensors must "
                  "have the same dimensions.\n",
                  prediction->dims(), target->dims());
+
     AIDGE_ASSERT(prediction->dataType() == target->dataType(),
                  "'prediction' (data type {}) and 'target' (data type {}) "
                  "Tensors must have the same data type.\n",
@@ -134,7 +142,7 @@ Aidge::Tensor Aidge::loss::BCE(std::shared_ptr<Tensor>& prediction,
                   ln1_node, ln2_node,
                   sub2_node, mul1_node, mul2_node, sub3_node, loss_node,
                   sub4_node, mul3_node, div1_node, gradient_node->getParent(1), gradient_node});
-    gv_loss->compile(prediction->getImpl()->backend(), prediction->dataType());
+    gv_loss->compile(prediction->backend(), prediction->dataType(), prediction->device());
 
     // Compute loss and gradient
     SequentialScheduler ss_loss{gv_loss};
diff --git a/src/loss/distillation/KD.cpp b/src/loss/distillation/KD.cpp
index 4283a28..411be45 100644
--- a/src/loss/distillation/KD.cpp
+++ b/src/loss/distillation/KD.cpp
@@ -59,10 +59,17 @@ Aidge::Tensor Aidge::loss::KD(std::shared_ptr<Tensor>& student_prediction,
                  "'prediction' and 'target' Tensors must be on the "
                  "same backend. Found {} and {}.\n",
                  student_prediction->backend(), teacher_prediction->backend());
+
+    AIDGE_ASSERT(student_prediction->device() == teacher_prediction->device(),
+                 "'prediction' and 'target' Tensors must be on the "
+                 "same device. Found {} and {}.\n",
+                 student_prediction->device(), teacher_prediction->device());
+
     AIDGE_ASSERT(student_prediction->dims() == teacher_prediction->dims(),
                  "'prediction' (shape {}) and 'target' (shape {}) Tensors must "
                  "have the same dimensions.\n",
                  student_prediction->dims(), teacher_prediction->dims());
+
     AIDGE_ASSERT(student_prediction->dataType() == teacher_prediction->dataType(),
                  "'prediction' (data type {}) and 'target' (data type {}) "
                  "Tensors must have the same data type.\n",
@@ -134,7 +141,7 @@ Aidge::Tensor Aidge::loss::KD(std::shared_ptr<Tensor>& student_prediction,
                   soft_teacher_node, mul_node, 
                   mul2_node->getParent(1), mul2_node, 
                   rm_node, sub_node});
-    gv_loss->compile(student_prediction->getImpl()->backend(), student_prediction->dataType());
+    gv_loss->compile(student_prediction->backend(), student_prediction->dataType(), student_prediction->device());
 
     SequentialScheduler ss_loss{gv_loss};
     ss_loss.forward(false);
diff --git a/src/loss/regression/MSE.cpp b/src/loss/regression/MSE.cpp
index b82eab8..06f6741 100644
--- a/src/loss/regression/MSE.cpp
+++ b/src/loss/regression/MSE.cpp
@@ -56,10 +56,17 @@ Aidge::Tensor Aidge::loss::MSE(std::shared_ptr<Tensor>& prediction,
                  "'prediction' and 'target' Tensors must be on the "
                  "same backend. Found {} and {}.\n",
                  prediction->backend(), target->backend());
+
+    AIDGE_ASSERT(prediction->device() == target->device(),
+                 "'prediction' and 'target' Tensors must be on the "
+                 "same device. Found {} and {}.\n",
+                 prediction->device(), target->device());
+
     AIDGE_ASSERT(prediction->dims() == target->dims(),
                  "'prediction' (shape {}) and 'target' (shape {}) Tensors must "
                  "have the same dimensions.\n",
                  prediction->dims(), target->dims());
+
     AIDGE_ASSERT(prediction->dataType() == target->dataType(),
                  "'prediction' (data type {}) and 'target' (data type {}) "
                  "Tensors must have the same data type.\n",
@@ -91,7 +98,7 @@ Aidge::Tensor Aidge::loss::MSE(std::shared_ptr<Tensor>& prediction,
         Sequential({sub_node, pow_node, rm_node});
     gv_local->add({sub_node->getParent(0), sub_node->getParent(1), pow_exp_node,
                    mul_node->getParent(1), mul_node});
-    gv_local->compile(prediction->getImpl()->backend(), prediction->dataType());
+    gv_local->compile(prediction->backend(), prediction->dataType(), prediction->device());
 
     SequentialScheduler ss_local{gv_local};
     ss_local.forward(false);
diff --git a/src/metrics/Accuracy.cpp b/src/metrics/Accuracy.cpp
index c7f8dd3..2a65fc7 100644
--- a/src/metrics/Accuracy.cpp
+++ b/src/metrics/Accuracy.cpp
@@ -24,8 +24,8 @@
 #include "aidge/scheduler/Scheduler.hpp"
 #include "aidge/scheduler/SequentialScheduler.hpp"
 
-Aidge::Tensor Aidge::metrics::Accuracy(std::shared_ptr<Tensor> prediction,
-                                        const std::shared_ptr<Tensor> target,
+Aidge::Tensor Aidge::metrics::Accuracy(std::shared_ptr<Tensor>& prediction,
+                                        const std::shared_ptr<Tensor>& target,
                                         std::int32_t axis) {
     /*
         Implementation note:
-- 
GitLab


From 4b8422258b99cfa7a2b5203d25d4cff24361f3c0 Mon Sep 17 00:00:00 2001
From: bhalimi <benjamin.halimi@cea.fr>
Date: Fri, 4 Jul 2025 14:12:41 +0000
Subject: [PATCH 6/7] remove extra comments

---
 src/loss/classification/CELoss.cpp | 2 --
 src/metrics/Accuracy.cpp           | 6 ++----
 2 files changed, 2 insertions(+), 6 deletions(-)

diff --git a/src/loss/classification/CELoss.cpp b/src/loss/classification/CELoss.cpp
index 657074f..09a8f77 100644
--- a/src/loss/classification/CELoss.cpp
+++ b/src/loss/classification/CELoss.cpp
@@ -128,8 +128,6 @@ Aidge::Tensor Aidge::loss::CELoss(std::shared_ptr<Tensor>& prediction,
 
     // Set the error signal
 
-    // Log::notice(" backend : {} {} ", err->backend(), err->device());
-
     prediction->setGrad(err);
 
     // Return the loss value
diff --git a/src/metrics/Accuracy.cpp b/src/metrics/Accuracy.cpp
index 2a65fc7..5810fb0 100644
--- a/src/metrics/Accuracy.cpp
+++ b/src/metrics/Accuracy.cpp
@@ -87,11 +87,9 @@ Aidge::Tensor Aidge::metrics::Accuracy(std::shared_ptr<Tensor>& prediction,
     SequentialScheduler scheduler{graphView};
     scheduler.forward(false);
 
-    // TODO : way too complicated to access
-
-    const std::shared_ptr<OperatorTensor> res = std::dynamic_pointer_cast<OperatorTensor>(reduce_node->getOperator());
+    const std::shared_ptr<OperatorTensor> op = std::dynamic_pointer_cast<OperatorTensor>(reduce_node->getOperator());
 
     std::shared_ptr<Tensor> fallback;
 
-    return res->getOutput(0)->refFrom(fallback, "cpu");
+    return op->getOutput(0)->refFrom(fallback, "cpu");
 }
-- 
GitLab


From 6800fd3957b72bfd5a96ef5795ca6f77fdfca20e Mon Sep 17 00:00:00 2001
From: bhalimi <benjamin.halimi@cea.fr>
Date: Fri, 4 Jul 2025 14:55:21 +0000
Subject: [PATCH 7/7] fix the merge

---
 src/metrics/Accuracy.cpp | 16 +++++++++-------
 1 file changed, 9 insertions(+), 7 deletions(-)

diff --git a/src/metrics/Accuracy.cpp b/src/metrics/Accuracy.cpp
index 865f860..7181feb 100644
--- a/src/metrics/Accuracy.cpp
+++ b/src/metrics/Accuracy.cpp
@@ -81,18 +81,20 @@ Aidge::Tensor Aidge::metrics::Accuracy(std::shared_ptr<Tensor>& prediction,
     argmax_pred_node->addChild(equal_node, 0, 0);
     argmax_target_node->addChild(equal_node, 0, 1);
 
-    std::shared_ptr<GraphView> graphView = Sequential({equal_node, reduce_node});
-
     // Create the graph
-    std::shared_ptr<GraphView> gv_local =
-        Sequential({equal_node, cast_node, rs_node});
+
+    std::shared_ptr<GraphView> gv_local = Sequential({equal_node, cast_node, rs_node});
+
+    gv_local->add({pred_node, argmax_pred_node, label_node, argmax_target_node});
+    gv_local->compile(prediction->backend(), prediction->dataType(), prediction->device());
 
     // Execute the graph and retreive the result
 
-    SequentialScheduler scheduler{graphView};
-    scheduler.forward(false);
+    SequentialScheduler ss_local{gv_local};
+
+    ss_local.forward(false);
 
-    const std::shared_ptr<OperatorTensor> op = std::dynamic_pointer_cast<OperatorTensor>(reduce_node->getOperator());
+    const std::shared_ptr<OperatorTensor> op = std::dynamic_pointer_cast<OperatorTensor>(rs_node->getOperator());
 
     std::shared_ptr<Tensor> fallback;
 
-- 
GitLab