Skip to content
Snippets Groups Projects
Commit 54630773 authored by Thibault Allenet's avatar Thibault Allenet
Browse files

Add datatypes for compacted low bits integers

parent 3eafaa31
No related branches found
No related tags found
2 merge requests!318[Upd] release verision 0.5.0,!262Low bit support for ARM Cortex-M export
Pipeline #60704 failed
...@@ -28,10 +28,14 @@ enum class DataType { ...@@ -28,10 +28,14 @@ enum class DataType {
Float16, Float16,
BFloat16, BFloat16,
Binary, Binary,
Octo_Binary,
Ternary, Ternary,
Int2, Int2,
Quad_Int2,
Int3, Int3,
Dual_Int3,
Int4, Int4,
Dual_Int4,
Int5, Int5,
Int6, Int6,
Int7, Int7,
...@@ -40,8 +44,11 @@ enum class DataType { ...@@ -40,8 +44,11 @@ enum class DataType {
Int32, Int32,
Int64, Int64,
UInt2, UInt2,
Quad_UInt2,
UInt3, UInt3,
Dual_UInt3,
UInt4, UInt4,
Dual_UInt4,
UInt5, UInt5,
UInt6, UInt6,
UInt7, UInt7,
...@@ -124,6 +131,31 @@ struct Int2Type { ...@@ -124,6 +131,31 @@ struct Int2Type {
struct UInt2Type { struct UInt2Type {
std::uint8_t value; std::uint8_t value;
}; };
struct Dual_Int4Type {
std::int8_t value;
};
struct Dual_UInt4Type {
std::uint8_t value;
};
struct Dual_Int3Type {
std::int8_t value;
};
struct Dual_UInt3Type {
std::uint8_t value;
};
struct Quad_Int2Type {
std::int8_t value;
};
struct Quad_UInt2Type {
std::uint8_t value;
};
struct BinaryType {
std::int8_t value;
};
struct Octo_BinaryType {
std::uint8_t value;
};
template <Aidge::DataType D> struct AidgeNbBits { static const int nbBits; }; template <Aidge::DataType D> struct AidgeNbBits { static const int nbBits; };
template <> const int AidgeNbBits<Aidge::DataType::Int4>::nbBits = 4; template <> const int AidgeNbBits<Aidge::DataType::Int4>::nbBits = 4;
...@@ -139,6 +171,12 @@ template <> const Aidge::DataType NativeType<Int3Type>::type = Aidge::DataType:: ...@@ -139,6 +171,12 @@ template <> const Aidge::DataType NativeType<Int3Type>::type = Aidge::DataType::
template <> const Aidge::DataType NativeType<UInt3Type>::type = Aidge::DataType::UInt3; template <> const Aidge::DataType NativeType<UInt3Type>::type = Aidge::DataType::UInt3;
template <> const Aidge::DataType NativeType<Int2Type>::type = Aidge::DataType::Int2; template <> const Aidge::DataType NativeType<Int2Type>::type = Aidge::DataType::Int2;
template <> const Aidge::DataType NativeType<UInt2Type>::type = Aidge::DataType::UInt2; template <> const Aidge::DataType NativeType<UInt2Type>::type = Aidge::DataType::UInt2;
template <> const Aidge::DataType NativeType<Dual_Int4Type>::type = Aidge::DataType::Dual_Int4;
template <> const Aidge::DataType NativeType<Dual_UInt4Type>::type = Aidge::DataType::Dual_UInt4;
template <> const Aidge::DataType NativeType<Dual_Int3Type>::type = Aidge::DataType::Dual_Int3;
template <> const Aidge::DataType NativeType<Dual_UInt3Type>::type = Aidge::DataType::Dual_UInt3;
template <> const Aidge::DataType NativeType<Quad_Int2Type>::type = Aidge::DataType::Quad_Int2;
template <> const Aidge::DataType NativeType<Quad_UInt2Type>::type = Aidge::DataType::Quad_UInt2;
template <> const Aidge::DataType NativeType<std::int8_t>::type = Aidge::DataType::Int8; template <> const Aidge::DataType NativeType<std::int8_t>::type = Aidge::DataType::Int8;
template <> const Aidge::DataType NativeType<std::int16_t>::type = Aidge::DataType::Int16; template <> const Aidge::DataType NativeType<std::int16_t>::type = Aidge::DataType::Int16;
template <> const Aidge::DataType NativeType<std::int32_t>::type = Aidge::DataType::Int32; template <> const Aidge::DataType NativeType<std::int32_t>::type = Aidge::DataType::Int32;
...@@ -150,9 +188,9 @@ template <> const Aidge::DataType NativeType<std::uint64_t>::type = Aidge::DataT ...@@ -150,9 +188,9 @@ template <> const Aidge::DataType NativeType<std::uint64_t>::type = Aidge::DataT
template <> template <>
const char* const EnumStrings<Aidge::DataType>::data[] const char* const EnumStrings<Aidge::DataType>::data[]
= {"Float64", "Float32", "Float16", "BFloat16", "Binary", "Ternary", = {"Float64", "Float32", "Float16", "BFloat16", "Binary", "Octo_Binary", "Ternary",
"Int2", "Int3", "Int4", "Int5", "Int6", "Int7", "Int8", "Int16", "Int2", "Quad_Int2", "Int3", "Dual_Int3", "Int4", "Dual_Int4", "Int5", "Int6", "Int7", "Int8", "Int16",
"Int32", "Int64", "UInt2", "UInt3", "UInt4", "UInt5", "UInt6", "Int32", "Int64", "UInt2", "Quad_UInt2", "UInt3", "Dual_UInt3", "UInt4", "Dual_UInt4", "UInt5", "UInt6",
"UInt7", "UInt8", "UInt16", "UInt32", "UInt64", "Any"}; "UInt7", "UInt8", "UInt16", "UInt32", "UInt64", "Any"};
template <> template <>
...@@ -171,6 +209,14 @@ template <> struct cpptype<Aidge::DataType::Int3> { using type = Int3Type; }; ...@@ -171,6 +209,14 @@ template <> struct cpptype<Aidge::DataType::Int3> { using type = Int3Type; };
template <> struct cpptype<Aidge::DataType::UInt3> { using type = UInt3Type; }; template <> struct cpptype<Aidge::DataType::UInt3> { using type = UInt3Type; };
template <> struct cpptype<Aidge::DataType::Int2> { using type = Int2Type; }; template <> struct cpptype<Aidge::DataType::Int2> { using type = Int2Type; };
template <> struct cpptype<Aidge::DataType::UInt2> { using type = UInt2Type; }; template <> struct cpptype<Aidge::DataType::UInt2> { using type = UInt2Type; };
template <> struct cpptype<Aidge::DataType::Dual_Int4> { using type = Dual_Int4Type; };
template <> struct cpptype<Aidge::DataType::Dual_UInt4> { using type = Dual_UInt4Type; };
template <> struct cpptype<Aidge::DataType::Dual_Int3> { using type = Dual_Int3Type; };
template <> struct cpptype<Aidge::DataType::Dual_UInt3> { using type = Dual_UInt3Type; };
template <> struct cpptype<Aidge::DataType::Quad_Int2> { using type = Quad_Int2Type; };
template <> struct cpptype<Aidge::DataType::Quad_UInt2> { using type = Quad_UInt2Type; };
template <> struct cpptype<Aidge::DataType::Binary> { using type = BinaryType; };
template <> struct cpptype<Aidge::DataType::Octo_Binary> { using type = Octo_BinaryType; };
template <> struct cpptype<Aidge::DataType::Int8> { using type = std::int8_t; }; template <> struct cpptype<Aidge::DataType::Int8> { using type = std::int8_t; };
template <> struct cpptype<Aidge::DataType::Int16> { using type = std::int16_t; }; template <> struct cpptype<Aidge::DataType::Int16> { using type = std::int16_t; };
template <> struct cpptype<Aidge::DataType::Int32> { using type = std::int32_t; }; template <> struct cpptype<Aidge::DataType::Int32> { using type = std::int32_t; };
......
...@@ -355,6 +355,22 @@ void init_Tensor(py::module& m){ ...@@ -355,6 +355,22 @@ void init_Tensor(py::module& m){
return py::cast(b.get<float>(idx)); return py::cast(b.get<float>(idx));
case DataType::Int8: case DataType::Int8:
return py::cast(b.get<std::int8_t>(idx)); return py::cast(b.get<std::int8_t>(idx));
case DataType::Int4:
return py::cast(b.get<std::int8_t>(idx));
case DataType::Dual_Int4:
return py::cast(b.get<std::int8_t>(idx));
case DataType::Int3:
return py::cast(b.get<std::int8_t>(idx));
case DataType::Dual_Int3:
return py::cast(b.get<std::int8_t>(idx));
case DataType::Int2:
return py::cast(b.get<std::int8_t>(idx));
case DataType::Quad_Int2:
return py::cast(b.get<std::int8_t>(idx));
case DataType::Binary:
return py::cast(b.get<std::int8_t>(idx));
case DataType::Octo_Binary:
return py::cast(b.get<std::int8_t>(idx));
case DataType::Int16: case DataType::Int16:
return py::cast(b.get<std::int16_t>(idx)); return py::cast(b.get<std::int16_t>(idx));
case DataType::Int32: case DataType::Int32:
...@@ -363,6 +379,18 @@ void init_Tensor(py::module& m){ ...@@ -363,6 +379,18 @@ void init_Tensor(py::module& m){
return py::cast(b.get<std::int64_t>(idx)); return py::cast(b.get<std::int64_t>(idx));
case DataType::UInt8: case DataType::UInt8:
return py::cast(b.get<std::uint8_t>(idx)); return py::cast(b.get<std::uint8_t>(idx));
case DataType::UInt4:
return py::cast(b.get<std::uint8_t>(idx));
case DataType::Dual_UInt4:
return py::cast(b.get<std::uint8_t>(idx));
case DataType::UInt3:
return py::cast(b.get<std::uint8_t>(idx));
case DataType::Dual_UInt3:
return py::cast(b.get<std::uint8_t>(idx));
case DataType::UInt2:
return py::cast(b.get<std::uint8_t>(idx));
case DataType::Dual_UInt2:
return py::cast(b.get<std::uint8_t>(idx));
case DataType::UInt16: case DataType::UInt16:
return py::cast(b.get<std::uint16_t>(idx)); return py::cast(b.get<std::uint16_t>(idx));
case DataType::UInt32: case DataType::UInt32:
...@@ -382,6 +410,22 @@ void init_Tensor(py::module& m){ ...@@ -382,6 +410,22 @@ void init_Tensor(py::module& m){
return py::cast(b.get<float>(coordIdx)); return py::cast(b.get<float>(coordIdx));
case DataType::Int8: case DataType::Int8:
return py::cast(b.get<std::int8_t>(coordIdx)); return py::cast(b.get<std::int8_t>(coordIdx));
case DataType::Int4:
return py::cast(b.get<std::int8_t>(coordIdx));
case DataType::Dual_Int4:
return py::cast(b.get<std::int8_t>(coordIdx));
case DataType::Int3:
return py::cast(b.get<std::int8_t>(coordIdx));
case DataType::Dual_Int3:
return py::cast(b.get<std::int8_t>(coordIdx));
case DataType::Int2:
return py::cast(b.get<std::int8_t>(coordIdx));
case DataType::Quad_Int2:
return py::cast(b.get<std::int8_t>(coordIdx));
case DataType::Binary:
return py::cast(b.get<std::int8_t>(coordIdx));
case DataType::Octo_Binary:
return py::cast(b.get<std::int8_t>(coordIdx));
case DataType::Int16: case DataType::Int16:
return py::cast(b.get<std::int16_t>(coordIdx)); return py::cast(b.get<std::int16_t>(coordIdx));
case DataType::Int32: case DataType::Int32:
...@@ -390,6 +434,18 @@ void init_Tensor(py::module& m){ ...@@ -390,6 +434,18 @@ void init_Tensor(py::module& m){
return py::cast(b.get<std::int64_t>(coordIdx)); return py::cast(b.get<std::int64_t>(coordIdx));
case DataType::UInt8: case DataType::UInt8:
return py::cast(b.get<std::uint8_t>(coordIdx)); return py::cast(b.get<std::uint8_t>(coordIdx));
case DataType::UInt4:
return py::cast(b.get<std::uint8_t>(coordIdx));
case DataType::Dual_UInt4:
return py::cast(b.get<std::uint8_t>(coordIdx));
case DataType::UInt3:
return py::cast(b.get<std::uint8_t>(coordIdx));
case DataType::Dual_UInt3:
return py::cast(b.get<std::uint8_t>(coordIdx));
case DataType::UInt2:
return py::cast(b.get<std::uint8_t>(coordIdx));
case DataType::Dual_UInt2:
return py::cast(b.get<std::uint8_t>(coordIdx));
case DataType::UInt16: case DataType::UInt16:
return py::cast(b.get<std::uint16_t>(coordIdx)); return py::cast(b.get<std::uint16_t>(coordIdx));
case DataType::UInt32: case DataType::UInt32:
...@@ -412,6 +468,30 @@ void init_Tensor(py::module& m){ ...@@ -412,6 +468,30 @@ void init_Tensor(py::module& m){
case DataType::Int8: case DataType::Int8:
b.set(idx, castToNativeType<std::int8_t>(val)); b.set(idx, castToNativeType<std::int8_t>(val));
break; break;
case DataType::Int4:
b.set(idx, castToNativeType<std::int8_t>(val));
break;
case DataType::Dual_Int4:
b.set(idx, castToNativeType<std::int8_t>(val));
break;
case DataType::Int3:
b.set(idx, castToNativeType<std::int8_t>(val));
break;
case DataType::Dual_Int3:
b.set(idx, castToNativeType<std::int8_t>(val));
break;
case DataType::Int2:
b.set(idx, castToNativeType<std::int8_t>(val));
break;
case DataType::Quad_Int2:
b.set(idx, castToNativeType<std::int8_t>(val));
break;
case DataType::Binary:
b.set(idx, castToNativeType<std::int8_t>(val));
break;
case DataType::Octo_Binary:
b.set(idx, castToNativeType<std::int8_t>(val));
break;
case DataType::Int16: case DataType::Int16:
b.set(idx, castToNativeType<std::int16_t>(val)); b.set(idx, castToNativeType<std::int16_t>(val));
break; break;
...@@ -424,6 +504,24 @@ void init_Tensor(py::module& m){ ...@@ -424,6 +504,24 @@ void init_Tensor(py::module& m){
case DataType::UInt8: case DataType::UInt8:
b.set(idx, castToNativeType<std::uint8_t>(val)); b.set(idx, castToNativeType<std::uint8_t>(val));
break; break;
case DataType::UInt4:
b.set(idx, castToNativeType<std::uint8_t>(val));
break;
case DataType::Dual_UInt4:
b.set(idx, castToNativeType<std::uint8_t>(val));
break;
case DataType::UInt3:
b.set(idx, castToNativeType<std::uint8_t>(val));
break;
case DataType::Dual_UInt3:
b.set(idx, castToNativeType<std::uint8_t>(val));
break;
case DataType::UInt2:
b.set(idx, castToNativeType<std::uint8_t>(val));
break;
case DataType::Dual_UInt2:
b.set(idx, castToNativeType<std::uint8_t>(val));
break;
case DataType::UInt16: case DataType::UInt16:
b.set(idx, castToNativeType<std::uint16_t>(val)); b.set(idx, castToNativeType<std::uint16_t>(val));
break; break;
...@@ -450,6 +548,30 @@ void init_Tensor(py::module& m){ ...@@ -450,6 +548,30 @@ void init_Tensor(py::module& m){
case DataType::Int8: case DataType::Int8:
b.set(coordIdx, castToNativeType<std::int8_t>(val)); b.set(coordIdx, castToNativeType<std::int8_t>(val));
break; break;
case DataType::Int4:
b.set(coordIdx, castToNativeType<std::int8_t>(val));
break;
case DataType::Dual_Int4:
b.set(coordIdx, castToNativeType<std::int8_t>(val));
break;
case DataType::Int3:
b.set(coordIdx, castToNativeType<std::int8_t>(val));
break;
case DataType::Dual_Int3:
b.set(coordIdx, castToNativeType<std::int8_t>(val));
break;
case DataType::Int2:
b.set(coordIdx, castToNativeType<std::int8_t>(val));
break;
case DataType::Quad_Int2:
b.set(coordIdx, castToNativeType<std::int8_t>(val));
break;
case DataType::Binary:
b.set(coordIdx, castToNativeType<std::int8_t>(val));
break;
case DataType::Octo_Binary:
b.set(coordIdx, castToNativeType<std::int8_t>(val));
break;
case DataType::Int16: case DataType::Int16:
b.set(coordIdx, castToNativeType<std::int16_t>(val)); b.set(coordIdx, castToNativeType<std::int16_t>(val));
break; break;
...@@ -462,6 +584,24 @@ void init_Tensor(py::module& m){ ...@@ -462,6 +584,24 @@ void init_Tensor(py::module& m){
case DataType::UInt8: case DataType::UInt8:
b.set(coordIdx, castToNativeType<std::uint8_t>(val)); b.set(coordIdx, castToNativeType<std::uint8_t>(val));
break; break;
case DataType::UInt4:
b.set(coordIdx, castToNativeType<std::uint8_t>(val));
break;
case DataType::Dual_UInt4:
b.set(coordIdx, castToNativeType<std::uint8_t>(val));
break;
case DataType::UInt3:
b.set(coordIdx, castToNativeType<std::uint8_t>(val));
break;
case DataType::Dual_UInt3:
b.set(coordIdx, castToNativeType<std::uint8_t>(val));
break;
case DataType::UInt2:
b.set(coordIdx, castToNativeType<std::uint8_t>(val));
break;
case DataType::Dual_UInt2:
b.set(coordIdx, castToNativeType<std::uint8_t>(val));
break;
case DataType::UInt16: case DataType::UInt16:
b.set(coordIdx, castToNativeType<std::uint16_t>(val)); b.set(coordIdx, castToNativeType<std::uint16_t>(val));
break; break;
...@@ -517,6 +657,30 @@ void init_Tensor(py::module& m){ ...@@ -517,6 +657,30 @@ void init_Tensor(py::module& m){
case DataType::UInt2: case DataType::UInt2:
dataFormatDescriptor = py::format_descriptor<std::uint8_t>::format(); dataFormatDescriptor = py::format_descriptor<std::uint8_t>::format();
break; break;
case DataType::Dual_Int4:
dataFormatDescriptor = py::format_descriptor<std::int8_t>::format();
break;
case DataType::Dual_UInt4:
dataFormatDescriptor = py::format_descriptor<std::uint8_t>::format();
break;
case DataType::Dual_Int3:
dataFormatDescriptor = py::format_descriptor<std::int8_t>::format();
break;
case DataType::Dual_UInt3:
dataFormatDescriptor = py::format_descriptor<std::uint8_t>::format();
break;
case DataType::Quad_Int2:
dataFormatDescriptor = py::format_descriptor<std::int8_t>::format();
break;
case DataType::Quad_UInt2:
dataFormatDescriptor = py::format_descriptor<std::uint8_t>::format();
break;
case DataType::Binary:
dataFormatDescriptor = py::format_descriptor<std::int8_t>::format();
break;
case DataType::Octo_Binary:
dataFormatDescriptor = py::format_descriptor<std::int8_t>::format();
break;
case DataType::Int8: case DataType::Int8:
dataFormatDescriptor = py::format_descriptor<std::int8_t>::format(); dataFormatDescriptor = py::format_descriptor<std::int8_t>::format();
break; break;
......
...@@ -251,6 +251,22 @@ std::string Tensor::toString() const { ...@@ -251,6 +251,22 @@ std::string Tensor::toString() const {
return std::to_string(static_cast<float*>(ptr)[idx]); return std::to_string(static_cast<float*>(ptr)[idx]);
case DataType::Float16: case DataType::Float16:
return std::to_string(static_cast<half_float::half*>(ptr)[idx]); return std::to_string(static_cast<half_float::half*>(ptr)[idx]);
case DataType::Binary:
return std::to_string(static_cast<int8_t*>(ptr)[idx]);
case DataType::Octo_Binary:
return std::to_string(static_cast<int8_t*>(ptr)[idx]);
case DataType::Dual_Int4:
return std::to_string(static_cast<int8_t*>(ptr)[idx]);
case DataType::Dual_UInt4:
return std::to_string(static_cast<uint8_t*>(ptr)[idx]);
case DataType::Dual_Int3:
return std::to_string(static_cast<int8_t*>(ptr)[idx]);
case DataType::Dual_UInt3:
return std::to_string(static_cast<uint8_t*>(ptr)[idx]);
case DataType::Quad_Int2:
return std::to_string(static_cast<int8_t*>(ptr)[idx]);
case DataType::Quad_UInt2:
return std::to_string(static_cast<uint8_t*>(ptr)[idx]);
case DataType::Int4: case DataType::Int4:
return std::to_string(static_cast<int8_t*>(ptr)[idx]); return std::to_string(static_cast<int8_t*>(ptr)[idx]);
case DataType::UInt4: case DataType::UInt4:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment