diff --git a/include/aidge/data/Data.hpp b/include/aidge/data/Data.hpp index a710bbc628481fc815e54cd9d10fe27885237b8f..dcc45961a44303d7d3ab150619108c61362006e4 100644 --- a/include/aidge/data/Data.hpp +++ b/include/aidge/data/Data.hpp @@ -28,10 +28,14 @@ enum class DataType { Float16, BFloat16, Binary, + Octo_Binary, Ternary, Int2, + Quad_Int2, Int3, + Dual_Int3, Int4, + Dual_Int4, Int5, Int6, Int7, @@ -40,8 +44,11 @@ enum class DataType { Int32, Int64, UInt2, + Quad_UInt2, UInt3, + Dual_UInt3, UInt4, + Dual_UInt4, UInt5, UInt6, UInt7, @@ -124,6 +131,31 @@ struct Int2Type { struct UInt2Type { std::uint8_t value; }; +struct Dual_Int4Type { + std::int8_t value; +}; +struct Dual_UInt4Type { + std::uint8_t value; +}; +struct Dual_Int3Type { + std::int8_t value; +}; +struct Dual_UInt3Type { + std::uint8_t value; +}; +struct Quad_Int2Type { + std::int8_t value; +}; +struct Quad_UInt2Type { + std::uint8_t value; +}; +struct BinaryType { + std::int8_t value; +}; +struct Octo_BinaryType { + std::uint8_t value; +}; + template <Aidge::DataType D> struct AidgeNbBits { static const int nbBits; }; template <> const int AidgeNbBits<Aidge::DataType::Int4>::nbBits = 4; @@ -139,6 +171,12 @@ template <> const Aidge::DataType NativeType<Int3Type>::type = Aidge::DataType:: template <> const Aidge::DataType NativeType<UInt3Type>::type = Aidge::DataType::UInt3; template <> const Aidge::DataType NativeType<Int2Type>::type = Aidge::DataType::Int2; template <> const Aidge::DataType NativeType<UInt2Type>::type = Aidge::DataType::UInt2; +template <> const Aidge::DataType NativeType<Dual_Int4Type>::type = Aidge::DataType::Dual_Int4; +template <> const Aidge::DataType NativeType<Dual_UInt4Type>::type = Aidge::DataType::Dual_UInt4; +template <> const Aidge::DataType NativeType<Dual_Int3Type>::type = Aidge::DataType::Dual_Int3; +template <> const Aidge::DataType NativeType<Dual_UInt3Type>::type = Aidge::DataType::Dual_UInt3; +template <> const Aidge::DataType NativeType<Quad_Int2Type>::type = Aidge::DataType::Quad_Int2; +template <> const Aidge::DataType NativeType<Quad_UInt2Type>::type = Aidge::DataType::Quad_UInt2; template <> const Aidge::DataType NativeType<std::int8_t>::type = Aidge::DataType::Int8; template <> const Aidge::DataType NativeType<std::int16_t>::type = Aidge::DataType::Int16; template <> const Aidge::DataType NativeType<std::int32_t>::type = Aidge::DataType::Int32; @@ -150,9 +188,9 @@ template <> const Aidge::DataType NativeType<std::uint64_t>::type = Aidge::DataT template <> const char* const EnumStrings<Aidge::DataType>::data[] - = {"Float64", "Float32", "Float16", "BFloat16", "Binary", "Ternary", - "Int2", "Int3", "Int4", "Int5", "Int6", "Int7", "Int8", "Int16", - "Int32", "Int64", "UInt2", "UInt3", "UInt4", "UInt5", "UInt6", + = {"Float64", "Float32", "Float16", "BFloat16", "Binary", "Octo_Binary", "Ternary", + "Int2", "Quad_Int2", "Int3", "Dual_Int3", "Int4", "Dual_Int4", "Int5", "Int6", "Int7", "Int8", "Int16", + "Int32", "Int64", "UInt2", "Quad_UInt2", "UInt3", "Dual_UInt3", "UInt4", "Dual_UInt4", "UInt5", "UInt6", "UInt7", "UInt8", "UInt16", "UInt32", "UInt64", "Any"}; template <> @@ -171,6 +209,14 @@ template <> struct cpptype<Aidge::DataType::Int3> { using type = Int3Type; }; template <> struct cpptype<Aidge::DataType::UInt3> { using type = UInt3Type; }; template <> struct cpptype<Aidge::DataType::Int2> { using type = Int2Type; }; template <> struct cpptype<Aidge::DataType::UInt2> { using type = UInt2Type; }; +template <> struct cpptype<Aidge::DataType::Dual_Int4> { using type = Dual_Int4Type; }; +template <> struct cpptype<Aidge::DataType::Dual_UInt4> { using type = Dual_UInt4Type; }; +template <> struct cpptype<Aidge::DataType::Dual_Int3> { using type = Dual_Int3Type; }; +template <> struct cpptype<Aidge::DataType::Dual_UInt3> { using type = Dual_UInt3Type; }; +template <> struct cpptype<Aidge::DataType::Quad_Int2> { using type = Quad_Int2Type; }; +template <> struct cpptype<Aidge::DataType::Quad_UInt2> { using type = Quad_UInt2Type; }; +template <> struct cpptype<Aidge::DataType::Binary> { using type = BinaryType; }; +template <> struct cpptype<Aidge::DataType::Octo_Binary> { using type = Octo_BinaryType; }; template <> struct cpptype<Aidge::DataType::Int8> { using type = std::int8_t; }; template <> struct cpptype<Aidge::DataType::Int16> { using type = std::int16_t; }; template <> struct cpptype<Aidge::DataType::Int32> { using type = std::int32_t; }; diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp index 429a218ac67107aa7975931516f0324d8ca8bc8b..d57384d428090323b962af502dd75c553fcb20d3 100644 --- a/python_binding/data/pybind_Tensor.cpp +++ b/python_binding/data/pybind_Tensor.cpp @@ -355,6 +355,22 @@ void init_Tensor(py::module& m){ return py::cast(b.get<float>(idx)); case DataType::Int8: return py::cast(b.get<std::int8_t>(idx)); + case DataType::Int4: + return py::cast(b.get<std::int8_t>(idx)); + case DataType::Dual_Int4: + return py::cast(b.get<std::int8_t>(idx)); + case DataType::Int3: + return py::cast(b.get<std::int8_t>(idx)); + case DataType::Dual_Int3: + return py::cast(b.get<std::int8_t>(idx)); + case DataType::Int2: + return py::cast(b.get<std::int8_t>(idx)); + case DataType::Quad_Int2: + return py::cast(b.get<std::int8_t>(idx)); + case DataType::Binary: + return py::cast(b.get<std::int8_t>(idx)); + case DataType::Octo_Binary: + return py::cast(b.get<std::int8_t>(idx)); case DataType::Int16: return py::cast(b.get<std::int16_t>(idx)); case DataType::Int32: @@ -363,6 +379,18 @@ void init_Tensor(py::module& m){ return py::cast(b.get<std::int64_t>(idx)); case DataType::UInt8: return py::cast(b.get<std::uint8_t>(idx)); + case DataType::UInt4: + return py::cast(b.get<std::uint8_t>(idx)); + case DataType::Dual_UInt4: + return py::cast(b.get<std::uint8_t>(idx)); + case DataType::UInt3: + return py::cast(b.get<std::uint8_t>(idx)); + case DataType::Dual_UInt3: + return py::cast(b.get<std::uint8_t>(idx)); + case DataType::UInt2: + return py::cast(b.get<std::uint8_t>(idx)); + case DataType::Dual_UInt2: + return py::cast(b.get<std::uint8_t>(idx)); case DataType::UInt16: return py::cast(b.get<std::uint16_t>(idx)); case DataType::UInt32: @@ -382,6 +410,22 @@ void init_Tensor(py::module& m){ return py::cast(b.get<float>(coordIdx)); case DataType::Int8: return py::cast(b.get<std::int8_t>(coordIdx)); + case DataType::Int4: + return py::cast(b.get<std::int8_t>(coordIdx)); + case DataType::Dual_Int4: + return py::cast(b.get<std::int8_t>(coordIdx)); + case DataType::Int3: + return py::cast(b.get<std::int8_t>(coordIdx)); + case DataType::Dual_Int3: + return py::cast(b.get<std::int8_t>(coordIdx)); + case DataType::Int2: + return py::cast(b.get<std::int8_t>(coordIdx)); + case DataType::Quad_Int2: + return py::cast(b.get<std::int8_t>(coordIdx)); + case DataType::Binary: + return py::cast(b.get<std::int8_t>(coordIdx)); + case DataType::Octo_Binary: + return py::cast(b.get<std::int8_t>(coordIdx)); case DataType::Int16: return py::cast(b.get<std::int16_t>(coordIdx)); case DataType::Int32: @@ -390,6 +434,18 @@ void init_Tensor(py::module& m){ return py::cast(b.get<std::int64_t>(coordIdx)); case DataType::UInt8: return py::cast(b.get<std::uint8_t>(coordIdx)); + case DataType::UInt4: + return py::cast(b.get<std::uint8_t>(coordIdx)); + case DataType::Dual_UInt4: + return py::cast(b.get<std::uint8_t>(coordIdx)); + case DataType::UInt3: + return py::cast(b.get<std::uint8_t>(coordIdx)); + case DataType::Dual_UInt3: + return py::cast(b.get<std::uint8_t>(coordIdx)); + case DataType::UInt2: + return py::cast(b.get<std::uint8_t>(coordIdx)); + case DataType::Dual_UInt2: + return py::cast(b.get<std::uint8_t>(coordIdx)); case DataType::UInt16: return py::cast(b.get<std::uint16_t>(coordIdx)); case DataType::UInt32: @@ -412,6 +468,30 @@ void init_Tensor(py::module& m){ case DataType::Int8: b.set(idx, castToNativeType<std::int8_t>(val)); break; + case DataType::Int4: + b.set(idx, castToNativeType<std::int8_t>(val)); + break; + case DataType::Dual_Int4: + b.set(idx, castToNativeType<std::int8_t>(val)); + break; + case DataType::Int3: + b.set(idx, castToNativeType<std::int8_t>(val)); + break; + case DataType::Dual_Int3: + b.set(idx, castToNativeType<std::int8_t>(val)); + break; + case DataType::Int2: + b.set(idx, castToNativeType<std::int8_t>(val)); + break; + case DataType::Quad_Int2: + b.set(idx, castToNativeType<std::int8_t>(val)); + break; + case DataType::Binary: + b.set(idx, castToNativeType<std::int8_t>(val)); + break; + case DataType::Octo_Binary: + b.set(idx, castToNativeType<std::int8_t>(val)); + break; case DataType::Int16: b.set(idx, castToNativeType<std::int16_t>(val)); break; @@ -424,6 +504,24 @@ void init_Tensor(py::module& m){ case DataType::UInt8: b.set(idx, castToNativeType<std::uint8_t>(val)); break; + case DataType::UInt4: + b.set(idx, castToNativeType<std::uint8_t>(val)); + break; + case DataType::Dual_UInt4: + b.set(idx, castToNativeType<std::uint8_t>(val)); + break; + case DataType::UInt3: + b.set(idx, castToNativeType<std::uint8_t>(val)); + break; + case DataType::Dual_UInt3: + b.set(idx, castToNativeType<std::uint8_t>(val)); + break; + case DataType::UInt2: + b.set(idx, castToNativeType<std::uint8_t>(val)); + break; + case DataType::Dual_UInt2: + b.set(idx, castToNativeType<std::uint8_t>(val)); + break; case DataType::UInt16: b.set(idx, castToNativeType<std::uint16_t>(val)); break; @@ -450,6 +548,30 @@ void init_Tensor(py::module& m){ case DataType::Int8: b.set(coordIdx, castToNativeType<std::int8_t>(val)); break; + case DataType::Int4: + b.set(coordIdx, castToNativeType<std::int8_t>(val)); + break; + case DataType::Dual_Int4: + b.set(coordIdx, castToNativeType<std::int8_t>(val)); + break; + case DataType::Int3: + b.set(coordIdx, castToNativeType<std::int8_t>(val)); + break; + case DataType::Dual_Int3: + b.set(coordIdx, castToNativeType<std::int8_t>(val)); + break; + case DataType::Int2: + b.set(coordIdx, castToNativeType<std::int8_t>(val)); + break; + case DataType::Quad_Int2: + b.set(coordIdx, castToNativeType<std::int8_t>(val)); + break; + case DataType::Binary: + b.set(coordIdx, castToNativeType<std::int8_t>(val)); + break; + case DataType::Octo_Binary: + b.set(coordIdx, castToNativeType<std::int8_t>(val)); + break; case DataType::Int16: b.set(coordIdx, castToNativeType<std::int16_t>(val)); break; @@ -462,6 +584,24 @@ void init_Tensor(py::module& m){ case DataType::UInt8: b.set(coordIdx, castToNativeType<std::uint8_t>(val)); break; + case DataType::UInt4: + b.set(coordIdx, castToNativeType<std::uint8_t>(val)); + break; + case DataType::Dual_UInt4: + b.set(coordIdx, castToNativeType<std::uint8_t>(val)); + break; + case DataType::UInt3: + b.set(coordIdx, castToNativeType<std::uint8_t>(val)); + break; + case DataType::Dual_UInt3: + b.set(coordIdx, castToNativeType<std::uint8_t>(val)); + break; + case DataType::UInt2: + b.set(coordIdx, castToNativeType<std::uint8_t>(val)); + break; + case DataType::Dual_UInt2: + b.set(coordIdx, castToNativeType<std::uint8_t>(val)); + break; case DataType::UInt16: b.set(coordIdx, castToNativeType<std::uint16_t>(val)); break; @@ -517,6 +657,30 @@ void init_Tensor(py::module& m){ case DataType::UInt2: dataFormatDescriptor = py::format_descriptor<std::uint8_t>::format(); break; + case DataType::Dual_Int4: + dataFormatDescriptor = py::format_descriptor<std::int8_t>::format(); + break; + case DataType::Dual_UInt4: + dataFormatDescriptor = py::format_descriptor<std::uint8_t>::format(); + break; + case DataType::Dual_Int3: + dataFormatDescriptor = py::format_descriptor<std::int8_t>::format(); + break; + case DataType::Dual_UInt3: + dataFormatDescriptor = py::format_descriptor<std::uint8_t>::format(); + break; + case DataType::Quad_Int2: + dataFormatDescriptor = py::format_descriptor<std::int8_t>::format(); + break; + case DataType::Quad_UInt2: + dataFormatDescriptor = py::format_descriptor<std::uint8_t>::format(); + break; + case DataType::Binary: + dataFormatDescriptor = py::format_descriptor<std::int8_t>::format(); + break; + case DataType::Octo_Binary: + dataFormatDescriptor = py::format_descriptor<std::int8_t>::format(); + break; case DataType::Int8: dataFormatDescriptor = py::format_descriptor<std::int8_t>::format(); break; diff --git a/src/data/Tensor.cpp b/src/data/Tensor.cpp index d406f8d0d9f0a9b412465d8524a5d4410f9d48d1..7eb09fe6d91d7a37cef89c98a66ef190332c8a21 100644 --- a/src/data/Tensor.cpp +++ b/src/data/Tensor.cpp @@ -251,6 +251,22 @@ std::string Tensor::toString() const { return std::to_string(static_cast<float*>(ptr)[idx]); case DataType::Float16: return std::to_string(static_cast<half_float::half*>(ptr)[idx]); + case DataType::Binary: + return std::to_string(static_cast<int8_t*>(ptr)[idx]); + case DataType::Octo_Binary: + return std::to_string(static_cast<int8_t*>(ptr)[idx]); + case DataType::Dual_Int4: + return std::to_string(static_cast<int8_t*>(ptr)[idx]); + case DataType::Dual_UInt4: + return std::to_string(static_cast<uint8_t*>(ptr)[idx]); + case DataType::Dual_Int3: + return std::to_string(static_cast<int8_t*>(ptr)[idx]); + case DataType::Dual_UInt3: + return std::to_string(static_cast<uint8_t*>(ptr)[idx]); + case DataType::Quad_Int2: + return std::to_string(static_cast<int8_t*>(ptr)[idx]); + case DataType::Quad_UInt2: + return std::to_string(static_cast<uint8_t*>(ptr)[idx]); case DataType::Int4: return std::to_string(static_cast<int8_t*>(ptr)[idx]); case DataType::UInt4: