From fe867fc3febf10b497e98a773f108bdf11d0aaa7 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Wed, 29 Jan 2025 18:39:02 +0100 Subject: [PATCH] Added missing binding for Elts struct --- include/aidge/scheduler/ProdConso.hpp | 4 + python_binding/data/pybind_Elts.cpp | 85 +++++++++++++++++++ python_binding/pybind_core.cpp | 2 + python_binding/scheduler/pybind_ProdConso.cpp | 1 + 4 files changed, 92 insertions(+) create mode 100644 python_binding/data/pybind_Elts.cpp diff --git a/include/aidge/scheduler/ProdConso.hpp b/include/aidge/scheduler/ProdConso.hpp index f30e00afa..bc42cb36c 100644 --- a/include/aidge/scheduler/ProdConso.hpp +++ b/include/aidge/scheduler/ProdConso.hpp @@ -34,6 +34,10 @@ public: return std::make_unique<ProdConso>(op, true); } + const Operator& getOperator() const noexcept { + return mOp; + } + /** * @brief Minimum amount of data from a specific input required by the * implementation to be run. diff --git a/python_binding/data/pybind_Elts.cpp b/python_binding/data/pybind_Elts.cpp new file mode 100644 index 000000000..59a8211e2 --- /dev/null +++ b/python_binding/data/pybind_Elts.cpp @@ -0,0 +1,85 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <algorithm> // std::transform +#include <cctype> // std::tolower +#include <string> // std::string +#include <vector> + +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> +#include <pybind11/operators.h> + +#include "aidge/data/Elts.hpp" + +namespace py = pybind11; +namespace Aidge { + +template <class T> +void bindEnum(py::module& m, const std::string& name) { + // Define enumeration names for python as lowercase type name + // This defined enum names compatible with basic numpy type + // name such as: float32, flot64, [u]int32, [u]int64, ... + auto python_enum_name = [](const T& type) { + auto str_lower = [](std::string& str) { + std::transform(str.begin(), str.end(), str.begin(), + [](unsigned char c){ + return std::tolower(c); + }); + }; + auto type_name = std::string(Aidge::format_as(type)); + str_lower(type_name); + return type_name; + }; + + // Auto generate enumeration names from lowercase type strings + std::vector<std::string> enum_names; + for (auto type_str : EnumStrings<T>::data) { + auto type = static_cast<T>(enum_names.size()); + auto enum_name = python_enum_name(type); + enum_names.push_back(enum_name); + } + + // Define python side enumeration aidge_core.type + auto e_type = py::enum_<T>(m, name.c_str()); + + // Add enum value for each enum name + for (std::size_t idx = 0; idx < enum_names.size(); idx++) { + e_type.value(enum_names[idx].c_str(), static_cast<T>(idx)); + } + + // Define str() to return the bare enum name value, it allows + // to compare directly for instance str(tensor.type()) + // with str(nparray.type) + e_type.def("__str__", [enum_names](const T& type) { + return enum_names[static_cast<int>(type)]; + }, py::prepend()); +} + +void init_Elts(py::module& m) { + bindEnum<Elts_t::EltType>(m, "EltType"); + m.def("format_as", (const char* (*)(Elts_t::EltType)) &format_as, py::arg("elt")); + + py::class_<Elts_t, std::shared_ptr<Elts_t>>( + m, "Elts_t", py::dynamic_attr()) + .def_static("none_elts", &Elts_t::NoneElts) + .def_static("data_elts", &Elts_t::DataElts, py::arg("data"), py::arg("token") = 1) + .def_static("token_elts", &Elts_t::TokenElts, py::arg("token")) + .def_readwrite("data", &Elts_t::data) + .def_readwrite("token", &Elts_t::token) + .def_readwrite("type", &Elts_t::type) + .def(py::self + py::self) + .def(py::self += py::self) + .def(py::self < py::self) + .def(py::self > py::self); +} + +} // namespace Aidge diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp index 1f35373f3..cc6f0bf25 100644 --- a/python_binding/pybind_core.cpp +++ b/python_binding/pybind_core.cpp @@ -21,6 +21,7 @@ void init_Random(py::module&); void init_Data(py::module&); void init_DataFormat(py::module&); void init_DataType(py::module&); +void init_Elts(py::module&); void init_Database(py::module&); void init_DataProvider(py::module&); void init_Interpolation(py::module&); @@ -112,6 +113,7 @@ void init_Aidge(py::module& m) { init_Data(m); init_DataFormat(m); init_DataType(m); + init_Elts(m); init_Database(m); init_DataProvider(m); init_Interpolation(m); diff --git a/python_binding/scheduler/pybind_ProdConso.cpp b/python_binding/scheduler/pybind_ProdConso.cpp index abd6d5379..547e2258d 100644 --- a/python_binding/scheduler/pybind_ProdConso.cpp +++ b/python_binding/scheduler/pybind_ProdConso.cpp @@ -104,6 +104,7 @@ void init_ProdConso(py::module& m){ .def(py::init<const Operator&, bool>(), py::keep_alive<1, 1>(), py::keep_alive<1, 2>(), py::keep_alive<1,3>()) .def_static("default_model", &ProdConso::defaultModel) .def_static("in_place_model", &ProdConso::inPlaceModel) + .def("get_operator", &ProdConso::getOperator) .def("get_nb_required_data", &ProdConso::getNbRequiredData) .def("get_nb_required_protected", &ProdConso::getNbRequiredProtected) .def("get_required_memory", &ProdConso::getRequiredMemory) -- GitLab