diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index ddc76115717ac50b34d7e436f1d26ddfe9271169..b09064c36f65a1e00d99ce5e2ff559e31681b065 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -350,7 +350,8 @@ public: * @param other_graph GraphView containing the Nodes to include. * @return true if graph ordering is unique (meaning inputs/outputs order is well defined). */ - bool add(std::shared_ptr<GraphView> otherGraph); + bool add(std::shared_ptr<GraphView> otherGraph, + bool includeLearnableParam = true); /** * @brief Include a Node in the current GraphView and link it to another diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp index 1000374454020625aada7f2043893b229deec833..376d61ae85b1166a62a392bee90324cf80d2c0f5 100644 --- a/python_binding/graph/pybind_GraphView.cpp +++ b/python_binding/graph/pybind_GraphView.cpp @@ -61,13 +61,15 @@ void init_GraphView(py::module& m) { :type include_learnable_parameters: bool, optional )mydelimiter") - .def("add", (bool (GraphView::*)(std::shared_ptr<GraphView>)) & GraphView::add, - py::arg("other_graph"), + .def("add", (bool (GraphView::*)(std::shared_ptr<GraphView>, bool)) & GraphView::add, + py::arg("other_graph"), py::arg("include_learnable_parameters") = true, R"mydelimiter( Include a GraphView to the current GraphView object. :param other_graph: GraphView to add :type other_graph: GraphView + :param include_learnable_parameters: include non-data inputs, like weights and biases, default True. + :type include_learnable_parameters: bool, optional )mydelimiter") .def("add_child", diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 73830ab323e6a60a6896f01c2c6050bb50498e4a..5124d41f575b0ebf7f3c6cf258900e0ae656d213 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -764,10 +764,10 @@ bool Aidge::GraphView::add(std::pair<NodePtr, std::set<NodePtr>> nodes, bool inc return add(nodes.second, includeLearnableParam); } -bool Aidge::GraphView::add(std::shared_ptr<GraphView> graph) { +bool Aidge::GraphView::add(std::shared_ptr<GraphView> graph, bool includeLearnableParam) { // set the rootNode to the other graphView rootNode if no rootNode yet mRootNode = mRootNode ? mRootNode : graph->rootNode(); - return add(graph->getNodes(), true); + return add(graph->getNodes(), includeLearnableParam); } void Aidge::GraphView::addChild(std::shared_ptr<Node> toOtherNode, diff --git a/src/recipes/FuseToMetaOps.cpp b/src/recipes/FuseToMetaOps.cpp index 198e3a44bc7663aea42554cd9f08b0bfc616a06d..e7748936c00a20ec235ea7853f4d17e2c10261fb 100644 --- a/src/recipes/FuseToMetaOps.cpp +++ b/src/recipes/FuseToMetaOps.cpp @@ -13,30 +13,20 @@ #include "aidge/graph/Node.hpp" #include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Matching.hpp" #include "aidge/operator/MetaOperator.hpp" #include "aidge/recipes/Recipes.hpp" -//Graph Regex -#include "aidge/graphRegex/GraphRegex.hpp" - size_t Aidge::fuseToMetaOps(std::shared_ptr<GraphView> graphView, const std::string& query, const std::string& type) { - std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); - regex->setKeyFromGraph(graphView); - regex->addQuery(query); - const auto metaType = (!type.empty()) ? type : query; + const auto matches = SinglePassGraphMatching(graphView).match(query); size_t nbReplaced = 0; - const auto matches = regex->match(graphView); - - for (const auto& solution : matches) { - auto microGraph = std::make_shared<GraphView>(); - microGraph->add(solution->getAll()); - - auto metaOp = MetaOperator(metaType.c_str(), microGraph->clone()); + for (const auto& match : matches) { + auto metaOp = MetaOperator(metaType.c_str(), match.graph->clone()); auto metaOpGraph = std::make_shared<GraphView>(); - metaOpGraph->add(metaOp); - const auto success = GraphView::replace(microGraph, metaOpGraph); + metaOpGraph->add(metaOp, false); + const auto success = GraphView::replace(match.graph, metaOpGraph); if (!success) { Log::notice("Could not replace sub-graph with meta operator");