Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
OperatorImpl.cpp 15.36 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include <cassert>
#include <string>

#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Operator.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Transpose.hpp"
#include "aidge/operator/Cast.hpp"
#include "aidge/operator/MetaOperator.hpp"
#include "aidge/scheduler/ProdConso.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"

Aidge::ImplSpec::ImplSpec(const DynamicAttributes& attrs_):
    attrs(attrs_) {}
Aidge::ImplSpec::ImplSpec(const IOSpec& io, const DynamicAttributes& attrs_):
    inputs(1, io), outputs(1, io), attrs(attrs_) {}
Aidge::ImplSpec::ImplSpec(const IOSpec& i, const IOSpec& o, const DynamicAttributes& attrs_):
    inputs(1, i), outputs(1, o), attrs(attrs_) {}
Aidge::ImplSpec::ImplSpec(const std::vector<IOSpec>& i, const std::vector<IOSpec>& o, const DynamicAttributes& attrs_):
    inputs(i), outputs(o), attrs(attrs_) {}
Aidge::ImplSpec::ImplSpec(const Aidge::ImplSpec&) = default;
Aidge::ImplSpec::~ImplSpec() noexcept = default;

Aidge::OperatorImpl::OperatorImpl(const Operator& op, const std::string& backend):
    mOp(op),
    mBackend(backend)
{
    //ctor
}

std::shared_ptr<Aidge::ProdConso> Aidge::OperatorImpl::prodConso() {
    if (!mProdConso) {
        mProdConso = getProdConso();
    }
    return mProdConso;
}

Aidge::ImplSpec Aidge::OperatorImpl::getRequiredSpec() const {
    const auto& opTensor = dynamic_cast<const OperatorTensor&>(mOp);

    ImplSpec requiredSpec;
    // Inputs specs
    for (size_t i = 0; i < opTensor.nbInputs(); ++i) {
        if (opTensor.getInput(i)) {
            std::vector<std::pair<int, int>> dims;
            for (auto dim : opTensor.getInput(i)->dims()) {
                dims.push_back(std::make_pair<int, int>(dim, dim));
            }

            requiredSpec.inputs.push_back({opTensor.getInput(i)->dataType(), opTensor.getInput(i)->dataFormat(), dims});
        }
        else {
            requiredSpec.inputs.push_back({DataType::Any});
        }
    }
    // Outputs specs
    for (size_t i = 0; i < opTensor.nbOutputs(); ++i) {
        std::vector<std::pair<int, int>> dims;
        for (auto dim : opTensor.getOutput(i)->dims()) {
            dims.push_back(std::make_pair<int, int>(dim, dim));
        }

        requiredSpec.outputs.push_back({opTensor.getOutput(i)->dataType(), opTensor.getOutput(i)->dataFormat(), dims});
    }
    // Attributes
    if (!mOp.isAtomic()) {
        requiredSpec.attrs.setAttr("type:!", mOp.type()); // :! mandatory qualifier
    }
    else {
        requiredSpec.attrs.setAttr("type", mOp.type());
    }

    const auto& inhAttrs = mOp.inheritedAttributes();
    if (inhAttrs) {
        if (inhAttrs->hasAttr("impl")) {
            requiredSpec.attrs.setAttr("impl", inhAttrs->getAny("impl"));
        }
    }
    return requiredSpec;
}

