Skip to content
Snippets Groups Projects
Concat.cpp 3.69 KiB
Newer Older
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include "aidge/data/Tensor.hpp"
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Types.h"

void Aidge::Concat_OpImpl::forward() {
    const Concat_Op& op = dynamic_cast<const Concat_Op&>(mOp);
    const DimSize_t axis = op.template getAttr<DimSize_t>("Axis");

    assert(op.getInput(0) && "missing input in Concat operator");
    DataType datatypeFirstInput = op.getInput(0)->dataType();
    for (IOIndex_t i = 1; i < mOp.nbInputs(); ++i) {
        assert(op.getInput(i) && "missing input in Concat operator");
        assert(op.getInput(i)->dataType() == datatypeFirstInput);
    }

    DimSize_t outputAxisValue = 0;
    for (IOIndex_t i = 0; i < mOp.nbInputs(); ++i) {
        outputAxisValue += op.getInput(i)->dims()[axis];
    }

    DimSize_t prodDimLower = 1;
    for (DimIdx_t i = 0; i < axis; ++i) {
        prodDimLower *= op.getInput(0)->dims()[i];
    }
    DimSize_t prodDimHigher = 1;
    for (DimIdx_t i = axis + 1; static_cast<std::size_t>(i) < op.getInput(0)->dims().size();
         ++i) {
        prodDimHigher *= op.getInput(0)->dims()[i];
    }

    std::size_t oIndexStart = 0;
Maxence Naud's avatar
Maxence Naud committed
    // std::size_t oIndex = 0;
    for (std::size_t inputId = 0; inputId < op.nbInputs(); ++inputId) {
Maxence Naud's avatar
Maxence Naud committed
        // oIndex = oIndexStart;
        const DimSize_t iOffset = prodDimHigher*op.getInput(inputId)->dims()[axis];
Maxence Naud's avatar
Maxence Naud committed
        for (std::size_t iIndex = 0, oIndex = oIndexStart; iIndex < prodDimLower; ++iIndex) {
            op.getOutput(0)->getImpl()->copy(op.getInput(inputId)->getImpl()->rawPtr(iIndex*iOffset), iOffset, oIndex);
            oIndex += prodDimHigher*outputAxisValue;
        }
        oIndexStart += op.getInput(inputId)->dims()[axis]*prodDimHigher;
    }
}

const std::string Aidge::Concat_Op::Type = "Concat";

bool Aidge::Concat_Op::forwardDims(bool /*allowDataDependency*/) {
    // Every input is non-empty with the same number of dimensions
    bool associated = (getInput(0) != nullptr);
    associated &= !(getInput(0)->empty()) && (getAttr<ConcatAttr::Axis>() < getInput(0)->nbDims()); // do not compute anything if no input
    auto outputDims =  getInput(0)->dims();
    const auto firstInputNbDims = getInput(0) -> nbDims();
    for (IOIndex_t i = 1; i < nbInputs(); ++i) {
        if (!getInput(i)) {
            AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #{} should be associated with a Tensor", type(), i);
        }

        if (getInput(i)->nbDims() == firstInputNbDims) {
            for (DimSize_t dim = 0; dim < firstInputNbDims; ++dim) {
                if (dim == getAttr<ConcatAttr::Axis>()) {
                    outputDims[dim] += getInput(i)->dims()[dim];
                }
                else {
                    associated &= (getInput(i)->dims()[dim] == outputDims[dim]);
                }
            }
        }
        else {
            associated = false;
            break;
        }
    }
    if (associated) {
        getOutput(0)->resize(outputDims);
    }

    return associated;
void Aidge::Concat_Op::setBackend(const std::string& name, DeviceIdx_t device) {
    if (Registrar<Concat_Op>::exists({name})) {
        SET_IMPL_MACRO(Concat_Op, *this, name);
    }
    else {
        mImpl = std::make_shared<Concat_OpImpl>(*this);
    }