From b43080da831066bf171b52e296550b55bf541161 Mon Sep 17 00:00:00 2001
From: thibault allenet <thibault.allenet@cea.fr>
Date: Thu, 21 Nov 2024 14:07:01 +0000
Subject: [PATCH] Add support of UInt4, Int3, UInt3, Int2, UInt2

---
 include/aidge/backend/cpu/data/TensorImpl.hpp |  5 ++++
 include/aidge/data/Data.hpp                   | 29 +++++++++++++++++++
 python_binding/data/pybind_Tensor.cpp         | 15 ++++++++++
 src/data/Tensor.cpp                           | 10 +++++++
 4 files changed, 59 insertions(+)

diff --git a/include/aidge/backend/cpu/data/TensorImpl.hpp b/include/aidge/backend/cpu/data/TensorImpl.hpp
index 5f4e25772..770a37b07 100644
--- a/include/aidge/backend/cpu/data/TensorImpl.hpp
+++ b/include/aidge/backend/cpu/data/TensorImpl.hpp
@@ -127,6 +127,11 @@ REGISTRAR(Tensor, {"cpu", DataType::Int32}, Aidge::TensorImpl_cpu<int32_t>::crea
 REGISTRAR(Tensor, {"cpu", DataType::Int16}, Aidge::TensorImpl_cpu<int16_t>::create);
 REGISTRAR(Tensor, {"cpu", DataType::Int8}, Aidge::TensorImpl_cpu<int8_t>::create);
 REGISTRAR(Tensor, {"cpu", DataType::Int4}, Aidge::TensorImpl_cpu<int8_t>::create);
+REGISTRAR(Tensor, {"cpu", DataType::UInt4}, Aidge::TensorImpl_cpu<uint8_t>::create);
+REGISTRAR(Tensor, {"cpu", DataType::Int3}, Aidge::TensorImpl_cpu<int8_t>::create);
+REGISTRAR(Tensor, {"cpu", DataType::UInt3}, Aidge::TensorImpl_cpu<uint8_t>::create);
+REGISTRAR(Tensor, {"cpu", DataType::Int2}, Aidge::TensorImpl_cpu<int8_t>::create);
+REGISTRAR(Tensor, {"cpu", DataType::UInt2}, Aidge::TensorImpl_cpu<uint8_t>::create);
 REGISTRAR(Tensor, {"cpu", DataType::UInt64}, Aidge::TensorImpl_cpu<uint64_t>::create);
 REGISTRAR(Tensor, {"cpu", DataType::UInt32}, Aidge::TensorImpl_cpu<uint32_t>::create);
 REGISTRAR(Tensor, {"cpu", DataType::UInt16}, Aidge::TensorImpl_cpu<uint16_t>::create);
diff --git a/include/aidge/data/Data.hpp b/include/aidge/data/Data.hpp
index d4d206409..acf5b950e 100644
--- a/include/aidge/data/Data.hpp
+++ b/include/aidge/data/Data.hpp
@@ -106,12 +106,36 @@ namespace {
 struct Int4Type {
     std::int8_t value;
 };
+struct UInt4Type {
+    std::uint8_t value;
+};
+struct Int3Type {
+    std::int8_t value;
+};
+struct UInt3Type {
+    std::uint8_t value;
+};
+struct Int2Type {
+    std::int8_t value;
+};
+struct UInt2Type {
+    std::uint8_t value;
+};
+
+template <Aidge::DataType D> struct AidgeNbBits { static const int nbBits; };
+template <> const int AidgeNbBits<Aidge::DataType::Int4>::nbBits = 4;
+
 
 template <typename T> struct NativeType { static const Aidge::DataType type; };
 template <> const Aidge::DataType NativeType<double>::type = Aidge::DataType::Float64;
 template <> const Aidge::DataType NativeType<float>::type = Aidge::DataType::Float32;
 template <> const Aidge::DataType NativeType<half_float::half>::type = Aidge::DataType::Float16;
 template <> const Aidge::DataType NativeType<Int4Type>::type = Aidge::DataType::Int4;
+template <> const Aidge::DataType NativeType<UInt4Type>::type = Aidge::DataType::UInt4;
+template <> const Aidge::DataType NativeType<Int3Type>::type = Aidge::DataType::Int3;
+template <> const Aidge::DataType NativeType<UInt3Type>::type = Aidge::DataType::UInt3;
+template <> const Aidge::DataType NativeType<Int2Type>::type = Aidge::DataType::Int2;
+template <> const Aidge::DataType NativeType<UInt2Type>::type = Aidge::DataType::UInt2;
 template <> const Aidge::DataType NativeType<std::int8_t>::type = Aidge::DataType::Int8;
 template <> const Aidge::DataType NativeType<std::int16_t>::type = Aidge::DataType::Int16;
 template <> const Aidge::DataType NativeType<std::int32_t>::type = Aidge::DataType::Int32;
@@ -139,6 +163,11 @@ template <> struct cpptype<Aidge::DataType::Float16> { using type = half_float::
 template <> struct cpptype<Aidge::DataType::Float32> { using type = float; };
 template <> struct cpptype<Aidge::DataType::Float64> { using type = double; };
 template <> struct cpptype<Aidge::DataType::Int4> { using type = Int4Type; };
+template <> struct cpptype<Aidge::DataType::UInt4> { using type = UInt4Type; };
+template <> struct cpptype<Aidge::DataType::Int3> { using type = Int3Type; };
+template <> struct cpptype<Aidge::DataType::UInt3> { using type = UInt3Type; };
+template <> struct cpptype<Aidge::DataType::Int2> { using type = Int2Type; };
+template <> struct cpptype<Aidge::DataType::UInt2> { using type = UInt2Type; };
 template <> struct cpptype<Aidge::DataType::Int8> { using type = std::int8_t; };
 template <> struct cpptype<Aidge::DataType::Int16> { using type = std::int16_t; };
 template <> struct cpptype<Aidge::DataType::Int32> { using type = std::int32_t; };
diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp
index 7e910be57..d94acaa81 100644
--- a/python_binding/data/pybind_Tensor.cpp
+++ b/python_binding/data/pybind_Tensor.cpp
@@ -505,6 +505,21 @@ void init_Tensor(py::module& m){
             case DataType::Int4:
                 dataFormatDescriptor = py::format_descriptor<std::int8_t>::format();
                 break;
+            case DataType::UInt4:
+                dataFormatDescriptor = py::format_descriptor<std::uint8_t>::format();
+                break;
+            case DataType::Int3:
+                dataFormatDescriptor = py::format_descriptor<std::int8_t>::format();
+                break;
+            case DataType::UInt3:
+                dataFormatDescriptor = py::format_descriptor<std::uint8_t>::format();
+                break;
+            case DataType::Int2:
+                dataFormatDescriptor = py::format_descriptor<std::int8_t>::format();
+                break;
+            case DataType::UInt2:
+                dataFormatDescriptor = py::format_descriptor<std::uint8_t>::format();
+                break;
             case DataType::Int8:
                 dataFormatDescriptor = py::format_descriptor<std::int8_t>::format();
                 break;
diff --git a/src/data/Tensor.cpp b/src/data/Tensor.cpp
index 2ba19a749..2a06aacbe 100644
--- a/src/data/Tensor.cpp
+++ b/src/data/Tensor.cpp
@@ -251,6 +251,16 @@ std::string Aidge::Tensor::toString() const {
                 return std::to_string(static_cast<half_float::half*>(ptr)[idx]);
             case DataType::Int4:
                 return std::to_string(static_cast<int8_t*>(ptr)[idx]);
+            case DataType::UInt4:
+                return std::to_string(static_cast<uint8_t*>(ptr)[idx]);
+            case DataType::Int3:
+                return std::to_string(static_cast<int8_t*>(ptr)[idx]);
+            case DataType::UInt3:
+                return std::to_string(static_cast<uint8_t*>(ptr)[idx]);
+            case DataType::Int2:
+                return std::to_string(static_cast<int8_t*>(ptr)[idx]);
+            case DataType::UInt2:
+                return std::to_string(static_cast<uint8_t*>(ptr)[idx]);
             case DataType::Int8:
                 return std::to_string(static_cast<int8_t*>(ptr)[idx]);
             case DataType::Int16:
-- 
GitLab