From 0356759aa624814c3d8565af7cdb07cffbd8e7ff Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Fri, 3 May 2024 09:10:53 +0000
Subject: [PATCH] Add setGrad method and bind initGrad method.

---
 include/aidge/data/Tensor.hpp         | 11 +++--------
 python_binding/data/pybind_Tensor.cpp |  2 ++
 2 files changed, 5 insertions(+), 8 deletions(-)

diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp
index ead6c19fa..ee4d9cdcb 100644
--- a/include/aidge/data/Tensor.hpp
+++ b/include/aidge/data/Tensor.hpp
@@ -554,16 +554,11 @@ public:
     inline void print() const { fmt::print("{}\n", toString()); }
 
     std::shared_ptr<Tensor> grad() {
-        // if (!mGrad && mImpl) {
-        //     mGrad = std::make_shared<Tensor>(mDims);
-        //     mGrad->setDataType(mDataType);
-        //     mGrad->setBackend(mImpl->backend());
-
-        //     // if (mImpl) mGrad->setBackend(mImpl->backend());
-        // }
-
         return mGrad;
     }
+    void setGrad(std::shared_ptr<Tensor> newGrad) {
+        mGrad = newGrad;
+    }
 
     /**
      * @brief Associate the gradient with a Tensor instance and set its implementation
diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp
index b97af94ad..a21aa8be5 100644
--- a/python_binding/data/pybind_Tensor.cpp
+++ b/python_binding/data/pybind_Tensor.cpp
@@ -77,7 +77,9 @@ void init_Tensor(py::module& m){
     .def("set_backend", &Tensor::setBackend, py::arg("name"), py::arg("device") = 0, py::arg("copyFrom") = true)
     .def("dims", (const std::vector<DimSize_t>& (Tensor::*)()const) &Tensor::dims)
     .def("grad", &Tensor::grad)
+    .def("set_grad", &Tensor::setGrad)
     .def("dtype", &Tensor::dataType)
+    .def("init_gradient", &Tensor::initGradient)
     .def("size", &Tensor::size)
     .def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&, std::vector<DimSize_t>)) &Tensor::resize)
     .def("has_impl", &Tensor::hasImpl)
-- 
GitLab