diff --git a/include/aidge/backend/OperatorImpl.hpp b/include/aidge/backend/OperatorImpl.hpp index 63ef1a0623349e7dc998e908803c3c6979b64096..4b4fee0147462298dd08a47853d2b168e3d7e380 100644 --- a/include/aidge/backend/OperatorImpl.hpp +++ b/include/aidge/backend/OperatorImpl.hpp @@ -112,16 +112,34 @@ public: /** * @brief Get the best implementation that matches \p requiredSpecs. + * If no implementation matches \p requiredSpecs, \p requiredSpecs is + * returned. * */ - ImplSpec getBestMatch(ImplSpec requiredSpecs) const; + ImplSpec getBestMatch(const ImplSpec& requiredSpecs) const; /** - * @brief Get the best alternative node (containing a meta operator) - * fulfilling the requirements. + * @brief Get an adapted meta operator corresponding to the required + * specifications \p requiredSpecs from the implementation specifications + * \p spec. * + * @param spec Implementation specification + * @param requiredSpecs Required specifications + * @return std::shared_ptr<Node> Adapted meta op or nullptr */ - std::shared_ptr<Node> getBestAlternative(ImplSpec requiredSpecs); + std::shared_ptr<Node> getAdaptation(const ImplSpec& spec, const ImplSpec& requiredSpecs) const; + + /** + * @brief Get the best adapted meta operator corresponding to the required + * specifications \p requiredSpecs. + * The best adaptation is the one with the lowest overhead cost. + * Currently, it is the one requiring the least number of additionnal + * operators to match the available implementations. + * + * @param requiredSpecs Required specifications + * @return std::shared_ptr<Node> Adapted meta op or nullptr + */ + std::shared_ptr<Node> getBestAdaptation(const ImplSpec& requiredSpecs) const; virtual ~OperatorImpl() = default; diff --git a/src/backend/OperatorImpl.cpp b/src/backend/OperatorImpl.cpp index 91db61e40b42e65fa4b12f4a73bf643d1db63367..ed64f569598b763ac13a71f389403b24d1cc172d 100644 --- a/src/backend/OperatorImpl.cpp +++ b/src/backend/OperatorImpl.cpp @@ -15,6 +15,9 @@ #include "aidge/backend/OperatorImpl.hpp" #include "aidge/operator/Operator.hpp" #include "aidge/operator/OperatorTensor.hpp" +#include "aidge/operator/Transpose.hpp" +#include "aidge/operator/Cast.hpp" +#include "aidge/operator/MetaOperator.hpp" #include "aidge/scheduler/ProdConso.hpp" #include "aidge/data/Tensor.hpp" #include "aidge/utils/ErrorHandling.hpp" @@ -70,7 +73,7 @@ Aidge::ImplSpec Aidge::OperatorImpl::getRequiredSpec() const { return requiredSpec; } -Aidge::ImplSpec Aidge::OperatorImpl::getBestMatch(ImplSpec requiredSpecs) const { +Aidge::ImplSpec Aidge::OperatorImpl::getBestMatch(const ImplSpec& requiredSpecs) const { Log::debug("getBestMatch() for requirements: {}", requiredSpecs); const auto availableSpecs = getAvailableImplSpecs(); @@ -192,8 +195,131 @@ bool Aidge::OperatorImpl::checkIOSpec(const ImplSpec::IOSpec& required, const Im return true; } -std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getBestAlternative(ImplSpec /*requiredSpecs*/) { - // TODO: have a generic getBestAlternative() that handle at least data type and data format conversions. +std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& spec, const ImplSpec& requiredSpecs) const { + auto op = std::static_pointer_cast<OperatorTensor>(mOp.clone()); + auto node = std::make_shared<Node>(op); + + // Adapt inputs + for (size_t i = 0; i < requiredSpecs.inputs.size(); ++i) { + const ImplSpec::IOSpec& IOSpec = spec.inputs[i]; + const ImplSpec::IOSpec& requiredIOSpec = requiredSpecs.inputs[i]; + std::shared_ptr<Node> parent = node; + + if (requiredIOSpec.type != DataType::Any + && IOSpec.type != DataType::Any + && requiredIOSpec.type != IOSpec.type) + { + const auto cast = Cast(IOSpec.type); + cast->addChild(parent, 0, i); + + op->getInput(i)->setDataType(IOSpec.type); + } + + if (requiredIOSpec.format != DataFormat::Any + && IOSpec.format != DataFormat::Any + && requiredIOSpec.format != IOSpec.format) + { + const auto transpose = getDataFormatTranspose(requiredIOSpec.format, IOSpec.format); + auto transposeOp = Transpose(std::vector<DimSize_t>(transpose.begin(), transpose.end())); + transposeOp->getOperator()->setDataFormat(IOSpec.format); + transposeOp->getOperator()->setDataType(IOSpec.type); + transposeOp->addChild(parent, 0, i); + + op->getInput(i)->setDataFormat(IOSpec.format); + } + + if (!requiredIOSpec.dims.empty() && !IOSpec.dims.empty()) { + if (requiredIOSpec.dims.size() != IOSpec.dims.size()) { + return nullptr; + } + + for (size_t dim = 0; dim < requiredIOSpec.dims.size(); ++dim) { + const auto requiredDim = requiredIOSpec.dims[dim]; + const auto specDim = IOSpec.dims[dim]; + + if (requiredDim.first != -1 + && specDim.first != -1 + && !(specDim.first <= requiredDim.first && specDim.second >= requiredDim.second)) + { + return nullptr; + } + } + } + } + + // Adapt outputs + for (size_t i = 0; i < requiredSpecs.outputs.size(); ++i) { + const ImplSpec::IOSpec& IOSpec = spec.outputs[i]; + const ImplSpec::IOSpec& requiredIOSpec = requiredSpecs.outputs[i]; + std::shared_ptr<Node> parent = node; + + if (requiredIOSpec.type != DataType::Any + && IOSpec.type != DataType::Any + && requiredIOSpec.type != IOSpec.type) + { + const auto cast = Cast(requiredIOSpec.type); + parent->addChild(cast, i, 0); + + op->getOutput(i)->setDataType(IOSpec.type); + } + + if (requiredIOSpec.format != DataFormat::Any + && IOSpec.format != DataFormat::Any + && requiredIOSpec.format != IOSpec.format) + { + const auto transpose = getDataFormatTranspose(IOSpec.format, requiredIOSpec.format); + auto transposeOp = Transpose(std::vector<DimSize_t>(transpose.begin(), transpose.end())); + transposeOp->getOperator()->setDataFormat(requiredIOSpec.format); + transposeOp->getOperator()->setDataType(requiredIOSpec.type); + parent->addChild(transposeOp, i, 0); + + op->getOutput(i)->setDataFormat(IOSpec.format); + } + + if (!requiredIOSpec.dims.empty() && !IOSpec.dims.empty()) { + if (requiredIOSpec.dims.size() != IOSpec.dims.size()) { + return nullptr; + } + + for (size_t dim = 0; dim < requiredIOSpec.dims.size(); ++dim) { + const auto requiredDim = requiredIOSpec.dims[dim]; + const auto specDim = IOSpec.dims[dim]; + + if (requiredDim.first != -1 + && specDim.first != -1 + && !(specDim.first <= requiredDim.first && specDim.second >= requiredDim.second)) + { + return nullptr; + } + } + } + } + + return MetaOperator("", getConnectedGraphView(node)); +} + +std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getBestAdaptation(const ImplSpec& requiredSpecs) const { + const auto availableSpecs = getAvailableImplSpecs(); + + using AdaptationCost = int; + std::map<std::shared_ptr<Node>, AdaptationCost> adaptations; + + for (size_t s = 0; s < availableSpecs.size(); ++s) { + auto adaptation = getAdaptation(availableSpecs[s], requiredSpecs); + + if (adaptation) { + auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(adaptation)->getMicroGraph(); + adaptations.insert(std::make_pair(adaptation, microGraph->getNodes().size())); + } + } + + if (!adaptations.empty()) { + // Return best adaptation (with min. AdaptationCost) + const auto bestAdaptation = std::min_element(adaptations.begin(), adaptations.end(), + [](const auto& lhs, const auto& rhs) { return lhs.second < rhs.second; }); + return bestAdaptation->first; + } + return nullptr; }