Skip to content
Snippets Groups Projects
Commit 95f7c06f authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added support for auto concatenation

parent c312252d
No related branches found
No related tags found
2 merge requests!318[Upd] release verision 0.5.0,!290[Add] support for auto-concatenation and Fix multiple adaptToBackend() issues
Pipeline #62245 canceled
......@@ -90,6 +90,12 @@ public:
NotConnected
};
enum class EarlyLateSort {
Default,
AsSoonAsPossible,
AsLateAsPossible
};
/**
* @struct PriorProducersConsumers
* @brief Manages producer-consumer relationships for nodes.
......@@ -124,9 +130,10 @@ public:
/**
* @brief Get the static scheduling order of nodes.
* @param step The step of the static schedule to retrieve (default is 0).
* @param sorting Sorting mode.
* @return Vector of shared pointers to Nodes in their scheduled order.
*/
std::vector<std::shared_ptr<Node>> getStaticScheduling(std::size_t step = 0) const;
std::vector<std::shared_ptr<Node>> getStaticScheduling(std::size_t step = 0, EarlyLateSort sorting = EarlyLateSort::Default) const;
/**
* @brief Get the GraphView associated with this Scheduler.
......@@ -156,6 +163,15 @@ public:
*/
MemoryManager generateMemory(bool incProducers = false, bool wrapAroundBuffer = false) const;
/**
* Generate the memory layout for the current static scheduling, with auto-
* concatenation: the Concat operator is replaced by direct allocation
* when possible.
* @param incProducers If true, include the producers in the memory layout.
* @param wrapAroundBuffer If true, allow wrapping in memory planes.
*/
MemoryManager generateMemoryAutoConcat(bool incProducers = false, bool wrapAroundBuffer = false) const;
/**
* @brief Connect input tensors to the data input of the GraphView.
* In case of multiple data input tensors, they are mapped to producers in
......
......@@ -21,6 +21,12 @@
namespace py = pybind11;
namespace Aidge {
void init_Scheduler(py::module& m){
py::enum_<Scheduler::EarlyLateSort>(m, "EarlyLateSort")
.value("Default", Scheduler::EarlyLateSort::Default)
.value("AsSoonAsPossible", Scheduler::EarlyLateSort::AsSoonAsPossible)
.value("AsLateAsPossible", Scheduler::EarlyLateSort::AsLateAsPossible)
.export_values();
py::class_<Scheduler, std::shared_ptr<Scheduler>>(m, "Scheduler")
.def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view"))
.def("graph_view", &Scheduler::graphView)
......@@ -28,9 +34,10 @@ void init_Scheduler(py::module& m){
.def("save_static_scheduling_diagram", &Scheduler::saveStaticSchedulingDiagram, py::arg("file_name"))
.def("resetScheduling", &Scheduler::resetScheduling)
.def("generate_scheduling", &Scheduler::generateScheduling)
.def("get_static_scheduling", &Scheduler::getStaticScheduling, py::arg("step") = 0)
.def("get_static_scheduling", &Scheduler::getStaticScheduling, py::arg("step") = 0, py::arg("sorting") = EarlyLateSort::Default)
.def("graph_view", &Scheduler::graphView)
.def("generate_memory", &Scheduler::generateMemory, py::arg("inc_producers") = false, py::arg("wrap_around_buffer") = false)
.def("generate_memory_auto_concat", &Scheduler::generateMemoryAutoConcat, py::arg("inc_producers") = false, py::arg("wrap_around_buffer") = false)
;
py::class_<SequentialScheduler, std::shared_ptr<SequentialScheduler>, Scheduler>(m, "SequentialScheduler")
......
......@@ -33,6 +33,7 @@
#include "aidge/operator/MetaOperator.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/operator/Concat.hpp"
#include "aidge/utils/Log.hpp"
#include "aidge/utils/Types.h"
......@@ -561,6 +562,212 @@ Aidge::MemoryManager Aidge::Scheduler::generateMemory(bool incProducers, bool wr
return memManager;
}
Aidge::MemoryManager Aidge::Scheduler::generateMemoryAutoConcat(bool incProducers, bool wrapAroundBuffer) const {
MemoryManager memManager;
std::map<NodePtr, MemoryManager::MemoryPlane> concatMemPlane;
for (std::size_t step = 0; step < mStaticSchedule.size(); ++step) {
// AsLateAsPossible ensures that when a node child is Concat, all the parents
// of the Concat parents have already been memory mapped!
for (const auto& node : getStaticScheduling(step, EarlyLateSort::AsLateAsPossible)) {
if (!incProducers && node->type() == Producer_Op::Type) {
memManager.releaseDependencies(node);
continue;
}
auto itConcat = concatMemPlane.find(node);
if (itConcat != concatMemPlane.end()) {
// Skip Concat
AIDGE_INTERNAL_ASSERT(itConcat->first->type() == Concat_Op::Type);
concatMemPlane.erase(itConcat);
continue;
}
itConcat = concatMemPlane.end();
auto childs = node->getChildren();
AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor,
"Operator must be of Tensor type for node {} (of type {}).",
node->name(), node->type());
const auto op = std::static_pointer_cast<OperatorTensor>(node->getOperator());
std::shared_ptr<Node> concat = nullptr;
// If the only child is a concatenation, check if we can allocate
// the concatenation result directly and skip allocation for this
// node output:
if (childs.size() == 1 && (*childs.begin())->type() == Concat_Op::Type) {
concat = *childs.begin();
for (const auto& concatParent : concat->getParents()) {
if (concatParent->getChildren().size() > 1) {
// not possible: at least one of the Concat parent has
// multiple children.
concat = nullptr;
break;
}
}
}
if (concat) {
itConcat = concatMemPlane.find(concat);
}
std::vector<const MemoryManager::MemoryPlane*> wrapAroundMemPlane;
// Allocate a memory plane for each node's output
AIDGE_INTERNAL_ASSERT(!concat || node->nbOutputs() == 1);
for (IOIndex_t outputIdx = 0; outputIdx < node->nbOutputs(); ++outputIdx) {
auto requiredSize = op->getRequiredMemory(outputIdx, {});
auto outputDims = (op->getOutput(outputIdx)) ? op->getOutput(outputIdx)->dims() : std::vector<DimSize_t>();
// If concat is not nullptr, we directly allocate the concatenation result
// Must check that we are on the right output too.
if (concat && node->getChildren(outputIdx).size() == 1) {
const auto concatOp = std::static_pointer_cast<OperatorTensor>(concat->getOperator());
requiredSize = concatOp->getRequiredMemory(0, {});
outputDims = (concatOp->getOutput(0)) ? concatOp->getOutput(0)->dims() : std::vector<DimSize_t>();
}
AIDGE_ASSERT(requiredSize.type == Elts_t::Data,
"Cannot generate memory with token-based producer-consumer model for node {} (of type {}).",
node->name(), node->type());
// By default, specifies a fully monolithic memory block
std::size_t size = requiredSize.data;
std::size_t stride = 0;
std::size_t length = 1;
std::size_t count = 1;
if (outputDims.size() > 3) {
// If it is possible, assume a NCHW layout
size = op->getOutput(outputIdx)->dims().end()[-3];
stride = outputDims.end()[-3];
length = outputDims.end()[-1];
count = outputDims.end()[-2];
AIDGE_INTERNAL_ASSERT(stride >= size);
AIDGE_INTERNAL_ASSERT(length == op->getOutput(outputIdx)->dims().end()[-1]);
AIDGE_INTERNAL_ASSERT(count == op->getOutput(outputIdx)->dims().end()[-2]);
}
// Check if wrap around buffer is possible for this node
// (re-using previous node outputs memory for this node outputs).
// => only if this node is the only child of its parent(s)
std::size_t wrapAroundSize = 0;
std::size_t wrapAroundExtra = 0;
wrapAroundMemPlane.push_back(nullptr); // default value of wrapAroundMemPlane[outputIdx]
// Select the best parent among all allocable nodes for
// reallocation, which is the one with most memory (in order
// to minimize the reallocation size).
const auto allocableNodes = (concat) ? concat->getParents() : std::vector<NodePtr>{node};
for (const auto& allocableNode : allocableNodes) {
IOIndex_t inputIdx = 0;
for (const auto& parent : allocableNode->dataInputs()) {
if (parent.first && parent.first->getChildren(parent.second).size() == 1
// there might be no existing plane if the parent was
// not yet scheduled (because it may be a recurrent connection)
&& memManager.getNbPlanes(parent.first) >= parent.first->nbOutputs()
// memSpace should not be already released
&& memManager.getPlanes(parent.first).end()[-parent.first->nbOutputs()+parent.second].memSpace->released == -1)
{
const auto requiredData = allocableNode->getOperator()->getNbRequiredData(inputIdx);
const auto requiredProtected = allocableNode->getOperator()->getNbRequiredProtected(inputIdx);
AIDGE_ASSERT(requiredData.type == Elts_t::Data && requiredProtected.type == Elts_t::Data,
"Cannot generate memory with token-based producer-consumer model for node {} (of type {}).",
node->name(), node->type());
const bool isWrappable = (requiredProtected.data < requiredData.data);
const MemoryManager::MemoryPlane& memPlane
= (concat && itConcat != concatMemPlane.end())
? itConcat->second
: memManager.getPlanes(parent.first).end()[-parent.first->nbOutputs()+parent.second];
if (isWrappable || !memManager.isWrapAround(
memPlane.memSpace,
memPlane.getFinalOffset()
- memPlane.memSpace->offset,
requiredSize.data))
{
if (memPlane.getSize() > wrapAroundSize + requiredProtected.data
&& std::find(wrapAroundMemPlane.begin(), wrapAroundMemPlane.end(), &memPlane) == wrapAroundMemPlane.end())
{
wrapAroundSize = memPlane.getSize() - requiredProtected.data;
if (requiredSize.data > wrapAroundSize) {
wrapAroundExtra = requiredSize.data - wrapAroundSize;
}
wrapAroundMemPlane[outputIdx] = &memPlane;
}
if (wrapAroundExtra == 0) {
break;
}
}
}
++inputIdx;
}
}
size_t concatOffset = 0;
if (concat) {
// Dependencies should be concat node *childs*, not concat node
childs = concat->getChildren();
// Compute concatOffset
for (auto concatParent : concat->getParents()) {
if (concatParent == node) {
break;
}
else {
const auto parentOp = std::static_pointer_cast<OperatorTensor>(concatParent->getOperator());
const auto parentRequiredSize = parentOp->getRequiredMemory(outputIdx, {});
const auto parentOutputDims = (parentOp->getOutput(outputIdx)) ? parentOp->getOutput(outputIdx)->dims() : std::vector<DimSize_t>();
// By default, specifies a fully monolithic memory block
std::size_t parentSize = parentRequiredSize.data;
if (parentOutputDims.size() > 3) {
// If it is possible, assume a NCHW layout
parentSize = parentOutputDims.end()[-3];
}
concatOffset += parentSize;
}
}
}
// MemoryPlane to (re)use
const MemoryManager::MemoryPlane& memPlane
= (concat && itConcat != concatMemPlane.end())
? itConcat->second :
(wrapAroundBuffer && wrapAroundSize > 0)
? (*wrapAroundMemPlane[outputIdx]) :
memManager.allocate(size, childs, stride, length, count);
if (wrapAroundBuffer && wrapAroundSize > 0) {
memManager.reallocate(memPlane,
node, concatOffset,
size, true, wrapAroundExtra, childs, stride, length, count);
}
else {
memManager.reallocate(memPlane.memSpace,
node, memPlane.offset + concatOffset,
size, false, 0, childs, stride, length, count);
}
if (concat && itConcat == concatMemPlane.end()) {
concatMemPlane.emplace(concat, memPlane);
}
}
memManager.releaseDependencies(node);
memManager.tick();
}
}
return memManager;
}
void Aidge::Scheduler::connectInputs(const std::vector<std::shared_ptr<Aidge::Tensor>>& data){
// This version of connect inputs only connects tensor inputs in input data producers.
auto inputNodes = mGraphView->getOrderedInputs();
......@@ -649,11 +856,21 @@ void Aidge::Scheduler::saveStaticSchedulingDiagram(const std::string& fileName)
fmt::print(fp.get(), "\n");
}
std::vector<std::shared_ptr<Aidge::Node>> Aidge::Scheduler::getStaticScheduling(std::size_t step) const {
std::vector<std::shared_ptr<Aidge::Node>> Aidge::Scheduler::getStaticScheduling(std::size_t step, EarlyLateSort sorting) const {
AIDGE_ASSERT(!mStaticSchedule.empty(), "Scheduler::getStaticScheduling(): static scheduling is empty, did you generate scheduling first?");
AIDGE_ASSERT(step < mStaticSchedule.size(), "Scheduler::getStaticScheduling(): no static scheduling at step {} (available steps: {})", mStaticSchedule.size(), step);
const auto& staticSchedule = mStaticSchedule.at(step);
std::deque<StaticSchedulingElement*> staticSchedule(mStaticSchedule.at(step).begin(), mStaticSchedule.at(step).end());
if (sorting == EarlyLateSort::AsSoonAsPossible) {
std::stable_sort(staticSchedule.begin(), staticSchedule.end(),
[](const auto& lhs, const auto& rhs) { return ((lhs->early < rhs->early) || (lhs->early == rhs->early && lhs->late < rhs->late)); });
}
else if (sorting == EarlyLateSort::AsLateAsPossible) {
std::stable_sort(staticSchedule.begin(), staticSchedule.end(),
[](const auto& lhs, const auto& rhs) { return ((lhs->late < rhs->late) || (lhs->late == rhs->late && lhs->early < rhs->early)); });
}
std::vector<std::shared_ptr<Node>> schedule;
std::transform(staticSchedule.begin(), staticSchedule.end(), std::back_inserter(schedule), [](const auto& v) { return v->node; });
return schedule;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment