Skip to content
Snippets Groups Projects
Commit be448ed5 authored by Grégoire Kubler's avatar Grégoire Kubler
Browse files

Merge branch 'dev' of gitlab.eclipse.org:eclipse/aidge/aidge_core into...

Merge branch 'dev' of gitlab.eclipse.org:eclipse/aidge/aidge_core into feat/operator_globalAveragePooling
parents 3e6342d1 93500677
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!91Feat/operator global average pooling
...@@ -117,6 +117,8 @@ std::set<std::shared_ptr<MatchSolution>> GraphRegex::match(std::shared_ptr<Graph ...@@ -117,6 +117,8 @@ std::set<std::shared_ptr<MatchSolution>> GraphRegex::match(std::shared_ptr<Graph
std::vector<std::shared_ptr<MatchSolution>> solution = fsm->test(combination); std::vector<std::shared_ptr<MatchSolution>> solution = fsm->test(combination);
solutions.insert(solutions.end(), solution.begin(), solution.end()); solutions.insert(solutions.end(), solution.begin(), solution.end());
} }
} }
return _findLargestCompatibleSet(solutions); return _findLargestCompatibleSet(solutions);
} }
...@@ -142,7 +144,10 @@ void GraphRegex::setNodeKey(const std::string key,std::function<bool(NodePtr)> f ...@@ -142,7 +144,10 @@ void GraphRegex::setNodeKey(const std::string key,std::function<bool(NodePtr)> f
throw std::runtime_error(key + " is define"); throw std::runtime_error(key + " is define");
} }
mAllLambda[key] = f; mAllLambda[key] = f;
_majConditionalInterpreterLambda(); _majConditionalInterpreterLambda();
//we add the lambda as key by default
setNodeKey(key, key + "($)==true");
} }
void GraphRegex::_majConditionalInterpreterLambda(){ void GraphRegex::_majConditionalInterpreterLambda(){
......
...@@ -120,7 +120,7 @@ std::shared_ptr<ParsingToken<ConditionalTokenTypes>> ConditionalLexer::getNextTo ...@@ -120,7 +120,7 @@ std::shared_ptr<ParsingToken<ConditionalTokenTypes>> ConditionalLexer::getNextTo
} }
if (std::regex_match(currentChars,std::regex("(true|false)"))){ if (std::regex_match(currentChars,std::regex("(true|false|True|False)"))){
return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::BOOL,currentChars); return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::BOOL,currentChars);
} else if (isLambda){ } else if (isLambda){
......
...@@ -18,6 +18,32 @@ using namespace Aidge; ...@@ -18,6 +18,32 @@ using namespace Aidge;
TEST_CASE("GraphRegexUser") { TEST_CASE("GraphRegexUser") {
SECTION("Match using custom lambda") {
std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph");
std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c");
std::shared_ptr<Node> fc = GenericOperator("FC", 1, 0, 1, "c1");
std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 0, 1, "c2");
std::shared_ptr<Node> fc2 = GenericOperator("FC", 1, 0, 1, "c3");
g1->add(conv);
g1->addChild(fc, "c");
g1->addChild(conv2, "c1");
g1->addChild(fc2, "c2");
///
std::shared_ptr<GraphRegex> sut = std::make_shared<GraphRegex>();
sut->setNodeKey("C",+[](NodePtr NodeOp){return NodeOp->type() == "FC";});
sut->setNodeKey("A","C($)==True");
sut->addQuery("A");
auto match = sut->match(g1);
REQUIRE(match.size() == 2);
}
SECTION("INIT") { SECTION("INIT") {
const std::string query = "Conv->FC"; const std::string query = "Conv->FC";
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment