Skip to content
Snippets Groups Projects
Commit 990a1382 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Rename initGradient -> initGrad for name consistency.

parent 7729a5a6
No related branches found
No related tags found
No related merge requests found
...@@ -569,7 +569,7 @@ public: ...@@ -569,7 +569,7 @@ public:
* @note If Tensor instance and implementation already existed for the gradient * @note If Tensor instance and implementation already existed for the gradient
* nothing is done. * nothing is done.
*/ */
void initGradient() { void initGrad() {
if (!mGrad) { if (!mGrad) {
mGrad = std::make_shared<Tensor>(mDims); mGrad = std::make_shared<Tensor>(mDims);
} }
......
/******************************************************************************** ./********************************************************************************
* Copyright (c) 2023 CEA-List * Copyright (c) 2023 CEA-List
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
...@@ -79,7 +79,7 @@ void init_Tensor(py::module& m){ ...@@ -79,7 +79,7 @@ void init_Tensor(py::module& m){
.def("grad", &Tensor::grad) .def("grad", &Tensor::grad)
.def("set_grad", &Tensor::setGrad) .def("set_grad", &Tensor::setGrad)
.def("dtype", &Tensor::dataType) .def("dtype", &Tensor::dataType)
.def("init_gradient", &Tensor::initGradient) .def("init_grad", &Tensor::initGrad)
.def("size", &Tensor::size) .def("size", &Tensor::size)
.def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&, std::vector<DimSize_t>)) &Tensor::resize) .def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&, std::vector<DimSize_t>)) &Tensor::resize)
.def("has_impl", &Tensor::hasImpl) .def("has_impl", &Tensor::hasImpl)
......
...@@ -51,7 +51,7 @@ void Aidge::compile_gradient(std::shared_ptr<Aidge::GraphView> gv) { ...@@ -51,7 +51,7 @@ void Aidge::compile_gradient(std::shared_ptr<Aidge::GraphView> gv) {
AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Cannot instanciate gradient of an Operator ({}) that doesn't use Tensor.", node->getOperator()->type()); AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Cannot instanciate gradient of an Operator ({}) that doesn't use Tensor.", node->getOperator()->type());
const std::shared_ptr<OperatorTensor> op = std::dynamic_pointer_cast<OperatorTensor>(node -> getOperator()); const std::shared_ptr<OperatorTensor> op = std::dynamic_pointer_cast<OperatorTensor>(node -> getOperator());
for (std::size_t o = 0; o < node -> nbOutputs(); ++o) { for (std::size_t o = 0; o < node -> nbOutputs(); ++o) {
op->getOutput(o)->initGradient(); op->getOutput(o)->initGrad();
} }
} }
} }
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment