From 859227aa62f50dd166f4e8249361abc7ff90bccb Mon Sep 17 00:00:00 2001 From: thibault allenet <thibault.allenet@cea.fr> Date: Wed, 2 Aug 2023 15:39:30 +0000 Subject: [PATCH] Change remove_flatten to take in a node instead of a graphView --- python_binding/recipies/pybind_Recipies.cpp | 2 +- src/recipies/RemoveFlatten.cpp | 12 ++++-------- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/python_binding/recipies/pybind_Recipies.cpp b/python_binding/recipies/pybind_Recipies.cpp index 4746aeacf..8f76d5395 100644 --- a/python_binding/recipies/pybind_Recipies.cpp +++ b/python_binding/recipies/pybind_Recipies.cpp @@ -21,7 +21,7 @@ namespace py = pybind11; namespace Aidge { void init_Recipies(py::module &m) { m.def("fuse_mul_add", &fuseMulAdd, py::arg("nodes")); - m.def("remove_flatten", &removeFlatten, py::arg("view")); + m.def("remove_flatten", &removeFlatten, py::arg("nodes")); } } // namespace Aidge diff --git a/src/recipies/RemoveFlatten.cpp b/src/recipies/RemoveFlatten.cpp index 23a9d645a..cc3c3324e 100644 --- a/src/recipies/RemoveFlatten.cpp +++ b/src/recipies/RemoveFlatten.cpp @@ -16,13 +16,9 @@ #include "aidge/utils/Recipies.hpp" namespace Aidge { - void removeFlatten(std::shared_ptr<GraphView> view) { - for (auto& nodePtr : view->getNodes()) { - if (nodePtr->type() == "Flatten") { - auto g = std::make_shared<GraphView>(); - g->add(std::set<std::shared_ptr<Node>>({nodePtr})); - g->replaceWith({}); - } - } + void removeFlatten(std::set<std::shared_ptr<Node>> nodes) { + auto g = std::make_shared<GraphView>(); + g->add(std::set<std::shared_ptr<Node>>({nodes})); + g->replaceWith({}); } } \ No newline at end of file -- GitLab