/******************************************************************************** * Copyright (c) 2023 CEA-List * * This program and the accompanying materials are made available under the * terms of the Eclipse Public License 2.0 which is available at * http://www.eclipse.org/legal/epl-2.0. * * SPDX-License-Identifier: EPL-2.0 * ********************************************************************************/ #include <cassert> #include <string> #include "aidge/backend/OperatorImpl.hpp" #include "aidge/operator/Operator.hpp" #include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/Transpose.hpp" #include "aidge/operator/Cast.hpp" #include "aidge/operator/MetaOperator.hpp" #include "aidge/scheduler/ProdConso.hpp" #include "aidge/data/Tensor.hpp" #include "aidge/utils/ErrorHandling.hpp" Aidge::ImplSpec::ImplSpec(const DynamicAttributes& attrs_): attrs(attrs_) {} Aidge::ImplSpec::ImplSpec(const IOSpec& io, const DynamicAttributes& attrs_): inputs(1, io), outputs(1, io), attrs(attrs_) {} Aidge::ImplSpec::ImplSpec(const IOSpec& i, const IOSpec& o, const DynamicAttributes& attrs_): inputs(1, i), outputs(1, o), attrs(attrs_) {} Aidge::ImplSpec::ImplSpec(const std::vector<IOSpec>& i, const std::vector<IOSpec>& o, const DynamicAttributes& attrs_): inputs(i), outputs(o), attrs(attrs_) {} Aidge::ImplSpec::ImplSpec(const Aidge::ImplSpec&) = default; Aidge::ImplSpec::~ImplSpec() noexcept = default; Aidge::OperatorImpl::OperatorImpl(const Operator& op, const std::string& backend): mOp(op), mBackend(backend) { //ctor } std::shared_ptr<Aidge::ProdConso> Aidge::OperatorImpl::prodConso() { if (!mProdConso) { mProdConso = getProdConso(); } return mProdConso; } Aidge::ImplSpec Aidge::OperatorImpl::getRequiredSpec() const { const auto& opTensor = dynamic_cast<const OperatorTensor&>(mOp); ImplSpec requiredSpec; // Inputs specs for (size_t i = 0; i < opTensor.nbInputs(); ++i) { if (opTensor.getInput(i)) { std::vector<std::pair<int, int>> dims; for (auto dim : opTensor.getInput(i)->dims()) { dims.push_back(std::make_pair<int, int>(dim, dim)); } requiredSpec.inputs.push_back({opTensor.getInput(i)->dataType(), opTensor.getInput(i)->dataFormat(), dims}); } else { requiredSpec.inputs.push_back({DataType::Any}); } } // Outputs specs for (size_t i = 0; i < opTensor.nbOutputs(); ++i) { std::vector<std::pair<int, int>> dims; for (auto dim : opTensor.getOutput(i)->dims()) { dims.push_back(std::make_pair<int, int>(dim, dim)); } requiredSpec.outputs.push_back({opTensor.getOutput(i)->dataType(), opTensor.getOutput(i)->dataFormat(), dims}); } // Attributes if (!mOp.isAtomic()) { requiredSpec.attrs.setAttr("type:!", mOp.type()); // :! mandatory qualifier } else { requiredSpec.attrs.setAttr("type", mOp.type()); } const auto& inhAttrs = mOp.inheritedAttributes(); if (inhAttrs) { if (inhAttrs->hasAttr("impl")) { requiredSpec.attrs.setAttr("impl", inhAttrs->getAny("impl")); } } return requiredSpec; } Aidge::ImplSpec Aidge::OperatorImpl::getBestMatch(const ImplSpec& requiredSpecs) const { Log::debug("getBestMatch() for requirements: {}", requiredSpecs); const auto availableSpecsSet = getAvailableImplSpecs(); AIDGE_ASSERT(availableSpecsSet.size() > 0 , "OperatorImpl::getBestMatch(): No available specs found by" "getAvailableSpecs(). " "Cannot find best implementation for required specs, aborting."); const std::vector<ImplSpec> availableSpecs(availableSpecsSet.begin(), availableSpecsSet.end()); std::vector<int> matchingSpecs(availableSpecs.size(), -1); for (size_t s = 0; s < availableSpecs.size(); ++s) { auto spec = availableSpecs[s]; bool match = true; int priority = 0; // Check inputs for (size_t i = 0; i < requiredSpecs.inputs.size(); ++i) { const auto inputSpec = (i < spec.inputs.size()) ? spec.inputs[i] : spec.inputs.back(); if (!checkIOSpec(requiredSpecs.inputs[i], inputSpec)) { match = false; break; } } // Check outputs for (size_t i = 0; i < requiredSpecs.outputs.size(); ++i) { const auto outputSpec = (i < spec.outputs.size()) ? spec.outputs[i] : spec.outputs.back(); if (!checkIOSpec(requiredSpecs.outputs[i], outputSpec)) { match = false; break; } } // Check attributes for (const auto& attrName : requiredSpecs.attrs.getAttrsName()) { std::string name = attrName; std::string qualifier; const auto qualifierPos = std::find_if(attrName.begin(), attrName.end(), [](char c) { return c == ':'; }); if (qualifierPos != attrName.end()) { name = attrName.substr(0, qualifierPos - attrName.begin()); qualifier = attrName.substr(qualifierPos - attrName.begin() + 1); } const bool mandatory = (qualifier == "!"); if (mandatory) { // Required attribute: if (!spec.attrs.hasAttr(name)) { Log::debug("Could not find mandatory attribute '{}'.", name); // Missing attribute match = false; break; } else if (requiredSpecs.attrs.getAny(attrName) < spec.attrs.getAny(name) || spec.attrs.getAny(name) < requiredSpecs.attrs.getAny(attrName)) { Log::debug("Attribute ({}) value mismatch {} != {}.", name, requiredSpecs.attrs.getAttr<std::string>(attrName), spec.attrs.getAttr<std::string>(name)); // Attribute value mismatch match = false; break; } } else { const int attrPriority = (!qualifier.empty()) ? std::stoi(qualifier) : 0; if (spec.attrs.hasAttr(name) && !(requiredSpecs.attrs.getAny(attrName) < spec.attrs.getAny(name)) && !(spec.attrs.getAny(name) < requiredSpecs.attrs.getAny(attrName))) { // Attribute value match priority = std::max(priority, attrPriority); } } } if (match) { matchingSpecs[s] = priority; } Log::debug(" {}:{} - {}", (match) ? "MATCH" : "MISMATCH", priority, spec); } if(matchingSpecs.empty()){ Log::debug(" No spec to match registered, returning requiredSpecs."); return requiredSpecs; } // Return best match const auto bestMatch = std::max_element(matchingSpecs.begin(), matchingSpecs.end()); if (*bestMatch >= 0) { const auto bestSpecIdx = bestMatch - matchingSpecs.begin(); return availableSpecs[bestSpecIdx]; } // If there is no match, return the required specs for the registrar, which // will throw a "missing or invalid registrar key" return requiredSpecs; } bool Aidge::OperatorImpl::checkIOSpec(const ImplSpec::IOSpec& required, const ImplSpec::IOSpec& spec) const { // Check type if (required.type != DataType::Any && spec.type != DataType::Any && required.type != spec.type) { return false; } // Check format if (required.format != DataFormat::Any && spec.format != DataFormat::Any && required.format != spec.format) { const auto transpose = getDataFormatTranspose(required.format, spec.format); std::vector<size_t> identity(transpose.size()); std::iota(std::begin(identity), std::end(identity), 0); if (!std::equal(transpose.begin(), transpose.end(), identity.begin())) { return false; } } // Check dims if (!required.dims.empty() && !spec.dims.empty()) { if (required.dims.size() != spec.dims.size()) { return false; } for (size_t dim = 0; dim < required.dims.size(); ++dim) { const auto requiredDim = required.dims[dim]; const auto specDim = spec.dims[dim]; if (requiredDim.first != -1 && specDim.first != -1 && !(specDim.first <= requiredDim.first && specDim.second >= requiredDim.second)) { return false; } } } return true; } std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& spec, const ImplSpec& requiredSpecs) const { // Original graph is: // --> {required IO specs} [node] {required IO specs} --> // Obtained meta-op is: // --> {required IO specs} [adapt inputs] --> {IO specs} [node] {IO specs} --> [adapt outputs] {required IO specs} auto op = std::static_pointer_cast<OperatorTensor>(mOp.clone()); auto node = std::make_shared<Node>(op); // Adapt inputs for (size_t i = 0; i < requiredSpecs.inputs.size(); ++i) { const auto IOSpec = (i < spec.inputs.size()) ? spec.inputs[i] : spec.inputs.back(); const ImplSpec::IOSpec& requiredIOSpec = requiredSpecs.inputs[i]; std::shared_ptr<Node> parent = node; // Input type if (requiredIOSpec.type != DataType::Any && IOSpec.type != DataType::Any && requiredIOSpec.type != IOSpec.type) { const auto cast = Cast(IOSpec.type); cast->getOperator()->setBackend(node->getOperator()->backend()); cast->addChild(parent, 0, i); op->getInput(i)->setDataType(IOSpec.type); } // Input format if (requiredIOSpec.format != DataFormat::Any && IOSpec.format != DataFormat::Any && requiredIOSpec.format != IOSpec.format) { const auto transpose = getDataFormatTranspose(requiredIOSpec.format, IOSpec.format); auto transposeOp = Transpose(std::vector<DimSize_t>(transpose.begin(), transpose.end())); transposeOp->getOperator()->setDataFormat(IOSpec.format); transposeOp->getOperator()->setDataType(requiredIOSpec.type); transposeOp->getOperator()->setBackend(node->getOperator()->backend()); transposeOp->addChild(parent, 0, i); op->getInput(i)->setDataFormat(IOSpec.format); } // Input dims if (!requiredIOSpec.dims.empty() && !IOSpec.dims.empty()) { if (requiredIOSpec.dims.size() != IOSpec.dims.size()) { return nullptr; } for (size_t dim = 0; dim < requiredIOSpec.dims.size(); ++dim) { const auto requiredDim = requiredIOSpec.dims[dim]; const auto specDim = IOSpec.dims[dim]; if (requiredDim.first != -1 && specDim.first != -1 && !(specDim.first <= requiredDim.first && specDim.second >= requiredDim.second)) { return nullptr; } } } } // Adapt outputs for (size_t i = 0; i < requiredSpecs.outputs.size(); ++i) { const auto IOSpec = (i < spec.outputs.size()) ? spec.outputs[i] : spec.outputs.back(); const ImplSpec::IOSpec& requiredIOSpec = requiredSpecs.outputs[i]; std::shared_ptr<Node> parent = node; // Output type if (requiredIOSpec.type != DataType::Any && IOSpec.type != DataType::Any && requiredIOSpec.type != IOSpec.type) { const auto cast = Cast(requiredIOSpec.type); cast->getOperator()->setBackend(node->getOperator()->backend()); parent->addChild(cast, i, 0); op->getOutput(i)->setDataType(IOSpec.type); } // Output format if (requiredIOSpec.format != DataFormat::Any && IOSpec.format != DataFormat::Any && requiredIOSpec.format != IOSpec.format) { const auto transpose = getDataFormatTranspose(IOSpec.format, requiredIOSpec.format); auto transposeOp = Transpose(std::vector<DimSize_t>(transpose.begin(), transpose.end())); transposeOp->getOperator()->setDataFormat(requiredIOSpec.format); transposeOp->getOperator()->setDataType(requiredIOSpec.type); transposeOp->getOperator()->setBackend(node->getOperator()->backend()); parent->addChild(transposeOp, i, 0); op->getOutput(i)->setDataFormat(IOSpec.format); } // Output dims if (!requiredIOSpec.dims.empty() && !IOSpec.dims.empty()) { if (requiredIOSpec.dims.size() != IOSpec.dims.size()) { return nullptr; } for (size_t dim = 0; dim < requiredIOSpec.dims.size(); ++dim) { const auto requiredDim = requiredIOSpec.dims[dim]; const auto specDim = IOSpec.dims[dim]; if (requiredDim.first != -1 && specDim.first != -1 && !(specDim.first <= requiredDim.first && specDim.second >= requiredDim.second)) { return nullptr; } } } } auto adaptedGraph = getConnectedGraphView(node); if (adaptedGraph->getNodes().size() > 1) { return MetaOperator(std::string("Adapted_" + op->type()).c_str(), adaptedGraph); } else { return node; } } std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getBestAdaptation(const ImplSpec& requiredSpecs) const { const auto availableSpecs = getAvailableImplSpecs(); Log::debug("Adapt operator type {}: {} impl. available", mOp.type(), availableSpecs.size()); using AdaptationCost = int; std::map<std::shared_ptr<Node>, AdaptationCost> adaptations; for (const auto& availableSpec : availableSpecs) { auto adaptation = getAdaptation(availableSpec, requiredSpecs); if (adaptation) { if (adaptation->getOperator()->isAtomic()) { adaptations.insert(std::make_pair(adaptation, 1)); } else { auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(adaptation->getOperator())->getMicroGraph(); adaptations.insert(std::make_pair(adaptation, microGraph->getNodes().size())); } } } Log::debug("Adapt operator type {}: found {} possible adaptations", mOp.type(), adaptations.size()); if (!adaptations.empty()) { // Return best adaptation (with min. AdaptationCost) const auto bestAdaptation = std::min_element(adaptations.begin(), adaptations.end(), [](const auto& lhs, const auto& rhs) { return lhs.second < rhs.second; }); return bestAdaptation->first; } return nullptr; } void Aidge::OperatorImpl::forward() { AIDGE_THROW_OR_ABORT(std::runtime_error, "forward() not implemented yet for operator of type {}", mOp.type()); } void Aidge::OperatorImpl::backward() { AIDGE_THROW_OR_ABORT(std::runtime_error, "backward() not implemented yet for operator of type {}", mOp.type()); } std::shared_ptr<Aidge::ProdConso> Aidge::OperatorImpl::getProdConso() const { return std::make_shared<ProdConso>(mOp); } std::vector<Aidge::ImplSpec> Aidge::OperatorImpl::getAvailableImplSpecs() const { return std::vector<ImplSpec>(); }