diff --git a/include/aidge/backend/OperatorImpl.hpp b/include/aidge/backend/OperatorImpl.hpp index 1fc9168da120ba87c916b1a6a346997be69184b4..a9f968c59c4d568cf138001b309635d805ce5141 100644 --- a/include/aidge/backend/OperatorImpl.hpp +++ b/include/aidge/backend/OperatorImpl.hpp @@ -14,73 +14,142 @@ #include <string> #include <vector> +#include <functional> #include "aidge/utils/Types.h" +#include "aidge/utils/DynamicAttributes.hpp" +#include "aidge/data/Data.hpp" #include "aidge/data/Elts.hpp" +#include "aidge/scheduler/ProdConso.hpp" namespace Aidge { +class Node; class Operator; +struct ImplSpec { + struct IOSpec { + IOSpec(DataType type_): + type(type_), + format(DataFormat::Any), + dims({}) + {} + + IOSpec(DataType type_, DataFormat format_): + type(type_), + format(format_), + dims({}) + {} + + DataType type; + DataFormat format; + std::vector<std::pair<size_t, size_t>> dims; + }; + + ImplSpec(IOSpec io) { + inputs.push_back(io); + outputs.push_back(io); + } + + ImplSpec(IOSpec i, IOSpec o) { + inputs.push_back(i); + outputs.push_back(o); + } + + std::vector<IOSpec> inputs; + std::vector<IOSpec> outputs; + //DynamicAttributes attrs; +}; + +inline bool operator<(const ImplSpec::IOSpec& lhs, const ImplSpec::IOSpec& rhs) { + return (lhs.type < rhs.type) && (lhs.format < rhs.format) && (lhs.dims < rhs.dims); +} + +inline bool operator<(const ImplSpec& lhs, const ImplSpec& rhs) { + return (lhs.inputs < rhs.inputs) && (lhs.outputs < rhs.outputs); +} + +template <class FwdFunc, class BwdFunc> +struct Impl { + Impl(std::function<std::unique_ptr<ProdConso>(const Operator&)> prodConso_, + std::function<FwdFunc> forward_, + std::function<BwdFunc> backward_ = nullptr): + prodConso(prodConso_), forward(forward_), backward(backward_) {} + + std::function<std::unique_ptr<ProdConso>(const Operator&)> prodConso; + std::function<FwdFunc> forward; + std::function<BwdFunc> backward; +}; + class OperatorImpl { public: OperatorImpl(const Operator& op, const std::string& backend = ""); virtual void forward(); virtual void backward(); + virtual std::shared_ptr<ProdConso> prodConso(); const std::string& backend() const noexcept { return mBackend; } - /** - * @brief Minimum amount of data from a specific input required by the - * implementation to be run. - * - * @param inputIdx Index of the input analysed. - * @return std::size_t - */ - virtual Elts_t getNbRequiredData(const IOIndex_t inputIdx) const; - - // Amount of input data that cannot be overwritten during the execution. - virtual Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const; - - // Memory required at an output for a given input size. - virtual Elts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const; /** - * @brief Total amount of consumed data from a specific input. - * - * @param inputIdx Index of the input analysed. - * @return DimSize_t + * @brief Get the operator required implementation specification, according + * to the current operator configuration. + * */ - virtual Elts_t getNbConsumedData(const IOIndex_t inputIdx) const; + ImplSpec getRequiredSpec() const { + // TODO + return ImplSpec{DataType::Float32}; + } /** - * @brief Total amount of produced data ready to be used on a specific output. - * - * @param outputIdx Index of the output analysed. - * @return DimSize_t + * @brief Get the best implementation that matches \p requiredSpecs. + * */ - virtual Elts_t getNbProducedData(const IOIndex_t outputIdx) const; + ImplSpec getBestMatch(ImplSpec /*requiredSpecs*/) const { + // TODO: + return getAvailableImplSpecs()[0]; + } - /** - * @brief Update the Consummer Producer system by simulating the consumption and production of i/o - * - */ - virtual void updateConsummerProducer(); + // std::shared_ptr<Node> getAdaptedOp(ImplSpec requiredSpecs) { - /** - * @brief Reset the Consummer Producer system. - * - */ - virtual void resetConsummerProducer(); + // } virtual ~OperatorImpl() = default; protected: + virtual std::shared_ptr<ProdConso> getProdConso() const; + virtual std::vector<ImplSpec> getAvailableImplSpecs() const; + const Operator &mOp; const std::string mBackend; - std::vector<Elts_t> mNbConsumedData; - std::vector<Elts_t> mNbProducedData; + std::shared_ptr<ProdConso> mProdConso; }; } // namespace Aidge +template<> +struct fmt::formatter<Aidge::ImplSpec::IOSpec> { + template<typename ParseContext> + inline constexpr auto parse(ParseContext& ctx) { + return ctx.begin(); + } + + template<typename FormatContext> + inline auto format(Aidge::ImplSpec::IOSpec const& ioSpec, FormatContext& ctx) const { + return fmt::format_to(ctx.out(), "{{{}, {}, {}}}", ioSpec.type, ioSpec.format, ioSpec.dims); + } +}; + +template<> +struct fmt::formatter<Aidge::ImplSpec> { + template<typename ParseContext> + inline constexpr auto parse(ParseContext& ctx) { + return ctx.begin(); + } + + template<typename FormatContext> + inline auto format(Aidge::ImplSpec const& implSpec, FormatContext& ctx) const { + return fmt::format_to(ctx.out(), "{{{}, {}}}", implSpec.inputs, implSpec.outputs); + } +}; + #endif /* AIDGE_BACKEND_OPERATORIMPL_H_ */ diff --git a/include/aidge/data/Data.hpp b/include/aidge/data/Data.hpp index eaadc7a7ca5fa85672619fb2d3b5b17590fd3778..ea6d9f98b5ca2dcfc624ed71eb63b8aa02c1ffb3 100644 --- a/include/aidge/data/Data.hpp +++ b/include/aidge/data/Data.hpp @@ -48,7 +48,8 @@ enum class DataType { UInt8, UInt16, UInt32, - UInt64 + UInt64, + Any }; enum class DataFormat { @@ -58,7 +59,8 @@ enum class DataFormat { CHWN, NCDHW, NDHWC, - CDHWN + CDHWN, + Any }; using DataFormatTranspose = std::array<size_t, 5>; @@ -145,11 +147,11 @@ const char* const EnumStrings<Aidge::DataType>::data[] = {"Float64", "Float32", "Float16", "BFloat16", "Binary", "Ternary", "Int2", "Int3", "Int4", "Int5", "Int6", "Int7", "Int8", "Int16", "Int32", "Int64", "UInt2", "UInt3", "UInt4", "UInt5", "UInt6", - "UInt7", "UInt8", "UInt16", "UInt32", "UInt64"}; + "UInt7", "UInt8", "UInt16", "UInt32", "UInt64", "Any"}; template <> const char* const EnumStrings<Aidge::DataFormat>::data[] - = {"Default", "NCHW", "NHWC", "CHWN", "NCDHW", "NDHWC", "CDHWN"}; + = {"Default", "NCHW", "NHWC", "CHWN", "NCDHW", "NDHWC", "CDHWN", "Any"}; template <Aidge::DataType D> struct cpptype { using type = void; // Placeholder diff --git a/include/aidge/operator/Memorize.hpp b/include/aidge/operator/Memorize.hpp index 8adac69f88ea4ed5e7d3b7549f7a41446db47ca6..bd37c0544032759bac48e168abd84542d8a0143e 100644 --- a/include/aidge/operator/Memorize.hpp +++ b/include/aidge/operator/Memorize.hpp @@ -25,12 +25,18 @@ #include "aidge/utils/Types.h" namespace Aidge { -class Memorize_OpImpl : public OperatorImpl { +class Memorize_ProdConso : public ProdConso { public: - Memorize_OpImpl(const Operator& op, const std::string& backend = ""): OperatorImpl(op, backend) {} + Memorize_ProdConso(const Operator& op): ProdConso(op) {} Elts_t getNbRequiredData(const IOIndex_t inputIdx) const override final; Elts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const override final; void updateConsummerProducer() override; +}; + +class Memorize_OpImpl : public OperatorImpl { +public: + Memorize_OpImpl(const Operator& op, const std::string& backend = ""): OperatorImpl(op, backend) {} + std::shared_ptr<ProdConso> getProdConso() const override { return std::make_shared<Memorize_ProdConso>(mOp); }; void forward() override; }; diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp index dc9f0e8ad4380cfbebc502fcad736a82f793091d..09e6c010f51f2c72cf07e0a7c34a8cb3101e2651 100644 --- a/include/aidge/operator/MetaOperator.hpp +++ b/include/aidge/operator/MetaOperator.hpp @@ -21,6 +21,7 @@ #include "aidge/graph/GraphView.hpp" #include "aidge/graph/OpArgs.hpp" #include "aidge/operator/OperatorTensor.hpp" +#include "aidge/scheduler/ProdConso.hpp" #include "aidge/scheduler/SequentialScheduler.hpp" #include "aidge/utils/Registrar.hpp" #include "aidge/utils/Types.h" diff --git a/include/aidge/operator/Pop.hpp b/include/aidge/operator/Pop.hpp index 41ab3c537eacc88920419cb5e0deecc4720796ba..1bdf30da839c5bebf3372b8ff5c9cd0b5b271e51 100644 --- a/include/aidge/operator/Pop.hpp +++ b/include/aidge/operator/Pop.hpp @@ -24,17 +24,23 @@ #include "aidge/utils/Types.h" namespace Aidge { +class Pop_ProdConso : public ProdConso { +public: + Pop_ProdConso(const Operator& op): ProdConso(op) {} + Elts_t getNbRequiredData(const IOIndex_t inputIdx) const override; +}; + class Pop_OpImpl : public OperatorImpl { public: Pop_OpImpl(const Operator& op, const std::string& backend = ""): OperatorImpl(op, backend) {} - Elts_t getNbRequiredData(const IOIndex_t inputIdx) const override; + std::shared_ptr<ProdConso> getProdConso() const override { return std::make_shared<Pop_ProdConso>(mOp); }; void forward() override; }; enum class PopAttr { ForwardStep }; class Pop_Op : public OperatorTensor, - public Registrable<Pop_Op, std::string, std::unique_ptr<OperatorImpl>(const Pop_Op&)> { + public Registrable<Pop_Op, std::string, std::function<std::unique_ptr<OperatorImpl>(const Pop_Op&)>> { public: static const std::string Type; diff --git a/include/aidge/scheduler/ProdConso.hpp b/include/aidge/scheduler/ProdConso.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a7c0ed5ae73d1f891744e835f0da5ad14a37f850 --- /dev/null +++ b/include/aidge/scheduler/ProdConso.hpp @@ -0,0 +1,89 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_SCHEDULER_PRODCONSO_H_ +#define AIDGE_SCHEDULER_PRODCONSO_H_ + +#include <string> +#include <vector> + +#include "aidge/utils/Types.h" +#include "aidge/data/Elts.hpp" + +namespace Aidge { +class Operator; + +class ProdConso { +public: + ProdConso(const Operator& op, bool inPlace = false); + + static std::unique_ptr<ProdConso> defaultModel(const Operator& op) { + return std::make_unique<ProdConso>(op, false); + } + + static std::unique_ptr<ProdConso> inPlaceModel(const Operator& op) { + return std::make_unique<ProdConso>(op, true); + } + + /** + * @brief Minimum amount of data from a specific input required by the + * implementation to be run. + * + * @param inputIdx Index of the input analysed. + * @return std::size_t + */ + virtual Elts_t getNbRequiredData(const IOIndex_t inputIdx) const; + + // Amount of input data that cannot be overwritten during the execution. + virtual Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const; + + // Memory required at an output for a given input size. + virtual Elts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const; + + /** + * @brief Total amount of consumed data from a specific input. + * + * @param inputIdx Index of the input analysed. + * @return DimSize_t + */ + virtual Elts_t getNbConsumedData(const IOIndex_t inputIdx) const; + + /** + * @brief Total amount of produced data ready to be used on a specific output. + * + * @param outputIdx Index of the output analysed. + * @return DimSize_t + */ + virtual Elts_t getNbProducedData(const IOIndex_t outputIdx) const; + + /** + * @brief Update the Consummer Producer system by simulating the consumption and production of i/o + * + */ + virtual void updateConsummerProducer(); + + /** + * @brief Reset the Consummer Producer system. + * + */ + virtual void resetConsummerProducer(); + + virtual ~ProdConso() = default; + +protected: + const Operator &mOp; + const bool mInPlace; + std::vector<Elts_t> mNbConsumedData; + std::vector<Elts_t> mNbProducedData; +}; +} // namespace Aidge + +#endif /* AIDGE_SCHEDULER_PRODCONSO_H_ */ diff --git a/src/backend/OperatorImpl.cpp b/src/backend/OperatorImpl.cpp index d992703fedb224e6650ce2ad50317cda3bae650f..0eb1e635a1402a76959d105db9ec2c88664ca642 100644 --- a/src/backend/OperatorImpl.cpp +++ b/src/backend/OperatorImpl.cpp @@ -14,106 +14,22 @@ #include "aidge/backend/OperatorImpl.hpp" #include "aidge/operator/Operator.hpp" +#include "aidge/scheduler/ProdConso.hpp" #include "aidge/data/Tensor.hpp" #include "aidge/utils/ErrorHandling.hpp" Aidge::OperatorImpl::OperatorImpl(const Operator& op, const std::string& backend): mOp(op), - mBackend(backend), - mNbConsumedData(mOp.nbInputs(), Elts_t::NoneElts()), - mNbProducedData(mOp.nbOutputs(), Elts_t::NoneElts()) + mBackend(backend) { //ctor } -Aidge::Elts_t Aidge::OperatorImpl::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const { - if (mOp.getRawInput(inputIdx)) { - const auto input = std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx)); - if (!input->undefined()) { - // Known amount of data: requires the whole tensor by default - return Elts_t::DataElts(input->size()); - } - else { - // Unknown amount of data: require a single token by default - return Elts_t::TokenElts(1); - } - } - - // Input not connected, meaning it is an optional input: do no require anything! - return Elts_t::NoneElts(); -} - -Aidge::Elts_t Aidge::OperatorImpl::getNbRequiredProtected(IOIndex_t inputIdx) const { - if (mOp.getRawInput(inputIdx)) { - const auto input = std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx)); - if (!input->undefined()) { - // Known amount of data: protect the whole tensor by default - return Elts_t::DataElts(input->size()); - } - else { - // Unknown amount of data: protect a single token by default - // (this does not really make sense for now, as getNbRequiredProtected() - // is supposed to give a precise amount of data to protect for - // memory management purpose...) - return Elts_t::TokenElts(1); - } - } - - // Input not connected, meaning it is an optional input: do no require anything! - return Elts_t::NoneElts(); -} - -Aidge::Elts_t Aidge::OperatorImpl::getRequiredMemory(const Aidge::IOIndex_t outputIdx, - const std::vector<Aidge::DimSize_t> &/*inputsSize*/) const { - if (mOp.getRawOutput(outputIdx)) { - const auto output = std::static_pointer_cast<Tensor>(mOp.getRawOutput(outputIdx)); - if (!output->undefined()) { - // Known amount of data: requires the whole tensor by default, - // regardless of available data on inputs - return Elts_t::DataElts(output->size()); - } - else { - // Unknown amount of data: require a single token by default - // (this does not really make sense for now, as getRequiredMemory() - // is supposed to give a precise amount of data to allocate for - // memory management purpose...) - return Elts_t::TokenElts(1); - } - } - - // Output not set, meaning it is an optional output: do no require anything! - return Elts_t::NoneElts(); -} - -Aidge::Elts_t Aidge::OperatorImpl::getNbConsumedData(Aidge::IOIndex_t inputIdx) const { - AIDGE_ASSERT(static_cast<std::size_t>(inputIdx) < mNbConsumedData.size(), - "input index ({}) is out of bound ({}) for operator type {}", - inputIdx, mNbConsumedData.size(), mOp.type()); - return mNbConsumedData[static_cast<std::size_t>(inputIdx)]; -} - -Aidge::Elts_t Aidge::OperatorImpl::getNbProducedData(Aidge::IOIndex_t outputIdx) const { - AIDGE_ASSERT(static_cast<std::size_t>(outputIdx) < mNbProducedData.size(), - "output index ({}) is out of bound ({}) for operator type {}", - outputIdx, mNbProducedData.size(), mOp.type()); - return mNbProducedData[static_cast<std::size_t>(outputIdx)]; -} - -void Aidge::OperatorImpl::updateConsummerProducer(){ - // Update producer-consumer data - for (std::size_t inputIdx = 0; inputIdx < mNbConsumedData.size(); ++inputIdx) { - // each input is consumed by the minimum amount for a forward pass - mNbConsumedData[inputIdx] += getNbRequiredData(static_cast<IOIndex_t>(inputIdx)); - } - - for (std::size_t outputIdx = 0; outputIdx < mNbProducedData.size(); ++outputIdx) { - mNbProducedData[outputIdx] += getRequiredMemory(outputIdx, {}); +std::shared_ptr<Aidge::ProdConso> Aidge::OperatorImpl::prodConso() { + if (!mProdConso) { + mProdConso = getProdConso(); } -} - -void Aidge::OperatorImpl::resetConsummerProducer(){ - std::fill(mNbConsumedData.begin(), mNbConsumedData.end(), Elts_t::NoneElts()); - std::fill(mNbProducedData.begin(), mNbProducedData.end(), Elts_t::NoneElts()); + return mProdConso; } void Aidge::OperatorImpl::forward() { @@ -123,3 +39,11 @@ void Aidge::OperatorImpl::forward() { void Aidge::OperatorImpl::backward() { AIDGE_THROW_OR_ABORT(std::runtime_error, "backward() not implemented yet for operator of type {}", mOp.type()); } + +std::shared_ptr<Aidge::ProdConso> Aidge::OperatorImpl::getProdConso() const { + return std::make_shared<ProdConso>(mOp); +} + +std::vector<Aidge::ImplSpec> Aidge::OperatorImpl::getAvailableImplSpecs() const { + return std::vector<ImplSpec>(); +} diff --git a/src/operator/Memorize.cpp b/src/operator/Memorize.cpp index f713fdaad793aebebf5047d4ebf1dfd5aca10cd1..620eed4071fdaaadee971cc4cd24fa029e18b13e 100644 --- a/src/operator/Memorize.cpp +++ b/src/operator/Memorize.cpp @@ -20,7 +20,7 @@ #include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/Types.h" -Aidge::Elts_t Aidge::Memorize_OpImpl::getNbRequiredData( +Aidge::Elts_t Aidge::Memorize_ProdConso::getNbRequiredData( Aidge::IOIndex_t inputIdx) const { const Memorize_Op& op = dynamic_cast<const Memorize_Op&>(mOp); @@ -35,11 +35,11 @@ Aidge::Elts_t Aidge::Memorize_OpImpl::getNbRequiredData( return Elts_t::NoneElts(); } else { - return OperatorImpl::getNbRequiredData(inputIdx); + return ProdConso::getNbRequiredData(inputIdx); } } -Aidge::Elts_t Aidge::Memorize_OpImpl::getRequiredMemory(const Aidge::IOIndex_t outputIdx, +Aidge::Elts_t Aidge::Memorize_ProdConso::getRequiredMemory(const Aidge::IOIndex_t outputIdx, const std::vector<Aidge::DimSize_t> &/*inputsSize*/) const { assert(mOp.getRawOutput(outputIdx) && "requires valid output"); @@ -53,8 +53,8 @@ Aidge::Elts_t Aidge::Memorize_OpImpl::getRequiredMemory(const Aidge::IOIndex_t o } } -void Aidge::Memorize_OpImpl::updateConsummerProducer() { - OperatorImpl::updateConsummerProducer(); +void Aidge::Memorize_ProdConso::updateConsummerProducer() { + ProdConso::updateConsummerProducer(); const Memorize_Op& op = dynamic_cast<const Memorize_Op&>(mOp); AIDGE_ASSERT(op.endStep() == 0 || op.scheduleStep() <= op.endStep(), "cannot update consumer producer anymore, number of cycles exceeded"); diff --git a/src/operator/MetaOperator.cpp b/src/operator/MetaOperator.cpp index 71e3a4781569820267b7d623da8d73134692c05d..fd49476557a4bab3837a1fac74ed1028086b3ae9 100644 --- a/src/operator/MetaOperator.cpp +++ b/src/operator/MetaOperator.cpp @@ -77,7 +77,7 @@ void Aidge::MetaOperator_Op::setBackend(const std::string &name, Aidge::DeviceId Aidge::Elts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputIdx) const { if (mImpl) { - return mImpl->getNbRequiredData(inputIdx); + return mImpl->prodConso()->getNbRequiredData(inputIdx); } else { const auto& inputOp = mGraph->getOrderedInputs()[inputIdx]; @@ -92,7 +92,7 @@ Aidge::Elts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputIdx Aidge::Elts_t Aidge::MetaOperator_Op::getNbRequiredProtected(const IOIndex_t inputIdx) const { if (mImpl) { - return mImpl->getNbRequiredProtected(inputIdx); + return mImpl->prodConso()->getNbRequiredProtected(inputIdx); } else { const auto& inputOp = mGraph->getOrderedInputs()[inputIdx]; @@ -107,7 +107,7 @@ Aidge::Elts_t Aidge::MetaOperator_Op::getNbRequiredProtected(const IOIndex_t inp Aidge::Elts_t Aidge::MetaOperator_Op::getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const { if (mImpl) { - return mImpl->getRequiredMemory(outputIdx, inputsSize); + return mImpl->prodConso()->getRequiredMemory(outputIdx, inputsSize); } else { const auto& outputOp = mGraph->getOrderedOutputs()[outputIdx]; @@ -122,7 +122,7 @@ Aidge::Elts_t Aidge::MetaOperator_Op::getRequiredMemory(const IOIndex_t outputId Aidge::Elts_t Aidge::MetaOperator_Op::getNbConsumedData(IOIndex_t inputIdx) const { if (mImpl) { - return mImpl->getNbConsumedData(inputIdx); + return mImpl->prodConso()->getNbConsumedData(inputIdx); } else { const auto& inputOp = mGraph->getOrderedInputs()[inputIdx]; @@ -137,7 +137,7 @@ Aidge::Elts_t Aidge::MetaOperator_Op::getNbConsumedData(IOIndex_t inputIdx) cons Aidge::Elts_t Aidge::MetaOperator_Op::getNbProducedData(IOIndex_t outputIdx) const { if (mImpl) { - return mImpl->getNbProducedData(outputIdx); + return mImpl->prodConso()->getNbProducedData(outputIdx); } else { const auto& outputOp = mGraph->getOrderedOutputs()[outputIdx]; @@ -152,7 +152,7 @@ Aidge::Elts_t Aidge::MetaOperator_Op::getNbProducedData(IOIndex_t outputIdx) con void Aidge::MetaOperator_Op::resetConsummerProducer() { if (mImpl) { - mImpl->resetConsummerProducer(); + mImpl->prodConso()->resetConsummerProducer(); } else { if (!mScheduler) { @@ -166,7 +166,7 @@ void Aidge::MetaOperator_Op::resetConsummerProducer() { void Aidge::MetaOperator_Op::updateConsummerProducer() { if (mImpl) { - mImpl->updateConsummerProducer(); + mImpl->prodConso()->updateConsummerProducer(); } else { if (!mScheduler) { diff --git a/src/operator/Operator.cpp b/src/operator/Operator.cpp index 317bbd364572f49a714e328bf33f3cd58c19215f..762d5fda8655c3094abcc7cb9118f4a00683a879 100644 --- a/src/operator/Operator.cpp +++ b/src/operator/Operator.cpp @@ -16,6 +16,7 @@ #include "aidge/backend/OperatorImpl.hpp" #include "aidge/operator/Operator.hpp" +#include "aidge/scheduler/ProdConso.hpp" #include "aidge/utils/Types.h" #include "aidge/utils/ErrorHandling.hpp" @@ -33,35 +34,35 @@ Aidge::Operator::~Operator() noexcept = default; Aidge::Elts_t Aidge::Operator::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const { AIDGE_ASSERT(mImpl != nullptr, "getNbRequiredData(): an implementation is required for {}!", type()); - return mImpl->getNbRequiredData(inputIdx); + return mImpl->prodConso()->getNbRequiredData(inputIdx); } Aidge::Elts_t Aidge::Operator::getNbRequiredProtected(const Aidge::IOIndex_t inputIdx) const { AIDGE_ASSERT(mImpl != nullptr, "getNbRequiredProtected(): an implementation is required for {}!", type()); - return mImpl->getNbRequiredProtected(inputIdx); + return mImpl->prodConso()->getNbRequiredProtected(inputIdx); } Aidge::Elts_t Aidge::Operator::getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const { AIDGE_ASSERT(mImpl != nullptr, "getRequiredMemory(): an implementation is required for {}!", type()); - return mImpl->getRequiredMemory(outputIdx, inputsSize); + return mImpl->prodConso()->getRequiredMemory(outputIdx, inputsSize); } Aidge::Elts_t Aidge::Operator::getNbConsumedData(Aidge::IOIndex_t inputIdx) const { AIDGE_ASSERT(mImpl != nullptr, "getNbConsumedData(): an implementation is required for {}!", type()); - return mImpl->getNbConsumedData(inputIdx); + return mImpl->prodConso()->getNbConsumedData(inputIdx); } Aidge::Elts_t Aidge::Operator::getNbProducedData(Aidge::IOIndex_t outputIdx) const { AIDGE_ASSERT(mImpl != nullptr, "getNbProducedData(): an implementation is required for {}!", type()); - return mImpl->getNbProducedData(outputIdx); + return mImpl->prodConso()->getNbProducedData(outputIdx); } void Aidge::Operator::updateConsummerProducer(){ AIDGE_ASSERT(mImpl != nullptr, "updateConsummerProducer(): an implementation is required for {}!", type()); - mImpl->updateConsummerProducer(); + mImpl->prodConso()->updateConsummerProducer(); } void Aidge::Operator::resetConsummerProducer(){ AIDGE_ASSERT(mImpl != nullptr, "resetConsummerProducer(): an implementation is required for {}!", type()); - mImpl->resetConsummerProducer(); + mImpl->prodConso()->resetConsummerProducer(); } void Aidge::Operator::runHooks() const { diff --git a/src/operator/Pop.cpp b/src/operator/Pop.cpp index 5d32a06fd01d8674d8e072f14838f3fd80d1f30a..70193d38b22d1288b19c27567475d95bf69fa30d 100644 --- a/src/operator/Pop.cpp +++ b/src/operator/Pop.cpp @@ -20,7 +20,7 @@ #include "aidge/utils/StaticAttributes.hpp" #include "aidge/utils/Types.h" -Aidge::Elts_t Aidge::Pop_OpImpl::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const { +Aidge::Elts_t Aidge::Pop_ProdConso::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const { assert(mOp.getRawInput(inputIdx) && "requires valid input"); const Pop_Op& op = dynamic_cast<const Pop_Op&>(mOp); diff --git a/src/scheduler/ProdConso.cpp b/src/scheduler/ProdConso.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a3bff53c3643a5da361dec5944f47a27f148a995 --- /dev/null +++ b/src/scheduler/ProdConso.cpp @@ -0,0 +1,117 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <cassert> +#include <string> + +#include "aidge/scheduler/ProdConso.hpp" +#include "aidge/operator/Operator.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/utils/ErrorHandling.hpp" + +Aidge::ProdConso::ProdConso(const Operator& op, bool inPlace): + mOp(op), + mInPlace(inPlace), + mNbConsumedData(mOp.nbInputs(), Elts_t::NoneElts()), + mNbProducedData(mOp.nbOutputs(), Elts_t::NoneElts()) +{ + //ctor +} + +Aidge::Elts_t Aidge::ProdConso::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const { + if (mOp.getRawInput(inputIdx)) { + const auto input = std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx)); + if (!input->undefined()) { + // Known amount of data: requires the whole tensor by default + return Elts_t::DataElts(input->size()); + } + else { + // Unknown amount of data: require a single token by default + return Elts_t::TokenElts(1); + } + } + + // Input not connected, meaning it is an optional input: do no require anything! + return Elts_t::NoneElts(); +} + +Aidge::Elts_t Aidge::ProdConso::getNbRequiredProtected(IOIndex_t inputIdx) const { + if (mOp.getRawInput(inputIdx)) { + const auto input = std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx)); + if (!input->undefined()) { + // Known amount of data: protect the whole tensor by default + return Elts_t::DataElts((mInPlace) ? 0 : input->size()); + } + else { + // Unknown amount of data: protect a single token by default + // (this does not really make sense for now, as getNbRequiredProtected() + // is supposed to give a precise amount of data to protect for + // memory management purpose...) + return Elts_t::TokenElts((mInPlace) ? 0 : 1); + } + } + + // Input not connected, meaning it is an optional input: do no require anything! + return Elts_t::NoneElts(); +} + +Aidge::Elts_t Aidge::ProdConso::getRequiredMemory(const Aidge::IOIndex_t outputIdx, + const std::vector<Aidge::DimSize_t> &/*inputsSize*/) const { + if (mOp.getRawOutput(outputIdx)) { + const auto output = std::static_pointer_cast<Tensor>(mOp.getRawOutput(outputIdx)); + if (!output->undefined()) { + // Known amount of data: requires the whole tensor by default, + // regardless of available data on inputs + return Elts_t::DataElts(output->size()); + } + else { + // Unknown amount of data: require a single token by default + // (this does not really make sense for now, as getRequiredMemory() + // is supposed to give a precise amount of data to allocate for + // memory management purpose...) + return Elts_t::TokenElts(1); + } + } + + // Output not set, meaning it is an optional output: do no require anything! + return Elts_t::NoneElts(); +} + +Aidge::Elts_t Aidge::ProdConso::getNbConsumedData(Aidge::IOIndex_t inputIdx) const { + AIDGE_ASSERT(static_cast<std::size_t>(inputIdx) < mNbConsumedData.size(), + "input index ({}) is out of bound ({}) for operator type {}", + inputIdx, mNbConsumedData.size(), mOp.type()); + return mNbConsumedData[static_cast<std::size_t>(inputIdx)]; +} + +Aidge::Elts_t Aidge::ProdConso::getNbProducedData(Aidge::IOIndex_t outputIdx) const { + AIDGE_ASSERT(static_cast<std::size_t>(outputIdx) < mNbProducedData.size(), + "output index ({}) is out of bound ({}) for operator type {}", + outputIdx, mNbProducedData.size(), mOp.type()); + return mNbProducedData[static_cast<std::size_t>(outputIdx)]; +} + +void Aidge::ProdConso::updateConsummerProducer(){ + // Update producer-consumer data + for (std::size_t inputIdx = 0; inputIdx < mNbConsumedData.size(); ++inputIdx) { + // each input is consumed by the minimum amount for a forward pass + mNbConsumedData[inputIdx] += getNbRequiredData(static_cast<IOIndex_t>(inputIdx)); + } + + for (std::size_t outputIdx = 0; outputIdx < mNbProducedData.size(); ++outputIdx) { + mNbProducedData[outputIdx] += getRequiredMemory(outputIdx, {}); + } +} + +void Aidge::ProdConso::resetConsummerProducer(){ + std::fill(mNbConsumedData.begin(), mNbConsumedData.end(), Elts_t::NoneElts()); + std::fill(mNbProducedData.begin(), mNbProducedData.end(), Elts_t::NoneElts()); +}