diff --git a/include/aidge/backend/cpu/data/TensorImpl.hpp b/include/aidge/backend/cpu/data/TensorImpl.hpp index 3cd4fd517ecbad3151eb9bfdbf9003d9874b8c12..2115b660fa38d3d077eaa9c416525a23c1d4c536 100644 --- a/include/aidge/backend/cpu/data/TensorImpl.hpp +++ b/include/aidge/backend/cpu/data/TensorImpl.hpp @@ -132,6 +132,14 @@ REGISTRAR(Tensor, {"cpu", DataType::Int3}, Aidge::TensorImpl_cpu<int8_t>::create REGISTRAR(Tensor, {"cpu", DataType::UInt3}, Aidge::TensorImpl_cpu<uint8_t>::create); REGISTRAR(Tensor, {"cpu", DataType::Int2}, Aidge::TensorImpl_cpu<int8_t>::create); REGISTRAR(Tensor, {"cpu", DataType::UInt2}, Aidge::TensorImpl_cpu<uint8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::Dual_Int4}, Aidge::TensorImpl_cpu<int8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::Dual_UInt4}, Aidge::TensorImpl_cpu<uint8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::Dual_Int3}, Aidge::TensorImpl_cpu<int8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::Dual_UInt3}, Aidge::TensorImpl_cpu<uint8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::Quad_Int2}, Aidge::TensorImpl_cpu<int8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::Quad_UInt2}, Aidge::TensorImpl_cpu<uint8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::Binary}, Aidge::TensorImpl_cpu<int8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::Octo_Binary}, Aidge::TensorImpl_cpu<int8_t>::create); REGISTRAR(Tensor, {"cpu", DataType::UInt64}, Aidge::TensorImpl_cpu<uint64_t>::create); REGISTRAR(Tensor, {"cpu", DataType::UInt32}, Aidge::TensorImpl_cpu<uint32_t>::create); REGISTRAR(Tensor, {"cpu", DataType::UInt16}, Aidge::TensorImpl_cpu<uint16_t>::create); diff --git a/include/aidge/data/Data.hpp b/include/aidge/data/Data.hpp index dcc45961a44303d7d3ab150619108c61362006e4..5303d61f9ca0bc28687c9300506220c0d34a5c70 100644 --- a/include/aidge/data/Data.hpp +++ b/include/aidge/data/Data.hpp @@ -157,8 +157,17 @@ struct Octo_BinaryType { }; -template <Aidge::DataType D> struct AidgeNbBits { static const int nbBits; }; -template <> const int AidgeNbBits<Aidge::DataType::Int4>::nbBits = 4; +// template <Aidge::DataType D> struct AidgeNbBits { static const int nbBits; }; +// template <> const int AidgeNbBits<Aidge::DataType::Int4>::nbBits = 4; + +template <Aidge::DataType D> struct WeightInterleavingType { static const Aidge::DataType type; }; +template <> const Aidge::DataType WeightInterleavingType<Aidge::DataType::Int4>::type = Aidge::DataType::Dual_Int4; +template <> const Aidge::DataType WeightInterleavingType<Aidge::DataType::UInt4>::type = Aidge::DataType::Dual_UInt4; +template <> const Aidge::DataType WeightInterleavingType<Aidge::DataType::Int3>::type = Aidge::DataType::Dual_Int3; +template <> const Aidge::DataType WeightInterleavingType<Aidge::DataType::UInt3>::type = Aidge::DataType::Dual_UInt3; +template <> const Aidge::DataType WeightInterleavingType<Aidge::DataType::Int2>::type = Aidge::DataType::Quad_Int2; +template <> const Aidge::DataType WeightInterleavingType<Aidge::DataType::UInt2>::type = Aidge::DataType::Quad_UInt2; +template <> const Aidge::DataType WeightInterleavingType<Aidge::DataType::Binary>::type = Aidge::DataType::Octo_Binary; template <typename T> struct NativeType { static const Aidge::DataType type; }; @@ -228,8 +237,6 @@ template <> struct cpptype<Aidge::DataType::UInt64> { using type = std::uint64_t template <Aidge::DataType D> using cpptype_t = typename cpptype<D>::type; - - } diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp index d57384d428090323b962af502dd75c553fcb20d3..35e60e1589ce5599affbc2b466171acc6bf4ef01 100644 --- a/python_binding/data/pybind_Tensor.cpp +++ b/python_binding/data/pybind_Tensor.cpp @@ -389,7 +389,7 @@ void init_Tensor(py::module& m){ return py::cast(b.get<std::uint8_t>(idx)); case DataType::UInt2: return py::cast(b.get<std::uint8_t>(idx)); - case DataType::Dual_UInt2: + case DataType::Quad_UInt2: return py::cast(b.get<std::uint8_t>(idx)); case DataType::UInt16: return py::cast(b.get<std::uint16_t>(idx)); @@ -444,7 +444,7 @@ void init_Tensor(py::module& m){ return py::cast(b.get<std::uint8_t>(coordIdx)); case DataType::UInt2: return py::cast(b.get<std::uint8_t>(coordIdx)); - case DataType::Dual_UInt2: + case DataType::Quad_UInt2: return py::cast(b.get<std::uint8_t>(coordIdx)); case DataType::UInt16: return py::cast(b.get<std::uint16_t>(coordIdx)); @@ -519,7 +519,7 @@ void init_Tensor(py::module& m){ case DataType::UInt2: b.set(idx, castToNativeType<std::uint8_t>(val)); break; - case DataType::Dual_UInt2: + case DataType::Quad_UInt2: b.set(idx, castToNativeType<std::uint8_t>(val)); break; case DataType::UInt16: @@ -599,7 +599,7 @@ void init_Tensor(py::module& m){ case DataType::UInt2: b.set(coordIdx, castToNativeType<std::uint8_t>(val)); break; - case DataType::Dual_UInt2: + case DataType::Quad_UInt2: b.set(coordIdx, castToNativeType<std::uint8_t>(val)); break; case DataType::UInt16: diff --git a/src/backend/cpu/data/TensorImpl.cpp b/src/backend/cpu/data/TensorImpl.cpp index 506287a0c520915e6426f1f0b64d9c562c754d33..236e5bb8e1e867d5a0dad85571d754bc9e2a2a22 100644 --- a/src/backend/cpu/data/TensorImpl.cpp +++ b/src/backend/cpu/data/TensorImpl.cpp @@ -95,6 +95,62 @@ void Aidge::TensorImpl_cpu<T>::copyCast(const void *src, const Aidge::DataType s std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length, dstT); break; + case DataType::Int4: + std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length, + dstT); + break; + case DataType::UInt4: + std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length, + dstT); + break; + case DataType::Dual_Int4: + std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length, + dstT); + break; + case DataType::Dual_UInt4: + std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length, + dstT); + break; + case DataType::Int3: + std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length, + dstT); + break; + case DataType::UInt3: + std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length, + dstT); + break; + case DataType::Dual_Int3: + std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length, + dstT); + break; + case DataType::Dual_UInt3: + std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length, + dstT); + break; + case DataType::Int2: + std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length, + dstT); + break; + case DataType::UInt2: + std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length, + dstT); + break; + case DataType::Quad_Int2: + std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length, + dstT); + break; + case DataType::Quad_UInt2: + std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length, + dstT); + break; + case DataType::Binary: + std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length, + dstT); + break; + case DataType::Octo_Binary: + std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length, + dstT); + break; default: AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported data type."); break;