diff --git a/src/operator/Concat.cpp b/src/operator/Concat.cpp index 739e69fbab6ea27c68d19205e858b7f97cc1ecab..bf4bbb85be606fc857bf8d771b9ce211ca8e858e 100644 --- a/src/operator/Concat.cpp +++ b/src/operator/Concat.cpp @@ -60,31 +60,47 @@ void Aidge::Concat_OpImpl::forward() { const std::string Aidge::Concat_Op::Type = "Concat"; bool Aidge::Concat_Op::forwardDims(bool /*allowDataDependency*/) { - if (inputsAssociated()) { - AIDGE_ASSERT(axis() < getInput(0)->nbDims(), "Concat: Axis ({}) out of range ({})", - axis(), getInput(0)->nbDims()); + if (!inputsAssociated()) { + return false; + } + const std::size_t nbDimsInput0 = getInput(0)->nbDims(); + if (nbDimsInput0 == 0) { + return false; + } + AIDGE_ASSERT(nbDimsInput0 > 0, "First input in {} Operator is empty", type()); + for (IOIndex_t i = 1; i < nbInputs(); ++i) { + if (getInput(i)->nbDims() == 0) { + return false; + } + AIDGE_ASSERT(nbDimsInput0 == getInput(i)->nbDims(), + "Input 0 and input {} in {} Operator have different number of dimensions: {} / {}", + i, type(), nbDimsInput0, getInput(i)->nbDims()); + } + // Check validity of attributes with inputs + // Axis + std::int32_t axis = mAttributes->template getAttr<ConcatAttr::Axis>(); + axis = (axis < 0) ? axis + static_cast<std::int32_t>(nbDimsInput0) : axis; + AIDGE_ASSERT(((axis >= 0) && (axis < static_cast<std::int32_t>(nbDimsInput0))), + "'Axis' attribute not compatible with provided inputs.") + const std::size_t axis_u64 = static_cast<std::size_t>(axis); - auto outputDims = getInput(0)->dims(); - const auto firstInputNbDims = getInput(0) -> nbDims(); - for (IOIndex_t i = 1; i < nbInputs(); ++i) { - if (getInput(i)->nbDims() == firstInputNbDims) { - for (DimSize_t dim = 0; dim < firstInputNbDims; ++dim) { - if (dim == axis()) { - outputDims[dim] += getInput(i)->dims()[dim]; - } - else { - AIDGE_ASSERT(getInput(i)->dims()[dim] == outputDims[dim], "Concat: input #{} dim #{} ({}) must match value {}", - i, dim, getInput(i)->dims()[dim], outputDims[dim]); - } - } + // Check validity of inputs + auto outputDims = getInput(0)->dims(); + for (IOIndex_t i = 1; i < nbInputs(); ++i) { + for (DimSize_t dim = 0; dim < nbDimsInput0; ++dim) { + if (dim == axis_u64) { + outputDims[axis_u64] += getInput(i)->dims()[axis_u64]; + } + else { + AIDGE_ASSERT(getInput(i)->dims()[dim] == outputDims[dim], + "Incomatible dimensions between input 0 {} and input {} {}", + getInput(0)->dims(), i, getInput(i)->dims()); } } - - getOutput(0)->resize(outputDims); - return true; } - return false; + getOutput(0)->resize(outputDims); + return true; } void Aidge::Concat_Op::setBackend(const std::string& name, DeviceIdx_t device) {