diff --git a/include/aidge/learning/optimizer/Optimizer.hpp b/include/aidge/learning/optimizer/Optimizer.hpp index 83ba3f37f35f608c416dc8750a25c8b226fac8bf..c4225bb0e67bd21d4f63f8492a1afa391898ca99 100644 --- a/include/aidge/learning/optimizer/Optimizer.hpp +++ b/include/aidge/learning/optimizer/Optimizer.hpp @@ -16,7 +16,9 @@ #include <vector> #include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" #include "aidge/operator/OperatorTensor.hpp" +#include "aidge/operator/MetaOperator.hpp" #include "aidge/learning/learningRate/LRScheduler.hpp" namespace Aidge { @@ -71,13 +73,9 @@ public: virtual void update() {} /** - * @brief Reset the gradient of each parameter registered in the Optimizer. + * @brief Reset recursively the gradient of each tensor in the GraphView */ - void resetGrad() const { - for (const auto& t_ptr : mParameters) { - t_ptr -> grad() -> zeros(); - } - } + void resetGrad(std::shared_ptr<GraphView> graphView); }; } // namespace Aidge diff --git a/python_binding/learning/optimizer/pybind_optimizer.cpp b/python_binding/learning/optimizer/pybind_optimizer.cpp index 965e5730e20656b70a3918a1b957b4514e9f74f2..437db440511e28255465c426753945d5adcfaeb7 100644 --- a/python_binding/learning/optimizer/pybind_optimizer.cpp +++ b/python_binding/learning/optimizer/pybind_optimizer.cpp @@ -26,7 +26,7 @@ void init_Optimizer(py::module& m) { .def("learning_rate", &Optimizer::learningRate) .def("learning_rate_scheduler", &Optimizer::learningRateScheduler) .def("set_learning_rate_scheduler", &Optimizer::setLearningRateScheduler) - .def("reset_grad", &Optimizer::resetGrad) + .def("reset_grad", &Optimizer::resetGrad, py::arg("graphview")) .def("update", &Optimizer::update); } // } // namespace learning diff --git a/src/optimizer/Optimizer.cpp b/src/optimizer/Optimizer.cpp index 367f2e84b5acab55d9458aded76f3a39c7f9e9f5..5e1e2f38a3401170ee2e39b6dd0f83ac12d65a5a 100644 --- a/src/optimizer/Optimizer.cpp +++ b/src/optimizer/Optimizer.cpp @@ -12,3 +12,20 @@ #include "aidge/learning/optimizer/Optimizer.hpp" Aidge::Optimizer::~Optimizer() noexcept = default; + +void Aidge::Optimizer::resetGrad(std::shared_ptr<GraphView> graphView) +{ + for (auto node : graphView->getNodes()) + { + auto op = node->getOperator(); + if (op->isAtomic()) { + auto tensorOp = std::static_pointer_cast<OperatorTensor>(op); + for (auto outputTensor : tensorOp->getOutputs()) { + outputTensor->grad()->zeros(); + } + } else { + auto metaOp = std::static_pointer_cast<MetaOperator_Op>(op); + resetGrad(metaOp->getMicroGraph()); + } + } +} \ No newline at end of file