diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index d456c8af2cfbfb9680069cda35eacc9941e9fe45..b928334f30bad74ab7c0c9a07dcb5ef43e23b68e 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -707,6 +707,7 @@ Aidge::MemoryManager Aidge::Scheduler::generateMemoryAutoConcat(bool incProducer } } + size_t concatSize = size; size_t concatOffset = 0; if (concat) { @@ -730,6 +731,9 @@ Aidge::MemoryManager Aidge::Scheduler::generateMemoryAutoConcat(bool incProducer if (parentOutputFormat == DataFormat::NHWC) { parentSize = parentOutputDims.end()[-3]; } + else { + concatSize = parentSize; + } concatOffset += parentSize; } @@ -747,16 +751,27 @@ Aidge::MemoryManager Aidge::Scheduler::generateMemoryAutoConcat(bool incProducer if (wrapAroundBuffer && wrapAroundSize > 0) { memManager.reallocate(memPlane, node, concatOffset, - size, true, wrapAroundExtra, childs, stride, length, count); + concatSize, true, wrapAroundExtra, childs, stride, length, count); } else { memManager.reallocate(memPlane.memSpace, node, memPlane.offset + concatOffset, - size, false, 0, childs, stride, length, count); + concatSize, false, 0, childs, stride, length, count); } if (concat && itConcat == concatMemPlane.end()) { concatMemPlane.emplace(concat, memPlane); + + if (wrapAroundBuffer && wrapAroundSize > 0) { + memManager.reallocate(memPlane, + concat, 0, + size, true, wrapAroundExtra, childs, stride, length, count); + } + else { + memManager.reallocate(memPlane.memSpace, + concat, memPlane.offset, + size, false, 0, childs, stride, length, count); + } } }