diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index b928334f30bad74ab7c0c9a07dcb5ef43e23b68e..ac73b8264502b970cead955262e482eb97592b84 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -716,24 +716,24 @@ Aidge::MemoryManager Aidge::Scheduler::generateMemoryAutoConcat(bool incProducer // Compute concatOffset for (auto concatParent : concat->getParents()) { + const auto parentOp = std::static_pointer_cast<OperatorTensor>(concatParent->getOperator()); + const auto parentRequiredSize = parentOp->getRequiredMemory(outputIdx, {}); + const auto parentOutputDims = (parentOp->getOutput(outputIdx)) ? parentOp->getOutput(outputIdx)->dims() : std::vector<DimSize_t>(); + const auto parentOutputFormat = (parentOp->getOutput(outputIdx)) ? parentOp->getOutput(outputIdx)->dataFormat() : DataFormat::Default; + if (concatParent == node) { + if (parentOutputFormat != DataFormat::NHWC) { + concatSize = parentRequiredSize.data; + } break; } else { - const auto parentOp = std::static_pointer_cast<OperatorTensor>(concatParent->getOperator()); - const auto parentRequiredSize = parentOp->getRequiredMemory(outputIdx, {}); - const auto parentOutputDims = (parentOp->getOutput(outputIdx)) ? parentOp->getOutput(outputIdx)->dims() : std::vector<DimSize_t>(); - const auto parentOutputFormat = (parentOp->getOutput(outputIdx)) ? parentOp->getOutput(outputIdx)->dataFormat() : DataFormat::Default; - // By default, specifies a fully monolithic memory block std::size_t parentSize = parentRequiredSize.data; if (parentOutputFormat == DataFormat::NHWC) { parentSize = parentOutputDims.end()[-3]; } - else { - concatSize = parentSize; - } concatOffset += parentSize; }