From 092cb4bd10ab46cf1365a9e90adacfe14feb346c Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Tue, 3 Sep 2024 09:12:38 +0000
Subject: [PATCH] Adapt aidge mem_info to new API

---
 aidge_core/mem_info.py | 86 +++++++++++++++++++++---------------------
 1 file changed, 44 insertions(+), 42 deletions(-)

diff --git a/aidge_core/mem_info.py b/aidge_core/mem_info.py
index 3319e4807..b607a19f3 100644
--- a/aidge_core/mem_info.py
+++ b/aidge_core/mem_info.py
@@ -70,14 +70,14 @@ def generate_optimized_memory_info(scheduler: aidge_core.Scheduler, stats_folder
     # In the export, we currently use an unified memory buffer whose size
     # is determined by the memory peak usage
     mem_size = mem_manager.get_peak_usage()
-    mem_info = []
+    mem_info = {}
 
     mem_planes = mem_manager.get_planes()
 
     for node in scheduler.get_static_scheduling():
         if node.type() == "Producer":
-            continue  # Skipping memory management for producers
-        if node in nodes_at_input:
+            mem_info[node] = [] # No meminfo for producer
+        elif node in nodes_at_input:
             # Input memory management (suppose tensor ends with [:, channel, height, width]))
             tensor = node.get_operator().get_output(0)
             if tensor is None:
@@ -86,47 +86,49 @@ def generate_optimized_memory_info(scheduler: aidge_core.Scheduler, stats_folder
                 raise RuntimeError(
                     f"Input producer dimensions must be with [:, channel, height, width] but got {tensor.dims()} instead")
 
-            name = node.name()
-            offset = 0                  # Suppose input data is stored outside the export function
-            # so the memory offset is not important to consider
             # TODO : use get_chan get_height and get_width function !
-            size = tensor.dims()[-3]    # Should be nb_channels
-            stride = tensor.dims()[-3]  # Should be nb_channels
-            length = tensor.dims()[-1]  # Should be width
-            count = tensor.dims()[-2]   # Should be height
-            cont_offset = 0             # Suppose input data is stored outside the export function
-            # so the memory offset is not important to consider
-            # Size of input
-            cont_size = tensor.dims()[-1] * \
-                tensor.dims()[-2] * tensor.dims()[-3]
-            wrap_offset = 0     # No wrapping
-            wrap_size = 0       # No wrapping
+            node_mem_info.append({
+                    "size": tensor.dims()[-3],          # Should be nb_channels
+                    "offset": 0,                        # Suppose input data is stored outside the export function
+                                                        # so the memory offset is not important to consider
+                    "stride": tensor.dims()[-3],        # Should be nb_channels
+                    "length": tensor.dims()[-1],        # Should be width
+                    "count":  tensor.dims()[-2],        # Should be height
+                    "cont_offset": 0,                   # Suppose input data is stored outside the export function
+                                                        # so the memory offset is not important to consider
+                    "cont_size": tensor.dims()[-1] * \
+                                 tensor.dims()[-2] * \
+                                 tensor.dims()[-3],     # Size of input
+                    "wrap_offset": 0,                   # No wrapping
+                    "wrap_size": 0                      # No wrapping
+                })
+            mem_info[node] = [{
+                "size": plane.size,
+                "offset": plane.offset,
+                "stride": plane.stride,
+                "length": plane.length,
+                "count": plane.count,
+                "cont_offset": plane.get_contiguous_offset(),
+                "cont_size": plane.get_contiguous_size(),
+                "wrap_offset": plane.get_wrapped_offset(),
+                "wrap_size": plane.get_wrapped_size()
+            }]
         else:
-            plane = mem_planes[node][0]
-            name = node.name()
-            offset = plane.offset
-            size = plane.size
-            stride = plane.stride
-            length = plane.length
-            count = plane.count
-            cont_offset = plane.get_contiguous_offset()
-            cont_size = plane.get_contiguous_size()
-            wrap_offset = plane.get_wrapped_offset()
-            wrap_size = plane.get_wrapped_size()
-
-        mem_info.append({
-            "layer_name": name,
-            "size": size,
-            "offset": offset,
-            "stride": stride,
-            "length": length,
-            "count": count,
-            "cont_offset": cont_offset,
-            "cont_size": cont_size,
-            "wrap_offset": wrap_offset,
-            "wrap_size": wrap_size
-        })
-
+            node_mem_info = []
+            for out_id in range(node.get_nb_outputs()):
+                plane = mem_planes[node][out_id]
+                node_mem_info.append({
+                    "size": plane.size,
+                    "offset": plane.offset,
+                    "stride": plane.stride,
+                    "length": plane.length,
+                    "count": plane.count,
+                    "cont_offset": plane.get_contiguous_offset(),
+                    "cont_size": plane.get_contiguous_size(),
+                    "wrap_offset": plane.get_wrapped_offset(),
+                    "wrap_size": plane.get_wrapped_size()
+                })
+            mem_info[node] = node_mem_info
     return mem_size, mem_info
 
 
-- 
GitLab