Skip to content
Snippets Groups Projects

Refactor OperatorImpl for backend/export

Merged Olivier BICHLER requested to merge backend_export into dev
Compare and
173 files
+ 1703
386
Compare changes
  • Side-by-side
  • Inline
Files
173
@@ -14,73 +14,172 @@
@@ -14,73 +14,172 @@
#include <string>
#include <string>
#include <vector>
#include <vector>
 
#include <functional>
#include "aidge/utils/Types.h"
#include "aidge/utils/Types.h"
 
#include "aidge/utils/DynamicAttributes.hpp"
 
#include "aidge/data/Data.hpp"
#include "aidge/data/Elts.hpp"
#include "aidge/data/Elts.hpp"
 
#include "aidge/scheduler/ProdConso.hpp"
namespace Aidge {
namespace Aidge {
 
class Node;
class Operator;
class Operator;
 
/**
 
* @brief ImplSpec stores the requirements or the specifications of an implementation.
 
*
 
*/
 
struct ImplSpec {
 
struct IOSpec {
 
IOSpec(DataType type_, DataFormat format_ = DataFormat::Any, const std::vector<std::pair<int, int>>& dims_ = {}):
 
type(type_),
 
format(format_),
 
dims(dims_)
 
{}
 
 
DataType type;
 
DataFormat format;
 
std::vector<std::pair<int, int>> dims;
 
};
 
 
ImplSpec(const DynamicAttributes& attrs_ = DynamicAttributes());
 
ImplSpec(const IOSpec& io, const DynamicAttributes& attrs_ = DynamicAttributes());
 
ImplSpec(const IOSpec& i, const IOSpec& o, const DynamicAttributes& attrs_ = DynamicAttributes());
 
ImplSpec(const std::vector<IOSpec>& i, const std::vector<IOSpec>& o, const DynamicAttributes& attrs_ = DynamicAttributes());
 
ImplSpec(const Aidge::ImplSpec&);
 
~ImplSpec() noexcept;
 
 
std::vector<IOSpec> inputs;
 
std::vector<IOSpec> outputs;
 
DynamicAttributes attrs;
 
};
 
 
inline bool operator==(const ImplSpec::IOSpec& lhs, const ImplSpec::IOSpec& rhs) {
 
return (lhs.type == rhs.type)
 
&& (lhs.format == rhs.format)
 
&& (lhs.dims == rhs.dims);
 
}
 
 
inline bool operator<(const ImplSpec::IOSpec& lhs, const ImplSpec::IOSpec& rhs) {
 
return (lhs.type < rhs.type)
 
|| (lhs.type == rhs.type && lhs.format < rhs.format)
 
|| (lhs.type == rhs.type && lhs.format == rhs.format && lhs.dims < rhs.dims);
 
}
 
 
inline bool operator<(const ImplSpec& lhs, const ImplSpec& rhs) {
 
return (lhs.inputs < rhs.inputs)
 
|| (lhs.inputs == rhs.inputs && lhs.outputs < rhs.outputs)
 
|| (lhs.inputs == rhs.inputs && lhs.outputs == rhs.outputs && lhs.attrs < rhs.attrs);
 
}
 
 
/**
 
* @brief Impl stores the details of a specific implementation.
 
* It is associated to a ImplSpec in a registry.
 
*
 
*/
 
template <class FwdFunc, class BwdFunc>
 
struct Impl {
 
Impl(std::function<std::unique_ptr<ProdConso>(const Operator&)> prodConso_,
 
std::function<FwdFunc> forward_,
 
std::function<BwdFunc> backward_ = nullptr):
 
prodConso(prodConso_), forward(forward_), backward(backward_) {}
 
 
std::function<std::unique_ptr<ProdConso>(const Operator&)> prodConso;
 
std::function<FwdFunc> forward;
 
std::function<BwdFunc> backward;
 
};
 
class OperatorImpl {
class OperatorImpl {
public:
public:
OperatorImpl(const Operator& op, const std::string& backend = "");
OperatorImpl(const Operator& op, const std::string& backend = "");
virtual void forward();
virtual void forward();
virtual void backward();
virtual void backward();
 
virtual std::shared_ptr<ProdConso> prodConso();
const std::string& backend() const noexcept {
const std::string& backend() const noexcept {
return mBackend;
return mBackend;
}
}
/**
* @brief Minimum amount of data from a specific input required by the
* implementation to be run.
*
* @param inputIdx Index of the input analysed.
* @return std::size_t
*/
virtual Elts_t getNbRequiredData(const IOIndex_t inputIdx) const;
// Amount of input data that cannot be overwritten during the execution.
const Operator& getOperator() const noexcept {
virtual Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const;
return mOp;
}
// Memory required at an output for a given input size.
virtual Elts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const;
/**
/**
* @brief Total amount of consumed data from a specific input.
* @brief Get the operator required implementation specification, according
*
* to the current operator configuration.
* @param inputIdx Index of the input analysed.
*
* @return DimSize_t
*/
*/
virtual Elts_t getNbConsumedData(const IOIndex_t inputIdx) const;
ImplSpec getRequiredSpec() const;
/**
/**
* @brief Total amount of produced data ready to be used on a specific output.
* @brief Get the best implementation that matches \p requiredSpecs.
*
* If no implementation matches \p requiredSpecs, \p requiredSpecs is
* @param outputIdx Index of the output analysed.
* returned.
* @return DimSize_t
*
*/
*/
virtual Elts_t getNbProducedData(const IOIndex_t outputIdx) const;
ImplSpec getBestMatch(const ImplSpec& requiredSpecs) const;
/**
/**
* @brief Update the Consummer Producer system by simulating the consumption and production of i/o
* @brief Get an adapted meta operator corresponding to the required
*
* specifications \p requiredSpecs from the implementation specifications
 
* \p spec.
 
*
 
* @param spec Implementation specification
 
* @param requiredSpecs Required specifications
 
* @return std::shared_ptr<Node> Adapted meta op or nullptr
*/
*/
virtual void updateConsummerProducer();
std::shared_ptr<Node> getAdaptation(const ImplSpec& spec, const ImplSpec& requiredSpecs) const;
/**
/**
* @brief Reset the Consummer Producer system.
* @brief Get the best adapted meta operator corresponding to the required
*
* specifications \p requiredSpecs.
 
* The best adaptation is the one with the lowest overhead cost.
 
* Currently, it is the one requiring the least number of additionnal
 
* operators to match the available implementations.
 
*
 
* @param requiredSpecs Required specifications
 
* @return std::shared_ptr<Node> Adapted meta op or nullptr
*/
*/
virtual void resetConsummerProducer();
std::shared_ptr<Node> getBestAdaptation(const ImplSpec& requiredSpecs) const;
virtual ~OperatorImpl() = default;
virtual ~OperatorImpl() = default;
protected:
protected:
 
virtual std::shared_ptr<ProdConso> getProdConso() const;
 
virtual std::set<ImplSpec> getAvailableImplSpecs() const;
 
bool checkIOSpec(const ImplSpec::IOSpec& required, const ImplSpec::IOSpec& spec) const;
 
const Operator &mOp;
const Operator &mOp;
const std::string mBackend;
const std::string mBackend;
std::vector<Elts_t> mNbConsumedData;
std::shared_ptr<ProdConso> mProdConso;
std::vector<Elts_t> mNbProducedData;
};
};
} // namespace Aidge
} // namespace Aidge
 
template<>
 
struct fmt::formatter<Aidge::ImplSpec::IOSpec> {
 
template<typename ParseContext>
 
inline constexpr auto parse(ParseContext& ctx) {
 
return ctx.begin();
 
}
 
 
template<typename FormatContext>
 
inline auto format(Aidge::ImplSpec::IOSpec const& ioSpec, FormatContext& ctx) const {
 
return fmt::format_to(ctx.out(), "{}, {}, {}", ioSpec.type, ioSpec.format, ioSpec.dims);
 
}
 
};
 
 
template<>
 
struct fmt::formatter<Aidge::ImplSpec> {
 
template<typename ParseContext>
 
inline constexpr auto parse(ParseContext& ctx) {
 
return ctx.begin();
 
}
 
 
template<typename FormatContext>
 
inline auto format(Aidge::ImplSpec const& implSpec, FormatContext& ctx) const {
 
return fmt::format_to(ctx.out(), "{}, {}", implSpec.inputs, implSpec.outputs);
 
}
 
};
 
#endif /* AIDGE_BACKEND_OPERATORIMPL_H_ */
#endif /* AIDGE_BACKEND_OPERATORIMPL_H_ */
Loading