Skip to content
Snippets Groups Projects
Commit b5130c8f authored by Jerome Hue's avatar Jerome Hue
Browse files

chore: Fix rebase mess

parent e9988fb5
No related branches found
No related tags found
No related merge requests found
...@@ -14,73 +14,177 @@ ...@@ -14,73 +14,177 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <functional>
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/utils/DynamicAttributes.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/data/Elts.hpp" #include "aidge/data/Elts.hpp"
#include "aidge/scheduler/ProdConso.hpp"
namespace Aidge { namespace Aidge {
class Node;
class Operator; class Operator;
/**
* @brief ImplSpec stores the requirements or the specifications of an implementation.
*
*/
struct ImplSpec {
struct IOSpec {
IOSpec(DataType type_, DataFormat format_ = DataFormat::Any, const std::vector<std::pair<int, int>>& dims_ = {}):
type(type_),
format(format_),
dims(dims_)
{}
DataType type;
DataFormat format;
std::vector<std::pair<int, int>> dims;
};
ImplSpec(const DynamicAttributes& attrs_ = DynamicAttributes());
ImplSpec(const IOSpec& io, const DynamicAttributes& attrs_ = DynamicAttributes());
ImplSpec(const IOSpec& i, const IOSpec& o, const DynamicAttributes& attrs_ = DynamicAttributes());
ImplSpec(const std::vector<IOSpec>& i, const std::vector<IOSpec>& o, const DynamicAttributes& attrs_ = DynamicAttributes());
ImplSpec(const Aidge::ImplSpec&);
~ImplSpec() noexcept;
std::vector<IOSpec> inputs;
std::vector<IOSpec> outputs;
DynamicAttributes attrs;
};
inline bool operator==(const ImplSpec::IOSpec& lhs, const ImplSpec::IOSpec& rhs) {
return (lhs.type == rhs.type)
&& (lhs.format == rhs.format)
&& (lhs.dims == rhs.dims);
}
inline bool operator<(const ImplSpec::IOSpec& lhs, const ImplSpec::IOSpec& rhs) {
return (lhs.type < rhs.type)
|| (lhs.type == rhs.type && lhs.format < rhs.format)
|| (lhs.type == rhs.type && lhs.format == rhs.format && lhs.dims < rhs.dims);
}
inline bool operator<(const ImplSpec& lhs, const ImplSpec& rhs) {
return (lhs.inputs < rhs.inputs)
|| (lhs.inputs == rhs.inputs && lhs.outputs < rhs.outputs)
|| (lhs.inputs == rhs.inputs && lhs.outputs == rhs.outputs && lhs.attrs < rhs.attrs);
}
inline bool operator==(const ImplSpec& lhs, const ImplSpec& rhs) {
return !(lhs < rhs) && !(rhs < lhs);
}
/**
* @brief Impl stores the details of a specific implementation.
* It is associated to a ImplSpec in a registry.
*
*/
template <class FwdFunc, class BwdFunc>
struct Impl {
Impl(std::function<std::unique_ptr<ProdConso>(const Operator&)> prodConso_,
std::function<FwdFunc> forward_,
std::function<BwdFunc> backward_ = nullptr):
prodConso(prodConso_), forward(forward_), backward(backward_) {}
std::function<std::unique_ptr<ProdConso>(const Operator&)> prodConso;
std::function<FwdFunc> forward;
std::function<BwdFunc> backward;
};
class OperatorImpl { class OperatorImpl {
public: public:
OperatorImpl(const Operator& op, const std::string& backend = ""); OperatorImpl(const Operator& op, const std::string& backend = "");
virtual void forward(); virtual void forward();
virtual void backward(); virtual void backward();
virtual std::shared_ptr<ProdConso> prodConso();
const std::string& backend() const noexcept { const std::string& backend() const noexcept {
return mBackend; return mBackend;
} }
/**
* @brief Minimum amount of data from a specific input required by the
* implementation to be run.
*
* @param inputIdx Index of the input analyzed.
* @return std::size_t
*/
virtual Elts_t getNbRequiredData(const IOIndex_t inputIdx) const;
// Amount of input data that cannot be overwritten during the execution. const Operator& getOperator() const noexcept {
virtual Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const; return mOp;
}
// Memory required at an output for a given input size.
virtual Elts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const;
/** /**
* @brief Total amount of consumed data from a specific input. * @brief Get the operator required implementation specification, according
* to the current operator configuration.
* *
* @param inputIdx Index of the input analyzed.
* @return DimSize_t
*/ */
virtual Elts_t getNbConsumedData(const IOIndex_t inputIdx) const; ImplSpec getRequiredSpec() const;
/** /**
* @brief Total amount of produced data ready to be used on a specific output. * @brief Get the best implementation that matches \p requiredSpecs.
* If no implementation matches \p requiredSpecs, \p requiredSpecs is
* returned.
* *
* @param outputIdx Index of the output analyzed.
* @return DimSize_t
*/ */
virtual Elts_t getNbProducedData(const IOIndex_t outputIdx) const; ImplSpec getBestMatch(const ImplSpec& requiredSpecs) const;
/** /**
* @brief Update the Consumer Producer system by simulating the consumption and production of i/o * @brief Get an adapted meta operator corresponding to the required
* specifications \p requiredSpecs from the implementation specifications
* \p spec.
* *
* @param spec Implementation specification
* @param requiredSpecs Required specifications
* @return std::shared_ptr<Node> Adapted meta op or nullptr
*/ */
virtual void updateConsummerProducer(); std::shared_ptr<Node> getAdaptation(const ImplSpec& spec, const ImplSpec& requiredSpecs) const;
/** /**
* @brief Reset the Consumer Producer system. * @brief Get the best adapted meta operator corresponding to the required
* specifications \p requiredSpecs.
* The best adaptation is the one with the lowest overhead cost.
* Currently, it is the one requiring the least number of additionnal
* operators to match the available implementations.
* *
* @param requiredSpecs Required specifications
* @return std::shared_ptr<Node> Adapted meta op or nullptr
*/ */
virtual void resetConsummerProducer(); std::shared_ptr<Node> getBestAdaptation(const ImplSpec& requiredSpecs) const;
virtual ~OperatorImpl() = default; virtual ~OperatorImpl() = default;
protected: protected:
virtual std::shared_ptr<ProdConso> getProdConso() const;
virtual std::vector<ImplSpec> getAvailableImplSpecs() const;
bool checkIOSpec(const ImplSpec::IOSpec& required, const ImplSpec::IOSpec& spec) const;
const Operator &mOp; const Operator &mOp;
const std::string mBackend; const std::string mBackend;
std::vector<Elts_t> mNbConsumedData; std::shared_ptr<ProdConso> mProdConso;
std::vector<Elts_t> mNbProducedData;
}; };
} // namespace Aidge } // namespace Aidge
template<>
struct fmt::formatter<Aidge::ImplSpec::IOSpec> {
template<typename ParseContext>
inline constexpr auto parse(ParseContext& ctx) {
return ctx.begin();
}
template<typename FormatContext>
inline auto format(Aidge::ImplSpec::IOSpec const& ioSpec, FormatContext& ctx) const {
return fmt::format_to(ctx.out(), "{}, {}, {}", ioSpec.type, ioSpec.format, ioSpec.dims);
}
};
template<>
struct fmt::formatter<Aidge::ImplSpec> {
template<typename ParseContext>
inline constexpr auto parse(ParseContext& ctx) {
return ctx.begin();
}
template<typename FormatContext>
inline auto format(Aidge::ImplSpec const& implSpec, FormatContext& ctx) const {
return fmt::format_to(ctx.out(), "{}, {}", implSpec.inputs, implSpec.outputs);
}
};
#endif /* AIDGE_BACKEND_OPERATORIMPL_H_ */ #endif /* AIDGE_BACKEND_OPERATORIMPL_H_ */
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment