From 8c0397d716673957cdf2e44fd94f95e75771d5e9 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Fri, 20 Oct 2023 16:24:25 +0200 Subject: [PATCH] Added default operator impl with default producer-consumer model --- include/aidge/backend/OperatorImpl.hpp | 25 ++++++--- src/backend/OperatorImpl.cpp | 77 ++++++++++++++++++++++++++ 2 files changed, 93 insertions(+), 9 deletions(-) create mode 100644 src/backend/OperatorImpl.cpp diff --git a/include/aidge/backend/OperatorImpl.hpp b/include/aidge/backend/OperatorImpl.hpp index 453e30a86..19f083750 100644 --- a/include/aidge/backend/OperatorImpl.hpp +++ b/include/aidge/backend/OperatorImpl.hpp @@ -18,11 +18,13 @@ #include "aidge/utils/Types.h" namespace Aidge { +class Operator; + class OperatorImpl { public: - - virtual void forward(){}; - virtual void backward(){}; + OperatorImpl(const Operator& op); + virtual void forward(); + virtual void backward(); /** * @brief Minimum amount of data from a specific input required by the @@ -31,13 +33,13 @@ public: * @param inputIdx Index of the input analysed. * @return std::size_t */ - virtual NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const = 0; + virtual NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const; // Amount of input data that cannot be overwritten during the execution. - virtual NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const = 0; + virtual NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const; // Memory required at an output for a given input size. - virtual NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const = 0; + virtual NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const; /** * @brief Total amount of consumed data from a specific input. @@ -45,7 +47,7 @@ public: * @param inputIdx Index of the input analysed. * @return DimSize_t */ - virtual NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const = 0; + virtual NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const; /** * @brief Total amount of produced data ready to be used on a specific output. @@ -53,15 +55,20 @@ public: * @param outputIdx Index of the output analysed. * @return DimSize_t */ - virtual NbElts_t getNbProducedData(const IOIndex_t outputIdx) const = 0; + virtual NbElts_t getNbProducedData(const IOIndex_t outputIdx) const; /** * @brief Update the Consummer Producer system by simulating the consumption and production of i/o * */ - virtual void updateConsummerProducer() = 0; + virtual void updateConsummerProducer(); virtual ~OperatorImpl() = default; + +protected: + const Operator &mOp; + std::vector<NbElts_t> mNbConsumedData; + std::vector<NbElts_t> mNbProducedData; }; } // namespace Aidge diff --git a/src/backend/OperatorImpl.cpp b/src/backend/OperatorImpl.cpp new file mode 100644 index 000000000..166754cc9 --- /dev/null +++ b/src/backend/OperatorImpl.cpp @@ -0,0 +1,77 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <cassert> + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/operator/Operator.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/utils/ErrorHandling.hpp" + +Aidge::OperatorImpl::OperatorImpl(const Operator& op): + mOp(op), + mNbConsumedData(mOp.nbInputs(), 0), + mNbProducedData(mOp.nbOutputs(), 0) +{ + //ctor +} + +Aidge::NbElts_t Aidge::OperatorImpl::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const { + assert(mOp.getInput(inputIdx) && "requires valid input"); + + // Requires the whole tensor by default + return std::static_pointer_cast<Tensor>(mOp.getInput(inputIdx))->size(); +} + +Aidge::NbElts_t Aidge::OperatorImpl::getNbRequiredProtected(IOIndex_t inputIdx) const { + assert(mOp.getInput(inputIdx) && "requires valid input"); + + // Protect the whole tensor by default + return std::static_pointer_cast<Tensor>(mOp.getInput(inputIdx))->size(); +} + +Aidge::NbElts_t Aidge::OperatorImpl::getRequiredMemory(const Aidge::IOIndex_t outputIdx, + const std::vector<Aidge::DimSize_t> &/*inputsSize*/) const { + assert(mOp.getOutput(outputIdx) && "requires valid output"); + + // Requires the whole tensor by default, regardless of available data on inputs + return std::static_pointer_cast<Tensor>(mOp.getOutput(outputIdx))->size(); +} + +Aidge::NbElts_t Aidge::OperatorImpl::getNbConsumedData(Aidge::IOIndex_t inputIdx) const { + assert(static_cast<std::size_t>(inputIdx) < mNbConsumedData.size()); + return mNbConsumedData[static_cast<std::size_t>(inputIdx)]; +} + +Aidge::NbElts_t Aidge::OperatorImpl::getNbProducedData(Aidge::IOIndex_t outputIdx) const { + assert(static_cast<std::size_t>(outputIdx) < mNbProducedData.size()); + return mNbProducedData[static_cast<std::size_t>(outputIdx)]; +} + +void Aidge::OperatorImpl::updateConsummerProducer(){ + // Update producer-consumer data + for (std::size_t inputIdx = 0; inputIdx < mNbConsumedData.size(); ++inputIdx) { + // each input is consumed by the minimum amount for a forward pass + mNbConsumedData[inputIdx] += getNbRequiredData(static_cast<IOIndex_t>(inputIdx)); + } + + for (std::size_t outputIdx = 0; outputIdx < mNbProducedData.size(); ++outputIdx) { + mNbProducedData[outputIdx] += getRequiredMemory(outputIdx, {}); + } +} + +void Aidge::OperatorImpl::forward() { + AIDGE_THROW_OR_ABORT(std::runtime_error, "forward() not implemented"); +} + +void Aidge::OperatorImpl::backward() { + AIDGE_THROW_OR_ABORT(std::runtime_error, "backward() not implemented"); +} -- GitLab