Skip to content
Snippets Groups Projects
Commit d6039d95 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added first prototype of getBestAdaptation()

parent 8857ae78
No related branches found
No related tags found
2 merge requests!212Version 0.3.0,!186Refactor OperatorImpl for backend/export
Pipeline #53564 failed
......@@ -112,16 +112,34 @@ public:
/**
* @brief Get the best implementation that matches \p requiredSpecs.
* If no implementation matches \p requiredSpecs, \p requiredSpecs is
* returned.
*
*/
ImplSpec getBestMatch(ImplSpec requiredSpecs) const;
ImplSpec getBestMatch(const ImplSpec& requiredSpecs) const;
/**
* @brief Get the best alternative node (containing a meta operator)
* fulfilling the requirements.
* @brief Get an adapted meta operator corresponding to the required
* specifications \p requiredSpecs from the implementation specifications
* \p spec.
*
* @param spec Implementation specification
* @param requiredSpecs Required specifications
* @return std::shared_ptr<Node> Adapted meta op or nullptr
*/
std::shared_ptr<Node> getBestAlternative(ImplSpec requiredSpecs);
std::shared_ptr<Node> getAdaptation(const ImplSpec& spec, const ImplSpec& requiredSpecs) const;
/**
* @brief Get the best adapted meta operator corresponding to the required
* specifications \p requiredSpecs.
* The best adaptation is the one with the lowest overhead cost.
* Currently, it is the one requiring the least number of additionnal
* operators to match the available implementations.
*
* @param requiredSpecs Required specifications
* @return std::shared_ptr<Node> Adapted meta op or nullptr
*/
std::shared_ptr<Node> getBestAdaptation(const ImplSpec& requiredSpecs) const;
virtual ~OperatorImpl() = default;
......
......@@ -15,6 +15,9 @@
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Operator.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Transpose.hpp"
#include "aidge/operator/Cast.hpp"
#include "aidge/operator/MetaOperator.hpp"
#include "aidge/scheduler/ProdConso.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"
......@@ -70,7 +73,7 @@ Aidge::ImplSpec Aidge::OperatorImpl::getRequiredSpec() const {
return requiredSpec;
}
Aidge::ImplSpec Aidge::OperatorImpl::getBestMatch(ImplSpec requiredSpecs) const {
Aidge::ImplSpec Aidge::OperatorImpl::getBestMatch(const ImplSpec& requiredSpecs) const {
Log::debug("getBestMatch() for requirements: {}", requiredSpecs);
const auto availableSpecs = getAvailableImplSpecs();
......@@ -192,8 +195,131 @@ bool Aidge::OperatorImpl::checkIOSpec(const ImplSpec::IOSpec& required, const Im
return true;
}
std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getBestAlternative(ImplSpec /*requiredSpecs*/) {
// TODO: have a generic getBestAlternative() that handle at least data type and data format conversions.
std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& spec, const ImplSpec& requiredSpecs) const {
auto op = std::static_pointer_cast<OperatorTensor>(mOp.clone());
auto node = std::make_shared<Node>(op);
// Adapt inputs
for (size_t i = 0; i < requiredSpecs.inputs.size(); ++i) {
const ImplSpec::IOSpec& IOSpec = spec.inputs[i];
const ImplSpec::IOSpec& requiredIOSpec = requiredSpecs.inputs[i];
std::shared_ptr<Node> parent = node;
if (requiredIOSpec.type != DataType::Any
&& IOSpec.type != DataType::Any
&& requiredIOSpec.type != IOSpec.type)
{
const auto cast = Cast(IOSpec.type);
cast->addChild(parent, 0, i);
op->getInput(i)->setDataType(IOSpec.type);
}
if (requiredIOSpec.format != DataFormat::Any
&& IOSpec.format != DataFormat::Any
&& requiredIOSpec.format != IOSpec.format)
{
const auto transpose = getDataFormatTranspose(requiredIOSpec.format, IOSpec.format);
auto transposeOp = Transpose(std::vector<DimSize_t>(transpose.begin(), transpose.end()));
transposeOp->getOperator()->setDataFormat(IOSpec.format);
transposeOp->getOperator()->setDataType(IOSpec.type);
transposeOp->addChild(parent, 0, i);
op->getInput(i)->setDataFormat(IOSpec.format);
}
if (!requiredIOSpec.dims.empty() && !IOSpec.dims.empty()) {
if (requiredIOSpec.dims.size() != IOSpec.dims.size()) {
return nullptr;
}
for (size_t dim = 0; dim < requiredIOSpec.dims.size(); ++dim) {
const auto requiredDim = requiredIOSpec.dims[dim];
const auto specDim = IOSpec.dims[dim];
if (requiredDim.first != -1
&& specDim.first != -1
&& !(specDim.first <= requiredDim.first && specDim.second >= requiredDim.second))
{
return nullptr;
}
}
}
}
// Adapt outputs
for (size_t i = 0; i < requiredSpecs.outputs.size(); ++i) {
const ImplSpec::IOSpec& IOSpec = spec.outputs[i];
const ImplSpec::IOSpec& requiredIOSpec = requiredSpecs.outputs[i];
std::shared_ptr<Node> parent = node;
if (requiredIOSpec.type != DataType::Any
&& IOSpec.type != DataType::Any
&& requiredIOSpec.type != IOSpec.type)
{
const auto cast = Cast(requiredIOSpec.type);
parent->addChild(cast, i, 0);
op->getOutput(i)->setDataType(IOSpec.type);
}
if (requiredIOSpec.format != DataFormat::Any
&& IOSpec.format != DataFormat::Any
&& requiredIOSpec.format != IOSpec.format)
{
const auto transpose = getDataFormatTranspose(IOSpec.format, requiredIOSpec.format);
auto transposeOp = Transpose(std::vector<DimSize_t>(transpose.begin(), transpose.end()));
transposeOp->getOperator()->setDataFormat(requiredIOSpec.format);
transposeOp->getOperator()->setDataType(requiredIOSpec.type);
parent->addChild(transposeOp, i, 0);
op->getOutput(i)->setDataFormat(IOSpec.format);
}
if (!requiredIOSpec.dims.empty() && !IOSpec.dims.empty()) {
if (requiredIOSpec.dims.size() != IOSpec.dims.size()) {
return nullptr;
}
for (size_t dim = 0; dim < requiredIOSpec.dims.size(); ++dim) {
const auto requiredDim = requiredIOSpec.dims[dim];
const auto specDim = IOSpec.dims[dim];
if (requiredDim.first != -1
&& specDim.first != -1
&& !(specDim.first <= requiredDim.first && specDim.second >= requiredDim.second))
{
return nullptr;
}
}
}
}
return MetaOperator("", getConnectedGraphView(node));
}
std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getBestAdaptation(const ImplSpec& requiredSpecs) const {
const auto availableSpecs = getAvailableImplSpecs();
using AdaptationCost = int;
std::map<std::shared_ptr<Node>, AdaptationCost> adaptations;
for (size_t s = 0; s < availableSpecs.size(); ++s) {
auto adaptation = getAdaptation(availableSpecs[s], requiredSpecs);
if (adaptation) {
auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(adaptation)->getMicroGraph();
adaptations.insert(std::make_pair(adaptation, microGraph->getNodes().size()));
}
}
if (!adaptations.empty()) {
// Return best adaptation (with min. AdaptationCost)
const auto bestAdaptation = std::min_element(adaptations.begin(), adaptations.end(),
[](const auto& lhs, const auto& rhs) { return lhs.second < rhs.second; });
return bestAdaptation->first;
}
return nullptr;
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment