Forked from
Eclipse Projects / aidge / aidge_core
1376 commits behind the upstream repository.
-
Add python enum values for all defined dtypes as per the table of dtype names (lower case of the dtype names). Add __str__() binding such that str(tensor.dtype()) return bare enumeration names, suitable for comparison with numpy dtype names.
Add python enum values for all defined dtypes as per the table of dtype names (lower case of the dtype names). Add __str__() binding such that str(tensor.dtype()) return bare enumeration names, suitable for comparison with numpy dtype names.
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
pybind_Data.cpp 2.14 KiB
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include "aidge/data/Data.hpp"
namespace py = pybind11;
namespace Aidge {
void init_Data(py::module& m){
// Define enumeration names for python as lowercase dtype name
// This defined enum names compatible with basic numpy dtype
// name such as: float32, flot64, [u]int32, [u]int64, ...
auto python_enum_name = [](const DataType& dtype) {
auto str_lower = [](std::string& str) {
std::transform(str.begin(), str.end(), str.begin(),
[](unsigned char c){
return std::tolower(c);
});
};
auto dtype_name = std::string(Aidge::format_as(dtype));
str_lower(dtype_name);
return dtype_name;
};
// Auto generate enumeration names from lowercase dtype strings
std::vector<std::string> enum_names;
for (auto dtype_str : EnumStrings<Aidge::DataType>::data) {
auto dtype = static_cast<DataType>(enum_names.size());
auto enum_name = python_enum_name(dtype);
enum_names.push_back(enum_name);
}
// Define python side enumeration aidge_core.dtype
auto e_dtype = py::enum_<DataType>(m, "dtype");
// Add enum value for each enum name
for (std::size_t idx = 0; idx < enum_names.size(); idx++) {
e_dtype.value(enum_names[idx].c_str(), static_cast<DataType>(idx));
}
// Define str() to return the bare enum name value, it allows
// to compare directly for instance str(tensor.dtype())
// with str(nparray.dtype)
e_dtype.def("__str__", [enum_names](const DataType& dtype) {
return enum_names[static_cast<int>(dtype)];
}, py::prepend());;
py::class_<Data, std::shared_ptr<Data>>(m,"Data");
}
}