From 95a2f4da366c6453800f78422e5b25d8c709befe Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Tue, 6 Feb 2024 13:47:45 +0000
Subject: [PATCH] [Add] features

[function] 'parameters()' to extract parameters of type Producer from a GraphView
[function] 'instanciateGraphView()' to initialize Tensors gradient with the same datatype/backend
---
 include/aidge/recipies/GraphViewHelper.hpp | 41 +++++++++++++++++++++-
 1 file changed, 40 insertions(+), 1 deletion(-)

diff --git a/include/aidge/recipies/GraphViewHelper.hpp b/include/aidge/recipies/GraphViewHelper.hpp
index d7bcec713..14f59db9f 100644
--- a/include/aidge/recipies/GraphViewHelper.hpp
+++ b/include/aidge/recipies/GraphViewHelper.hpp
@@ -17,6 +17,8 @@
 
 #include "aidge/graph/Node.hpp"
 #include "aidge/graph/GraphView.hpp"
+#include "aidge/operator/OperatorTensor.hpp"
+#include "aidge/utils/ErrorHandling.hpp"
 
 
 namespace Aidge {
@@ -37,4 +39,41 @@ std::set<std::shared_ptr<Aidge::Node>> producers(std::shared_ptr<Aidge::GraphVie
 
     return res;
 }
-} // namespace Aidge
\ No newline at end of file
+
+/**
+ * @brief Getter for every Producer operator in a GraphView that is a parameter.
+ * @param graphview GraphView instance where Producers should be searched.
+ * @return std::set<std::shared_ptr<Node>>
+ */
+std::set<std::shared_ptr<Aidge::Node>> parameters(std::shared_ptr<Aidge::GraphView> graphview) {
+    std::set<std::shared_ptr<Node>> res;
+    const std::set<std::shared_ptr<Node>> nodes = graphview->getNodes();
+
+    for (auto it = nodes.cbegin(); it != nodes.cend(); ++it) {
+        for (std::size_t inID = (*it)->nbData(); inID < (*it)->nbInputs(); ++inID) {
+            const std::shared_ptr<Node>& parent = (*it)->getParent(inID);
+            if (parent && parent->type() == "Producer") {
+                res.insert(parent);
+            }
+        }
+    }
+
+    return res;
+}
+
+void instanciateGradient(std::shared_ptr<Aidge::GraphView> gv) {
+    for (const auto& node : gv->getNodes()) {
+        // TODO: check that each node is an OperatorTensor
+        AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Cannot instanciate gradient of an Operator that doesn't use Tensor.");
+        const std::shared_ptr<OperatorTensor> op = std::static_pointer_cast<OperatorTensor>(node -> getOperator());
+        for (std::size_t o = 0; o < node -> nbOutputs(); ++o) {
+           const auto& t = op->getOutput(o);
+           t -> grad() -> setDataType(t -> dataType());
+           t -> grad() -> setBackend(t -> getImpl() -> backend());
+        }
+    }
+}
+
+} // namespace Aidge
+
+#endif /* AIDGE_CORE_UTILS_RECIPIES_H_ */
\ No newline at end of file
-- 
GitLab