From 11821374e9828c73e0c28adf28cc8ceaae42ef13 Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Tue, 26 Mar 2024 12:42:47 +0000 Subject: [PATCH] Add binding to overload database. --- python_binding/data/pybind_Database.cpp | 35 ++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/python_binding/data/pybind_Database.cpp b/python_binding/data/pybind_Database.cpp index 903e692ca..4bc28a19d 100644 --- a/python_binding/data/pybind_Database.cpp +++ b/python_binding/data/pybind_Database.cpp @@ -1,13 +1,40 @@ #include <pybind11/pybind11.h> +#include <pybind11/stl.h> + #include "aidge/data/Database.hpp" +#include "aidge/data/Tensor.hpp" namespace py = pybind11; namespace Aidge { -void init_Database(py::module& m){ +/** + * @brief Trampoline class for binding + * + */ +class pyDatabase : public Database { + public: + using Database::Database; // Inherit constructors - py::class_<Database, std::shared_ptr<Database>>(m,"Database"); + std::vector<std::shared_ptr<Tensor>> getItem( + const std::size_t index) const override { + PYBIND11_OVERRIDE_PURE_NAME(std::vector<std::shared_ptr<Tensor>>, Database, + "get_item", getItem, index); + } + std::size_t getLen() const noexcept override { + PYBIND11_OVERRIDE_PURE_NAME(std::size_t, Database, "len", getLen); + } + std::size_t getNbModalities() const noexcept override { + PYBIND11_OVERRIDE_PURE_NAME(std::size_t, Database, "get_nb_modalities", + getNbModalities); + } +}; - -} +void init_Database(py::module& m) { + py::class_<Database, std::shared_ptr<Database>, pyDatabase>( + m, "Database", py::dynamic_attr()) + .def(py::init<>()) + .def("get_item", &Database::getItem) + .def("len", &Database::getLen) + .def("get_nb_modalities", &Database::getNbModalities); } +} // namespace Aidge -- GitLab