Skip to content
Snippets Groups Projects

[DataType] Complete python binding to map all dtype values

@@ -17,20 +17,42 @@ namespace py = pybind11;
@@ -17,20 +17,42 @@ namespace py = pybind11;
namespace Aidge {
namespace Aidge {
void init_Data(py::module& m){
void init_Data(py::module& m){
// TODO : extend with more values !
// Define enumeration names for python as lowercase dtype name
py::enum_<DataType>(m, "dtype")
// This defined enum names compatible with basic numpy dtype
.value("float64", DataType::Float64)
// name such as: float32, flot64, [u]int32, [u]int64, ...
.value("float32", DataType::Float32)
auto python_enum_name = [](const DataType& dtype) {
.value("float16", DataType::Float16)
auto str_lower = [](std::string& str) {
.value("int8", DataType::Int8)
std::transform(str.begin(), str.end(), str.begin(),
.value("int16", DataType::Int16)
[](unsigned char c){
.value("int32", DataType::Int32)
return std::tolower(c);
.value("int64", DataType::Int64)
});
.value("uint8", DataType::UInt8)
};
.value("uint16", DataType::UInt16)
auto dtype_name = std::string(Aidge::format_as(dtype));
.value("uint32", DataType::UInt32)
str_lower(dtype_name);
.value("uint64", DataType::UInt64)
return dtype_name;
;
};
 
// Auto generate enumeration names from lowercase dtype strings
 
std::vector<std::string> enum_names;
 
for (auto dtype_str : EnumStrings<Aidge::DataType>::data) {
 
auto dtype = static_cast<DataType>(enum_names.size());
 
auto enum_name = python_enum_name(dtype);
 
enum_names.push_back(enum_name);
 
}
 
 
// Define python side enumeration aidge_core.dtype
 
auto e_dtype = py::enum_<DataType>(m, "dtype");
 
 
// Add enum value for each enum name
 
for (std::size_t idx = 0; idx < enum_names.size(); idx++) {
 
e_dtype.value(enum_names[idx].c_str(), static_cast<DataType>(idx));
 
}
 
 
// Define str() to return the bare enum name value, it allows
 
// to compare directly for instance str(tensor.dtype())
 
// with str(nparray.dtype)
 
e_dtype.def("__str__", [enum_names](const DataType& dtype) {
 
return enum_names[static_cast<int>(dtype)];
 
}, py::prepend());;
py::class_<Data, std::shared_ptr<Data>>(m,"Data");
py::class_<Data, std::shared_ptr<Data>>(m,"Data");
Loading