From 1cca0848cdea96c172c3021907c43f75691b55b8 Mon Sep 17 00:00:00 2001
From: bhalimi <benjamin.halimi@cea.fr>
Date: Tue, 11 Feb 2025 14:48:09 +0000
Subject: [PATCH 1/2] rework the gradient cleaning routine

---
 include/aidge/learning/optimizer/Optimizer.hpp    | 10 ++++------
 .../learning/optimizer/pybind_optimizer.cpp       |  2 +-
 src/optimizer/Optimizer.cpp                       | 15 +++++++++++++++
 3 files changed, 20 insertions(+), 7 deletions(-)

diff --git a/include/aidge/learning/optimizer/Optimizer.hpp b/include/aidge/learning/optimizer/Optimizer.hpp
index 83ba3f3..c4225bb 100644
--- a/include/aidge/learning/optimizer/Optimizer.hpp
+++ b/include/aidge/learning/optimizer/Optimizer.hpp
@@ -16,7 +16,9 @@
 #include <vector>
 
 #include "aidge/data/Tensor.hpp"
+#include "aidge/graph/GraphView.hpp"
 #include "aidge/operator/OperatorTensor.hpp"
+#include "aidge/operator/MetaOperator.hpp"
 #include "aidge/learning/learningRate/LRScheduler.hpp"
 
 namespace Aidge {
@@ -71,13 +73,9 @@ public:
     virtual void update() {}
 
     /**
-     * @brief Reset the gradient of each parameter registered in the Optimizer.
+     * @brief Reset recursively the gradient of each tensor in the GraphView
      */
-    void resetGrad() const {
-        for (const auto& t_ptr : mParameters) {
-            t_ptr -> grad() -> zeros();
-        }
-    }
+    void resetGrad(std::shared_ptr<GraphView> graphView);
 };
 
 } // namespace Aidge
diff --git a/python_binding/learning/optimizer/pybind_optimizer.cpp b/python_binding/learning/optimizer/pybind_optimizer.cpp
index 965e573..437db44 100644
--- a/python_binding/learning/optimizer/pybind_optimizer.cpp
+++ b/python_binding/learning/optimizer/pybind_optimizer.cpp
@@ -26,7 +26,7 @@ void init_Optimizer(py::module& m) {
     .def("learning_rate", &Optimizer::learningRate)
     .def("learning_rate_scheduler", &Optimizer::learningRateScheduler)
     .def("set_learning_rate_scheduler", &Optimizer::setLearningRateScheduler)
-    .def("reset_grad", &Optimizer::resetGrad)
+    .def("reset_grad", &Optimizer::resetGrad, py::arg("graphview"))
     .def("update", &Optimizer::update);
 }
 // }  // namespace learning
diff --git a/src/optimizer/Optimizer.cpp b/src/optimizer/Optimizer.cpp
index 367f2e8..723f140 100644
--- a/src/optimizer/Optimizer.cpp
+++ b/src/optimizer/Optimizer.cpp
@@ -12,3 +12,18 @@
 #include "aidge/learning/optimizer/Optimizer.hpp"
 
 Aidge::Optimizer::~Optimizer() noexcept = default;
+
+void Aidge::Optimizer::resetGrad(std::shared_ptr<GraphView> graphView) 
+{
+    for (auto node : graphView->getNodes())
+    {
+        auto op = node->getOperator();
+        if (op->isAtomic()) {
+            auto tensorOp = std::static_pointer_cast<OperatorTensor>(op);
+            tensorOp->getOutput(0)->grad()->zeros();
+        } else {
+            auto metaOp = std::static_pointer_cast<MetaOperator_Op>(op);
+            resetGrad(metaOp->getMicroGraph());
+        }
+    }
+}
\ No newline at end of file
-- 
GitLab


From 062008bc67760ab3361ab339931e7da3374251cc Mon Sep 17 00:00:00 2001
From: bhalimi <benjamin.halimi@cea.fr>
Date: Tue, 11 Feb 2025 15:16:09 +0000
Subject: [PATCH 2/2] handle multi-output operators

---
 src/optimizer/Optimizer.cpp | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/src/optimizer/Optimizer.cpp b/src/optimizer/Optimizer.cpp
index 723f140..5e1e2f3 100644
--- a/src/optimizer/Optimizer.cpp
+++ b/src/optimizer/Optimizer.cpp
@@ -20,7 +20,9 @@ void Aidge::Optimizer::resetGrad(std::shared_ptr<GraphView> graphView)
         auto op = node->getOperator();
         if (op->isAtomic()) {
             auto tensorOp = std::static_pointer_cast<OperatorTensor>(op);
-            tensorOp->getOutput(0)->grad()->zeros();
+            for (auto outputTensor : tensorOp->getOutputs()) {
+                outputTensor->grad()->zeros();
+            }
         } else {
             auto metaOp = std::static_pointer_cast<MetaOperator_Op>(op);
             resetGrad(metaOp->getMicroGraph());
-- 
GitLab