From ecc96977839512bfdac4374aa76e0887ef678fce Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Tue, 18 Jun 2024 22:12:39 +0200 Subject: [PATCH] Added lazy init for grad --- include/aidge/data/Tensor.hpp | 18 ++++++++---------- include/aidge/recipes/GraphViewHelper.hpp | 2 -- .../aidge/scheduler/SequentialScheduler.hpp | 2 +- python_binding/data/pybind_Tensor.cpp | 1 - .../recipes/pybind_GraphViewHelper.cpp | 1 - src/recipes/GraphViewHelper.cpp | 11 ----------- src/scheduler/SequentialScheduler.cpp | 5 +---- 7 files changed, 10 insertions(+), 30 deletions(-) diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index 5f6be6045..6c1806101 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -591,23 +591,16 @@ public: inline void print() const { fmt::print("{}\n", toString()); } - std::shared_ptr<Tensor> grad() { - return mGrad; - } - void setGrad(std::shared_ptr<Tensor> newGrad) { - mGrad = newGrad; - } - /** - * @brief Associate the gradient with a Tensor instance and set its implementation - * if none was previously set. + * @brief Get the gradient Tensor. If not initialized, set a Tensor instance + * and set its implementation if none was previously set. * @note Dimensions for the Tensor instance are copied from the original current Tensor. * @note If a Tensor instance was already associated, only the implementation is created * with values set to 0. * @note If Tensor instance and implementation already existed for the gradient * nothing is done. */ - void initGrad() { + std::shared_ptr<Tensor> grad() { if (!mGrad) { mGrad = std::make_shared<Tensor>(mDims); } @@ -617,6 +610,11 @@ public: mGrad->setBackend(hasImpl() ? mImpl->backend() : "cpu"); mGrad->zeros(); } + return mGrad; + } + + void setGrad(std::shared_ptr<Tensor> newGrad) { + mGrad = newGrad; } /** diff --git a/include/aidge/recipes/GraphViewHelper.hpp b/include/aidge/recipes/GraphViewHelper.hpp index a2c571bf4..3b8ba7627 100644 --- a/include/aidge/recipes/GraphViewHelper.hpp +++ b/include/aidge/recipes/GraphViewHelper.hpp @@ -39,8 +39,6 @@ std::set<std::shared_ptr<Tensor>> producers(std::shared_ptr<GraphView> graphview */ std::set<std::shared_ptr<Tensor>> parameters(std::shared_ptr<GraphView> graphview); -void compile_gradient(std::shared_ptr<Aidge::GraphView> gv); - } // namespace Aidge #endif /* AIDGE_CORE_UTILS_GRAPHVIEWHELPER_H_ */ diff --git a/include/aidge/scheduler/SequentialScheduler.hpp b/include/aidge/scheduler/SequentialScheduler.hpp index a7929fde8..35dafead6 100644 --- a/include/aidge/scheduler/SequentialScheduler.hpp +++ b/include/aidge/scheduler/SequentialScheduler.hpp @@ -54,7 +54,7 @@ public: /** * @brief Run the provided Computational Graph with a batch of data */ - void backward(bool instantiateGrad = true); + void backward(); private: SchedulingPolicy mSchedulingPolicy; diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp index 005175ab6..1acd42b96 100644 --- a/python_binding/data/pybind_Tensor.cpp +++ b/python_binding/data/pybind_Tensor.cpp @@ -83,7 +83,6 @@ void init_Tensor(py::module& m){ .def("grad", &Tensor::grad) .def("set_grad", &Tensor::setGrad) .def("dtype", &Tensor::dataType) - .def("init_grad", &Tensor::initGrad) .def("size", &Tensor::size) .def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&, std::vector<DimSize_t>)) &Tensor::resize) .def("has_impl", &Tensor::hasImpl) diff --git a/python_binding/recipes/pybind_GraphViewHelper.cpp b/python_binding/recipes/pybind_GraphViewHelper.cpp index e65b790d3..ac56fb4b4 100644 --- a/python_binding/recipes/pybind_GraphViewHelper.cpp +++ b/python_binding/recipes/pybind_GraphViewHelper.cpp @@ -24,6 +24,5 @@ namespace py = pybind11; namespace Aidge { void init_GraphViewHelper(py::module &m) { m.def("producers", &producers, py::arg("graphview")); - m.def("compile_gradient", &compile_gradient, py::arg("graphview")); } } // namespace Aidge diff --git a/src/recipes/GraphViewHelper.cpp b/src/recipes/GraphViewHelper.cpp index b0c99bffb..9522c0fe7 100644 --- a/src/recipes/GraphViewHelper.cpp +++ b/src/recipes/GraphViewHelper.cpp @@ -44,14 +44,3 @@ std::set<std::shared_ptr<Aidge::Tensor>> Aidge::parameters(std::shared_ptr<Aidge } return res; } - -void Aidge::compile_gradient(std::shared_ptr<Aidge::GraphView> gv) { - for (const auto& node : gv->getNodes()) { - // TODO: check that each node is an OperatorTensor - AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Cannot instanciate gradient of an Operator ({}) that doesn't use Tensor.", node->getOperator()->type()); - const std::shared_ptr<OperatorTensor> op = std::dynamic_pointer_cast<OperatorTensor>(node -> getOperator()); - for (std::size_t o = 0; o < node -> nbOutputs(); ++o) { - op->getOutput(o)->initGrad(); - } - } -} diff --git a/src/scheduler/SequentialScheduler.cpp b/src/scheduler/SequentialScheduler.cpp index 74b1b3f0c..88b5e98bc 100644 --- a/src/scheduler/SequentialScheduler.cpp +++ b/src/scheduler/SequentialScheduler.cpp @@ -73,10 +73,7 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, const std::vector<std } } -void Aidge::SequentialScheduler::backward(bool instanciateGrad) { - // create ad set Grad values - if (instanciateGrad) { compile_gradient(mGraphView); } - +void Aidge::SequentialScheduler::backward() { // TODO: Check output grad are not empty // Generate scheduling *only if empty* -- GitLab