diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp index c677da0f2e34a299ddec6ee85f5a84616206193d..a411101618a5f4acaf070516d67691a6b55e3ff5 100644 --- a/include/aidge/operator/MetaOperator.hpp +++ b/include/aidge/operator/MetaOperator.hpp @@ -70,16 +70,9 @@ public: return mScheduler; } - void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final { - AIDGE_ASSERT(data->type() == Tensor::Type, "input data must be of Tensor type"); - AIDGE_ASSERT(inputIdx < mGraph->getOrderedInputs().size(), "associateInput(): inputIdx ({}) out of bound for MetaOperator", inputIdx); - - const auto& inputOp = mGraph->getOrderedInputs()[inputIdx]; - inputOp.first->getOperator()->associateInput(inputOp.second, data); - - // Associate inputs for custom implementation - mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); - } + void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final; + void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final; + void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) override final; bool forwardDims(bool allowDataDependency = false) override final { // Check first that all required inputs are available, otherwise diff --git a/include/aidge/operator/OperatorTensor.hpp b/include/aidge/operator/OperatorTensor.hpp index 6086c5145eb39cee081468ba91473dc983cfa35f..a493793278d42904d8a62e31571720f94ff1655d 100644 --- a/include/aidge/operator/OperatorTensor.hpp +++ b/include/aidge/operator/OperatorTensor.hpp @@ -56,8 +56,8 @@ public: /////////////////////////////////////////////////// // Tensor access // input management - void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final; - void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) override final; + void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override; + void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) override; const std::shared_ptr<Tensor>& getInput(const IOIndex_t inputIdx) const; std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final; diff --git a/src/operator/MetaOperator.cpp b/src/operator/MetaOperator.cpp index 46e9e1173af98ed5711aa0bbce54705fb61dc03c..36ff1854703d015980a1943390eb87d0863d877f 100644 --- a/src/operator/MetaOperator.cpp +++ b/src/operator/MetaOperator.cpp @@ -37,6 +37,37 @@ Aidge::MetaOperator_Op::MetaOperator_Op(const std::string& type, const std::shar } } +void Aidge::MetaOperator_Op::associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) { + AIDGE_ASSERT(data->type() == Tensor::Type, "input data must be of Tensor type"); + AIDGE_ASSERT(inputIdx < mGraph->getOrderedInputs().size(), "associateInput(): inputIdx ({}) out of bound for MetaOperator", inputIdx); + + const auto& inputOp = mGraph->getOrderedInputs()[inputIdx]; + inputOp.first->getOperator()->associateInput(inputOp.second, data); + + // Associate inputs for custom implementation + mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(inputOp.first->getOperator()->getRawInput(inputOp.second)); +} + +void Aidge::MetaOperator_Op::setInput(const Aidge::IOIndex_t inputIdx, const std::shared_ptr<Data>& data) { + AIDGE_ASSERT(data->type() == Tensor::Type, "{} Operator only accepts Tensors as inputs", type()); + + const auto& inputOp = mGraph->getOrderedInputs()[inputIdx]; + inputOp.first->getOperator()->setInput(inputOp.second, data); + + // Associate inputs for custom implementation + mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(inputOp.first->getOperator()->getRawInput(inputOp.second)); +} + +void Aidge::MetaOperator_Op::setInput(const Aidge::IOIndex_t inputIdx, std::shared_ptr<Data>&& data) { + AIDGE_ASSERT(data->type() == Tensor::Type, "{} Operator only accepts Tensors as inputs", type()); + + const auto& inputOp = mGraph->getOrderedInputs()[inputIdx]; + inputOp.first->getOperator()->setInput(inputOp.second, std::forward<std::shared_ptr<Data>>(data)); + + // Associate inputs for custom implementation + mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(inputOp.first->getOperator()->getRawInput(inputOp.second)); +} + Aidge::Elts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputIdx) const { if (mImpl) { return mImpl->getNbRequiredData(inputIdx);