Skip to content
Snippets Groups Projects
Commit a3d47a6c authored by Cyril Moineau's avatar Cyril Moineau Committed by Cyril Moineau
Browse files

[fix] producers helper function always retrieved constant producer, by default...

[fix] producers helper function always retrieved constant producer, by default this is not the case anymore, added a constant argument if we need constant producers.
parent c5e2286a
No related branches found
No related tags found
2 merge requests!279v0.4.0,!257[fix] producers helper function always retrieved constant producer, by default...
Pipeline #59769 passed
...@@ -22,11 +22,12 @@ ...@@ -22,11 +22,12 @@
namespace Aidge { namespace Aidge {
/** /**
* @brief Getter for every Producer operator in a GraphView. * @brief Getter for every Tensor held by a Producer operator in a GraphView.
* @param graphview GraphView instance where Producers should be searched. * @param graphview GraphView instance where Producers should be searched.
* @param constant If true, Producer with attribute ``constant=true`` are also included in the returned set, default=false
* @return std::set<std::shared_ptr<Node>> * @return std::set<std::shared_ptr<Node>>
*/ */
std::set<std::shared_ptr<Tensor>> producers(std::shared_ptr<GraphView> graphview); std::set<std::shared_ptr<Tensor>> producers(std::shared_ptr<GraphView> graphview, bool constant=false);
// TODO: change for every Tensor of Operator Producer not constant // TODO: change for every Tensor of Operator Producer not constant
......
...@@ -23,6 +23,6 @@ namespace py = pybind11; ...@@ -23,6 +23,6 @@ namespace py = pybind11;
namespace Aidge { namespace Aidge {
void init_GraphViewHelper(py::module &m) { void init_GraphViewHelper(py::module &m) {
m.def("producers", &producers, py::arg("graphview")); m.def("producers", &producers, py::arg("graphview"), py::arg("constant")=false);
} }
} // namespace Aidge } // namespace Aidge
...@@ -16,17 +16,20 @@ ...@@ -16,17 +16,20 @@
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp" #include "aidge/graph/GraphView.hpp"
#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
#include "aidge/recipes/GraphViewHelper.hpp" #include "aidge/recipes/GraphViewHelper.hpp"
std::set<std::shared_ptr<Aidge::Tensor>> Aidge::producers(std::shared_ptr<Aidge::GraphView> graphview) { std::set<std::shared_ptr<Aidge::Tensor>> Aidge::producers(std::shared_ptr<Aidge::GraphView> graphview, bool constant) {
std::set<std::shared_ptr<Tensor>> res; std::set<std::shared_ptr<Tensor>> res;
const auto& nodes = graphview->getNodes(); const auto& nodes = graphview->getNodes();
for (const auto& node : nodes) { for (const auto& node : nodes) {
if (node->type() == "Producer") { if (node->type() == "Producer") {
const auto& param = std::static_pointer_cast<OperatorTensor>(node->getOperator()); const auto& producer = std::static_pointer_cast<Producer_Op>(node->getOperator());
res.insert(param->getOutput(0)); if (!producer->constant() || constant) {
res.insert(producer->getOutput(0));
}
} }
} }
return res; return res;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment