From c21cbc8fd5bd38a9795e23272c05d0198c97b540 Mon Sep 17 00:00:00 2001
From: thibault allenet <thibault.allenet@cea.fr>
Date: Fri, 23 Feb 2024 13:36:36 +0000
Subject: [PATCH] Add support for Int8, Int16, Uint8, Uint16 in tensor binding

---
 python_binding/data/pybind_Tensor.cpp | 28 +++++++++++++++++++++++++++
 1 file changed, 28 insertions(+)

diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp
index e07f70eaa..f8a0567bd 100644
--- a/python_binding/data/pybind_Tensor.cpp
+++ b/python_binding/data/pybind_Tensor.cpp
@@ -96,10 +96,18 @@ void init_Tensor(py::module& m){
                 return py::cast(b.get<double>(idx));
             case DataType::Float32:
                 return py::cast(b.get<float>(idx));
+            case DataType::Int8:
+                return py::cast(b.get<std::int8_t>(idx));
+            case DataType::Int16:
+                return py::cast(b.get<std::int16_t>(idx));
             case DataType::Int32:
                 return py::cast(b.get<std::int32_t>(idx));
             case DataType::Int64:
                 return py::cast(b.get<std::int64_t>(idx));
+            case DataType::UInt8:
+                return py::cast(b.get<std::uint8_t>(idx));
+            case DataType::UInt16:
+                return py::cast(b.get<std::uint16_t>(idx));
             default:
                 return py::none();
         }
@@ -111,10 +119,18 @@ void init_Tensor(py::module& m){
                 return py::cast(b.get<double>(coordIdx));
             case DataType::Float32:
                 return py::cast(b.get<float>(coordIdx));
+            case DataType::Int8:
+                return py::cast(b.get<std::int8_t>(coordIdx));
+            case DataType::Int16:
+                return py::cast(b.get<std::int16_t>(coordIdx));
             case DataType::Int32:
                 return py::cast(b.get<std::int32_t>(coordIdx));
             case DataType::Int64:
                 return py::cast(b.get<std::int64_t>(coordIdx));
+            case DataType::UInt8:
+                return py::cast(b.get<std::uint8_t>(coordIdx));
+            case DataType::UInt16:
+                return py::cast(b.get<std::uint16_t>(coordIdx));
             default:
                 return py::none();
         }
@@ -141,6 +157,12 @@ void init_Tensor(py::module& m){
                 break;
             case DataType::Float32:
                 dataFormatDescriptor = py::format_descriptor<float>::format();
+                break;;
+            case DataType::Int8:
+                dataFormatDescriptor = py::format_descriptor<std::int8_t>::format();
+                break;
+            case DataType::Int16:
+                dataFormatDescriptor = py::format_descriptor<std::int16_t>::format();
                 break;
             case DataType::Int32:
                 dataFormatDescriptor = py::format_descriptor<std::int32_t>::format();
@@ -148,6 +170,12 @@ void init_Tensor(py::module& m){
             case DataType::Int64:
                 dataFormatDescriptor = py::format_descriptor<std::int64_t>::format();
                 break;
+            case DataType::UInt8:
+                dataFormatDescriptor = py::format_descriptor<std::uint8_t>::format();
+                break;
+            case DataType::UInt16:
+                dataFormatDescriptor = py::format_descriptor<std::uint16_t>::format();
+                break;
             default:
                 throw py::value_error("Unsupported data format");
         }
-- 
GitLab