Skip to content
Snippets Groups Projects
Commit 8c0397d7 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added default operator impl with default producer-consumer model

parent 4feec190
No related branches found
No related tags found
1 merge request!37Added default operator impl with default producer-consumer model
Pipeline #33159 passed
...@@ -18,11 +18,13 @@ ...@@ -18,11 +18,13 @@
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
namespace Aidge { namespace Aidge {
class Operator;
class OperatorImpl { class OperatorImpl {
public: public:
OperatorImpl(const Operator& op);
virtual void forward(){}; virtual void forward();
virtual void backward(){}; virtual void backward();
/** /**
* @brief Minimum amount of data from a specific input required by the * @brief Minimum amount of data from a specific input required by the
...@@ -31,13 +33,13 @@ public: ...@@ -31,13 +33,13 @@ public:
* @param inputIdx Index of the input analysed. * @param inputIdx Index of the input analysed.
* @return std::size_t * @return std::size_t
*/ */
virtual NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const = 0; virtual NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const;
// Amount of input data that cannot be overwritten during the execution. // Amount of input data that cannot be overwritten during the execution.
virtual NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const = 0; virtual NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const;
// Memory required at an output for a given input size. // Memory required at an output for a given input size.
virtual NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const = 0; virtual NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const;
/** /**
* @brief Total amount of consumed data from a specific input. * @brief Total amount of consumed data from a specific input.
...@@ -45,7 +47,7 @@ public: ...@@ -45,7 +47,7 @@ public:
* @param inputIdx Index of the input analysed. * @param inputIdx Index of the input analysed.
* @return DimSize_t * @return DimSize_t
*/ */
virtual NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const = 0; virtual NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const;
/** /**
* @brief Total amount of produced data ready to be used on a specific output. * @brief Total amount of produced data ready to be used on a specific output.
...@@ -53,15 +55,20 @@ public: ...@@ -53,15 +55,20 @@ public:
* @param outputIdx Index of the output analysed. * @param outputIdx Index of the output analysed.
* @return DimSize_t * @return DimSize_t
*/ */
virtual NbElts_t getNbProducedData(const IOIndex_t outputIdx) const = 0; virtual NbElts_t getNbProducedData(const IOIndex_t outputIdx) const;
/** /**
* @brief Update the Consummer Producer system by simulating the consumption and production of i/o * @brief Update the Consummer Producer system by simulating the consumption and production of i/o
* *
*/ */
virtual void updateConsummerProducer() = 0; virtual void updateConsummerProducer();
virtual ~OperatorImpl() = default; virtual ~OperatorImpl() = default;
protected:
const Operator &mOp;
std::vector<NbElts_t> mNbConsumedData;
std::vector<NbElts_t> mNbProducedData;
}; };
} // namespace Aidge } // namespace Aidge
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cassert>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Operator.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"
Aidge::OperatorImpl::OperatorImpl(const Operator& op):
mOp(op),
mNbConsumedData(mOp.nbInputs(), 0),
mNbProducedData(mOp.nbOutputs(), 0)
{
//ctor
}
Aidge::NbElts_t Aidge::OperatorImpl::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const {
assert(mOp.getInput(inputIdx) && "requires valid input");
// Requires the whole tensor by default
return std::static_pointer_cast<Tensor>(mOp.getInput(inputIdx))->size();
}
Aidge::NbElts_t Aidge::OperatorImpl::getNbRequiredProtected(IOIndex_t inputIdx) const {
assert(mOp.getInput(inputIdx) && "requires valid input");
// Protect the whole tensor by default
return std::static_pointer_cast<Tensor>(mOp.getInput(inputIdx))->size();
}
Aidge::NbElts_t Aidge::OperatorImpl::getRequiredMemory(const Aidge::IOIndex_t outputIdx,
const std::vector<Aidge::DimSize_t> &/*inputsSize*/) const {
assert(mOp.getOutput(outputIdx) && "requires valid output");
// Requires the whole tensor by default, regardless of available data on inputs
return std::static_pointer_cast<Tensor>(mOp.getOutput(outputIdx))->size();
}
Aidge::NbElts_t Aidge::OperatorImpl::getNbConsumedData(Aidge::IOIndex_t inputIdx) const {
assert(static_cast<std::size_t>(inputIdx) < mNbConsumedData.size());
return mNbConsumedData[static_cast<std::size_t>(inputIdx)];
}
Aidge::NbElts_t Aidge::OperatorImpl::getNbProducedData(Aidge::IOIndex_t outputIdx) const {
assert(static_cast<std::size_t>(outputIdx) < mNbProducedData.size());
return mNbProducedData[static_cast<std::size_t>(outputIdx)];
}
void Aidge::OperatorImpl::updateConsummerProducer(){
// Update producer-consumer data
for (std::size_t inputIdx = 0; inputIdx < mNbConsumedData.size(); ++inputIdx) {
// each input is consumed by the minimum amount for a forward pass
mNbConsumedData[inputIdx] += getNbRequiredData(static_cast<IOIndex_t>(inputIdx));
}
for (std::size_t outputIdx = 0; outputIdx < mNbProducedData.size(); ++outputIdx) {
mNbProducedData[outputIdx] += getRequiredMemory(outputIdx, {});
}
}
void Aidge::OperatorImpl::forward() {
AIDGE_THROW_OR_ABORT(std::runtime_error, "forward() not implemented");
}
void Aidge::OperatorImpl::backward() {
AIDGE_THROW_OR_ABORT(std::runtime_error, "backward() not implemented");
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment