From 0b474a55a2a1dc123c721db3e934af72e585a2ad Mon Sep 17 00:00:00 2001 From: thibault allenet <thibault.allenet@cea.fr> Date: Wed, 6 Nov 2024 12:56:52 +0000 Subject: [PATCH] Add type aidge::int4 with tensor implementation int8_t --- include/aidge/backend/cpu/data/TensorImpl.hpp | 1 + include/aidge/data/Data.hpp | 11 +++++++++++ python_binding/data/pybind_Tensor.cpp | 5 +++++ src/data/Tensor.cpp | 2 ++ 4 files changed, 19 insertions(+) diff --git a/include/aidge/backend/cpu/data/TensorImpl.hpp b/include/aidge/backend/cpu/data/TensorImpl.hpp index 9390fe586..5f4e25772 100644 --- a/include/aidge/backend/cpu/data/TensorImpl.hpp +++ b/include/aidge/backend/cpu/data/TensorImpl.hpp @@ -126,6 +126,7 @@ REGISTRAR(Tensor, {"cpu", DataType::Int64}, Aidge::TensorImpl_cpu<int64_t>::crea REGISTRAR(Tensor, {"cpu", DataType::Int32}, Aidge::TensorImpl_cpu<int32_t>::create); REGISTRAR(Tensor, {"cpu", DataType::Int16}, Aidge::TensorImpl_cpu<int16_t>::create); REGISTRAR(Tensor, {"cpu", DataType::Int8}, Aidge::TensorImpl_cpu<int8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::Int4}, Aidge::TensorImpl_cpu<int8_t>::create); REGISTRAR(Tensor, {"cpu", DataType::UInt64}, Aidge::TensorImpl_cpu<uint64_t>::create); REGISTRAR(Tensor, {"cpu", DataType::UInt32}, Aidge::TensorImpl_cpu<uint32_t>::create); REGISTRAR(Tensor, {"cpu", DataType::UInt16}, Aidge::TensorImpl_cpu<uint16_t>::create); diff --git a/include/aidge/data/Data.hpp b/include/aidge/data/Data.hpp index 23221e653..d4d206409 100644 --- a/include/aidge/data/Data.hpp +++ b/include/aidge/data/Data.hpp @@ -101,10 +101,17 @@ private: } namespace { + +// Define a distinct type alias for Int4 +struct Int4Type { + std::int8_t value; +}; + template <typename T> struct NativeType { static const Aidge::DataType type; }; template <> const Aidge::DataType NativeType<double>::type = Aidge::DataType::Float64; template <> const Aidge::DataType NativeType<float>::type = Aidge::DataType::Float32; template <> const Aidge::DataType NativeType<half_float::half>::type = Aidge::DataType::Float16; +template <> const Aidge::DataType NativeType<Int4Type>::type = Aidge::DataType::Int4; template <> const Aidge::DataType NativeType<std::int8_t>::type = Aidge::DataType::Int8; template <> const Aidge::DataType NativeType<std::int16_t>::type = Aidge::DataType::Int16; template <> const Aidge::DataType NativeType<std::int32_t>::type = Aidge::DataType::Int32; @@ -131,6 +138,7 @@ template <Aidge::DataType D> struct cpptype { template <> struct cpptype<Aidge::DataType::Float16> { using type = half_float::half; }; template <> struct cpptype<Aidge::DataType::Float32> { using type = float; }; template <> struct cpptype<Aidge::DataType::Float64> { using type = double; }; +template <> struct cpptype<Aidge::DataType::Int4> { using type = Int4Type; }; template <> struct cpptype<Aidge::DataType::Int8> { using type = std::int8_t; }; template <> struct cpptype<Aidge::DataType::Int16> { using type = std::int16_t; }; template <> struct cpptype<Aidge::DataType::Int32> { using type = std::int32_t; }; @@ -141,6 +149,9 @@ template <> struct cpptype<Aidge::DataType::UInt32> { using type = std::uint32_t template <> struct cpptype<Aidge::DataType::UInt64> { using type = std::uint64_t; }; template <Aidge::DataType D> using cpptype_t = typename cpptype<D>::type; + + + } diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp index fe606cfb5..7e910be57 100644 --- a/python_binding/data/pybind_Tensor.cpp +++ b/python_binding/data/pybind_Tensor.cpp @@ -226,6 +226,8 @@ static T castToNativeType(const py::object val_obj) { DataType dtype; getConservativeNativeVal(val_obj, &val, &dtype); switch (dtype) { + case DataType::Int4: + return (T)val.i8; case DataType::Int8: return (T)val.i8; case DataType::Int16: @@ -500,6 +502,9 @@ void init_Tensor(py::module& m){ case DataType::Float32: dataFormatDescriptor = py::format_descriptor<float>::format(); break;; + case DataType::Int4: + dataFormatDescriptor = py::format_descriptor<std::int8_t>::format(); + break; case DataType::Int8: dataFormatDescriptor = py::format_descriptor<std::int8_t>::format(); break; diff --git a/src/data/Tensor.cpp b/src/data/Tensor.cpp index 6f60d2f15..2ba19a749 100644 --- a/src/data/Tensor.cpp +++ b/src/data/Tensor.cpp @@ -249,6 +249,8 @@ std::string Aidge::Tensor::toString() const { return std::to_string(static_cast<float*>(ptr)[idx]); case DataType::Float16: return std::to_string(static_cast<half_float::half*>(ptr)[idx]); + case DataType::Int4: + return std::to_string(static_cast<int8_t*>(ptr)[idx]); case DataType::Int8: return std::to_string(static_cast<int8_t*>(ptr)[idx]); case DataType::Int16: -- GitLab