From c241006751b2dcb130d2292a56780adb93d06eef Mon Sep 17 00:00:00 2001 From: Noam ZERAH <noam.zerah@cea.fr> Date: Mon, 24 Mar 2025 15:55:28 +0000 Subject: [PATCH] Fixing the resetInput method to work with MetaOperators --- include/aidge/operator/MetaOperator.hpp | 8 +++++++- include/aidge/operator/OperatorTensor.hpp | 2 +- src/operator/MetaOperator.cpp | 8 ++++++++ 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp index 63ca56d0b..cbc9cc118 100644 --- a/include/aidge/operator/MetaOperator.hpp +++ b/include/aidge/operator/MetaOperator.hpp @@ -122,7 +122,13 @@ public: * @param data Shared pointer to the data tensor. */ void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final; - + + /** + * @brief Resets the input tensor at a given index. + * @param[in] inputIdx Index of the input to reset. + */ + void resetInput(const IOIndex_t inputIdx) override; + /** * @brief Forward the dimensions through the micro-graph. * diff --git a/include/aidge/operator/OperatorTensor.hpp b/include/aidge/operator/OperatorTensor.hpp index 0e3d275eb..1b2035222 100644 --- a/include/aidge/operator/OperatorTensor.hpp +++ b/include/aidge/operator/OperatorTensor.hpp @@ -94,7 +94,7 @@ public: * @brief Resets the input tensor at a given index. * @param[in] inputIdx Index of the input to reset. */ - void resetInput(const IOIndex_t inputIdx) override final; + virtual void resetInput(const IOIndex_t inputIdx) override; /////////////////////////////////////////////////// /////////////////////////////////////////////////// diff --git a/src/operator/MetaOperator.cpp b/src/operator/MetaOperator.cpp index 192cc9f5e..939c7bb23 100644 --- a/src/operator/MetaOperator.cpp +++ b/src/operator/MetaOperator.cpp @@ -106,6 +106,14 @@ void Aidge::MetaOperator_Op::setInput(const Aidge::IOIndex_t inputIdx, const std mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(inputOp.first->getOperator()->getRawInput(inputOp.second)); } +void Aidge::MetaOperator_Op::resetInput(const Aidge::IOIndex_t inputIdx) { + const auto& inputOp = mGraph->getOrderedInputs()[inputIdx]; + AIDGE_ASSERT(inputIdx < inputOp.first->nbInputs(), "Input idx out of range."); + inputOp.first->getOperator()->resetInput(inputIdx); + + mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(inputOp.first->getOperator()->getRawInput(inputOp.second)); +} + std::string Aidge::MetaOperator_Op::backend() const noexcept { return (mImpl) ? mImpl->backend() -- GitLab