Skip to content
Snippets Groups Projects
Commit 01f4ca8b authored by Thibault Allenet's avatar Thibault Allenet
Browse files

Add DataProvider iterator for python and shuffle and droplast batch

parent da8c26ab
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!86Dataprovider iterator
......@@ -20,8 +20,6 @@
#include "aidge/data/Database.hpp"
#include "aidge/data/Data.hpp"
namespace Aidge {
/**
......@@ -33,14 +31,35 @@ class DataProvider {
private:
// Dataset providing the data to the dataProvider
const Database& mDatabase;
// Desired size of the produced batches
const std::size_t mBatchSize;
// Enable random shuffling for learning
const bool mShuffle;
// Drops the last non-full batch
const bool mDropLast;
// Number of modality in one item
const std::size_t mNumberModality;
std::vector<std::vector<std::size_t>> mDataSizes;
std::vector<std::string> mDataBackends;
std::vector<DataType> mDataTypes;
// Desired size of the produced batches
const std::size_t mBatchSize;
// mNbItems contains the number of items in the database
std::size_t mNbItems;
// mBatches contains the call order of each database item
std::vector<unsigned int> mBatches;
// mIndex browsing the number of batch
std::size_t mIndexBatch;
// mNbBatch contains the number of batch
std::size_t mNbBatch;
// Size of the Last batch
std::size_t mLastBatchSize;
// Store each modality dimensions, backend and type
std::vector<std::vector<std::size_t>> mDataDims;
std::vector<std::string> mDataBackends;
std::vector<DataType> mDataTypes;
public:
/**
......@@ -48,15 +67,63 @@ public:
* @param database database from which to load the data.
* @param batchSize number of data samples per batch.
*/
DataProvider(const Database& database, const std::size_t batchSize);
DataProvider(const Database& database, const std::size_t batchSize, const bool shuffle = false, const bool dropLast = false);
public:
/**
* @brief Create a batch for each data modality in the database. The returned batch contain the data as sorted in the database.
* @param startIndex the starting index in the database to start the batch from.
* @brief Create a batch for each data modality in the database.
* @return a vector of tensors. Each tensor is a batch corresponding to one modality.
*/
std::vector<std::shared_ptr<Tensor>> readBatch(const std::size_t startIndex) const;
std::vector<std::shared_ptr<Tensor>> readBatch() const;
/**
* @brief Get the Number of Batch
*
* @return std::size_t
*/
inline std::size_t getNbBatch(){
return mNbBatch;
};
/**
* @brief Get the current Index Batch
*
* @return std::size_t
*/
inline std::size_t getIndexBatch(){
return mIndexBatch;
};
/**
* @brief Reset the internal index batch that browses the data of the database to zero.
*/
inline void resetIndexBatch(){
mIndexBatch = 0;
};
/**
* @brief Increment the internal index batch that browses the data of the database.
*/
inline void incrementIndexBatch(){
++mIndexBatch;
};
void setBatches();
/**
* @brief End of dataProvider condition
*
* @return true when all batch were fetched, False otherwise
*/
inline bool done(){
return (mIndexBatch == mNbBatch);
};
// Functions for python iterator iter and next (definition in pybind.cpp)
// __iter__ method for iterator protocol
DataProvider* iter();
// __next__ method for iterator protocol
std::vector<std::shared_ptr<Aidge::Tensor>> next();
};
} // namespace Aidge
......
......@@ -4,19 +4,33 @@
#include "aidge/data/Database.hpp"
namespace py = pybind11;
namespace Aidge {
// __iter__ method for iterator protocol
DataProvider* DataProvider::iter(){
resetIndexBatch();
setBatches();
return this;
}
// __next__ method for iterator protocol
std::vector<std::shared_ptr<Aidge::Tensor>> DataProvider::next() {
if (!done()){
incrementIndexBatch();
return readBatch();
} else {
throw py::stop_iteration();
}
}
void init_DataProvider(py::module& m){
py::class_<DataProvider, std::shared_ptr<DataProvider>>(m, "DataProvider")
.def(py::init<Database&, std::size_t>(), py::arg("database"), py::arg("batchSize"))
.def("read_batch", &DataProvider::readBatch, py::arg("start_index"),
R"mydelimiter(
Return a batch of each data modality.
:param start_index: Database starting index to read the batch from
:type start_index: int
)mydelimiter");
.def(py::init<Database&, std::size_t, bool, bool>(), py::arg("database"), py::arg("batch_size"), py::arg("shuffle"), py::arg("drop_last"))
.def("__iter__", &DataProvider::iter)
.def("__next__", &DataProvider::next)
.def("__len__", &DataProvider::getNbBatch);
}
}
......@@ -13,45 +13,56 @@
#include <cstddef> // std::size_t
#include <memory>
#include <vector>
#include <cmath>
#include "aidge/data/Database.hpp"
#include "aidge/data/DataProvider.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/Random.hpp"
Aidge::DataProvider::DataProvider(const Aidge::Database& database, const std::size_t batchSize)
Aidge::DataProvider::DataProvider(const Aidge::Database& database, const std::size_t batchSize, const bool shuffle, const bool dropLast)
: mDatabase(database),
mBatchSize(batchSize),
mShuffle(shuffle),
mDropLast(dropLast),
mNumberModality(database.getItem(0).size()),
mBatchSize(batchSize)
mNbItems(mDatabase.getLen()),
mIndexBatch(0)
{
// Iterating on each data modality in the database
// Get the tensor dimensions, datatype and backend of each modality to ensure each data have the same
for (const auto& modality : mDatabase.getItem(0)) {
mDataSizes.push_back(modality->dims());
mDataDims.push_back(modality->dims());
// assert(std::strcmp(item[i]->getImpl()->backend(), "cpu") == 0 && "DataProvider currently only supports cpu backend tensors");
// mDataBackends.push_back(item[i]->getImpl()->backend());
mDataTypes.push_back(modality->dataType());
}
// Compute the number of bacthes depending on mDropLast boolean
mNbBatch = (mDropLast) ?
static_cast<std::size_t>(std::floor(mNbItems / mBatchSize)) :
static_cast<std::size_t>(std::ceil(mNbItems / mBatchSize));
}
std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch(const std::size_t startIndex) const
std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch() const
{
assert((startIndex) <= mDatabase.getLen() && " DataProvider readBatch : database fetch out of bounds");
// Determine the batch size (may differ for the last batch)
const std::size_t current_batch_size = ((startIndex + mBatchSize) > mDatabase.getLen()) ?
mDatabase.getLen()-startIndex :
mBatchSize;
AIDGE_ASSERT(mIndexBatch <= mNbBatch, "Cannot fetch more data than available in database");
std::size_t current_batch_size;
if (mIndexBatch == mNbBatch) {
current_batch_size = mLastBatchSize;
} else {
current_batch_size = mBatchSize;
}
// Create batch tensors (dimensions, backends, datatype) for each modality
std::vector<std::shared_ptr<Tensor>> batchTensors;
auto dataBatchSize = mDataSizes;
auto dataBatchDims = mDataDims;
for (std::size_t i = 0; i < mNumberModality; ++i) {
dataBatchSize[i].insert(dataBatchSize[i].begin(), current_batch_size);
dataBatchDims[i].insert(dataBatchDims[i].begin(), current_batch_size);
auto batchData = std::make_shared<Tensor>();
batchData->resize(dataBatchSize[i]);
// batchData->setBackend(mDataBackends[i]);
batchData->resize(dataBatchDims[i]);
batchData->setBackend("cpu");
batchData->setDataType(mDataTypes[i]);
batchTensors.push_back(batchData);
......@@ -60,7 +71,8 @@ std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch(const
// Call each database item and concatenate each data modularity in the batch tensors
for (std::size_t i = 0; i < current_batch_size; ++i){
auto dataItem = mDatabase.getItem(startIndex+i);
auto dataItem = mDatabase.getItem(mBatches[(mIndexBatch-1)*mBatchSize+i]);
// auto dataItem = mDatabase.getItem(startIndex+i);
// assert same number of modalities
assert(dataItem.size() == mNumberModality && "DataProvider readBatch : item from database have inconsistent number of modality.");
......@@ -69,7 +81,7 @@ std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch(const
auto dataSample = dataItem[j];
// Assert tensor sizes
assert(dataSample->dims() == mDataSizes[j] && "DataProvider readBatch : corrupted Data size");
assert(dataSample->dims() == mDataDims[j] && "DataProvider readBatch : corrupted Data size");
// Assert implementation backend
// assert(dataSample->getImpl()->backend() == mDataBackends[j] && "DataProvider readBatch : corrupted data backend");
......@@ -82,4 +94,31 @@ std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch(const
}
}
return batchTensors;
}
\ No newline at end of file
}
void Aidge::DataProvider::setBatches(){
mBatches.clear();
mBatches.resize(mNbItems);
std::iota(mBatches.begin(),
mBatches.end(),
0U);
if (mShuffle){
Random::randShuffle(mBatches);
}
if (mNbItems % mBatchSize !=0){ // The last batch is not full
std::size_t lastBatchSize = static_cast<std::size_t>(mNbItems % mBatchSize);
if (mDropLast){ // Remove the last non-full batch
AIDGE_ASSERT(lastBatchSize <= mBatches.size(), "Last batch bigger than the size of database");
mBatches.erase(mBatches.end() - lastBatchSize, mBatches.end());
mLastBatchSize = mBatchSize;
} else { // Keep the last non-full batch
mLastBatchSize = lastBatchSize;
}
} else { // The last batch is full
mLastBatchSize = mBatchSize;
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment