From 58c5ee5e0bc6031abdaea9c974803aecdf8539ef Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Fri, 19 Apr 2024 12:19:38 +0000
Subject: [PATCH] Add asserts to Conv forwardDims member function

---
 include/aidge/operator/Conv.hpp     | 20 ++++++++++----------
 unit_tests/graph/Test_GraphView.cpp |  5 +----
 2 files changed, 11 insertions(+), 14 deletions(-)

diff --git a/include/aidge/operator/Conv.hpp b/include/aidge/operator/Conv.hpp
index f0c6c12d7..d6a0df5ab 100644
--- a/include/aidge/operator/Conv.hpp
+++ b/include/aidge/operator/Conv.hpp
@@ -117,18 +117,18 @@ public:
             }
             associated &= !(getInput(i)->empty());
         }
-        AIDGE_ASSERT((getInput(0)->nbDims() == (DIM+2)) &&
+        if (associated) {
+            AIDGE_ASSERT((getInput(0)->nbDims() == (DIM+2)) &&
                      (getInput(0)->template dims<DIM+2>()[1] == this->template getAttr<ConvAttr::InChannels>()),
                      "Wrong input size for Conv operator.");
-        AIDGE_ASSERT((getInput(1)->nbDims() == (DIM+2)) &&
-                     (getInput(1)->template dims<DIM+2>()[1] == this->template getAttr<ConvAttr::InChannels>()) &&
-                     (getInput(1)->template dims<DIM+2>()[0] == this->template getAttr<ConvAttr::OutChannels>()),
-                     "Wrong weight size for Conv operator.");
-        if(!this->template getAttr<ConvAttr::NoBias>())
-            AIDGE_ASSERT((getInput(2)->nbDims() == (1)) &&
-                     (getInput(2)->template dims<1>()[0] == this->template getAttr<ConvAttr::OutChannels>()),
-                     "Wrong bias size for Conv operator.");
-        if (associated) {
+            AIDGE_ASSERT((getInput(1)->nbDims() == (DIM+2)) &&
+                        (getInput(1)->template dims<DIM+2>()[1] == this->template getAttr<ConvAttr::InChannels>()) &&
+                        (getInput(1)->template dims<DIM+2>()[0] == this->template getAttr<ConvAttr::OutChannels>()),
+                        "Wrong weight size for Conv operator.");
+            if(!this->template getAttr<ConvAttr::NoBias>())
+                AIDGE_ASSERT((getInput(2)->nbDims() == (1)) &&
+                        (getInput(2)->template dims<1>()[0] == this->template getAttr<ConvAttr::OutChannels>()),
+                        "Wrong bias size for Conv operator.");
             std::array<DimSize_t, DIM + 2> outputDims{};
             const std::array<DimSize_t, DIM + 2> inputDims(getInput(0)->template dims<DIM+2>());
 
diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp
index 437780b95..8403686d1 100644
--- a/unit_tests/graph/Test_GraphView.cpp
+++ b/unit_tests/graph/Test_GraphView.cpp
@@ -648,11 +648,8 @@ TEST_CASE("[GraphView] clone") {
     auto conv1 = Conv(3, 32, {3, 3}, "conv1");
     auto conv2 = Conv(32, 64, {3, 3}, "conv2");
     auto conv3 = Conv(64, 10, {1, 1}, "conv3");
-    auto g1 = std::make_shared<GraphView>("TestGraph");
+    auto g1 = Sequential({conv1, conv2, conv3});
     dataProvider->addChild(conv1, 0);
-    g1->add(conv1);
-    g1->addChild(conv2, conv1, 0);
-    g1->addChild(conv3, conv2, 0);
     g1->save("clone_g1");
 
     SECTION("Check input-output connections") {
-- 
GitLab