diff --git a/include/aidge/backend/cpu/operator/WeightInterleavingImpl_kernels.hpp b/include/aidge/backend/cpu/operator/WeightInterleavingImpl_kernels.hpp index 422afab59178732dbcb2427892fbf930e97cbb45..f2347fd2d7ad3e9adfa134ce1413b6348e08c064 100644 --- a/include/aidge/backend/cpu/operator/WeightInterleavingImpl_kernels.hpp +++ b/include/aidge/backend/cpu/operator/WeightInterleavingImpl_kernels.hpp @@ -23,7 +23,8 @@ namespace Aidge { * @param compactData The output array storing the compacted data. * @param nb_bits The number of bits to extract from each `data` element (must be less than 8). */ - void compact_data(const std::int8_t* data, std::size_t dataSize, std::int8_t* compactData, std::uint8_t nb_bits) { + template <typename T> + void compact_data(const T* data, std::size_t dataSize, T* compactData, std::uint8_t nb_bits) { AIDGE_ASSERT(nb_bits > 0 && nb_bits < 5, "Cannot compact with the given nb_bits"); // Ensure valid bit width // Mask to extract `nb_bits` from each data element @@ -41,7 +42,7 @@ namespace Aidge { // Main loop to process data in groups of `nbSlot` for (std::size_t i = 0; i < nbFullCompactbytes; ++i) { - std::int8_t compact = 0; + T compact = 0; for (unsigned int j = 0; j < nbSlot; ++j) { compact |= (data[i * nbSlot + j] & mask); // Apply mask to keep `nb_bits` only @@ -91,14 +92,39 @@ void WeightInterleavingImpl_cpu_forward_kernel(const DimSize_t input_interleavin REGISTRAR(WeightInterleavingImpl_cpu, - {ImplSpec::IOSpec{DataType::Int4, DataFormat::NHWC}}, + {ImplSpec::IOSpec{DataType::Int4, DataFormat::NHWC}, ImplSpec::IOSpec{WeightInterleavingType<DataType::Int4>::type, DataFormat::NHWC}}, {ProdConso::defaultModel, Aidge::WeightInterleavingImpl_cpu_forward_kernel<int8_t, int8_t, 4>, nullptr}); REGISTRAR(WeightInterleavingImpl_cpu, - {ImplSpec::IOSpec{DataType::Int3, DataFormat::NHWC}}, + {ImplSpec::IOSpec{DataType::Int3, DataFormat::NHWC}, ImplSpec::IOSpec{WeightInterleavingType<DataType::Int3>::type, DataFormat::NHWC}}, {ProdConso::defaultModel, Aidge::WeightInterleavingImpl_cpu_forward_kernel<int8_t, int8_t, 3>, nullptr}); REGISTRAR(WeightInterleavingImpl_cpu, - {ImplSpec::IOSpec{DataType::Int2, DataFormat::NHWC}}, + {ImplSpec::IOSpec{DataType::Int2, DataFormat::NHWC}, ImplSpec::IOSpec{WeightInterleavingType<DataType::Int2>::type, DataFormat::NHWC}}, {ProdConso::defaultModel, Aidge::WeightInterleavingImpl_cpu_forward_kernel<int8_t, int8_t, 2>, nullptr}); +REGISTRAR(WeightInterleavingImpl_cpu, + {ImplSpec::IOSpec{DataType::Binary, DataFormat::NHWC}, ImplSpec::IOSpec{WeightInterleavingType<DataType::Binary>::type, DataFormat::NHWC}}, + {ProdConso::defaultModel, Aidge::WeightInterleavingImpl_cpu_forward_kernel<int8_t, int8_t, 1>, nullptr}); + +REGISTRAR(WeightInterleavingImpl_cpu, + {ImplSpec::IOSpec{DataType::UInt4, DataFormat::NHWC}, ImplSpec::IOSpec{WeightInterleavingType<DataType::UInt4>::type, DataFormat::NHWC}}, + {ProdConso::defaultModel, Aidge::WeightInterleavingImpl_cpu_forward_kernel<uint8_t, uint8_t, 4>, nullptr}); +REGISTRAR(WeightInterleavingImpl_cpu, + {ImplSpec::IOSpec{DataType::UInt3, DataFormat::NHWC}, ImplSpec::IOSpec{WeightInterleavingType<DataType::UInt3>::type, DataFormat::NHWC}}, + {ProdConso::defaultModel, Aidge::WeightInterleavingImpl_cpu_forward_kernel<uint8_t, uint8_t, 3>, nullptr}); +REGISTRAR(WeightInterleavingImpl_cpu, + {ImplSpec::IOSpec{DataType::UInt2, DataFormat::NHWC}, ImplSpec::IOSpec{WeightInterleavingType<DataType::UInt2>::type, DataFormat::NHWC}}, + {ProdConso::defaultModel, Aidge::WeightInterleavingImpl_cpu_forward_kernel<uint8_t, uint8_t, 2>, nullptr}); + + +// REGISTRAR(WeightInterleavingImpl_cpu, +// {ImplSpec::IOSpec{DataType::Int4, DataFormat::NHWC}}, +// {ProdConso::defaultModel, Aidge::WeightInterleavingImpl_cpu_forward_kernel<int8_t, int8_t, 4>, nullptr}); +// REGISTRAR(WeightInterleavingImpl_cpu, +// {ImplSpec::IOSpec{DataType::Int3, DataFormat::NHWC}}, +// {ProdConso::defaultModel, Aidge::WeightInterleavingImpl_cpu_forward_kernel<int8_t, int8_t, 3>, nullptr}); +// REGISTRAR(WeightInterleavingImpl_cpu, +// {ImplSpec::IOSpec{DataType::Int2, DataFormat::NHWC}}, +// {ProdConso::defaultModel, Aidge::WeightInterleavingImpl_cpu_forward_kernel<int8_t, int8_t, 2>, nullptr}); + }