Skip to content
Snippets Groups Projects
Commit 9fb6df60 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Update export to support more than just float datatype.

parent ee13ce6e
No related branches found
No related tags found
2 merge requests!152Update Aidge export to take a graph view has an argument instead of a...,!115Aidge export
from aidge_core.aidge_export_aidge.utils import operator_register from aidge_core.aidge_export_aidge.utils import operator_register
from aidge_core.aidge_export_aidge import ROOT_EXPORT from aidge_core.aidge_export_aidge import ROOT_EXPORT
from aidge_core import ExportNode, generate_file, generate_str from aidge_core import DataType, ExportNode, generate_file, generate_str
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
# Convert aidge datatype to C++ type
datatype_converter = {
DataType.Float64 : "double",
DataType.Float32 : "float",
DataType.Float16 : "half_float::half",
DataType.Int8 : "int8_t",
DataType.Int16 : "int16_t",
DataType.Int32 : "int32_t",
DataType.Int64 : "int64_t",
DataType.UInt8 : "uint8_t",
DataType.UInt16 : "uint16_t",
DataType.UInt32 : "uint32_t",
DataType.UInt64 : "uint64_t"
}
@operator_register("Producer") @operator_register("Producer")
class Producer(ExportNode): class Producer(ExportNode):
...@@ -24,11 +39,16 @@ class Producer(ExportNode): ...@@ -24,11 +39,16 @@ class Producer(ExportNode):
filepath = export_folder / f"include/{include_path}" filepath = export_folder / f"include/{include_path}"
aidge_tensor = self.operator.get_output(0) aidge_tensor = self.operator.get_output(0)
aidge_type = aidge_tensor.dtype()
if aidge_type in datatype_converter:
datatype = datatype_converter[aidge_type]
else:
raise RuntimeError(f"No conversion found for data type {aidge_type}.")
generate_file( generate_file(
filepath, filepath,
ROOT_EXPORT / "templates/parameter.jinja", ROOT_EXPORT / "templates/parameter.jinja",
dims = aidge_tensor.dims(), dims = aidge_tensor.dims(),
data_t = "float", # TODO : get data from producer data_t = datatype, # TODO : get data from producer
name = self.tensor_name, name = self.tensor_name,
values = str(aidge_tensor) values = str(aidge_tensor)
) )
......
...@@ -22,9 +22,11 @@ void init_Data(py::module& m){ ...@@ -22,9 +22,11 @@ void init_Data(py::module& m){
.value("Float32", DataType::Float32) .value("Float32", DataType::Float32)
.value("Float16", DataType::Float16) .value("Float16", DataType::Float16)
.value("Int8", DataType::Int8) .value("Int8", DataType::Int8)
.value("Int16", DataType::Int16)
.value("Int32", DataType::Int32) .value("Int32", DataType::Int32)
.value("Int64", DataType::Int64) .value("Int64", DataType::Int64)
.value("UInt8", DataType::UInt8) .value("UInt8", DataType::UInt8)
.value("UInt16", DataType::UInt16)
.value("UInt32", DataType::UInt32) .value("UInt32", DataType::UInt32)
.value("UInt64", DataType::UInt64) .value("UInt64", DataType::UInt64)
; ;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment