Newer
Older
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cassert>
#include <string>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Operator.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"
Aidge::OperatorImpl::OperatorImpl(const Operator& op, const std::string& backend):
mOp(op),
mBackend(backend),
mNbConsumedData(mOp.nbInputs(), Elts_t::NoneElts()),
mNbProducedData(mOp.nbOutputs(), Elts_t::NoneElts())
{
//ctor
}
Aidge::Elts_t Aidge::OperatorImpl::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const {
AIDGE_ASSERT(mOp.getRawInput(inputIdx),
"a valid input is required at index {} for operator type {}",
inputIdx, mOp.type());
if (mOp.getRawInput(inputIdx)) {
const auto input = std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx));
if (!input->empty()) {
// Known amount of data: requires the whole tensor by default
return Elts_t::DataElts(input->size());
}
else {
// Unknown amount of data: require a single token by default
return Elts_t::TokenElts(1);
}
}
// Input not connected, meaning it is an optional input: do no require anything!
return Elts_t::NoneElts();
Aidge::Elts_t Aidge::OperatorImpl::getNbRequiredProtected(IOIndex_t inputIdx) const {
AIDGE_ASSERT(mOp.getRawInput(inputIdx),
"a valid input is required at index {} for operator type {}",
inputIdx, mOp.type());
if (mOp.getRawInput(inputIdx)) {
const auto input = std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx));
if (!input->empty()) {
// Known amount of data: protect the whole tensor by default
return Elts_t::DataElts(input->size());
}
else {
// Unknown amount of data: protect a single token by default
// (this does not really make sense for now, as getNbRequiredProtected()
// is supposed to give a precise amount of data to protect for
// memory management purpose...)
return Elts_t::TokenElts(1);
}
}
// Input not connected, meaning it is an optional input: do no require anything!
return Elts_t::NoneElts();
Aidge::Elts_t Aidge::OperatorImpl::getRequiredMemory(const Aidge::IOIndex_t outputIdx,
const std::vector<Aidge::DimSize_t> &/*inputsSize*/) const {
AIDGE_ASSERT(mOp.getRawOutput(outputIdx),
"a valid output is required at index {} for operator type {}",
outputIdx, mOp.type());
if (mOp.getRawOutput(outputIdx)) {
const auto output = std::static_pointer_cast<Tensor>(mOp.getRawOutput(outputIdx));
if (!output->empty()) {
// Known amount of data: requires the whole tensor by default,
// regardless of available data on inputs
return Elts_t::DataElts(output->size());
}
else {
// Unknown amount of data: require a single token by default
// (this does not really make sense for now, as getRequiredMemory()
// is supposed to give a precise amount of data to allocate for
// memory management purpose...)
return Elts_t::TokenElts(1);
}
}
// Output not set, meaning it is an optional output: do no require anything!
return Elts_t::NoneElts();
Aidge::Elts_t Aidge::OperatorImpl::getNbConsumedData(Aidge::IOIndex_t inputIdx) const {
AIDGE_ASSERT(static_cast<std::size_t>(inputIdx) < mNbConsumedData.size(),
"input index ({}) is out of bound ({}) for operator type {}",
inputIdx, mNbConsumedData.size(), mOp.type());
return mNbConsumedData[static_cast<std::size_t>(inputIdx)];
}
Aidge::Elts_t Aidge::OperatorImpl::getNbProducedData(Aidge::IOIndex_t outputIdx) const {
AIDGE_ASSERT(static_cast<std::size_t>(outputIdx) < mNbProducedData.size(),
"output index ({}) is out of bound ({}) for operator type {}",
outputIdx, mNbProducedData.size(), mOp.type());
return mNbProducedData[static_cast<std::size_t>(outputIdx)];
}
void Aidge::OperatorImpl::updateConsummerProducer(){
// Update producer-consumer data
for (std::size_t inputIdx = 0; inputIdx < mNbConsumedData.size(); ++inputIdx) {
// each input is consumed by the minimum amount for a forward pass
mNbConsumedData[inputIdx] += getNbRequiredData(static_cast<IOIndex_t>(inputIdx));
}
for (std::size_t outputIdx = 0; outputIdx < mNbProducedData.size(); ++outputIdx) {
mNbProducedData[outputIdx] += getRequiredMemory(outputIdx, {});
}
}
void Aidge::OperatorImpl::resetConsummerProducer(){
std::fill(mNbConsumedData.begin(), mNbConsumedData.end(), Elts_t::NoneElts());
std::fill(mNbProducedData.begin(), mNbProducedData.end(), Elts_t::NoneElts());
void Aidge::OperatorImpl::forward() {
AIDGE_THROW_OR_ABORT(std::runtime_error, "forward() not implemented yet for operator of type {}", mOp.type());
}
void Aidge::OperatorImpl::backward() {
AIDGE_THROW_OR_ABORT(std::runtime_error, "backward() not implemented yet for operator of type {}", mOp.type());