From 651b8db055b29b71ac31d0627837403394172eeb Mon Sep 17 00:00:00 2001 From: thibault allenet <thibault.allenet@cea.fr> Date: Fri, 13 Dec 2024 15:01:48 +0000 Subject: [PATCH] Upgrade the data conversion to a generic function and integrate it to nodeExport class constructor --- aidge_core/export_utils/data_conversion.py | 28 +++++++++++---- aidge_core/export_utils/node_export.py | 42 +++++++++++++++++----- 2 files changed, 55 insertions(+), 15 deletions(-) diff --git a/aidge_core/export_utils/data_conversion.py b/aidge_core/export_utils/data_conversion.py index 5333c6a3b..6dba5b78c 100644 --- a/aidge_core/export_utils/data_conversion.py +++ b/aidge_core/export_utils/data_conversion.py @@ -1,8 +1,9 @@ import numpy as np import aidge_core +from typing import Dict -datatype_converter_aide2c = { +datatype_converter_aidge2c = { aidge_core.dtype.float64 : "double", aidge_core.dtype.float32 : "float", aidge_core.dtype.float16 : "half_float::half", @@ -26,9 +27,24 @@ def aidge2c(datatype): :return: A string representing the C type :rtype: string """ - if datatype in datatype_converter_aide2c: - return datatype_converter_aide2c[datatype] + if datatype in datatype_converter_aidge2c: + return datatype_converter_aidge2c[datatype] else: - # raise ValueError(f"Unsupported {datatype} aidge datatype") - aidge_core.Log.warn(f"Unsupported conversion of {datatype} (aidge datatype) to a C type.") - return None + raise ValueError(f"Unsupported {datatype} aidge datatype") + +def aidge2export_type(datatype: aidge_core.dtype, conversion_map: Dict[aidge_core.dtype, str] = datatype_converter_aidge2c) -> str: + """Convert a aidge datatype to the export type specified by the map passed in argument + + If the aidge type is not convertible, that is to say, is not specified in the map, a value Error is raised. + + :param datatype: Aidge datatype to convert + :type datatype: :py:object:`aidge_core.DataType` + :param conversion_map: Map that specify the conversion + :type conversion_map: Dict[:py:object:`aidge_core.DataType`, str] + :return: A string representing the export type + :rtype: string + """ + if datatype in conversion_map: + return conversion_map[datatype] + else: + raise ValueError(f"Unsupported type conversion {datatype} aidge datatype for export") diff --git a/aidge_core/export_utils/node_export.py b/aidge_core/export_utils/node_export.py index 5777814a0..c24727adf 100644 --- a/aidge_core/export_utils/node_export.py +++ b/aidge_core/export_utils/node_export.py @@ -3,7 +3,8 @@ from pathlib import Path from aidge_core.export_utils import data_conversion, code_generation from abc import ABC, abstractmethod -from typing import List +from typing import List, Dict + def get_chan(tensor: aidge_core.Tensor) -> int: @@ -14,12 +15,19 @@ def get_chan(tensor: aidge_core.Tensor) -> int: return dims[1] elif len(dims) == 2: # Suppose NC return dims[1] + elif len(dims) == 1: # Suppose C (for bias) + return dims[0] else: return None elif dformat == aidge_core.dformat.nchw: return dims[1] elif dformat == aidge_core.dformat.nhwc: - return dims[3] + if len(dims) == 4: # NHWC + return dims[3] + elif len(dims) == 2: # NC + return 1 + elif len(dims) == 1: # C for bias + return 1 elif dformat == aidge_core.dformat.chwn: return dims[0] elif dformat == aidge_core.dformat.ncdhw: @@ -40,12 +48,19 @@ def get_height(tensor: aidge_core.Tensor) -> int: return dims[2] elif len(dims) == 2: # Suppose NC return 1 + elif len(dims) == 1: # Suppose C for bias + return 1 else: return None elif dformat == aidge_core.dformat.nchw: return dims[2] elif dformat == aidge_core.dformat.nhwc: - return dims[1] + if len(dims) == 4: # NHWC + return dims[1] + elif len(dims) == 2: # NC + return 1 + elif len(dims) == 1: # C for bias + return 1 elif dformat == aidge_core.dformat.chwn: return dims[1] elif dformat == aidge_core.dformat.ncdhw: @@ -66,12 +81,19 @@ def get_width(tensor: aidge_core.Tensor) -> int: return dims[3] elif len(dims) == 2: # Suppose NC return 1 + elif len(dims) == 1: # Suppose C for bias + return 1 else: return None elif dformat == aidge_core.dformat.nchw: return dims[3] elif dformat == aidge_core.dformat.nhwc: - return dims[2] + if len(dims) == 4: # NHWC + return dims[2] + elif len(dims) == 2: # NC + return 1 + elif len(dims) == 1: # C for bias + return 1 elif dformat == aidge_core.dformat.chwn: return dims[2] elif dformat == aidge_core.dformat.ncdhw: @@ -162,7 +184,9 @@ class ExportNode(ABC): """ @abstractmethod - def __init__(self, aidge_node: aidge_core.Node, mem_info: List[dict]=None) -> None: + def __init__(self, aidge_node: aidge_core.Node, + mem_info: List[dict]=None, + conversion_map: Dict[aidge_core.dtype, str] = data_conversion.datatype_converter_aidge2c) -> None: """Create ExportNode and retrieve attributes from ``aidge_node``: """ @@ -231,8 +255,8 @@ class ExportNode(ABC): self.attributes["in_dformat"][idx] = tensor.dformat() self.attributes["in_format"][idx] = aidge_core.format_as(tensor.dformat()) self.attributes["in_dtype"][idx] = tensor.dtype() - self.attributes["in_cdtype"][idx] = data_conversion.aidge2c( - tensor.dtype()) + # self.attributes["in_cdtype"][idx] = data_conversion.aidge2c(tensor.dtype()) + self.attributes["in_cdtype"][idx] = data_conversion.aidge2export_type(tensor.dtype(), conversion_map) self.attributes["in_chan"][idx] = get_chan(tensor) self.attributes["in_height"][idx] = get_height(tensor) self.attributes["in_width"][idx] = get_width(tensor) @@ -254,8 +278,8 @@ class ExportNode(ABC): self.attributes["out_dformat"][idx] = tensor.dformat() self.attributes["out_format"][idx] = aidge_core.format_as(tensor.dformat()) self.attributes["out_dtype"][idx] = tensor.dtype() - self.attributes["out_cdtype"][idx] = data_conversion.aidge2c( - tensor.dtype()) + # self.attributes["out_cdtype"][idx] = data_conversion.aidge2c(tensor.dtype()) + self.attributes["out_cdtype"][idx] = data_conversion.aidge2export_type(tensor.dtype(), conversion_map) self.attributes["out_chan"][idx] = get_chan(tensor) self.attributes["out_height"][idx] = get_height(tensor) self.attributes["out_width"][idx] = get_width(tensor) -- GitLab