diff --git a/python_binding/data/pybind_Database.cpp b/python_binding/data/pybind_Database.cpp index 903e692ca3d14d6ae25f0d6f151b1b08d557d924..4bc28a19d350236933c3b6c139e9e3a4d980fa3f 100644 --- a/python_binding/data/pybind_Database.cpp +++ b/python_binding/data/pybind_Database.cpp @@ -1,13 +1,40 @@ #include <pybind11/pybind11.h> +#include <pybind11/stl.h> + #include "aidge/data/Database.hpp" +#include "aidge/data/Tensor.hpp" namespace py = pybind11; namespace Aidge { -void init_Database(py::module& m){ +/** + * @brief Trampoline class for binding + * + */ +class pyDatabase : public Database { + public: + using Database::Database; // Inherit constructors - py::class_<Database, std::shared_ptr<Database>>(m,"Database"); + std::vector<std::shared_ptr<Tensor>> getItem( + const std::size_t index) const override { + PYBIND11_OVERRIDE_PURE_NAME(std::vector<std::shared_ptr<Tensor>>, Database, + "get_item", getItem, index); + } + std::size_t getLen() const noexcept override { + PYBIND11_OVERRIDE_PURE_NAME(std::size_t, Database, "len", getLen); + } + std::size_t getNbModalities() const noexcept override { + PYBIND11_OVERRIDE_PURE_NAME(std::size_t, Database, "get_nb_modalities", + getNbModalities); + } +}; - -} +void init_Database(py::module& m) { + py::class_<Database, std::shared_ptr<Database>, pyDatabase>( + m, "Database", py::dynamic_attr()) + .def(py::init<>()) + .def("get_item", &Database::getItem) + .def("len", &Database::getLen) + .def("get_nb_modalities", &Database::getNbModalities); } +} // namespace Aidge