Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
Producer.cpp 7.29 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include "aidge/operator/Producer.hpp"

#include <cstddef>
#include <array>
#include <memory>
#include <string>

#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Types.h"


const std::string Aidge::Producer_Op::Type = "Producer";

template <std::size_t DIM>
Aidge::Producer_Op::Producer_Op(
            const std::array<Aidge::DimSize_t, DIM>& dims,
            bool constant)
    : OperatorTensor(Type, {}, 1),
        mAttributes(std::make_shared<Attributes_>(
        attr<ProdAttr::Constant>(constant)))
{
    mOutputs[0]->resize(dims);
    mImpl = std::make_shared<OperatorImpl>(*this);
}

Aidge::Producer_Op::Producer_Op(const std::shared_ptr<Aidge::Tensor> tensor, bool constant)
    : OperatorTensor(Type, {}, 1),
      mAttributes(std::make_shared<Attributes_>(
        attr<ProdAttr::Constant>(constant)))
{
    mOutputs[0] = tensor; // copy the pointer of the Tensor
    if (mOutputs[0] && mOutputs[0]->hasImpl() && Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()})){
        SET_IMPL_MACRO(Producer_Op, *this, mOutputs[0]->getImpl()->backend());
    }
    else {
        mImpl = std::make_shared<OperatorImpl>(*this);
    }
}

/**
 * @brief Copy-constructor. Copy the operator attributes and its output tensor(s),
 * but not its input tensors (the new operator has no input associated).
 * @param op OperatorTensor to copy.
 */
Aidge::Producer_Op::Producer_Op(const Aidge::Producer_Op& op)
    : OperatorTensor(op),
      mAttributes(op.mAttributes)
{
    *mOutputs[0] = *(op.getOutput(0));
    if (mOutputs[0]->getImpl() && Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()})){
        SET_IMPL_MACRO(Producer_Op, *this, mOutputs[0]->getImpl()->backend());
    }
    else {
        mImpl = std::make_shared<OperatorImpl>(*this);
    }
}

std::shared_ptr<Aidge::Operator> Aidge::Producer_Op::clone() const {
    // mOutput cannot be nullptr because of OperatorTensor constructor
    std::shared_ptr<Tensor> newTensor = std::make_shared<Tensor>(mOutputs[0]->clone());

    std::shared_ptr<Producer_Op> newOp = std::make_shared<Producer_Op>(newTensor, constant());

    return newOp;
}

void Aidge::Producer_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
    if (Registrar<Producer_Op>::exists({name})){
        SET_IMPL_MACRO(Producer_Op, *this, name);
    }
    else {
        mImpl = std::make_shared<OperatorImpl>(*this);
    }
    mOutputs[0]->setBackend(name, device);
}

std::set<std::string> Aidge::Producer_Op::getAvailableBackends() const {
    return Registrar<Producer_Op>::getKeys();
}

void Aidge::Producer_Op::forward() {
    if (!backend().empty()) {
        mImpl->forward();
    }
}

void Aidge::Producer_Op::setOutput(const Aidge::IOIndex_t outputIdx, const std::shared_ptr<Aidge::Data>& data) const {
    if (mAttributes->template getAttr<ProdAttr::Constant>()) {
        AIDGE_THROW_OR_ABORT(std::runtime_error, "Producer is constant, cannot update output.");
    }
    OperatorTensor::setOutput(outputIdx, data);
}

/////////////////////////////////////////////

template <std::array<Aidge::DimSize_t, 1>::size_type DIM>
std::shared_ptr<Aidge::Node> Aidge::Producer(const std::array<Aidge::DimSize_t, DIM> &dims,
        const std::string& name,
        bool constant)
{
  static_assert(DIM<=MaxDim,"Too many tensor dimensions required by Producer, not supported");
  return std::make_shared<Node>(std::make_shared<Producer_Op>(dims, constant), name);
}

