diff --git a/include/aidge/graph/Matching.hpp b/include/aidge/graph/Matching.hpp index b6aeaf2dcda0b9bb5c8c16a36cd3acf803643a07..441ae44cf9839028386bdea2f36f2020c7410cf2 100644 --- a/include/aidge/graph/Matching.hpp +++ b/include/aidge/graph/Matching.hpp @@ -37,7 +37,10 @@ public: bool singleOutput = true; IOIndex_t edgeLeftIdx = 0; IOIndex_t edgeRightIdx = 0; + + // For check & debug purpose: size_t depth = 0; + std::set<std::string> anchors; }; struct MatchingResult { @@ -75,13 +78,13 @@ public: * * Some rules: * - The first node of the first sequence is the root node and cannot be optional - * WRONG: Conv?->ReLU + * WRONG: Conv?->ReLU (will throws an error) * GOOD: ReLU<-Conv? * * - The first node of any further sequence must be an existing anchor * (the anchor cannot be in the middle of the sequence) - * WRONG: Conv->ReLU;Pad->Conv - * Pad->Conv;Conv->ReLU + * WRONG: Conv->ReLU;Pad->Conv (will throws an error) + * Pad->Conv;Conv->ReLU (will throws an error) * GOOD: Conv#->ReLU;Conv#<-Pad * Pad->Conv#;Conv#->ReLU * diff --git a/src/graph/Matching.cpp b/src/graph/Matching.cpp index 3ba80fd015870a374d7c37371dbef8dbbfb67dc1..d8f03d942f9817da53e7302b68d4bff32e791bb8 100644 --- a/src/graph/Matching.cpp +++ b/src/graph/Matching.cpp @@ -68,6 +68,9 @@ bool Aidge::SinglePassGraphMatching::matchNodeOrBlock(Context& ctx, std::set<Mat size_t matchQuantity = 0; removeWhiteSpace(newCtx.query); if (!newCtx.query.empty() && (newCtx.query[0] == '?' || newCtx.query[0] == '*')) { + AIDGE_ASSERT(!(ctx.firstSequence && ctx.firstNode), + "Ill-formed query; the root node cannot be optional in query at: {}", ctx.query); + for (const auto& match : matches) { bool found = false; for (const auto& newMatch : newMatches) { @@ -132,6 +135,7 @@ bool Aidge::SinglePassGraphMatching::matchNodeOrBlock(Context& ctx, std::set<Mat auto additionalCtx = ctx; additionalCtx.firstNode = newCtx.firstNode; additionalCtx.firstSequence = newCtx.firstSequence; + additionalCtx.anchors = newCtx.anchors; ++additionalCtx.depth; additionalMatches = newMatches; @@ -531,6 +535,11 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::set<MatchingRe } else if (newCtx.firstNode) { // First node of a (new) sequence: it has to be an existing anchor + AIDGE_ASSERT(!anchor.empty(), + "Ill-formed query; an anchor is expected in query at: {}", ctx.query); + AIDGE_ASSERT(newCtx.anchors.find(type + anchor) != newCtx.anchors.end(), + "Ill-formed query; the node anchor {} has to be an existing anchor in query at: {}", type + anchor, ctx.query); + for (auto it = newMatches.begin(); it != newMatches.end(); ) { const auto anchors = it->anchors[type]; const auto anchorNode = anchors.find(anchor); @@ -629,6 +638,7 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::set<MatchingRe Log::debug("{}node {}{}, found: {}", std::string(2*newCtx.depth + 2, ' '), fmt::styled(type.empty() ? "." : type, fmt::fg(fmt::color::yellow)), anchor, newMatches.size()); } + newCtx.anchors.insert(type + anchor); ctx = newCtx; matches = newMatches; return true; diff --git a/unit_tests/graph/Test_Matching.cpp b/unit_tests/graph/Test_Matching.cpp index 93a110ebdefe97d7bdd8e590948156ef77a09725..fa2c4fdcb88cbea6b0d20cb6311af52afdd4a0c8 100644 --- a/unit_tests/graph/Test_Matching.cpp +++ b/unit_tests/graph/Test_Matching.cpp @@ -74,6 +74,22 @@ TEST_CASE("[core/graph] Matching") { }); } + SECTION("Conv->ReLU;ReLU->Pad") { + REQUIRE_THROWS(SinglePassGraphMatching(g1).match("Conv->ReLU;ReLU->Pad")); + } + + SECTION("Conv->ReLU#1;ReLU#2->Pad") { + REQUIRE_THROWS(SinglePassGraphMatching(g1).match("Conv->ReLU#1;ReLU#2->Pad")); + } + + SECTION("Conv?->ReLU") { + REQUIRE_THROWS(SinglePassGraphMatching(g1).match("Conv?->ReLU")); + } + + SECTION("(Add#<*~.)*") { + REQUIRE_THROWS(SinglePassGraphMatching(g1).match("(Add#<*~.)*")); + } + SECTION("Conv->(ReLU~>Pad->Conv)*") { const auto results = SinglePassGraphMatching(g1).match("Conv->(ReLU~>Pad->Conv)*");