From fc39b17a0d3c6cd7d93ad825e9c8b3120ff7c889 Mon Sep 17 00:00:00 2001
From: hrouis <houssemeddine.rouis92@gmail.com>
Date: Tue, 7 May 2024 10:03:08 +0200
Subject: [PATCH] avoid connecting empty shape tensor

---
 src/operator/Reshape.cpp                 | 9 +++++----
 unit_tests/operator/Test_ReshapeImpl.cpp | 4 ----
 2 files changed, 5 insertions(+), 8 deletions(-)

diff --git a/src/operator/Reshape.cpp b/src/operator/Reshape.cpp
index b78346b25..0cce7a5b9 100644
--- a/src/operator/Reshape.cpp
+++ b/src/operator/Reshape.cpp
@@ -32,10 +32,8 @@ const std::string Aidge::Reshape_Op::Type = "Reshape";
 
 bool Aidge::Reshape_Op::forwardDims(bool /*allowDataDependency*/) {
     // check input has been associated
-    for (size_t i = 0; i < 2; ++i) {
-        if (!getInput(i)) {
-            AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #{} should be associated with a Tensor", type(), i);
-        }
+    if (!getInput(0)) {
+        AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type());
     }
     if (!getInput(0)->empty()) {
         std::vector<DimSize_t> outDims;
@@ -46,6 +44,9 @@ bool Aidge::Reshape_Op::forwardDims(bool /*allowDataDependency*/) {
 
         // Fill shape attr if empty
         if (this->template getAttr<ReshapeAttr::Shape>().empty()) {
+            if (!getInput(1)) {
+                AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #1 should be associated with a Tensor", type());
+            }
             if(!getInput(1)->empty()) {
                 this->template getAttr<ReshapeAttr::Shape>().clear(); // If both are provided input would override attrs
                 this->template getAttr<ReshapeAttr::Shape>().reserve(getInput(1)->size());
diff --git a/unit_tests/operator/Test_ReshapeImpl.cpp b/unit_tests/operator/Test_ReshapeImpl.cpp
index 5b6f5f3fb..2685518b5 100644
--- a/unit_tests/operator/Test_ReshapeImpl.cpp
+++ b/unit_tests/operator/Test_ReshapeImpl.cpp
@@ -58,8 +58,6 @@ TEST_CASE("[cpu/operator] Reshape(forward)") {
             std::shared_ptr<Node> myReshape = Reshape({2, 3});
             auto op = std::static_pointer_cast<OperatorTensor>(myReshape -> getOperator());
             op->associateInput(0, input);
-            // TODO find a way to avoid connecting an empty tensor
-            op->associateInput(1, std::make_shared<Tensor>(Tensor({})));
             op->setDataType(DataType::Float32);
             op->setBackend("cpu");
             myReshape->forward();
@@ -116,8 +114,6 @@ TEST_CASE("[cpu/operator] Reshape(forward)") {
             std::shared_ptr<Node> myReshape = Reshape({3, 2});
             auto op = std::static_pointer_cast<OperatorTensor>(myReshape -> getOperator());
             op->associateInput(0, input);
-            // TODO find a way to avoid connecting an empty tensor
-            op->associateInput(1, std::make_shared<Tensor>(Tensor({})));
             op->setDataType(DataType::Float32);
             op->setBackend("cpu");
             myReshape->forward();
-- 
GitLab