From 48de8d5dc56f79ca8e6353dc7d73ebd64d3cb3ec Mon Sep 17 00:00:00 2001
From: ms245755 <michal.szczepanski@cea.fr>
Date: Tue, 11 Jun 2024 11:40:16 +0000
Subject: [PATCH] AllowDataDependency added for resize operator.

---
 include/aidge/operator/Resize.hpp |  2 +-
 src/operator/Resize.cpp           | 30 ++++++++++++++++++++++++------
 2 files changed, 25 insertions(+), 7 deletions(-)

diff --git a/include/aidge/operator/Resize.hpp b/include/aidge/operator/Resize.hpp
index 3ecf74d2b..266a793b9 100644
--- a/include/aidge/operator/Resize.hpp
+++ b/include/aidge/operator/Resize.hpp
@@ -78,7 +78,7 @@ public:
         return std::make_shared<Resize_Op>(*this);
     }
 
-    // function see inputs
+    bool dimsForwarded() const override final;
     bool forwardDims(bool allowDataDependency = false) override final;
 
     void setBackend(const std::string& name, DeviceIdx_t device = 0) override final;
diff --git a/src/operator/Resize.cpp b/src/operator/Resize.cpp
index 98bffb9e0..989ab6fe2 100644
--- a/src/operator/Resize.cpp
+++ b/src/operator/Resize.cpp
@@ -25,6 +25,19 @@
 
 const std::string Aidge::Resize_Op::Type = "Resize";
 
+bool Aidge::Resize_Op::dimsForwarded() const {
+    if ((getInput(1) && !getInput(1)->empty())
+        || (getInput(2) && !getInput(2)->empty())
+        || (getInput(3) && !getInput(3)->empty()))
+    {
+        // output dims are data dependent
+        return false;
+    }
+
+    return OperatorTensor::dimsForwarded();
+}
+
+
 bool Aidge::Resize_Op::forwardDims(bool allowDataDependency) {
 
     AIDGE_ASSERT(getInput(0)->nbDims() ==  4,\
@@ -35,12 +48,6 @@ bool Aidge::Resize_Op::forwardDims(bool allowDataDependency) {
         if (!getInput(i)) {
             AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #{} not provided", type(), i);
         }
-
-        if (!allowDataDependency) {
-            Log::warn("Resize_Op: unable to forwardDims() because output dims are data dependent\
-             on input#0 and (input#2 or input#3)");
-            return false;
-        }
     }
 
     if (this->template getAttr<ResizeAttr::NoROI>() && this->template getAttr<ResizeAttr::NoSizes>())  {
@@ -49,6 +56,11 @@ bool Aidge::Resize_Op::forwardDims(bool allowDataDependency) {
          fmt::print("Condition scales: Input#0 and Input#2 must be provided and must have the same dimension,\
           while Inputs#1 and #3 must not be provided.\n"); 
         */
+        if (!allowDataDependency) {
+            Log::warn("Resize_Op: unable to forwardDims() because output dims are data dependent\
+             on input#0 and input#2");
+            return false;
+        }
 
         AIDGE_ASSERT(getInput(0)->nbDims() ==  getInput(2)->size(),\
              "input tensor and Scales must have the same dimensions.");
@@ -73,6 +85,12 @@ bool Aidge::Resize_Op::forwardDims(bool allowDataDependency) {
           while Inputs#1 and #2 must not be provided.\n"); 
         */
 
+        if (!allowDataDependency) {
+            Log::warn("Resize_Op: unable to forwardDims() because output dims are data dependent\
+             on input#0 and input#3)");
+            return false;
+        }
+
         AIDGE_ASSERT(getInput(0)->nbDims() ==  getInput(3)->size(),\
              "input tensor and Sizes must have the same dimensions.");       
         
-- 
GitLab