From cd743cb25ca1deb4866e5defb821ef07bd0657d7 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Mon, 26 Aug 2024 12:58:12 +0200 Subject: [PATCH] Fix comparison operators and improved selection concept --- include/aidge/backend/OperatorImpl.hpp | 56 ++++++++------ src/backend/OperatorImpl.cpp | 100 +++++++++++++++++++++++++ src/utils/Log.cpp | 2 +- 3 files changed, 135 insertions(+), 23 deletions(-) diff --git a/include/aidge/backend/OperatorImpl.hpp b/include/aidge/backend/OperatorImpl.hpp index a9f968c59..76dcdc3c4 100644 --- a/include/aidge/backend/OperatorImpl.hpp +++ b/include/aidge/backend/OperatorImpl.hpp @@ -26,25 +26,26 @@ namespace Aidge { class Node; class Operator; +/** + * @brief ImplSpec stores the requirements or the specifications of an implementation. + * + */ struct ImplSpec { struct IOSpec { - IOSpec(DataType type_): - type(type_), - format(DataFormat::Any), - dims({}) - {} - - IOSpec(DataType type_, DataFormat format_): + IOSpec(DataType type_, DataFormat format_ = DataFormat::Any, std::vector<std::pair<DimSize_t, DimSize_t>> dims_ = {}): type(type_), format(format_), - dims({}) + dims(dims_) {} DataType type; DataFormat format; - std::vector<std::pair<size_t, size_t>> dims; + std::vector<std::pair<DimSize_t, DimSize_t>> dims; }; + ImplSpec() { + } + ImplSpec(IOSpec io) { inputs.push_back(io); outputs.push_back(io); @@ -60,14 +61,28 @@ struct ImplSpec { //DynamicAttributes attrs; }; +inline bool operator==(const ImplSpec::IOSpec& lhs, const ImplSpec::IOSpec& rhs) { + return (lhs.type == rhs.type) + && (lhs.format == rhs.format) + && (lhs.dims == rhs.dims); +} + inline bool operator<(const ImplSpec::IOSpec& lhs, const ImplSpec::IOSpec& rhs) { - return (lhs.type < rhs.type) && (lhs.format < rhs.format) && (lhs.dims < rhs.dims); + return (lhs.type < rhs.type) + || (lhs.type == rhs.type && lhs.format < rhs.format) + || (lhs.type == rhs.type && lhs.format == rhs.format && lhs.dims < rhs.dims); } inline bool operator<(const ImplSpec& lhs, const ImplSpec& rhs) { - return (lhs.inputs < rhs.inputs) && (lhs.outputs < rhs.outputs); + return (lhs.inputs < rhs.inputs) + || (lhs.inputs == rhs.inputs && lhs.outputs < rhs.outputs); } +/** + * @brief Impl stores the details of a specific implementation. + * It is associated to a ImplSpec in a registry. + * + */ template <class FwdFunc, class BwdFunc> struct Impl { Impl(std::function<std::unique_ptr<ProdConso>(const Operator&)> prodConso_, @@ -96,23 +111,20 @@ public: * to the current operator configuration. * */ - ImplSpec getRequiredSpec() const { - // TODO - return ImplSpec{DataType::Float32}; - } + ImplSpec getRequiredSpec() const; /** * @brief Get the best implementation that matches \p requiredSpecs. * */ - ImplSpec getBestMatch(ImplSpec /*requiredSpecs*/) const { - // TODO: - return getAvailableImplSpecs()[0]; - } - - // std::shared_ptr<Node> getAdaptedOp(ImplSpec requiredSpecs) { + ImplSpec getBestMatch(ImplSpec requiredSpecs) const; - // } + /** + * @brief Get the best alternative node (containing a meta operator) + * fulfilling the requirements. + * + */ + std::shared_ptr<Node> getBestAlternative(ImplSpec requiredSpecs); virtual ~OperatorImpl() = default; diff --git a/src/backend/OperatorImpl.cpp b/src/backend/OperatorImpl.cpp index 0eb1e635a..5e29418d4 100644 --- a/src/backend/OperatorImpl.cpp +++ b/src/backend/OperatorImpl.cpp @@ -14,6 +14,7 @@ #include "aidge/backend/OperatorImpl.hpp" #include "aidge/operator/Operator.hpp" +#include "aidge/operator/OperatorTensor.hpp" #include "aidge/scheduler/ProdConso.hpp" #include "aidge/data/Tensor.hpp" #include "aidge/utils/ErrorHandling.hpp" @@ -32,6 +33,105 @@ std::shared_ptr<Aidge::ProdConso> Aidge::OperatorImpl::prodConso() { return mProdConso; } +Aidge::ImplSpec Aidge::OperatorImpl::getRequiredSpec() const { + const auto& opTensor = dynamic_cast<const OperatorTensor&>(mOp); + + ImplSpec requiredSpec; + // Inputs specs + for (size_t i = 0; i < opTensor.nbInputs(); ++i) { + if (opTensor.getInput(i)) { + std::vector<std::pair<DimSize_t, DimSize_t>> dims; + for (auto dim : opTensor.getInput(i)->dims()) { + dims.push_back(std::make_pair(dim, dim)); + } + + requiredSpec.inputs.push_back({opTensor.getInput(i)->dataType(), opTensor.getInput(i)->dataFormat(), dims}); + } + else { + requiredSpec.inputs.push_back({DataType::Any}); + } + } + // Outputs specs + for (size_t i = 0; i < opTensor.nbOutputs(); ++i) { + std::vector<std::pair<DimSize_t, DimSize_t>> dims; + for (auto dim : opTensor.getOutput(i)->dims()) { + dims.push_back(std::make_pair(dim, dim)); + } + + requiredSpec.outputs.push_back({opTensor.getOutput(i)->dataType(), opTensor.getOutput(i)->dataFormat(), dims}); + } + return requiredSpec; +} + +Aidge::ImplSpec Aidge::OperatorImpl::getBestMatch(ImplSpec requiredSpecs) const { + Log::debug("getBestMatch() for requirements: {}", requiredSpecs); + + const auto availableSpecs = getAvailableImplSpecs(); + std::vector<bool> matchingSpecs(availableSpecs.size(), false); + + for (size_t s = 0; s < availableSpecs.size(); ++s) { + auto spec = availableSpecs[s]; + bool match = true; + + // Check inputs + for (size_t i = 0; i < requiredSpecs.inputs.size(); ++i) { + if (requiredSpecs.inputs[i].type != DataType::Any + && spec.inputs[i].type != DataType::Any + && requiredSpecs.inputs[i].type != spec.inputs[i].type) + { + match = false; + break; + } + + if (requiredSpecs.inputs[i].format != DataFormat::Any + && spec.inputs[i].format != DataFormat::Any + && requiredSpecs.inputs[i].format != spec.inputs[i].format) + { + match = false; + break; + } + + if (!requiredSpecs.inputs[i].dims.empty() && !spec.inputs[i].dims.empty()) { + if (requiredSpecs.inputs[i].dims.size() != spec.inputs[i].dims.size()) { + match = false; + break; + } + + for (size_t dim = 0; dim < requiredSpecs.inputs[i].dims.size(); ++dim) { + // TODO + } + } + } + + // Check outputs + // TODO + + // Check attributes + // TODO + + matchingSpecs[s] = match; + + Log::debug(" {} - {}", (match) ? "MATCH" : "MISMATCH", spec); + } + + // Return best match + // TODO: for now, returns the **first** match + for (size_t s = 0; s < availableSpecs.size(); ++s) { + if (matchingSpecs[s]) { + return availableSpecs[s]; + } + } + + // If there is no match, return the required specs for the registrar, which + // will throw a "missing or invalid registrar key" + return requiredSpecs; +} + +std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getBestAlternative(ImplSpec /*requiredSpecs*/) { + // TODO: have a generic getBestAlternative() that handle at least data type and data format conversions. + return nullptr; +} + void Aidge::OperatorImpl::forward() { AIDGE_THROW_OR_ABORT(std::runtime_error, "forward() not implemented yet for operator of type {}", mOp.type()); } diff --git a/src/utils/Log.cpp b/src/utils/Log.cpp index ae8816e78..da32a8e0e 100644 --- a/src/utils/Log.cpp +++ b/src/utils/Log.cpp @@ -89,7 +89,7 @@ void Aidge::Log::log(Level level, const std::string& msg) { fmt::println("Context: {}", context); } - fmt::println(mFile.get(), msg); + fmt::println(mFile.get(), "{}", msg); } } -- GitLab