From 3b04d563d3bb92e64479e1fcfef9d8339d9be4b0 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Tue, 4 Mar 2025 21:58:16 +0100
Subject: [PATCH] Move generateMemory in SequentialScheduler

---
 aidge_core/mem_info.py                        |   2 +-
 include/aidge/scheduler/Scheduler.hpp         |  16 -
 .../aidge/scheduler/SequentialScheduler.hpp   |  17 +
 python_binding/scheduler/pybind_Scheduler.cpp |   4 +-
 src/scheduler/Scheduler.cpp                   | 351 -----------------
 src/scheduler/SequentialScheduler.cpp         | 354 ++++++++++++++++++
 6 files changed, 374 insertions(+), 370 deletions(-)

diff --git a/aidge_core/mem_info.py b/aidge_core/mem_info.py
index b8d3c6101..d87bea939 100644
--- a/aidge_core/mem_info.py
+++ b/aidge_core/mem_info.py
@@ -109,7 +109,7 @@ def log_meminfo(mem_manager:aidge_core.MemoryManager, path: Path, diplay_names:b
     aidge_core.Log.notice(f"Generated memory management info at: {path}")
 
 
-def generate_optimized_memory_info(scheduler: aidge_core.Scheduler, stats_folder: Path = None, wrapping: bool = False, auto_concat: bool = False, display_names: bool=True) -> Tuple[int, List[dict]]:
+def generate_optimized_memory_info(scheduler: aidge_core.SequentialScheduler, stats_folder: Path = None, wrapping: bool = False, auto_concat: bool = False, display_names: bool=True) -> Tuple[int, List[dict]]:
     """Generates optimized memory information for a computation graph managed by a scheduler.
 
     This function analyzes the memory usage of a computation graph, determining the memory peak
diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp
index 1c269cbb4..dbfdfcc39 100644
--- a/include/aidge/scheduler/Scheduler.hpp
+++ b/include/aidge/scheduler/Scheduler.hpp
@@ -192,22 +192,6 @@ public:
     */
    void clearScheduling();
 
-    /**
-     * Generate the memory layout for the current static scheduling.
-     * @param incProducers If true, include the producers in the memory layout.
-     * @param wrapAroundBuffer If true, allow wrapping in memory planes.
-    */
-    MemoryManager generateMemory(bool incProducers = false, bool wrapAroundBuffer = false) const;
-
-    /**
-     * Generate the memory layout for the current static scheduling, with auto-
-     * concatenation: the Concat operator is replaced by direct allocation
-     * when possible.
-     * @param incProducers If true, include the producers in the memory layout.
-     * @param wrapAroundBuffer If true, allow wrapping in memory planes.
-    */
-    MemoryManager generateMemoryAutoConcat(bool incProducers = false, bool wrapAroundBuffer = false) const;
-
     /**
      * @brief Connect input tensors to the data input of the GraphView.
      * In case of multiple data input tensors, they are mapped to producers in
diff --git a/include/aidge/scheduler/SequentialScheduler.hpp b/include/aidge/scheduler/SequentialScheduler.hpp
index 0ae18b085..b95a298d6 100644
--- a/include/aidge/scheduler/SequentialScheduler.hpp
+++ b/include/aidge/scheduler/SequentialScheduler.hpp
@@ -39,6 +39,23 @@ public:
     inline void setSchedulingPolicy(SchedulingPolicy policy) {
         mSchedulingPolicy = policy;
     }
+
+    /**
+     * Generate the memory layout for the current static scheduling.
+     * @param incProducers If true, include the producers in the memory layout.
+     * @param wrapAroundBuffer If true, allow wrapping in memory planes.
+    */
+    MemoryManager generateMemory(bool incProducers = false, bool wrapAroundBuffer = false) const;
+
+    /**
+     * Generate the memory layout for the current static scheduling, with auto-
+     * concatenation: the Concat operator is replaced by direct allocation
+     * when possible.
+     * @param incProducers If true, include the producers in the memory layout.
+     * @param wrapAroundBuffer If true, allow wrapping in memory planes.
+    */
+    MemoryManager generateMemoryAutoConcat(bool incProducers = false, bool wrapAroundBuffer = false) const;
+
     /**
      * @brief Run the provided Computational Graph with a batch of data
      */
diff --git a/python_binding/scheduler/pybind_Scheduler.cpp b/python_binding/scheduler/pybind_Scheduler.cpp
index 582ba4678..62bb7e0af 100644
--- a/python_binding/scheduler/pybind_Scheduler.cpp
+++ b/python_binding/scheduler/pybind_Scheduler.cpp
@@ -55,12 +55,12 @@ void init_Scheduler(py::module& m){
     .def("get_sequential_static_scheduling", &Scheduler::getSequentialStaticScheduling, py::arg("step") = 0, py::arg("sorting") = Scheduler::SchedulingPolicy::Default)
     .def("get_scheduling", &Scheduler::getScheduling)
     .def("graph_view", &Scheduler::graphView)
-    .def("generate_memory", &Scheduler::generateMemory, py::arg("inc_producers") = false, py::arg("wrap_around_buffer") = false)
-    .def("generate_memory_auto_concat", &Scheduler::generateMemoryAutoConcat, py::arg("inc_producers") = false, py::arg("wrap_around_buffer") = false)
     ;
 
     py::class_<SequentialScheduler, std::shared_ptr<SequentialScheduler>, Scheduler>(m, "SequentialScheduler")
     .def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view"))
+    .def("generate_memory", &SequentialScheduler::generateMemory, py::arg("inc_producers") = false, py::arg("wrap_around_buffer") = false)
+    .def("generate_memory_auto_concat", &SequentialScheduler::generateMemoryAutoConcat, py::arg("inc_producers") = false, py::arg("wrap_around_buffer") = false)
     .def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("data")=std::vector<Tensor>())
     .def("backward", &SequentialScheduler::backward)
     ;
diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp
index 90222bdba..2ffaab8e7 100644
--- a/src/scheduler/Scheduler.cpp
+++ b/src/scheduler/Scheduler.cpp
@@ -33,7 +33,6 @@
 #include "aidge/operator/MetaOperator.hpp"
 #include "aidge/operator/OperatorTensor.hpp"
 #include "aidge/operator/Producer.hpp"
-#include "aidge/operator/Concat.hpp"
 #include "aidge/utils/FileManagement.hpp"
 #include "aidge/utils/Log.hpp"
 #include "aidge/utils/Types.h"
@@ -712,356 +711,6 @@ void Aidge::Scheduler::clearScheduling() {
     mScheduling.clear();
 }
 
-/**
- * @warning This version is a simplified version without special handling of concatenation.
- */
-Aidge::MemoryManager Aidge::Scheduler::generateMemory(bool incProducers, bool wrapAroundBuffer) const {
-    MemoryManager memManager;
-
-    for (std::size_t step = 0; step < mStaticSchedule.size(); ++step) {
-        for (const auto& node : getSequentialStaticScheduling(step)) {
-            if (!incProducers && node->type() == Producer_Op::Type) {
-                memManager.releaseDependencies(node);
-                continue;
-            }
-
-            const auto childs = node->getChildren();
-            AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor,
-                "Operator must be of Tensor type for node {} (of type {}).",
-                node->name(), node->type());
-            const auto op = std::static_pointer_cast<OperatorTensor>(node->getOperator());
-
-            std::vector<const MemoryManager::MemoryPlane*> wrapAroundMemPlane;
-
-            // Allocate a memory plane for each node's output
-            for (IOIndex_t outputIdx = 0; outputIdx < node->nbOutputs(); ++outputIdx) {
-                const auto requiredSize = op->getRequiredMemory(outputIdx, {});
-                AIDGE_ASSERT(requiredSize.type == Elts_t::Data,
-                    "Cannot generate memory with token-based producer-consumer model for node {} (of type {}). You may need to forward dimensions in the graph first.",
-                    node->name(), node->type());
-
-                // By default, specifies a fully monolithic memory block
-                std::size_t size = requiredSize.data;
-                std::size_t stride = 0;
-                std::size_t length = 1;
-                std::size_t count = 1;
-
-                if (op->getOutput(outputIdx) && op->getOutput(outputIdx)->dataFormat() == DataFormat::NHWC) {
-                    size = op->getOutput(outputIdx)->dims().end()[-3];
-                    stride = size;
-                    length = op->getOutput(outputIdx)->dims().end()[-1];
-                    count = op->getOutput(outputIdx)->dims().end()[-2];
-                }
-
-                // Check if wrap around buffer is possible for this node
-                // (re-using previous node outputs memory for this node outputs).
-                // => only if this node is the only child of its parent(s)
-                std::size_t wrapAroundSize = 0;
-                std::size_t wrapAroundExtra = 0;
-                wrapAroundMemPlane.push_back(nullptr);
-
-                // Select the best parent among all allocable nodes for
-                // reallocation, which is the one with most memory (in order
-                // to minimize the reallocation size).
-                IOIndex_t inputIdx = 0;
-                for (const auto& parent : node->dataInputs()) {
-                    if (parent.first && parent.first->getChildren(parent.second).size() == 1
-                        // there might be no existing plane if the parent was
-                        // not yet scheduled (because it may be a recurrent connection)
-                        && memManager.getNbPlanes(parent.first) >= parent.first->nbOutputs()
-                        // memSpace should not be already released
-                        && memManager.getPlanes(parent.first).end()[-parent.first->nbOutputs()+parent.second].memSpace->released == -1)
-                    {
-                        const auto requiredData = op->getNbRequiredData(inputIdx);
-                        const auto requiredProtected = op->getNbRequiredProtected(inputIdx);
-                        AIDGE_ASSERT(requiredData.type == Elts_t::Data && requiredProtected.type == Elts_t::Data,
-                            "Cannot generate memory with token-based producer-consumer model for node {} (of type {}). You may need to forward dimensions in the graph first.",
-                            node->name(), node->type());
-
-                        const bool isWrappable = (requiredProtected.data < requiredData.data);
-                        const auto& memPlanes = memManager.getPlanes(parent.first);
-                        const MemoryManager::MemoryPlane& memPlane = memPlanes.at(memPlanes.size() - parent.first->nbOutputs() + parent.second); // use at() to avoid dangling reference pointer
-
-                        if (isWrappable || !memManager.isWrapAround(
-                                    memPlane.memSpace,
-                                    memPlane.getFinalOffset()
-                                        - memPlane.memSpace->offset,
-                                    requiredSize.data))
-                        {
-                            if (memPlane.getSize() > wrapAroundSize + requiredProtected.data
-                                && std::find(wrapAroundMemPlane.begin(), wrapAroundMemPlane.end(), &memPlane) == wrapAroundMemPlane.end())
-                            {
-                                wrapAroundSize = memPlane.getSize() - requiredProtected.data;
-                                if (requiredSize.data > wrapAroundSize) {
-                                    wrapAroundExtra = requiredSize.data - wrapAroundSize;
-                                }
-                                wrapAroundMemPlane[outputIdx] = &memPlane;
-                            }
-
-                            if (wrapAroundExtra == 0) {
-                                break;
-                            }
-                        }
-                    }
-                    ++inputIdx;
-                }
-
-                // MemoryPlane to (re)use
-                const MemoryManager::MemoryPlane& memPlane
-                    = (wrapAroundBuffer && wrapAroundSize > 0)
-                        ? (*wrapAroundMemPlane[outputIdx]) :
-                            memManager.allocate(size, childs, stride, length, count);
-
-                if (wrapAroundBuffer && wrapAroundSize > 0) {
-                    memManager.reallocate(memPlane,
-                        node, 0,
-                        size, true, wrapAroundExtra, childs, stride, length, count);
-                }
-                else {
-                    memManager.reallocate(memPlane.memSpace,
-                        node, memPlane.offset,
-                        size, false, 0, childs, stride, length, count);
-                }
-            }
-
-            memManager.releaseDependencies(node);
-            memManager.tick();
-        }
-    }
-
-    return memManager;
-}
-
-Aidge::MemoryManager Aidge::Scheduler::generateMemoryAutoConcat(bool incProducers, bool wrapAroundBuffer) const {
-    MemoryManager memManager;
-    std::map<NodePtr, MemoryManager::MemoryPlane> concatMemPlane;
-
-    for (std::size_t step = 0; step < mStaticSchedule.size(); ++step) {
-        // AsLateAsPossible ensures that when a node child is Concat, all the parents
-        // of the Concat parents have already been memory mapped!
-        for (const auto& node : getSequentialStaticScheduling(step, SchedulingPolicy::AsLateAsPossible)) {
-            if (!incProducers && node->type() == Producer_Op::Type) {
-                memManager.releaseDependencies(node);
-                continue;
-            }
-
-            auto itConcat = concatMemPlane.find(node);
-            if (itConcat != concatMemPlane.end()) {
-                // Skip Concat
-                AIDGE_INTERNAL_ASSERT(itConcat->first->type() == Concat_Op::Type);
-                concatMemPlane.erase(itConcat);
-                continue;
-            }
-            itConcat = concatMemPlane.end();
-
-            auto childs = node->getChildren();
-            AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor,
-                "Operator must be of Tensor type for node {} (of type {}).",
-                node->name(), node->type());
-            const auto op = std::static_pointer_cast<OperatorTensor>(node->getOperator());
-
-            std::shared_ptr<Node> concat = nullptr;
-            // If the only child is a concatenation, check if we can allocate
-            // the concatenation result directly and skip allocation for this
-            // node output:
-            if (childs.size() == 1 && (*childs.begin())->type() == Concat_Op::Type) {
-                concat = *childs.begin();
-
-                for (const auto& concatParent : concat->getParents()) {
-                    if (concatParent->getChildren().size() > 1) {
-                        // not possible: at least one of the Concat parent has
-                        // multiple children.
-                        concat = nullptr;
-                        break;
-                    }
-                }
-            }
-
-            if (concat) {
-                itConcat = concatMemPlane.find(concat);
-            }
-
-            std::vector<const MemoryManager::MemoryPlane*> wrapAroundMemPlane;
-
-            // Allocate a memory plane for each node's output
-            AIDGE_INTERNAL_ASSERT(!concat || node->nbOutputs() == 1);
-            for (IOIndex_t outputIdx = 0; outputIdx < node->nbOutputs(); ++outputIdx) {
-                auto requiredSize = op->getRequiredMemory(outputIdx, {});
-                auto outputDims = (op->getOutput(outputIdx)) ? op->getOutput(outputIdx)->dims() : std::vector<DimSize_t>();
-                auto outputFormat = (op->getOutput(outputIdx)) ? op->getOutput(outputIdx)->dataFormat() : DataFormat::Default;
-
-                // If concat is not nullptr, we directly allocate the concatenation result
-                // Must check that we are on the right output too.
-                if (concat && node->getChildren(outputIdx).size() == 1) {
-                    const auto concatOp = std::static_pointer_cast<OperatorTensor>(concat->getOperator());
-                    requiredSize = concatOp->getRequiredMemory(0, {});
-                    outputDims = (concatOp->getOutput(0)) ? concatOp->getOutput(0)->dims() : std::vector<DimSize_t>();
-                    outputFormat = (concatOp->getOutput(0)) ? concatOp->getOutput(0)->dataFormat() : DataFormat::Default;
-                }
-
-                AIDGE_ASSERT(requiredSize.type == Elts_t::Data,
-                    "Cannot generate memory with token-based producer-consumer model for node {} (of type {}). You may need to forward dimensions in the graph first.",
-                    node->name(), node->type());
-
-                // By default, specifies a fully monolithic memory block
-                std::size_t size = requiredSize.data;
-                std::size_t stride = 0;
-                std::size_t length = 1;
-                std::size_t count = 1;
-
-                if (outputFormat == DataFormat::NHWC) {
-                    size = op->getOutput(outputIdx)->dims().end()[-3];
-                    stride = outputDims.end()[-3];
-                    length = outputDims.end()[-1];
-                    count = outputDims.end()[-2];
-
-                    AIDGE_INTERNAL_ASSERT(stride >= size);
-                    AIDGE_INTERNAL_ASSERT(length == op->getOutput(outputIdx)->dims().end()[-1]);
-                    AIDGE_INTERNAL_ASSERT(count == op->getOutput(outputIdx)->dims().end()[-2]);
-                }
-
-                // Check if wrap around buffer is possible for this node
-                // (re-using previous node outputs memory for this node outputs).
-                // => only if this node is the only child of its parent(s)
-                std::size_t wrapAroundSize = 0;
-                std::size_t wrapAroundExtra = 0;
-                wrapAroundMemPlane.push_back(nullptr); // default value of wrapAroundMemPlane[outputIdx]
-
-                // Select the best parent among all allocable nodes for
-                // reallocation, which is the one with most memory (in order
-                // to minimize the reallocation size).
-                const auto allocableNodes = (concat) ? concat->getParents() : std::vector<NodePtr>{node};
-                for (const auto& allocableNode : allocableNodes) {
-                    IOIndex_t inputIdx = 0;
-                    for (const auto& parent : allocableNode->dataInputs()) {
-                        if (parent.first && parent.first->getChildren(parent.second).size() == 1
-                            // there might be no existing plane if the parent was
-                            // not yet scheduled (because it may be a recurrent connection)
-                            && memManager.getNbPlanes(parent.first) >= parent.first->nbOutputs()
-                            // memSpace should not be already released
-                            && memManager.getPlanes(parent.first).end()[-parent.first->nbOutputs()+parent.second].memSpace->released == -1)
-                        {
-                            const auto requiredData = allocableNode->getOperator()->getNbRequiredData(inputIdx);
-                            const auto requiredProtected = allocableNode->getOperator()->getNbRequiredProtected(inputIdx);
-                            AIDGE_ASSERT(requiredData.type == Elts_t::Data && requiredProtected.type == Elts_t::Data,
-                                "Cannot generate memory with token-based producer-consumer model for node {} (of type {}). You may need to forward dimensions in the graph first.",
-                                node->name(), node->type());
-
-                            const bool isWrappable = (requiredProtected.data < requiredData.data);
-                            const auto& memPlanes = memManager.getPlanes(parent.first);
-                            const MemoryManager::MemoryPlane& memPlane
-                                = (concat && itConcat != concatMemPlane.end())
-                                    ? itConcat->second
-                                    : memPlanes.at(memPlanes.size()-parent.first->nbOutputs()+parent.second); // use at() to avoid dangling reference pointer
-
-                            if (isWrappable || !memManager.isWrapAround(
-                                        memPlane.memSpace,
-                                        memPlane.getFinalOffset()
-                                            - memPlane.memSpace->offset,
-                                        requiredSize.data))
-                            {
-                                if (memPlane.getSize() > wrapAroundSize + requiredProtected.data
-                                    && std::find(wrapAroundMemPlane.begin(), wrapAroundMemPlane.end(), &memPlane) == wrapAroundMemPlane.end())
-                                {
-                                    wrapAroundSize = memPlane.getSize() - requiredProtected.data;
-                                    if (requiredSize.data > wrapAroundSize) {
-                                        wrapAroundExtra = requiredSize.data - wrapAroundSize;
-                                    }
-                                    wrapAroundMemPlane[outputIdx] = &memPlane;
-                                }
-
-                                if (wrapAroundExtra == 0) {
-                                    break;
-                                }
-                            }
-                        }
-                        ++inputIdx;
-                    }
-                }
-
-                size_t concatSize = size;
-                size_t concatOffset = 0;
-
-                if (concat) {
-                    // Dependencies should be concat node *childs*, not concat node
-                    childs = concat->getChildren();
-
-                    // Compute concatOffset
-                    for (auto concatParent : concat->getParents()) {
-                        const auto parentOp = std::static_pointer_cast<OperatorTensor>(concatParent->getOperator());
-                        const auto parentRequiredSize = parentOp->getRequiredMemory(outputIdx, {});
-                        const auto parentOutputDims = (parentOp->getOutput(outputIdx)) ? parentOp->getOutput(outputIdx)->dims() : std::vector<DimSize_t>();
-                        const auto parentOutputFormat = (parentOp->getOutput(outputIdx)) ? parentOp->getOutput(outputIdx)->dataFormat() : DataFormat::Default;
-
-                        if (concatParent == node) {
-                            if (parentOutputFormat != DataFormat::NHWC) {
-                                concatSize = parentRequiredSize.data;
-                            }
-                            break;
-                        }
-                        else {
-                            // By default, specifies a fully monolithic memory block
-                            std::size_t parentSize = parentRequiredSize.data;
-
-                            if (parentOutputFormat == DataFormat::NHWC) {
-                                parentSize = parentOutputDims.end()[-3];
-                            }
-
-                            concatOffset += parentSize;
-                        }
-                    }
-
-                    // Size in reallocate() is counted from the offset, not from 0,
-                    // meaning the offset must be substracted to obtain the correct
-                    // total size without excess.
-                    if (concatOffset > 0) {
-                        concatSize -= concatOffset;
-                    }
-                }
-
-                // MemoryPlane to (re)use
-                const MemoryManager::MemoryPlane& memPlane
-                    = (concat && itConcat != concatMemPlane.end())
-                        ? itConcat->second :
-                      (wrapAroundBuffer && wrapAroundSize > 0)
-                        ? (*wrapAroundMemPlane[outputIdx]) :
-                            memManager.allocate(size, childs, stride, length, count);
-
-                if (wrapAroundBuffer && wrapAroundSize > 0) {
-                    memManager.reallocate(memPlane,
-                        node, concatOffset,
-                        concatSize, true, wrapAroundExtra, childs, stride, length, count);
-                }
-                else {
-                    memManager.reallocate(memPlane.memSpace,
-                        node, memPlane.offset + concatOffset,
-                        concatSize, false, 0, childs, stride, length, count);
-                }
-
-                if (concat && itConcat == concatMemPlane.end()) {
-                    concatMemPlane.emplace(concat, memPlane);
-
-                    if (wrapAroundBuffer && wrapAroundSize > 0) {
-                        memManager.reallocate(memPlane,
-                            concat, 0,
-                            size, true, wrapAroundExtra, childs, stride, length, count);
-                    }
-                    else {
-                        memManager.reallocate(memPlane.memSpace,
-                            concat, memPlane.offset,
-                            size, false, 0, childs, stride, length, count);
-                    }
-                }
-            }
-
-            memManager.releaseDependencies(node);
-            memManager.tick();
-        }
-    }
-
-    return memManager;
-}
-
 void Aidge::Scheduler::connectInputs(const std::vector<std::shared_ptr<Aidge::Tensor>>& data){
     // This version of connect inputs only connects tensor inputs in input data producers.
     auto inputNodes = mGraphView->getOrderedInputs();
diff --git a/src/scheduler/SequentialScheduler.cpp b/src/scheduler/SequentialScheduler.cpp
index 07f01ce09..6a52882f6 100644
--- a/src/scheduler/SequentialScheduler.cpp
+++ b/src/scheduler/SequentialScheduler.cpp
@@ -25,9 +25,363 @@
 #include "aidge/operator/OperatorTensor.hpp"
 #include "aidge/operator/Producer.hpp"
 #include "aidge/operator/Memorize.hpp"
+#include "aidge/operator/Concat.hpp"
 #include "aidge/operator/MetaOperator.hpp"
 #include "aidge/recipes/GraphViewHelper.hpp"
 
+/**
+ * @warning This version is a simplified version without special handling of concatenation.
+ */
+Aidge::MemoryManager Aidge::SequentialScheduler::generateMemory(bool incProducers, bool wrapAroundBuffer) const {
+    MemoryManager memManager;
+
+    for (std::size_t step = 0; step < mStaticSchedule.size(); ++step) {
+        for (const auto& node : getSequentialStaticScheduling(step, mSchedulingPolicy)) {
+            if (!incProducers && node->type() == Producer_Op::Type) {
+                memManager.releaseDependencies(node);
+                continue;
+            }
+
+            const auto childs = node->getChildren();
+            AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor,
+                "Operator must be of Tensor type for node {} (of type {}).",
+                node->name(), node->type());
+            const auto op = std::static_pointer_cast<OperatorTensor>(node->getOperator());
+
+            std::vector<const MemoryManager::MemoryPlane*> wrapAroundMemPlane;
+
+            // Allocate a memory plane for each node's output
+            for (IOIndex_t outputIdx = 0; outputIdx < node->nbOutputs(); ++outputIdx) {
+                const auto requiredSize = op->getRequiredMemory(outputIdx, {});
+                AIDGE_ASSERT(requiredSize.type == Elts_t::Data,
+                    "Cannot generate memory with token-based producer-consumer model for node {} (of type {}). You may need to forward dimensions in the graph first.",
+                    node->name(), node->type());
+
+                // By default, specifies a fully monolithic memory block
+                std::size_t size = requiredSize.data;
+                std::size_t stride = 0;
+                std::size_t length = 1;
+                std::size_t count = 1;
+
+                if (op->getOutput(outputIdx) && op->getOutput(outputIdx)->dataFormat() == DataFormat::NHWC) {
+                    size = op->getOutput(outputIdx)->dims().end()[-3];
+                    stride = size;
+                    length = op->getOutput(outputIdx)->dims().end()[-1];
+                    count = op->getOutput(outputIdx)->dims().end()[-2];
+                }
+
+                // Check if wrap around buffer is possible for this node
+                // (re-using previous node outputs memory for this node outputs).
+                // => only if this node is the only child of its parent(s)
+                std::size_t wrapAroundSize = 0;
+                std::size_t wrapAroundExtra = 0;
+                wrapAroundMemPlane.push_back(nullptr);
+
+                // Select the best parent among all allocable nodes for
+                // reallocation, which is the one with most memory (in order
+                // to minimize the reallocation size).
+                IOIndex_t inputIdx = 0;
+                for (const auto& parent : node->dataInputs()) {
+                    if (parent.first && parent.first->getChildren(parent.second).size() == 1
+                        // there might be no existing plane if the parent was
+                        // not yet scheduled (because it may be a recurrent connection)
+                        && memManager.getNbPlanes(parent.first) >= parent.first->nbOutputs()
+                        // memSpace should not be already released
+                        && memManager.getPlanes(parent.first).end()[-parent.first->nbOutputs()+parent.second].memSpace->released == -1)
+                    {
+                        const auto requiredData = op->getNbRequiredData(inputIdx);
+                        const auto requiredProtected = op->getNbRequiredProtected(inputIdx);
+                        AIDGE_ASSERT(requiredData.type == Elts_t::Data && requiredProtected.type == Elts_t::Data,
+                            "Cannot generate memory with token-based producer-consumer model for node {} (of type {}). You may need to forward dimensions in the graph first.",
+                            node->name(), node->type());
+
+                        const bool isWrappable = (requiredProtected.data < requiredData.data);
+                        const auto& memPlanes = memManager.getPlanes(parent.first);
+                        const MemoryManager::MemoryPlane& memPlane = memPlanes.at(memPlanes.size() - parent.first->nbOutputs() + parent.second); // use at() to avoid dangling reference pointer
+
+                        if (isWrappable || !memManager.isWrapAround(
+                                    memPlane.memSpace,
+                                    memPlane.getFinalOffset()
+                                        - memPlane.memSpace->offset,
+                                    requiredSize.data))
+                        {
+                            if (memPlane.getSize() > wrapAroundSize + requiredProtected.data
+                                && std::find(wrapAroundMemPlane.begin(), wrapAroundMemPlane.end(), &memPlane) == wrapAroundMemPlane.end())
+                            {
+                                wrapAroundSize = memPlane.getSize() - requiredProtected.data;
+                                if (requiredSize.data > wrapAroundSize) {
+                                    wrapAroundExtra = requiredSize.data - wrapAroundSize;
+                                }
+                                wrapAroundMemPlane[outputIdx] = &memPlane;
+                            }
+
+                            if (wrapAroundExtra == 0) {
+                                break;
+                            }
+                        }
+                    }
+                    ++inputIdx;
+                }
+
+                // MemoryPlane to (re)use
+                const MemoryManager::MemoryPlane& memPlane
+                    = (wrapAroundBuffer && wrapAroundSize > 0)
+                        ? (*wrapAroundMemPlane[outputIdx]) :
+                            memManager.allocate(size, childs, stride, length, count);
+
+                if (wrapAroundBuffer && wrapAroundSize > 0) {
+                    memManager.reallocate(memPlane,
+                        node, 0,
+                        size, true, wrapAroundExtra, childs, stride, length, count);
+                }
+                else {
+                    memManager.reallocate(memPlane.memSpace,
+                        node, memPlane.offset,
+                        size, false, 0, childs, stride, length, count);
+                }
+            }
+
+            memManager.releaseDependencies(node);
+            memManager.tick();
+        }
+    }
+
+    return memManager;
+}
+
+Aidge::MemoryManager Aidge::SequentialScheduler::generateMemoryAutoConcat(bool incProducers, bool wrapAroundBuffer) const {
+    MemoryManager memManager;
+    std::map<NodePtr, MemoryManager::MemoryPlane> concatMemPlane;
+
+    for (std::size_t step = 0; step < mStaticSchedule.size(); ++step) {
+        // For AutoConcat, the AsLateAsPossible scheduling policy ensures that
+        // when a node child is Concat, all the parents of the Concat parents
+        // have already been memory mapped!
+        // But in any case, the policy must remain the same than the scheduler
+        // export policy.
+        for (const auto& node : getSequentialStaticScheduling(step, mSchedulingPolicy)) {
+            if (!incProducers && node->type() == Producer_Op::Type) {
+                memManager.releaseDependencies(node);
+                continue;
+            }
+
+            auto itConcat = concatMemPlane.find(node);
+            if (itConcat != concatMemPlane.end()) {
+                // Skip Concat
+                AIDGE_INTERNAL_ASSERT(itConcat->first->type() == Concat_Op::Type);
+                concatMemPlane.erase(itConcat);
+                continue;
+            }
+            itConcat = concatMemPlane.end();
+
+            auto childs = node->getChildren();
+            AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor,
+                "Operator must be of Tensor type for node {} (of type {}).",
+                node->name(), node->type());
+            const auto op = std::static_pointer_cast<OperatorTensor>(node->getOperator());
+
+            std::shared_ptr<Node> concat = nullptr;
+            // If the only child is a concatenation, check if we can allocate
+            // the concatenation result directly and skip allocation for this
+            // node output:
+            if (childs.size() == 1 && (*childs.begin())->type() == Concat_Op::Type) {
+                concat = *childs.begin();
+
+                for (const auto& concatParent : concat->getParents()) {
+                    if (concatParent->getChildren().size() > 1) {
+                        // not possible: at least one of the Concat parent has
+                        // multiple children.
+                        concat = nullptr;
+                        break;
+                    }
+                }
+            }
+
+            if (concat) {
+                itConcat = concatMemPlane.find(concat);
+            }
+
+            std::vector<const MemoryManager::MemoryPlane*> wrapAroundMemPlane;
+
+            // Allocate a memory plane for each node's output
+            AIDGE_INTERNAL_ASSERT(!concat || node->nbOutputs() == 1);
+            for (IOIndex_t outputIdx = 0; outputIdx < node->nbOutputs(); ++outputIdx) {
+                auto requiredSize = op->getRequiredMemory(outputIdx, {});
+                auto outputDims = (op->getOutput(outputIdx)) ? op->getOutput(outputIdx)->dims() : std::vector<DimSize_t>();
+                auto outputFormat = (op->getOutput(outputIdx)) ? op->getOutput(outputIdx)->dataFormat() : DataFormat::Default;
+
+                // If concat is not nullptr, we directly allocate the concatenation result
+                // Must check that we are on the right output too.
+                if (concat && node->getChildren(outputIdx).size() == 1) {
+                    const auto concatOp = std::static_pointer_cast<OperatorTensor>(concat->getOperator());
+                    requiredSize = concatOp->getRequiredMemory(0, {});
+                    outputDims = (concatOp->getOutput(0)) ? concatOp->getOutput(0)->dims() : std::vector<DimSize_t>();
+                    outputFormat = (concatOp->getOutput(0)) ? concatOp->getOutput(0)->dataFormat() : DataFormat::Default;
+                }
+
+                AIDGE_ASSERT(requiredSize.type == Elts_t::Data,
+                    "Cannot generate memory with token-based producer-consumer model for node {} (of type {}). You may need to forward dimensions in the graph first.",
+                    node->name(), node->type());
+
+                // By default, specifies a fully monolithic memory block
+                std::size_t size = requiredSize.data;
+                std::size_t stride = 0;
+                std::size_t length = 1;
+                std::size_t count = 1;
+
+                if (outputFormat == DataFormat::NHWC) {
+                    size = op->getOutput(outputIdx)->dims().end()[-3];
+                    stride = outputDims.end()[-3];
+                    length = outputDims.end()[-1];
+                    count = outputDims.end()[-2];
+
+                    AIDGE_INTERNAL_ASSERT(stride >= size);
+                    AIDGE_INTERNAL_ASSERT(length == op->getOutput(outputIdx)->dims().end()[-1]);
+                    AIDGE_INTERNAL_ASSERT(count == op->getOutput(outputIdx)->dims().end()[-2]);
+                }
+
+                // Check if wrap around buffer is possible for this node
+                // (re-using previous node outputs memory for this node outputs).
+                // => only if this node is the only child of its parent(s)
+                std::size_t wrapAroundSize = 0;
+                std::size_t wrapAroundExtra = 0;
+                wrapAroundMemPlane.push_back(nullptr); // default value of wrapAroundMemPlane[outputIdx]
+
+                // Select the best parent among all allocable nodes for
+                // reallocation, which is the one with most memory (in order
+                // to minimize the reallocation size).
+                const auto allocableNodes = (concat) ? concat->getParents() : std::vector<NodePtr>{node};
+                for (const auto& allocableNode : allocableNodes) {
+                    IOIndex_t inputIdx = 0;
+                    for (const auto& parent : allocableNode->dataInputs()) {
+                        if (parent.first && parent.first->getChildren(parent.second).size() == 1
+                            // there might be no existing plane if the parent was
+                            // not yet scheduled (because it may be a recurrent connection)
+                            && memManager.getNbPlanes(parent.first) >= parent.first->nbOutputs()
+                            // memSpace should not be already released
+                            && memManager.getPlanes(parent.first).end()[-parent.first->nbOutputs()+parent.second].memSpace->released == -1)
+                        {
+                            const auto requiredData = allocableNode->getOperator()->getNbRequiredData(inputIdx);
+                            const auto requiredProtected = allocableNode->getOperator()->getNbRequiredProtected(inputIdx);
+                            AIDGE_ASSERT(requiredData.type == Elts_t::Data && requiredProtected.type == Elts_t::Data,
+                                "Cannot generate memory with token-based producer-consumer model for node {} (of type {}). You may need to forward dimensions in the graph first.",
+                                node->name(), node->type());
+
+                            const bool isWrappable = (requiredProtected.data < requiredData.data);
+                            const auto& memPlanes = memManager.getPlanes(parent.first);
+                            const MemoryManager::MemoryPlane& memPlane
+                                = (concat && itConcat != concatMemPlane.end())
+                                    ? itConcat->second
+                                    : memPlanes.at(memPlanes.size()-parent.first->nbOutputs()+parent.second); // use at() to avoid dangling reference pointer
+
+                            if (isWrappable || !memManager.isWrapAround(
+                                        memPlane.memSpace,
+                                        memPlane.getFinalOffset()
+                                            - memPlane.memSpace->offset,
+                                        requiredSize.data))
+                            {
+                                if (memPlane.getSize() > wrapAroundSize + requiredProtected.data
+                                    && std::find(wrapAroundMemPlane.begin(), wrapAroundMemPlane.end(), &memPlane) == wrapAroundMemPlane.end())
+                                {
+                                    wrapAroundSize = memPlane.getSize() - requiredProtected.data;
+                                    if (requiredSize.data > wrapAroundSize) {
+                                        wrapAroundExtra = requiredSize.data - wrapAroundSize;
+                                    }
+                                    wrapAroundMemPlane[outputIdx] = &memPlane;
+                                }
+
+                                if (wrapAroundExtra == 0) {
+                                    break;
+                                }
+                            }
+                        }
+                        ++inputIdx;
+                    }
+                }
+
+                size_t concatSize = size;
+                size_t concatOffset = 0;
+
+                if (concat) {
+                    // Dependencies should be concat node *childs*, not concat node
+                    childs = concat->getChildren();
+
+                    // Compute concatOffset
+                    for (auto concatParent : concat->getParents()) {
+                        const auto parentOp = std::static_pointer_cast<OperatorTensor>(concatParent->getOperator());
+                        const auto parentRequiredSize = parentOp->getRequiredMemory(outputIdx, {});
+                        const auto parentOutputDims = (parentOp->getOutput(outputIdx)) ? parentOp->getOutput(outputIdx)->dims() : std::vector<DimSize_t>();
+                        const auto parentOutputFormat = (parentOp->getOutput(outputIdx)) ? parentOp->getOutput(outputIdx)->dataFormat() : DataFormat::Default;
+
+                        if (concatParent == node) {
+                            if (parentOutputFormat != DataFormat::NHWC) {
+                                concatSize = parentRequiredSize.data;
+                            }
+                            break;
+                        }
+                        else {
+                            // By default, specifies a fully monolithic memory block
+                            std::size_t parentSize = parentRequiredSize.data;
+
+                            if (parentOutputFormat == DataFormat::NHWC) {
+                                parentSize = parentOutputDims.end()[-3];
+                            }
+
+                            concatOffset += parentSize;
+                        }
+                    }
+
+                    // Size in reallocate() is counted from the offset, not from 0,
+                    // meaning the offset must be substracted to obtain the correct
+                    // total size without excess.
+                    if (concatOffset > 0) {
+                        concatSize -= concatOffset;
+                    }
+                }
+
+                // MemoryPlane to (re)use
+                const MemoryManager::MemoryPlane& memPlane
+                    = (concat && itConcat != concatMemPlane.end())
+                        ? itConcat->second :
+                      (wrapAroundBuffer && wrapAroundSize > 0)
+                        ? (*wrapAroundMemPlane[outputIdx]) :
+                            memManager.allocate(size, childs, stride, length, count);
+
+                if (wrapAroundBuffer && wrapAroundSize > 0) {
+                    memManager.reallocate(memPlane,
+                        node, concatOffset,
+                        concatSize, true, wrapAroundExtra, childs, stride, length, count);
+                }
+                else {
+                    memManager.reallocate(memPlane.memSpace,
+                        node, memPlane.offset + concatOffset,
+                        concatSize, false, 0, childs, stride, length, count);
+                }
+
+                if (concat && itConcat == concatMemPlane.end()) {
+                    concatMemPlane.emplace(concat, memPlane);
+
+                    if (wrapAroundBuffer && wrapAroundSize > 0) {
+                        memManager.reallocate(memPlane,
+                            concat, 0,
+                            size, true, wrapAroundExtra, childs, stride, length, count);
+                    }
+                    else {
+                        memManager.reallocate(memPlane.memSpace,
+                            concat, memPlane.offset,
+                            size, false, 0, childs, stride, length, count);
+                    }
+                }
+            }
+
+            memManager.releaseDependencies(node);
+            memManager.tick();
+        }
+    }
+
+    return memManager;
+}
+
 void Aidge::SequentialScheduler::forward(bool forwardDims, const std::vector<std::shared_ptr<Aidge::Tensor>>& data) {
     // Collect all data input of the graph (that are producers)
     if (!data.empty()){
-- 
GitLab