diff --git a/include/aidge/recipes/GraphViewHelper.hpp b/include/aidge/recipes/GraphViewHelper.hpp index 3b8ba7627362c945a6bfbe587ec952fdda013e98..bfd0a5364f858c7518de4be0c18888d37ab669da 100644 --- a/include/aidge/recipes/GraphViewHelper.hpp +++ b/include/aidge/recipes/GraphViewHelper.hpp @@ -22,11 +22,12 @@ namespace Aidge { /** - * @brief Getter for every Producer operator in a GraphView. + * @brief Getter for every Tensor held by a Producer operator in a GraphView. * @param graphview GraphView instance where Producers should be searched. + * @param constant If true, Producer with attribute ``constant=true`` are also included in the returned set, default=false * @return std::set<std::shared_ptr<Node>> */ -std::set<std::shared_ptr<Tensor>> producers(std::shared_ptr<GraphView> graphview); +std::set<std::shared_ptr<Tensor>> producers(std::shared_ptr<GraphView> graphview, bool constant=false); // TODO: change for every Tensor of Operator Producer not constant diff --git a/python_binding/recipes/pybind_GraphViewHelper.cpp b/python_binding/recipes/pybind_GraphViewHelper.cpp index ac56fb4b43eb5b0a737157ec9e64c6771a692816..eee5b6a61771b4b0724699f1c45dad6c8a35f04d 100644 --- a/python_binding/recipes/pybind_GraphViewHelper.cpp +++ b/python_binding/recipes/pybind_GraphViewHelper.cpp @@ -23,6 +23,6 @@ namespace py = pybind11; namespace Aidge { void init_GraphViewHelper(py::module &m) { - m.def("producers", &producers, py::arg("graphview")); + m.def("producers", &producers, py::arg("graphview"), py::arg("constant")=false); } } // namespace Aidge diff --git a/src/recipes/GraphViewHelper.cpp b/src/recipes/GraphViewHelper.cpp index 9522c0fe7346e78875a08d3ebf19a04dea2909e1..bbc6524ccbe7158476a6bd846d77da754f186091 100644 --- a/src/recipes/GraphViewHelper.cpp +++ b/src/recipes/GraphViewHelper.cpp @@ -16,17 +16,20 @@ #include "aidge/graph/Node.hpp" #include "aidge/graph/GraphView.hpp" #include "aidge/operator/OperatorTensor.hpp" +#include "aidge/operator/Producer.hpp" #include "aidge/utils/ErrorHandling.hpp" #include "aidge/recipes/GraphViewHelper.hpp" -std::set<std::shared_ptr<Aidge::Tensor>> Aidge::producers(std::shared_ptr<Aidge::GraphView> graphview) { +std::set<std::shared_ptr<Aidge::Tensor>> Aidge::producers(std::shared_ptr<Aidge::GraphView> graphview, bool constant) { std::set<std::shared_ptr<Tensor>> res; const auto& nodes = graphview->getNodes(); for (const auto& node : nodes) { if (node->type() == "Producer") { - const auto& param = std::static_pointer_cast<OperatorTensor>(node->getOperator()); - res.insert(param->getOutput(0)); + const auto& producer = std::static_pointer_cast<Producer_Op>(node->getOperator()); + if (!producer->constant() || constant) { + res.insert(producer->getOutput(0)); + } } } return res;