From b43080da831066bf171b52e296550b55bf541161 Mon Sep 17 00:00:00 2001 From: thibault allenet <thibault.allenet@cea.fr> Date: Thu, 21 Nov 2024 14:07:01 +0000 Subject: [PATCH] Add support of UInt4, Int3, UInt3, Int2, UInt2 --- include/aidge/backend/cpu/data/TensorImpl.hpp | 5 ++++ include/aidge/data/Data.hpp | 29 +++++++++++++++++++ python_binding/data/pybind_Tensor.cpp | 15 ++++++++++ src/data/Tensor.cpp | 10 +++++++ 4 files changed, 59 insertions(+) diff --git a/include/aidge/backend/cpu/data/TensorImpl.hpp b/include/aidge/backend/cpu/data/TensorImpl.hpp index 5f4e25772..770a37b07 100644 --- a/include/aidge/backend/cpu/data/TensorImpl.hpp +++ b/include/aidge/backend/cpu/data/TensorImpl.hpp @@ -127,6 +127,11 @@ REGISTRAR(Tensor, {"cpu", DataType::Int32}, Aidge::TensorImpl_cpu<int32_t>::crea REGISTRAR(Tensor, {"cpu", DataType::Int16}, Aidge::TensorImpl_cpu<int16_t>::create); REGISTRAR(Tensor, {"cpu", DataType::Int8}, Aidge::TensorImpl_cpu<int8_t>::create); REGISTRAR(Tensor, {"cpu", DataType::Int4}, Aidge::TensorImpl_cpu<int8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::UInt4}, Aidge::TensorImpl_cpu<uint8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::Int3}, Aidge::TensorImpl_cpu<int8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::UInt3}, Aidge::TensorImpl_cpu<uint8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::Int2}, Aidge::TensorImpl_cpu<int8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::UInt2}, Aidge::TensorImpl_cpu<uint8_t>::create); REGISTRAR(Tensor, {"cpu", DataType::UInt64}, Aidge::TensorImpl_cpu<uint64_t>::create); REGISTRAR(Tensor, {"cpu", DataType::UInt32}, Aidge::TensorImpl_cpu<uint32_t>::create); REGISTRAR(Tensor, {"cpu", DataType::UInt16}, Aidge::TensorImpl_cpu<uint16_t>::create); diff --git a/include/aidge/data/Data.hpp b/include/aidge/data/Data.hpp index d4d206409..acf5b950e 100644 --- a/include/aidge/data/Data.hpp +++ b/include/aidge/data/Data.hpp @@ -106,12 +106,36 @@ namespace { struct Int4Type { std::int8_t value; }; +struct UInt4Type { + std::uint8_t value; +}; +struct Int3Type { + std::int8_t value; +}; +struct UInt3Type { + std::uint8_t value; +}; +struct Int2Type { + std::int8_t value; +}; +struct UInt2Type { + std::uint8_t value; +}; + +template <Aidge::DataType D> struct AidgeNbBits { static const int nbBits; }; +template <> const int AidgeNbBits<Aidge::DataType::Int4>::nbBits = 4; + template <typename T> struct NativeType { static const Aidge::DataType type; }; template <> const Aidge::DataType NativeType<double>::type = Aidge::DataType::Float64; template <> const Aidge::DataType NativeType<float>::type = Aidge::DataType::Float32; template <> const Aidge::DataType NativeType<half_float::half>::type = Aidge::DataType::Float16; template <> const Aidge::DataType NativeType<Int4Type>::type = Aidge::DataType::Int4; +template <> const Aidge::DataType NativeType<UInt4Type>::type = Aidge::DataType::UInt4; +template <> const Aidge::DataType NativeType<Int3Type>::type = Aidge::DataType::Int3; +template <> const Aidge::DataType NativeType<UInt3Type>::type = Aidge::DataType::UInt3; +template <> const Aidge::DataType NativeType<Int2Type>::type = Aidge::DataType::Int2; +template <> const Aidge::DataType NativeType<UInt2Type>::type = Aidge::DataType::UInt2; template <> const Aidge::DataType NativeType<std::int8_t>::type = Aidge::DataType::Int8; template <> const Aidge::DataType NativeType<std::int16_t>::type = Aidge::DataType::Int16; template <> const Aidge::DataType NativeType<std::int32_t>::type = Aidge::DataType::Int32; @@ -139,6 +163,11 @@ template <> struct cpptype<Aidge::DataType::Float16> { using type = half_float:: template <> struct cpptype<Aidge::DataType::Float32> { using type = float; }; template <> struct cpptype<Aidge::DataType::Float64> { using type = double; }; template <> struct cpptype<Aidge::DataType::Int4> { using type = Int4Type; }; +template <> struct cpptype<Aidge::DataType::UInt4> { using type = UInt4Type; }; +template <> struct cpptype<Aidge::DataType::Int3> { using type = Int3Type; }; +template <> struct cpptype<Aidge::DataType::UInt3> { using type = UInt3Type; }; +template <> struct cpptype<Aidge::DataType::Int2> { using type = Int2Type; }; +template <> struct cpptype<Aidge::DataType::UInt2> { using type = UInt2Type; }; template <> struct cpptype<Aidge::DataType::Int8> { using type = std::int8_t; }; template <> struct cpptype<Aidge::DataType::Int16> { using type = std::int16_t; }; template <> struct cpptype<Aidge::DataType::Int32> { using type = std::int32_t; }; diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp index 7e910be57..d94acaa81 100644 --- a/python_binding/data/pybind_Tensor.cpp +++ b/python_binding/data/pybind_Tensor.cpp @@ -505,6 +505,21 @@ void init_Tensor(py::module& m){ case DataType::Int4: dataFormatDescriptor = py::format_descriptor<std::int8_t>::format(); break; + case DataType::UInt4: + dataFormatDescriptor = py::format_descriptor<std::uint8_t>::format(); + break; + case DataType::Int3: + dataFormatDescriptor = py::format_descriptor<std::int8_t>::format(); + break; + case DataType::UInt3: + dataFormatDescriptor = py::format_descriptor<std::uint8_t>::format(); + break; + case DataType::Int2: + dataFormatDescriptor = py::format_descriptor<std::int8_t>::format(); + break; + case DataType::UInt2: + dataFormatDescriptor = py::format_descriptor<std::uint8_t>::format(); + break; case DataType::Int8: dataFormatDescriptor = py::format_descriptor<std::int8_t>::format(); break; diff --git a/src/data/Tensor.cpp b/src/data/Tensor.cpp index 2ba19a749..2a06aacbe 100644 --- a/src/data/Tensor.cpp +++ b/src/data/Tensor.cpp @@ -251,6 +251,16 @@ std::string Aidge::Tensor::toString() const { return std::to_string(static_cast<half_float::half*>(ptr)[idx]); case DataType::Int4: return std::to_string(static_cast<int8_t*>(ptr)[idx]); + case DataType::UInt4: + return std::to_string(static_cast<uint8_t*>(ptr)[idx]); + case DataType::Int3: + return std::to_string(static_cast<int8_t*>(ptr)[idx]); + case DataType::UInt3: + return std::to_string(static_cast<uint8_t*>(ptr)[idx]); + case DataType::Int2: + return std::to_string(static_cast<int8_t*>(ptr)[idx]); + case DataType::UInt2: + return std::to_string(static_cast<uint8_t*>(ptr)[idx]); case DataType::Int8: return std::to_string(static_cast<int8_t*>(ptr)[idx]); case DataType::Int16: -- GitLab