Skip to content
Snippets Groups Projects

Dataprovider iterator

Merged Thibault Allenet requested to merge dataproviderIterator into dev
6 files
+ 231
38
Compare changes
  • Side-by-side
  • Inline
Files
6
@@ -20,8 +20,6 @@
@@ -20,8 +20,6 @@
#include "aidge/data/Database.hpp"
#include "aidge/data/Database.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/data/Data.hpp"
namespace Aidge {
namespace Aidge {
/**
/**
@@ -33,14 +31,35 @@ class DataProvider {
@@ -33,14 +31,35 @@ class DataProvider {
private:
private:
// Dataset providing the data to the dataProvider
// Dataset providing the data to the dataProvider
const Database& mDatabase;
const Database& mDatabase;
 
 
// Desired size of the produced batches
 
const std::size_t mBatchSize;
 
 
// Enable random shuffling for learning
 
const bool mShuffle;
 
// Drops the last non-full batch
 
const bool mDropLast;
 
 
// Number of modality in one item
const std::size_t mNumberModality;
const std::size_t mNumberModality;
std::vector<std::vector<std::size_t>> mDataSizes;
std::vector<std::string> mDataBackends;
std::vector<DataType> mDataTypes;
// Desired size of the produced batches
// mNbItems contains the number of items in the database
const std::size_t mBatchSize;
std::size_t mNbItems;
 
// mBatches contains the call order of each database item
 
std::vector<unsigned int> mBatches;
 
// mIndex browsing the number of batch
 
std::size_t mIndexBatch;
 
 
// mNbBatch contains the number of batch
 
std::size_t mNbBatch;
 
// Size of the Last batch
 
std::size_t mLastBatchSize;
 
 
// Store each modality dimensions, backend and type
 
std::vector<std::vector<std::size_t>> mDataDims;
 
std::vector<std::string> mDataBackends;
 
std::vector<DataType> mDataTypes;
public:
public:
/**
/**
@@ -48,15 +67,76 @@ public:
@@ -48,15 +67,76 @@ public:
* @param database database from which to load the data.
* @param database database from which to load the data.
* @param batchSize number of data samples per batch.
* @param batchSize number of data samples per batch.
*/
*/
DataProvider(const Database& database, const std::size_t batchSize);
DataProvider(const Database& database, const std::size_t batchSize, const bool shuffle = false, const bool dropLast = false);
public:
public:
/**
/**
* @brief Create a batch for each data modality in the database. The returned batch contain the data as sorted in the database.
* @brief Create a batch for each data modality in the database.
* @param startIndex the starting index in the database to start the batch from.
* @return a vector of tensors. Each tensor is a batch corresponding to one modality.
* @return a vector of tensors. Each tensor is a batch corresponding to one modality.
*/
*/
std::vector<std::shared_ptr<Tensor>> readBatch(const std::size_t startIndex) const;
std::vector<std::shared_ptr<Tensor>> readBatch() const;
 
 
/**
 
* @brief Get the Number of Batch.
 
*
 
* @return std::size_t
 
*/
 
inline std::size_t getNbBatch(){
 
return mNbBatch;
 
};
 
 
/**
 
* @brief Get the current Index Batch.
 
*
 
* @return std::size_t
 
*/
 
inline std::size_t getIndexBatch(){
 
return mIndexBatch;
 
};
 
 
/**
 
* @brief Reset the internal index batch that browses the data of the database to zero.
 
*/
 
inline void resetIndexBatch(){
 
mIndexBatch = 0;
 
};
 
 
/**
 
* @brief Increment the internal index batch that browses the data of the database.
 
*/
 
inline void incrementIndexBatch(){
 
++mIndexBatch;
 
};
 
 
/**
 
* @brief Setup the batches for one pass on the database.
 
*/
 
void setBatches();
 
 
/**
 
* @brief End condition of dataProvider for one pass on the database.
 
*
 
* @return true when all batch were fetched, False otherwise
 
*/
 
inline bool done(){
 
return (mIndexBatch == mNbBatch);
 
};
 
 
 
// Functions for python iterator iter and next (definition in pybind.cpp)
 
/**
 
* @brief __iter__ method for iterator protocol
 
*
 
* @return DataProvider*
 
*/
 
DataProvider* iter();
 
 
/**
 
* @brief __next__ method for iterator protocol
 
*
 
* @return std::vector<std::shared_ptr<Aidge::Tensor>>
 
*/
 
std::vector<std::shared_ptr<Aidge::Tensor>> next();
};
};
} // namespace Aidge
} // namespace Aidge
Loading