Skip to content
Snippets Groups Projects
Commit 651b8db0 authored by Thibault Allenet's avatar Thibault Allenet
Browse files

Upgrade the data conversion to a generic function and integrate it to nodeExport class constructor

parent e7b038f2
No related branches found
No related tags found
2 merge requests!318[Upd] release verision 0.5.0,!262Low bit support for ARM Cortex-M export
Pipeline #61729 passed
import numpy as np
import aidge_core
from typing import Dict
datatype_converter_aide2c = {
datatype_converter_aidge2c = {
aidge_core.dtype.float64 : "double",
aidge_core.dtype.float32 : "float",
aidge_core.dtype.float16 : "half_float::half",
......@@ -26,9 +27,24 @@ def aidge2c(datatype):
:return: A string representing the C type
:rtype: string
"""
if datatype in datatype_converter_aide2c:
return datatype_converter_aide2c[datatype]
if datatype in datatype_converter_aidge2c:
return datatype_converter_aidge2c[datatype]
else:
# raise ValueError(f"Unsupported {datatype} aidge datatype")
aidge_core.Log.warn(f"Unsupported conversion of {datatype} (aidge datatype) to a C type.")
return None
raise ValueError(f"Unsupported {datatype} aidge datatype")
def aidge2export_type(datatype: aidge_core.dtype, conversion_map: Dict[aidge_core.dtype, str] = datatype_converter_aidge2c) -> str:
"""Convert a aidge datatype to the export type specified by the map passed in argument
If the aidge type is not convertible, that is to say, is not specified in the map, a value Error is raised.
:param datatype: Aidge datatype to convert
:type datatype: :py:object:`aidge_core.DataType`
:param conversion_map: Map that specify the conversion
:type conversion_map: Dict[:py:object:`aidge_core.DataType`, str]
:return: A string representing the export type
:rtype: string
"""
if datatype in conversion_map:
return conversion_map[datatype]
else:
raise ValueError(f"Unsupported type conversion {datatype} aidge datatype for export")
......@@ -3,7 +3,8 @@ from pathlib import Path
from aidge_core.export_utils import data_conversion, code_generation
from abc import ABC, abstractmethod
from typing import List
from typing import List, Dict
def get_chan(tensor: aidge_core.Tensor) -> int:
......@@ -14,12 +15,19 @@ def get_chan(tensor: aidge_core.Tensor) -> int:
return dims[1]
elif len(dims) == 2: # Suppose NC
return dims[1]
elif len(dims) == 1: # Suppose C (for bias)
return dims[0]
else:
return None
elif dformat == aidge_core.dformat.nchw:
return dims[1]
elif dformat == aidge_core.dformat.nhwc:
return dims[3]
if len(dims) == 4: # NHWC
return dims[3]
elif len(dims) == 2: # NC
return 1
elif len(dims) == 1: # C for bias
return 1
elif dformat == aidge_core.dformat.chwn:
return dims[0]
elif dformat == aidge_core.dformat.ncdhw:
......@@ -40,12 +48,19 @@ def get_height(tensor: aidge_core.Tensor) -> int:
return dims[2]
elif len(dims) == 2: # Suppose NC
return 1
elif len(dims) == 1: # Suppose C for bias
return 1
else:
return None
elif dformat == aidge_core.dformat.nchw:
return dims[2]
elif dformat == aidge_core.dformat.nhwc:
return dims[1]
if len(dims) == 4: # NHWC
return dims[1]
elif len(dims) == 2: # NC
return 1
elif len(dims) == 1: # C for bias
return 1
elif dformat == aidge_core.dformat.chwn:
return dims[1]
elif dformat == aidge_core.dformat.ncdhw:
......@@ -66,12 +81,19 @@ def get_width(tensor: aidge_core.Tensor) -> int:
return dims[3]
elif len(dims) == 2: # Suppose NC
return 1
elif len(dims) == 1: # Suppose C for bias
return 1
else:
return None
elif dformat == aidge_core.dformat.nchw:
return dims[3]
elif dformat == aidge_core.dformat.nhwc:
return dims[2]
if len(dims) == 4: # NHWC
return dims[2]
elif len(dims) == 2: # NC
return 1
elif len(dims) == 1: # C for bias
return 1
elif dformat == aidge_core.dformat.chwn:
return dims[2]
elif dformat == aidge_core.dformat.ncdhw:
......@@ -162,7 +184,9 @@ class ExportNode(ABC):
"""
@abstractmethod
def __init__(self, aidge_node: aidge_core.Node, mem_info: List[dict]=None) -> None:
def __init__(self, aidge_node: aidge_core.Node,
mem_info: List[dict]=None,
conversion_map: Dict[aidge_core.dtype, str] = data_conversion.datatype_converter_aidge2c) -> None:
"""Create ExportNode and retrieve attributes from ``aidge_node``:
"""
......@@ -231,8 +255,8 @@ class ExportNode(ABC):
self.attributes["in_dformat"][idx] = tensor.dformat()
self.attributes["in_format"][idx] = aidge_core.format_as(tensor.dformat())
self.attributes["in_dtype"][idx] = tensor.dtype()
self.attributes["in_cdtype"][idx] = data_conversion.aidge2c(
tensor.dtype())
# self.attributes["in_cdtype"][idx] = data_conversion.aidge2c(tensor.dtype())
self.attributes["in_cdtype"][idx] = data_conversion.aidge2export_type(tensor.dtype(), conversion_map)
self.attributes["in_chan"][idx] = get_chan(tensor)
self.attributes["in_height"][idx] = get_height(tensor)
self.attributes["in_width"][idx] = get_width(tensor)
......@@ -254,8 +278,8 @@ class ExportNode(ABC):
self.attributes["out_dformat"][idx] = tensor.dformat()
self.attributes["out_format"][idx] = aidge_core.format_as(tensor.dformat())
self.attributes["out_dtype"][idx] = tensor.dtype()
self.attributes["out_cdtype"][idx] = data_conversion.aidge2c(
tensor.dtype())
# self.attributes["out_cdtype"][idx] = data_conversion.aidge2c(tensor.dtype())
self.attributes["out_cdtype"][idx] = data_conversion.aidge2export_type(tensor.dtype(), conversion_map)
self.attributes["out_chan"][idx] = get_chan(tensor)
self.attributes["out_height"][idx] = get_height(tensor)
self.attributes["out_width"][idx] = get_width(tensor)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment