diff --git a/python_binding/data/pybind_Data.cpp b/python_binding/data/pybind_Data.cpp index c6595360b17ee08eaa82d483987914adc67b60a8..1d4eae0776b66a16e6472a51661b22fe281e6f6b 100644 --- a/python_binding/data/pybind_Data.cpp +++ b/python_binding/data/pybind_Data.cpp @@ -17,20 +17,42 @@ namespace py = pybind11; namespace Aidge { void init_Data(py::module& m){ - // TODO : extend with more values ! - py::enum_<DataType>(m, "dtype") - .value("float64", DataType::Float64) - .value("float32", DataType::Float32) - .value("float16", DataType::Float16) - .value("int8", DataType::Int8) - .value("int16", DataType::Int16) - .value("int32", DataType::Int32) - .value("int64", DataType::Int64) - .value("uint8", DataType::UInt8) - .value("uint16", DataType::UInt16) - .value("uint32", DataType::UInt32) - .value("uint64", DataType::UInt64) - ; + // Define enumeration names for python as lowercase dtype name + // This defined enum names compatible with basic numpy dtype + // name such as: float32, flot64, [u]int32, [u]int64, ... + auto python_enum_name = [](const DataType& dtype) { + auto str_lower = [](std::string& str) { + std::transform(str.begin(), str.end(), str.begin(), + [](unsigned char c){ + return std::tolower(c); + }); + }; + auto dtype_name = std::string(Aidge::format_as(dtype)); + str_lower(dtype_name); + return dtype_name; + }; + // Auto generate enumeration names from lowercase dtype strings + std::vector<std::string> enum_names; + for (auto dtype_str : EnumStrings<Aidge::DataType>::data) { + auto dtype = static_cast<DataType>(enum_names.size()); + auto enum_name = python_enum_name(dtype); + enum_names.push_back(enum_name); + } + + // Define python side enumeration aidge_core.dtype + auto e_dtype = py::enum_<DataType>(m, "dtype"); + + // Add enum value for each enum name + for (std::size_t idx = 0; idx < enum_names.size(); idx++) { + e_dtype.value(enum_names[idx].c_str(), static_cast<DataType>(idx)); + } + + // Define str() to return the bare enum name value, it allows + // to compare directly for instance str(tensor.dtype()) + // with str(nparray.dtype) + e_dtype.def("__str__", [enum_names](const DataType& dtype) { + return enum_names[static_cast<int>(dtype)]; + }, py::prepend());; py::class_<Data, std::shared_ptr<Data>>(m,"Data");