template std::shared_ptr<Aidge::Node> Aidge::Producer<1>(const std::array<Aidge::DimSize_t, 1>&, const std::string&, bool);
template std::shared_ptr<Aidge::Node> Aidge::Producer<2>(const std::array<Aidge::DimSize_t, 2>&, const std::string&, bool);
template std::shared_ptr<Aidge::Node> Aidge::Producer<3>(const std::array<Aidge::DimSize_t, 3>&, const std::string&, bool);
template std::shared_ptr<Aidge::Node> Aidge::Producer<4>(const std::array<Aidge::DimSize_t, 4>&, const std::string&, bool);
template std::shared_ptr<Aidge::Node> Aidge::Producer<5>(const std::array<Aidge::DimSize_t, 5>&, const std::string&, bool);
template std::shared_ptr<Aidge::Node> Aidge::Producer<6>(const std::array<Aidge::DimSize_t, 6>&, const std::string&, bool);
template std::shared_ptr<Aidge::Node> Aidge::Producer<7>(const std::array<Aidge::DimSize_t, 7>&, const std::string&, bool);
template std::shared_ptr<Aidge::Node> Aidge::Producer<8>(const std::array<Aidge::DimSize_t, 8>&, const std::string&, bool);
template std::shared_ptr<Aidge::Node> Aidge::Producer<9>(const std::array<Aidge::DimSize_t, 9>&, const std::string&, bool);
template std::shared_ptr<Aidge::Node> Aidge::Producer<10>(const std::array<Aidge::DimSize_t, 10>&, const std::string&, bool);

std::shared_ptr<Aidge::Node> Aidge::Producer(const std::shared_ptr<Aidge::Tensor> tensor,
            const std::string& name,
            bool constant)
{
    return std::make_shared<Node>(std::make_shared<Producer_Op>(tensor, constant), name);
}

template <std::array<Aidge::DimSize_t, 1>::size_type DIM>
std::shared_ptr<Aidge::Node> Aidge::addProducer(std::shared_ptr<Aidge::Node>& otherNode,
        const IOIndex_t inputIdx,
        const std::array<Aidge::DimSize_t, DIM>& dims,
        const std::string& extension)
{
    AIDGE_ASSERT(inputIdx < gk_IODefaultIndex, "Input index too high. Cannot create Producer");
    static_assert(DIM<=MaxDim,"Too many tensor dimensions required by addProducer, not supported");
    const std::string prodName = (otherNode->name().empty()) ? "" : (otherNode->name() + std::string("_") + extension);
    auto prod = Producer(dims, prodName);
    prod->addChild(otherNode, 0, inputIdx);
    otherNode->getOperator()->associateInput(inputIdx, prod->getOperator()->getRawOutput(0));
    return prod;
}

template std::shared_ptr<Aidge::Node> Aidge::addProducer<1>(std::shared_ptr<Aidge::Node>& otherNode,
        const IOIndex_t inputIdx,
        const std::array<Aidge::DimSize_t, 1>& dims,
        const std::string& extension);
template std::shared_ptr<Aidge::Node> Aidge::addProducer<2>(std::shared_ptr<Aidge::Node>& otherNode,
        const IOIndex_t inputIdx,
        const std::array<Aidge::DimSize_t, 2>& dims,
        const std::string& extension);
template std::shared_ptr<Aidge::Node> Aidge::addProducer<3>(std::shared_ptr<Aidge::Node>& otherNode,
        const IOIndex_t inputIdx,
        const std::array<Aidge::DimSize_t, 3>& dims,
        const std::string& extension);
template std::shared_ptr<Aidge::Node> Aidge::addProducer<4>(std::shared_ptr<Aidge::Node>& otherNode,
        const IOIndex_t inputIdx,
        const std::array<Aidge::DimSize_t, 4>& dims,
        const std::string& extension);
template std::shared_ptr<Aidge::Node> Aidge::addProducer<5>(std::shared_ptr<Aidge::Node>& otherNode,
        const IOIndex_t inputIdx,
        const std::array<Aidge::DimSize_t, 5>& dims,
        const std::string& extension);