From 6121620a35a8f8373453e57c8b01545c91320166 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Tue, 30 Apr 2024 10:42:39 +0200
Subject: [PATCH] Fixed MetaOperator setInput()

---
 include/aidge/operator/MetaOperator.hpp   | 13 +++-------
 include/aidge/operator/OperatorTensor.hpp |  4 +--
 src/operator/MetaOperator.cpp             | 31 +++++++++++++++++++++++
 3 files changed, 36 insertions(+), 12 deletions(-)

diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp
index c677da0f2..a41110161 100644
--- a/include/aidge/operator/MetaOperator.hpp
+++ b/include/aidge/operator/MetaOperator.hpp
@@ -70,16 +70,9 @@ public:
         return mScheduler;
     }
 
-    void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final {
-        AIDGE_ASSERT(data->type() == Tensor::Type, "input data must be of Tensor type");
-        AIDGE_ASSERT(inputIdx < mGraph->getOrderedInputs().size(), "associateInput(): inputIdx ({}) out of bound for MetaOperator", inputIdx);
-
-        const auto& inputOp = mGraph->getOrderedInputs()[inputIdx];
-        inputOp.first->getOperator()->associateInput(inputOp.second, data);
-
-        // Associate inputs for custom implementation
-        mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
-    }
+    void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final;
+    void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final;
+    void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) override final;
 
     bool forwardDims(bool allowDataDependency = false) override final {
         // Check first that all required inputs are available, otherwise
diff --git a/include/aidge/operator/OperatorTensor.hpp b/include/aidge/operator/OperatorTensor.hpp
index 6086c5145..a49379327 100644
--- a/include/aidge/operator/OperatorTensor.hpp
+++ b/include/aidge/operator/OperatorTensor.hpp
@@ -56,8 +56,8 @@ public:
     ///////////////////////////////////////////////////
     // Tensor access
     // input management
-    void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final;
-    void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) override final;
+    void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override;
+    void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) override;
     const std::shared_ptr<Tensor>& getInput(const IOIndex_t inputIdx) const;
     std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final;
 
diff --git a/src/operator/MetaOperator.cpp b/src/operator/MetaOperator.cpp
index 46e9e1173..36ff18547 100644
--- a/src/operator/MetaOperator.cpp
+++ b/src/operator/MetaOperator.cpp
@@ -37,6 +37,37 @@ Aidge::MetaOperator_Op::MetaOperator_Op(const std::string& type, const std::shar
     }
 }
 
+void Aidge::MetaOperator_Op::associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) {
+    AIDGE_ASSERT(data->type() == Tensor::Type, "input data must be of Tensor type");
+    AIDGE_ASSERT(inputIdx < mGraph->getOrderedInputs().size(), "associateInput(): inputIdx ({}) out of bound for MetaOperator", inputIdx);
+
+    const auto& inputOp = mGraph->getOrderedInputs()[inputIdx];
+    inputOp.first->getOperator()->associateInput(inputOp.second, data);
+
+    // Associate inputs for custom implementation
+    mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(inputOp.first->getOperator()->getRawInput(inputOp.second));
+}
+
+void Aidge::MetaOperator_Op::setInput(const Aidge::IOIndex_t inputIdx, const std::shared_ptr<Data>& data) {
+    AIDGE_ASSERT(data->type() == Tensor::Type, "{} Operator only accepts Tensors as inputs", type());
+
+    const auto& inputOp = mGraph->getOrderedInputs()[inputIdx];
+    inputOp.first->getOperator()->setInput(inputOp.second, data);
+
+    // Associate inputs for custom implementation
+    mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(inputOp.first->getOperator()->getRawInput(inputOp.second));
+}
+
+void Aidge::MetaOperator_Op::setInput(const Aidge::IOIndex_t inputIdx, std::shared_ptr<Data>&& data) {
+    AIDGE_ASSERT(data->type() == Tensor::Type, "{} Operator only accepts Tensors as inputs", type());
+
+    const auto& inputOp = mGraph->getOrderedInputs()[inputIdx];
+    inputOp.first->getOperator()->setInput(inputOp.second, std::forward<std::shared_ptr<Data>>(data));
+
+    // Associate inputs for custom implementation
+    mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(inputOp.first->getOperator()->getRawInput(inputOp.second));
+}
+
 Aidge::Elts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputIdx) const {
     if (mImpl) {
         return mImpl->getNbRequiredData(inputIdx);
-- 
GitLab