From d9e59761905aba69ad39da6c1b31f5bd72ec27c6 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Fri, 15 Mar 2024 14:31:19 +0100 Subject: [PATCH] Updated C-P model to work with both data and tokens --- include/aidge/backend/OperatorImpl.hpp | 15 ++- include/aidge/data/Elts.hpp | 124 ++++++++++++++++++ include/aidge/operator/MetaOperator.hpp | 10 +- include/aidge/operator/Operator.hpp | 16 +-- include/aidge/scheduler/Scheduler.hpp | 2 +- include/aidge/utils/ErrorHandling.hpp | 1 + .../backend/pybind_OperatorImpl.cpp | 20 +-- src/backend/OperatorImpl.cpp | 73 ++++++++--- src/operator/MetaOperator.cpp | 20 +-- src/operator/Operator.cpp | 10 +- src/scheduler/Scheduler.cpp | 51 ++++--- unit_tests/scheduler/Test_Scheduler.cpp | 42 ++++++ 12 files changed, 305 insertions(+), 79 deletions(-) create mode 100644 include/aidge/data/Elts.hpp diff --git a/include/aidge/backend/OperatorImpl.hpp b/include/aidge/backend/OperatorImpl.hpp index 8b5aba10d..215ac804c 100644 --- a/include/aidge/backend/OperatorImpl.hpp +++ b/include/aidge/backend/OperatorImpl.hpp @@ -16,6 +16,7 @@ #include <vector> #include <memory> #include "aidge/utils/Types.h" +#include "aidge/data/Elts.hpp" namespace Aidge { class Operator; @@ -33,13 +34,13 @@ public: * @param inputIdx Index of the input analysed. * @return std::size_t */ - virtual NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const; + virtual Elts_t getNbRequiredData(const IOIndex_t inputIdx) const; // Amount of input data that cannot be overwritten during the execution. - virtual NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const; + virtual Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const; // Memory required at an output for a given input size. - virtual NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const; + virtual Elts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const; /** * @brief Total amount of consumed data from a specific input. @@ -47,7 +48,7 @@ public: * @param inputIdx Index of the input analysed. * @return DimSize_t */ - virtual NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const; + virtual Elts_t getNbConsumedData(const IOIndex_t inputIdx) const; /** * @brief Total amount of produced data ready to be used on a specific output. @@ -55,7 +56,7 @@ public: * @param outputIdx Index of the output analysed. * @return DimSize_t */ - virtual NbElts_t getNbProducedData(const IOIndex_t outputIdx) const; + virtual Elts_t getNbProducedData(const IOIndex_t outputIdx) const; /** * @brief Update the Consummer Producer system by simulating the consumption and production of i/o @@ -73,8 +74,8 @@ public: protected: const Operator &mOp; - std::vector<NbElts_t> mNbConsumedData; - std::vector<NbElts_t> mNbProducedData; + std::vector<Elts_t> mNbConsumedData; + std::vector<Elts_t> mNbProducedData; }; } // namespace Aidge diff --git a/include/aidge/data/Elts.hpp b/include/aidge/data/Elts.hpp new file mode 100644 index 000000000..1a5a9e10e --- /dev/null +++ b/include/aidge/data/Elts.hpp @@ -0,0 +1,124 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_ELTS_H_ +#define AIDGE_ELTS_H_ + +#include "aidge/utils/ErrorHandling.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { +/** + * Base object for Aidge consumer-producer model (C-P model). + * It is a hybrid model: operator implementations can specify their C-P model + * with precise data (bytes) or with tokens. +*/ +struct Elts_t { + enum EltType { + Data, + Token, + Undef + }; + + NbElts_t data; + NbElts_t token; + EltType type; + + // Addition operator + inline Elts_t operator+(const Elts_t& other) const { + AIDGE_ASSERT(type == other.type || other.type == Undef || type == Undef, + "Incompatible C-P model types: {} + {}. Data and Token cannot be mixed.", type, other.type); + return Elts_t(data + other.data, token + other.token, (other.type == Undef) ? type : other.type); + } + + // Addition assignment operator + inline Elts_t& operator+=(const Elts_t& other) { + AIDGE_ASSERT(type == other.type || other.type == Undef || type == Undef, + "Incompatible C-P model types: {} += {}. Data and Token cannot be mixed.", type, other.type); + data += other.data; + token += other.token; + type = (other.type == Undef) ? type : other.type; + return *this; + } + + // Comparison operators + inline bool operator<(const Elts_t& other) const { + if (type == Elts_t::Undef || type == Elts_t::Token) { + // Nothing, or only a token is required: don't care about how much data has been produced for the token + return (token < other.token); + } + else if (type == Elts_t::Data && other.type != Elts_t::Token) { + // A precise amount of data is required, so the amount of produced data must be specified, a token is not enough + return (data < other.data); + } + else { + AIDGE_THROW_OR_ABORT(std::runtime_error, + "Incompatible C-P model types: {} < {}. Data is expected for right-hand side.", type, other.type); + } + } + + inline bool operator>(const Elts_t& other) const { + if (type == Elts_t::Undef || type == Elts_t::Token) { + // Nothing, or only a token is required: don't care about how much data has been produced for the token + return (token > other.token); + } + else if (type == Elts_t::Data && other.type != Elts_t::Token) { + // A precise amount of data is required, so the amount of produced data must be specified, a token is not enough + return (data > other.data); + } + else { + AIDGE_THROW_OR_ABORT(std::runtime_error, + "Incompatible C-P model types: {} > {}. Data is expected for right-hand side.", type, other.type); + } + } + + inline static Elts_t NoneElts() { + return Elts_t(0, 0, Elts_t::Undef); + } + + inline static Elts_t DataElts(NbElts_t data, NbElts_t token = 1) { + return Elts_t(data, token, Elts_t::Data); + } + + inline static Elts_t TokenElts(NbElts_t token) { + return Elts_t(0, token, Elts_t::Token); + } + +private: + inline Elts_t(NbElts_t data_, NbElts_t token_, EltType type_): + data(data_), token(token_), type(type_) {} +}; +} // end namespace Aidge + +template<> +struct fmt::formatter<Aidge::Elts_t> { + template<typename ParseContext> + inline constexpr auto parse(ParseContext& ctx) { + return ctx.begin(); + } + + template<typename FormatContext> + inline auto format(Aidge::Elts_t const& elt, FormatContext& ctx) { + return fmt::format_to(ctx.out(), "{}:{}", elt.data, elt.token); + } +}; + +namespace { +template <> +const char* const EnumStrings<Aidge::Elts_t::EltType>::data[] + = {"Data", "Token", "Undef"}; +} + +namespace Aidge { +inline auto format_as(Elts_t::EltType elt) { return EnumStrings<Aidge::Elts_t::EltType>::data[static_cast<int>(elt)]; } +} + +#endif /* AIDGE_ELTS_H_ */ diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp index ce328c23f..cd23acd90 100644 --- a/include/aidge/operator/MetaOperator.hpp +++ b/include/aidge/operator/MetaOperator.hpp @@ -107,11 +107,11 @@ public: mGraph->setDataType(datatype); } - NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override; - NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override; - NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const override; - NbElts_t getNbConsumedData(IOIndex_t inputIdx) const override; - NbElts_t getNbProducedData(IOIndex_t outputIdx) const override; + Elts_t getNbRequiredData(const IOIndex_t inputIdx) const override; + Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override; + Elts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const override; + Elts_t getNbConsumedData(IOIndex_t inputIdx) const override; + Elts_t getNbProducedData(IOIndex_t outputIdx) const override; void updateConsummerProducer() override; void forward() override; diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index 396c60e46..6e2e44426 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -131,31 +131,31 @@ public: /** * @brief Minimum amount of data from a specific input for one computation pass. * @param inputIdx Index of the input analysed. - * @return NbElts_t + * @return Elts_t */ - virtual NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const; + virtual Elts_t getNbRequiredData(const IOIndex_t inputIdx) const; // Amount of input data that cannot be overwritten during the execution. - virtual NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const; + virtual Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const; // Memory required at an output for a given input size. - virtual NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const; + virtual Elts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const; /** * @brief Total amount of consumed data from a specific input. * * @param inputIdx Index of the input analysed. - * @return NbElts_t + * @return Elts_t */ - virtual NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const; + virtual Elts_t getNbConsumedData(const IOIndex_t inputIdx) const; /** * @brief Total amount of produced data ready to be used on a specific output. * * @param outputIdx Index of the output analysed. - * @return NbElts_t + * @return Elts_t */ - virtual NbElts_t getNbProducedData(const IOIndex_t outputIdx) const; + virtual Elts_t getNbProducedData(const IOIndex_t outputIdx) const; virtual void updateConsummerProducer(); diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index e0284f0fb..4c5b3bd4c 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -141,7 +141,7 @@ protected: * @return std::set<std::shared_ptr<Node>> */ std::set<std::shared_ptr<Node>> getConsumers(const std::set<std::shared_ptr<Node>>& producers) const; - NbElts_t getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const; + Elts_t getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const; PriorProducersConsumers getPriorProducersConsumers(const std::shared_ptr<Node>& node) const; /** @brief Shared ptr to the scheduled graph view */ diff --git a/include/aidge/utils/ErrorHandling.hpp b/include/aidge/utils/ErrorHandling.hpp index d4235d2db..f6a9aefe2 100644 --- a/include/aidge/utils/ErrorHandling.hpp +++ b/include/aidge/utils/ErrorHandling.hpp @@ -14,6 +14,7 @@ #define AIDGE_ERRORHANDLING_H_ #include <memory> +#include <cassert> #include <fmt/format.h> #include <fmt/ranges.h> diff --git a/python_binding/backend/pybind_OperatorImpl.cpp b/python_binding/backend/pybind_OperatorImpl.cpp index 91d65484a..5259d877d 100644 --- a/python_binding/backend/pybind_OperatorImpl.cpp +++ b/python_binding/backend/pybind_OperatorImpl.cpp @@ -42,18 +42,18 @@ public: ); } - NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override { + Elts_t getNbRequiredData(const IOIndex_t inputIdx) const override { PYBIND11_OVERRIDE_NAME( - NbElts_t, + Elts_t, OperatorImpl, "get_nb_required_data", getNbRequiredData, inputIdx ); } - NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override { + Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override { PYBIND11_OVERRIDE_NAME( - NbElts_t, + Elts_t, OperatorImpl, "get_nb_required_protected", getNbRequiredProtected, @@ -61,10 +61,10 @@ public: ); } - NbElts_t getRequiredMemory(const IOIndex_t outputIdx, + Elts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const override { PYBIND11_OVERRIDE_NAME( - NbElts_t, + Elts_t, OperatorImpl, "get_required_memory", getRequiredMemory, @@ -73,9 +73,9 @@ public: ); } - NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const override { + Elts_t getNbConsumedData(const IOIndex_t inputIdx) const override { PYBIND11_OVERRIDE_NAME( - NbElts_t, + Elts_t, OperatorImpl, "get_nb_consumed_data", getNbConsumedData, @@ -83,9 +83,9 @@ public: ); } - NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override { + Elts_t getNbProducedData(const IOIndex_t outputIdx) const override { PYBIND11_OVERRIDE_NAME( - NbElts_t, + Elts_t, OperatorImpl, "get_nb_produced_data", getNbProducedData, diff --git a/src/backend/OperatorImpl.cpp b/src/backend/OperatorImpl.cpp index 1439391b2..42e8545d3 100644 --- a/src/backend/OperatorImpl.cpp +++ b/src/backend/OperatorImpl.cpp @@ -18,48 +18,91 @@ Aidge::OperatorImpl::OperatorImpl(const Operator& op): mOp(op), - mNbConsumedData(mOp.nbInputs(), 0), - mNbProducedData(mOp.nbOutputs(), 0) + mNbConsumedData(mOp.nbInputs(), Elts_t::NoneElts()), + mNbProducedData(mOp.nbOutputs(), Elts_t::NoneElts()) { //ctor } -Aidge::NbElts_t Aidge::OperatorImpl::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const { +Aidge::Elts_t Aidge::OperatorImpl::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const { AIDGE_ASSERT(mOp.getRawInput(inputIdx), "a valid input is required at index {} for operator type {}", inputIdx, mOp.type()); - // Requires the whole tensor by default - return std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx))->size(); + if (mOp.getRawInput(inputIdx)) { + const auto input = std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx)); + if (!input->empty()) { + // Known amount of data: requires the whole tensor by default + return Elts_t::DataElts(input->size()); + } + else { + // Unknown amount of data: require a single token by default + return Elts_t::TokenElts(1); + } + } + + // Input not connected, meaning it is an optional input: do no require anything! + return Elts_t::NoneElts(); } -Aidge::NbElts_t Aidge::OperatorImpl::getNbRequiredProtected(IOIndex_t inputIdx) const { +Aidge::Elts_t Aidge::OperatorImpl::getNbRequiredProtected(IOIndex_t inputIdx) const { AIDGE_ASSERT(mOp.getRawInput(inputIdx), "a valid input is required at index {} for operator type {}", inputIdx, mOp.type()); - // Protect the whole tensor by default - return std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx))->size(); + if (mOp.getRawInput(inputIdx)) { + const auto input = std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx)); + if (!input->empty()) { + // Known amount of data: protect the whole tensor by default + return Elts_t::DataElts(input->size()); + } + else { + // Unknown amount of data: protect a single token by default + // (this does not really make sense for now, as getNbRequiredProtected() + // is supposed to give a precise amount of data to protect for + // memory management purpose...) + return Elts_t::TokenElts(1); + } + } + + // Input not connected, meaning it is an optional input: do no require anything! + return Elts_t::NoneElts(); } -Aidge::NbElts_t Aidge::OperatorImpl::getRequiredMemory(const Aidge::IOIndex_t outputIdx, +Aidge::Elts_t Aidge::OperatorImpl::getRequiredMemory(const Aidge::IOIndex_t outputIdx, const std::vector<Aidge::DimSize_t> &/*inputsSize*/) const { AIDGE_ASSERT(mOp.getRawOutput(outputIdx), "a valid output is required at index {} for operator type {}", outputIdx, mOp.type()); - // Requires the whole tensor by default, regardless of available data on inputs - return std::static_pointer_cast<Tensor>(mOp.getRawOutput(outputIdx))->size(); + if (mOp.getRawOutput(outputIdx)) { + const auto output = std::static_pointer_cast<Tensor>(mOp.getRawOutput(outputIdx)); + if (!output->empty()) { + // Known amount of data: requires the whole tensor by default, + // regardless of available data on inputs + return Elts_t::DataElts(output->size()); + } + else { + // Unknown amount of data: require a single token by default + // (this does not really make sense for now, as getRequiredMemory() + // is supposed to give a precise amount of data to allocate for + // memory management purpose...) + return Elts_t::TokenElts(1); + } + } + + // Output not set, meaning it is an optional output: do no require anything! + return Elts_t::NoneElts(); } -Aidge::NbElts_t Aidge::OperatorImpl::getNbConsumedData(Aidge::IOIndex_t inputIdx) const { +Aidge::Elts_t Aidge::OperatorImpl::getNbConsumedData(Aidge::IOIndex_t inputIdx) const { AIDGE_ASSERT(static_cast<std::size_t>(inputIdx) < mNbConsumedData.size(), "input index ({}) is out of bound ({}) for operator type {}", inputIdx, mNbConsumedData.size(), mOp.type()); return mNbConsumedData[static_cast<std::size_t>(inputIdx)]; } -Aidge::NbElts_t Aidge::OperatorImpl::getNbProducedData(Aidge::IOIndex_t outputIdx) const { +Aidge::Elts_t Aidge::OperatorImpl::getNbProducedData(Aidge::IOIndex_t outputIdx) const { AIDGE_ASSERT(static_cast<std::size_t>(outputIdx) < mNbProducedData.size(), "output index ({}) is out of bound ({}) for operator type {}", outputIdx, mNbProducedData.size(), mOp.type()); @@ -79,8 +122,8 @@ void Aidge::OperatorImpl::updateConsummerProducer(){ } void Aidge::OperatorImpl::resetConsummerProducer(){ - std::fill(mNbConsumedData.begin(), mNbConsumedData.end(), 0); - std::fill(mNbProducedData.begin(), mNbProducedData.end(), 0); + std::fill(mNbConsumedData.begin(), mNbConsumedData.end(), Elts_t::NoneElts()); + std::fill(mNbProducedData.begin(), mNbProducedData.end(), Elts_t::NoneElts()); } void Aidge::OperatorImpl::forward() { diff --git a/src/operator/MetaOperator.cpp b/src/operator/MetaOperator.cpp index 883185021..1d15db1fb 100644 --- a/src/operator/MetaOperator.cpp +++ b/src/operator/MetaOperator.cpp @@ -30,7 +30,7 @@ Aidge::MetaOperator_Op::MetaOperator_Op(const char *type, const std::shared_ptr< } } -Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputIdx) const { +Aidge::Elts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputIdx) const { if (mImpl) { return mImpl->getNbRequiredData(inputIdx); } @@ -40,12 +40,12 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputI return inputOp.first->getOperator()->getNbRequiredData(inputOp.second); } else { - return 0; + return Elts_t::NoneElts(); } } } -Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredProtected(const IOIndex_t inputIdx) const { +Aidge::Elts_t Aidge::MetaOperator_Op::getNbRequiredProtected(const IOIndex_t inputIdx) const { if (mImpl) { return mImpl->getNbRequiredProtected(inputIdx); } @@ -55,12 +55,12 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredProtected(const IOIndex_t i return inputOp.first->getOperator()->getNbRequiredProtected(inputOp.second); } else { - return 0; + return Elts_t::NoneElts(); } } } -Aidge::NbElts_t Aidge::MetaOperator_Op::getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const { +Aidge::Elts_t Aidge::MetaOperator_Op::getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const { if (mImpl) { return mImpl->getRequiredMemory(outputIdx, inputsSize); } @@ -70,12 +70,12 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getRequiredMemory(const IOIndex_t output return outputOp.first->getOperator()->getRequiredMemory(outputOp.second, inputsSize); } else { - return 0; + return Elts_t::NoneElts(); } } } -Aidge::NbElts_t Aidge::MetaOperator_Op::getNbConsumedData(IOIndex_t inputIdx) const { +Aidge::Elts_t Aidge::MetaOperator_Op::getNbConsumedData(IOIndex_t inputIdx) const { if (mImpl) { return mImpl->getNbConsumedData(inputIdx); } @@ -85,12 +85,12 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbConsumedData(IOIndex_t inputIdx) co return inputOp.first->getOperator()->getNbConsumedData(inputOp.second); } else { - return 0; + return Elts_t::NoneElts(); } } } -Aidge::NbElts_t Aidge::MetaOperator_Op::getNbProducedData(IOIndex_t outputIdx) const { +Aidge::Elts_t Aidge::MetaOperator_Op::getNbProducedData(IOIndex_t outputIdx) const { if (mImpl) { return mImpl->getNbProducedData(outputIdx); } @@ -100,7 +100,7 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbProducedData(IOIndex_t outputIdx) c return outputOp.first->getOperator()->getNbProducedData(outputOp.second); } else { - return 0; + return Elts_t::NoneElts(); } } } diff --git a/src/operator/Operator.cpp b/src/operator/Operator.cpp index e4213cad8..317bbd364 100644 --- a/src/operator/Operator.cpp +++ b/src/operator/Operator.cpp @@ -31,27 +31,27 @@ Aidge::Operator::~Operator() noexcept = default; // IMPLEMENTATION /////////////////////////////////////////////////////// -Aidge::NbElts_t Aidge::Operator::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const { +Aidge::Elts_t Aidge::Operator::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const { AIDGE_ASSERT(mImpl != nullptr, "getNbRequiredData(): an implementation is required for {}!", type()); return mImpl->getNbRequiredData(inputIdx); } -Aidge::NbElts_t Aidge::Operator::getNbRequiredProtected(const Aidge::IOIndex_t inputIdx) const { +Aidge::Elts_t Aidge::Operator::getNbRequiredProtected(const Aidge::IOIndex_t inputIdx) const { AIDGE_ASSERT(mImpl != nullptr, "getNbRequiredProtected(): an implementation is required for {}!", type()); return mImpl->getNbRequiredProtected(inputIdx); } -Aidge::NbElts_t Aidge::Operator::getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const { +Aidge::Elts_t Aidge::Operator::getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const { AIDGE_ASSERT(mImpl != nullptr, "getRequiredMemory(): an implementation is required for {}!", type()); return mImpl->getRequiredMemory(outputIdx, inputsSize); } -Aidge::NbElts_t Aidge::Operator::getNbConsumedData(Aidge::IOIndex_t inputIdx) const { +Aidge::Elts_t Aidge::Operator::getNbConsumedData(Aidge::IOIndex_t inputIdx) const { AIDGE_ASSERT(mImpl != nullptr, "getNbConsumedData(): an implementation is required for {}!", type()); return mImpl->getNbConsumedData(inputIdx); } -Aidge::NbElts_t Aidge::Operator::getNbProducedData(Aidge::IOIndex_t outputIdx) const { +Aidge::Elts_t Aidge::Operator::getNbProducedData(Aidge::IOIndex_t outputIdx) const { AIDGE_ASSERT(mImpl != nullptr, "getNbProducedData(): an implementation is required for {}!", type()); return mImpl->getNbProducedData(outputIdx); } diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 639375902..906b3fa71 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -138,8 +138,7 @@ std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::S bool isRunnable = true; for (IOIndex_t inputIdx = 0; inputIdx < consumer->nbInputs(); ++inputIdx) { - if (/*consumer->getOperator()->getNbRequiredData(inputIdx) > 0 - && */(consumer->getOperator()->getNbConsumedData(inputIdx) + consumer->getOperator()->getNbRequiredData(inputIdx)) > + if ((consumer->getOperator()->getNbConsumedData(inputIdx) + consumer->getOperator()->getNbRequiredData(inputIdx)) > getNbAvailableData(consumer, inputIdx)) { Log::debug(" not runnable: C{} + R{} > P{} for input #{}", consumer->getOperator()->getNbConsumedData(inputIdx), @@ -226,12 +225,17 @@ std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::S IOIndex_t inputIdx = 0; for (const auto& childParent : child->getParents()) { if (childParent == consumer) { - if (consumer->getOperator()->getNbProducedData(outId) > child->getOperator()->getNbConsumedData(inputIdx)) { + if (child->getOperator()->getNbConsumedData(inputIdx) < consumer->getOperator()->getNbProducedData(outId)) { isProducer = true; + break; } } ++inputIdx; } + + if (isProducer) { + break; + } } } /* @@ -383,17 +387,22 @@ Aidge::MemoryManager Aidge::Scheduler::generateMemory(bool incProducers, bool wr } const auto childs = node->getChildren(); - AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type."); + AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, + "Operator must be of Tensor type for node {} (of type {}).", + node->name(), node->type()); const auto op = std::static_pointer_cast<OperatorTensor>(node->getOperator()); std::vector<const MemoryManager::MemoryPlane*> wrapAroundMemPlane; // Allocate a memory plane for each node's output for (IOIndex_t outputIdx = 0; outputIdx < node->nbOutputs(); ++outputIdx) { - const size_t requiredSize = op->getRequiredMemory(outputIdx, {}); + const auto requiredSize = op->getRequiredMemory(outputIdx, {}); + AIDGE_ASSERT(requiredSize.type == Elts_t::Data, + "Cannot generate memory with token-based producer-consumer model for node {} (of type {}).", + node->name(), node->type()); // By default, specifies a fully monolithic memory block - size_t size = requiredSize; + size_t size = requiredSize.data; size_t stride = 0; size_t length = 1; size_t count = 1; @@ -425,21 +434,27 @@ Aidge::MemoryManager Aidge::Scheduler::generateMemory(bool incProducers, bool wr // memSpace should not be already released && memManager.getPlanes(parent.first).end()[-parent.first->nbOutputs()+parent.second].memSpace->released == -1) { - const bool isWrappable = (op->getNbRequiredProtected(inputIdx) < op->getNbRequiredData(inputIdx)); + const auto requiredData = op->getNbRequiredData(inputIdx); + const auto requiredProtected = op->getNbRequiredProtected(inputIdx); + AIDGE_ASSERT(requiredData.type == Elts_t::Data && requiredProtected.type == Elts_t::Data, + "Cannot generate memory with token-based producer-consumer model for node {} (of type {}).", + node->name(), node->type()); + + const bool isWrappable = (requiredProtected.data < requiredData.data); const MemoryManager::MemoryPlane& memPlane = memManager.getPlanes(parent.first).end()[-parent.first->nbOutputs()+parent.second]; if (isWrappable || !memManager.isWrapAround( memPlane.memSpace, memPlane.getFinalOffset() - memPlane.memSpace->offset, - requiredSize)) + requiredSize.data)) { - if (memPlane.getSize() > wrapAroundSize + op->getNbRequiredProtected(inputIdx) + if (memPlane.getSize() > wrapAroundSize + requiredProtected.data && std::find(wrapAroundMemPlane.begin(), wrapAroundMemPlane.end(), &memPlane) == wrapAroundMemPlane.end()) { - wrapAroundSize = memPlane.getSize() - op->getNbRequiredProtected(inputIdx); - if (requiredSize > wrapAroundSize) { - wrapAroundExtra = requiredSize - wrapAroundSize; + wrapAroundSize = memPlane.getSize() - requiredProtected.data; + if (requiredSize.data > wrapAroundSize) { + wrapAroundExtra = requiredSize.data - wrapAroundSize; } wrapAroundMemPlane[outputIdx] = &memPlane; } @@ -456,17 +471,17 @@ Aidge::MemoryManager Aidge::Scheduler::generateMemory(bool incProducers, bool wr const MemoryManager::MemoryPlane& memPlane = (wrapAroundBuffer && wrapAroundSize > 0) ? (*wrapAroundMemPlane[outputIdx]) : - memManager.allocate(requiredSize, childs, stride, length, count); + memManager.allocate(requiredSize.data, childs, stride, length, count); if (wrapAroundBuffer && wrapAroundSize > 0) { memManager.reallocate(memPlane, node, 0, - requiredSize, true, wrapAroundExtra, childs, stride, length, count); + requiredSize.data, true, wrapAroundExtra, childs, stride, length, count); } else { memManager.reallocate(memPlane.memSpace, node, memPlane.offset, - requiredSize, false, 0, childs, stride, length, count); + requiredSize.data, false, 0, childs, stride, length, count); } } @@ -574,7 +589,7 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::Scheduler::getConsumers( return consumers; } -Aidge::NbElts_t Aidge::Scheduler::getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const { +Aidge::Elts_t Aidge::Scheduler::getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const { const auto parent = node->inputs()[inputIdx]; if (parent.first) { @@ -605,14 +620,14 @@ Aidge::NbElts_t Aidge::Scheduler::getNbAvailableData(const std::shared_ptr<Node> // In this case, we assume a single-use data (unlike a Producer, which // keep producing the data each time it is needed). fmt::print("No producer node attached to input#{} for node {} ({})\n", inputIdx, node->name(), node->type()); - return std::static_pointer_cast<Tensor>(node->getOperator()->getRawInput(inputIdx))->size(); + return Elts_t::DataElts(std::static_pointer_cast<Tensor>(node->getOperator()->getRawInput(inputIdx))->size()); } else { // Input is not connected, this is an error AIDGE_THROW_OR_ABORT(std::runtime_error, "Missing input#{} for node {} ({})\n", inputIdx, node->name(), node->type()); } - return 0; + return Elts_t::NoneElts(); } Aidge::Scheduler::PriorProducersConsumers Aidge::Scheduler::getPriorProducersConsumers( diff --git a/unit_tests/scheduler/Test_Scheduler.cpp b/unit_tests/scheduler/Test_Scheduler.cpp index ab5fef1f6..75a0daed6 100644 --- a/unit_tests/scheduler/Test_Scheduler.cpp +++ b/unit_tests/scheduler/Test_Scheduler.cpp @@ -75,3 +75,45 @@ TEST_CASE("randomScheduling", "[Scheduler][randomGen]") { fmt::print("nbUnicity = {}/{}\n", nbUnicity, nbTests); } + +TEST_CASE("randomScheduling_tokens", "[Scheduler][randomGen]") { + const size_t nbTests = 100; + size_t nbUnicity = 0; + + for (int test = 0; test < nbTests; ++test) { + std::random_device rd; + const std::mt19937::result_type seed(rd()); + + RandomGraph randGraph; + randGraph.acyclic = true; + const auto g1 = std::make_shared<GraphView>("g1"); + const bool unicity1 = g1->add(randGraph.gen(seed, 10)); + + if (unicity1) { + const auto orderedInputs = g1->getOrderedInputs(); + for (const auto& input : orderedInputs) { + auto prod = Producer({16, 32}); + prod->addChild(input.first, 0, input.second); + g1->add(prod); + } + + g1->save("schedule"); + + auto scheduler = SequentialScheduler(g1); + scheduler.generateScheduling(); + const auto sch = scheduler.getStaticScheduling(); + + const auto namePtrTable = g1->getRankedNodesName("{0} ({1}#{3})"); + + std::vector<std::string> nodesName; + std::transform(sch.begin(), sch.end(), + std::back_inserter(nodesName), + [&namePtrTable](auto val){ return namePtrTable.at(val); }); + + fmt::print("schedule: {}\n", nodesName); + REQUIRE(sch.size() == 10 + orderedInputs.size()); + } + } + + fmt::print("nbUnicity = {}/{}\n", nbUnicity, nbTests); +} -- GitLab