Skip to content
Snippets Groups Projects
Commit d497b10f authored by Grégoire Kubler's avatar Grégoire Kubler
Browse files

Merge remote-tracking branch 'origin/dev' into feat/support_ASAN

parents 000d75fc 93500677
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!100fix/scheduler_exec_time
......@@ -132,11 +132,13 @@ void declare_registrable(py::module& m, const std::string& class_name){
#ifdef PYBIND
#define SET_IMPL_MACRO(T_Op, op, backend_name) \
\
if(Py_IsInitialized()) { \
auto obj = py::cast(&(op)); \
(op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \
} else { \
(op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \
if (Registrar<T_Op>::exists(backend_name)) { \
if(Py_IsInitialized()) { \
auto obj = py::cast(&(op)); \
(op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \
} else { \
(op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \
} \
}
#else
#define SET_IMPL_MACRO(T_Op, op, backend_name) \
......
......@@ -117,6 +117,8 @@ std::set<std::shared_ptr<MatchSolution>> GraphRegex::match(std::shared_ptr<Graph
std::vector<std::shared_ptr<MatchSolution>> solution = fsm->test(combination);
solutions.insert(solutions.end(), solution.begin(), solution.end());
}
}
return _findLargestCompatibleSet(solutions);
}
......@@ -142,7 +144,10 @@ void GraphRegex::setNodeKey(const std::string key,std::function<bool(NodePtr)> f
throw std::runtime_error(key + " is define");
}
mAllLambda[key] = f;
_majConditionalInterpreterLambda();
//we add the lambda as key by default
setNodeKey(key, key + "($)==true");
}
void GraphRegex::_majConditionalInterpreterLambda(){
......
......@@ -120,7 +120,7 @@ std::shared_ptr<ParsingToken<ConditionalTokenTypes>> ConditionalLexer::getNextTo
}
if (std::regex_match(currentChars,std::regex("(true|false)"))){
if (std::regex_match(currentChars,std::regex("(true|false|True|False)"))){
return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::BOOL,currentChars);
} else if (isLambda){
......
......@@ -18,6 +18,32 @@ using namespace Aidge;
TEST_CASE("GraphRegexUser") {
SECTION("Match using custom lambda") {
std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph");
std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c");
std::shared_ptr<Node> fc = GenericOperator("FC", 1, 0, 1, "c1");
std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 0, 1, "c2");
std::shared_ptr<Node> fc2 = GenericOperator("FC", 1, 0, 1, "c3");
g1->add(conv);
g1->addChild(fc, "c");
g1->addChild(conv2, "c1");
g1->addChild(fc2, "c2");
///
std::shared_ptr<GraphRegex> sut = std::make_shared<GraphRegex>();
sut->setNodeKey("C",+[](NodePtr NodeOp){return NodeOp->type() == "FC";});
sut->setNodeKey("A","C($)==True");
sut->addQuery("A");
auto match = sut->match(g1);
REQUIRE(match.size() == 2);
}
SECTION("INIT") {
const std::string query = "Conv->FC";
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment