diff --git a/include/aidge/data/DataProvider.hpp b/include/aidge/data/DataProvider.hpp index 5c7a1c73ce4ad4eb512a446879cb1ad9b673eb2f..43245573cab19c4e5c31de180910c9666154ff6d 100644 --- a/include/aidge/data/DataProvider.hpp +++ b/include/aidge/data/DataProvider.hpp @@ -20,8 +20,6 @@ #include "aidge/data/Database.hpp" #include "aidge/data/Data.hpp" - - namespace Aidge { /** @@ -33,14 +31,35 @@ class DataProvider { private: // Dataset providing the data to the dataProvider const Database& mDatabase; + + // Desired size of the produced batches + const std::size_t mBatchSize; + // Enable random shuffling for learning + const bool mShuffle; + + // Drops the last non-full batch + const bool mDropLast; + + // Number of modality in one item const std::size_t mNumberModality; - std::vector<std::vector<std::size_t>> mDataSizes; - std::vector<std::string> mDataBackends; - std::vector<DataType> mDataTypes; - // Desired size of the produced batches - const std::size_t mBatchSize; + // mNbItems contains the number of items in the database + std::size_t mNbItems; + // mBatches contains the call order of each database item + std::vector<unsigned int> mBatches; + // mIndex browsing the number of batch + std::size_t mIndexBatch; + + // mNbBatch contains the number of batch + std::size_t mNbBatch; + // Size of the Last batch + std::size_t mLastBatchSize; + + // Store each modality dimensions, backend and type + std::vector<std::vector<std::size_t>> mDataDims; + std::vector<std::string> mDataBackends; + std::vector<DataType> mDataTypes; public: /** @@ -48,15 +67,63 @@ public: * @param database database from which to load the data. * @param batchSize number of data samples per batch. */ - DataProvider(const Database& database, const std::size_t batchSize); + DataProvider(const Database& database, const std::size_t batchSize, const bool shuffle = false, const bool dropLast = false); public: /** - * @brief Create a batch for each data modality in the database. The returned batch contain the data as sorted in the database. - * @param startIndex the starting index in the database to start the batch from. + * @brief Create a batch for each data modality in the database. * @return a vector of tensors. Each tensor is a batch corresponding to one modality. */ - std::vector<std::shared_ptr<Tensor>> readBatch(const std::size_t startIndex) const; + std::vector<std::shared_ptr<Tensor>> readBatch() const; + + /** + * @brief Get the Number of Batch + * + * @return std::size_t + */ + inline std::size_t getNbBatch(){ + return mNbBatch; + }; + + /** + * @brief Get the current Index Batch + * + * @return std::size_t + */ + inline std::size_t getIndexBatch(){ + return mIndexBatch; + }; + + /** + * @brief Reset the internal index batch that browses the data of the database to zero. + */ + inline void resetIndexBatch(){ + mIndexBatch = 0; + }; + + /** + * @brief Increment the internal index batch that browses the data of the database. + */ + inline void incrementIndexBatch(){ + ++mIndexBatch; + }; + + void setBatches(); + + /** + * @brief End of dataProvider condition + * + * @return true when all batch were fetched, False otherwise + */ + inline bool done(){ + return (mIndexBatch == mNbBatch); + }; + + // Functions for python iterator iter and next (definition in pybind.cpp) + // __iter__ method for iterator protocol + DataProvider* iter(); + // __next__ method for iterator protocol + std::vector<std::shared_ptr<Aidge::Tensor>> next(); }; } // namespace Aidge diff --git a/python_binding/data/pybind_DataProvider.cpp b/python_binding/data/pybind_DataProvider.cpp index dfdf188946673c4e2a7ea2dc0829312758d80f96..2f652aff5008f8008952ffb1bb6fb1738021b436 100644 --- a/python_binding/data/pybind_DataProvider.cpp +++ b/python_binding/data/pybind_DataProvider.cpp @@ -4,19 +4,33 @@ #include "aidge/data/Database.hpp" namespace py = pybind11; + namespace Aidge { +// __iter__ method for iterator protocol +DataProvider* DataProvider::iter(){ + resetIndexBatch(); + setBatches(); + return this; +} + +// __next__ method for iterator protocol +std::vector<std::shared_ptr<Aidge::Tensor>> DataProvider::next() { + if (!done()){ + incrementIndexBatch(); + return readBatch(); + } else { + throw py::stop_iteration(); + } +} + void init_DataProvider(py::module& m){ py::class_<DataProvider, std::shared_ptr<DataProvider>>(m, "DataProvider") - .def(py::init<Database&, std::size_t>(), py::arg("database"), py::arg("batchSize")) - .def("read_batch", &DataProvider::readBatch, py::arg("start_index"), - R"mydelimiter( - Return a batch of each data modality. - - :param start_index: Database starting index to read the batch from - :type start_index: int - )mydelimiter"); + .def(py::init<Database&, std::size_t, bool, bool>(), py::arg("database"), py::arg("batch_size"), py::arg("shuffle"), py::arg("drop_last")) + .def("__iter__", &DataProvider::iter) + .def("__next__", &DataProvider::next) + .def("__len__", &DataProvider::getNbBatch); } } diff --git a/src/data/DataProvider.cpp b/src/data/DataProvider.cpp index dffb5745d9e324856548387069bcf1d5ff6a7b48..7783ed86cf4ae1d8672cc6a35a97ca9a996457b6 100644 --- a/src/data/DataProvider.cpp +++ b/src/data/DataProvider.cpp @@ -13,45 +13,56 @@ #include <cstddef> // std::size_t #include <memory> #include <vector> +#include <cmath> + #include "aidge/data/Database.hpp" #include "aidge/data/DataProvider.hpp" #include "aidge/data/Tensor.hpp" +#include "aidge/utils/Random.hpp" + -Aidge::DataProvider::DataProvider(const Aidge::Database& database, const std::size_t batchSize) +Aidge::DataProvider::DataProvider(const Aidge::Database& database, const std::size_t batchSize, const bool shuffle, const bool dropLast) : mDatabase(database), + mBatchSize(batchSize), + mShuffle(shuffle), + mDropLast(dropLast), mNumberModality(database.getItem(0).size()), - mBatchSize(batchSize) + mNbItems(mDatabase.getLen()), + mIndexBatch(0) { // Iterating on each data modality in the database // Get the tensor dimensions, datatype and backend of each modality to ensure each data have the same for (const auto& modality : mDatabase.getItem(0)) { - mDataSizes.push_back(modality->dims()); + mDataDims.push_back(modality->dims()); // assert(std::strcmp(item[i]->getImpl()->backend(), "cpu") == 0 && "DataProvider currently only supports cpu backend tensors"); - // mDataBackends.push_back(item[i]->getImpl()->backend()); mDataTypes.push_back(modality->dataType()); } + + // Compute the number of bacthes depending on mDropLast boolean + mNbBatch = (mDropLast) ? + static_cast<std::size_t>(std::floor(mNbItems / mBatchSize)) : + static_cast<std::size_t>(std::ceil(mNbItems / mBatchSize)); } -std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch(const std::size_t startIndex) const +std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch() const { - assert((startIndex) <= mDatabase.getLen() && " DataProvider readBatch : database fetch out of bounds"); - - - // Determine the batch size (may differ for the last batch) - const std::size_t current_batch_size = ((startIndex + mBatchSize) > mDatabase.getLen()) ? - mDatabase.getLen()-startIndex : - mBatchSize; + AIDGE_ASSERT(mIndexBatch <= mNbBatch, "Cannot fetch more data than available in database"); + std::size_t current_batch_size; + if (mIndexBatch == mNbBatch) { + current_batch_size = mLastBatchSize; + } else { + current_batch_size = mBatchSize; + } // Create batch tensors (dimensions, backends, datatype) for each modality std::vector<std::shared_ptr<Tensor>> batchTensors; - auto dataBatchSize = mDataSizes; + auto dataBatchDims = mDataDims; for (std::size_t i = 0; i < mNumberModality; ++i) { - dataBatchSize[i].insert(dataBatchSize[i].begin(), current_batch_size); + dataBatchDims[i].insert(dataBatchDims[i].begin(), current_batch_size); auto batchData = std::make_shared<Tensor>(); - batchData->resize(dataBatchSize[i]); - // batchData->setBackend(mDataBackends[i]); + batchData->resize(dataBatchDims[i]); batchData->setBackend("cpu"); batchData->setDataType(mDataTypes[i]); batchTensors.push_back(batchData); @@ -60,7 +71,8 @@ std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch(const // Call each database item and concatenate each data modularity in the batch tensors for (std::size_t i = 0; i < current_batch_size; ++i){ - auto dataItem = mDatabase.getItem(startIndex+i); + auto dataItem = mDatabase.getItem(mBatches[(mIndexBatch-1)*mBatchSize+i]); + // auto dataItem = mDatabase.getItem(startIndex+i); // assert same number of modalities assert(dataItem.size() == mNumberModality && "DataProvider readBatch : item from database have inconsistent number of modality."); @@ -69,7 +81,7 @@ std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch(const auto dataSample = dataItem[j]; // Assert tensor sizes - assert(dataSample->dims() == mDataSizes[j] && "DataProvider readBatch : corrupted Data size"); + assert(dataSample->dims() == mDataDims[j] && "DataProvider readBatch : corrupted Data size"); // Assert implementation backend // assert(dataSample->getImpl()->backend() == mDataBackends[j] && "DataProvider readBatch : corrupted data backend"); @@ -82,4 +94,31 @@ std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch(const } } return batchTensors; -} \ No newline at end of file +} + + +void Aidge::DataProvider::setBatches(){ + + mBatches.clear(); + mBatches.resize(mNbItems); + std::iota(mBatches.begin(), + mBatches.end(), + 0U); + + if (mShuffle){ + Random::randShuffle(mBatches); + } + + if (mNbItems % mBatchSize !=0){ // The last batch is not full + std::size_t lastBatchSize = static_cast<std::size_t>(mNbItems % mBatchSize); + if (mDropLast){ // Remove the last non-full batch + AIDGE_ASSERT(lastBatchSize <= mBatches.size(), "Last batch bigger than the size of database"); + mBatches.erase(mBatches.end() - lastBatchSize, mBatches.end()); + mLastBatchSize = mBatchSize; + } else { // Keep the last non-full batch + mLastBatchSize = lastBatchSize; + } + } else { // The last batch is full + mLastBatchSize = mBatchSize; + } +}