Aidge::ImplSpec Aidge::OperatorImpl::getBestMatch(const ImplSpec& requiredSpecs) const {
    Log::debug("getBestMatch() for requirements: {}", requiredSpecs);

    const auto availableSpecsSet = getAvailableImplSpecs();
    AIDGE_ASSERT(availableSpecsSet.size() > 0 ,
                 "OperatorImpl::getBestMatch(): No available specs found by"
                 "getAvailableSpecs(). "
                 "Cannot find best implementation for required specs, aborting.");
    const std::vector<ImplSpec> availableSpecs(availableSpecsSet.begin(), availableSpecsSet.end());
    std::vector<int> matchingSpecs(availableSpecs.size(), -1);

    for (size_t s = 0; s < availableSpecs.size(); ++s) {
        auto spec = availableSpecs[s];
        bool match = true;
        int priority = 0;

        // Check inputs
        for (size_t i = 0; i < requiredSpecs.inputs.size(); ++i) {
            const auto inputSpec = (i < spec.inputs.size()) ? spec.inputs[i] : spec.inputs.back();
            if (!checkIOSpec(requiredSpecs.inputs[i], inputSpec)) {
                match = false;
                break;
            }
        }

        // Check outputs
        for (size_t i = 0; i < requiredSpecs.outputs.size(); ++i) {
            const auto outputSpec = (i < spec.outputs.size()) ? spec.outputs[i] : spec.outputs.back();
            if (!checkIOSpec(requiredSpecs.outputs[i], outputSpec)) {
                match = false;
                break;
            }
        }

        // Check attributes
        for (const auto& attrName : requiredSpecs.attrs.getAttrsName()) {
            std::string name = attrName;
            std::string qualifier;
            const auto qualifierPos = std::find_if(attrName.begin(), attrName.end(),
                [](char c) { return c == ':'; });
            if (qualifierPos != attrName.end()) {
                name = attrName.substr(0, qualifierPos - attrName.begin());
                qualifier = attrName.substr(qualifierPos - attrName.begin() + 1);
            }
            const bool mandatory = (qualifier == "!");
            if (mandatory) {
                // Required attribute:
                if (!spec.attrs.hasAttr(name)) {
                    Log::debug("Could not find mandatory attribute '{}'.", name);
                    // Missing attribute
                    match = false;
                    break;
                }
                else if (requiredSpecs.attrs.getAny(attrName) < spec.attrs.getAny(name)
                    || spec.attrs.getAny(name) < requiredSpecs.attrs.getAny(attrName))
                {
                    Log::debug("Attribute ({}) value mismatch {} != {}.", name, requiredSpecs.attrs.getAttr<std::string>(attrName), spec.attrs.getAttr<std::string>(name));
                    // Attribute value mismatch
                    match = false;
                    break;
                }
            }
            else {
                const int attrPriority = (!qualifier.empty()) ? std::stoi(qualifier) : 0;

                if (spec.attrs.hasAttr(name)
                    && !(requiredSpecs.attrs.getAny(attrName) < spec.attrs.getAny(name))
                    && !(spec.attrs.getAny(name) < requiredSpecs.attrs.getAny(attrName)))
                {
                    // Attribute value match
                    priority = std::max(priority, attrPriority);
                }
            }
        }

        if (match) {
            matchingSpecs[s] = priority;
        }

        Log::debug("  {}:{} - {}", (match) ? "MATCH" : "MISMATCH", priority, spec);
    }

    if(matchingSpecs.empty()){
        Log::debug("  No spec to match registered, returning requiredSpecs.");
        return requiredSpecs;
    }
    // Return best match
    const auto bestMatch = std::max_element(matchingSpecs.begin(), matchingSpecs.end());
    if (*bestMatch >= 0) {
        const auto bestSpecIdx = bestMatch - matchingSpecs.begin();
        return availableSpecs[bestSpecIdx];
    }

    // If there is no match, return the required specs for the registrar, which
    // will throw a "missing or invalid registrar key"
    return requiredSpecs;
}

