Skip to content
Snippets Groups Projects
Commit 7050788e authored by Thibault Allenet's avatar Thibault Allenet
Browse files

Add integer datatypes for weightInterleaving

parent e7a13ff5
No related branches found
No related tags found
2 merge requests!318[Upd] release verision 0.5.0,!262Low bit support for ARM Cortex-M export
......@@ -132,6 +132,14 @@ REGISTRAR(Tensor, {"cpu", DataType::Int3}, Aidge::TensorImpl_cpu<int8_t>::create
REGISTRAR(Tensor, {"cpu", DataType::UInt3}, Aidge::TensorImpl_cpu<uint8_t>::create);
REGISTRAR(Tensor, {"cpu", DataType::Int2}, Aidge::TensorImpl_cpu<int8_t>::create);
REGISTRAR(Tensor, {"cpu", DataType::UInt2}, Aidge::TensorImpl_cpu<uint8_t>::create);
REGISTRAR(Tensor, {"cpu", DataType::Dual_Int4}, Aidge::TensorImpl_cpu<int8_t>::create);
REGISTRAR(Tensor, {"cpu", DataType::Dual_UInt4}, Aidge::TensorImpl_cpu<uint8_t>::create);
REGISTRAR(Tensor, {"cpu", DataType::Dual_Int3}, Aidge::TensorImpl_cpu<int8_t>::create);
REGISTRAR(Tensor, {"cpu", DataType::Dual_UInt3}, Aidge::TensorImpl_cpu<uint8_t>::create);
REGISTRAR(Tensor, {"cpu", DataType::Quad_Int2}, Aidge::TensorImpl_cpu<int8_t>::create);
REGISTRAR(Tensor, {"cpu", DataType::Quad_UInt2}, Aidge::TensorImpl_cpu<uint8_t>::create);
REGISTRAR(Tensor, {"cpu", DataType::Binary}, Aidge::TensorImpl_cpu<int8_t>::create);
REGISTRAR(Tensor, {"cpu", DataType::Octo_Binary}, Aidge::TensorImpl_cpu<int8_t>::create);
REGISTRAR(Tensor, {"cpu", DataType::UInt64}, Aidge::TensorImpl_cpu<uint64_t>::create);
REGISTRAR(Tensor, {"cpu", DataType::UInt32}, Aidge::TensorImpl_cpu<uint32_t>::create);
REGISTRAR(Tensor, {"cpu", DataType::UInt16}, Aidge::TensorImpl_cpu<uint16_t>::create);
......
......@@ -157,8 +157,17 @@ struct Octo_BinaryType {
};
template <Aidge::DataType D> struct AidgeNbBits { static const int nbBits; };
template <> const int AidgeNbBits<Aidge::DataType::Int4>::nbBits = 4;
// template <Aidge::DataType D> struct AidgeNbBits { static const int nbBits; };
// template <> const int AidgeNbBits<Aidge::DataType::Int4>::nbBits = 4;
template <Aidge::DataType D> struct WeightInterleavingType { static const Aidge::DataType type; };
template <> const Aidge::DataType WeightInterleavingType<Aidge::DataType::Int4>::type = Aidge::DataType::Dual_Int4;
template <> const Aidge::DataType WeightInterleavingType<Aidge::DataType::UInt4>::type = Aidge::DataType::Dual_UInt4;
template <> const Aidge::DataType WeightInterleavingType<Aidge::DataType::Int3>::type = Aidge::DataType::Dual_Int3;
template <> const Aidge::DataType WeightInterleavingType<Aidge::DataType::UInt3>::type = Aidge::DataType::Dual_UInt3;
template <> const Aidge::DataType WeightInterleavingType<Aidge::DataType::Int2>::type = Aidge::DataType::Quad_Int2;
template <> const Aidge::DataType WeightInterleavingType<Aidge::DataType::UInt2>::type = Aidge::DataType::Quad_UInt2;
template <> const Aidge::DataType WeightInterleavingType<Aidge::DataType::Binary>::type = Aidge::DataType::Octo_Binary;
template <typename T> struct NativeType { static const Aidge::DataType type; };
......@@ -228,8 +237,6 @@ template <> struct cpptype<Aidge::DataType::UInt64> { using type = std::uint64_t
template <Aidge::DataType D> using cpptype_t = typename cpptype<D>::type;
}
......
......@@ -389,7 +389,7 @@ void init_Tensor(py::module& m){
return py::cast(b.get<std::uint8_t>(idx));
case DataType::UInt2:
return py::cast(b.get<std::uint8_t>(idx));
case DataType::Dual_UInt2:
case DataType::Quad_UInt2:
return py::cast(b.get<std::uint8_t>(idx));
case DataType::UInt16:
return py::cast(b.get<std::uint16_t>(idx));
......@@ -444,7 +444,7 @@ void init_Tensor(py::module& m){
return py::cast(b.get<std::uint8_t>(coordIdx));
case DataType::UInt2:
return py::cast(b.get<std::uint8_t>(coordIdx));
case DataType::Dual_UInt2:
case DataType::Quad_UInt2:
return py::cast(b.get<std::uint8_t>(coordIdx));
case DataType::UInt16:
return py::cast(b.get<std::uint16_t>(coordIdx));
......@@ -519,7 +519,7 @@ void init_Tensor(py::module& m){
case DataType::UInt2:
b.set(idx, castToNativeType<std::uint8_t>(val));
break;
case DataType::Dual_UInt2:
case DataType::Quad_UInt2:
b.set(idx, castToNativeType<std::uint8_t>(val));
break;
case DataType::UInt16:
......@@ -599,7 +599,7 @@ void init_Tensor(py::module& m){
case DataType::UInt2:
b.set(coordIdx, castToNativeType<std::uint8_t>(val));
break;
case DataType::Dual_UInt2:
case DataType::Quad_UInt2:
b.set(coordIdx, castToNativeType<std::uint8_t>(val));
break;
case DataType::UInt16:
......
......@@ -95,6 +95,62 @@ void Aidge::TensorImpl_cpu<T>::copyCast(const void *src, const Aidge::DataType s
std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length,
dstT);
break;
case DataType::Int4:
std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length,
dstT);
break;
case DataType::UInt4:
std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length,
dstT);
break;
case DataType::Dual_Int4:
std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length,
dstT);
break;
case DataType::Dual_UInt4:
std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length,
dstT);
break;
case DataType::Int3:
std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length,
dstT);
break;
case DataType::UInt3:
std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length,
dstT);
break;
case DataType::Dual_Int3:
std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length,
dstT);
break;
case DataType::Dual_UInt3:
std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length,
dstT);
break;
case DataType::Int2:
std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length,
dstT);
break;
case DataType::UInt2:
std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length,
dstT);
break;
case DataType::Quad_Int2:
std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length,
dstT);
break;
case DataType::Quad_UInt2:
std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length,
dstT);
break;
case DataType::Binary:
std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length,
dstT);
break;
case DataType::Octo_Binary:
std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length,
dstT);
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported data type.");
break;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment