From 95f7c06f314f02d2b0376643388c3769b958174b Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Fri, 20 Dec 2024 18:35:05 +0100
Subject: [PATCH] Added support for auto concatenation

---
 include/aidge/scheduler/Scheduler.hpp         |  18 +-
 python_binding/scheduler/pybind_Scheduler.cpp |   9 +-
 src/scheduler/Scheduler.cpp                   | 221 +++++++++++++++++-
 3 files changed, 244 insertions(+), 4 deletions(-)

diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp
index 28ecde6d9..51f62ed1b 100644
--- a/include/aidge/scheduler/Scheduler.hpp
+++ b/include/aidge/scheduler/Scheduler.hpp
@@ -90,6 +90,12 @@ public:
         NotConnected
     };
 
+    enum class EarlyLateSort {
+        Default,
+        AsSoonAsPossible,
+        AsLateAsPossible
+    };
+
     /**
      * @struct PriorProducersConsumers
      * @brief Manages producer-consumer relationships for nodes.
@@ -124,9 +130,10 @@ public:
     /**
      * @brief Get the static scheduling order of nodes.
      * @param step The step of the static schedule to retrieve (default is 0).
+     * @param sorting Sorting mode.
      * @return Vector of shared pointers to Nodes in their scheduled order.
      */
-    std::vector<std::shared_ptr<Node>> getStaticScheduling(std::size_t step = 0) const;
+    std::vector<std::shared_ptr<Node>> getStaticScheduling(std::size_t step = 0, EarlyLateSort sorting = EarlyLateSort::Default) const;
 
     /**
      * @brief Get the GraphView associated with this Scheduler.
@@ -156,6 +163,15 @@ public:
     */
     MemoryManager generateMemory(bool incProducers = false, bool wrapAroundBuffer = false) const;
 
