From 052f6d607121864a7a7f1b6d33758d7c47782249 Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Tue, 9 Jul 2024 09:25:20 +0000
Subject: [PATCH] Heavy refactor of node_export to retrieve more attributes.

---
 aidge_core/export_utils/node_export.py | 201 +++++++++++++++++++++----
 1 file changed, 169 insertions(+), 32 deletions(-)

diff --git a/aidge_core/export_utils/node_export.py b/aidge_core/export_utils/node_export.py
index 80c37dd0a..78a103aea 100644
--- a/aidge_core/export_utils/node_export.py
+++ b/aidge_core/export_utils/node_export.py
@@ -1,61 +1,198 @@
 import aidge_core
 
+from aidge_core.export_utils import data_conversion
 from abc import ABC, abstractmethod
 
 
+def get_chan(tensor: aidge_core.Tensor) -> int:
+    """Given a tensor return the number of channel
+    """
+    dformat = tensor.dformat()
+    dims = tensor.dims()
+    if len(dims) == 4: # Suppose NCHW
+        return dims[1]
+    elif len(dims) == 2: # Suppose NC
+        return dims[1]
+    else:
+        return None
+    # if dformat == aidge_core.dformat.Default:
+    #     return None
+    # elif dformat == aidge_core.dformat.NCHW:
+    #     return dims[1]
+    # elif dformat == aidge_core.dformat.NHWC:
+    #     return dims[3]
+    # elif dformat == aidge_core.dformat.CHWN:
+    #     return dims[0]
+    # elif dformat == aidge_core.dformat.NCDHW:
+    #     return dims[1]
+    # elif dformat == aidge_core.dformat.NDHWC:
+    #     return dims[4]
+    # elif dformat == aidge_core.dformat.CDHWN:
+    #     return dims[0]
+    # else:
+    #     raise RuntimeError(f"Unknown dataformat: {dformat}")
+
+
+def get_height(tensor: aidge_core.Tensor) -> int:
+    dformat = tensor.dformat()
+    dims = tensor.dims()
+    if len(dims) == 4: # Suppose NCHW
+        return dims[2]
+    elif len(dims) == 2: # Suppose NC
+        return 1
+    else:
+        return None
+    # TODO: use when dformat is fully supported
+    # if dformat == aidge_core.dformat.Default:
+    #     return None
+    # elif dformat == aidge_core.dformat.NCHW:
+    #     return dims[2]
+    # elif dformat == aidge_core.dformat.NHWC:
+    #     return dims[1]
+    # elif dformat == aidge_core.dformat.CHWN:
+    #     return dims[1]
+    # elif dformat == aidge_core.dformat.NCDHW:
+    #     return dims[3]
+    # elif dformat == aidge_core.dformat.NDHWC:
+    #     return dims[2]
+    # elif dformat == aidge_core.dformat.CDHWN:
+    #     return dims[2]
+    # else:
+    #     raise RuntimeError(f"Unknown dataformat: {dformat}")
+
+
+def get_width(tensor: aidge_core.Tensor) -> int:
+    dformat = tensor.dformat()
+    dims = tensor.dims()
+    if len(dims) == 4: # Suppose NCHW
+        return dims[3]
+    elif len(dims) == 2: # Suppose NC
+        return 1
+    else:
+        return None
+    # if dformat == aidge_core.dformat.Default:
+    #     return None
+    # elif dformat == aidge_core.dformat.NCHW:
+    #     return dims[3]
+    # elif dformat == aidge_core.dformat.NHWC:
+    #     return dims[2]
+    # elif dformat == aidge_core.dformat.CHWN:
+    #     return dims[2]
+    # elif dformat == aidge_core.dformat.NCDHW:
+    #     return dims[4]
+    # elif dformat == aidge_core.dformat.NDHWC:
+    #     return dims[3]
+    # elif dformat == aidge_core.dformat.CDHWN:
+    #     return dims[3]
+    # else:
+    #     raise RuntimeError(f"Unknown dataformat: {dformat}")
+
+
 class ExportNode(ABC):
     """Abstract class to interface node with export generation.
+
+    This class expose a dictionary ``attributes`` which contains all the information required to generate an export:
+    - All the attributes of the Aidge node are automatically fetch, the key to get an attribute is the attribute name in python format, example ``no_bias``
+    - **name**: Name of the Node, ``str``
+    - **nb_in**: Number of inputs, ``int``
+    - **in_name**: unique name for each input, if no input node the name is ``{node_name}_input_{in_id}``, if there is a parent, the name is ``{parent_name}_output_{out_id}``, ``list[str]``
+    - **in_dims**: A list of the dimension for each inputs, ``list[list[int]]``
+    - **in_chan**: A list of channel for each inputs, deduced by the dataformat, ``list[int]``
+    - **in_height**: A list of height for each inputs, deduced by the dataformat, ``list[int]``
+    - **in_width**: A list of width for each inputs, deduced by the dataformat, ``list[int]``
+    - **in_dtype**: A list of type (Aidge format) for each input, ``List[:py:class:`aidge_core.dtype`]``
+    - **in_cdtype**: A list of type (C/C++ format) for each input, ``List[str]``
+    - **out_name**: unique name for each output, the name is ``{name}_output_{out_id}``, ``list[str]``
+    - **nb_out**: Number of outputs, ``int``
+    - **out_dims**: A list of the dimension for each inputs, ``list[list[int]]``
+    - **out_chan**: A list of channel for each outputs, deduced by the dataformat, ``list[int]``
+    - **out_height**: A list of height for each outputs, deduced by the dataformat, ``list[int]``
+    - **out_width**: A list of width for each outputs, deduced by the dataformat, ``list[int]``
+    - **out_dtype**: A list of type (Aidge format) for each output, ``List[:py:class:`aidge_core.dtype`]``
+    - **out_cdtype**: A list of type (C/C++ format) for each output, ``List[str]``
+    - **is_output**: True if the node is an output node, ``bool``
+    - **is_input**: True if the node is an input node, ``bool``
     """
 
     @abstractmethod
     def __init__(self, aidge_node: aidge_core.Node) -> None:
