Skip to content
Snippets Groups Projects
Commit 0279070b authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

add backend cuda support for DataProvider

parent 5a68b2db
No related branches found
No related tags found
2 merge requests!212Version 0.3.0,!178Learning backend cuda
......@@ -56,6 +56,8 @@ private:
// Size of the Last batch
std::size_t mLastBatchSize;
std::string mBackend;
// Store each modality dimensions, backend and type
std::vector<std::vector<std::size_t>> mDataDims;
std::vector<std::string> mDataBackends;
......@@ -67,7 +69,7 @@ public:
* @param database database from which to load the data.
* @param batchSize number of data samples per batch.
*/
DataProvider(const Database& database, const std::size_t batchSize, const bool shuffle = false, const bool dropLast = false);
DataProvider(const Database& database, const std::size_t batchSize, const std::string& backend = "cpu", const bool shuffle = false, const bool dropLast = false);
public:
/**
......
......@@ -27,7 +27,7 @@ std::vector<std::shared_ptr<Aidge::Tensor>> DataProvider::next() {
void init_DataProvider(py::module& m){
py::class_<DataProvider, std::shared_ptr<DataProvider>>(m, "DataProvider")
.def(py::init<Database&, std::size_t, bool, bool>(), py::arg("database"), py::arg("batch_size"), py::arg("shuffle"), py::arg("drop_last"))
.def(py::init<Database&, std::size_t, std::string, bool, bool>(), py::arg("database"), py::arg("batch_size"), py::arg("backend"), py::arg("shuffle"), py::arg("drop_last"))
.def("__iter__", &DataProvider::iter)
.def("__next__", &DataProvider::next)
.def("__len__", &DataProvider::getNbBatch);
......
......@@ -23,9 +23,10 @@
#include "aidge/utils/Random.hpp"
Aidge::DataProvider::DataProvider(const Aidge::Database& database, const std::size_t batchSize, const bool shuffle, const bool dropLast)
Aidge::DataProvider::DataProvider(const Aidge::Database& database, const std::size_t batchSize, const std::string& backend, const bool shuffle, const bool dropLast)
: mDatabase(database),
mBatchSize(batchSize),
mBackend(backend),
mShuffle(shuffle),
mDropLast(dropLast),
mNumberModality(database.getItem(0).size()),
......@@ -63,7 +64,7 @@ std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch() con
dataBatchDims[i].insert(dataBatchDims[i].begin(), current_batch_size);
auto batchData = std::make_shared<Tensor>();
batchData->resize(dataBatchDims[i]);
batchData->setBackend("cpu");
batchData->setBackend(mBackend);
batchData->setDataType(mDataTypes[i]);
batchTensors.push_back(batchData);
}
......@@ -78,6 +79,8 @@ std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch() con
// Browse each modularity in the database item
for (std::size_t j = 0; j < mNumberModality; ++j) {
dataItem[j]->setBackend(mBackend);
auto dataSample = dataItem[j];
// Assert tensor sizes
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment