Skip to content
Snippets Groups Projects

Dataprovider iterator

Merged Thibault Allenet requested to merge dataproviderIterator into dev
6 files
+ 231
38
Compare changes
  • Side-by-side
  • Inline
Files
6
@@ -20,8 +20,6 @@
#include "aidge/data/Database.hpp"
#include "aidge/data/Data.hpp"
namespace Aidge {
/**
@@ -33,14 +31,35 @@ class DataProvider {
private:
// Dataset providing the data to the dataProvider
const Database& mDatabase;
// Desired size of the produced batches
const std::size_t mBatchSize;
// Enable random shuffling for learning
const bool mShuffle;
// Drops the last non-full batch
const bool mDropLast;
// Number of modality in one item
const std::size_t mNumberModality;
std::vector<std::vector<std::size_t>> mDataSizes;
std::vector<std::string> mDataBackends;
std::vector<DataType> mDataTypes;
// Desired size of the produced batches
const std::size_t mBatchSize;
// mNbItems contains the number of items in the database
std::size_t mNbItems;
// mBatches contains the call order of each database item
std::vector<unsigned int> mBatches;
// mIndex browsing the number of batch
std::size_t mIndexBatch;
// mNbBatch contains the number of batch
std::size_t mNbBatch;
// Size of the Last batch
std::size_t mLastBatchSize;
// Store each modality dimensions, backend and type
std::vector<std::vector<std::size_t>> mDataDims;
std::vector<std::string> mDataBackends;
std::vector<DataType> mDataTypes;
public:
/**
@@ -48,15 +67,76 @@ public:
* @param database database from which to load the data.
* @param batchSize number of data samples per batch.
*/
DataProvider(const Database& database, const std::size_t batchSize);
DataProvider(const Database& database, const std::size_t batchSize, const bool shuffle = false, const bool dropLast = false);
public:
/**
* @brief Create a batch for each data modality in the database. The returned batch contain the data as sorted in the database.
* @param startIndex the starting index in the database to start the batch from.
* @brief Create a batch for each data modality in the database.
* @return a vector of tensors. Each tensor is a batch corresponding to one modality.
*/
std::vector<std::shared_ptr<Tensor>> readBatch(const std::size_t startIndex) const;
std::vector<std::shared_ptr<Tensor>> readBatch() const;
/**
* @brief Get the Number of Batch.
*
* @return std::size_t
*/
inline std::size_t getNbBatch(){
return mNbBatch;
};
/**
* @brief Get the current Index Batch.
*
* @return std::size_t
*/
inline std::size_t getIndexBatch(){
return mIndexBatch;
};
/**
* @brief Reset the internal index batch that browses the data of the database to zero.
*/
inline void resetIndexBatch(){
mIndexBatch = 0;
};
/**
* @brief Increment the internal index batch that browses the data of the database.
*/
inline void incrementIndexBatch(){
++mIndexBatch;
};
/**
* @brief Setup the batches for one pass on the database.
*/
void setBatches();
/**
* @brief End condition of dataProvider for one pass on the database.
*
* @return true when all batch were fetched, False otherwise
*/
inline bool done(){
return (mIndexBatch == mNbBatch);
};
// Functions for python iterator iter and next (definition in pybind.cpp)
/**
* @brief __iter__ method for iterator protocol
*
* @return DataProvider*
*/
DataProvider* iter();
/**
* @brief __next__ method for iterator protocol
*
* @return std::vector<std::shared_ptr<Aidge::Tensor>>
*/
std::vector<std::shared_ptr<Aidge::Tensor>> next();
};
} // namespace Aidge
Loading