From 0356759aa624814c3d8565af7cdb07cffbd8e7ff Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Fri, 3 May 2024 09:10:53 +0000 Subject: [PATCH] Add setGrad method and bind initGrad method. --- include/aidge/data/Tensor.hpp | 11 +++-------- python_binding/data/pybind_Tensor.cpp | 2 ++ 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index ead6c19fa..ee4d9cdcb 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -554,16 +554,11 @@ public: inline void print() const { fmt::print("{}\n", toString()); } std::shared_ptr<Tensor> grad() { - // if (!mGrad && mImpl) { - // mGrad = std::make_shared<Tensor>(mDims); - // mGrad->setDataType(mDataType); - // mGrad->setBackend(mImpl->backend()); - - // // if (mImpl) mGrad->setBackend(mImpl->backend()); - // } - return mGrad; } + void setGrad(std::shared_ptr<Tensor> newGrad) { + mGrad = newGrad; + } /** * @brief Associate the gradient with a Tensor instance and set its implementation diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp index b97af94ad..a21aa8be5 100644 --- a/python_binding/data/pybind_Tensor.cpp +++ b/python_binding/data/pybind_Tensor.cpp @@ -77,7 +77,9 @@ void init_Tensor(py::module& m){ .def("set_backend", &Tensor::setBackend, py::arg("name"), py::arg("device") = 0, py::arg("copyFrom") = true) .def("dims", (const std::vector<DimSize_t>& (Tensor::*)()const) &Tensor::dims) .def("grad", &Tensor::grad) + .def("set_grad", &Tensor::setGrad) .def("dtype", &Tensor::dataType) + .def("init_gradient", &Tensor::initGradient) .def("size", &Tensor::size) .def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&, std::vector<DimSize_t>)) &Tensor::resize) .def("has_impl", &Tensor::hasImpl) -- GitLab