From aeb8d162fb18099e74c2339474cd6257a511b8dc Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Tue, 9 Apr 2024 10:39:06 +0200
Subject: [PATCH] Fixed Identity to not require forwardDims() and removed
 associateInput() from forwardDims()

---
 include/aidge/operator/Identity.hpp | 21 +--------------------
 src/graph/GraphView.cpp             | 15 +++++----------
 src/operator/Identity.cpp           |  8 +++++++-
 3 files changed, 13 insertions(+), 31 deletions(-)

diff --git a/include/aidge/operator/Identity.hpp b/include/aidge/operator/Identity.hpp
index 08634d9fa..51c70eae5 100644
--- a/include/aidge/operator/Identity.hpp
+++ b/include/aidge/operator/Identity.hpp
@@ -78,29 +78,10 @@ public:
     }
 
 
-    void forward() override final { runHooks(); }
+    void forward() override final;
 
     void backward() override final { }
 
-    void setOutput(const IOIndex_t outputIdx, const std::shared_ptr<Data>& data) override final {
-        AIDGE_ASSERT(data->type() == "Tensor", "{} Operator only accepts Tensors as outputs", type());
-        AIDGE_ASSERT(outputIdx < nbInputs(), "{} Operator has {} outputs", type(), nbInputs());
-        *mInputs[outputIdx] = *std::dynamic_pointer_cast<Tensor>(data);
-    }
-
-    void setOutput(const IOIndex_t outputIdx, std::shared_ptr<Data>&& data) override final {
-        AIDGE_ASSERT(data->type() == "Tensor", "{} Operator only accepts Tensors as inputs", type());
-        AIDGE_ASSERT(outputIdx < nbInputs(), "{} Operator has {} outputs", type(), nbInputs());
-        *mInputs[outputIdx] = std::move(*std::dynamic_pointer_cast<Tensor>(data));
-    }
-
-    const std::shared_ptr<Tensor>& getOutput(const IOIndex_t outputIdx) const override final {
-        AIDGE_ASSERT(outputIdx < nbInputs(), "{} Operator has {} outputs", type(), nbInputs());
-        if (mInputs[outputIdx] == nullptr){
-            return mOutputs[outputIdx]; // Input is not initialized with empty tensor
-        }
-        return mInputs[outputIdx]; // Identity, so Output is Input
-    }
     void setBackend(const std::string& /*name*/, DeviceIdx_t /*device*/ = 0) override final {
         // setBackend do nothing, Identity node has no backend it just pass the same Tensor
     }
diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp
index 9b53a9d82..88c7383a9 100644
--- a/src/graph/GraphView.cpp
+++ b/src/graph/GraphView.cpp
@@ -406,19 +406,14 @@ bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_
     // Ensure every node in the graph is correctly connected
     for (std::shared_ptr<Node> nodePtr : getNodes()) {
         for (IOIndex_t i = 0; i < nodePtr->nbInputs(); ++i) {
-            // assess if the input was not already set and is a Tensor then link it to parent output
             std::pair<std::shared_ptr<Node>, IOIndex_t> inputI = nodePtr->input(i);
             if (inputI.first) {
-                if ( std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i)) != inputI.first->getOperator()->getRawOutput(inputI.second)) {
-                    if (nodePtr->getOperator()->operatorType() == OperatorType::Tensor) {
-                        // assert provided Data is of "Tensor" type
-                        nodePtr->getOperator()->associateInput(i, inputI.first->getOperator()->getRawOutput(inputI.second));
-                    }
-                    else {
-                        AIDGE_ASSERT(false, "Non-tensor entries not handled yet, for node {} (of type {}).", nodePtr->name(), nodePtr->type());
-                    }
-                }
+                // Check that tensors are properly connected...
+                AIDGE_ASSERT(nodePtr->getOperator()->getRawInput(i) == inputI.first->getOperator()->getRawOutput(inputI.second),
+                  "Input#{} for node {} ({}) is not properly connected to output#{} of node {} ({}): Data or Tensor mismatch!",
+                    i, nodePtr->name(), nodePtr->type(), inputI.second, inputI.first->name(), inputI.first->type());
             } else {
+                // Input is missing
                 AIDGE_ASSERT(nodePtr->getOperator()->getRawInput(i)
                     && !std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i))->empty(),
                   "Missing input#{} for node {} ({})", i, nodePtr->name(), nodePtr->type());
diff --git a/src/operator/Identity.cpp b/src/operator/Identity.cpp
index f57906dd4..2b8107bfc 100644
--- a/src/operator/Identity.cpp
+++ b/src/operator/Identity.cpp
@@ -13,4 +13,10 @@
 
 #include "aidge/operator/Identity.hpp"
 
-const std::string Aidge::Identity_Op::Type = "Identity";
\ No newline at end of file
+const std::string Aidge::Identity_Op::Type = "Identity";
+
+void Aidge::Identity_Op::forward() {
+    // Perform a shallow copy
+    *(mOutputs[0]) = *(mInputs[0]);
+    runHooks();
+}
-- 
GitLab