+    /**
+     * Generate the memory layout for the current static scheduling, with auto-
+     * concatenation: the Concat operator is replaced by direct allocation
+     * when possible.
+     * @param incProducers If true, include the producers in the memory layout.
+     * @param wrapAroundBuffer If true, allow wrapping in memory planes.
+    */
+    MemoryManager generateMemoryAutoConcat(bool incProducers = false, bool wrapAroundBuffer = false) const;
+
     /**
      * @brief Connect input tensors to the data input of the GraphView.
      * In case of multiple data input tensors, they are mapped to producers in
diff --git a/python_binding/scheduler/pybind_Scheduler.cpp b/python_binding/scheduler/pybind_Scheduler.cpp
index 472af2a94..7eb0db712 100644
--- a/python_binding/scheduler/pybind_Scheduler.cpp
+++ b/python_binding/scheduler/pybind_Scheduler.cpp
@@ -21,6 +21,12 @@
 namespace py = pybind11;
 namespace Aidge {
 void init_Scheduler(py::module& m){
+    py::enum_<Scheduler::EarlyLateSort>(m, "EarlyLateSort")
+        .value("Default", Scheduler::EarlyLateSort::Default)
+        .value("AsSoonAsPossible", Scheduler::EarlyLateSort::AsSoonAsPossible)
+        .value("AsLateAsPossible", Scheduler::EarlyLateSort::AsLateAsPossible)
+        .export_values();
+
     py::class_<Scheduler, std::shared_ptr<Scheduler>>(m, "Scheduler")
     .def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view"))
     .def("graph_view", &Scheduler::graphView)
@@ -28,9 +34,10 @@ void init_Scheduler(py::module& m){
     .def("save_static_scheduling_diagram", &Scheduler::saveStaticSchedulingDiagram, py::arg("file_name"))
     .def("resetScheduling", &Scheduler::resetScheduling)
     .def("generate_scheduling", &Scheduler::generateScheduling)
-    .def("get_static_scheduling", &Scheduler::getStaticScheduling, py::arg("step") = 0)
+    .def("get_static_scheduling", &Scheduler::getStaticScheduling, py::arg("step") = 0, py::arg("sorting") = EarlyLateSort::Default)
     .def("graph_view", &Scheduler::graphView)
     .def("generate_memory", &Scheduler::generateMemory, py::arg("inc_producers") = false, py::arg("wrap_around_buffer") = false)
+    .def("generate_memory_auto_concat", &Scheduler::generateMemoryAutoConcat, py::arg("inc_producers") = false, py::arg("wrap_around_buffer") = false)
     ;
 
     py::class_<SequentialScheduler, std::shared_ptr<SequentialScheduler>, Scheduler>(m, "SequentialScheduler")
diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp
index 2e9dc034e..9b655331a 100644
--- a/src/scheduler/Scheduler.cpp
+++ b/src/scheduler/Scheduler.cpp
@@ -33,6 +33,7 @@
 #include "aidge/operator/MetaOperator.hpp"
 #include "aidge/operator/OperatorTensor.hpp"
 #include "aidge/operator/Producer.hpp"
+#include "aidge/operator/Concat.hpp"
 #include "aidge/utils/Log.hpp"
 #include "aidge/utils/Types.h"
 
@@ -561,6 +562,212 @@ Aidge::MemoryManager Aidge::Scheduler::generateMemory(bool incProducers, bool wr
     return memManager;
 }
 
+Aidge::MemoryManager Aidge::Scheduler::generateMemoryAutoConcat(bool incProducers, bool wrapAroundBuffer) const {
+    MemoryManager memManager;
+    std::map<NodePtr, MemoryManager::MemoryPlane> concatMemPlane;
+
+    for (std::size_t step = 0; step < mStaticSchedule.size(); ++step) {
+        // AsLateAsPossible ensures that when a node child is Concat, all the parents
+        // of the Concat parents have already been memory mapped!
+        for (const auto& node : getStaticScheduling(step, EarlyLateSort::AsLateAsPossible)) {
+            if (!incProducers && node->type() == Producer_Op::Type) {
+                memManager.releaseDependencies(node);
+                continue;
+            }
+
+            auto itConcat = concatMemPlane.find(node);
+            if (itConcat != concatMemPlane.end()) {
+                // Skip Concat
+                AIDGE_INTERNAL_ASSERT(itConcat->first->type() == Concat_Op::Type);
+                concatMemPlane.erase(itConcat);
+                continue;
+            }
+            itConcat = concatMemPlane.end();
+
+            auto childs = node->getChildren();
+            AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor,
+                "Operator must be of Tensor type for node {} (of type {}).",
+                node->name(), node->type());
+            const auto op = std::static_pointer_cast<OperatorTensor>(node->getOperator());
+
+            std::shared_ptr<Node> concat = nullptr;
+            // If the only child is a concatenation, check if we can allocate
+            // the concatenation result directly and skip allocation for this 
+            // node output:
+            if (childs.size() == 1 && (*childs.begin())->type() == Concat_Op::Type) {
+                concat = *childs.begin();
+
+                for (const auto& concatParent : concat->getParents()) {
+                    if (concatParent->getChildren().size() > 1) {
+                        // not possible: at least one of the Concat parent has
+                        // multiple children.
+                        concat = nullptr;
+                        break;
+                    }
+                }
+            }
+
+            if (concat) {
+                itConcat = concatMemPlane.find(concat);
+            }
+
+            std::vector<const MemoryManager::MemoryPlane*> wrapAroundMemPlane;
+
+            // Allocate a memory plane for each node's output
+            AIDGE_INTERNAL_ASSERT(!concat || node->nbOutputs() == 1);
+            for (IOIndex_t outputIdx = 0; outputIdx < node->nbOutputs(); ++outputIdx) {
+                auto requiredSize = op->getRequiredMemory(outputIdx, {});
+                auto outputDims = (op->getOutput(outputIdx)) ? op->getOutput(outputIdx)->dims() : std::vector<DimSize_t>();
+
+                // If concat is not nullptr, we directly allocate the concatenation result
+                // Must check that we are on the right output too.
+                if (concat && node->getChildren(outputIdx).size() == 1) {
+                    const auto concatOp = std::static_pointer_cast<OperatorTensor>(concat->getOperator());
+                    requiredSize = concatOp->getRequiredMemory(0, {});
+                    outputDims = (concatOp->getOutput(0)) ? concatOp->getOutput(0)->dims() : std::vector<DimSize_t>();
+                }
+
+                AIDGE_ASSERT(requiredSize.type == Elts_t::Data,
+                    "Cannot generate memory with token-based producer-consumer model for node {} (of type {}).",
+                    node->name(), node->type());
+
+                // By default, specifies a fully monolithic memory block
+                std::size_t size = requiredSize.data;
+                std::size_t stride = 0;
+                std::size_t length = 1;
+                std::size_t count = 1;
+
+                if (outputDims.size() > 3) {
+                    // If it is possible, assume a NCHW layout
+                    size = op->getOutput(outputIdx)->dims().end()[-3];
+                    stride = outputDims.end()[-3];
+                    length = outputDims.end()[-1];
+                    count = outputDims.end()[-2];
+
+                    AIDGE_INTERNAL_ASSERT(stride >= size);
+                    AIDGE_INTERNAL_ASSERT(length == op->getOutput(outputIdx)->dims().end()[-1]);
+                    AIDGE_INTERNAL_ASSERT(count == op->getOutput(outputIdx)->dims().end()[-2]);
+                }
+
+                // Check if wrap around buffer is possible for this node
+                // (re-using previous node outputs memory for this node outputs).
+                // => only if this node is the only child of its parent(s)
+                std::size_t wrapAroundSize = 0;
+                std::size_t wrapAroundExtra = 0;
+                wrapAroundMemPlane.push_back(nullptr); // default value of wrapAroundMemPlane[outputIdx]
+
+                // Select the best parent among all allocable nodes for
+                // reallocation, which is the one with most memory (in order
+                // to minimize the reallocation size).
+                const auto allocableNodes = (concat) ? concat->getParents() : std::vector<NodePtr>{node};
+                for (const auto& allocableNode : allocableNodes) {
+                    IOIndex_t inputIdx = 0;
+                    for (const auto& parent : allocableNode->dataInputs()) {
+                        if (parent.first && parent.first->getChildren(parent.second).size() == 1
+                            // there might be no existing plane if the parent was
+                            // not yet scheduled (because it may be a recurrent connection)
+                            && memManager.getNbPlanes(parent.first) >= parent.first->nbOutputs()
+                            // memSpace should not be already released
+                            && memManager.getPlanes(parent.first).end()[-parent.first->nbOutputs()+parent.second].memSpace->released == -1)
+                        {
+                            const auto requiredData = allocableNode->getOperator()->getNbRequiredData(inputIdx);
+                            const auto requiredProtected = allocableNode->getOperator()->getNbRequiredProtected(inputIdx);
+                            AIDGE_ASSERT(requiredData.type == Elts_t::Data && requiredProtected.type == Elts_t::Data,
+                                "Cannot generate memory with token-based producer-consumer model for node {} (of type {}).",
+                                node->name(), node->type());
+
+                            const bool isWrappable = (requiredProtected.data < requiredData.data);
+                            const MemoryManager::MemoryPlane& memPlane
+                                = (concat && itConcat != concatMemPlane.end())
+                                    ? itConcat->second
+                                    : memManager.getPlanes(parent.first).end()[-parent.first->nbOutputs()+parent.second];
+
+                            if (isWrappable || !memManager.isWrapAround(
+                                        memPlane.memSpace,
+                                        memPlane.getFinalOffset()
+                                            - memPlane.memSpace->offset,
+                                        requiredSize.data))
+                            {
+                                if (memPlane.getSize() > wrapAroundSize + requiredProtected.data
+                                    && std::find(wrapAroundMemPlane.begin(), wrapAroundMemPlane.end(), &memPlane) == wrapAroundMemPlane.end())
+                                {
+                                    wrapAroundSize = memPlane.getSize() - requiredProtected.data;
+                                    if (requiredSize.data > wrapAroundSize) {
+                                        wrapAroundExtra = requiredSize.data - wrapAroundSize;
+                                    }
+                                    wrapAroundMemPlane[outputIdx] = &memPlane;
+                                }
+
+                                if (wrapAroundExtra == 0) {
+                                    break;
+                                }
+                            }
+                        }
+                        ++inputIdx;
+                    }
+                }
+
+                size_t concatOffset = 0;
+
+                if (concat) {
+                    // Dependencies should be concat node *childs*, not concat node
+                    childs = concat->getChildren();
+
+                    // Compute concatOffset
+                    for (auto concatParent : concat->getParents()) {
+                        if (concatParent == node) {
+                            break;
+                        }
+                        else {
+                            const auto parentOp = std::static_pointer_cast<OperatorTensor>(concatParent->getOperator());
+                            const auto parentRequiredSize = parentOp->getRequiredMemory(outputIdx, {});
+                            const auto parentOutputDims = (parentOp->getOutput(outputIdx)) ? parentOp->getOutput(outputIdx)->dims() : std::vector<DimSize_t>();
+
+                            // By default, specifies a fully monolithic memory block
+                            std::size_t parentSize = parentRequiredSize.data;
+
+                            if (parentOutputDims.size() > 3) {
+                                // If it is possible, assume a NCHW layout
+                                parentSize = parentOutputDims.end()[-3];
+                            }
+
+                            concatOffset += parentSize;
+                        }
+                    }
+                }
+
+                // MemoryPlane to (re)use
+                const MemoryManager::MemoryPlane& memPlane
+                    = (concat && itConcat != concatMemPlane.end())
+                        ? itConcat->second :
+                      (wrapAroundBuffer && wrapAroundSize > 0)
+                        ? (*wrapAroundMemPlane[outputIdx]) :
+                            memManager.allocate(size, childs, stride, length, count);
+
+                if (wrapAroundBuffer && wrapAroundSize > 0) {
+                    memManager.reallocate(memPlane,
+                        node, concatOffset,
+                        size, true, wrapAroundExtra, childs, stride, length, count);
+                }
+                else {
+                    memManager.reallocate(memPlane.memSpace,
+                        node, memPlane.offset + concatOffset,
+                        size, false, 0, childs, stride, length, count);
+                }
+
+                if (concat && itConcat == concatMemPlane.end()) {
+                    concatMemPlane.emplace(concat, memPlane);
+                }
+            }
+
+            memManager.releaseDependencies(node);
+            memManager.tick();
+        }
+    }
+
+    return memManager;
+}
+
 void Aidge::Scheduler::connectInputs(const std::vector<std::shared_ptr<Aidge::Tensor>>& data){
     // This version of connect inputs only connects tensor inputs in input data producers.
     auto inputNodes = mGraphView->getOrderedInputs();
@@ -649,11 +856,21 @@ void Aidge::Scheduler::saveStaticSchedulingDiagram(const std::string& fileName)
     fmt::print(fp.get(), "\n");
 }
 
-std::vector<std::shared_ptr<Aidge::Node>> Aidge::Scheduler::getStaticScheduling(std::size_t step) const {
+std::vector<std::shared_ptr<Aidge::Node>> Aidge::Scheduler::getStaticScheduling(std::size_t step, EarlyLateSort sorting) const {
     AIDGE_ASSERT(!mStaticSchedule.empty(), "Scheduler::getStaticScheduling(): static scheduling is empty, did you generate scheduling first?");
     AIDGE_ASSERT(step < mStaticSchedule.size(), "Scheduler::getStaticScheduling(): no static scheduling at step {} (available steps: {})", mStaticSchedule.size(), step);
 
-    const auto& staticSchedule = mStaticSchedule.at(step);
+    std::deque<StaticSchedulingElement*> staticSchedule(mStaticSchedule.at(step).begin(), mStaticSchedule.at(step).end());
+
+    if (sorting == EarlyLateSort::AsSoonAsPossible) {
+        std::stable_sort(staticSchedule.begin(), staticSchedule.end(),
+            [](const auto& lhs, const auto& rhs) { return ((lhs->early < rhs->early) || (lhs->early == rhs->early && lhs->late < rhs->late)); });
+    }
+    else if (sorting == EarlyLateSort::AsLateAsPossible) {
+        std::stable_sort(staticSchedule.begin(), staticSchedule.end(),
+            [](const auto& lhs, const auto& rhs) { return ((lhs->late < rhs->late) || (lhs->late == rhs->late && lhs->early < rhs->early)); });
+    }
+
     std::vector<std::shared_ptr<Node>> schedule;
     std::transform(staticSchedule.begin(), staticSchedule.end(), std::back_inserter(schedule), [](const auto& v) { return v->node; });
     return schedule;
-- 
GitLab