From 0279070b75ce5ea5fa77e5f30c3cb25047f3e09a Mon Sep 17 00:00:00 2001 From: hrouis <houssemeddine.rouis92@gmail.com> Date: Wed, 24 Jul 2024 17:12:50 +0200 Subject: [PATCH] add backend cuda support for DataProvider --- include/aidge/data/DataProvider.hpp | 4 +++- python_binding/data/pybind_DataProvider.cpp | 2 +- src/data/DataProvider.cpp | 7 +++++-- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/include/aidge/data/DataProvider.hpp b/include/aidge/data/DataProvider.hpp index 62d10a698..f3e0ff43d 100644 --- a/include/aidge/data/DataProvider.hpp +++ b/include/aidge/data/DataProvider.hpp @@ -56,6 +56,8 @@ private: // Size of the Last batch std::size_t mLastBatchSize; + std::string mBackend; + // Store each modality dimensions, backend and type std::vector<std::vector<std::size_t>> mDataDims; std::vector<std::string> mDataBackends; @@ -67,7 +69,7 @@ public: * @param database database from which to load the data. * @param batchSize number of data samples per batch. */ - DataProvider(const Database& database, const std::size_t batchSize, const bool shuffle = false, const bool dropLast = false); + DataProvider(const Database& database, const std::size_t batchSize, const std::string& backend = "cpu", const bool shuffle = false, const bool dropLast = false); public: /** diff --git a/python_binding/data/pybind_DataProvider.cpp b/python_binding/data/pybind_DataProvider.cpp index 2f652aff5..c0b7218cd 100644 --- a/python_binding/data/pybind_DataProvider.cpp +++ b/python_binding/data/pybind_DataProvider.cpp @@ -27,7 +27,7 @@ std::vector<std::shared_ptr<Aidge::Tensor>> DataProvider::next() { void init_DataProvider(py::module& m){ py::class_<DataProvider, std::shared_ptr<DataProvider>>(m, "DataProvider") - .def(py::init<Database&, std::size_t, bool, bool>(), py::arg("database"), py::arg("batch_size"), py::arg("shuffle"), py::arg("drop_last")) + .def(py::init<Database&, std::size_t, std::string, bool, bool>(), py::arg("database"), py::arg("batch_size"), py::arg("backend"), py::arg("shuffle"), py::arg("drop_last")) .def("__iter__", &DataProvider::iter) .def("__next__", &DataProvider::next) .def("__len__", &DataProvider::getNbBatch); diff --git a/src/data/DataProvider.cpp b/src/data/DataProvider.cpp index fc6b842ed..7f4eb71aa 100644 --- a/src/data/DataProvider.cpp +++ b/src/data/DataProvider.cpp @@ -23,9 +23,10 @@ #include "aidge/utils/Random.hpp" -Aidge::DataProvider::DataProvider(const Aidge::Database& database, const std::size_t batchSize, const bool shuffle, const bool dropLast) +Aidge::DataProvider::DataProvider(const Aidge::Database& database, const std::size_t batchSize, const std::string& backend, const bool shuffle, const bool dropLast) : mDatabase(database), mBatchSize(batchSize), + mBackend(backend), mShuffle(shuffle), mDropLast(dropLast), mNumberModality(database.getItem(0).size()), @@ -63,7 +64,7 @@ std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch() con dataBatchDims[i].insert(dataBatchDims[i].begin(), current_batch_size); auto batchData = std::make_shared<Tensor>(); batchData->resize(dataBatchDims[i]); - batchData->setBackend("cpu"); + batchData->setBackend(mBackend); batchData->setDataType(mDataTypes[i]); batchTensors.push_back(batchData); } @@ -78,6 +79,8 @@ std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch() con // Browse each modularity in the database item for (std::size_t j = 0; j < mNumberModality; ++j) { + + dataItem[j]->setBackend(mBackend); auto dataSample = dataItem[j]; // Assert tensor sizes -- GitLab