From 60f51b7a575f1f4b61900d2c4d229562ce988465 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Thu, 9 Jan 2025 10:31:40 +0100
Subject: [PATCH] Fix of the previous fix

---
 src/scheduler/Scheduler.cpp | 16 ++++++++--------
 1 file changed, 8 insertions(+), 8 deletions(-)

diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp
index b928334f3..ac73b8264 100644
--- a/src/scheduler/Scheduler.cpp
+++ b/src/scheduler/Scheduler.cpp
@@ -716,24 +716,24 @@ Aidge::MemoryManager Aidge::Scheduler::generateMemoryAutoConcat(bool incProducer
 
                     // Compute concatOffset
                     for (auto concatParent : concat->getParents()) {
+                        const auto parentOp = std::static_pointer_cast<OperatorTensor>(concatParent->getOperator());
+                        const auto parentRequiredSize = parentOp->getRequiredMemory(outputIdx, {});
+                        const auto parentOutputDims = (parentOp->getOutput(outputIdx)) ? parentOp->getOutput(outputIdx)->dims() : std::vector<DimSize_t>();
+                        const auto parentOutputFormat = (parentOp->getOutput(outputIdx)) ? parentOp->getOutput(outputIdx)->dataFormat() : DataFormat::Default;
+
                         if (concatParent == node) {
+                            if (parentOutputFormat != DataFormat::NHWC) {
+                                concatSize = parentRequiredSize.data;
+                            }
                             break;
                         }
                         else {
-                            const auto parentOp = std::static_pointer_cast<OperatorTensor>(concatParent->getOperator());
-                            const auto parentRequiredSize = parentOp->getRequiredMemory(outputIdx, {});
-                            const auto parentOutputDims = (parentOp->getOutput(outputIdx)) ? parentOp->getOutput(outputIdx)->dims() : std::vector<DimSize_t>();
-                            const auto parentOutputFormat = (parentOp->getOutput(outputIdx)) ? parentOp->getOutput(outputIdx)->dataFormat() : DataFormat::Default;
-
                             // By default, specifies a fully monolithic memory block
                             std::size_t parentSize = parentRequiredSize.data;
 
                             if (parentOutputFormat == DataFormat::NHWC) {
                                 parentSize = parentOutputDims.end()[-3];
                             }
-                            else {
-                                concatSize = parentSize;
-                            }
 
                             concatOffset += parentSize;
                         }
-- 
GitLab