Skip to content
Snippets Groups Projects
Commit e5fa197b authored by Thibault Allenet's avatar Thibault Allenet Committed by Olivier BICHLER
Browse files

Add type aidge::int4 with tensor implementation int8_t

parent b60ad289
No related branches found
No related tags found
2 merge requests!318[Upd] release verision 0.5.0,!262Low bit support for ARM Cortex-M export
This commit is part of merge request !262. Comments created here will be created in the context of that merge request.
...@@ -126,6 +126,7 @@ REGISTRAR(Tensor, {"cpu", DataType::Int64}, Aidge::TensorImpl_cpu<int64_t>::crea ...@@ -126,6 +126,7 @@ REGISTRAR(Tensor, {"cpu", DataType::Int64}, Aidge::TensorImpl_cpu<int64_t>::crea
REGISTRAR(Tensor, {"cpu", DataType::Int32}, Aidge::TensorImpl_cpu<int32_t>::create); REGISTRAR(Tensor, {"cpu", DataType::Int32}, Aidge::TensorImpl_cpu<int32_t>::create);
REGISTRAR(Tensor, {"cpu", DataType::Int16}, Aidge::TensorImpl_cpu<int16_t>::create); REGISTRAR(Tensor, {"cpu", DataType::Int16}, Aidge::TensorImpl_cpu<int16_t>::create);
REGISTRAR(Tensor, {"cpu", DataType::Int8}, Aidge::TensorImpl_cpu<int8_t>::create); REGISTRAR(Tensor, {"cpu", DataType::Int8}, Aidge::TensorImpl_cpu<int8_t>::create);
REGISTRAR(Tensor, {"cpu", DataType::Int4}, Aidge::TensorImpl_cpu<int8_t>::create);
REGISTRAR(Tensor, {"cpu", DataType::UInt64}, Aidge::TensorImpl_cpu<uint64_t>::create); REGISTRAR(Tensor, {"cpu", DataType::UInt64}, Aidge::TensorImpl_cpu<uint64_t>::create);
REGISTRAR(Tensor, {"cpu", DataType::UInt32}, Aidge::TensorImpl_cpu<uint32_t>::create); REGISTRAR(Tensor, {"cpu", DataType::UInt32}, Aidge::TensorImpl_cpu<uint32_t>::create);
REGISTRAR(Tensor, {"cpu", DataType::UInt16}, Aidge::TensorImpl_cpu<uint16_t>::create); REGISTRAR(Tensor, {"cpu", DataType::UInt16}, Aidge::TensorImpl_cpu<uint16_t>::create);
......
...@@ -104,10 +104,17 @@ private: ...@@ -104,10 +104,17 @@ private:
} }
namespace { namespace {
// Define a distinct type alias for Int4
struct Int4Type {
std::int8_t value;
};
template <typename T> struct NativeType { static const Aidge::DataType type; }; template <typename T> struct NativeType { static const Aidge::DataType type; };
template <> const Aidge::DataType NativeType<double>::type = Aidge::DataType::Float64; template <> const Aidge::DataType NativeType<double>::type = Aidge::DataType::Float64;
template <> const Aidge::DataType NativeType<float>::type = Aidge::DataType::Float32; template <> const Aidge::DataType NativeType<float>::type = Aidge::DataType::Float32;
template <> const Aidge::DataType NativeType<half_float::half>::type = Aidge::DataType::Float16; template <> const Aidge::DataType NativeType<half_float::half>::type = Aidge::DataType::Float16;
template <> const Aidge::DataType NativeType<Int4Type>::type = Aidge::DataType::Int4;
template <> const Aidge::DataType NativeType<std::int8_t>::type = Aidge::DataType::Int8; template <> const Aidge::DataType NativeType<std::int8_t>::type = Aidge::DataType::Int8;
template <> const Aidge::DataType NativeType<std::int16_t>::type = Aidge::DataType::Int16; template <> const Aidge::DataType NativeType<std::int16_t>::type = Aidge::DataType::Int16;
template <> const Aidge::DataType NativeType<std::int32_t>::type = Aidge::DataType::Int32; template <> const Aidge::DataType NativeType<std::int32_t>::type = Aidge::DataType::Int32;
...@@ -134,6 +141,7 @@ template <Aidge::DataType D> struct cpptype { ...@@ -134,6 +141,7 @@ template <Aidge::DataType D> struct cpptype {
template <> struct cpptype<Aidge::DataType::Float16> { using type = half_float::half; }; template <> struct cpptype<Aidge::DataType::Float16> { using type = half_float::half; };
template <> struct cpptype<Aidge::DataType::Float32> { using type = float; }; template <> struct cpptype<Aidge::DataType::Float32> { using type = float; };
template <> struct cpptype<Aidge::DataType::Float64> { using type = double; }; template <> struct cpptype<Aidge::DataType::Float64> { using type = double; };
template <> struct cpptype<Aidge::DataType::Int4> { using type = Int4Type; };
template <> struct cpptype<Aidge::DataType::Int8> { using type = std::int8_t; }; template <> struct cpptype<Aidge::DataType::Int8> { using type = std::int8_t; };
template <> struct cpptype<Aidge::DataType::Int16> { using type = std::int16_t; }; template <> struct cpptype<Aidge::DataType::Int16> { using type = std::int16_t; };
template <> struct cpptype<Aidge::DataType::Int32> { using type = std::int32_t; }; template <> struct cpptype<Aidge::DataType::Int32> { using type = std::int32_t; };
...@@ -144,6 +152,9 @@ template <> struct cpptype<Aidge::DataType::UInt32> { using type = std::uint32_t ...@@ -144,6 +152,9 @@ template <> struct cpptype<Aidge::DataType::UInt32> { using type = std::uint32_t
template <> struct cpptype<Aidge::DataType::UInt64> { using type = std::uint64_t; }; template <> struct cpptype<Aidge::DataType::UInt64> { using type = std::uint64_t; };
template <Aidge::DataType D> using cpptype_t = typename cpptype<D>::type; template <Aidge::DataType D> using cpptype_t = typename cpptype<D>::type;
} }
......
...@@ -226,6 +226,8 @@ static T castToNativeType(const py::object val_obj) { ...@@ -226,6 +226,8 @@ static T castToNativeType(const py::object val_obj) {
DataType dtype; DataType dtype;
getConservativeNativeVal(val_obj, &val, &dtype); getConservativeNativeVal(val_obj, &val, &dtype);
switch (dtype) { switch (dtype) {
case DataType::Int4:
return (T)val.i8;
case DataType::Int8: case DataType::Int8:
return (T)val.i8; return (T)val.i8;
case DataType::Int16: case DataType::Int16:
...@@ -497,6 +499,9 @@ void init_Tensor(py::module& m){ ...@@ -497,6 +499,9 @@ void init_Tensor(py::module& m){
case DataType::Float32: case DataType::Float32:
dataFormatDescriptor = py::format_descriptor<float>::format(); dataFormatDescriptor = py::format_descriptor<float>::format();
break;; break;;
case DataType::Int4:
dataFormatDescriptor = py::format_descriptor<std::int8_t>::format();
break;
case DataType::Int8: case DataType::Int8:
dataFormatDescriptor = py::format_descriptor<std::int8_t>::format(); dataFormatDescriptor = py::format_descriptor<std::int8_t>::format();
break; break;
......
...@@ -250,6 +250,8 @@ std::string Aidge::Tensor::toString() const { ...@@ -250,6 +250,8 @@ std::string Aidge::Tensor::toString() const {
return std::to_string(static_cast<float*>(ptr)[idx]); return std::to_string(static_cast<float*>(ptr)[idx]);
case DataType::Float16: case DataType::Float16:
return std::to_string(static_cast<half_float::half*>(ptr)[idx]); return std::to_string(static_cast<half_float::half*>(ptr)[idx]);
case DataType::Int4:
return std::to_string(static_cast<int8_t*>(ptr)[idx]);
case DataType::Int8: case DataType::Int8:
return std::to_string(static_cast<int8_t*>(ptr)[idx]); return std::to_string(static_cast<int8_t*>(ptr)[idx]);
case DataType::Int16: case DataType::Int16:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment