From 95f7c06f314f02d2b0376643388c3769b958174b Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Fri, 20 Dec 2024 18:35:05 +0100 Subject: [PATCH] Added support for auto concatenation --- include/aidge/scheduler/Scheduler.hpp | 18 +- python_binding/scheduler/pybind_Scheduler.cpp | 9 +- src/scheduler/Scheduler.cpp | 221 +++++++++++++++++- 3 files changed, 244 insertions(+), 4 deletions(-) diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index 28ecde6d9..51f62ed1b 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -90,6 +90,12 @@ public: NotConnected }; + enum class EarlyLateSort { + Default, + AsSoonAsPossible, + AsLateAsPossible + }; + /** * @struct PriorProducersConsumers * @brief Manages producer-consumer relationships for nodes. @@ -124,9 +130,10 @@ public: /** * @brief Get the static scheduling order of nodes. * @param step The step of the static schedule to retrieve (default is 0). + * @param sorting Sorting mode. * @return Vector of shared pointers to Nodes in their scheduled order. */ - std::vector<std::shared_ptr<Node>> getStaticScheduling(std::size_t step = 0) const; + std::vector<std::shared_ptr<Node>> getStaticScheduling(std::size_t step = 0, EarlyLateSort sorting = EarlyLateSort::Default) const; /** * @brief Get the GraphView associated with this Scheduler. @@ -156,6 +163,15 @@ public: */ MemoryManager generateMemory(bool incProducers = false, bool wrapAroundBuffer = false) const; + /** + * Generate the memory layout for the current static scheduling, with auto- + * concatenation: the Concat operator is replaced by direct allocation + * when possible. + * @param incProducers If true, include the producers in the memory layout. + * @param wrapAroundBuffer If true, allow wrapping in memory planes. + */ + MemoryManager generateMemoryAutoConcat(bool incProducers = false, bool wrapAroundBuffer = false) const; + /** * @brief Connect input tensors to the data input of the GraphView. * In case of multiple data input tensors, they are mapped to producers in diff --git a/python_binding/scheduler/pybind_Scheduler.cpp b/python_binding/scheduler/pybind_Scheduler.cpp index 472af2a94..7eb0db712 100644 --- a/python_binding/scheduler/pybind_Scheduler.cpp +++ b/python_binding/scheduler/pybind_Scheduler.cpp @@ -21,6 +21,12 @@ namespace py = pybind11; namespace Aidge { void init_Scheduler(py::module& m){ + py::enum_<Scheduler::EarlyLateSort>(m, "EarlyLateSort") + .value("Default", Scheduler::EarlyLateSort::Default) + .value("AsSoonAsPossible", Scheduler::EarlyLateSort::AsSoonAsPossible) + .value("AsLateAsPossible", Scheduler::EarlyLateSort::AsLateAsPossible) + .export_values(); + py::class_<Scheduler, std::shared_ptr<Scheduler>>(m, "Scheduler") .def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view")) .def("graph_view", &Scheduler::graphView) @@ -28,9 +34,10 @@ void init_Scheduler(py::module& m){ .def("save_static_scheduling_diagram", &Scheduler::saveStaticSchedulingDiagram, py::arg("file_name")) .def("resetScheduling", &Scheduler::resetScheduling) .def("generate_scheduling", &Scheduler::generateScheduling) - .def("get_static_scheduling", &Scheduler::getStaticScheduling, py::arg("step") = 0) + .def("get_static_scheduling", &Scheduler::getStaticScheduling, py::arg("step") = 0, py::arg("sorting") = EarlyLateSort::Default) .def("graph_view", &Scheduler::graphView) .def("generate_memory", &Scheduler::generateMemory, py::arg("inc_producers") = false, py::arg("wrap_around_buffer") = false) + .def("generate_memory_auto_concat", &Scheduler::generateMemoryAutoConcat, py::arg("inc_producers") = false, py::arg("wrap_around_buffer") = false) ; py::class_<SequentialScheduler, std::shared_ptr<SequentialScheduler>, Scheduler>(m, "SequentialScheduler") diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 2e9dc034e..9b655331a 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -33,6 +33,7 @@ #include "aidge/operator/MetaOperator.hpp" #include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/Producer.hpp" +#include "aidge/operator/Concat.hpp" #include "aidge/utils/Log.hpp" #include "aidge/utils/Types.h" @@ -561,6 +562,212 @@ Aidge::MemoryManager Aidge::Scheduler::generateMemory(bool incProducers, bool wr return memManager; } +Aidge::MemoryManager Aidge::Scheduler::generateMemoryAutoConcat(bool incProducers, bool wrapAroundBuffer) const { + MemoryManager memManager; + std::map<NodePtr, MemoryManager::MemoryPlane> concatMemPlane; + + for (std::size_t step = 0; step < mStaticSchedule.size(); ++step) { + // AsLateAsPossible ensures that when a node child is Concat, all the parents + // of the Concat parents have already been memory mapped! + for (const auto& node : getStaticScheduling(step, EarlyLateSort::AsLateAsPossible)) { + if (!incProducers && node->type() == Producer_Op::Type) { + memManager.releaseDependencies(node); + continue; + } + + auto itConcat = concatMemPlane.find(node); + if (itConcat != concatMemPlane.end()) { + // Skip Concat + AIDGE_INTERNAL_ASSERT(itConcat->first->type() == Concat_Op::Type); + concatMemPlane.erase(itConcat); + continue; + } + itConcat = concatMemPlane.end(); + + auto childs = node->getChildren(); + AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, + "Operator must be of Tensor type for node {} (of type {}).", + node->name(), node->type()); + const auto op = std::static_pointer_cast<OperatorTensor>(node->getOperator()); + + std::shared_ptr<Node> concat = nullptr; + // If the only child is a concatenation, check if we can allocate + // the concatenation result directly and skip allocation for this + // node output: + if (childs.size() == 1 && (*childs.begin())->type() == Concat_Op::Type) { + concat = *childs.begin(); + + for (const auto& concatParent : concat->getParents()) { + if (concatParent->getChildren().size() > 1) { + // not possible: at least one of the Concat parent has + // multiple children. + concat = nullptr; + break; + } + } + } + + if (concat) { + itConcat = concatMemPlane.find(concat); + } + + std::vector<const MemoryManager::MemoryPlane*> wrapAroundMemPlane; + + // Allocate a memory plane for each node's output + AIDGE_INTERNAL_ASSERT(!concat || node->nbOutputs() == 1); + for (IOIndex_t outputIdx = 0; outputIdx < node->nbOutputs(); ++outputIdx) { + auto requiredSize = op->getRequiredMemory(outputIdx, {}); + auto outputDims = (op->getOutput(outputIdx)) ? op->getOutput(outputIdx)->dims() : std::vector<DimSize_t>(); + + // If concat is not nullptr, we directly allocate the concatenation result + // Must check that we are on the right output too. + if (concat && node->getChildren(outputIdx).size() == 1) { + const auto concatOp = std::static_pointer_cast<OperatorTensor>(concat->getOperator()); + requiredSize = concatOp->getRequiredMemory(0, {}); + outputDims = (concatOp->getOutput(0)) ? concatOp->getOutput(0)->dims() : std::vector<DimSize_t>(); + } + + AIDGE_ASSERT(requiredSize.type == Elts_t::Data, + "Cannot generate memory with token-based producer-consumer model for node {} (of type {}).", + node->name(), node->type()); + + // By default, specifies a fully monolithic memory block + std::size_t size = requiredSize.data; + std::size_t stride = 0; + std::size_t length = 1; + std::size_t count = 1; + + if (outputDims.size() > 3) { + // If it is possible, assume a NCHW layout + size = op->getOutput(outputIdx)->dims().end()[-3]; + stride = outputDims.end()[-3]; + length = outputDims.end()[-1]; + count = outputDims.end()[-2]; + + AIDGE_INTERNAL_ASSERT(stride >= size); + AIDGE_INTERNAL_ASSERT(length == op->getOutput(outputIdx)->dims().end()[-1]); + AIDGE_INTERNAL_ASSERT(count == op->getOutput(outputIdx)->dims().end()[-2]); + } + + // Check if wrap around buffer is possible for this node + // (re-using previous node outputs memory for this node outputs). + // => only if this node is the only child of its parent(s) + std::size_t wrapAroundSize = 0; + std::size_t wrapAroundExtra = 0; + wrapAroundMemPlane.push_back(nullptr); // default value of wrapAroundMemPlane[outputIdx] + + // Select the best parent among all allocable nodes for + // reallocation, which is the one with most memory (in order + // to minimize the reallocation size). + const auto allocableNodes = (concat) ? concat->getParents() : std::vector<NodePtr>{node}; + for (const auto& allocableNode : allocableNodes) { + IOIndex_t inputIdx = 0; + for (const auto& parent : allocableNode->dataInputs()) { + if (parent.first && parent.first->getChildren(parent.second).size() == 1 + // there might be no existing plane if the parent was + // not yet scheduled (because it may be a recurrent connection) + && memManager.getNbPlanes(parent.first) >= parent.first->nbOutputs() + // memSpace should not be already released + && memManager.getPlanes(parent.first).end()[-parent.first->nbOutputs()+parent.second].memSpace->released == -1) + { + const auto requiredData = allocableNode->getOperator()->getNbRequiredData(inputIdx); + const auto requiredProtected = allocableNode->getOperator()->getNbRequiredProtected(inputIdx); + AIDGE_ASSERT(requiredData.type == Elts_t::Data && requiredProtected.type == Elts_t::Data, + "Cannot generate memory with token-based producer-consumer model for node {} (of type {}).", + node->name(), node->type()); + + const bool isWrappable = (requiredProtected.data < requiredData.data); + const MemoryManager::MemoryPlane& memPlane + = (concat && itConcat != concatMemPlane.end()) + ? itConcat->second + : memManager.getPlanes(parent.first).end()[-parent.first->nbOutputs()+parent.second]; + + if (isWrappable || !memManager.isWrapAround( + memPlane.memSpace, + memPlane.getFinalOffset() + - memPlane.memSpace->offset, + requiredSize.data)) + { + if (memPlane.getSize() > wrapAroundSize + requiredProtected.data + && std::find(wrapAroundMemPlane.begin(), wrapAroundMemPlane.end(), &memPlane) == wrapAroundMemPlane.end()) + { + wrapAroundSize = memPlane.getSize() - requiredProtected.data; + if (requiredSize.data > wrapAroundSize) { + wrapAroundExtra = requiredSize.data - wrapAroundSize; + } + wrapAroundMemPlane[outputIdx] = &memPlane; + } + + if (wrapAroundExtra == 0) { + break; + } + } + } + ++inputIdx; + } + } + + size_t concatOffset = 0; + + if (concat) { + // Dependencies should be concat node *childs*, not concat node + childs = concat->getChildren(); + + // Compute concatOffset + for (auto concatParent : concat->getParents()) { + if (concatParent == node) { + break; + } + else { + const auto parentOp = std::static_pointer_cast<OperatorTensor>(concatParent->getOperator()); + const auto parentRequiredSize = parentOp->getRequiredMemory(outputIdx, {}); + const auto parentOutputDims = (parentOp->getOutput(outputIdx)) ? parentOp->getOutput(outputIdx)->dims() : std::vector<DimSize_t>(); + + // By default, specifies a fully monolithic memory block + std::size_t parentSize = parentRequiredSize.data; + + if (parentOutputDims.size() > 3) { + // If it is possible, assume a NCHW layout + parentSize = parentOutputDims.end()[-3]; + } + + concatOffset += parentSize; + } + } + } + + // MemoryPlane to (re)use + const MemoryManager::MemoryPlane& memPlane + = (concat && itConcat != concatMemPlane.end()) + ? itConcat->second : + (wrapAroundBuffer && wrapAroundSize > 0) + ? (*wrapAroundMemPlane[outputIdx]) : + memManager.allocate(size, childs, stride, length, count); + + if (wrapAroundBuffer && wrapAroundSize > 0) { + memManager.reallocate(memPlane, + node, concatOffset, + size, true, wrapAroundExtra, childs, stride, length, count); + } + else { + memManager.reallocate(memPlane.memSpace, + node, memPlane.offset + concatOffset, + size, false, 0, childs, stride, length, count); + } + + if (concat && itConcat == concatMemPlane.end()) { + concatMemPlane.emplace(concat, memPlane); + } + } + + memManager.releaseDependencies(node); + memManager.tick(); + } + } + + return memManager; +} + void Aidge::Scheduler::connectInputs(const std::vector<std::shared_ptr<Aidge::Tensor>>& data){ // This version of connect inputs only connects tensor inputs in input data producers. auto inputNodes = mGraphView->getOrderedInputs(); @@ -649,11 +856,21 @@ void Aidge::Scheduler::saveStaticSchedulingDiagram(const std::string& fileName) fmt::print(fp.get(), "\n"); } -std::vector<std::shared_ptr<Aidge::Node>> Aidge::Scheduler::getStaticScheduling(std::size_t step) const { +std::vector<std::shared_ptr<Aidge::Node>> Aidge::Scheduler::getStaticScheduling(std::size_t step, EarlyLateSort sorting) const { AIDGE_ASSERT(!mStaticSchedule.empty(), "Scheduler::getStaticScheduling(): static scheduling is empty, did you generate scheduling first?"); AIDGE_ASSERT(step < mStaticSchedule.size(), "Scheduler::getStaticScheduling(): no static scheduling at step {} (available steps: {})", mStaticSchedule.size(), step); - const auto& staticSchedule = mStaticSchedule.at(step); + std::deque<StaticSchedulingElement*> staticSchedule(mStaticSchedule.at(step).begin(), mStaticSchedule.at(step).end()); + + if (sorting == EarlyLateSort::AsSoonAsPossible) { + std::stable_sort(staticSchedule.begin(), staticSchedule.end(), + [](const auto& lhs, const auto& rhs) { return ((lhs->early < rhs->early) || (lhs->early == rhs->early && lhs->late < rhs->late)); }); + } + else if (sorting == EarlyLateSort::AsLateAsPossible) { + std::stable_sort(staticSchedule.begin(), staticSchedule.end(), + [](const auto& lhs, const auto& rhs) { return ((lhs->late < rhs->late) || (lhs->late == rhs->late && lhs->early < rhs->early)); }); + } + std::vector<std::shared_ptr<Node>> schedule; std::transform(staticSchedule.begin(), staticSchedule.end(), std::back_inserter(schedule), [](const auto& v) { return v->node; }); return schedule; -- GitLab