From 426e4c6c6029f4d1a1a767a54f85e89c15e5eb0a Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Sun, 19 May 2024 12:15:44 +0200
Subject: [PATCH] Improved error checking

---
 include/aidge/graph/Matching.hpp   |  9 ++++++---
 src/graph/Matching.cpp             | 10 ++++++++++
 unit_tests/graph/Test_Matching.cpp | 16 ++++++++++++++++
 3 files changed, 32 insertions(+), 3 deletions(-)

diff --git a/include/aidge/graph/Matching.hpp b/include/aidge/graph/Matching.hpp
index b6aeaf2dc..441ae44cf 100644
--- a/include/aidge/graph/Matching.hpp
+++ b/include/aidge/graph/Matching.hpp
@@ -37,7 +37,10 @@ public:
         bool singleOutput = true;
         IOIndex_t edgeLeftIdx = 0;
         IOIndex_t edgeRightIdx = 0;
+
+        // For check & debug purpose:
         size_t depth = 0;
+        std::set<std::string> anchors;
     };
 
     struct MatchingResult {
@@ -75,13 +78,13 @@ public:
      * 
      * Some rules:
      * - The first node of the first sequence is the root node and cannot be optional
-     *   WRONG: Conv?->ReLU
+     *   WRONG: Conv?->ReLU (will throws an error)
      *   GOOD: ReLU<-Conv?
      * 
      * - The first node of any further sequence must be an existing anchor
      *   (the anchor cannot be in the middle of the sequence)
-     *   WRONG: Conv->ReLU;Pad->Conv
-     *          Pad->Conv;Conv->ReLU
+     *   WRONG: Conv->ReLU;Pad->Conv (will throws an error)
+     *          Pad->Conv;Conv->ReLU (will throws an error)
      *   GOOD: Conv#->ReLU;Conv#<-Pad
      *         Pad->Conv#;Conv#->ReLU
      * 
diff --git a/src/graph/Matching.cpp b/src/graph/Matching.cpp
index 3ba80fd01..d8f03d942 100644
--- a/src/graph/Matching.cpp
+++ b/src/graph/Matching.cpp
@@ -68,6 +68,9 @@ bool Aidge::SinglePassGraphMatching::matchNodeOrBlock(Context& ctx, std::set<Mat
     size_t matchQuantity = 0;
     removeWhiteSpace(newCtx.query);
     if (!newCtx.query.empty() && (newCtx.query[0] == '?' || newCtx.query[0] == '*')) {
+        AIDGE_ASSERT(!(ctx.firstSequence && ctx.firstNode),
+            "Ill-formed query; the root node cannot be optional in query at: {}", ctx.query);
+
         for (const auto& match : matches) {
             bool found = false;
             for (const auto& newMatch : newMatches) {
@@ -132,6 +135,7 @@ bool Aidge::SinglePassGraphMatching::matchNodeOrBlock(Context& ctx, std::set<Mat
             auto additionalCtx = ctx;
             additionalCtx.firstNode = newCtx.firstNode;
             additionalCtx.firstSequence = newCtx.firstSequence;
+            additionalCtx.anchors = newCtx.anchors;
             ++additionalCtx.depth;
             additionalMatches = newMatches;
 
@@ -531,6 +535,11 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::set<MatchingRe
     }
     else if (newCtx.firstNode) {
         // First node of a (new) sequence: it has to be an existing anchor
+        AIDGE_ASSERT(!anchor.empty(),
+            "Ill-formed query; an anchor is expected in query at: {}", ctx.query);
+        AIDGE_ASSERT(newCtx.anchors.find(type + anchor) != newCtx.anchors.end(),
+            "Ill-formed query; the node anchor {} has to be an existing anchor in query at: {}", type + anchor, ctx.query);
+
         for (auto it = newMatches.begin(); it != newMatches.end(); ) {
             const auto anchors = it->anchors[type];
             const auto anchorNode = anchors.find(anchor);
@@ -629,6 +638,7 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::set<MatchingRe
         Log::debug("{}node {}{}, found: {}", std::string(2*newCtx.depth + 2, ' '), fmt::styled(type.empty() ? "." : type, fmt::fg(fmt::color::yellow)), anchor, newMatches.size());
     }
 
+    newCtx.anchors.insert(type + anchor);
     ctx = newCtx;
     matches = newMatches;
     return true;
diff --git a/unit_tests/graph/Test_Matching.cpp b/unit_tests/graph/Test_Matching.cpp
index 93a110ebd..fa2c4fdcb 100644
--- a/unit_tests/graph/Test_Matching.cpp
+++ b/unit_tests/graph/Test_Matching.cpp
@@ -74,6 +74,22 @@ TEST_CASE("[core/graph] Matching") {
         });
     }
 
+    SECTION("Conv->ReLU;ReLU->Pad") {
+        REQUIRE_THROWS(SinglePassGraphMatching(g1).match("Conv->ReLU;ReLU->Pad"));
+    }
+
+    SECTION("Conv->ReLU#1;ReLU#2->Pad") {
+        REQUIRE_THROWS(SinglePassGraphMatching(g1).match("Conv->ReLU#1;ReLU#2->Pad"));
+    }
+
+    SECTION("Conv?->ReLU") {
+        REQUIRE_THROWS(SinglePassGraphMatching(g1).match("Conv?->ReLU"));
+    }
+
+    SECTION("(Add#<*~.)*") {
+        REQUIRE_THROWS(SinglePassGraphMatching(g1).match("(Add#<*~.)*"));
+    }
+
     SECTION("Conv->(ReLU~>Pad->Conv)*") {
         const auto results = SinglePassGraphMatching(g1).match("Conv->(ReLU~>Pad->Conv)*");
 
-- 
GitLab