From c21cbc8fd5bd38a9795e23272c05d0198c97b540 Mon Sep 17 00:00:00 2001 From: thibault allenet <thibault.allenet@cea.fr> Date: Fri, 23 Feb 2024 13:36:36 +0000 Subject: [PATCH] Add support for Int8, Int16, Uint8, Uint16 in tensor binding --- python_binding/data/pybind_Tensor.cpp | 28 +++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp index e07f70eaa..f8a0567bd 100644 --- a/python_binding/data/pybind_Tensor.cpp +++ b/python_binding/data/pybind_Tensor.cpp @@ -96,10 +96,18 @@ void init_Tensor(py::module& m){ return py::cast(b.get<double>(idx)); case DataType::Float32: return py::cast(b.get<float>(idx)); + case DataType::Int8: + return py::cast(b.get<std::int8_t>(idx)); + case DataType::Int16: + return py::cast(b.get<std::int16_t>(idx)); case DataType::Int32: return py::cast(b.get<std::int32_t>(idx)); case DataType::Int64: return py::cast(b.get<std::int64_t>(idx)); + case DataType::UInt8: + return py::cast(b.get<std::uint8_t>(idx)); + case DataType::UInt16: + return py::cast(b.get<std::uint16_t>(idx)); default: return py::none(); } @@ -111,10 +119,18 @@ void init_Tensor(py::module& m){ return py::cast(b.get<double>(coordIdx)); case DataType::Float32: return py::cast(b.get<float>(coordIdx)); + case DataType::Int8: + return py::cast(b.get<std::int8_t>(coordIdx)); + case DataType::Int16: + return py::cast(b.get<std::int16_t>(coordIdx)); case DataType::Int32: return py::cast(b.get<std::int32_t>(coordIdx)); case DataType::Int64: return py::cast(b.get<std::int64_t>(coordIdx)); + case DataType::UInt8: + return py::cast(b.get<std::uint8_t>(coordIdx)); + case DataType::UInt16: + return py::cast(b.get<std::uint16_t>(coordIdx)); default: return py::none(); } @@ -141,6 +157,12 @@ void init_Tensor(py::module& m){ break; case DataType::Float32: dataFormatDescriptor = py::format_descriptor<float>::format(); + break;; + case DataType::Int8: + dataFormatDescriptor = py::format_descriptor<std::int8_t>::format(); + break; + case DataType::Int16: + dataFormatDescriptor = py::format_descriptor<std::int16_t>::format(); break; case DataType::Int32: dataFormatDescriptor = py::format_descriptor<std::int32_t>::format(); @@ -148,6 +170,12 @@ void init_Tensor(py::module& m){ case DataType::Int64: dataFormatDescriptor = py::format_descriptor<std::int64_t>::format(); break; + case DataType::UInt8: + dataFormatDescriptor = py::format_descriptor<std::uint8_t>::format(); + break; + case DataType::UInt16: + dataFormatDescriptor = py::format_descriptor<std::uint16_t>::format(); + break; default: throw py::value_error("Unsupported data format"); } -- GitLab