From 7050788e6988db85ea06f73478e9799b472a392f Mon Sep 17 00:00:00 2001 From: thibault allenet <thibault.allenet@cea.fr> Date: Wed, 4 Dec 2024 13:21:47 +0000 Subject: [PATCH] Add integer datatypes for weightInterleaving --- include/aidge/backend/cpu/data/TensorImpl.hpp | 8 +++ include/aidge/data/Data.hpp | 15 +++-- python_binding/data/pybind_Tensor.cpp | 8 +-- src/backend/cpu/data/TensorImpl.cpp | 56 +++++++++++++++++++ 4 files changed, 79 insertions(+), 8 deletions(-) diff --git a/include/aidge/backend/cpu/data/TensorImpl.hpp b/include/aidge/backend/cpu/data/TensorImpl.hpp index 3cd4fd517..2115b660f 100644 --- a/include/aidge/backend/cpu/data/TensorImpl.hpp +++ b/include/aidge/backend/cpu/data/TensorImpl.hpp @@ -132,6 +132,14 @@ REGISTRAR(Tensor, {"cpu", DataType::Int3}, Aidge::TensorImpl_cpu<int8_t>::create REGISTRAR(Tensor, {"cpu", DataType::UInt3}, Aidge::TensorImpl_cpu<uint8_t>::create); REGISTRAR(Tensor, {"cpu", DataType::Int2}, Aidge::TensorImpl_cpu<int8_t>::create); REGISTRAR(Tensor, {"cpu", DataType::UInt2}, Aidge::TensorImpl_cpu<uint8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::Dual_Int4}, Aidge::TensorImpl_cpu<int8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::Dual_UInt4}, Aidge::TensorImpl_cpu<uint8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::Dual_Int3}, Aidge::TensorImpl_cpu<int8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::Dual_UInt3}, Aidge::TensorImpl_cpu<uint8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::Quad_Int2}, Aidge::TensorImpl_cpu<int8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::Quad_UInt2}, Aidge::TensorImpl_cpu<uint8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::Binary}, Aidge::TensorImpl_cpu<int8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::Octo_Binary}, Aidge::TensorImpl_cpu<int8_t>::create); REGISTRAR(Tensor, {"cpu", DataType::UInt64}, Aidge::TensorImpl_cpu<uint64_t>::create); REGISTRAR(Tensor, {"cpu", DataType::UInt32}, Aidge::TensorImpl_cpu<uint32_t>::create); REGISTRAR(Tensor, {"cpu", DataType::UInt16}, Aidge::TensorImpl_cpu<uint16_t>::create); diff --git a/include/aidge/data/Data.hpp b/include/aidge/data/Data.hpp index dcc45961a..5303d61f9 100644 --- a/include/aidge/data/Data.hpp +++ b/include/aidge/data/Data.hpp @@ -157,8 +157,17 @@ struct Octo_BinaryType { }; -template <Aidge::DataType D> struct AidgeNbBits { static const int nbBits; }; -template <> const int AidgeNbBits<Aidge::DataType::Int4>::nbBits = 4; +// template <Aidge::DataType D> struct AidgeNbBits { static const int nbBits; }; +// template <> const int AidgeNbBits<Aidge::DataType::Int4>::nbBits = 4; + +template <Aidge::DataType D> struct WeightInterleavingType { static const Aidge::DataType type; }; +template <> const Aidge::DataType WeightInterleavingType<Aidge::DataType::Int4>::type = Aidge::DataType::Dual_Int4; +template <> const Aidge::DataType WeightInterleavingType<Aidge::DataType::UInt4>::type = Aidge::DataType::Dual_UInt4; +template <> const Aidge::DataType WeightInterleavingType<Aidge::DataType::Int3>::type = Aidge::DataType::Dual_Int3; +template <> const Aidge::DataType WeightInterleavingType<Aidge::DataType::UInt3>::type = Aidge::DataType::Dual_UInt3; +template <> const Aidge::DataType WeightInterleavingType<Aidge::DataType::Int2>::type = Aidge::DataType::Quad_Int2; +template <> const Aidge::DataType WeightInterleavingType<Aidge::DataType::UInt2>::type = Aidge::DataType::Quad_UInt2; +template <> const Aidge::DataType WeightInterleavingType<Aidge::DataType::Binary>::type = Aidge::DataType::Octo_Binary; template <typename T> struct NativeType { static const Aidge::DataType type; }; @@ -228,8 +237,6 @@ template <> struct cpptype<Aidge::DataType::UInt64> { using type = std::uint64_t template <Aidge::DataType D> using cpptype_t = typename cpptype<D>::type; - - } diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp index d57384d42..35e60e158 100644 --- a/python_binding/data/pybind_Tensor.cpp +++ b/python_binding/data/pybind_Tensor.cpp @@ -389,7 +389,7 @@ void init_Tensor(py::module& m){ return py::cast(b.get<std::uint8_t>(idx)); case DataType::UInt2: return py::cast(b.get<std::uint8_t>(idx)); - case DataType::Dual_UInt2: + case DataType::Quad_UInt2: return py::cast(b.get<std::uint8_t>(idx)); case DataType::UInt16: return py::cast(b.get<std::uint16_t>(idx)); @@ -444,7 +444,7 @@ void init_Tensor(py::module& m){ return py::cast(b.get<std::uint8_t>(coordIdx)); case DataType::UInt2: return py::cast(b.get<std::uint8_t>(coordIdx)); - case DataType::Dual_UInt2: + case DataType::Quad_UInt2: return py::cast(b.get<std::uint8_t>(coordIdx)); case DataType::UInt16: return py::cast(b.get<std::uint16_t>(coordIdx)); @@ -519,7 +519,7 @@ void init_Tensor(py::module& m){ case DataType::UInt2: b.set(idx, castToNativeType<std::uint8_t>(val)); break; - case DataType::Dual_UInt2: + case DataType::Quad_UInt2: b.set(idx, castToNativeType<std::uint8_t>(val)); break; case DataType::UInt16: @@ -599,7 +599,7 @@ void init_Tensor(py::module& m){ case DataType::UInt2: b.set(coordIdx, castToNativeType<std::uint8_t>(val)); break; - case DataType::Dual_UInt2: + case DataType::Quad_UInt2: b.set(coordIdx, castToNativeType<std::uint8_t>(val)); break; case DataType::UInt16: diff --git a/src/backend/cpu/data/TensorImpl.cpp b/src/backend/cpu/data/TensorImpl.cpp index 506287a0c..236e5bb8e 100644 --- a/src/backend/cpu/data/TensorImpl.cpp +++ b/src/backend/cpu/data/TensorImpl.cpp @@ -95,6 +95,62 @@ void Aidge::TensorImpl_cpu<T>::copyCast(const void *src, const Aidge::DataType s std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length, dstT); break; + case DataType::Int4: + std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length, + dstT); + break; + case DataType::UInt4: + std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length, + dstT); + break; + case DataType::Dual_Int4: + std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length, + dstT); + break; + case DataType::Dual_UInt4: + std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length, + dstT); + break; + case DataType::Int3: + std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length, + dstT); + break; + case DataType::UInt3: + std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length, + dstT); + break; + case DataType::Dual_Int3: + std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length, + dstT); + break; + case DataType::Dual_UInt3: + std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length, + dstT); + break; + case DataType::Int2: + std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length, + dstT); + break; + case DataType::UInt2: + std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length, + dstT); + break; + case DataType::Quad_Int2: + std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length, + dstT); + break; + case DataType::Quad_UInt2: + std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length, + dstT); + break; + case DataType::Binary: + std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length, + dstT); + break; + case DataType::Octo_Binary: + std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length, + dstT); + break; default: AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported data type."); break; -- GitLab