From a343e1187404164d86fb4693a5b1dd037d165fb5 Mon Sep 17 00:00:00 2001
From: thibault allenet <thibault.allenet@cea.fr>
Date: Mon, 15 Jan 2024 14:59:26 +0000
Subject: [PATCH] Add pybind DataProvider

---
 python_binding/data/pybind_DataProvider.cpp | 22 +++++++++++++++++++++
 python_binding/pybind_core.cpp              |  4 +++-
 2 files changed, 25 insertions(+), 1 deletion(-)
 create mode 100644 python_binding/data/pybind_DataProvider.cpp

diff --git a/python_binding/data/pybind_DataProvider.cpp b/python_binding/data/pybind_DataProvider.cpp
new file mode 100644
index 000000000..dfdf18894
--- /dev/null
+++ b/python_binding/data/pybind_DataProvider.cpp
@@ -0,0 +1,22 @@
+#include <pybind11/pybind11.h>
+#include <pybind11/stl.h>
+#include "aidge/data/DataProvider.hpp"
+#include "aidge/data/Database.hpp"
+
+namespace py = pybind11;
+namespace Aidge {
+
+void init_DataProvider(py::module& m){
+
+    py::class_<DataProvider, std::shared_ptr<DataProvider>>(m, "DataProvider")
+          .def(py::init<Database&, std::size_t>(), py::arg("database"), py::arg("batchSize"))
+          .def("read_batch", &DataProvider::readBatch, py::arg("start_index"),
+          R"mydelimiter(
+          Return a batch of each data modality.
+
+          :param start_index: Database starting index to read the batch from
+          :type start_index: int
+          )mydelimiter");
+    
+}
+}
diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp
index ea61f05ad..f4dd72cc1 100644
--- a/python_binding/pybind_core.cpp
+++ b/python_binding/pybind_core.cpp
@@ -15,7 +15,8 @@ namespace py = pybind11;
 
 namespace Aidge {
 void init_Data(py::module&);
-void init_Database(py::module& m);
+void init_Database(py::module&);
+void init_DataProvider(py::module&);
 void init_Tensor(py::module&);
 void init_OperatorImpl(py::module&);
 void init_Attributes(py::module&);
@@ -67,6 +68,7 @@ void init_TensorUtils(py::module&);
 void init_Aidge(py::module& m){
     init_Data(m);
     init_Database(m);
+    init_DataProvider(m);
     init_Tensor(m);
 
     init_Node(m);
-- 
GitLab