Skip to content
Snippets Groups Projects
Commit cfdec01b authored by Maxence Naud's avatar Maxence Naud
Browse files

fix concat formward_dims()

parent f27194c2
No related branches found
No related tags found
2 merge requests!152Update Aidge export to take a graph view has an argument instead of a...,!145Improve UI for Operator/Node/GraphView/Tensor
Pipeline #49432 failed
...@@ -60,31 +60,47 @@ void Aidge::Concat_OpImpl::forward() { ...@@ -60,31 +60,47 @@ void Aidge::Concat_OpImpl::forward() {
const std::string Aidge::Concat_Op::Type = "Concat"; const std::string Aidge::Concat_Op::Type = "Concat";
bool Aidge::Concat_Op::forwardDims(bool /*allowDataDependency*/) { bool Aidge::Concat_Op::forwardDims(bool /*allowDataDependency*/) {
if (inputsAssociated()) { if (!inputsAssociated()) {
AIDGE_ASSERT(axis() < getInput(0)->nbDims(), "Concat: Axis ({}) out of range ({})", return false;
axis(), getInput(0)->nbDims()); }
const std::size_t nbDimsInput0 = getInput(0)->nbDims();
if (nbDimsInput0 == 0) {
return false;
}
AIDGE_ASSERT(nbDimsInput0 > 0, "First input in {} Operator is empty", type());
for (IOIndex_t i = 1; i < nbInputs(); ++i) {
if (getInput(i)->nbDims() == 0) {
return false;
}
AIDGE_ASSERT(nbDimsInput0 == getInput(i)->nbDims(),
"Input 0 and input {} in {} Operator have different number of dimensions: {} / {}",
i, type(), nbDimsInput0, getInput(i)->nbDims());
}
// Check validity of attributes with inputs
// Axis
std::int32_t axis = mAttributes->template getAttr<ConcatAttr::Axis>();
axis = (axis < 0) ? axis + static_cast<std::int32_t>(nbDimsInput0) : axis;
AIDGE_ASSERT(((axis >= 0) && (axis < static_cast<std::int32_t>(nbDimsInput0))),
"'Axis' attribute not compatible with provided inputs.")
const std::size_t axis_u64 = static_cast<std::size_t>(axis);
auto outputDims = getInput(0)->dims(); // Check validity of inputs
const auto firstInputNbDims = getInput(0) -> nbDims(); auto outputDims = getInput(0)->dims();
for (IOIndex_t i = 1; i < nbInputs(); ++i) { for (IOIndex_t i = 1; i < nbInputs(); ++i) {
if (getInput(i)->nbDims() == firstInputNbDims) { for (DimSize_t dim = 0; dim < nbDimsInput0; ++dim) {
for (DimSize_t dim = 0; dim < firstInputNbDims; ++dim) { if (dim == axis_u64) {
if (dim == axis()) { outputDims[axis_u64] += getInput(i)->dims()[axis_u64];
outputDims[dim] += getInput(i)->dims()[dim]; }
} else {
else { AIDGE_ASSERT(getInput(i)->dims()[dim] == outputDims[dim],
AIDGE_ASSERT(getInput(i)->dims()[dim] == outputDims[dim], "Concat: input #{} dim #{} ({}) must match value {}", "Incomatible dimensions between input 0 {} and input {} {}",
i, dim, getInput(i)->dims()[dim], outputDims[dim]); getInput(0)->dims(), i, getInput(i)->dims());
}
}
} }
} }
getOutput(0)->resize(outputDims);
return true;
} }
return false; getOutput(0)->resize(outputDims);
return true;
} }
void Aidge::Concat_Op::setBackend(const std::string& name, DeviceIdx_t device) { void Aidge::Concat_Op::setBackend(const std::string& name, DeviceIdx_t device) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment