From 10eba2d39e483ed2c5fa5dba4923bfe223ea6d1d Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Mon, 5 Feb 2024 13:06:02 +0000
Subject: [PATCH] [Upd] Reshape.cpp and Gather.cpp computeOutputDims() function
 to check input emptyness

---
 src/operator/Gather.cpp  | 30 ++++++++++++++------------
 src/operator/Reshape.cpp | 46 +++++++++++++++++++++-------------------
 2 files changed, 41 insertions(+), 35 deletions(-)

diff --git a/src/operator/Gather.cpp b/src/operator/Gather.cpp
index 3eafe99ef..b5f9d738a 100644
--- a/src/operator/Gather.cpp
+++ b/src/operator/Gather.cpp
@@ -9,8 +9,8 @@
  *
  ********************************************************************************/
 
-#include <cassert>
 #include <cstddef>
+#include <cstdint>
 #include <string>
 #include <vector>
 
@@ -26,18 +26,22 @@ void Aidge::Gather_Op::computeOutputDims() {
         AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not connected");
     }
 
-    std::vector<DimSize_t> outDims = getInput(0)->dims();
-    const std::vector<DimSize_t> gatheredShape = this->template getAttr<GatherAttr::GatheredShape>();
-    // TODO: check indices and gatheredShape
+    if (!getInput(0)->empty()) {
+        std::vector<DimSize_t> outDims = getInput(0)->dims();
+        const std::vector<DimSize_t> gatheredShape = this->template getAttr<GatherAttr::GatheredShape>();
+        // TODO: check indices and gatheredShape
 
-    const std::int64_t axisIdx = this->template getAttr<GatherAttr::Axis>() >= 0 ?
-                                 this->template getAttr<GatherAttr::Axis>() :
-                                 this->template getAttr<GatherAttr::Axis>() + outDims.size();
-    outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx));
-    if (!gatheredShape.empty())
-    {
-        outDims.insert(outDims.begin() + static_cast<std::size_t>(axisIdx), gatheredShape.begin(),gatheredShape.end());
-    }
+        const std::int64_t axisIdx = this->template getAttr<GatherAttr::Axis>() >= 0 ?
+                                        this->template getAttr<GatherAttr::Axis>() :
+                                        this->template getAttr<GatherAttr::Axis>() + outDims.size();
+        outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx));
+        if (!gatheredShape.empty())
+        {
+            outDims.insert(outDims.cbegin() + static_cast<std::size_t>(axisIdx),
+                            gatheredShape.cbegin(),
+                            gatheredShape.cend());
+        }
 
-    mOutputs[0]->resize(outDims);
+        mOutputs[0]->resize(outDims);
+    }
 }
\ No newline at end of file
diff --git a/src/operator/Reshape.cpp b/src/operator/Reshape.cpp
index 7032c8110..30b060cd2 100644
--- a/src/operator/Reshape.cpp
+++ b/src/operator/Reshape.cpp
@@ -27,30 +27,32 @@ void Aidge::Reshape_Op::computeOutputDims() {
         AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not connected");
     }
 
-    std::vector<DimSize_t> outDims;
-    // variables to handle a negative dimension
-    bool foundNegativeDimension = false;
-    std::size_t outSize = 1;
-    DimIdx_t negativeIndex = 0;
-
-    for(std::size_t i = 0; i < this->template getAttr<ReshapeAttr::Shape>().size(); ++i)
-    {
-        std::int64_t dimSize = this->template getAttr<ReshapeAttr::Shape>()[i];
-        if (dimSize < 0) {
-            if (foundNegativeDimension) {
-                AIDGE_THROW_OR_ABORT(std::runtime_error, "Found more than one negative dimension in Reshape Operator.");
+    if (!getInput(0)->empty()) {
+        std::vector<DimSize_t> outDims;
+        // variables to handle a negative dimension
+        bool foundNegativeDimension = false;
+        std::size_t outSize = 1;
+        DimIdx_t negativeIndex = 0;
+
+        for(std::size_t i = 0; i < this->template getAttr<ReshapeAttr::Shape>().size(); ++i)
+        {
+            std::int64_t dimSize = this->template getAttr<ReshapeAttr::Shape>()[i];
+            if (dimSize < 0) {
+                if (foundNegativeDimension) {
+                    AIDGE_THROW_OR_ABORT(std::runtime_error, "Found more than one negative dimension in Reshape Operator.");
+                }
+                foundNegativeDimension = true;
+                dimSize = 1;
+                negativeIndex = static_cast<DimIdx_t>(i);
             }
-            foundNegativeDimension = true;
-            dimSize = 1;
-            negativeIndex = static_cast<DimIdx_t>(i);
+            outDims.push_back(static_cast<DimSize_t>(dimSize));
+            outSize *= static_cast<DimSize_t>(dimSize);
         }
-        outDims.push_back(static_cast<DimSize_t>(dimSize));
-        outSize *= static_cast<DimSize_t>(dimSize);
-    }
 
-    if (foundNegativeDimension) {
-        outDims[negativeIndex] = (getInput(0) -> size()) / outSize;
-    }
+        if (foundNegativeDimension) {
+            outDims[negativeIndex] = (getInput(0) -> size()) / outSize;
+        }
 
-    mOutputs[0]->resize(outDims);
+        mOutputs[0]->resize(outDims);
+    }
 }
\ No newline at end of file
-- 
GitLab