From 6121620a35a8f8373453e57c8b01545c91320166 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Tue, 30 Apr 2024 10:42:39 +0200 Subject: [PATCH] Fixed MetaOperator setInput() --- include/aidge/operator/MetaOperator.hpp | 13 +++------- include/aidge/operator/OperatorTensor.hpp | 4 +-- src/operator/MetaOperator.cpp | 31 +++++++++++++++++++++++ 3 files changed, 36 insertions(+), 12 deletions(-) diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp index c677da0f2..a41110161 100644 --- a/include/aidge/operator/MetaOperator.hpp +++ b/include/aidge/operator/MetaOperator.hpp @@ -70,16 +70,9 @@ public: return mScheduler; } - void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final { - AIDGE_ASSERT(data->type() == Tensor::Type, "input data must be of Tensor type"); - AIDGE_ASSERT(inputIdx < mGraph->getOrderedInputs().size(), "associateInput(): inputIdx ({}) out of bound for MetaOperator", inputIdx); - - const auto& inputOp = mGraph->getOrderedInputs()[inputIdx]; - inputOp.first->getOperator()->associateInput(inputOp.second, data); - - // Associate inputs for custom implementation - mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); - } + void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final; + void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final; + void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) override final; bool forwardDims(bool allowDataDependency = false) override final { // Check first that all required inputs are available, otherwise diff --git a/include/aidge/operator/OperatorTensor.hpp b/include/aidge/operator/OperatorTensor.hpp index 6086c5145..a49379327 100644 --- a/include/aidge/operator/OperatorTensor.hpp +++ b/include/aidge/operator/OperatorTensor.hpp @@ -56,8 +56,8 @@ public: /////////////////////////////////////////////////// // Tensor access // input management - void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final; - void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) override final; + void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override; + void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) override; const std::shared_ptr<Tensor>& getInput(const IOIndex_t inputIdx) const; std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final; diff --git a/src/operator/MetaOperator.cpp b/src/operator/MetaOperator.cpp index 46e9e1173..36ff18547 100644 --- a/src/operator/MetaOperator.cpp +++ b/src/operator/MetaOperator.cpp @@ -37,6 +37,37 @@ Aidge::MetaOperator_Op::MetaOperator_Op(const std::string& type, const std::shar } } +void Aidge::MetaOperator_Op::associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) { + AIDGE_ASSERT(data->type() == Tensor::Type, "input data must be of Tensor type"); + AIDGE_ASSERT(inputIdx < mGraph->getOrderedInputs().size(), "associateInput(): inputIdx ({}) out of bound for MetaOperator", inputIdx); + + const auto& inputOp = mGraph->getOrderedInputs()[inputIdx]; + inputOp.first->getOperator()->associateInput(inputOp.second, data); + + // Associate inputs for custom implementation + mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(inputOp.first->getOperator()->getRawInput(inputOp.second)); +} + +void Aidge::MetaOperator_Op::setInput(const Aidge::IOIndex_t inputIdx, const std::shared_ptr<Data>& data) { + AIDGE_ASSERT(data->type() == Tensor::Type, "{} Operator only accepts Tensors as inputs", type()); + + const auto& inputOp = mGraph->getOrderedInputs()[inputIdx]; + inputOp.first->getOperator()->setInput(inputOp.second, data); + + // Associate inputs for custom implementation + mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(inputOp.first->getOperator()->getRawInput(inputOp.second)); +} + +void Aidge::MetaOperator_Op::setInput(const Aidge::IOIndex_t inputIdx, std::shared_ptr<Data>&& data) { + AIDGE_ASSERT(data->type() == Tensor::Type, "{} Operator only accepts Tensors as inputs", type()); + + const auto& inputOp = mGraph->getOrderedInputs()[inputIdx]; + inputOp.first->getOperator()->setInput(inputOp.second, std::forward<std::shared_ptr<Data>>(data)); + + // Associate inputs for custom implementation + mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(inputOp.first->getOperator()->getRawInput(inputOp.second)); +} + Aidge::Elts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputIdx) const { if (mImpl) { return mImpl->getNbRequiredData(inputIdx); -- GitLab