diff --git a/include/aidge/utils/Registrar.hpp b/include/aidge/utils/Registrar.hpp index a5bd260ec189ac998134b738ca1ae757f2a0038c..567270d63c092aef6411a4438f59b7770ee3d5bf 100644 --- a/include/aidge/utils/Registrar.hpp +++ b/include/aidge/utils/Registrar.hpp @@ -132,11 +132,13 @@ void declare_registrable(py::module& m, const std::string& class_name){ #ifdef PYBIND #define SET_IMPL_MACRO(T_Op, op, backend_name) \ \ - if(Py_IsInitialized()) { \ - auto obj = py::cast(&(op)); \ - (op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \ - } else { \ - (op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \ + if (Registrar<T_Op>::exists(backend_name)) { \ + if(Py_IsInitialized()) { \ + auto obj = py::cast(&(op)); \ + (op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \ + } else { \ + (op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \ + } \ } #else #define SET_IMPL_MACRO(T_Op, op, backend_name) \ diff --git a/src/graphRegex/GraphRegex.cpp b/src/graphRegex/GraphRegex.cpp index 00a031e3fa9b03ff1870446b9ae58e8d3eb65bf7..ca15ff8dec5ff5ebd4ea69141c6e286849162bb5 100644 --- a/src/graphRegex/GraphRegex.cpp +++ b/src/graphRegex/GraphRegex.cpp @@ -117,6 +117,8 @@ std::set<std::shared_ptr<MatchSolution>> GraphRegex::match(std::shared_ptr<Graph std::vector<std::shared_ptr<MatchSolution>> solution = fsm->test(combination); solutions.insert(solutions.end(), solution.begin(), solution.end()); } + + } return _findLargestCompatibleSet(solutions); } @@ -142,7 +144,10 @@ void GraphRegex::setNodeKey(const std::string key,std::function<bool(NodePtr)> f throw std::runtime_error(key + " is define"); } mAllLambda[key] = f; + _majConditionalInterpreterLambda(); + //we add the lambda as key by default + setNodeKey(key, key + "($)==true"); } void GraphRegex::_majConditionalInterpreterLambda(){ diff --git a/src/nodeTester/ConditionalLexer.cpp b/src/nodeTester/ConditionalLexer.cpp index 9379bd8409f8f7ec4bae3e0122f88de79718e9dd..e70772fc1a5d6136fb56f5981d73bf6cb0622991 100644 --- a/src/nodeTester/ConditionalLexer.cpp +++ b/src/nodeTester/ConditionalLexer.cpp @@ -120,7 +120,7 @@ std::shared_ptr<ParsingToken<ConditionalTokenTypes>> ConditionalLexer::getNextTo } - if (std::regex_match(currentChars,std::regex("(true|false)"))){ + if (std::regex_match(currentChars,std::regex("(true|false|True|False)"))){ return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::BOOL,currentChars); } else if (isLambda){ diff --git a/unit_tests/graphRegex/Test_GraphRegex.cpp b/unit_tests/graphRegex/Test_GraphRegex.cpp index bcd6d0f4cd9ba32ee4318188343b7e6360670d3b..a62b9a8602b494f26fb47061b899eaba41129a1f 100644 --- a/unit_tests/graphRegex/Test_GraphRegex.cpp +++ b/unit_tests/graphRegex/Test_GraphRegex.cpp @@ -18,6 +18,32 @@ using namespace Aidge; TEST_CASE("GraphRegexUser") { + + SECTION("Match using custom lambda") { + + std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c"); + std::shared_ptr<Node> fc = GenericOperator("FC", 1, 0, 1, "c1"); + std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 0, 1, "c2"); + std::shared_ptr<Node> fc2 = GenericOperator("FC", 1, 0, 1, "c3"); + + g1->add(conv); + g1->addChild(fc, "c"); + g1->addChild(conv2, "c1"); + g1->addChild(fc2, "c2"); + + /// + std::shared_ptr<GraphRegex> sut = std::make_shared<GraphRegex>(); + sut->setNodeKey("C",+[](NodePtr NodeOp){return NodeOp->type() == "FC";}); + + sut->setNodeKey("A","C($)==True"); + sut->addQuery("A"); + auto match = sut->match(g1); + REQUIRE(match.size() == 2); + + } + + SECTION("INIT") { const std::string query = "Conv->FC";