From 0279070b75ce5ea5fa77e5f30c3cb25047f3e09a Mon Sep 17 00:00:00 2001
From: hrouis <houssemeddine.rouis92@gmail.com>
Date: Wed, 24 Jul 2024 17:12:50 +0200
Subject: [PATCH] add backend cuda support for DataProvider

---
 include/aidge/data/DataProvider.hpp         | 4 +++-
 python_binding/data/pybind_DataProvider.cpp | 2 +-
 src/data/DataProvider.cpp                   | 7 +++++--
 3 files changed, 9 insertions(+), 4 deletions(-)

diff --git a/include/aidge/data/DataProvider.hpp b/include/aidge/data/DataProvider.hpp
index 62d10a698..f3e0ff43d 100644
--- a/include/aidge/data/DataProvider.hpp
+++ b/include/aidge/data/DataProvider.hpp
@@ -56,6 +56,8 @@ private:
     // Size of the Last batch
     std::size_t mLastBatchSize;
 
+    std::string mBackend;
+
     // Store each modality dimensions, backend and type
     std::vector<std::vector<std::size_t>> mDataDims;
     std::vector<std::string> mDataBackends;
@@ -67,7 +69,7 @@ public:
      * @param database database from which to load the data.
      * @param batchSize number of data samples per batch.
      */
-    DataProvider(const Database& database, const std::size_t batchSize, const bool shuffle = false, const bool dropLast = false);
+    DataProvider(const Database& database, const std::size_t batchSize, const std::string& backend = "cpu", const bool shuffle = false, const bool dropLast = false);
 
 public:
     /**
diff --git a/python_binding/data/pybind_DataProvider.cpp b/python_binding/data/pybind_DataProvider.cpp
index 2f652aff5..c0b7218cd 100644
--- a/python_binding/data/pybind_DataProvider.cpp
+++ b/python_binding/data/pybind_DataProvider.cpp
@@ -27,7 +27,7 @@ std::vector<std::shared_ptr<Aidge::Tensor>> DataProvider::next() {
 void init_DataProvider(py::module& m){
 
     py::class_<DataProvider, std::shared_ptr<DataProvider>>(m, "DataProvider")
-        .def(py::init<Database&, std::size_t, bool, bool>(), py::arg("database"), py::arg("batch_size"), py::arg("shuffle"), py::arg("drop_last"))
+        .def(py::init<Database&, std::size_t, std::string, bool, bool>(), py::arg("database"), py::arg("batch_size"), py::arg("backend"), py::arg("shuffle"), py::arg("drop_last"))
         .def("__iter__", &DataProvider::iter)
         .def("__next__", &DataProvider::next)
         .def("__len__", &DataProvider::getNbBatch);
diff --git a/src/data/DataProvider.cpp b/src/data/DataProvider.cpp
index fc6b842ed..7f4eb71aa 100644
--- a/src/data/DataProvider.cpp
+++ b/src/data/DataProvider.cpp
@@ -23,9 +23,10 @@
 #include "aidge/utils/Random.hpp"
 
 
-Aidge::DataProvider::DataProvider(const Aidge::Database& database, const std::size_t batchSize, const bool shuffle, const bool dropLast)
+Aidge::DataProvider::DataProvider(const Aidge::Database& database, const std::size_t batchSize, const std::string& backend, const bool shuffle, const bool dropLast)
     : mDatabase(database),
       mBatchSize(batchSize),
+      mBackend(backend),
       mShuffle(shuffle),
       mDropLast(dropLast),
       mNumberModality(database.getItem(0).size()),
@@ -63,7 +64,7 @@ std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch() con
         dataBatchDims[i].insert(dataBatchDims[i].begin(), current_batch_size);
         auto batchData = std::make_shared<Tensor>();
         batchData->resize(dataBatchDims[i]);
-        batchData->setBackend("cpu");
+        batchData->setBackend(mBackend);
         batchData->setDataType(mDataTypes[i]);
         batchTensors.push_back(batchData);
     }
@@ -78,6 +79,8 @@ std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch() con
 
         // Browse each modularity in the database item
         for (std::size_t j = 0; j < mNumberModality; ++j) {
+
+            dataItem[j]->setBackend(mBackend);
             auto dataSample = dataItem[j];
 
             // Assert tensor sizes
-- 
GitLab