diff --git a/python_binding/recipies/pybind_Recipies.cpp b/python_binding/recipies/pybind_Recipies.cpp index 4746aeacf2bf6ad43306b60660ecd0bace0a7ab5..8f76d53955d27eaade320d1b933b6cc640217b58 100644 --- a/python_binding/recipies/pybind_Recipies.cpp +++ b/python_binding/recipies/pybind_Recipies.cpp @@ -21,7 +21,7 @@ namespace py = pybind11; namespace Aidge { void init_Recipies(py::module &m) { m.def("fuse_mul_add", &fuseMulAdd, py::arg("nodes")); - m.def("remove_flatten", &removeFlatten, py::arg("view")); + m.def("remove_flatten", &removeFlatten, py::arg("nodes")); } } // namespace Aidge diff --git a/src/recipies/RemoveFlatten.cpp b/src/recipies/RemoveFlatten.cpp index 23a9d645af52d0fe6af0b4d03f773a550b104e0c..cc3c3324e40636a1edcbc73cdc4a9dcfeec8a026 100644 --- a/src/recipies/RemoveFlatten.cpp +++ b/src/recipies/RemoveFlatten.cpp @@ -16,13 +16,9 @@ #include "aidge/utils/Recipies.hpp" namespace Aidge { - void removeFlatten(std::shared_ptr<GraphView> view) { - for (auto& nodePtr : view->getNodes()) { - if (nodePtr->type() == "Flatten") { - auto g = std::make_shared<GraphView>(); - g->add(std::set<std::shared_ptr<Node>>({nodePtr})); - g->replaceWith({}); - } - } + void removeFlatten(std::set<std::shared_ptr<Node>> nodes) { + auto g = std::make_shared<GraphView>(); + g->add(std::set<std::shared_ptr<Node>>({nodes})); + g->replaceWith({}); } } \ No newline at end of file