From 8e7fa67f2050cf716dcbf37dc8e64da9941994d0 Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Tue, 26 Mar 2024 15:42:33 +0000 Subject: [PATCH] Upd compile_gradient function --- src/recipes/GraphViewHelper.cpp | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/recipes/GraphViewHelper.cpp b/src/recipes/GraphViewHelper.cpp index 79f255ac3..3b42db7fe 100644 --- a/src/recipes/GraphViewHelper.cpp +++ b/src/recipes/GraphViewHelper.cpp @@ -12,11 +12,12 @@ #include <memory> #include <set> +#include "aidge/data/Tensor.hpp" #include "aidge/graph/Node.hpp" #include "aidge/graph/GraphView.hpp" #include "aidge/operator/OperatorTensor.hpp" #include "aidge/utils/ErrorHandling.hpp" -#include "aidge/recipies/GraphViewHelper.hpp" +#include "aidge/recipes/GraphViewHelper.hpp" std::set<std::shared_ptr<Aidge::Tensor>> Aidge::producers(std::shared_ptr<Aidge::GraphView> graphview) { @@ -47,12 +48,10 @@ std::set<std::shared_ptr<Aidge::Tensor>> Aidge::parameters(std::shared_ptr<Aidge void Aidge::compile_gradient(std::shared_ptr<Aidge::GraphView> gv) { for (const auto& node : gv->getNodes()) { // TODO: check that each node is an OperatorTensor - AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Cannot instanciate gradient of an Operator that doesn't use Tensor."); - const std::shared_ptr<OperatorTensor> op = std::static_pointer_cast<OperatorTensor>(node -> getOperator()); + AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Cannot instanciate gradient of an Operator ({}) that doesn't use Tensor.", node->getOperator()->type()); + const std::shared_ptr<OperatorTensor> op = std::dynamic_pointer_cast<OperatorTensor>(node -> getOperator()); for (std::size_t o = 0; o < node -> nbOutputs(); ++o) { - const auto& t = op->getOutput(o); - t -> grad() -> setDataType(t -> dataType()); - t -> grad() -> setBackend(t -> getImpl() -> backend()); + op->getOutput(o)->initGradient(); } } } \ No newline at end of file -- GitLab