bool Aidge::OperatorImpl::checkIOSpec(const ImplSpec::IOSpec& required, const ImplSpec::IOSpec& spec) const {
    // Check type
    if (required.type != DataType::Any
        && spec.type != DataType::Any
        && required.type != spec.type)
    {
        return false;
    }

    // Check format
    if (required.format != DataFormat::Any
        && spec.format != DataFormat::Any
        && required.format != spec.format)
    {
        const auto transpose = getDataFormatTranspose(required.format, spec.format);
        std::vector<size_t> identity(transpose.size());
        std::iota(std::begin(identity), std::end(identity), 0);

        if (!std::equal(transpose.begin(), transpose.end(), identity.begin())) {
            return false;
        }
    }

    // Check dims
    if (!required.dims.empty() && !spec.dims.empty()) {
        if (required.dims.size() != spec.dims.size()) {
            return false;
        }

        for (size_t dim = 0; dim < required.dims.size(); ++dim) {
            const auto requiredDim = required.dims[dim];
            const auto specDim = spec.dims[dim];

            if (requiredDim.first != -1
                && specDim.first != -1
                && !(specDim.first <= requiredDim.first && specDim.second >= requiredDim.second))
            {
                return false;
            }
        }
    }

    return true;
}

std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& spec, const ImplSpec& requiredSpecs) const {
    // Original graph is:
    //                               --> {required IO specs} [node] {required IO specs} -->
    // Obtained meta-op is:
    // --> {required IO specs} [adapt inputs] --> {IO specs} [node] {IO specs} --> [adapt outputs] {required IO specs}

    auto op = std::static_pointer_cast<OperatorTensor>(mOp.clone());
    auto node = std::make_shared<Node>(op);

    // Adapt inputs
    for (size_t i = 0; i < requiredSpecs.inputs.size(); ++i) {
        const auto IOSpec = (i < spec.inputs.size()) ? spec.inputs[i] : spec.inputs.back();
        const ImplSpec::IOSpec& requiredIOSpec = requiredSpecs.inputs[i];
        std::shared_ptr<Node> parent = node;

        // Input type
        if (requiredIOSpec.type != DataType::Any
            && IOSpec.type != DataType::Any
            && requiredIOSpec.type != IOSpec.type)
        {
            const auto cast = Cast(IOSpec.type);
            cast->getOperator()->setBackend(node->getOperator()->backend());
            cast->addChild(parent, 0, i);

            op->getInput(i)->setDataType(IOSpec.type);
        }

        // Input format
        if (requiredIOSpec.format != DataFormat::Any
            && IOSpec.format != DataFormat::Any
            && requiredIOSpec.format != IOSpec.format)
        {
            const auto transpose = getDataFormatTranspose(requiredIOSpec.format, IOSpec.format);
            auto transposeOp = Transpose(std::vector<DimSize_t>(transpose.begin(), transpose.end()));
            transposeOp->getOperator()->setDataFormat(IOSpec.format);
            transposeOp->getOperator()->setDataType(requiredIOSpec.type);
            transposeOp->getOperator()->setBackend(node->getOperator()->backend());
            transposeOp->addChild(parent, 0, i);

            op->getInput(i)->setDataFormat(IOSpec.format);
        }

        // Input dims
        if (!requiredIOSpec.dims.empty() && !IOSpec.dims.empty()) {
            if (requiredIOSpec.dims.size() != IOSpec.dims.size()) {
                return nullptr;
            }

            for (size_t dim = 0; dim < requiredIOSpec.dims.size(); ++dim) {
                const auto requiredDim = requiredIOSpec.dims[dim];
                const auto specDim = IOSpec.dims[dim];

                if (requiredDim.first != -1
                    && specDim.first != -1
                    && !(specDim.first <= requiredDim.first && specDim.second >= requiredDim.second))
                {
                    return nullptr;
                }
            }
        }
    }

    // Adapt outputs
    for (size_t i = 0; i < requiredSpecs.outputs.size(); ++i) {
        const auto IOSpec = (i < spec.outputs.size()) ? spec.outputs[i] : spec.outputs.back();
        const ImplSpec::IOSpec& requiredIOSpec = requiredSpecs.outputs[i];
        std::shared_ptr<Node> parent = node;

        // Output type
        if (requiredIOSpec.type != DataType::Any
            && IOSpec.type != DataType::Any
            && requiredIOSpec.type != IOSpec.type)
        {
            const auto cast = Cast(requiredIOSpec.type);
            cast->getOperator()->setBackend(node->getOperator()->backend());
            parent->addChild(cast, i, 0);

            op->getOutput(i)->setDataType(IOSpec.type);
        }

        // Output format
        if (requiredIOSpec.format != DataFormat::Any
            && IOSpec.format != DataFormat::Any
            && requiredIOSpec.format != IOSpec.format)
        {
            const auto transpose = getDataFormatTranspose(IOSpec.format, requiredIOSpec.format);
            auto transposeOp = Transpose(std::vector<DimSize_t>(transpose.begin(), transpose.end()));
            transposeOp->getOperator()->setDataFormat(requiredIOSpec.format);
            transposeOp->getOperator()->setDataType(requiredIOSpec.type);
            transposeOp->getOperator()->setBackend(node->getOperator()->backend());
            parent->addChild(transposeOp, i, 0);

            op->getOutput(i)->setDataFormat(IOSpec.format);
        }

        // Output dims
        if (!requiredIOSpec.dims.empty() && !IOSpec.dims.empty()) {
            if (requiredIOSpec.dims.size() != IOSpec.dims.size()) {
                return nullptr;
            }

            for (size_t dim = 0; dim < requiredIOSpec.dims.size(); ++dim) {
                const auto requiredDim = requiredIOSpec.dims[dim];
                const auto specDim = IOSpec.dims[dim];

                if (requiredDim.first != -1
                    && specDim.first != -1
                    && !(specDim.first <= requiredDim.first && specDim.second >= requiredDim.second))
                {
                    return nullptr;
                }
            }
        }
    }

    auto adaptedGraph = getConnectedGraphView(node);
    if (adaptedGraph->getNodes().size() > 1) {
        return MetaOperator(std::string("Adapted_" + op->type()).c_str(), adaptedGraph);
    }
    else {
        return node;
    }
}

