diff --git a/aidge_core/mem_info.py b/aidge_core/mem_info.py index c7ca85bbd73bd205850b19616e53fda210749a80..8946c4dbdbab23b29f6c687db2d12026b15343c5 100644 --- a/aidge_core/mem_info.py +++ b/aidge_core/mem_info.py @@ -54,7 +54,7 @@ def _gnuplot_installed(): aidge_core.Log.warn("Gnuplot command found but failed to run.") return False -def generate_optimized_memory_info(scheduler: aidge_core.Scheduler, stats_folder: Path = None, wrapping: bool = False) -> Tuple[int, List[dict]]: +def generate_optimized_memory_info(scheduler: aidge_core.Scheduler, stats_folder: Path = None, wrapping: bool = False, auto_concat: bool = False) -> Tuple[int, List[dict]]: """Generates optimized memory information for a computation graph managed by a scheduler. This function analyzes the memory usage of a computation graph, determining the memory peak @@ -70,6 +70,9 @@ def generate_optimized_memory_info(scheduler: aidge_core.Scheduler, stats_folder :param wrapping: Boolean flag to enable or disable wrap-around buffer optimization. Defaults to `False`. :type wrapping: bool, optional + :param auto_concat: Boolean flag to enable or disable auto-concatenation optimization. + Defaults to `False`. + :type auto_concat: bool, optional :return: A tuple containing the peak memory size and a list of memory information for each scheduled node. The memory information for each node includes details such as size, offset, stride, length, count, and optional wrap-around details. @@ -81,8 +84,12 @@ def generate_optimized_memory_info(scheduler: aidge_core.Scheduler, stats_folder # scheduler.generate_scheduling() # Generate the memory manager # So far, the Producers are not take in consideration in the meory manager => inc_producers=False - mem_manager = scheduler.generate_memory( - inc_producers=False, wrap_around_buffer=wrapping) + if auto_concat: + mem_manager = scheduler.generate_memory_auto_concat( + inc_producers=False, wrap_around_buffer=wrapping) + else: + mem_manager = scheduler.generate_memory( + inc_producers=False, wrap_around_buffer=wrapping) # List of nodes which are connected at the input of the graph (None if input is not connected) nodes_at_input = [n[0] for n in scheduler.graph_view().inputs()] @@ -137,18 +144,19 @@ def generate_optimized_memory_info(scheduler: aidge_core.Scheduler, stats_folder }) else: for out_id in range(node.get_nb_outputs()): - plane = mem_planes[node][out_id] - node_mem_info.append({ - "size": plane.size, - "offset": plane.get_contiguous_offset(), - "stride": plane.stride, - "length": plane.length, - "count": plane.count, - "cont_offset": plane.get_contiguous_offset(), - "cont_size": plane.get_contiguous_size(), - "wrap_offset": plane.get_wrapped_offset(), - "wrap_size": plane.get_wrapped_size() - }) + if node in mem_planes: + plane = mem_planes[node][out_id] + node_mem_info.append({ + "size": plane.size, + "offset": plane.get_contiguous_offset(), + "stride": plane.stride, + "length": plane.length, + "count": plane.count, + "cont_offset": plane.get_contiguous_offset(), + "cont_size": plane.get_contiguous_size(), + "wrap_offset": plane.get_wrapped_offset(), + "wrap_size": plane.get_wrapped_size() + }) mem_info[node] = node_mem_info return mem_size, mem_info diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 9b655331a990473d3958e0a6a4a3fb9ded598813..242a2d0e6936bce1459946ee72e7d790d15c4fda 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -476,8 +476,7 @@ Aidge::MemoryManager Aidge::Scheduler::generateMemory(bool incProducers, bool wr std::size_t length = 1; std::size_t count = 1; - if (op->getOutput(outputIdx) && op->getOutput(outputIdx)->dims().size() > 3) { - // If it is possible, assume a NCHW layout + if (op->getOutput(outputIdx) && op->getOutput(outputIdx)->dataFormat() == DataFormat::NHWC) { size = op->getOutput(outputIdx)->dims().end()[-3]; stride = size; length = op->getOutput(outputIdx)->dims().end()[-1]; @@ -618,6 +617,7 @@ Aidge::MemoryManager Aidge::Scheduler::generateMemoryAutoConcat(bool incProducer for (IOIndex_t outputIdx = 0; outputIdx < node->nbOutputs(); ++outputIdx) { auto requiredSize = op->getRequiredMemory(outputIdx, {}); auto outputDims = (op->getOutput(outputIdx)) ? op->getOutput(outputIdx)->dims() : std::vector<DimSize_t>(); + auto outputFormat = (op->getOutput(outputIdx)) ? op->getOutput(outputIdx)->dataFormat() : DataFormat::Default; // If concat is not nullptr, we directly allocate the concatenation result // Must check that we are on the right output too. @@ -625,6 +625,7 @@ Aidge::MemoryManager Aidge::Scheduler::generateMemoryAutoConcat(bool incProducer const auto concatOp = std::static_pointer_cast<OperatorTensor>(concat->getOperator()); requiredSize = concatOp->getRequiredMemory(0, {}); outputDims = (concatOp->getOutput(0)) ? concatOp->getOutput(0)->dims() : std::vector<DimSize_t>(); + outputFormat = (concatOp->getOutput(0)) ? concatOp->getOutput(0)->dataFormat() : DataFormat::Default; } AIDGE_ASSERT(requiredSize.type == Elts_t::Data, @@ -637,8 +638,7 @@ Aidge::MemoryManager Aidge::Scheduler::generateMemoryAutoConcat(bool incProducer std::size_t length = 1; std::size_t count = 1; - if (outputDims.size() > 3) { - // If it is possible, assume a NCHW layout + if (outputFormat == DataFormat::NHWC) { size = op->getOutput(outputIdx)->dims().end()[-3]; stride = outputDims.end()[-3]; length = outputDims.end()[-1];