From 092cb4bd10ab46cf1365a9e90adacfe14feb346c Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Tue, 3 Sep 2024 09:12:38 +0000 Subject: [PATCH] Adapt aidge mem_info to new API --- aidge_core/mem_info.py | 86 +++++++++++++++++++++--------------------- 1 file changed, 44 insertions(+), 42 deletions(-) diff --git a/aidge_core/mem_info.py b/aidge_core/mem_info.py index 3319e4807..b607a19f3 100644 --- a/aidge_core/mem_info.py +++ b/aidge_core/mem_info.py @@ -70,14 +70,14 @@ def generate_optimized_memory_info(scheduler: aidge_core.Scheduler, stats_folder # In the export, we currently use an unified memory buffer whose size # is determined by the memory peak usage mem_size = mem_manager.get_peak_usage() - mem_info = [] + mem_info = {} mem_planes = mem_manager.get_planes() for node in scheduler.get_static_scheduling(): if node.type() == "Producer": - continue # Skipping memory management for producers - if node in nodes_at_input: + mem_info[node] = [] # No meminfo for producer + elif node in nodes_at_input: # Input memory management (suppose tensor ends with [:, channel, height, width])) tensor = node.get_operator().get_output(0) if tensor is None: @@ -86,47 +86,49 @@ def generate_optimized_memory_info(scheduler: aidge_core.Scheduler, stats_folder raise RuntimeError( f"Input producer dimensions must be with [:, channel, height, width] but got {tensor.dims()} instead") - name = node.name() - offset = 0 # Suppose input data is stored outside the export function - # so the memory offset is not important to consider # TODO : use get_chan get_height and get_width function ! - size = tensor.dims()[-3] # Should be nb_channels - stride = tensor.dims()[-3] # Should be nb_channels - length = tensor.dims()[-1] # Should be width - count = tensor.dims()[-2] # Should be height - cont_offset = 0 # Suppose input data is stored outside the export function - # so the memory offset is not important to consider - # Size of input - cont_size = tensor.dims()[-1] * \ - tensor.dims()[-2] * tensor.dims()[-3] - wrap_offset = 0 # No wrapping - wrap_size = 0 # No wrapping + node_mem_info.append({ + "size": tensor.dims()[-3], # Should be nb_channels + "offset": 0, # Suppose input data is stored outside the export function + # so the memory offset is not important to consider + "stride": tensor.dims()[-3], # Should be nb_channels + "length": tensor.dims()[-1], # Should be width + "count": tensor.dims()[-2], # Should be height + "cont_offset": 0, # Suppose input data is stored outside the export function + # so the memory offset is not important to consider + "cont_size": tensor.dims()[-1] * \ + tensor.dims()[-2] * \ + tensor.dims()[-3], # Size of input + "wrap_offset": 0, # No wrapping + "wrap_size": 0 # No wrapping + }) + mem_info[node] = [{ + "size": plane.size, + "offset": plane.offset, + "stride": plane.stride, + "length": plane.length, + "count": plane.count, + "cont_offset": plane.get_contiguous_offset(), + "cont_size": plane.get_contiguous_size(), + "wrap_offset": plane.get_wrapped_offset(), + "wrap_size": plane.get_wrapped_size() + }] else: - plane = mem_planes[node][0] - name = node.name() - offset = plane.offset - size = plane.size - stride = plane.stride - length = plane.length - count = plane.count - cont_offset = plane.get_contiguous_offset() - cont_size = plane.get_contiguous_size() - wrap_offset = plane.get_wrapped_offset() - wrap_size = plane.get_wrapped_size() - - mem_info.append({ - "layer_name": name, - "size": size, - "offset": offset, - "stride": stride, - "length": length, - "count": count, - "cont_offset": cont_offset, - "cont_size": cont_size, - "wrap_offset": wrap_offset, - "wrap_size": wrap_size - }) - + node_mem_info = [] + for out_id in range(node.get_nb_outputs()): + plane = mem_planes[node][out_id] + node_mem_info.append({ + "size": plane.size, + "offset": plane.offset, + "stride": plane.stride, + "length": plane.length, + "count": plane.count, + "cont_offset": plane.get_contiguous_offset(), + "cont_size": plane.get_contiguous_size(), + "wrap_offset": plane.get_wrapped_offset(), + "wrap_size": plane.get_wrapped_size() + }) + mem_info[node] = node_mem_info return mem_size, mem_info -- GitLab