From 11821374e9828c73e0c28adf28cc8ceaae42ef13 Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Tue, 26 Mar 2024 12:42:47 +0000
Subject: [PATCH] Add binding to overload database.

---
 python_binding/data/pybind_Database.cpp | 35 ++++++++++++++++++++++---
 1 file changed, 31 insertions(+), 4 deletions(-)

diff --git a/python_binding/data/pybind_Database.cpp b/python_binding/data/pybind_Database.cpp
index 903e692ca..4bc28a19d 100644
--- a/python_binding/data/pybind_Database.cpp
+++ b/python_binding/data/pybind_Database.cpp
@@ -1,13 +1,40 @@
 #include <pybind11/pybind11.h>
+#include <pybind11/stl.h>
+
 #include "aidge/data/Database.hpp"
+#include "aidge/data/Tensor.hpp"
 
 namespace py = pybind11;
 namespace Aidge {
 
-void init_Database(py::module& m){
+/**
+ * @brief Trampoline class for binding
+ *
+ */
+class pyDatabase : public Database {
+   public:
+    using Database::Database;  // Inherit constructors
 
-    py::class_<Database, std::shared_ptr<Database>>(m,"Database");
+    std::vector<std::shared_ptr<Tensor>> getItem(
+        const std::size_t index) const override {
+        PYBIND11_OVERRIDE_PURE_NAME(std::vector<std::shared_ptr<Tensor>>, Database,
+                               "get_item", getItem, index);
+    }
+    std::size_t getLen() const noexcept override {
+        PYBIND11_OVERRIDE_PURE_NAME(std::size_t, Database, "len", getLen);
+    }
+    std::size_t getNbModalities() const noexcept override {
+        PYBIND11_OVERRIDE_PURE_NAME(std::size_t, Database, "get_nb_modalities",
+                               getNbModalities);
+    }
+};
 
-    
-}
+void init_Database(py::module& m) {
+    py::class_<Database, std::shared_ptr<Database>, pyDatabase>(
+        m, "Database", py::dynamic_attr())
+        .def(py::init<>())
+        .def("get_item", &Database::getItem)
+        .def("len", &Database::getLen)
+        .def("get_nb_modalities", &Database::getNbModalities);
 }
+}  // namespace Aidge
-- 
GitLab