diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index ead6c19fa5fe1e91ec1c24cf8dfee6146390477f..ee4d9cdcb6638c15ecffcb5d86de00fca62046e1 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -554,16 +554,11 @@ public: inline void print() const { fmt::print("{}\n", toString()); } std::shared_ptr<Tensor> grad() { - // if (!mGrad && mImpl) { - // mGrad = std::make_shared<Tensor>(mDims); - // mGrad->setDataType(mDataType); - // mGrad->setBackend(mImpl->backend()); - - // // if (mImpl) mGrad->setBackend(mImpl->backend()); - // } - return mGrad; } + void setGrad(std::shared_ptr<Tensor> newGrad) { + mGrad = newGrad; + } /** * @brief Associate the gradient with a Tensor instance and set its implementation diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp index b97af94ad583cf42e25fa3afc0697021f6dcadcc..a21aa8be5a68ef32a4735eacb5701670a2d6a56c 100644 --- a/python_binding/data/pybind_Tensor.cpp +++ b/python_binding/data/pybind_Tensor.cpp @@ -77,7 +77,9 @@ void init_Tensor(py::module& m){ .def("set_backend", &Tensor::setBackend, py::arg("name"), py::arg("device") = 0, py::arg("copyFrom") = true) .def("dims", (const std::vector<DimSize_t>& (Tensor::*)()const) &Tensor::dims) .def("grad", &Tensor::grad) + .def("set_grad", &Tensor::setGrad) .def("dtype", &Tensor::dataType) + .def("init_gradient", &Tensor::initGradient) .def("size", &Tensor::size) .def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&, std::vector<DimSize_t>)) &Tensor::resize) .def("has_impl", &Tensor::hasImpl)