From 651b8db055b29b71ac31d0627837403394172eeb Mon Sep 17 00:00:00 2001
From: thibault allenet <thibault.allenet@cea.fr>
Date: Fri, 13 Dec 2024 15:01:48 +0000
Subject: [PATCH] Upgrade the data conversion to a generic function and
 integrate it to nodeExport class constructor

---
 aidge_core/export_utils/data_conversion.py | 28 +++++++++++----
 aidge_core/export_utils/node_export.py     | 42 +++++++++++++++++-----
 2 files changed, 55 insertions(+), 15 deletions(-)

diff --git a/aidge_core/export_utils/data_conversion.py b/aidge_core/export_utils/data_conversion.py
index 5333c6a3b..6dba5b78c 100644
--- a/aidge_core/export_utils/data_conversion.py
+++ b/aidge_core/export_utils/data_conversion.py
@@ -1,8 +1,9 @@
 import numpy as np
 import aidge_core
 
+from typing import Dict
 
-datatype_converter_aide2c = {
+datatype_converter_aidge2c = {
     aidge_core.dtype.float64 : "double",
     aidge_core.dtype.float32 : "float",
     aidge_core.dtype.float16 : "half_float::half",
@@ -26,9 +27,24 @@ def aidge2c(datatype):
     :return: A string representing the C type
     :rtype: string
     """
-    if datatype in datatype_converter_aide2c:
-        return datatype_converter_aide2c[datatype]
+    if datatype in datatype_converter_aidge2c:
+        return datatype_converter_aidge2c[datatype]
     else:
-        # raise ValueError(f"Unsupported {datatype} aidge datatype")
-        aidge_core.Log.warn(f"Unsupported conversion of {datatype} (aidge datatype) to a C type.")
-        return None
+        raise ValueError(f"Unsupported {datatype} aidge datatype")
+
+def aidge2export_type(datatype: aidge_core.dtype, conversion_map: Dict[aidge_core.dtype, str] = datatype_converter_aidge2c) -> str:
+    """Convert a aidge datatype to the export type specified by the map passed in argument
+
+    If the aidge type is not convertible, that is to say, is not specified in the map, a value Error is raised.
+
+    :param datatype: Aidge datatype to convert
+    :type datatype: :py:object:`aidge_core.DataType`
+    :param conversion_map: Map that specify the conversion
+    :type conversion_map: Dict[:py:object:`aidge_core.DataType`, str]
+    :return: A string representing the export type
+    :rtype: string
+    """
+    if datatype in conversion_map:
+        return conversion_map[datatype]
+    else:
+        raise ValueError(f"Unsupported type conversion {datatype} aidge datatype for export")
diff --git a/aidge_core/export_utils/node_export.py b/aidge_core/export_utils/node_export.py
index 5777814a0..c24727adf 100644
--- a/aidge_core/export_utils/node_export.py
+++ b/aidge_core/export_utils/node_export.py
@@ -3,7 +3,8 @@ from pathlib import Path
 
 from aidge_core.export_utils import data_conversion, code_generation
 from abc import ABC, abstractmethod
-from typing import List
+from typing import List, Dict
+
 
 
 def get_chan(tensor: aidge_core.Tensor) -> int:
@@ -14,12 +15,19 @@ def get_chan(tensor: aidge_core.Tensor) -> int:
             return dims[1]
         elif len(dims) == 2:  # Suppose NC
             return dims[1]
+        elif len(dims) == 1:  # Suppose C (for bias)
+            return dims[0]
         else:
             return None
     elif dformat == aidge_core.dformat.nchw:
         return dims[1]
     elif dformat == aidge_core.dformat.nhwc:
-        return dims[3]
+        if len(dims) == 4:  # NHWC
+            return dims[3]
+        elif len(dims) == 2:  # NC
+            return 1
+        elif len(dims) == 1:  # C for bias
+            return 1
     elif dformat == aidge_core.dformat.chwn:
         return dims[0]
     elif dformat == aidge_core.dformat.ncdhw:
@@ -40,12 +48,19 @@ def get_height(tensor: aidge_core.Tensor) -> int:
             return dims[2]
         elif len(dims) == 2:  # Suppose NC
             return 1
+        elif len(dims) == 1:  # Suppose C for bias
+            return 1
         else:
             return None
     elif dformat == aidge_core.dformat.nchw:
         return dims[2]
     elif dformat == aidge_core.dformat.nhwc:
-        return dims[1]
+        if len(dims) == 4:  # NHWC
+            return dims[1]
+        elif len(dims) == 2:  # NC
+            return 1
+        elif len(dims) == 1:  # C for bias
+            return 1
     elif dformat == aidge_core.dformat.chwn:
         return dims[1]
     elif dformat == aidge_core.dformat.ncdhw:
@@ -66,12 +81,19 @@ def get_width(tensor: aidge_core.Tensor) -> int:
             return dims[3]
         elif len(dims) == 2:  # Suppose NC
             return 1
+        elif len(dims) == 1:  # Suppose C for bias
+            return 1
         else:
             return None
     elif dformat == aidge_core.dformat.nchw:
         return dims[3]
     elif dformat == aidge_core.dformat.nhwc:
-        return dims[2]
+        if len(dims) == 4:  # NHWC
+            return dims[2]
+        elif len(dims) == 2:  # NC
+            return 1
+        elif len(dims) == 1:  # C for bias
+            return 1
     elif dformat == aidge_core.dformat.chwn:
         return dims[2]
     elif dformat == aidge_core.dformat.ncdhw:
@@ -162,7 +184,9 @@ class ExportNode(ABC):
     """
 
     @abstractmethod
-    def __init__(self, aidge_node: aidge_core.Node, mem_info: List[dict]=None) -> None:
+    def __init__(self, aidge_node: aidge_core.Node, 
+                 mem_info: List[dict]=None, 
+                 conversion_map: Dict[aidge_core.dtype, str] = data_conversion.datatype_converter_aidge2c) -> None:
         """Create ExportNode and retrieve attributes from ``aidge_node``:
         """
 
@@ -231,8 +255,8 @@ class ExportNode(ABC):
                 self.attributes["in_dformat"][idx] = tensor.dformat()
                 self.attributes["in_format"][idx] = aidge_core.format_as(tensor.dformat())
                 self.attributes["in_dtype"][idx] = tensor.dtype()
-                self.attributes["in_cdtype"][idx] = data_conversion.aidge2c(
-                    tensor.dtype())
+                # self.attributes["in_cdtype"][idx] = data_conversion.aidge2c(tensor.dtype())
+                self.attributes["in_cdtype"][idx] = data_conversion.aidge2export_type(tensor.dtype(), conversion_map)
                 self.attributes["in_chan"][idx] = get_chan(tensor)
                 self.attributes["in_height"][idx] = get_height(tensor)
                 self.attributes["in_width"][idx] = get_width(tensor)
@@ -254,8 +278,8 @@ class ExportNode(ABC):
                 self.attributes["out_dformat"][idx] = tensor.dformat()
                 self.attributes["out_format"][idx] = aidge_core.format_as(tensor.dformat())
                 self.attributes["out_dtype"][idx] = tensor.dtype()
-                self.attributes["out_cdtype"][idx] = data_conversion.aidge2c(
-                    tensor.dtype())
+                # self.attributes["out_cdtype"][idx] = data_conversion.aidge2c(tensor.dtype())
+                self.attributes["out_cdtype"][idx] = data_conversion.aidge2export_type(tensor.dtype(), conversion_map)
                 self.attributes["out_chan"][idx] = get_chan(tensor)
                 self.attributes["out_height"][idx] = get_height(tensor)
                 self.attributes["out_width"][idx] = get_width(tensor)
-- 
GitLab