From a9e44123c1db68396a620cb3895403d76940af44 Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Thu, 21 Mar 2024 14:12:47 +0000
Subject: [PATCH] Upd GraphViewHelper functions to return Tensors instead of
 Nodes

---
 include/aidge/recipes/GraphViewHelper.hpp | 32 +++++++--------
 src/recipes/GraphViewHelper.cpp           | 47 +++++++++++++++++++++++
 2 files changed, 64 insertions(+), 15 deletions(-)
 create mode 100644 src/recipes/GraphViewHelper.cpp

diff --git a/include/aidge/recipes/GraphViewHelper.hpp b/include/aidge/recipes/GraphViewHelper.hpp
index c6204cdff..8fdf1e1d7 100644
--- a/include/aidge/recipes/GraphViewHelper.hpp
+++ b/include/aidge/recipes/GraphViewHelper.hpp
@@ -9,14 +9,14 @@
  *
  ********************************************************************************/
 
-#ifndef AIDGE_CORE_UTILS_RECIPES_H_
-#define AIDGE_CORE_UTILS_RECIPES_H_
+#ifndef AIDGE_CORE_UTILS_GRAPHVIEWHELPER_H_
+#define AIDGE_CORE_UTILS_GRAPHVIEWHELPER_H_
 
 #include <memory>
 #include <set>
 
-#include "aidge/graph/Node.hpp"
 #include "aidge/graph/GraphView.hpp"
+#include "aidge/data/Tensor.hpp"
 
 
 namespace Aidge {
@@ -26,15 +26,17 @@ namespace Aidge {
  * @param graphview GraphView instance where Producers should be searched.
  * @return std::set<std::shared_ptr<Node>>
  */
-std::set<std::shared_ptr<Aidge::Node>> producers(std::shared_ptr<Aidge::GraphView> graphview) {
-    std::set<std::shared_ptr<Node>> res;
-    const std::set<std::shared_ptr<Node>> nodes = graphview->getNodes();
-
-    std::copy_if(nodes.cbegin(),
-                    nodes.cend(),
-                    std::inserter(res, res.begin()),
-                    [](std::shared_ptr<Node> n){ return n->type() == "Producer"; });
-
-    return res;
-}
-} // namespace Aidge
\ No newline at end of file
+std::set<std::shared_ptr<Tensor>> producers(std::shared_ptr<GraphView> graphview);
+
+/**
+ * @brief Getter for every ``Tensor`` owned by an ``Operator`` inside the provided ``GraphView``.
+ * @note An ``Operator`` owns its output ``Tensor``s.
+ *
+ * @param graphview Pointer to the ``GraphView`` from which ``Tensor``s should be extracted.
+ * @return std::set<std::shared_ptr<Tensor>> Set of pointers to the ``Tensor``s.
+ */
+std::set<std::shared_ptr<Tensor>> parameters(std::shared_ptr<GraphView> graphview);
+
+} // namespace Aidge
+
+#endif /* AIDGE_CORE_UTILS_GRAPHVIEWHELPER_H_ */
diff --git a/src/recipes/GraphViewHelper.cpp b/src/recipes/GraphViewHelper.cpp
new file mode 100644
index 000000000..ec58871b2
--- /dev/null
+++ b/src/recipes/GraphViewHelper.cpp
@@ -0,0 +1,47 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include "aidge/recipes/GraphViewHelper.hpp"
+
+#include <memory>
+#include <set>
+#include <vector>
+
+#include "aidge/data/Tensor.hpp"
+#include "aidge/graph/GraphView.hpp"
+#include "aidge/graph/Node.hpp"
+#include "aidge/operator/OperatorTensor.hpp"
+
+
+std::set<std::shared_ptr<Aidge::Tensor>> Aidge::producers(std::shared_ptr<Aidge::GraphView> graphview) {
+    std::set<std::shared_ptr<Tensor>> res;
+    const auto& nodes = graphview->getNodes();
+    for (const auto& node : nodes) {
+        if (node->type() == "Producer") {
+            const auto& param = std::static_pointer_cast<OperatorTensor>(node->getOperator());
+            res.insert(param->getOutput(0));
+        }
+    }
+    return res;
+}
+
+
+std::set<std::shared_ptr<Aidge::Tensor>> Aidge::parameters(std::shared_ptr<Aidge::GraphView> graphview) {
+    std::set<std::shared_ptr<Tensor>> res;
+    const auto& nodes = graphview->getNodes();
+    for (const auto& node : nodes) {
+        const auto& param = std::static_pointer_cast<OperatorTensor>(node->getOperator());
+        for (std::size_t o = 0; o < param->nbOutputs(); ++o) {
+            res.insert(param->getOutput(o));
+        }
+    }
+    return res;
+}
-- 
GitLab