Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
GenericOperator.cpp 6.72 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include "aidge/operator/GenericOperator.hpp"

#include <cstddef>  // std::size_t
#include <vector>

#include "aidge/data/Tensor.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"


Aidge::GenericOperator_Op::GenericOperator_Op(const std::string& type,
                                            const std::vector<Aidge::InputCategory>& inputsCategory,
                                            Aidge::IOIndex_t nbOut)
    : OperatorTensor(type, inputsCategory, nbOut),
        mAttributes(std::make_shared<DynamicAttributes>())
{
    mImpl = std::make_shared<OperatorImpl>(*this);
}

Aidge::GenericOperator_Op::GenericOperator_Op(const std::string& type,
                                            Aidge::IOIndex_t nbData,
                                            Aidge::IOIndex_t nbParam,
                                            Aidge::IOIndex_t nbOut)
    : OperatorTensor(type, [nbData, nbParam]() {
                            std::vector<InputCategory> inputsCategory(nbData, InputCategory::Data);
                            inputsCategory.resize(nbData + nbParam, InputCategory::Param);
                            return inputsCategory;
                        }(), nbOut),
        mAttributes(std::make_shared<DynamicAttributes>())
{
    mImpl = std::make_shared<OperatorImpl>(*this);
}

Aidge::GenericOperator_Op::GenericOperator_Op(const Aidge::GenericOperator_Op& op)
    : OperatorTensor(op),
        mForwardDims(op.mForwardDims),
        mAttributes(std::make_shared<DynamicAttributes>(*op.mAttributes))
{
    mImpl = std::make_shared<OperatorImpl>(*this, op.backend());
}

Aidge::GenericOperator_Op::~GenericOperator_Op() noexcept = default;

std::shared_ptr<Aidge::Operator> Aidge::GenericOperator_Op::clone() const {
    return std::make_shared<GenericOperator_Op>(*this);
}

const Aidge::GenericOperator_Op::ComputeDimsFunc Aidge::GenericOperator_Op::Identity
    = [](const std::vector<std::vector<std::size_t>>& inputsDims) { return inputsDims; };

const Aidge::GenericOperator_Op::ComputeDimsFunc Aidge::GenericOperator_Op::InputIdentity(IOIndex_t inputIdx, IOIndex_t nbOutputs) {
    return [nbOutputs, inputIdx](const std::vector<std::vector<std::size_t>>& inputsDims) { return std::vector<std::vector<std::size_t>>(nbOutputs, inputsDims[inputIdx]); };
}

bool Aidge::GenericOperator_Op::forwardDims(bool /*allowDataDependency*/) {
    if (mForwardDims && inputsAssociated(false)) {
        std::vector<std::vector<std::size_t>> inputsDims(nbInputs(), std::vector<std::size_t>());
        for (std::size_t i = 0; i < nbInputs(); ++i) {
            // Check for input, as it may be optional
            if (getInput(i)) {
                inputsDims[i] = getInput(i)->dims();
            }
        }

        const auto& outputsDims = mForwardDims(inputsDims);
        AIDGE_ASSERT(!outputsDims.empty(), "The provided ComputeDimsFunc cannot compute the output dims (an empty vector was returned)");
        AIDGE_ASSERT(outputsDims.size() == nbOutputs(), "The provided ComputeDimsFunc function returned the wrong number of outputs: {}, but {} are expected", outputsDims.size(), nbOutputs());
        for (std::size_t i = 0; i < nbOutputs(); ++i) {
            mOutputs[i]->resize(outputsDims[i]);
        }
        return true;
    }
    else {
        Log::warn("GenericOperator: cannot compute output dims, no ComputeDimsFunc function provided.");
        return false;
    }
}

void Aidge::GenericOperator_Op::setBackend(const std::string & name, DeviceIdx_t device) {
    if (Registrar<GenericOperator_Op>::exists({name, type()})) {
        // A custom implementation exists for this meta operator
        mImpl = Registrar<GenericOperator_Op>::create({name, type()})(*this);
    }else{
        Log::warn("GenericOperator::setBackend(): cannot set backend for a generic operator, as no implementation has been provided!");
    }




    for (std::size_t i = 0; i < nbOutputs(); ++i) {
        mOutputs[i]->setBackend(name, device);
    }
}

///////////////////////////////////////////////

std::shared_ptr<Aidge::Node> Aidge::GenericOperator(const std::string& type,
                                            const std::vector<Aidge::InputCategory>& inputCategory,
                                            Aidge::IOIndex_t nbOut,
                                            const std::string& name) {
    return std::make_shared<Node>(std::make_shared<GenericOperator_Op>(type, inputCategory, nbOut), name);
}

std::shared_ptr<Aidge::Node> Aidge::GenericOperator(const std::string& type,
                                                Aidge::IOIndex_t nbData,
                                                Aidge::IOIndex_t nbParam,
                                                Aidge::IOIndex_t nbOut,
                                                const std::string& name) {
    return std::make_shared<Node>(std::make_shared<GenericOperator_Op>(type, nbData, nbParam, nbOut), name);
}

std::shared_ptr<Aidge::Node> Aidge::GenericOperator(const std::string& type,
                                            std::shared_ptr<OperatorTensor> op,
                                            const std::string& name)
{
    // Create a generic op with the same inputs/outputs
    auto genericOp = std::make_shared<GenericOperator_Op>(type, op->inputCategory(), op->nbOutputs());

    // Copy attributes
    genericOp->setAttrs(op->attributes()->getAttrs());

    // Set a default forward dims if possible
    if (op->dimsForwarded()) {
        auto opInputDims = std::vector<std::vector<DimSize_t>>(op->nbInputs());
        for (size_t i = 0; i < op->nbInputs(); ++i) {
            opInputDims[i] = op->getInput(i)->dims();
        }

        auto opOutputDims = std::vector<std::vector<DimSize_t>>(op->nbOutputs());
        for (size_t o = 0; o < op->nbOutputs(); ++o) {
            opOutputDims[o] = op->getOutput(o)->dims();
        }

        genericOp->setForwardDims([opInputDims, opOutputDims](const std::vector<std::vector<std::size_t>>& inputsDims) {
            // Check input dims
            for (size_t i = 0; i < opInputDims.size(); ++i) {
                if (inputsDims[i] != opInputDims[i]) {
                    // No matching => unable to compute output dims!
                    return std::vector<std::vector<std::size_t>>();
                }
            }
            return opOutputDims;
        });
    }

    return std::make_shared<Node>(genericOp, name);
}