-        """Create ExportNode and retieve attirubtes from ``aidge_node``:
-
-        - name: aidge Node name
-        - attributes: dictionnary of attributes of the aidge Operator linked to the node, attributes name follow aidge naming convention
-        - parameters: List of parameters node, order in the list is the same as the one defined by the aidge operator
-
+        """Create ExportNode and retieve attriubtes from ``aidge_node``:
         """
+
         super().__init__()
         self.node = aidge_node
         self.operator = aidge_node.get_operator()
-        self.name = self.node.name()
-        self.attributes = self.operator.attr.dict() if self.operator.attr is not None else {} # Attributes are auto fetched from aidge operators
-
-        # rename is_leaf ?
-        self.is_last = len(self.node.get_children()) == 0
-
-
+        # Attributes are auto fetched from aidge operators
+        self.attributes = self.operator.attr.dict(
+        ) if self.operator.attr is not None else {}
+        self.attributes["name"] = self.node.name()
+        self.attributes["nb_in"] = self.node.get_nb_inputs()
+        self.attributes["nb_out"] = self.node.get_nb_outputs()
+        # TODO : this check doesn't work if we export a subgraph !
+        # Maybe we need to add the graph we want to export as parameter !
+        # Actually may be mandatory for memory manager ...
+        self.attributes["is_input"] = len(self.node.get_parents()) == 0
+        self.attributes["is_output"] = len(self.node.get_children()) == 0
         self.inputs = []
         self.outputs = []
-        self.inputs_dims = []
-        self.outputs_dims = []
-
-        for idx, parent_node in enumerate(self.node.get_parents()):
-            self.inputs.append(parent_node)
-            if parent_node is not None:
-                self.inputs_dims.append(self.operator.get_input(idx).dims())
-            else:
-                if self.operator.get_input(idx) is not None:
-                    self.inputs_dims.append(self.operator.get_input(idx).dims())
-                else:
-                    self.inputs_dims.append(None)
 
-        for idx, child_node in enumerate(self.node.get_children()):
-            self.outputs.append(child_node)
+        self.attributes["in_name"] = [None] * self.attributes["nb_in"]
+        self.attributes["in_dims"] = [None] * self.attributes["nb_in"]
+        self.attributes["in_dformat"] = [None] * self.attributes["nb_in"]
+        self.attributes["in_dtype"] = [None] * self.attributes["nb_in"]
+        self.attributes["in_cdtype"] = [None] * self.attributes["nb_in"]
+        self.attributes["in_chan"] = [None] * self.attributes["nb_in"]
+        self.attributes["in_height"] = [None] * self.attributes["nb_in"]
+        self.attributes["in_width"] = [None] * self.attributes["nb_in"]
 
-        # Dirty hot fix, change it quickly
-        self.outputs_dims.append(self.operator.get_output(0).dims())
+        self.attributes["out_name"] = [None] * self.attributes["nb_out"]
+        self.attributes["out_dims"] = [None] * self.attributes["nb_out"]
+        self.attributes["out_dformat"] = [None] * self.attributes["nb_out"]
+        self.attributes["out_dtype"] = [None] * self.attributes["nb_out"]
+        self.attributes["out_cdtype"] = [None] * self.attributes["nb_out"]
+        self.attributes["out_chan"] = [None] * self.attributes["nb_out"]
+        self.attributes["out_height"] = [None] * self.attributes["nb_out"]
+        self.attributes["out_width"] = [None] * self.attributes["nb_out"]
 
+        for idx, parent_node_in_id in enumerate(self.node.inputs()):
+            parent_node, out_id = parent_node_in_id
+            self.inputs.append(parent_node)
+            if self.operator.get_input(idx) is not None:
+                tensor = self.operator.get_input(idx)
+                self.attributes["in_name"][idx] = f"{self.attributes['name']}_input_{idx}" if parent_node is None else f"{parent_node.name()}_output_{out_id}"
+                self.attributes["in_dims"][idx] = tensor.dims()
+                self.attributes["in_dformat"][idx] = tensor.dformat()
+                self.attributes["in_dtype"][idx] = tensor.dtype()
+                self.attributes["in_cdtype"][idx] = data_conversion.aidge2c(
+                    tensor.dtype())
+                self.attributes["in_chan"][idx] = get_chan(tensor)
+                self.attributes["in_height"][idx] = get_height(tensor)
+                self.attributes["in_width"][idx] = get_width(tensor)
+            else:
+                print(f"No input for {self.node.name()}")
+        for idx, list_child_node_in_id in enumerate(self.node.outputs()):
+            self.outputs += [node_in_id[0] for node_in_id in list_child_node_in_id]
+            if self.operator.get_output(idx) is not None:
+                tensor = self.operator.get_output(idx)
+                self.attributes["out_name"][idx] = f"{self.attributes['name']}_output_{idx}"
+                self.attributes["out_dims"][idx] = tensor.dims()
+                self.attributes["out_dformat"][idx] = tensor.dformat()
+                self.attributes["out_dtype"][idx] = tensor.dtype()
+                self.attributes["out_cdtype"][idx] = data_conversion.aidge2c(
+                    tensor.dtype())
+                self.attributes["out_chan"][idx] = get_chan(tensor)
+                self.attributes["out_height"][idx] = get_height(tensor)
+                self.attributes["out_width"][idx] = get_width(tensor)
+            else:
+                print(f"No output for {self.node.name()}")
     @abstractmethod
-    def export(self, export_folder:str, list_configs:list):
+    def export(self, export_folder: str, list_configs: list):
         """Define how to export the node definition.
         """
         pass
 
     @abstractmethod
-    def forward(self, list_actions:list):
+    def forward(self, list_actions: list):
         """Define how to generate code to perform a forward pass.
         """
         pass
-
-- 
GitLab