Skip to content
Snippets Groups Projects
Commit ecc96977 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added lazy init for grad

parent 064955b6
No related branches found
No related tags found
2 merge requests!152Update Aidge export to take a graph view has an argument instead of a...,!143Multiple refactors
Pipeline #48472 failed
......@@ -591,23 +591,16 @@ public:
inline void print() const { fmt::print("{}\n", toString()); }
std::shared_ptr<Tensor> grad() {
return mGrad;
}
void setGrad(std::shared_ptr<Tensor> newGrad) {
mGrad = newGrad;
}
/**
* @brief Associate the gradient with a Tensor instance and set its implementation
* if none was previously set.
* @brief Get the gradient Tensor. If not initialized, set a Tensor instance
* and set its implementation if none was previously set.
* @note Dimensions for the Tensor instance are copied from the original current Tensor.
* @note If a Tensor instance was already associated, only the implementation is created
* with values set to 0.
* @note If Tensor instance and implementation already existed for the gradient
* nothing is done.
*/
void initGrad() {
std::shared_ptr<Tensor> grad() {
if (!mGrad) {
mGrad = std::make_shared<Tensor>(mDims);
}
......@@ -617,6 +610,11 @@ public:
mGrad->setBackend(hasImpl() ? mImpl->backend() : "cpu");
mGrad->zeros();
}
return mGrad;
}
void setGrad(std::shared_ptr<Tensor> newGrad) {
mGrad = newGrad;
}
/**
......
......@@ -39,8 +39,6 @@ std::set<std::shared_ptr<Tensor>> producers(std::shared_ptr<GraphView> graphview
*/
std::set<std::shared_ptr<Tensor>> parameters(std::shared_ptr<GraphView> graphview);
void compile_gradient(std::shared_ptr<Aidge::GraphView> gv);
} // namespace Aidge
#endif /* AIDGE_CORE_UTILS_GRAPHVIEWHELPER_H_ */
......@@ -54,7 +54,7 @@ public:
/**
* @brief Run the provided Computational Graph with a batch of data
*/
void backward(bool instantiateGrad = true);
void backward();
private:
SchedulingPolicy mSchedulingPolicy;
......
......@@ -83,7 +83,6 @@ void init_Tensor(py::module& m){
.def("grad", &Tensor::grad)
.def("set_grad", &Tensor::setGrad)
.def("dtype", &Tensor::dataType)
.def("init_grad", &Tensor::initGrad)
.def("size", &Tensor::size)
.def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&, std::vector<DimSize_t>)) &Tensor::resize)
.def("has_impl", &Tensor::hasImpl)
......
......@@ -24,6 +24,5 @@ namespace py = pybind11;
namespace Aidge {
void init_GraphViewHelper(py::module &m) {
m.def("producers", &producers, py::arg("graphview"));
m.def("compile_gradient", &compile_gradient, py::arg("graphview"));
}
} // namespace Aidge
......@@ -44,14 +44,3 @@ std::set<std::shared_ptr<Aidge::Tensor>> Aidge::parameters(std::shared_ptr<Aidge
}
return res;
}
void Aidge::compile_gradient(std::shared_ptr<Aidge::GraphView> gv) {
for (const auto& node : gv->getNodes()) {
// TODO: check that each node is an OperatorTensor
AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Cannot instanciate gradient of an Operator ({}) that doesn't use Tensor.", node->getOperator()->type());
const std::shared_ptr<OperatorTensor> op = std::dynamic_pointer_cast<OperatorTensor>(node -> getOperator());
for (std::size_t o = 0; o < node -> nbOutputs(); ++o) {
op->getOutput(o)->initGrad();
}
}
}
......@@ -73,10 +73,7 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, const std::vector<std
}
}
void Aidge::SequentialScheduler::backward(bool instanciateGrad) {
// create ad set Grad values
if (instanciateGrad) { compile_gradient(mGraphView); }
void Aidge::SequentialScheduler::backward() {
// TODO: Check output grad are not empty
// Generate scheduling *only if empty*
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment