diff --git a/aidge_core/aidge_export_aidge/operator_export/producer.py b/aidge_core/aidge_export_aidge/operator_export/producer.py index 870ec319af470c8882b45402d3952de60dd0327d..c04019043a6cd7d1a0de95a0ba32f2bb7a3a4bec 100644 --- a/aidge_core/aidge_export_aidge/operator_export/producer.py +++ b/aidge_core/aidge_export_aidge/operator_export/producer.py @@ -1,9 +1,24 @@ from aidge_core.aidge_export_aidge.utils import operator_register from aidge_core.aidge_export_aidge import ROOT_EXPORT -from aidge_core import ExportNode, generate_file, generate_str +from aidge_core import DataType, ExportNode, generate_file, generate_str import numpy as np from pathlib import Path +# Convert aidge datatype to C++ type +datatype_converter = { + DataType.Float64 : "double", + DataType.Float32 : "float", + DataType.Float16 : "half_float::half", + DataType.Int8 : "int8_t", + DataType.Int16 : "int16_t", + DataType.Int32 : "int32_t", + DataType.Int64 : "int64_t", + DataType.UInt8 : "uint8_t", + DataType.UInt16 : "uint16_t", + DataType.UInt32 : "uint32_t", + DataType.UInt64 : "uint64_t" +} + @operator_register("Producer") class Producer(ExportNode): @@ -24,11 +39,16 @@ class Producer(ExportNode): filepath = export_folder / f"include/{include_path}" aidge_tensor = self.operator.get_output(0) + aidge_type = aidge_tensor.dtype() + if aidge_type in datatype_converter: + datatype = datatype_converter[aidge_type] + else: + raise RuntimeError(f"No conversion found for data type {aidge_type}.") generate_file( filepath, ROOT_EXPORT / "templates/parameter.jinja", dims = aidge_tensor.dims(), - data_t = "float", # TODO : get data from producer + data_t = datatype, # TODO : get data from producer name = self.tensor_name, values = str(aidge_tensor) ) diff --git a/python_binding/data/pybind_Data.cpp b/python_binding/data/pybind_Data.cpp index bca246c94434b280a12d070526ad4ffb2c7fbe7b..955b510e6cce6712e4738c0064836dbb733a3c3d 100644 --- a/python_binding/data/pybind_Data.cpp +++ b/python_binding/data/pybind_Data.cpp @@ -22,9 +22,11 @@ void init_Data(py::module& m){ .value("Float32", DataType::Float32) .value("Float16", DataType::Float16) .value("Int8", DataType::Int8) + .value("Int16", DataType::Int16) .value("Int32", DataType::Int32) .value("Int64", DataType::Int64) .value("UInt8", DataType::UInt8) + .value("UInt16", DataType::UInt16) .value("UInt32", DataType::UInt32) .value("UInt64", DataType::UInt64) ;