From ef3acd297eeeb2b8ba129dbac2147b400cbec0b5 Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Wed, 7 Feb 2024 14:27:18 +0000
Subject: [PATCH] [Upd][WIP][NF] 'backward()' for SequentialScheduler

---
 include/aidge/recipies/GraphViewHelper.hpp | 44 ++-------------
 src/recipies/GraphViewHelper.cpp           | 62 ++++++++++++++++++++++
 src/scheduler/Scheduler.cpp                |  2 +-
 3 files changed, 66 insertions(+), 42 deletions(-)
 create mode 100644 src/recipies/GraphViewHelper.cpp

diff --git a/include/aidge/recipies/GraphViewHelper.hpp b/include/aidge/recipies/GraphViewHelper.hpp
index 14f59db9f..7cd5d662f 100644
--- a/include/aidge/recipies/GraphViewHelper.hpp
+++ b/include/aidge/recipies/GraphViewHelper.hpp
@@ -15,10 +15,7 @@
 #include <memory>
 #include <set>
 
-#include "aidge/graph/Node.hpp"
 #include "aidge/graph/GraphView.hpp"
-#include "aidge/operator/OperatorTensor.hpp"
-#include "aidge/utils/ErrorHandling.hpp"
 
 
 namespace Aidge {
@@ -28,51 +25,16 @@ namespace Aidge {
  * @param graphview GraphView instance where Producers should be searched.
  * @return std::set<std::shared_ptr<Node>>
  */
-std::set<std::shared_ptr<Aidge::Node>> producers(std::shared_ptr<Aidge::GraphView> graphview) {
-    std::set<std::shared_ptr<Node>> res;
-    const std::set<std::shared_ptr<Node>> nodes = graphview->getNodes();
-
-    std::copy_if(nodes.cbegin(),
-                    nodes.cend(),
-                    std::inserter(res, res.begin()),
-                    [](std::shared_ptr<Node> n){ return n->type() == "Producer"; });
-
-    return res;
-}
+std::set<std::shared_ptr<Aidge::Node>> producers(std::shared_ptr<Aidge::GraphView> graphview);
 
 /**
  * @brief Getter for every Producer operator in a GraphView that is a parameter.
  * @param graphview GraphView instance where Producers should be searched.
  * @return std::set<std::shared_ptr<Node>>
  */
-std::set<std::shared_ptr<Aidge::Node>> parameters(std::shared_ptr<Aidge::GraphView> graphview) {
-    std::set<std::shared_ptr<Node>> res;
-    const std::set<std::shared_ptr<Node>> nodes = graphview->getNodes();
-
-    for (auto it = nodes.cbegin(); it != nodes.cend(); ++it) {
-        for (std::size_t inID = (*it)->nbData(); inID < (*it)->nbInputs(); ++inID) {
-            const std::shared_ptr<Node>& parent = (*it)->getParent(inID);
-            if (parent && parent->type() == "Producer") {
-                res.insert(parent);
-            }
-        }
-    }
-
-    return res;
-}
+std::set<std::shared_ptr<Aidge::Node>> parameters(std::shared_ptr<Aidge::GraphView> graphview);
 
-void instanciateGradient(std::shared_ptr<Aidge::GraphView> gv) {
-    for (const auto& node : gv->getNodes()) {
-        // TODO: check that each node is an OperatorTensor
-        AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Cannot instanciate gradient of an Operator that doesn't use Tensor.");
-        const std::shared_ptr<OperatorTensor> op = std::static_pointer_cast<OperatorTensor>(node -> getOperator());
-        for (std::size_t o = 0; o < node -> nbOutputs(); ++o) {
-           const auto& t = op->getOutput(o);
-           t -> grad() -> setDataType(t -> dataType());
-           t -> grad() -> setBackend(t -> getImpl() -> backend());
-        }
-    }
-}
+void compile_gradient(std::shared_ptr<Aidge::GraphView> gv);
 
 } // namespace Aidge
 
diff --git a/src/recipies/GraphViewHelper.cpp b/src/recipies/GraphViewHelper.cpp
new file mode 100644
index 000000000..ac2cb1fdf
--- /dev/null
+++ b/src/recipies/GraphViewHelper.cpp
@@ -0,0 +1,62 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <memory>
+#include <set>
+
+#include "aidge/graph/Node.hpp"
+#include "aidge/graph/GraphView.hpp"
+#include "aidge/operator/OperatorTensor.hpp"
+#include "aidge/utils/ErrorHandling.hpp"
+#include "aidge/recipies/GraphViewHelper.hpp"
+
+
+std::set<std::shared_ptr<Aidge::Node>> Aidge::producers(std::shared_ptr<Aidge::GraphView> graphview) {
+    std::set<std::shared_ptr<Node>> res;
+    const std::set<std::shared_ptr<Node>> nodes = graphview->getNodes();
+
+    std::copy_if(nodes.cbegin(),
+                    nodes.cend(),
+                    std::inserter(res, res.begin()),
+                    [](std::shared_ptr<Node> n){ return n->type() == "Producer"; });
+
+    return res;
+}
+
+
+std::set<std::shared_ptr<Aidge::Node>> Aidge::parameters(std::shared_ptr<Aidge::GraphView> graphview) {
+    std::set<std::shared_ptr<Node>> res;
+    const std::set<std::shared_ptr<Node>> nodes = graphview->getNodes();
+
+    for (auto it = nodes.cbegin(); it != nodes.cend(); ++it) {
+        for (std::size_t inID = (*it)->nbData(); inID < (*it)->nbInputs(); ++inID) {
+            const std::shared_ptr<Node>& parent = (*it)->getParent(inID);
+            if (parent && parent->type() == "Producer") {
+                res.insert(parent);
+            }
+        }
+    }
+
+    return res;
+}
+
+void Aidge::compile_gradient(std::shared_ptr<Aidge::GraphView> gv) {
+    for (const auto& node : gv->getNodes()) {
+        // TODO: check that each node is an OperatorTensor
+        AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Cannot instanciate gradient of an Operator that doesn't use Tensor.");
+        const std::shared_ptr<OperatorTensor> op = std::static_pointer_cast<OperatorTensor>(node -> getOperator());
+        for (std::size_t o = 0; o < node -> nbOutputs(); ++o) {
+           const auto& t = op->getOutput(o);
+           t -> grad() -> setDataType(t -> dataType());
+           t -> grad() -> setBackend(t -> getImpl() -> backend());
+        }
+    }
+}
\ No newline at end of file
diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp
index 074f3a98e..d5a3d2764 100644
--- a/src/scheduler/Scheduler.cpp
+++ b/src/scheduler/Scheduler.cpp
@@ -211,7 +211,7 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) {
 
 void Aidge::SequentialScheduler::backward(bool instanciateGrad, bool verbose) {
     // Forward dims (if allowed)
-    if (instanciateGrad) {instanciateGradient(mGraphView); }
+    if (instanciateGrad) {compile_gradient(mGraphView); }
 
     // Generate scheduling *only if empty*
     // If scheduling was already generated (in one or several steps, i.e. one or
-- 
GitLab