std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getBestAdaptation(const ImplSpec& requiredSpecs) const {
    const auto availableSpecs = getAvailableImplSpecs();
    Log::debug("Adapt operator type {}: {} impl. available", mOp.type(), availableSpecs.size());

    using AdaptationCost = int;
    std::map<std::shared_ptr<Node>, AdaptationCost> adaptations;

    for (const auto& availableSpec : availableSpecs) {
        auto adaptation = getAdaptation(availableSpec, requiredSpecs);

        if (adaptation) {
            if (adaptation->getOperator()->isAtomic()) {
                adaptations.insert(std::make_pair(adaptation, 1));
            }
            else {
                auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(adaptation->getOperator())->getMicroGraph();
                adaptations.insert(std::make_pair(adaptation, microGraph->getNodes().size()));
            }
        }
    }

    Log::debug("Adapt operator type {}: found {} possible adaptations", mOp.type(), adaptations.size());

    if (!adaptations.empty()) {
        // Return best adaptation (with min. AdaptationCost)
        const auto bestAdaptation = std::min_element(adaptations.begin(), adaptations.end(),
            [](const auto& lhs, const auto& rhs) { return lhs.second < rhs.second; });
        return bestAdaptation->first;
    }

    return nullptr;
}

void Aidge::OperatorImpl::forward() {
    AIDGE_THROW_OR_ABORT(std::runtime_error, "forward() not implemented yet for operator of type {}", mOp.type());
}

void Aidge::OperatorImpl::backward() {
    AIDGE_THROW_OR_ABORT(std::runtime_error, "backward() not implemented yet for operator of type {}", mOp.type());
}

std::shared_ptr<Aidge::ProdConso> Aidge::OperatorImpl::getProdConso() const {
    return std::make_shared<ProdConso>(mOp);
}

std::vector<Aidge::ImplSpec> Aidge::OperatorImpl::getAvailableImplSpecs() const {
    return std::vector<ImplSpec>();
}