Skip to content
Snippets Groups Projects
Commit d9e59761 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Updated C-P model to work with both data and tokens

parent 26c97931
No related branches found
No related tags found
No related merge requests found
......@@ -16,6 +16,7 @@
#include <vector>
#include <memory>
#include "aidge/utils/Types.h"
#include "aidge/data/Elts.hpp"
namespace Aidge {
class Operator;
......@@ -33,13 +34,13 @@ public:
* @param inputIdx Index of the input analysed.
* @return std::size_t
*/
virtual NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const;
virtual Elts_t getNbRequiredData(const IOIndex_t inputIdx) const;
// Amount of input data that cannot be overwritten during the execution.
virtual NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const;
virtual Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const;
// Memory required at an output for a given input size.
virtual NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const;
virtual Elts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const;
/**
* @brief Total amount of consumed data from a specific input.
......@@ -47,7 +48,7 @@ public:
* @param inputIdx Index of the input analysed.
* @return DimSize_t
*/
virtual NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const;
virtual Elts_t getNbConsumedData(const IOIndex_t inputIdx) const;
/**
* @brief Total amount of produced data ready to be used on a specific output.
......@@ -55,7 +56,7 @@ public:
* @param outputIdx Index of the output analysed.
* @return DimSize_t
*/
virtual NbElts_t getNbProducedData(const IOIndex_t outputIdx) const;
virtual Elts_t getNbProducedData(const IOIndex_t outputIdx) const;
/**
* @brief Update the Consummer Producer system by simulating the consumption and production of i/o
......@@ -73,8 +74,8 @@ public:
protected:
const Operator &mOp;
std::vector<NbElts_t> mNbConsumedData;
std::vector<NbElts_t> mNbProducedData;
std::vector<Elts_t> mNbConsumedData;
std::vector<Elts_t> mNbProducedData;
};
} // namespace Aidge
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_ELTS_H_
#define AIDGE_ELTS_H_
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
/**
* Base object for Aidge consumer-producer model (C-P model).
* It is a hybrid model: operator implementations can specify their C-P model
* with precise data (bytes) or with tokens.
*/
struct Elts_t {
enum EltType {
Data,
Token,
Undef
};
NbElts_t data;
NbElts_t token;
EltType type;
// Addition operator
inline Elts_t operator+(const Elts_t& other) const {
AIDGE_ASSERT(type == other.type || other.type == Undef || type == Undef,
"Incompatible C-P model types: {} + {}. Data and Token cannot be mixed.", type, other.type);
return Elts_t(data + other.data, token + other.token, (other.type == Undef) ? type : other.type);
}
// Addition assignment operator
inline Elts_t& operator+=(const Elts_t& other) {
AIDGE_ASSERT(type == other.type || other.type == Undef || type == Undef,
"Incompatible C-P model types: {} += {}. Data and Token cannot be mixed.", type, other.type);
data += other.data;
token += other.token;
type = (other.type == Undef) ? type : other.type;
return *this;
}
// Comparison operators
inline bool operator<(const Elts_t& other) const {
if (type == Elts_t::Undef || type == Elts_t::Token) {
// Nothing, or only a token is required: don't care about how much data has been produced for the token
return (token < other.token);
}
else if (type == Elts_t::Data && other.type != Elts_t::Token) {
// A precise amount of data is required, so the amount of produced data must be specified, a token is not enough
return (data < other.data);
}
else {
AIDGE_THROW_OR_ABORT(std::runtime_error,
"Incompatible C-P model types: {} < {}. Data is expected for right-hand side.", type, other.type);
}
}
inline bool operator>(const Elts_t& other) const {
if (type == Elts_t::Undef || type == Elts_t::Token) {
// Nothing, or only a token is required: don't care about how much data has been produced for the token
return (token > other.token);
}
else if (type == Elts_t::Data && other.type != Elts_t::Token) {
// A precise amount of data is required, so the amount of produced data must be specified, a token is not enough
return (data > other.data);
}
else {
AIDGE_THROW_OR_ABORT(std::runtime_error,
"Incompatible C-P model types: {} > {}. Data is expected for right-hand side.", type, other.type);
}
}
inline static Elts_t NoneElts() {
return Elts_t(0, 0, Elts_t::Undef);
}
inline static Elts_t DataElts(NbElts_t data, NbElts_t token = 1) {
return Elts_t(data, token, Elts_t::Data);
}
inline static Elts_t TokenElts(NbElts_t token) {
return Elts_t(0, token, Elts_t::Token);
}
private:
inline Elts_t(NbElts_t data_, NbElts_t token_, EltType type_):
data(data_), token(token_), type(type_) {}
};
} // end namespace Aidge
template<>
struct fmt::formatter<Aidge::Elts_t> {
template<typename ParseContext>
inline constexpr auto parse(ParseContext& ctx) {
return ctx.begin();
}
template<typename FormatContext>
inline auto format(Aidge::Elts_t const& elt, FormatContext& ctx) {
return fmt::format_to(ctx.out(), "{}:{}", elt.data, elt.token);
}
};
namespace {
template <>
const char* const EnumStrings<Aidge::Elts_t::EltType>::data[]
= {"Data", "Token", "Undef"};
}
namespace Aidge {
inline auto format_as(Elts_t::EltType elt) { return EnumStrings<Aidge::Elts_t::EltType>::data[static_cast<int>(elt)]; }
}
#endif /* AIDGE_ELTS_H_ */
......@@ -107,11 +107,11 @@ public:
mGraph->setDataType(datatype);
}
NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override;
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override;
NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const override;
NbElts_t getNbConsumedData(IOIndex_t inputIdx) const override;
NbElts_t getNbProducedData(IOIndex_t outputIdx) const override;
Elts_t getNbRequiredData(const IOIndex_t inputIdx) const override;
Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override;
Elts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const override;
Elts_t getNbConsumedData(IOIndex_t inputIdx) const override;
Elts_t getNbProducedData(IOIndex_t outputIdx) const override;
void updateConsummerProducer() override;
void forward() override;
......
......@@ -131,31 +131,31 @@ public:
/**
* @brief Minimum amount of data from a specific input for one computation pass.
* @param inputIdx Index of the input analysed.
* @return NbElts_t
* @return Elts_t
*/
virtual NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const;
virtual Elts_t getNbRequiredData(const IOIndex_t inputIdx) const;
// Amount of input data that cannot be overwritten during the execution.
virtual NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const;
virtual Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const;
// Memory required at an output for a given input size.
virtual NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const;
virtual Elts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const;
/**
* @brief Total amount of consumed data from a specific input.
*
* @param inputIdx Index of the input analysed.
* @return NbElts_t
* @return Elts_t
*/
virtual NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const;
virtual Elts_t getNbConsumedData(const IOIndex_t inputIdx) const;
/**
* @brief Total amount of produced data ready to be used on a specific output.
*
* @param outputIdx Index of the output analysed.
* @return NbElts_t
* @return Elts_t
*/
virtual NbElts_t getNbProducedData(const IOIndex_t outputIdx) const;
virtual Elts_t getNbProducedData(const IOIndex_t outputIdx) const;
virtual void updateConsummerProducer();
......
......@@ -141,7 +141,7 @@ protected:
* @return std::set<std::shared_ptr<Node>>
*/
std::set<std::shared_ptr<Node>> getConsumers(const std::set<std::shared_ptr<Node>>& producers) const;
NbElts_t getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const;
Elts_t getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const;
PriorProducersConsumers getPriorProducersConsumers(const std::shared_ptr<Node>& node) const;
/** @brief Shared ptr to the scheduled graph view */
......
......@@ -14,6 +14,7 @@
#define AIDGE_ERRORHANDLING_H_
#include <memory>
#include <cassert>
#include <fmt/format.h>
#include <fmt/ranges.h>
......
......@@ -42,18 +42,18 @@ public:
);
}
NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override {
Elts_t getNbRequiredData(const IOIndex_t inputIdx) const override {
PYBIND11_OVERRIDE_NAME(
NbElts_t,
Elts_t,
OperatorImpl,
"get_nb_required_data",
getNbRequiredData,
inputIdx
);
}
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override {
Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override {
PYBIND11_OVERRIDE_NAME(
NbElts_t,
Elts_t,
OperatorImpl,
"get_nb_required_protected",
getNbRequiredProtected,
......@@ -61,10 +61,10 @@ public:
);
}
NbElts_t getRequiredMemory(const IOIndex_t outputIdx,
Elts_t getRequiredMemory(const IOIndex_t outputIdx,
const std::vector<DimSize_t> &inputsSize) const override {
PYBIND11_OVERRIDE_NAME(
NbElts_t,
Elts_t,
OperatorImpl,
"get_required_memory",
getRequiredMemory,
......@@ -73,9 +73,9 @@ public:
);
}
NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const override {
Elts_t getNbConsumedData(const IOIndex_t inputIdx) const override {
PYBIND11_OVERRIDE_NAME(
NbElts_t,
Elts_t,
OperatorImpl,
"get_nb_consumed_data",
getNbConsumedData,
......@@ -83,9 +83,9 @@ public:
);
}
NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override {
Elts_t getNbProducedData(const IOIndex_t outputIdx) const override {
PYBIND11_OVERRIDE_NAME(
NbElts_t,
Elts_t,
OperatorImpl,
"get_nb_produced_data",
getNbProducedData,
......
......@@ -18,48 +18,91 @@
Aidge::OperatorImpl::OperatorImpl(const Operator& op):
mOp(op),
mNbConsumedData(mOp.nbInputs(), 0),
mNbProducedData(mOp.nbOutputs(), 0)
mNbConsumedData(mOp.nbInputs(), Elts_t::NoneElts()),
mNbProducedData(mOp.nbOutputs(), Elts_t::NoneElts())
{
//ctor
}
Aidge::NbElts_t Aidge::OperatorImpl::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const {
Aidge::Elts_t Aidge::OperatorImpl::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const {
AIDGE_ASSERT(mOp.getRawInput(inputIdx),
"a valid input is required at index {} for operator type {}",
inputIdx, mOp.type());
// Requires the whole tensor by default
return std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx))->size();
if (mOp.getRawInput(inputIdx)) {
const auto input = std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx));
if (!input->empty()) {
// Known amount of data: requires the whole tensor by default
return Elts_t::DataElts(input->size());
}
else {
// Unknown amount of data: require a single token by default
return Elts_t::TokenElts(1);
}
}
// Input not connected, meaning it is an optional input: do no require anything!
return Elts_t::NoneElts();
}
Aidge::NbElts_t Aidge::OperatorImpl::getNbRequiredProtected(IOIndex_t inputIdx) const {
Aidge::Elts_t Aidge::OperatorImpl::getNbRequiredProtected(IOIndex_t inputIdx) const {
AIDGE_ASSERT(mOp.getRawInput(inputIdx),
"a valid input is required at index {} for operator type {}",
inputIdx, mOp.type());
// Protect the whole tensor by default
return std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx))->size();
if (mOp.getRawInput(inputIdx)) {
const auto input = std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx));
if (!input->empty()) {
// Known amount of data: protect the whole tensor by default
return Elts_t::DataElts(input->size());
}
else {
// Unknown amount of data: protect a single token by default
// (this does not really make sense for now, as getNbRequiredProtected()
// is supposed to give a precise amount of data to protect for
// memory management purpose...)
return Elts_t::TokenElts(1);
}
}
// Input not connected, meaning it is an optional input: do no require anything!
return Elts_t::NoneElts();
}
Aidge::NbElts_t Aidge::OperatorImpl::getRequiredMemory(const Aidge::IOIndex_t outputIdx,
Aidge::Elts_t Aidge::OperatorImpl::getRequiredMemory(const Aidge::IOIndex_t outputIdx,
const std::vector<Aidge::DimSize_t> &/*inputsSize*/) const {
AIDGE_ASSERT(mOp.getRawOutput(outputIdx),
"a valid output is required at index {} for operator type {}",
outputIdx, mOp.type());
// Requires the whole tensor by default, regardless of available data on inputs
return std::static_pointer_cast<Tensor>(mOp.getRawOutput(outputIdx))->size();
if (mOp.getRawOutput(outputIdx)) {
const auto output = std::static_pointer_cast<Tensor>(mOp.getRawOutput(outputIdx));
if (!output->empty()) {
// Known amount of data: requires the whole tensor by default,
// regardless of available data on inputs
return Elts_t::DataElts(output->size());
}
else {
// Unknown amount of data: require a single token by default
// (this does not really make sense for now, as getRequiredMemory()
// is supposed to give a precise amount of data to allocate for
// memory management purpose...)
return Elts_t::TokenElts(1);
}
}
// Output not set, meaning it is an optional output: do no require anything!
return Elts_t::NoneElts();
}
Aidge::NbElts_t Aidge::OperatorImpl::getNbConsumedData(Aidge::IOIndex_t inputIdx) const {
Aidge::Elts_t Aidge::OperatorImpl::getNbConsumedData(Aidge::IOIndex_t inputIdx) const {
AIDGE_ASSERT(static_cast<std::size_t>(inputIdx) < mNbConsumedData.size(),
"input index ({}) is out of bound ({}) for operator type {}",
inputIdx, mNbConsumedData.size(), mOp.type());
return mNbConsumedData[static_cast<std::size_t>(inputIdx)];
}
Aidge::NbElts_t Aidge::OperatorImpl::getNbProducedData(Aidge::IOIndex_t outputIdx) const {
Aidge::Elts_t Aidge::OperatorImpl::getNbProducedData(Aidge::IOIndex_t outputIdx) const {
AIDGE_ASSERT(static_cast<std::size_t>(outputIdx) < mNbProducedData.size(),
"output index ({}) is out of bound ({}) for operator type {}",
outputIdx, mNbProducedData.size(), mOp.type());
......@@ -79,8 +122,8 @@ void Aidge::OperatorImpl::updateConsummerProducer(){
}
void Aidge::OperatorImpl::resetConsummerProducer(){
std::fill(mNbConsumedData.begin(), mNbConsumedData.end(), 0);
std::fill(mNbProducedData.begin(), mNbProducedData.end(), 0);
std::fill(mNbConsumedData.begin(), mNbConsumedData.end(), Elts_t::NoneElts());
std::fill(mNbProducedData.begin(), mNbProducedData.end(), Elts_t::NoneElts());
}
void Aidge::OperatorImpl::forward() {
......
......@@ -30,7 +30,7 @@ Aidge::MetaOperator_Op::MetaOperator_Op(const char *type, const std::shared_ptr<
}
}
Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputIdx) const {
Aidge::Elts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputIdx) const {
if (mImpl) {
return mImpl->getNbRequiredData(inputIdx);
}
......@@ -40,12 +40,12 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputI
return inputOp.first->getOperator()->getNbRequiredData(inputOp.second);
}
else {
return 0;
return Elts_t::NoneElts();
}
}
}
Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredProtected(const IOIndex_t inputIdx) const {
Aidge::Elts_t Aidge::MetaOperator_Op::getNbRequiredProtected(const IOIndex_t inputIdx) const {
if (mImpl) {
return mImpl->getNbRequiredProtected(inputIdx);
}
......@@ -55,12 +55,12 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredProtected(const IOIndex_t i
return inputOp.first->getOperator()->getNbRequiredProtected(inputOp.second);
}
else {
return 0;
return Elts_t::NoneElts();
}
}
}
Aidge::NbElts_t Aidge::MetaOperator_Op::getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const {
Aidge::Elts_t Aidge::MetaOperator_Op::getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const {
if (mImpl) {
return mImpl->getRequiredMemory(outputIdx, inputsSize);
}
......@@ -70,12 +70,12 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getRequiredMemory(const IOIndex_t output
return outputOp.first->getOperator()->getRequiredMemory(outputOp.second, inputsSize);
}
else {
return 0;
return Elts_t::NoneElts();
}
}
}
Aidge::NbElts_t Aidge::MetaOperator_Op::getNbConsumedData(IOIndex_t inputIdx) const {
Aidge::Elts_t Aidge::MetaOperator_Op::getNbConsumedData(IOIndex_t inputIdx) const {
if (mImpl) {
return mImpl->getNbConsumedData(inputIdx);
}
......@@ -85,12 +85,12 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbConsumedData(IOIndex_t inputIdx) co
return inputOp.first->getOperator()->getNbConsumedData(inputOp.second);
}
else {
return 0;
return Elts_t::NoneElts();
}
}
}
Aidge::NbElts_t Aidge::MetaOperator_Op::getNbProducedData(IOIndex_t outputIdx) const {
Aidge::Elts_t Aidge::MetaOperator_Op::getNbProducedData(IOIndex_t outputIdx) const {
if (mImpl) {
return mImpl->getNbProducedData(outputIdx);
}
......@@ -100,7 +100,7 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbProducedData(IOIndex_t outputIdx) c
return outputOp.first->getOperator()->getNbProducedData(outputOp.second);
}
else {
return 0;
return Elts_t::NoneElts();
}
}
}
......
......@@ -31,27 +31,27 @@ Aidge::Operator::~Operator() noexcept = default;
// IMPLEMENTATION
///////////////////////////////////////////////////////
Aidge::NbElts_t Aidge::Operator::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const {
Aidge::Elts_t Aidge::Operator::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const {
AIDGE_ASSERT(mImpl != nullptr, "getNbRequiredData(): an implementation is required for {}!", type());
return mImpl->getNbRequiredData(inputIdx);
}
Aidge::NbElts_t Aidge::Operator::getNbRequiredProtected(const Aidge::IOIndex_t inputIdx) const {
Aidge::Elts_t Aidge::Operator::getNbRequiredProtected(const Aidge::IOIndex_t inputIdx) const {
AIDGE_ASSERT(mImpl != nullptr, "getNbRequiredProtected(): an implementation is required for {}!", type());
return mImpl->getNbRequiredProtected(inputIdx);
}
Aidge::NbElts_t Aidge::Operator::getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const {
Aidge::Elts_t Aidge::Operator::getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const {
AIDGE_ASSERT(mImpl != nullptr, "getRequiredMemory(): an implementation is required for {}!", type());
return mImpl->getRequiredMemory(outputIdx, inputsSize);
}
Aidge::NbElts_t Aidge::Operator::getNbConsumedData(Aidge::IOIndex_t inputIdx) const {
Aidge::Elts_t Aidge::Operator::getNbConsumedData(Aidge::IOIndex_t inputIdx) const {
AIDGE_ASSERT(mImpl != nullptr, "getNbConsumedData(): an implementation is required for {}!", type());
return mImpl->getNbConsumedData(inputIdx);
}
Aidge::NbElts_t Aidge::Operator::getNbProducedData(Aidge::IOIndex_t outputIdx) const {
Aidge::Elts_t Aidge::Operator::getNbProducedData(Aidge::IOIndex_t outputIdx) const {
AIDGE_ASSERT(mImpl != nullptr, "getNbProducedData(): an implementation is required for {}!", type());
return mImpl->getNbProducedData(outputIdx);
}
......
......@@ -138,8 +138,7 @@ std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::S
bool isRunnable = true;
for (IOIndex_t inputIdx = 0; inputIdx < consumer->nbInputs(); ++inputIdx) {
if (/*consumer->getOperator()->getNbRequiredData(inputIdx) > 0
&& */(consumer->getOperator()->getNbConsumedData(inputIdx) + consumer->getOperator()->getNbRequiredData(inputIdx)) >
if ((consumer->getOperator()->getNbConsumedData(inputIdx) + consumer->getOperator()->getNbRequiredData(inputIdx)) >
getNbAvailableData(consumer, inputIdx)) {
Log::debug(" not runnable: C{} + R{} > P{} for input #{}",
consumer->getOperator()->getNbConsumedData(inputIdx),
......@@ -226,12 +225,17 @@ std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::S
IOIndex_t inputIdx = 0;
for (const auto& childParent : child->getParents()) {
if (childParent == consumer) {
if (consumer->getOperator()->getNbProducedData(outId) > child->getOperator()->getNbConsumedData(inputIdx)) {
if (child->getOperator()->getNbConsumedData(inputIdx) < consumer->getOperator()->getNbProducedData(outId)) {
isProducer = true;
break;
}
}
++inputIdx;
}
if (isProducer) {
break;
}
}
}
/*
......@@ -383,17 +387,22 @@ Aidge::MemoryManager Aidge::Scheduler::generateMemory(bool incProducers, bool wr
}
const auto childs = node->getChildren();
AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor,
"Operator must be of Tensor type for node {} (of type {}).",
node->name(), node->type());
const auto op = std::static_pointer_cast<OperatorTensor>(node->getOperator());
std::vector<const MemoryManager::MemoryPlane*> wrapAroundMemPlane;
// Allocate a memory plane for each node's output
for (IOIndex_t outputIdx = 0; outputIdx < node->nbOutputs(); ++outputIdx) {
const size_t requiredSize = op->getRequiredMemory(outputIdx, {});
const auto requiredSize = op->getRequiredMemory(outputIdx, {});
AIDGE_ASSERT(requiredSize.type == Elts_t::Data,
"Cannot generate memory with token-based producer-consumer model for node {} (of type {}).",
node->name(), node->type());
// By default, specifies a fully monolithic memory block
size_t size = requiredSize;
size_t size = requiredSize.data;
size_t stride = 0;
size_t length = 1;
size_t count = 1;
......@@ -425,21 +434,27 @@ Aidge::MemoryManager Aidge::Scheduler::generateMemory(bool incProducers, bool wr
// memSpace should not be already released
&& memManager.getPlanes(parent.first).end()[-parent.first->nbOutputs()+parent.second].memSpace->released == -1)
{
const bool isWrappable = (op->getNbRequiredProtected(inputIdx) < op->getNbRequiredData(inputIdx));
const auto requiredData = op->getNbRequiredData(inputIdx);
const auto requiredProtected = op->getNbRequiredProtected(inputIdx);
AIDGE_ASSERT(requiredData.type == Elts_t::Data && requiredProtected.type == Elts_t::Data,
"Cannot generate memory with token-based producer-consumer model for node {} (of type {}).",
node->name(), node->type());
const bool isWrappable = (requiredProtected.data < requiredData.data);
const MemoryManager::MemoryPlane& memPlane = memManager.getPlanes(parent.first).end()[-parent.first->nbOutputs()+parent.second];
if (isWrappable || !memManager.isWrapAround(
memPlane.memSpace,
memPlane.getFinalOffset()
- memPlane.memSpace->offset,
requiredSize))
requiredSize.data))
{
if (memPlane.getSize() > wrapAroundSize + op->getNbRequiredProtected(inputIdx)
if (memPlane.getSize() > wrapAroundSize + requiredProtected.data
&& std::find(wrapAroundMemPlane.begin(), wrapAroundMemPlane.end(), &memPlane) == wrapAroundMemPlane.end())
{
wrapAroundSize = memPlane.getSize() - op->getNbRequiredProtected(inputIdx);
if (requiredSize > wrapAroundSize) {
wrapAroundExtra = requiredSize - wrapAroundSize;
wrapAroundSize = memPlane.getSize() - requiredProtected.data;
if (requiredSize.data > wrapAroundSize) {
wrapAroundExtra = requiredSize.data - wrapAroundSize;
}
wrapAroundMemPlane[outputIdx] = &memPlane;
}
......@@ -456,17 +471,17 @@ Aidge::MemoryManager Aidge::Scheduler::generateMemory(bool incProducers, bool wr
const MemoryManager::MemoryPlane& memPlane
= (wrapAroundBuffer && wrapAroundSize > 0)
? (*wrapAroundMemPlane[outputIdx]) :
memManager.allocate(requiredSize, childs, stride, length, count);
memManager.allocate(requiredSize.data, childs, stride, length, count);
if (wrapAroundBuffer && wrapAroundSize > 0) {
memManager.reallocate(memPlane,
node, 0,
requiredSize, true, wrapAroundExtra, childs, stride, length, count);
requiredSize.data, true, wrapAroundExtra, childs, stride, length, count);
}
else {
memManager.reallocate(memPlane.memSpace,
node, memPlane.offset,
requiredSize, false, 0, childs, stride, length, count);
requiredSize.data, false, 0, childs, stride, length, count);
}
}
......@@ -574,7 +589,7 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::Scheduler::getConsumers(
return consumers;
}
Aidge::NbElts_t Aidge::Scheduler::getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const {
Aidge::Elts_t Aidge::Scheduler::getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const {
const auto parent = node->inputs()[inputIdx];
if (parent.first) {
......@@ -605,14 +620,14 @@ Aidge::NbElts_t Aidge::Scheduler::getNbAvailableData(const std::shared_ptr<Node>
// In this case, we assume a single-use data (unlike a Producer, which
// keep producing the data each time it is needed).
fmt::print("No producer node attached to input#{} for node {} ({})\n", inputIdx, node->name(), node->type());
return std::static_pointer_cast<Tensor>(node->getOperator()->getRawInput(inputIdx))->size();
return Elts_t::DataElts(std::static_pointer_cast<Tensor>(node->getOperator()->getRawInput(inputIdx))->size());
}
else {
// Input is not connected, this is an error
AIDGE_THROW_OR_ABORT(std::runtime_error, "Missing input#{} for node {} ({})\n", inputIdx, node->name(), node->type());
}
return 0;
return Elts_t::NoneElts();
}
Aidge::Scheduler::PriorProducersConsumers Aidge::Scheduler::getPriorProducersConsumers(
......
......@@ -75,3 +75,45 @@ TEST_CASE("randomScheduling", "[Scheduler][randomGen]") {
fmt::print("nbUnicity = {}/{}\n", nbUnicity, nbTests);
}
TEST_CASE("randomScheduling_tokens", "[Scheduler][randomGen]") {
const size_t nbTests = 100;
size_t nbUnicity = 0;
for (int test = 0; test < nbTests; ++test) {
std::random_device rd;
const std::mt19937::result_type seed(rd());
RandomGraph randGraph;
randGraph.acyclic = true;
const auto g1 = std::make_shared<GraphView>("g1");
const bool unicity1 = g1->add(randGraph.gen(seed, 10));
if (unicity1) {
const auto orderedInputs = g1->getOrderedInputs();
for (const auto& input : orderedInputs) {
auto prod = Producer({16, 32});
prod->addChild(input.first, 0, input.second);
g1->add(prod);
}
g1->save("schedule");
auto scheduler = SequentialScheduler(g1);
scheduler.generateScheduling();
const auto sch = scheduler.getStaticScheduling();
const auto namePtrTable = g1->getRankedNodesName("{0} ({1}#{3})");
std::vector<std::string> nodesName;
std::transform(sch.begin(), sch.end(),
std::back_inserter(nodesName),
[&namePtrTable](auto val){ return namePtrTable.at(val); });
fmt::print("schedule: {}\n", nodesName);
REQUIRE(sch.size() == 10 + orderedInputs.size());
}
}
fmt::print("nbUnicity = {}/{}\n", nbUnicity, nbTests);
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment