diff --git a/aidge_core/mem_info.py b/aidge_core/mem_info.py index 87f286bcbcde3bc656e7997e65ec0078437e72e9..cabc2c72ee973babdf0342ba82057f7ab0769b52 100644 --- a/aidge_core/mem_info.py +++ b/aidge_core/mem_info.py @@ -109,7 +109,7 @@ def log_meminfo(mem_manager:aidge_core.MemoryManager, path: Path, diplay_names:b aidge_core.Log.notice(f"Generated memory management info at: {path}") -def generate_optimized_memory_info(scheduler: aidge_core.Scheduler, stats_folder: Path = None, wrapping: bool = False, display_names: bool=True) -> Tuple[int, List[dict]]: +def generate_optimized_memory_info(scheduler: aidge_core.Scheduler, stats_folder: Path = None, wrapping: bool = False, auto_concat: bool = False, display_names: bool=True) -> Tuple[int, List[dict]]: """Generates optimized memory information for a computation graph managed by a scheduler. This function analyzes the memory usage of a computation graph, determining the memory peak @@ -125,6 +125,9 @@ def generate_optimized_memory_info(scheduler: aidge_core.Scheduler, stats_folder :param wrapping: Boolean flag to enable or disable wrap-around buffer optimization. Defaults to `False`. :type wrapping: bool, optional + :param auto_concat: Boolean flag to enable or disable auto-concatenation optimization. + Defaults to `False`. + :type auto_concat: bool, optional :param diplay_names: If True Node names are diplayed in the memory plot alongside their block, defaults to False :type diplay_names: bool, optional :return: A tuple containing the peak memory size and a list of memory information for each @@ -138,8 +141,12 @@ def generate_optimized_memory_info(scheduler: aidge_core.Scheduler, stats_folder # scheduler.generate_scheduling() # Generate the memory manager # So far, the Producers are not take in consideration in the meory manager => inc_producers=False - mem_manager = scheduler.generate_memory( - inc_producers=False, wrap_around_buffer=wrapping) + if auto_concat: + mem_manager = scheduler.generate_memory_auto_concat( + inc_producers=False, wrap_around_buffer=wrapping) + else: + mem_manager = scheduler.generate_memory( + inc_producers=False, wrap_around_buffer=wrapping) # List of nodes which are connected at the input of the graph (None if input is not connected) nodes_at_input = [n[0] for n in scheduler.graph_view().inputs()] diff --git a/include/aidge/recipes/Recipes.hpp b/include/aidge/recipes/Recipes.hpp index 5f16c480c233c0aee23962549c1d86695af81d89..0a3f5dc4d0ea00ddb5c8d0b8885269c882f7f705 100644 --- a/include/aidge/recipes/Recipes.hpp +++ b/include/aidge/recipes/Recipes.hpp @@ -124,6 +124,12 @@ void explicitCastMove(std::shared_ptr<GraphView> graphView); */ void explicitTranspose(std::shared_ptr<GraphView> graphView); +/** + * Replace a single meta operator by its micro graph. + * @return true if node is indeed a meta operator and could be expanded. +*/ +bool expandMetaOp(std::shared_ptr<Node> node); + /** * Flatten the graph by replacing the meta operators by their micro graph. * @param recursive If true, recursively replace meta operators until there is diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index 28ecde6d9319ae05be20f591cc9a6a4e2a29acc0..51f62ed1b7412abed4e0b850183fb101f208a69e 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -90,6 +90,12 @@ public: NotConnected }; + enum class EarlyLateSort { + Default, + AsSoonAsPossible, + AsLateAsPossible + }; + /** * @struct PriorProducersConsumers * @brief Manages producer-consumer relationships for nodes. @@ -124,9 +130,10 @@ public: /** * @brief Get the static scheduling order of nodes. * @param step The step of the static schedule to retrieve (default is 0). + * @param sorting Sorting mode. * @return Vector of shared pointers to Nodes in their scheduled order. */ - std::vector<std::shared_ptr<Node>> getStaticScheduling(std::size_t step = 0) const; + std::vector<std::shared_ptr<Node>> getStaticScheduling(std::size_t step = 0, EarlyLateSort sorting = EarlyLateSort::Default) const; /** * @brief Get the GraphView associated with this Scheduler. @@ -156,6 +163,15 @@ public: */ MemoryManager generateMemory(bool incProducers = false, bool wrapAroundBuffer = false) const; + /** + * Generate the memory layout for the current static scheduling, with auto- + * concatenation: the Concat operator is replaced by direct allocation + * when possible. + * @param incProducers If true, include the producers in the memory layout. + * @param wrapAroundBuffer If true, allow wrapping in memory planes. + */ + MemoryManager generateMemoryAutoConcat(bool incProducers = false, bool wrapAroundBuffer = false) const; + /** * @brief Connect input tensors to the data input of the GraphView. * In case of multiple data input tensors, they are mapped to producers in diff --git a/python_binding/scheduler/pybind_Scheduler.cpp b/python_binding/scheduler/pybind_Scheduler.cpp index 472af2a9465b121593613492f5120ddc9d7fe254..d4cd7da44d148aef669c893405cff3101b6090b5 100644 --- a/python_binding/scheduler/pybind_Scheduler.cpp +++ b/python_binding/scheduler/pybind_Scheduler.cpp @@ -21,6 +21,12 @@ namespace py = pybind11; namespace Aidge { void init_Scheduler(py::module& m){ + py::enum_<Scheduler::EarlyLateSort>(m, "EarlyLateSort") + .value("Default", Scheduler::EarlyLateSort::Default) + .value("AsSoonAsPossible", Scheduler::EarlyLateSort::AsSoonAsPossible) + .value("AsLateAsPossible", Scheduler::EarlyLateSort::AsLateAsPossible) + .export_values(); + py::class_<Scheduler, std::shared_ptr<Scheduler>>(m, "Scheduler") .def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view")) .def("graph_view", &Scheduler::graphView) @@ -28,9 +34,10 @@ void init_Scheduler(py::module& m){ .def("save_static_scheduling_diagram", &Scheduler::saveStaticSchedulingDiagram, py::arg("file_name")) .def("resetScheduling", &Scheduler::resetScheduling) .def("generate_scheduling", &Scheduler::generateScheduling) - .def("get_static_scheduling", &Scheduler::getStaticScheduling, py::arg("step") = 0) + .def("get_static_scheduling", &Scheduler::getStaticScheduling, py::arg("step") = 0, py::arg("sorting") = Scheduler::EarlyLateSort::Default) .def("graph_view", &Scheduler::graphView) .def("generate_memory", &Scheduler::generateMemory, py::arg("inc_producers") = false, py::arg("wrap_around_buffer") = false) + .def("generate_memory_auto_concat", &Scheduler::generateMemoryAutoConcat, py::arg("inc_producers") = false, py::arg("wrap_around_buffer") = false) ; py::class_<SequentialScheduler, std::shared_ptr<SequentialScheduler>, Scheduler>(m, "SequentialScheduler") diff --git a/src/backend/OperatorImpl.cpp b/src/backend/OperatorImpl.cpp index c74b538a4e566b3b88e77dd4d097344d52838505..8a4924c0e3e6c8f0e6625232451a4267c5d9b318 100644 --- a/src/backend/OperatorImpl.cpp +++ b/src/backend/OperatorImpl.cpp @@ -250,9 +250,10 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& && requiredIOSpec.type != IOSpec.type) { const auto cast = Cast(IOSpec.type); + cast->getOperator()->setBackend(node->getOperator()->backend()); cast->addChild(parent, 0, i); - op->getInput(i)->setDataType(IOSpec.type); + op->getInput(i)->setDataType(requiredIOSpec.type); } // Input format @@ -263,10 +264,11 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& const auto transpose = getDataFormatTranspose(requiredIOSpec.format, IOSpec.format); auto transposeOp = Transpose(std::vector<DimSize_t>(transpose.begin(), transpose.end())); transposeOp->getOperator()->setDataFormat(IOSpec.format); - transposeOp->getOperator()->setDataType(IOSpec.type); + transposeOp->getOperator()->setDataType(requiredIOSpec.type); + transposeOp->getOperator()->setBackend(node->getOperator()->backend()); transposeOp->addChild(parent, 0, i); - op->getInput(i)->setDataFormat(IOSpec.format); + op->getInput(i)->setDataFormat(requiredIOSpec.format); } // Input dims @@ -301,6 +303,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& && requiredIOSpec.type != IOSpec.type) { const auto cast = Cast(requiredIOSpec.type); + cast->getOperator()->setBackend(node->getOperator()->backend()); parent->addChild(cast, i, 0); op->getOutput(i)->setDataType(IOSpec.type); @@ -315,6 +318,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& auto transposeOp = Transpose(std::vector<DimSize_t>(transpose.begin(), transpose.end())); transposeOp->getOperator()->setDataFormat(requiredIOSpec.format); transposeOp->getOperator()->setDataType(requiredIOSpec.type); + transposeOp->getOperator()->setBackend(node->getOperator()->backend()); parent->addChild(transposeOp, i, 0); op->getOutput(i)->setDataFormat(IOSpec.format); @@ -340,7 +344,13 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& } } - return MetaOperator(std::string("Adapted_" + op->type()).c_str(), getConnectedGraphView(node)); + auto adaptedGraph = getConnectedGraphView(node); + if (adaptedGraph->getNodes().size() > 1) { + return MetaOperator(std::string("Adapted_" + op->type()).c_str(), adaptedGraph); + } + else { + return node; + } } std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getBestAdaptation(const ImplSpec& requiredSpecs) const { @@ -354,8 +364,13 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getBestAdaptation(const ImplSp auto adaptation = getAdaptation(availableSpec, requiredSpecs); if (adaptation) { - auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(adaptation->getOperator())->getMicroGraph(); - adaptations.insert(std::make_pair(adaptation, microGraph->getNodes().size())); + if (adaptation->getOperator()->isAtomic()) { + adaptations.insert(std::make_pair(adaptation, 1)); + } + else { + auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(adaptation->getOperator())->getMicroGraph(); + adaptations.insert(std::make_pair(adaptation, microGraph->getNodes().size())); + } } } diff --git a/src/data/Tensor.cpp b/src/data/Tensor.cpp index c47f3c33efd8348f8bac4f0ab2221e39e3d5e62a..a14ae4187707490cfb70681fc418daf961cb053b 100644 --- a/src/data/Tensor.cpp +++ b/src/data/Tensor.cpp @@ -572,6 +572,7 @@ void Tensor::copyTranspose(const Tensor& src, const std::vector<DimSize_t>& tran } } + AIDGE_ASSERT(mImpl, "Tensor::copyTranspose(): an implementation is required, use setBackend() first!"); std::shared_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({mImpl->backend(), mDataType})(mImpl->device().second, newDims); std::vector<size_t> indices(newDims.size(), 0); diff --git a/src/operator/Transpose.cpp b/src/operator/Transpose.cpp index d24b9c90927830db6f1c1256d133d871ba3dc37f..b550db16dfee8286242df7cfbed9b3b300ee96d5 100644 --- a/src/operator/Transpose.cpp +++ b/src/operator/Transpose.cpp @@ -66,12 +66,17 @@ bool Aidge::Transpose_Op::forwardDims(bool /*allowDataDependency*/) { std::iota(this->outputDimsOrder().rbegin(), this->outputDimsOrder().rend(), 0); } - AIDGE_ASSERT(outputDimsOrder().size() == getInput(0)->nbDims(), - "Permutation vector must have the same rank as input tensor."); + AIDGE_ASSERT(outputDimsOrder().size() >= getInput(0)->nbDims(), + "Permutation vector ({}) must have at least the same rank as input tensor ({}).", outputDimsOrder(), getInput(0)->dims()); std::vector<DimSize_t> outputDims; - for (std::size_t i = 0; i < outputDimsOrder().size(); ++i) { + std::size_t i = 0; + for (; i < getInput(0)->nbDims(); ++i) { outputDims.push_back(getInput(0)->dims()[outputDimsOrder()[i]]); } + for (; i < outputDimsOrder().size(); ++i) { + AIDGE_ASSERT(i == outputDimsOrder()[i], + "Permutation vector ({}) must be the identity above the input tensor rank ({}).", outputDimsOrder(), getInput(0)->dims()); + } mOutputs[0]->resize(outputDims); return true; } diff --git a/src/recipes/AdaptToBackend.cpp b/src/recipes/AdaptToBackend.cpp index e625a52f6545c3b2b34f85745fd88087a1b9883b..bb4222c492f6b0bff17f7de68decd3e8c77efba5 100644 --- a/src/recipes/AdaptToBackend.cpp +++ b/src/recipes/AdaptToBackend.cpp @@ -33,6 +33,7 @@ void Aidge::adaptToBackend(std::shared_ptr<GraphView> graphView) { Log::info("Adapted node {} (of type {}) to backend {}", node->name(), node->type(), impl->backend()); AIDGE_ASSERT(GraphView::replace({node}, {adaptedNode}), "Unable to replace adapted node!"); + expandMetaOp(adaptedNode); } } } diff --git a/src/recipes/ExpandMetaOps.cpp b/src/recipes/ExpandMetaOps.cpp index 16f0b4c52f394e32e24fa49951c39a7c2cb35162..459a1ca852c486e83af5b1ab18893fa475caecf8 100644 --- a/src/recipes/ExpandMetaOps.cpp +++ b/src/recipes/ExpandMetaOps.cpp @@ -14,6 +14,21 @@ #include "aidge/recipes/Recipes.hpp" #include "aidge/operator/MetaOperator.hpp" +bool Aidge::expandMetaOp(std::shared_ptr<Node> node) { + auto metaOp = std::dynamic_pointer_cast<MetaOperator_Op>(node->getOperator()); + + if (metaOp != nullptr) { + // Replace meta op by its micro-graph + // graph will be updated accordingly in GraphView::replace() + auto g = std::make_shared<GraphView>(); + g->add(node, false); + GraphView::replace(g, metaOp->getMicroGraph()); + return true; + } + + return false; +} + void Aidge::expandMetaOps(std::shared_ptr<GraphView> graph, bool recursive) { bool found = false; const auto nodes = graph->getNodes(); diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 2e9dc034ef0bf2b0be2ee27c26b1995a2d0e4244..ac73b8264502b970cead955262e482eb97592b84 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -33,6 +33,7 @@ #include "aidge/operator/MetaOperator.hpp" #include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/Producer.hpp" +#include "aidge/operator/Concat.hpp" #include "aidge/utils/Log.hpp" #include "aidge/utils/Types.h" @@ -475,8 +476,7 @@ Aidge::MemoryManager Aidge::Scheduler::generateMemory(bool incProducers, bool wr std::size_t length = 1; std::size_t count = 1; - if (op->getOutput(outputIdx) && op->getOutput(outputIdx)->dims().size() > 3) { - // If it is possible, assume a NCHW layout + if (op->getOutput(outputIdx) && op->getOutput(outputIdx)->dataFormat() == DataFormat::NHWC) { size = op->getOutput(outputIdx)->dims().end()[-3]; stride = size; length = op->getOutput(outputIdx)->dims().end()[-1]; @@ -561,6 +561,228 @@ Aidge::MemoryManager Aidge::Scheduler::generateMemory(bool incProducers, bool wr return memManager; } +Aidge::MemoryManager Aidge::Scheduler::generateMemoryAutoConcat(bool incProducers, bool wrapAroundBuffer) const { + MemoryManager memManager; + std::map<NodePtr, MemoryManager::MemoryPlane> concatMemPlane; + + for (std::size_t step = 0; step < mStaticSchedule.size(); ++step) { + // AsLateAsPossible ensures that when a node child is Concat, all the parents + // of the Concat parents have already been memory mapped! + for (const auto& node : getStaticScheduling(step, EarlyLateSort::AsLateAsPossible)) { + if (!incProducers && node->type() == Producer_Op::Type) { + memManager.releaseDependencies(node); + continue; + } + + auto itConcat = concatMemPlane.find(node); + if (itConcat != concatMemPlane.end()) { + // Skip Concat + AIDGE_INTERNAL_ASSERT(itConcat->first->type() == Concat_Op::Type); + concatMemPlane.erase(itConcat); + continue; + } + itConcat = concatMemPlane.end(); + + auto childs = node->getChildren(); + AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, + "Operator must be of Tensor type for node {} (of type {}).", + node->name(), node->type()); + const auto op = std::static_pointer_cast<OperatorTensor>(node->getOperator()); + + std::shared_ptr<Node> concat = nullptr; + // If the only child is a concatenation, check if we can allocate + // the concatenation result directly and skip allocation for this + // node output: + if (childs.size() == 1 && (*childs.begin())->type() == Concat_Op::Type) { + concat = *childs.begin(); + + for (const auto& concatParent : concat->getParents()) { + if (concatParent->getChildren().size() > 1) { + // not possible: at least one of the Concat parent has + // multiple children. + concat = nullptr; + break; + } + } + } + + if (concat) { + itConcat = concatMemPlane.find(concat); + } + + std::vector<const MemoryManager::MemoryPlane*> wrapAroundMemPlane; + + // Allocate a memory plane for each node's output + AIDGE_INTERNAL_ASSERT(!concat || node->nbOutputs() == 1); + for (IOIndex_t outputIdx = 0; outputIdx < node->nbOutputs(); ++outputIdx) { + auto requiredSize = op->getRequiredMemory(outputIdx, {}); + auto outputDims = (op->getOutput(outputIdx)) ? op->getOutput(outputIdx)->dims() : std::vector<DimSize_t>(); + auto outputFormat = (op->getOutput(outputIdx)) ? op->getOutput(outputIdx)->dataFormat() : DataFormat::Default; + + // If concat is not nullptr, we directly allocate the concatenation result + // Must check that we are on the right output too. + if (concat && node->getChildren(outputIdx).size() == 1) { + const auto concatOp = std::static_pointer_cast<OperatorTensor>(concat->getOperator()); + requiredSize = concatOp->getRequiredMemory(0, {}); + outputDims = (concatOp->getOutput(0)) ? concatOp->getOutput(0)->dims() : std::vector<DimSize_t>(); + outputFormat = (concatOp->getOutput(0)) ? concatOp->getOutput(0)->dataFormat() : DataFormat::Default; + } + + AIDGE_ASSERT(requiredSize.type == Elts_t::Data, + "Cannot generate memory with token-based producer-consumer model for node {} (of type {}).", + node->name(), node->type()); + + // By default, specifies a fully monolithic memory block + std::size_t size = requiredSize.data; + std::size_t stride = 0; + std::size_t length = 1; + std::size_t count = 1; + + if (outputFormat == DataFormat::NHWC) { + size = op->getOutput(outputIdx)->dims().end()[-3]; + stride = outputDims.end()[-3]; + length = outputDims.end()[-1]; + count = outputDims.end()[-2]; + + AIDGE_INTERNAL_ASSERT(stride >= size); + AIDGE_INTERNAL_ASSERT(length == op->getOutput(outputIdx)->dims().end()[-1]); + AIDGE_INTERNAL_ASSERT(count == op->getOutput(outputIdx)->dims().end()[-2]); + } + + // Check if wrap around buffer is possible for this node + // (re-using previous node outputs memory for this node outputs). + // => only if this node is the only child of its parent(s) + std::size_t wrapAroundSize = 0; + std::size_t wrapAroundExtra = 0; + wrapAroundMemPlane.push_back(nullptr); // default value of wrapAroundMemPlane[outputIdx] + + // Select the best parent among all allocable nodes for + // reallocation, which is the one with most memory (in order + // to minimize the reallocation size). + const auto allocableNodes = (concat) ? concat->getParents() : std::vector<NodePtr>{node}; + for (const auto& allocableNode : allocableNodes) { + IOIndex_t inputIdx = 0; + for (const auto& parent : allocableNode->dataInputs()) { + if (parent.first && parent.first->getChildren(parent.second).size() == 1 + // there might be no existing plane if the parent was + // not yet scheduled (because it may be a recurrent connection) + && memManager.getNbPlanes(parent.first) >= parent.first->nbOutputs() + // memSpace should not be already released + && memManager.getPlanes(parent.first).end()[-parent.first->nbOutputs()+parent.second].memSpace->released == -1) + { + const auto requiredData = allocableNode->getOperator()->getNbRequiredData(inputIdx); + const auto requiredProtected = allocableNode->getOperator()->getNbRequiredProtected(inputIdx); + AIDGE_ASSERT(requiredData.type == Elts_t::Data && requiredProtected.type == Elts_t::Data, + "Cannot generate memory with token-based producer-consumer model for node {} (of type {}).", + node->name(), node->type()); + + const bool isWrappable = (requiredProtected.data < requiredData.data); + const MemoryManager::MemoryPlane& memPlane + = (concat && itConcat != concatMemPlane.end()) + ? itConcat->second + : memManager.getPlanes(parent.first).end()[-parent.first->nbOutputs()+parent.second]; + + if (isWrappable || !memManager.isWrapAround( + memPlane.memSpace, + memPlane.getFinalOffset() + - memPlane.memSpace->offset, + requiredSize.data)) + { + if (memPlane.getSize() > wrapAroundSize + requiredProtected.data + && std::find(wrapAroundMemPlane.begin(), wrapAroundMemPlane.end(), &memPlane) == wrapAroundMemPlane.end()) + { + wrapAroundSize = memPlane.getSize() - requiredProtected.data; + if (requiredSize.data > wrapAroundSize) { + wrapAroundExtra = requiredSize.data - wrapAroundSize; + } + wrapAroundMemPlane[outputIdx] = &memPlane; + } + + if (wrapAroundExtra == 0) { + break; + } + } + } + ++inputIdx; + } + } + + size_t concatSize = size; + size_t concatOffset = 0; + + if (concat) { + // Dependencies should be concat node *childs*, not concat node + childs = concat->getChildren(); + + // Compute concatOffset + for (auto concatParent : concat->getParents()) { + const auto parentOp = std::static_pointer_cast<OperatorTensor>(concatParent->getOperator()); + const auto parentRequiredSize = parentOp->getRequiredMemory(outputIdx, {}); + const auto parentOutputDims = (parentOp->getOutput(outputIdx)) ? parentOp->getOutput(outputIdx)->dims() : std::vector<DimSize_t>(); + const auto parentOutputFormat = (parentOp->getOutput(outputIdx)) ? parentOp->getOutput(outputIdx)->dataFormat() : DataFormat::Default; + + if (concatParent == node) { + if (parentOutputFormat != DataFormat::NHWC) { + concatSize = parentRequiredSize.data; + } + break; + } + else { + // By default, specifies a fully monolithic memory block + std::size_t parentSize = parentRequiredSize.data; + + if (parentOutputFormat == DataFormat::NHWC) { + parentSize = parentOutputDims.end()[-3]; + } + + concatOffset += parentSize; + } + } + } + + // MemoryPlane to (re)use + const MemoryManager::MemoryPlane& memPlane + = (concat && itConcat != concatMemPlane.end()) + ? itConcat->second : + (wrapAroundBuffer && wrapAroundSize > 0) + ? (*wrapAroundMemPlane[outputIdx]) : + memManager.allocate(size, childs, stride, length, count); + + if (wrapAroundBuffer && wrapAroundSize > 0) { + memManager.reallocate(memPlane, + node, concatOffset, + concatSize, true, wrapAroundExtra, childs, stride, length, count); + } + else { + memManager.reallocate(memPlane.memSpace, + node, memPlane.offset + concatOffset, + concatSize, false, 0, childs, stride, length, count); + } + + if (concat && itConcat == concatMemPlane.end()) { + concatMemPlane.emplace(concat, memPlane); + + if (wrapAroundBuffer && wrapAroundSize > 0) { + memManager.reallocate(memPlane, + concat, 0, + size, true, wrapAroundExtra, childs, stride, length, count); + } + else { + memManager.reallocate(memPlane.memSpace, + concat, memPlane.offset, + size, false, 0, childs, stride, length, count); + } + } + } + + memManager.releaseDependencies(node); + memManager.tick(); + } + } + + return memManager; +} + void Aidge::Scheduler::connectInputs(const std::vector<std::shared_ptr<Aidge::Tensor>>& data){ // This version of connect inputs only connects tensor inputs in input data producers. auto inputNodes = mGraphView->getOrderedInputs(); @@ -649,11 +871,21 @@ void Aidge::Scheduler::saveStaticSchedulingDiagram(const std::string& fileName) fmt::print(fp.get(), "\n"); } -std::vector<std::shared_ptr<Aidge::Node>> Aidge::Scheduler::getStaticScheduling(std::size_t step) const { +std::vector<std::shared_ptr<Aidge::Node>> Aidge::Scheduler::getStaticScheduling(std::size_t step, EarlyLateSort sorting) const { AIDGE_ASSERT(!mStaticSchedule.empty(), "Scheduler::getStaticScheduling(): static scheduling is empty, did you generate scheduling first?"); AIDGE_ASSERT(step < mStaticSchedule.size(), "Scheduler::getStaticScheduling(): no static scheduling at step {} (available steps: {})", mStaticSchedule.size(), step); - const auto& staticSchedule = mStaticSchedule.at(step); + std::deque<StaticSchedulingElement*> staticSchedule(mStaticSchedule.at(step).begin(), mStaticSchedule.at(step).end()); + + if (sorting == EarlyLateSort::AsSoonAsPossible) { + std::stable_sort(staticSchedule.begin(), staticSchedule.end(), + [](const auto& lhs, const auto& rhs) { return ((lhs->early < rhs->early) || (lhs->early == rhs->early && lhs->late < rhs->late)); }); + } + else if (sorting == EarlyLateSort::AsLateAsPossible) { + std::stable_sort(staticSchedule.begin(), staticSchedule.end(), + [](const auto& lhs, const auto& rhs) { return ((lhs->late < rhs->late) || (lhs->late == rhs->late && lhs->early < rhs->early)); }); + } + std::vector<std::shared_ptr<Node>> schedule; std::transform(staticSchedule.begin(), staticSchedule.end(), std::back_inserter(schedule), [](const auto& v) { return v->node; }); return schedule;