From 426e4c6c6029f4d1a1a767a54f85e89c15e5eb0a Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Sun, 19 May 2024 12:15:44 +0200 Subject: [PATCH] Improved error checking --- include/aidge/graph/Matching.hpp | 9 ++++++--- src/graph/Matching.cpp | 10 ++++++++++ unit_tests/graph/Test_Matching.cpp | 16 ++++++++++++++++ 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/include/aidge/graph/Matching.hpp b/include/aidge/graph/Matching.hpp index b6aeaf2dc..441ae44cf 100644 --- a/include/aidge/graph/Matching.hpp +++ b/include/aidge/graph/Matching.hpp @@ -37,7 +37,10 @@ public: bool singleOutput = true; IOIndex_t edgeLeftIdx = 0; IOIndex_t edgeRightIdx = 0; + + // For check & debug purpose: size_t depth = 0; + std::set<std::string> anchors; }; struct MatchingResult { @@ -75,13 +78,13 @@ public: * * Some rules: * - The first node of the first sequence is the root node and cannot be optional - * WRONG: Conv?->ReLU + * WRONG: Conv?->ReLU (will throws an error) * GOOD: ReLU<-Conv? * * - The first node of any further sequence must be an existing anchor * (the anchor cannot be in the middle of the sequence) - * WRONG: Conv->ReLU;Pad->Conv - * Pad->Conv;Conv->ReLU + * WRONG: Conv->ReLU;Pad->Conv (will throws an error) + * Pad->Conv;Conv->ReLU (will throws an error) * GOOD: Conv#->ReLU;Conv#<-Pad * Pad->Conv#;Conv#->ReLU * diff --git a/src/graph/Matching.cpp b/src/graph/Matching.cpp index 3ba80fd01..d8f03d942 100644 --- a/src/graph/Matching.cpp +++ b/src/graph/Matching.cpp @@ -68,6 +68,9 @@ bool Aidge::SinglePassGraphMatching::matchNodeOrBlock(Context& ctx, std::set<Mat size_t matchQuantity = 0; removeWhiteSpace(newCtx.query); if (!newCtx.query.empty() && (newCtx.query[0] == '?' || newCtx.query[0] == '*')) { + AIDGE_ASSERT(!(ctx.firstSequence && ctx.firstNode), + "Ill-formed query; the root node cannot be optional in query at: {}", ctx.query); + for (const auto& match : matches) { bool found = false; for (const auto& newMatch : newMatches) { @@ -132,6 +135,7 @@ bool Aidge::SinglePassGraphMatching::matchNodeOrBlock(Context& ctx, std::set<Mat auto additionalCtx = ctx; additionalCtx.firstNode = newCtx.firstNode; additionalCtx.firstSequence = newCtx.firstSequence; + additionalCtx.anchors = newCtx.anchors; ++additionalCtx.depth; additionalMatches = newMatches; @@ -531,6 +535,11 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::set<MatchingRe } else if (newCtx.firstNode) { // First node of a (new) sequence: it has to be an existing anchor + AIDGE_ASSERT(!anchor.empty(), + "Ill-formed query; an anchor is expected in query at: {}", ctx.query); + AIDGE_ASSERT(newCtx.anchors.find(type + anchor) != newCtx.anchors.end(), + "Ill-formed query; the node anchor {} has to be an existing anchor in query at: {}", type + anchor, ctx.query); + for (auto it = newMatches.begin(); it != newMatches.end(); ) { const auto anchors = it->anchors[type]; const auto anchorNode = anchors.find(anchor); @@ -629,6 +638,7 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::set<MatchingRe Log::debug("{}node {}{}, found: {}", std::string(2*newCtx.depth + 2, ' '), fmt::styled(type.empty() ? "." : type, fmt::fg(fmt::color::yellow)), anchor, newMatches.size()); } + newCtx.anchors.insert(type + anchor); ctx = newCtx; matches = newMatches; return true; diff --git a/unit_tests/graph/Test_Matching.cpp b/unit_tests/graph/Test_Matching.cpp index 93a110ebd..fa2c4fdcb 100644 --- a/unit_tests/graph/Test_Matching.cpp +++ b/unit_tests/graph/Test_Matching.cpp @@ -74,6 +74,22 @@ TEST_CASE("[core/graph] Matching") { }); } + SECTION("Conv->ReLU;ReLU->Pad") { + REQUIRE_THROWS(SinglePassGraphMatching(g1).match("Conv->ReLU;ReLU->Pad")); + } + + SECTION("Conv->ReLU#1;ReLU#2->Pad") { + REQUIRE_THROWS(SinglePassGraphMatching(g1).match("Conv->ReLU#1;ReLU#2->Pad")); + } + + SECTION("Conv?->ReLU") { + REQUIRE_THROWS(SinglePassGraphMatching(g1).match("Conv?->ReLU")); + } + + SECTION("(Add#<*~.)*") { + REQUIRE_THROWS(SinglePassGraphMatching(g1).match("(Add#<*~.)*")); + } + SECTION("Conv->(ReLU~>Pad->Conv)*") { const auto results = SinglePassGraphMatching(g1).match("Conv->(ReLU~>Pad->Conv)*"); -- GitLab