From 3672c60e676df75c0dffe40ed85c353b2b1644df Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Fri, 10 Jan 2025 16:05:28 +0000
Subject: [PATCH] Update shape so that forwardDims does not require a backend.
 Update test to remove unecessary Node.

---
 include/aidge/data/Tensor.hpp          |  7 +++++++
 include/aidge/operator/Shape.hpp       |  4 +---
 src/operator/Shape.cpp                 | 12 ++++++++++--
 unit_tests/operator/Test_ShapeImpl.cpp |  9 ++++-----
 4 files changed, 22 insertions(+), 10 deletions(-)

diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp
index fdeef2a8e..3f609e54d 100644
--- a/include/aidge/data/Tensor.hpp
+++ b/include/aidge/data/Tensor.hpp
@@ -393,6 +393,13 @@ public:
         return hasImpl() ? getImpl()->backend() : "";
     }
 
+
+    /**
+     * @brief Get the device index.
+     * @return DeviceIdx_t
+     */
+    DeviceIdx_t device() const noexcept { return mImpl ? mImpl->device().second : static_cast<DeviceIdx_t>(0); }
+
     /**
      * @brief Set the backend of the Tensor associated implementation. If there
      * was no previous implementation set, data will be allocated, but it will
diff --git a/include/aidge/operator/Shape.hpp b/include/aidge/operator/Shape.hpp
index cfd43fa0d..d40067d29 100644
--- a/include/aidge/operator/Shape.hpp
+++ b/include/aidge/operator/Shape.hpp
@@ -47,9 +47,7 @@ private:
     const std::shared_ptr<Attributes_> mAttributes;
 
 public:
-    Shape_Op() = delete;
-
-    Shape_Op(const std::int64_t start, const std::int64_t end);
+    Shape_Op(const std::int64_t start = 0, const std::int64_t end = -1);
 
     /**
      * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
diff --git a/src/operator/Shape.cpp b/src/operator/Shape.cpp
index ecaa12191..c38f52d76 100644
--- a/src/operator/Shape.cpp
+++ b/src/operator/Shape.cpp
@@ -18,10 +18,14 @@
 #include "aidge/data/Tensor.hpp"
 #include "aidge/utils/ErrorHandling.hpp"
 #include "aidge/utils/Types.h"
+#include "aidge/utils/Log.hpp"
 
 void Aidge::Shape_OpImpl::forward() {
-    // Do nothing...
     // Output is already valid after forwardDims()
+    // But it may be with the wrong device (default cpu)
+    // This can happen if forwardDims is called before setBackend
+    const Shape_Op& op = dynamic_cast<const Shape_Op&>(mOp);
+    op.getOutput(0)->setBackend(op.getInput(0)->backend(), op.getInput(0)->device());
 }
 
 ///////////////////////////////////////////////
@@ -69,6 +73,10 @@ bool Aidge::Shape_Op::forwardDims(bool /*allowDataDependency*/) {
         AIDGE_ASSERT(roi> 1, "Invalid ROI for Shape");
 
         mOutputs[0]->resize({roi});
+        if (!mOutputs[0]->getImpl()){
+            Log::debug("Shape::forwardDims, no implementation set for output, defaulting to CPU.");
+            mOutputs[0]->setBackend("cpu");
+        }
         // Ensure the output of this operator is valid after forwardDims():
         mOutputs[0]->getImpl()->copyCast(std::next(getInput(0)->dims().data(),
                                                     start),
@@ -98,4 +106,4 @@ std::set<std::string> Aidge::Shape_Op::getAvailableBackends() const {
 
 std::shared_ptr<Aidge::Node> Aidge::Shape(const std::int64_t start, const std::int64_t end, const std::string& name) {
     return std::make_shared<Node>(std::make_shared<Shape_Op>(start, end), name);
-}
\ No newline at end of file
+}
diff --git a/unit_tests/operator/Test_ShapeImpl.cpp b/unit_tests/operator/Test_ShapeImpl.cpp
index 45df89df0..56ca3eaa1 100644
--- a/unit_tests/operator/Test_ShapeImpl.cpp
+++ b/unit_tests/operator/Test_ShapeImpl.cpp
@@ -42,12 +42,11 @@ TEST_CASE("[cpu/operator] Shape(forward)", "[Shape][CPU]") {
             {1, 2, 3, 5}
         });
 
-        std::shared_ptr<Node> myShape = Shape();
-        auto op = std::static_pointer_cast<OperatorTensor>(myShape -> getOperator());
-        op->associateInput(0,input);
+        std::shared_ptr<Shape_Op> op = std::make_shared<Shape_Op>();
+        op->associateInput(0, input);
         op->setDataType(DataType::Int32);
         op->setBackend("cpu");
-        myShape->forward();
+        op->forward();
 
         REQUIRE(*(op->getOutput(0)) == *expectedOutput);
 
@@ -83,4 +82,4 @@ TEST_CASE("[cpu/operator] Shape(forward)", "[Shape][CPU]") {
         REQUIRE(*(op->getOutput(0)) == *expectedOutput);
 
     }
-}
\ No newline at end of file
+}
-- 
GitLab