Newer
Older
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/

Maxence Naud
committed
#include "aidge/operator/Concat.hpp"
#include <string>

Maxence Naud
committed
#include <vector>

Maxence Naud
committed
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Types.h"
void Aidge::Concat_OpImpl::forward() {
const Concat_Op& op = dynamic_cast<const Concat_Op&>(mOp);
const DimSize_t axis = op.template getAttr<DimSize_t>("Axis");
assert(op.getInput(0) && "missing input in Concat operator");
DataType datatypeFirstInput = op.getInput(0)->dataType();
for (IOIndex_t i = 1; i < mOp.nbInputs(); ++i) {
assert(op.getInput(i) && "missing input in Concat operator");
assert(op.getInput(i)->dataType() == datatypeFirstInput);
}
DimSize_t outputAxisValue = 0;
for (IOIndex_t i = 0; i < mOp.nbInputs(); ++i) {
outputAxisValue += op.getInput(i)->dims()[axis];
}
DimSize_t prodDimLower = 1;
for (DimIdx_t i = 0; i < axis; ++i) {
prodDimLower *= op.getInput(0)->dims()[i];
}
DimSize_t prodDimHigher = 1;
for (DimIdx_t i = axis + 1; static_cast<std::size_t>(i) < op.getInput(0)->dims().size();
++i) {
prodDimHigher *= op.getInput(0)->dims()[i];
}
std::size_t oIndexStart = 0;
for (std::size_t inputId = 0; inputId < op.nbInputs(); ++inputId) {
const DimSize_t iOffset = prodDimHigher*op.getInput(inputId)->dims()[axis];
for (std::size_t iIndex = 0, oIndex = oIndexStart; iIndex < prodDimLower; ++iIndex) {
op.getOutput(0)->getImpl()->copy(op.getInput(inputId)->getImpl()->rawPtr(iIndex*iOffset), iOffset, oIndex);
oIndex += prodDimHigher*outputAxisValue;
}
oIndexStart += op.getInput(inputId)->dims()[axis]*prodDimHigher;
}
}

Maxence Naud
committed
const std::string Aidge::Concat_Op::Type = "Concat";
bool Aidge::Concat_Op::forwardDims(bool /*allowDataDependency*/) {

Maxence Naud
committed
// Every input is non-empty with the same number of dimensions
bool associated = (getInput(0) != nullptr);
associated &= !(getInput(0)->empty()) && (getAttr<ConcatAttr::Axis>() < getInput(0)->nbDims()); // do not compute anything if no input
auto outputDims = getInput(0)->dims();
const auto firstInputNbDims = getInput(0) -> nbDims();
for (IOIndex_t i = 1; i < nbInputs(); ++i) {
if (!getInput(i)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #{} should be associated with a Tensor", type(), i);
}
if (getInput(i)->nbDims() == firstInputNbDims) {
for (DimSize_t dim = 0; dim < firstInputNbDims; ++dim) {
if (dim == getAttr<ConcatAttr::Axis>()) {
outputDims[dim] += getInput(i)->dims()[dim];
}
else {
associated &= (getInput(i)->dims()[dim] == outputDims[dim]);
}
}
}
else {
associated = false;
break;
}
}
if (associated) {
getOutput(0)->resize(outputDims);
}

Maxence Naud
committed
}

Maxence Naud
committed
void Aidge::Concat_Op::setBackend(const std::string& name, DeviceIdx_t device) {
if (Registrar<Concat_Op>::exists({name})) {
SET_IMPL_MACRO(Concat_Op, *this, name);
}
else {
mImpl = std::make_shared<Concat_OpImpl>(*this);
}

Maxence Naud
committed
mOutputs[0]->setBackend(name, device);
}