From 8e7fa67f2050cf716dcbf37dc8e64da9941994d0 Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Tue, 26 Mar 2024 15:42:33 +0000
Subject: [PATCH] Upd compile_gradient function

---
 src/recipes/GraphViewHelper.cpp | 11 +++++------
 1 file changed, 5 insertions(+), 6 deletions(-)

diff --git a/src/recipes/GraphViewHelper.cpp b/src/recipes/GraphViewHelper.cpp
index 79f255ac3..3b42db7fe 100644
--- a/src/recipes/GraphViewHelper.cpp
+++ b/src/recipes/GraphViewHelper.cpp
@@ -12,11 +12,12 @@
 #include <memory>
 #include <set>
 
+#include "aidge/data/Tensor.hpp"
 #include "aidge/graph/Node.hpp"
 #include "aidge/graph/GraphView.hpp"
 #include "aidge/operator/OperatorTensor.hpp"
 #include "aidge/utils/ErrorHandling.hpp"
-#include "aidge/recipies/GraphViewHelper.hpp"
+#include "aidge/recipes/GraphViewHelper.hpp"
 
 
 std::set<std::shared_ptr<Aidge::Tensor>> Aidge::producers(std::shared_ptr<Aidge::GraphView> graphview) {
@@ -47,12 +48,10 @@ std::set<std::shared_ptr<Aidge::Tensor>> Aidge::parameters(std::shared_ptr<Aidge
 void Aidge::compile_gradient(std::shared_ptr<Aidge::GraphView> gv) {
     for (const auto& node : gv->getNodes()) {
         // TODO: check that each node is an OperatorTensor
-        AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Cannot instanciate gradient of an Operator that doesn't use Tensor.");
-        const std::shared_ptr<OperatorTensor> op = std::static_pointer_cast<OperatorTensor>(node -> getOperator());
+        AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Cannot instanciate gradient of an Operator ({}) that doesn't use Tensor.", node->getOperator()->type());
+        const std::shared_ptr<OperatorTensor> op = std::dynamic_pointer_cast<OperatorTensor>(node -> getOperator());
         for (std::size_t o = 0; o < node -> nbOutputs(); ++o) {
-           const auto& t = op->getOutput(o);
-           t -> grad() -> setDataType(t -> dataType());
-           t -> grad() -> setBackend(t -> getImpl() -> backend());
+            op->getOutput(o)->initGradient();
         }
     }
 }
\ No newline at end of file
-- 
GitLab