diff --git a/include/aidge/recipes/Recipes.hpp b/include/aidge/recipes/Recipes.hpp index e389d3872e7da0def6b5aaf53b2110a949e1bb48..5b4ae548726405d51292576ecda4379b059624f6 100644 --- a/include/aidge/recipes/Recipes.hpp +++ b/include/aidge/recipes/Recipes.hpp @@ -123,7 +123,14 @@ void explicitCastMove(std::shared_ptr<GraphView> graphView); */ void expandMetaOps(std::shared_ptr<GraphView> graph, bool recursive = false); -void fuseToMetaOps(std::shared_ptr<GraphView> graph, const std::string& query, const std::string& name = ""); +/** + * Fuse each sub-graph matching a query in a Meta Operator. + * @param graph Graph to manipulate + * @param query Sub-graph matching query + * @param type Type name of the resulting meta operators + * @return size_t Number of replacement +*/ +size_t fuseToMetaOps(std::shared_ptr<GraphView> graph, const std::string& query, const std::string& type = ""); } // namespace Aidge diff --git a/src/recipes/FuseToMetaOps.cpp b/src/recipes/FuseToMetaOps.cpp index be6f3c1df24835eb76d1f288d033cc6ae9856713..198e3a44bc7663aea42554cd9f08b0bfc616a06d 100644 --- a/src/recipes/FuseToMetaOps.cpp +++ b/src/recipes/FuseToMetaOps.cpp @@ -19,11 +19,13 @@ //Graph Regex #include "aidge/graphRegex/GraphRegex.hpp" -void Aidge::fuseToMetaOps(std::shared_ptr<GraphView> graphView, const std::string& query, const std::string& name) { +size_t Aidge::fuseToMetaOps(std::shared_ptr<GraphView> graphView, const std::string& query, const std::string& type) { std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); regex->setKeyFromGraph(graphView); regex->addQuery(query); + const auto metaType = (!type.empty()) ? type : query; + size_t nbReplaced = 0; const auto matches = regex->match(graphView); @@ -31,8 +33,10 @@ void Aidge::fuseToMetaOps(std::shared_ptr<GraphView> graphView, const std::strin auto microGraph = std::make_shared<GraphView>(); microGraph->add(solution->getAll()); - auto metaOp = MetaOperator(query.c_str(), microGraph, name); - const auto success = GraphView::replace(solution->getAll(), {metaOp}); + auto metaOp = MetaOperator(metaType.c_str(), microGraph->clone()); + auto metaOpGraph = std::make_shared<GraphView>(); + metaOpGraph->add(metaOp); + const auto success = GraphView::replace(microGraph, metaOpGraph); if (!success) { Log::notice("Could not replace sub-graph with meta operator"); @@ -43,4 +47,5 @@ void Aidge::fuseToMetaOps(std::shared_ptr<GraphView> graphView, const std::strin } Log::info("Replaced {} (out of {}) matching sub-graph with meta operators", nbReplaced, matches.size()); + return nbReplaced; } diff --git a/unit_tests/recipes/Test_FuseToMetaOps.cpp b/unit_tests/recipes/Test_FuseToMetaOps.cpp index 34ca595dfd92bcb33a2ec009e6849aed8a2d5bdd..9fceedf2feef0a3ed79b83a8494a1a2b49f77291 100644 --- a/unit_tests/recipes/Test_FuseToMetaOps.cpp +++ b/unit_tests/recipes/Test_FuseToMetaOps.cpp @@ -14,13 +14,31 @@ #include "aidge/data/Tensor.hpp" #include "aidge/graph/GraphView.hpp" +#include "aidge/graph/OpArgs.hpp" +#include "aidge/operator/Conv.hpp" +#include "aidge/operator/ReLU.hpp" +#include "aidge/operator/Producer.hpp" #include "aidge/recipes/Recipes.hpp" namespace Aidge { TEST_CASE("[cpu/recipes] FuseToMetaOps", "[FuseToMetaOps][recipes]") { - + auto g1 = Sequential({ + Producer({16, 3, 224, 224}, "dataProvider"), + Conv(3, 32, {3, 3}, "conv1"), + ReLU("relu1"), + Conv(32, 64, {3, 3}, "conv2"), + ReLU("relu2"), + Conv(64, 10, {1, 1}, "conv3") + }); + g1->save("FuseToMetaOps_before"); + + // FIXME: GraphRegex also matches the Conv Producers, which are not in the query! + const auto nbFused = fuseToMetaOps(g1, "Conv->ReLU", "ConvReLU"); + g1->save("FuseToMetaOps_after", true); + + REQUIRE(nbFused == 2); } } // namespace Aidge