Skip to content
Snippets Groups Projects

Learning backend cuda

Merged Houssem ROUIS requested to merge hrouis/aidge_core:learning_backend_cuda into dev
Files
18
@@ -35,6 +35,9 @@ private:
@@ -35,6 +35,9 @@ private:
// Desired size of the produced batches
// Desired size of the produced batches
const std::size_t mBatchSize;
const std::size_t mBatchSize;
 
// The backend for data tensors
 
std::string mBackend;
 
// Enable random shuffling for learning
// Enable random shuffling for learning
const bool mShuffle;
const bool mShuffle;
@@ -67,7 +70,7 @@ public:
@@ -67,7 +70,7 @@ public:
* @param database database from which to load the data.
* @param database database from which to load the data.
* @param batchSize number of data samples per batch.
* @param batchSize number of data samples per batch.
*/
*/
DataProvider(const Database& database, const std::size_t batchSize, const bool shuffle = false, const bool dropLast = false);
DataProvider(const Database& database, const std::size_t batchSize, const std::string& backend = "cpu", const bool shuffle = false, const bool dropLast = false);
public:
public:
/**
/**
Loading