diff --git a/include/aidge/backend/cpu/data/TensorImpl.hpp b/include/aidge/backend/cpu/data/TensorImpl.hpp index 9390fe5860b5d3523886856d9b2a40752d338af5..5f4e257720bda050086f1bc17b240b410c1c8341 100644 --- a/include/aidge/backend/cpu/data/TensorImpl.hpp +++ b/include/aidge/backend/cpu/data/TensorImpl.hpp @@ -126,6 +126,7 @@ REGISTRAR(Tensor, {"cpu", DataType::Int64}, Aidge::TensorImpl_cpu<int64_t>::crea REGISTRAR(Tensor, {"cpu", DataType::Int32}, Aidge::TensorImpl_cpu<int32_t>::create); REGISTRAR(Tensor, {"cpu", DataType::Int16}, Aidge::TensorImpl_cpu<int16_t>::create); REGISTRAR(Tensor, {"cpu", DataType::Int8}, Aidge::TensorImpl_cpu<int8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::Int4}, Aidge::TensorImpl_cpu<int8_t>::create); REGISTRAR(Tensor, {"cpu", DataType::UInt64}, Aidge::TensorImpl_cpu<uint64_t>::create); REGISTRAR(Tensor, {"cpu", DataType::UInt32}, Aidge::TensorImpl_cpu<uint32_t>::create); REGISTRAR(Tensor, {"cpu", DataType::UInt16}, Aidge::TensorImpl_cpu<uint16_t>::create); diff --git a/include/aidge/data/Data.hpp b/include/aidge/data/Data.hpp index 23221e653ba725e4463b06cfabb5483a20756701..d4d206409f9de1cb026c6795b5d50ebe99567401 100644 --- a/include/aidge/data/Data.hpp +++ b/include/aidge/data/Data.hpp @@ -101,10 +101,17 @@ private: } namespace { + +// Define a distinct type alias for Int4 +struct Int4Type { + std::int8_t value; +}; + template <typename T> struct NativeType { static const Aidge::DataType type; }; template <> const Aidge::DataType NativeType<double>::type = Aidge::DataType::Float64; template <> const Aidge::DataType NativeType<float>::type = Aidge::DataType::Float32; template <> const Aidge::DataType NativeType<half_float::half>::type = Aidge::DataType::Float16; +template <> const Aidge::DataType NativeType<Int4Type>::type = Aidge::DataType::Int4; template <> const Aidge::DataType NativeType<std::int8_t>::type = Aidge::DataType::Int8; template <> const Aidge::DataType NativeType<std::int16_t>::type = Aidge::DataType::Int16; template <> const Aidge::DataType NativeType<std::int32_t>::type = Aidge::DataType::Int32; @@ -131,6 +138,7 @@ template <Aidge::DataType D> struct cpptype { template <> struct cpptype<Aidge::DataType::Float16> { using type = half_float::half; }; template <> struct cpptype<Aidge::DataType::Float32> { using type = float; }; template <> struct cpptype<Aidge::DataType::Float64> { using type = double; }; +template <> struct cpptype<Aidge::DataType::Int4> { using type = Int4Type; }; template <> struct cpptype<Aidge::DataType::Int8> { using type = std::int8_t; }; template <> struct cpptype<Aidge::DataType::Int16> { using type = std::int16_t; }; template <> struct cpptype<Aidge::DataType::Int32> { using type = std::int32_t; }; @@ -141,6 +149,9 @@ template <> struct cpptype<Aidge::DataType::UInt32> { using type = std::uint32_t template <> struct cpptype<Aidge::DataType::UInt64> { using type = std::uint64_t; }; template <Aidge::DataType D> using cpptype_t = typename cpptype<D>::type; + + + } diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp index fe606cfb557042d581e09da7419d80841d1dc2d4..7e910be5723e31b128406becba9a105108e2b7a5 100644 --- a/python_binding/data/pybind_Tensor.cpp +++ b/python_binding/data/pybind_Tensor.cpp @@ -226,6 +226,8 @@ static T castToNativeType(const py::object val_obj) { DataType dtype; getConservativeNativeVal(val_obj, &val, &dtype); switch (dtype) { + case DataType::Int4: + return (T)val.i8; case DataType::Int8: return (T)val.i8; case DataType::Int16: @@ -500,6 +502,9 @@ void init_Tensor(py::module& m){ case DataType::Float32: dataFormatDescriptor = py::format_descriptor<float>::format(); break;; + case DataType::Int4: + dataFormatDescriptor = py::format_descriptor<std::int8_t>::format(); + break; case DataType::Int8: dataFormatDescriptor = py::format_descriptor<std::int8_t>::format(); break; diff --git a/src/data/Tensor.cpp b/src/data/Tensor.cpp index 6f60d2f15ce0e561c32d7bc5a7561c2f8d507588..2ba19a7490aefd979f65244074485642d67c7b80 100644 --- a/src/data/Tensor.cpp +++ b/src/data/Tensor.cpp @@ -249,6 +249,8 @@ std::string Aidge::Tensor::toString() const { return std::to_string(static_cast<float*>(ptr)[idx]); case DataType::Float16: return std::to_string(static_cast<half_float::half*>(ptr)[idx]); + case DataType::Int4: + return std::to_string(static_cast<int8_t*>(ptr)[idx]); case DataType::Int8: return std::to_string(static_cast<int8_t*>(ptr)[idx]); case DataType::Int16: