diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index ddc76115717ac50b34d7e436f1d26ddfe9271169..b09064c36f65a1e00d99ce5e2ff559e31681b065 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -350,7 +350,8 @@ public: * @param other_graph GraphView containing the Nodes to include. * @return true if graph ordering is unique (meaning inputs/outputs order is well defined). */ - bool add(std::shared_ptr<GraphView> otherGraph); + bool add(std::shared_ptr<GraphView> otherGraph, + bool includeLearnableParam = true); /** * @brief Include a Node in the current GraphView and link it to another diff --git a/include/aidge/graph/Node.hpp b/include/aidge/graph/Node.hpp index a00debad1b2efe4c9577f6e1a2e730d7011907e5..7e9cfe399a1e13f281c999fafcf7d823276b7670 100644 --- a/include/aidge/graph/Node.hpp +++ b/include/aidge/graph/Node.hpp @@ -252,7 +252,9 @@ public: inline std::set<std::shared_ptr<GraphView>> views() const noexcept { std::set<std::shared_ptr<GraphView>> res; for (const auto &v : mViews) { - res.insert(v.lock()); + if (auto p = v.lock()) { + res.insert(p); + } } return res; } diff --git a/include/aidge/scheduler/MemoryManager.hpp b/include/aidge/scheduler/MemoryManager.hpp index 360b01f76e7a9b51f36b83d4d35286eced35016a..94add56e8afdebb8e42f7ae49a32da2aeed9e9cb 100644 --- a/include/aidge/scheduler/MemoryManager.hpp +++ b/include/aidge/scheduler/MemoryManager.hpp @@ -182,16 +182,7 @@ public: const std::shared_ptr<MemorySpace>& p1); }; - struct CompByNodeName { - bool operator()(const std::shared_ptr<Node>& lhs, - const std::shared_ptr<Node>& rhs) const - { - return lhs->name() < rhs->name(); - } - }; - - typedef std::map<std::shared_ptr<Node>, std::vector<MemoryPlane>, - CompByNodeName> MemMap_T; + typedef std::map<std::shared_ptr<Node>, std::vector<MemoryPlane>> MemMap_T; public: MemoryManager(): mClock(0) {} diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp index 1000374454020625aada7f2043893b229deec833..4e74be8878eb3ca081fd2d5457e42768f4026be5 100644 --- a/python_binding/graph/pybind_GraphView.cpp +++ b/python_binding/graph/pybind_GraphView.cpp @@ -30,6 +30,10 @@ void init_GraphView(py::module& m) { :param path: save location :type path: str )mydelimiter") + .def("in_view", (bool (GraphView::*)(const NodePtr&) const) &GraphView::inView) + .def("in_view", (bool (GraphView::*)(const std::string&) const) &GraphView::inView) + .def("root_node", &GraphView::rootNode) + .def("set_root_node", &GraphView::setRootNode, py::arg("node")) .def("log_outputs", &GraphView::logOutputs, py::arg("path")) .def("get_ordered_inputs", &GraphView::getOrderedInputs) .def("get_ordered_outputs", &GraphView::getOrderedOutputs) @@ -61,13 +65,15 @@ void init_GraphView(py::module& m) { :type include_learnable_parameters: bool, optional )mydelimiter") - .def("add", (bool (GraphView::*)(std::shared_ptr<GraphView>)) & GraphView::add, - py::arg("other_graph"), + .def("add", (bool (GraphView::*)(std::shared_ptr<GraphView>, bool)) & GraphView::add, + py::arg("other_graph"), py::arg("include_learnable_parameters") = true, R"mydelimiter( Include a GraphView to the current GraphView object. :param other_graph: GraphView to add :type other_graph: GraphView + :param include_learnable_parameters: include non-data inputs, like weights and biases, default True. + :type include_learnable_parameters: bool, optional )mydelimiter") .def("add_child", diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 73830ab323e6a60a6896f01c2c6050bb50498e4a..5124d41f575b0ebf7f3c6cf258900e0ae656d213 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -764,10 +764,10 @@ bool Aidge::GraphView::add(std::pair<NodePtr, std::set<NodePtr>> nodes, bool inc return add(nodes.second, includeLearnableParam); } -bool Aidge::GraphView::add(std::shared_ptr<GraphView> graph) { +bool Aidge::GraphView::add(std::shared_ptr<GraphView> graph, bool includeLearnableParam) { // set the rootNode to the other graphView rootNode if no rootNode yet mRootNode = mRootNode ? mRootNode : graph->rootNode(); - return add(graph->getNodes(), true); + return add(graph->getNodes(), includeLearnableParam); } void Aidge::GraphView::addChild(std::shared_ptr<Node> toOtherNode, diff --git a/src/recipes/FuseToMetaOps.cpp b/src/recipes/FuseToMetaOps.cpp index 198e3a44bc7663aea42554cd9f08b0bfc616a06d..e7748936c00a20ec235ea7853f4d17e2c10261fb 100644 --- a/src/recipes/FuseToMetaOps.cpp +++ b/src/recipes/FuseToMetaOps.cpp @@ -13,30 +13,20 @@ #include "aidge/graph/Node.hpp" #include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Matching.hpp" #include "aidge/operator/MetaOperator.hpp" #include "aidge/recipes/Recipes.hpp" -//Graph Regex -#include "aidge/graphRegex/GraphRegex.hpp" - size_t Aidge::fuseToMetaOps(std::shared_ptr<GraphView> graphView, const std::string& query, const std::string& type) { - std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); - regex->setKeyFromGraph(graphView); - regex->addQuery(query); - const auto metaType = (!type.empty()) ? type : query; + const auto matches = SinglePassGraphMatching(graphView).match(query); size_t nbReplaced = 0; - const auto matches = regex->match(graphView); - - for (const auto& solution : matches) { - auto microGraph = std::make_shared<GraphView>(); - microGraph->add(solution->getAll()); - - auto metaOp = MetaOperator(metaType.c_str(), microGraph->clone()); + for (const auto& match : matches) { + auto metaOp = MetaOperator(metaType.c_str(), match.graph->clone()); auto metaOpGraph = std::make_shared<GraphView>(); - metaOpGraph->add(metaOp); - const auto success = GraphView::replace(microGraph, metaOpGraph); + metaOpGraph->add(metaOp, false); + const auto success = GraphView::replace(match.graph, metaOpGraph); if (!success) { Log::notice("Could not replace sub-graph with meta operator"); diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index af10e3dcd3ead044f8619c40570936f53039d9a2..acd583d873930bba38c48f43dc7cd336ce83268e 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -643,7 +643,7 @@ Aidge::Elts_t Aidge::Scheduler::getNbAvailableData(const std::shared_ptr<Node>& if (input.first == node) { // Current node is an input const auto upperInput = upperNode->inputs()[nodeInputIdx]; - if (upperInput.first) { + if (upperInput.first && nodeInputIdx == inputIdx) { return upperInput.first->getOperator()->getNbProducedData(upperInput.second); } }