From a3d47a6c048bb012f67842339062bda1599ec1e6 Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Tue, 19 Nov 2024 16:30:25 +0000
Subject: [PATCH] [fix] producers helper function always retrieved constant
 producer, by default this is not the case anymore, added a constant argument
 if we need constant producers.

---
 include/aidge/recipes/GraphViewHelper.hpp         | 5 +++--
 python_binding/recipes/pybind_GraphViewHelper.cpp | 2 +-
 src/recipes/GraphViewHelper.cpp                   | 9 ++++++---
 3 files changed, 10 insertions(+), 6 deletions(-)

diff --git a/include/aidge/recipes/GraphViewHelper.hpp b/include/aidge/recipes/GraphViewHelper.hpp
index 3b8ba7627..bfd0a5364 100644
--- a/include/aidge/recipes/GraphViewHelper.hpp
+++ b/include/aidge/recipes/GraphViewHelper.hpp
@@ -22,11 +22,12 @@
 namespace Aidge {
 
 /**
- * @brief Getter for every Producer operator in a GraphView.
+ * @brief Getter for every Tensor held by a Producer operator in a GraphView.
  * @param graphview GraphView instance where Producers should be searched.
+ * @param constant If true, Producer with attribute ``constant=true`` are also included in the returned set, default=false
  * @return std::set<std::shared_ptr<Node>>
  */
-std::set<std::shared_ptr<Tensor>> producers(std::shared_ptr<GraphView> graphview);
+std::set<std::shared_ptr<Tensor>> producers(std::shared_ptr<GraphView> graphview, bool constant=false);
 
 
 // TODO: change for every Tensor of Operator Producer not constant
diff --git a/python_binding/recipes/pybind_GraphViewHelper.cpp b/python_binding/recipes/pybind_GraphViewHelper.cpp
index ac56fb4b4..eee5b6a61 100644
--- a/python_binding/recipes/pybind_GraphViewHelper.cpp
+++ b/python_binding/recipes/pybind_GraphViewHelper.cpp
@@ -23,6 +23,6 @@ namespace py = pybind11;
 
 namespace Aidge {
 void init_GraphViewHelper(py::module &m) {
-    m.def("producers", &producers, py::arg("graphview"));
+    m.def("producers", &producers, py::arg("graphview"), py::arg("constant")=false);
 }
 } // namespace Aidge
diff --git a/src/recipes/GraphViewHelper.cpp b/src/recipes/GraphViewHelper.cpp
index 9522c0fe7..bbc6524cc 100644
--- a/src/recipes/GraphViewHelper.cpp
+++ b/src/recipes/GraphViewHelper.cpp
@@ -16,17 +16,20 @@
 #include "aidge/graph/Node.hpp"
 #include "aidge/graph/GraphView.hpp"
 #include "aidge/operator/OperatorTensor.hpp"
+#include "aidge/operator/Producer.hpp"
 #include "aidge/utils/ErrorHandling.hpp"
 #include "aidge/recipes/GraphViewHelper.hpp"
 
 
-std::set<std::shared_ptr<Aidge::Tensor>> Aidge::producers(std::shared_ptr<Aidge::GraphView> graphview) {
+std::set<std::shared_ptr<Aidge::Tensor>> Aidge::producers(std::shared_ptr<Aidge::GraphView> graphview, bool constant) {
     std::set<std::shared_ptr<Tensor>> res;
     const auto& nodes = graphview->getNodes();
     for (const auto& node : nodes) {
         if (node->type() == "Producer") {
-            const auto& param = std::static_pointer_cast<OperatorTensor>(node->getOperator());
-            res.insert(param->getOutput(0));
+            const auto& producer = std::static_pointer_cast<Producer_Op>(node->getOperator());
+            if (!producer->constant() || constant) {
+                res.insert(producer->getOutput(0));
+            }
         }
     }
     return res;
-- 
GitLab