Skip to content
Snippets Groups Projects
Commit 7eb944b3 authored by vincent  lorrain's avatar vincent lorrain
Browse files

Merge branch 'addFeature/graphRegex/queryResolution' into 'main'

Add feature/graph regex/query resolution

See merge request !49
parents 3182ea4a 64a3132b
No related branches found
No related tags found
1 merge request!49Add feature/graph regex/query resolution
Pipeline #34556 passed
......@@ -11,6 +11,11 @@
namespace Aidge{
/**
* type for recipes function use in query and resolve
*/
using RecipesFunctionType = std::function<void(std::shared_ptr<MatchSolution>)>;
/**
* @brief class which is the hight level interface for graph matching, used to simplify match definition
*
......@@ -19,9 +24,10 @@ class GraphRegex{
private:
std::vector<std::string> mQuery;
//std::vector<std::string> mQuery;
std::vector<std::shared_ptr<ConditionalInterpreter>> mAllTest;
std::map<std::string, std::function<bool(NodePtr)>> mAllLambda;
std::map<std::string,RecipesFunctionType> mQueryRecipe;
public:
GraphRegex(){};
......@@ -31,7 +37,15 @@ class GraphRegex{
* @brief add a topology query to the match
* @param query the topology query to find
**/
void addQuery(const std::string query);
//void addQuery(const std::string query);
/**
* @brief add a topology query to the match and a function for recipe
* @param query the topology query to find
* @param f the funct
**/
void addQuery(const std::string query,RecipesFunctionType f = nullptr);
/**
* @brief get all the types of a graph and set it as type key in the query
......@@ -53,13 +67,19 @@ class GraphRegex{
**/
void setNodeKey(const std::string key,std::function<bool(NodePtr)> f);
/***
/**
* @brief brief match the queries in the graph
* @param Reference the graph were the querys in search
* @param ref the graph were the querys in search
* @return the result
*/
std::set<std::shared_ptr<MatchSolution>> match(std::shared_ptr<GraphView> ref);
/***
* @brief match the queries in the graph and applied the recipes fuction
* @param ref the graph were the querys in search
*/
void appliedRecipes(std::shared_ptr<GraphView> ref);
private:
void _generateCombinationsStart(const std::set<NodePtr>& elements, std::size_t n, std::size_t index,
......
......@@ -26,10 +26,17 @@ void GraphRegex::setKeyFromGraph(std::shared_ptr<GraphView> ref){
void GraphRegex::addQuery(const std::string query){
mQuery.push_back(query);
}
// void GraphRegex::addQuery(const std::string query){
// //TODO one query only but the same string is a same query but
// //2 different string it's maybe the same query , we need to check the AST
// mQueryRecipe[query] = nullptr;
// }
void GraphRegex::addQuery(const std::string query,RecipesFunctionType f ){
mQueryRecipe[query] = f;
}
// Function to generate all combinations of n elements from a set
......@@ -87,7 +94,9 @@ std::set<std::shared_ptr<MatchSolution>> GraphRegex::match(std::shared_ptr<Graph
std::vector<std::shared_ptr<MatchSolution>> solutions = {};
for (const std::string& query : mQuery) {
//for (const std::string& query : mQuery) {
for (auto it = mQueryRecipe.begin(); it != mQueryRecipe.end(); ++it) {
const std::string query = it->first;
std::shared_ptr<GraphFsmInterpreter> fsmGenerator = std::make_shared<GraphFsmInterpreter>(query,mAllTest);
std::shared_ptr<FsmGraph> fsm = fsmGenerator->interpret();
......@@ -108,6 +117,15 @@ std::set<std::shared_ptr<MatchSolution>> GraphRegex::match(std::shared_ptr<Graph
return _findLargestCompatibleSet(solutions);
}
void GraphRegex::appliedRecipes(std::shared_ptr<GraphView> ref){
std::set<std::shared_ptr<MatchSolution>> matchRef = match(ref);
for (const auto& solution : matchRef) {
if(mQueryRecipe[solution->getQuery()] != nullptr){
mQueryRecipe[solution->getQuery()](solution);
}
}
}
void GraphRegex::setNodeKey(const std::string key, const std::string conditionalExpressions ){
mAllTest.push_back(std::make_shared<ConditionalInterpreter>(key,conditionalExpressions));
_majConditionalInterpreterLambda();
......
......@@ -2,6 +2,15 @@
#include <catch2/catch_test_macros.hpp>
#include "aidge/graphRegex/GraphRegex.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/operator/Add.hpp"
#include "aidge/operator/FC.hpp"
#include "aidge/operator/MatMul.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/utils/Recipies.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/GenericOperator.hpp"
......@@ -47,12 +56,8 @@ TEST_CASE("GraphRegexUser") {
}
SECTION("CC") {
std::shared_ptr<GraphRegex> sut = std::make_shared<GraphRegex>();
std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph");
std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c");
std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 1, 1, "c1");
......@@ -81,4 +86,58 @@ TEST_CASE("GraphRegexUser") {
}
}
SECTION("Applied Recipes"){
// generate the original GraphView
auto matmul0 = MatMul(5, "matmul0");
auto add0 = Add<2>("add0");
auto matmul1 = MatMul(5, "matmul1");
auto add1 = Add<2>("add1");
auto b0 = Producer({5}, "B0");
auto w0 = Producer({5, 5}, "W0");
auto b1 = Producer({5}, "B1");
auto w1 = Producer({5,5},"W1");
auto input = Producer({2,5}, "input");
input->addChild(matmul0, 0, 0);
w0->addChild(matmul0, 0, 1);
matmul0->addChild(add0, 0, 0);
b0->addChild(add0, 0, 1);
add0->addChild(matmul1, 0, 0);
w1->addChild(matmul1, 0, 1);
matmul1->addChild(add1, 0, 0);
b1->addChild(add1, 0, 1);
auto fc = GenericOperator("FC", 1, 1, 1, "c");
auto fl = GenericOperator("Flatten", 1, 1, 1, "c");
auto g = std::make_shared<GraphView>();
g->add({matmul0, add0, matmul1, add1, b0, b1,fl,fc});
std::shared_ptr<GraphRegex> kitchenBook = std::make_shared<GraphRegex>();
kitchenBook->setNodeKey("Add","getType($) =='Add'");
kitchenBook->setNodeKey("MatMul","getType($) =='MatMul'");
kitchenBook->setNodeKey("Flatten","getType($) =='Flatten'");
kitchenBook->setNodeKey("FC","getType($) =='FC'");
kitchenBook->addQuery("MatMul->Add",static_cast<void(*)(std::shared_ptr<MatchSolution>)>(fuseMulAdd));
kitchenBook->addQuery("Flatten->FC",static_cast<void(*)(std::shared_ptr<MatchSolution>)>(removeFlatten));
kitchenBook->appliedRecipes(g);
std::set<std::shared_ptr<Node>> newNodes = g->getNodes();
REQUIRE(newNodes != std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1,fc}));
//REQUIRE(newNodes.size() == 6);
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment