From 7050788e6988db85ea06f73478e9799b472a392f Mon Sep 17 00:00:00 2001
From: thibault allenet <thibault.allenet@cea.fr>
Date: Wed, 4 Dec 2024 13:21:47 +0000
Subject: [PATCH] Add integer datatypes for weightInterleaving

---
 include/aidge/backend/cpu/data/TensorImpl.hpp |  8 +++
 include/aidge/data/Data.hpp                   | 15 +++--
 python_binding/data/pybind_Tensor.cpp         |  8 +--
 src/backend/cpu/data/TensorImpl.cpp           | 56 +++++++++++++++++++
 4 files changed, 79 insertions(+), 8 deletions(-)

diff --git a/include/aidge/backend/cpu/data/TensorImpl.hpp b/include/aidge/backend/cpu/data/TensorImpl.hpp
index 3cd4fd517..2115b660f 100644
--- a/include/aidge/backend/cpu/data/TensorImpl.hpp
+++ b/include/aidge/backend/cpu/data/TensorImpl.hpp
@@ -132,6 +132,14 @@ REGISTRAR(Tensor, {"cpu", DataType::Int3}, Aidge::TensorImpl_cpu<int8_t>::create
 REGISTRAR(Tensor, {"cpu", DataType::UInt3}, Aidge::TensorImpl_cpu<uint8_t>::create);
 REGISTRAR(Tensor, {"cpu", DataType::Int2}, Aidge::TensorImpl_cpu<int8_t>::create);
 REGISTRAR(Tensor, {"cpu", DataType::UInt2}, Aidge::TensorImpl_cpu<uint8_t>::create);
+REGISTRAR(Tensor, {"cpu", DataType::Dual_Int4}, Aidge::TensorImpl_cpu<int8_t>::create);
+REGISTRAR(Tensor, {"cpu", DataType::Dual_UInt4}, Aidge::TensorImpl_cpu<uint8_t>::create);
+REGISTRAR(Tensor, {"cpu", DataType::Dual_Int3}, Aidge::TensorImpl_cpu<int8_t>::create);
+REGISTRAR(Tensor, {"cpu", DataType::Dual_UInt3}, Aidge::TensorImpl_cpu<uint8_t>::create);
+REGISTRAR(Tensor, {"cpu", DataType::Quad_Int2}, Aidge::TensorImpl_cpu<int8_t>::create);
+REGISTRAR(Tensor, {"cpu", DataType::Quad_UInt2}, Aidge::TensorImpl_cpu<uint8_t>::create);
+REGISTRAR(Tensor, {"cpu", DataType::Binary}, Aidge::TensorImpl_cpu<int8_t>::create);
+REGISTRAR(Tensor, {"cpu", DataType::Octo_Binary}, Aidge::TensorImpl_cpu<int8_t>::create);
 REGISTRAR(Tensor, {"cpu", DataType::UInt64}, Aidge::TensorImpl_cpu<uint64_t>::create);
 REGISTRAR(Tensor, {"cpu", DataType::UInt32}, Aidge::TensorImpl_cpu<uint32_t>::create);
 REGISTRAR(Tensor, {"cpu", DataType::UInt16}, Aidge::TensorImpl_cpu<uint16_t>::create);
diff --git a/include/aidge/data/Data.hpp b/include/aidge/data/Data.hpp
index dcc45961a..5303d61f9 100644
--- a/include/aidge/data/Data.hpp
+++ b/include/aidge/data/Data.hpp
@@ -157,8 +157,17 @@ struct Octo_BinaryType {
 };
 
 
-template <Aidge::DataType D> struct AidgeNbBits { static const int nbBits; };
-template <> const int AidgeNbBits<Aidge::DataType::Int4>::nbBits = 4;
+// template <Aidge::DataType D> struct AidgeNbBits { static const int nbBits; };
+// template <> const int AidgeNbBits<Aidge::DataType::Int4>::nbBits = 4;
+
+template <Aidge::DataType D> struct WeightInterleavingType { static const Aidge::DataType type; };
+template <> const Aidge::DataType WeightInterleavingType<Aidge::DataType::Int4>::type = Aidge::DataType::Dual_Int4;
+template <> const Aidge::DataType WeightInterleavingType<Aidge::DataType::UInt4>::type = Aidge::DataType::Dual_UInt4;
+template <> const Aidge::DataType WeightInterleavingType<Aidge::DataType::Int3>::type = Aidge::DataType::Dual_Int3;
+template <> const Aidge::DataType WeightInterleavingType<Aidge::DataType::UInt3>::type = Aidge::DataType::Dual_UInt3;
+template <> const Aidge::DataType WeightInterleavingType<Aidge::DataType::Int2>::type = Aidge::DataType::Quad_Int2;
+template <> const Aidge::DataType WeightInterleavingType<Aidge::DataType::UInt2>::type = Aidge::DataType::Quad_UInt2;
+template <> const Aidge::DataType WeightInterleavingType<Aidge::DataType::Binary>::type = Aidge::DataType::Octo_Binary;
 
 
 template <typename T> struct NativeType { static const Aidge::DataType type; };
@@ -228,8 +237,6 @@ template <> struct cpptype<Aidge::DataType::UInt64> { using type = std::uint64_t
 
 template <Aidge::DataType D> using cpptype_t = typename cpptype<D>::type;
 
-
-
 }
 
 
diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp
index d57384d42..35e60e158 100644
--- a/python_binding/data/pybind_Tensor.cpp
+++ b/python_binding/data/pybind_Tensor.cpp
@@ -389,7 +389,7 @@ void init_Tensor(py::module& m){
                 return py::cast(b.get<std::uint8_t>(idx));
             case DataType::UInt2:
                 return py::cast(b.get<std::uint8_t>(idx));
-            case DataType::Dual_UInt2:
+            case DataType::Quad_UInt2:
                 return py::cast(b.get<std::uint8_t>(idx));
             case DataType::UInt16:
                 return py::cast(b.get<std::uint16_t>(idx));
@@ -444,7 +444,7 @@ void init_Tensor(py::module& m){
                 return py::cast(b.get<std::uint8_t>(coordIdx));
             case DataType::UInt2:
                 return py::cast(b.get<std::uint8_t>(coordIdx));
-            case DataType::Dual_UInt2:
+            case DataType::Quad_UInt2:
                 return py::cast(b.get<std::uint8_t>(coordIdx));
             case DataType::UInt16:
                 return py::cast(b.get<std::uint16_t>(coordIdx));
@@ -519,7 +519,7 @@ void init_Tensor(py::module& m){
             case DataType::UInt2:
                 b.set(idx, castToNativeType<std::uint8_t>(val));
                 break;
-            case DataType::Dual_UInt2:
+            case DataType::Quad_UInt2:
                 b.set(idx, castToNativeType<std::uint8_t>(val));
                 break;
             case DataType::UInt16:
@@ -599,7 +599,7 @@ void init_Tensor(py::module& m){
             case DataType::UInt2:
                 b.set(coordIdx, castToNativeType<std::uint8_t>(val));
                 break;
-            case DataType::Dual_UInt2:
+            case DataType::Quad_UInt2:
                 b.set(coordIdx, castToNativeType<std::uint8_t>(val));
                 break;
             case DataType::UInt16:
diff --git a/src/backend/cpu/data/TensorImpl.cpp b/src/backend/cpu/data/TensorImpl.cpp
index 506287a0c..236e5bb8e 100644
--- a/src/backend/cpu/data/TensorImpl.cpp
+++ b/src/backend/cpu/data/TensorImpl.cpp
@@ -95,6 +95,62 @@ void Aidge::TensorImpl_cpu<T>::copyCast(const void *src, const Aidge::DataType s
             std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length,
                     dstT);
             break;
+        case DataType::Int4:
+            std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length,
+                    dstT);
+            break;
+        case DataType::UInt4:
+            std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length,
+                    dstT);
+            break;
+        case DataType::Dual_Int4:
+            std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length,
+                    dstT);
+            break;
+        case DataType::Dual_UInt4:
+            std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length,
+                    dstT);
+            break;
+        case DataType::Int3:
+            std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length,
+                    dstT);
+            break;
+        case DataType::UInt3:
+            std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length,
+                    dstT);
+            break;
+        case DataType::Dual_Int3:
+            std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length,
+                    dstT);
+            break;
+        case DataType::Dual_UInt3:
+            std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length,
+                    dstT);
+            break;
+        case DataType::Int2:
+            std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length,
+                    dstT);
+            break;
+        case DataType::UInt2:
+            std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length,
+                    dstT);
+            break;
+        case DataType::Quad_Int2:
+            std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length,
+                    dstT);
+            break;
+        case DataType::Quad_UInt2:
+            std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length,
+                    dstT);
+            break;
+        case DataType::Binary:
+            std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length,
+                    dstT);
+            break;
+        case DataType::Octo_Binary:
+            std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length,
+                    dstT);
+            break;
         default:
             AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported data type.");
             break;
-- 
GitLab