From 9fb6df60ed21ea765f8682172f125585bb6a69c8 Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Tue, 4 Jun 2024 11:50:52 +0000
Subject: [PATCH] Update export to support more than just float datatype.

---
 .../operator_export/producer.py               | 24 +++++++++++++++++--
 python_binding/data/pybind_Data.cpp           |  2 ++
 2 files changed, 24 insertions(+), 2 deletions(-)

diff --git a/aidge_core/aidge_export_aidge/operator_export/producer.py b/aidge_core/aidge_export_aidge/operator_export/producer.py
index 870ec319a..c04019043 100644
--- a/aidge_core/aidge_export_aidge/operator_export/producer.py
+++ b/aidge_core/aidge_export_aidge/operator_export/producer.py
@@ -1,9 +1,24 @@
 from aidge_core.aidge_export_aidge.utils import operator_register
 from aidge_core.aidge_export_aidge import ROOT_EXPORT
-from aidge_core import ExportNode, generate_file, generate_str
+from aidge_core import DataType, ExportNode, generate_file, generate_str
 import numpy as np
 from pathlib import Path
 
+# Convert aidge datatype to C++ type
+datatype_converter = {
+    DataType.Float64 : "double",
+    DataType.Float32 : "float",
+    DataType.Float16 : "half_float::half",
+    DataType.Int8    : "int8_t",
+    DataType.Int16   : "int16_t",
+    DataType.Int32   : "int32_t",
+    DataType.Int64   : "int64_t",
+    DataType.UInt8   : "uint8_t",
+    DataType.UInt16  : "uint16_t",
+    DataType.UInt32  : "uint32_t",
+    DataType.UInt64  : "uint64_t"
+}
+
 
 @operator_register("Producer")
 class Producer(ExportNode):
@@ -24,11 +39,16 @@ class Producer(ExportNode):
         filepath = export_folder / f"include/{include_path}"
 
         aidge_tensor = self.operator.get_output(0)
+        aidge_type = aidge_tensor.dtype()
+        if aidge_type in datatype_converter:
+            datatype = datatype_converter[aidge_type]
+        else:
+            raise RuntimeError(f"No conversion found for data type {aidge_type}.")
         generate_file(
             filepath,
             ROOT_EXPORT / "templates/parameter.jinja",
             dims = aidge_tensor.dims(),
-            data_t = "float", # TODO : get data from producer
+            data_t = datatype, # TODO : get data from producer
             name = self.tensor_name,
             values = str(aidge_tensor)
         )
diff --git a/python_binding/data/pybind_Data.cpp b/python_binding/data/pybind_Data.cpp
index bca246c94..955b510e6 100644
--- a/python_binding/data/pybind_Data.cpp
+++ b/python_binding/data/pybind_Data.cpp
@@ -22,9 +22,11 @@ void init_Data(py::module& m){
     .value("Float32", DataType::Float32)
     .value("Float16", DataType::Float16)
     .value("Int8", DataType::Int8)
+    .value("Int16", DataType::Int16)
     .value("Int32", DataType::Int32)
     .value("Int64", DataType::Int64)
     .value("UInt8", DataType::UInt8)
+    .value("UInt16", DataType::UInt16)
     .value("UInt32", DataType::UInt32)
     .value("UInt64", DataType::UInt64)
     ;
-- 
GitLab