From b5130c8fffafaeb22e31b7bcad7568afd05ae884 Mon Sep 17 00:00:00 2001 From: Jerome Hue <jerome.hue@cea.fr> Date: Wed, 27 Nov 2024 16:46:00 +0100 Subject: [PATCH] chore: Fix rebase mess --- include/aidge/backend/OperatorImpl.hpp | 158 ++++++++++++++++++++----- 1 file changed, 131 insertions(+), 27 deletions(-) diff --git a/include/aidge/backend/OperatorImpl.hpp b/include/aidge/backend/OperatorImpl.hpp index 68e2a57b4..649898dd1 100644 --- a/include/aidge/backend/OperatorImpl.hpp +++ b/include/aidge/backend/OperatorImpl.hpp @@ -14,73 +14,177 @@ #include <string> #include <vector> +#include <functional> #include "aidge/utils/Types.h" +#include "aidge/utils/DynamicAttributes.hpp" +#include "aidge/data/Data.hpp" #include "aidge/data/Elts.hpp" +#include "aidge/scheduler/ProdConso.hpp" namespace Aidge { +class Node; class Operator; +/** + * @brief ImplSpec stores the requirements or the specifications of an implementation. + * + */ +struct ImplSpec { + struct IOSpec { + IOSpec(DataType type_, DataFormat format_ = DataFormat::Any, const std::vector<std::pair<int, int>>& dims_ = {}): + type(type_), + format(format_), + dims(dims_) + {} + + DataType type; + DataFormat format; + std::vector<std::pair<int, int>> dims; + }; + + ImplSpec(const DynamicAttributes& attrs_ = DynamicAttributes()); + ImplSpec(const IOSpec& io, const DynamicAttributes& attrs_ = DynamicAttributes()); + ImplSpec(const IOSpec& i, const IOSpec& o, const DynamicAttributes& attrs_ = DynamicAttributes()); + ImplSpec(const std::vector<IOSpec>& i, const std::vector<IOSpec>& o, const DynamicAttributes& attrs_ = DynamicAttributes()); + ImplSpec(const Aidge::ImplSpec&); + ~ImplSpec() noexcept; + + std::vector<IOSpec> inputs; + std::vector<IOSpec> outputs; + DynamicAttributes attrs; +}; + +inline bool operator==(const ImplSpec::IOSpec& lhs, const ImplSpec::IOSpec& rhs) { + return (lhs.type == rhs.type) + && (lhs.format == rhs.format) + && (lhs.dims == rhs.dims); +} + +inline bool operator<(const ImplSpec::IOSpec& lhs, const ImplSpec::IOSpec& rhs) { + return (lhs.type < rhs.type) + || (lhs.type == rhs.type && lhs.format < rhs.format) + || (lhs.type == rhs.type && lhs.format == rhs.format && lhs.dims < rhs.dims); +} + +inline bool operator<(const ImplSpec& lhs, const ImplSpec& rhs) { + return (lhs.inputs < rhs.inputs) + || (lhs.inputs == rhs.inputs && lhs.outputs < rhs.outputs) + || (lhs.inputs == rhs.inputs && lhs.outputs == rhs.outputs && lhs.attrs < rhs.attrs); +} + + +inline bool operator==(const ImplSpec& lhs, const ImplSpec& rhs) { + return !(lhs < rhs) && !(rhs < lhs); +} + +/** + * @brief Impl stores the details of a specific implementation. + * It is associated to a ImplSpec in a registry. + * + */ +template <class FwdFunc, class BwdFunc> +struct Impl { + Impl(std::function<std::unique_ptr<ProdConso>(const Operator&)> prodConso_, + std::function<FwdFunc> forward_, + std::function<BwdFunc> backward_ = nullptr): + prodConso(prodConso_), forward(forward_), backward(backward_) {} + + std::function<std::unique_ptr<ProdConso>(const Operator&)> prodConso; + std::function<FwdFunc> forward; + std::function<BwdFunc> backward; +}; + class OperatorImpl { public: OperatorImpl(const Operator& op, const std::string& backend = ""); virtual void forward(); virtual void backward(); + virtual std::shared_ptr<ProdConso> prodConso(); const std::string& backend() const noexcept { return mBackend; } - /** - * @brief Minimum amount of data from a specific input required by the - * implementation to be run. - * - * @param inputIdx Index of the input analyzed. - * @return std::size_t - */ - virtual Elts_t getNbRequiredData(const IOIndex_t inputIdx) const; - // Amount of input data that cannot be overwritten during the execution. - virtual Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const; - - // Memory required at an output for a given input size. - virtual Elts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const; + const Operator& getOperator() const noexcept { + return mOp; + } /** - * @brief Total amount of consumed data from a specific input. + * @brief Get the operator required implementation specification, according + * to the current operator configuration. * - * @param inputIdx Index of the input analyzed. - * @return DimSize_t */ - virtual Elts_t getNbConsumedData(const IOIndex_t inputIdx) const; + ImplSpec getRequiredSpec() const; /** - * @brief Total amount of produced data ready to be used on a specific output. + * @brief Get the best implementation that matches \p requiredSpecs. + * If no implementation matches \p requiredSpecs, \p requiredSpecs is + * returned. * - * @param outputIdx Index of the output analyzed. - * @return DimSize_t */ - virtual Elts_t getNbProducedData(const IOIndex_t outputIdx) const; + ImplSpec getBestMatch(const ImplSpec& requiredSpecs) const; /** - * @brief Update the Consumer Producer system by simulating the consumption and production of i/o + * @brief Get an adapted meta operator corresponding to the required + * specifications \p requiredSpecs from the implementation specifications + * \p spec. * + * @param spec Implementation specification + * @param requiredSpecs Required specifications + * @return std::shared_ptr<Node> Adapted meta op or nullptr */ - virtual void updateConsummerProducer(); + std::shared_ptr<Node> getAdaptation(const ImplSpec& spec, const ImplSpec& requiredSpecs) const; /** - * @brief Reset the Consumer Producer system. + * @brief Get the best adapted meta operator corresponding to the required + * specifications \p requiredSpecs. + * The best adaptation is the one with the lowest overhead cost. + * Currently, it is the one requiring the least number of additionnal + * operators to match the available implementations. * + * @param requiredSpecs Required specifications + * @return std::shared_ptr<Node> Adapted meta op or nullptr */ - virtual void resetConsummerProducer(); + std::shared_ptr<Node> getBestAdaptation(const ImplSpec& requiredSpecs) const; virtual ~OperatorImpl() = default; protected: + virtual std::shared_ptr<ProdConso> getProdConso() const; + virtual std::vector<ImplSpec> getAvailableImplSpecs() const; + bool checkIOSpec(const ImplSpec::IOSpec& required, const ImplSpec::IOSpec& spec) const; + const Operator &mOp; const std::string mBackend; - std::vector<Elts_t> mNbConsumedData; - std::vector<Elts_t> mNbProducedData; + std::shared_ptr<ProdConso> mProdConso; }; } // namespace Aidge +template<> +struct fmt::formatter<Aidge::ImplSpec::IOSpec> { + template<typename ParseContext> + inline constexpr auto parse(ParseContext& ctx) { + return ctx.begin(); + } + + template<typename FormatContext> + inline auto format(Aidge::ImplSpec::IOSpec const& ioSpec, FormatContext& ctx) const { + return fmt::format_to(ctx.out(), "{}, {}, {}", ioSpec.type, ioSpec.format, ioSpec.dims); + } +}; + +template<> +struct fmt::formatter<Aidge::ImplSpec> { + template<typename ParseContext> + inline constexpr auto parse(ParseContext& ctx) { + return ctx.begin(); + } + + template<typename FormatContext> + inline auto format(Aidge::ImplSpec const& implSpec, FormatContext& ctx) const { + return fmt::format_to(ctx.out(), "{}, {}", implSpec.inputs, implSpec.outputs); + } +}; + #endif /* AIDGE_BACKEND_OPERATORIMPL_H_ */ -- GitLab