Skip to content
Snippets Groups Projects
Commit c2b40c94 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Working recipe

parent fb245d4a
No related branches found
No related tags found
2 merge requests!152Update Aidge export to take a graph view has an argument instead of a...,!131New features to simplify exports
Pipeline #47078 passed
......@@ -123,7 +123,14 @@ void explicitCastMove(std::shared_ptr<GraphView> graphView);
*/
void expandMetaOps(std::shared_ptr<GraphView> graph, bool recursive = false);
void fuseToMetaOps(std::shared_ptr<GraphView> graph, const std::string& query, const std::string& name = "");
/**
* Fuse each sub-graph matching a query in a Meta Operator.
* @param graph Graph to manipulate
* @param query Sub-graph matching query
* @param type Type name of the resulting meta operators
* @return size_t Number of replacement
*/
size_t fuseToMetaOps(std::shared_ptr<GraphView> graph, const std::string& query, const std::string& type = "");
} // namespace Aidge
......
......@@ -19,11 +19,13 @@
//Graph Regex
#include "aidge/graphRegex/GraphRegex.hpp"
void Aidge::fuseToMetaOps(std::shared_ptr<GraphView> graphView, const std::string& query, const std::string& name) {
size_t Aidge::fuseToMetaOps(std::shared_ptr<GraphView> graphView, const std::string& query, const std::string& type) {
std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>();
regex->setKeyFromGraph(graphView);
regex->addQuery(query);
const auto metaType = (!type.empty()) ? type : query;
size_t nbReplaced = 0;
const auto matches = regex->match(graphView);
......@@ -31,8 +33,10 @@ void Aidge::fuseToMetaOps(std::shared_ptr<GraphView> graphView, const std::strin
auto microGraph = std::make_shared<GraphView>();
microGraph->add(solution->getAll());
auto metaOp = MetaOperator(query.c_str(), microGraph, name);
const auto success = GraphView::replace(solution->getAll(), {metaOp});
auto metaOp = MetaOperator(metaType.c_str(), microGraph->clone());
auto metaOpGraph = std::make_shared<GraphView>();
metaOpGraph->add(metaOp);
const auto success = GraphView::replace(microGraph, metaOpGraph);
if (!success) {
Log::notice("Could not replace sub-graph with meta operator");
......@@ -43,4 +47,5 @@ void Aidge::fuseToMetaOps(std::shared_ptr<GraphView> graphView, const std::strin
}
Log::info("Replaced {} (out of {}) matching sub-graph with meta operators", nbReplaced, matches.size());
return nbReplaced;
}
......@@ -14,13 +14,31 @@
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/ReLU.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/recipes/Recipes.hpp"
namespace Aidge {
TEST_CASE("[cpu/recipes] FuseToMetaOps", "[FuseToMetaOps][recipes]") {
auto g1 = Sequential({
Producer({16, 3, 224, 224}, "dataProvider"),
Conv(3, 32, {3, 3}, "conv1"),
ReLU("relu1"),
Conv(32, 64, {3, 3}, "conv2"),
ReLU("relu2"),
Conv(64, 10, {1, 1}, "conv3")
});
g1->save("FuseToMetaOps_before");
// FIXME: GraphRegex also matches the Conv Producers, which are not in the query!
const auto nbFused = fuseToMetaOps(g1, "Conv->ReLU", "ConvReLU");
g1->save("FuseToMetaOps_after", true);
REQUIRE(nbFused == 2);
}
} // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment