diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index 5f6be6045167f6ff523876aaa309a536683810de..6c18061014b0c1360a90bdf90229d147fa1dc6c8 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -591,23 +591,16 @@ public: inline void print() const { fmt::print("{}\n", toString()); } - std::shared_ptr<Tensor> grad() { - return mGrad; - } - void setGrad(std::shared_ptr<Tensor> newGrad) { - mGrad = newGrad; - } - /** - * @brief Associate the gradient with a Tensor instance and set its implementation - * if none was previously set. + * @brief Get the gradient Tensor. If not initialized, set a Tensor instance + * and set its implementation if none was previously set. * @note Dimensions for the Tensor instance are copied from the original current Tensor. * @note If a Tensor instance was already associated, only the implementation is created * with values set to 0. * @note If Tensor instance and implementation already existed for the gradient * nothing is done. */ - void initGrad() { + std::shared_ptr<Tensor> grad() { if (!mGrad) { mGrad = std::make_shared<Tensor>(mDims); } @@ -617,6 +610,11 @@ public: mGrad->setBackend(hasImpl() ? mImpl->backend() : "cpu"); mGrad->zeros(); } + return mGrad; + } + + void setGrad(std::shared_ptr<Tensor> newGrad) { + mGrad = newGrad; } /** diff --git a/include/aidge/recipes/GraphViewHelper.hpp b/include/aidge/recipes/GraphViewHelper.hpp index a2c571bf4ed164729f7c3416c814b913b4d07e6f..3b8ba7627362c945a6bfbe587ec952fdda013e98 100644 --- a/include/aidge/recipes/GraphViewHelper.hpp +++ b/include/aidge/recipes/GraphViewHelper.hpp @@ -39,8 +39,6 @@ std::set<std::shared_ptr<Tensor>> producers(std::shared_ptr<GraphView> graphview */ std::set<std::shared_ptr<Tensor>> parameters(std::shared_ptr<GraphView> graphview); -void compile_gradient(std::shared_ptr<Aidge::GraphView> gv); - } // namespace Aidge #endif /* AIDGE_CORE_UTILS_GRAPHVIEWHELPER_H_ */ diff --git a/include/aidge/scheduler/SequentialScheduler.hpp b/include/aidge/scheduler/SequentialScheduler.hpp index a7929fde8a2affdd562d70d11a7c809aaf3357d0..35dafead6dc424550df7d83d54f5ec998c3b4d86 100644 --- a/include/aidge/scheduler/SequentialScheduler.hpp +++ b/include/aidge/scheduler/SequentialScheduler.hpp @@ -54,7 +54,7 @@ public: /** * @brief Run the provided Computational Graph with a batch of data */ - void backward(bool instantiateGrad = true); + void backward(); private: SchedulingPolicy mSchedulingPolicy; diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp index 005175ab613594c48959073c4674e6d69b60b29f..1acd42b964cea22aaf0f8493efbb2e8f8fe751fd 100644 --- a/python_binding/data/pybind_Tensor.cpp +++ b/python_binding/data/pybind_Tensor.cpp @@ -83,7 +83,6 @@ void init_Tensor(py::module& m){ .def("grad", &Tensor::grad) .def("set_grad", &Tensor::setGrad) .def("dtype", &Tensor::dataType) - .def("init_grad", &Tensor::initGrad) .def("size", &Tensor::size) .def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&, std::vector<DimSize_t>)) &Tensor::resize) .def("has_impl", &Tensor::hasImpl) diff --git a/python_binding/recipes/pybind_GraphViewHelper.cpp b/python_binding/recipes/pybind_GraphViewHelper.cpp index e65b790d3eba6072e3e1b112c7d841959d4a5672..ac56fb4b43eb5b0a737157ec9e64c6771a692816 100644 --- a/python_binding/recipes/pybind_GraphViewHelper.cpp +++ b/python_binding/recipes/pybind_GraphViewHelper.cpp @@ -24,6 +24,5 @@ namespace py = pybind11; namespace Aidge { void init_GraphViewHelper(py::module &m) { m.def("producers", &producers, py::arg("graphview")); - m.def("compile_gradient", &compile_gradient, py::arg("graphview")); } } // namespace Aidge diff --git a/src/recipes/GraphViewHelper.cpp b/src/recipes/GraphViewHelper.cpp index b0c99bffb895dc64b20d76991911ae5f4b604c85..9522c0fe7346e78875a08d3ebf19a04dea2909e1 100644 --- a/src/recipes/GraphViewHelper.cpp +++ b/src/recipes/GraphViewHelper.cpp @@ -44,14 +44,3 @@ std::set<std::shared_ptr<Aidge::Tensor>> Aidge::parameters(std::shared_ptr<Aidge } return res; } - -void Aidge::compile_gradient(std::shared_ptr<Aidge::GraphView> gv) { - for (const auto& node : gv->getNodes()) { - // TODO: check that each node is an OperatorTensor - AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Cannot instanciate gradient of an Operator ({}) that doesn't use Tensor.", node->getOperator()->type()); - const std::shared_ptr<OperatorTensor> op = std::dynamic_pointer_cast<OperatorTensor>(node -> getOperator()); - for (std::size_t o = 0; o < node -> nbOutputs(); ++o) { - op->getOutput(o)->initGrad(); - } - } -} diff --git a/src/scheduler/SequentialScheduler.cpp b/src/scheduler/SequentialScheduler.cpp index 74b1b3f0c6e9be164792460669821744661c15b3..88b5e98bc62456bd59dc235c3112396daaeddd24 100644 --- a/src/scheduler/SequentialScheduler.cpp +++ b/src/scheduler/SequentialScheduler.cpp @@ -73,10 +73,7 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, const std::vector<std } } -void Aidge::SequentialScheduler::backward(bool instanciateGrad) { - // create ad set Grad values - if (instanciateGrad) { compile_gradient(mGraphView); } - +void Aidge::SequentialScheduler::backward() { // TODO: Check output grad are not empty // Generate scheduling *only if empty*