Skip to content
Snippets Groups Projects

[Fix] Rework the Gradient Clearing routine

Merged Benjamin Halimi requested to merge grad_rework into dev
1 unresolved thread
3 files
+ 22
7
Compare changes
  • Side-by-side
  • Inline
Files
3
@@ -16,7 +16,9 @@
#include <vector>
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/MetaOperator.hpp"
#include "aidge/learning/learningRate/LRScheduler.hpp"
namespace Aidge {
@@ -71,13 +73,9 @@ public:
virtual void update() {}
/**
* @brief Reset the gradient of each parameter registered in the Optimizer.
* @brief Reset recursively the gradient of each tensor in the GraphView
*/
void resetGrad() const {
for (const auto& t_ptr : mParameters) {
t_ptr -> grad() -> zeros();
}
}
void resetGrad(std::shared_ptr<GraphView> graphView);
};
} // namespace Aidge
Loading