From 23a110dc213c8ec3342912ce15f2b47751fa64fe Mon Sep 17 00:00:00 2001 From: thibault allenet <thibault.allenet@cea.fr> Date: Wed, 4 Dec 2024 13:41:39 +0000 Subject: [PATCH] Update implementation kernels for new lowbit integer datatypes --- .../WeightInterleavingImpl_kernels.hpp | 36 ++++++++++++++++--- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/include/aidge/backend/cpu/operator/WeightInterleavingImpl_kernels.hpp b/include/aidge/backend/cpu/operator/WeightInterleavingImpl_kernels.hpp index 422afab5..f2347fd2 100644 --- a/include/aidge/backend/cpu/operator/WeightInterleavingImpl_kernels.hpp +++ b/include/aidge/backend/cpu/operator/WeightInterleavingImpl_kernels.hpp @@ -23,7 +23,8 @@ namespace Aidge { * @param compactData The output array storing the compacted data. * @param nb_bits The number of bits to extract from each `data` element (must be less than 8). */ - void compact_data(const std::int8_t* data, std::size_t dataSize, std::int8_t* compactData, std::uint8_t nb_bits) { + template <typename T> + void compact_data(const T* data, std::size_t dataSize, T* compactData, std::uint8_t nb_bits) { AIDGE_ASSERT(nb_bits > 0 && nb_bits < 5, "Cannot compact with the given nb_bits"); // Ensure valid bit width // Mask to extract `nb_bits` from each data element @@ -41,7 +42,7 @@ namespace Aidge { // Main loop to process data in groups of `nbSlot` for (std::size_t i = 0; i < nbFullCompactbytes; ++i) { - std::int8_t compact = 0; + T compact = 0; for (unsigned int j = 0; j < nbSlot; ++j) { compact |= (data[i * nbSlot + j] & mask); // Apply mask to keep `nb_bits` only @@ -91,14 +92,39 @@ void WeightInterleavingImpl_cpu_forward_kernel(const DimSize_t input_interleavin REGISTRAR(WeightInterleavingImpl_cpu, - {ImplSpec::IOSpec{DataType::Int4, DataFormat::NHWC}}, + {ImplSpec::IOSpec{DataType::Int4, DataFormat::NHWC}, ImplSpec::IOSpec{WeightInterleavingType<DataType::Int4>::type, DataFormat::NHWC}}, {ProdConso::defaultModel, Aidge::WeightInterleavingImpl_cpu_forward_kernel<int8_t, int8_t, 4>, nullptr}); REGISTRAR(WeightInterleavingImpl_cpu, - {ImplSpec::IOSpec{DataType::Int3, DataFormat::NHWC}}, + {ImplSpec::IOSpec{DataType::Int3, DataFormat::NHWC}, ImplSpec::IOSpec{WeightInterleavingType<DataType::Int3>::type, DataFormat::NHWC}}, {ProdConso::defaultModel, Aidge::WeightInterleavingImpl_cpu_forward_kernel<int8_t, int8_t, 3>, nullptr}); REGISTRAR(WeightInterleavingImpl_cpu, - {ImplSpec::IOSpec{DataType::Int2, DataFormat::NHWC}}, + {ImplSpec::IOSpec{DataType::Int2, DataFormat::NHWC}, ImplSpec::IOSpec{WeightInterleavingType<DataType::Int2>::type, DataFormat::NHWC}}, {ProdConso::defaultModel, Aidge::WeightInterleavingImpl_cpu_forward_kernel<int8_t, int8_t, 2>, nullptr}); +REGISTRAR(WeightInterleavingImpl_cpu, + {ImplSpec::IOSpec{DataType::Binary, DataFormat::NHWC}, ImplSpec::IOSpec{WeightInterleavingType<DataType::Binary>::type, DataFormat::NHWC}}, + {ProdConso::defaultModel, Aidge::WeightInterleavingImpl_cpu_forward_kernel<int8_t, int8_t, 1>, nullptr}); + +REGISTRAR(WeightInterleavingImpl_cpu, + {ImplSpec::IOSpec{DataType::UInt4, DataFormat::NHWC}, ImplSpec::IOSpec{WeightInterleavingType<DataType::UInt4>::type, DataFormat::NHWC}}, + {ProdConso::defaultModel, Aidge::WeightInterleavingImpl_cpu_forward_kernel<uint8_t, uint8_t, 4>, nullptr}); +REGISTRAR(WeightInterleavingImpl_cpu, + {ImplSpec::IOSpec{DataType::UInt3, DataFormat::NHWC}, ImplSpec::IOSpec{WeightInterleavingType<DataType::UInt3>::type, DataFormat::NHWC}}, + {ProdConso::defaultModel, Aidge::WeightInterleavingImpl_cpu_forward_kernel<uint8_t, uint8_t, 3>, nullptr}); +REGISTRAR(WeightInterleavingImpl_cpu, + {ImplSpec::IOSpec{DataType::UInt2, DataFormat::NHWC}, ImplSpec::IOSpec{WeightInterleavingType<DataType::UInt2>::type, DataFormat::NHWC}}, + {ProdConso::defaultModel, Aidge::WeightInterleavingImpl_cpu_forward_kernel<uint8_t, uint8_t, 2>, nullptr}); + + +// REGISTRAR(WeightInterleavingImpl_cpu, +// {ImplSpec::IOSpec{DataType::Int4, DataFormat::NHWC}}, +// {ProdConso::defaultModel, Aidge::WeightInterleavingImpl_cpu_forward_kernel<int8_t, int8_t, 4>, nullptr}); +// REGISTRAR(WeightInterleavingImpl_cpu, +// {ImplSpec::IOSpec{DataType::Int3, DataFormat::NHWC}}, +// {ProdConso::defaultModel, Aidge::WeightInterleavingImpl_cpu_forward_kernel<int8_t, int8_t, 3>, nullptr}); +// REGISTRAR(WeightInterleavingImpl_cpu, +// {ImplSpec::IOSpec{DataType::Int2, DataFormat::NHWC}}, +// {ProdConso::defaultModel, Aidge::WeightInterleavingImpl_cpu_forward_kernel<int8_t, int8_t, 2>, nullptr}); + } -- GitLab