Skip to content
Snippets Groups Projects
Commit e754cb7b authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Merge branch 'user/cguillon/dev/dtype-enums' into 'dev'

[DataType] Complete python binding to map all dtype values

See merge request !168
parents f4c537de 25666644
No related branches found
No related tags found
2 merge requests!212Version 0.3.0,!168[DataType] Complete python binding to map all dtype values
Pipeline #50775 passed
......@@ -17,20 +17,42 @@ namespace py = pybind11;
namespace Aidge {
void init_Data(py::module& m){
// TODO : extend with more values !
py::enum_<DataType>(m, "dtype")
.value("float64", DataType::Float64)
.value("float32", DataType::Float32)
.value("float16", DataType::Float16)
.value("int8", DataType::Int8)
.value("int16", DataType::Int16)
.value("int32", DataType::Int32)
.value("int64", DataType::Int64)
.value("uint8", DataType::UInt8)
.value("uint16", DataType::UInt16)
.value("uint32", DataType::UInt32)
.value("uint64", DataType::UInt64)
;
// Define enumeration names for python as lowercase dtype name
// This defined enum names compatible with basic numpy dtype
// name such as: float32, flot64, [u]int32, [u]int64, ...
auto python_enum_name = [](const DataType& dtype) {
auto str_lower = [](std::string& str) {
std::transform(str.begin(), str.end(), str.begin(),
[](unsigned char c){
return std::tolower(c);
});
};
auto dtype_name = std::string(Aidge::format_as(dtype));
str_lower(dtype_name);
return dtype_name;
};
// Auto generate enumeration names from lowercase dtype strings
std::vector<std::string> enum_names;
for (auto dtype_str : EnumStrings<Aidge::DataType>::data) {
auto dtype = static_cast<DataType>(enum_names.size());
auto enum_name = python_enum_name(dtype);
enum_names.push_back(enum_name);
}
// Define python side enumeration aidge_core.dtype
auto e_dtype = py::enum_<DataType>(m, "dtype");
// Add enum value for each enum name
for (std::size_t idx = 0; idx < enum_names.size(); idx++) {
e_dtype.value(enum_names[idx].c_str(), static_cast<DataType>(idx));
}
// Define str() to return the bare enum name value, it allows
// to compare directly for instance str(tensor.dtype())
// with str(nparray.dtype)
e_dtype.def("__str__", [enum_names](const DataType& dtype) {
return enum_names[static_cast<int>(dtype)];
}, py::prepend());;
py::class_<Data, std::shared_ptr<Data>>(m,"Data");
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment