From 1cca0848cdea96c172c3021907c43f75691b55b8 Mon Sep 17 00:00:00 2001 From: bhalimi <benjamin.halimi@cea.fr> Date: Tue, 11 Feb 2025 14:48:09 +0000 Subject: [PATCH 1/2] rework the gradient cleaning routine --- include/aidge/learning/optimizer/Optimizer.hpp | 10 ++++------ .../learning/optimizer/pybind_optimizer.cpp | 2 +- src/optimizer/Optimizer.cpp | 15 +++++++++++++++ 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/include/aidge/learning/optimizer/Optimizer.hpp b/include/aidge/learning/optimizer/Optimizer.hpp index 83ba3f3..c4225bb 100644 --- a/include/aidge/learning/optimizer/Optimizer.hpp +++ b/include/aidge/learning/optimizer/Optimizer.hpp @@ -16,7 +16,9 @@ #include <vector> #include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" #include "aidge/operator/OperatorTensor.hpp" +#include "aidge/operator/MetaOperator.hpp" #include "aidge/learning/learningRate/LRScheduler.hpp" namespace Aidge { @@ -71,13 +73,9 @@ public: virtual void update() {} /** - * @brief Reset the gradient of each parameter registered in the Optimizer. + * @brief Reset recursively the gradient of each tensor in the GraphView */ - void resetGrad() const { - for (const auto& t_ptr : mParameters) { - t_ptr -> grad() -> zeros(); - } - } + void resetGrad(std::shared_ptr<GraphView> graphView); }; } // namespace Aidge diff --git a/python_binding/learning/optimizer/pybind_optimizer.cpp b/python_binding/learning/optimizer/pybind_optimizer.cpp index 965e573..437db44 100644 --- a/python_binding/learning/optimizer/pybind_optimizer.cpp +++ b/python_binding/learning/optimizer/pybind_optimizer.cpp @@ -26,7 +26,7 @@ void init_Optimizer(py::module& m) { .def("learning_rate", &Optimizer::learningRate) .def("learning_rate_scheduler", &Optimizer::learningRateScheduler) .def("set_learning_rate_scheduler", &Optimizer::setLearningRateScheduler) - .def("reset_grad", &Optimizer::resetGrad) + .def("reset_grad", &Optimizer::resetGrad, py::arg("graphview")) .def("update", &Optimizer::update); } // } // namespace learning diff --git a/src/optimizer/Optimizer.cpp b/src/optimizer/Optimizer.cpp index 367f2e8..723f140 100644 --- a/src/optimizer/Optimizer.cpp +++ b/src/optimizer/Optimizer.cpp @@ -12,3 +12,18 @@ #include "aidge/learning/optimizer/Optimizer.hpp" Aidge::Optimizer::~Optimizer() noexcept = default; + +void Aidge::Optimizer::resetGrad(std::shared_ptr<GraphView> graphView) +{ + for (auto node : graphView->getNodes()) + { + auto op = node->getOperator(); + if (op->isAtomic()) { + auto tensorOp = std::static_pointer_cast<OperatorTensor>(op); + tensorOp->getOutput(0)->grad()->zeros(); + } else { + auto metaOp = std::static_pointer_cast<MetaOperator_Op>(op); + resetGrad(metaOp->getMicroGraph()); + } + } +} \ No newline at end of file -- GitLab From 062008bc67760ab3361ab339931e7da3374251cc Mon Sep 17 00:00:00 2001 From: bhalimi <benjamin.halimi@cea.fr> Date: Tue, 11 Feb 2025 15:16:09 +0000 Subject: [PATCH 2/2] handle multi-output operators --- src/optimizer/Optimizer.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/optimizer/Optimizer.cpp b/src/optimizer/Optimizer.cpp index 723f140..5e1e2f3 100644 --- a/src/optimizer/Optimizer.cpp +++ b/src/optimizer/Optimizer.cpp @@ -20,7 +20,9 @@ void Aidge::Optimizer::resetGrad(std::shared_ptr<GraphView> graphView) auto op = node->getOperator(); if (op->isAtomic()) { auto tensorOp = std::static_pointer_cast<OperatorTensor>(op); - tensorOp->getOutput(0)->grad()->zeros(); + for (auto outputTensor : tensorOp->getOutputs()) { + outputTensor->grad()->zeros(); + } } else { auto metaOp = std::static_pointer_cast<MetaOperator_Op>(op); resetGrad(metaOp->getMicroGraph()); -- GitLab