Skip to content
Snippets Groups Projects
Commit fe867fc3 authored by Olivier BICHLER's avatar Olivier BICHLER Committed by Maxence Naud
Browse files

Added missing binding for Elts struct

parent 516ea062
No related branches found
No related tags found
2 merge requests!318[Upd] release verision 0.5.0,!316Improved scheduling tutorial
...@@ -34,6 +34,10 @@ public: ...@@ -34,6 +34,10 @@ public:
return std::make_unique<ProdConso>(op, true); return std::make_unique<ProdConso>(op, true);
} }
const Operator& getOperator() const noexcept {
return mOp;
}
/** /**
* @brief Minimum amount of data from a specific input required by the * @brief Minimum amount of data from a specific input required by the
* implementation to be run. * implementation to be run.
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <algorithm> // std::transform
#include <cctype> // std::tolower
#include <string> // std::string
#include <vector>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/operators.h>
#include "aidge/data/Elts.hpp"
namespace py = pybind11;
namespace Aidge {
template <class T>
void bindEnum(py::module& m, const std::string& name) {
// Define enumeration names for python as lowercase type name
// This defined enum names compatible with basic numpy type
// name such as: float32, flot64, [u]int32, [u]int64, ...
auto python_enum_name = [](const T& type) {
auto str_lower = [](std::string& str) {
std::transform(str.begin(), str.end(), str.begin(),
[](unsigned char c){
return std::tolower(c);
});
};
auto type_name = std::string(Aidge::format_as(type));
str_lower(type_name);
return type_name;
};
// Auto generate enumeration names from lowercase type strings
std::vector<std::string> enum_names;
for (auto type_str : EnumStrings<T>::data) {
auto type = static_cast<T>(enum_names.size());
auto enum_name = python_enum_name(type);
enum_names.push_back(enum_name);
}
// Define python side enumeration aidge_core.type
auto e_type = py::enum_<T>(m, name.c_str());
// Add enum value for each enum name
for (std::size_t idx = 0; idx < enum_names.size(); idx++) {
e_type.value(enum_names[idx].c_str(), static_cast<T>(idx));
}
// Define str() to return the bare enum name value, it allows
// to compare directly for instance str(tensor.type())
// with str(nparray.type)
e_type.def("__str__", [enum_names](const T& type) {
return enum_names[static_cast<int>(type)];
}, py::prepend());
}
void init_Elts(py::module& m) {
bindEnum<Elts_t::EltType>(m, "EltType");
m.def("format_as", (const char* (*)(Elts_t::EltType)) &format_as, py::arg("elt"));
py::class_<Elts_t, std::shared_ptr<Elts_t>>(
m, "Elts_t", py::dynamic_attr())
.def_static("none_elts", &Elts_t::NoneElts)
.def_static("data_elts", &Elts_t::DataElts, py::arg("data"), py::arg("token") = 1)
.def_static("token_elts", &Elts_t::TokenElts, py::arg("token"))
.def_readwrite("data", &Elts_t::data)
.def_readwrite("token", &Elts_t::token)
.def_readwrite("type", &Elts_t::type)
.def(py::self + py::self)
.def(py::self += py::self)
.def(py::self < py::self)
.def(py::self > py::self);
}
} // namespace Aidge
...@@ -21,6 +21,7 @@ void init_Random(py::module&); ...@@ -21,6 +21,7 @@ void init_Random(py::module&);
void init_Data(py::module&); void init_Data(py::module&);
void init_DataFormat(py::module&); void init_DataFormat(py::module&);
void init_DataType(py::module&); void init_DataType(py::module&);
void init_Elts(py::module&);
void init_Database(py::module&); void init_Database(py::module&);
void init_DataProvider(py::module&); void init_DataProvider(py::module&);
void init_Interpolation(py::module&); void init_Interpolation(py::module&);
...@@ -112,6 +113,7 @@ void init_Aidge(py::module& m) { ...@@ -112,6 +113,7 @@ void init_Aidge(py::module& m) {
init_Data(m); init_Data(m);
init_DataFormat(m); init_DataFormat(m);
init_DataType(m); init_DataType(m);
init_Elts(m);
init_Database(m); init_Database(m);
init_DataProvider(m); init_DataProvider(m);
init_Interpolation(m); init_Interpolation(m);
......
...@@ -104,6 +104,7 @@ void init_ProdConso(py::module& m){ ...@@ -104,6 +104,7 @@ void init_ProdConso(py::module& m){
.def(py::init<const Operator&, bool>(), py::keep_alive<1, 1>(), py::keep_alive<1, 2>(), py::keep_alive<1,3>()) .def(py::init<const Operator&, bool>(), py::keep_alive<1, 1>(), py::keep_alive<1, 2>(), py::keep_alive<1,3>())
.def_static("default_model", &ProdConso::defaultModel) .def_static("default_model", &ProdConso::defaultModel)
.def_static("in_place_model", &ProdConso::inPlaceModel) .def_static("in_place_model", &ProdConso::inPlaceModel)
.def("get_operator", &ProdConso::getOperator)
.def("get_nb_required_data", &ProdConso::getNbRequiredData) .def("get_nb_required_data", &ProdConso::getNbRequiredData)
.def("get_nb_required_protected", &ProdConso::getNbRequiredProtected) .def("get_nb_required_protected", &ProdConso::getNbRequiredProtected)
.def("get_required_memory", &ProdConso::getRequiredMemory) .def("get_required_memory", &ProdConso::getRequiredMemory)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment