From eb48b9ea8921d280a4aefde9ea913db5df3f8576 Mon Sep 17 00:00:00 2001 From: thibault allenet <thibault.allenet@cea.fr> Date: Wed, 6 Dec 2023 14:00:18 +0000 Subject: [PATCH] Add DataProvider to create batches --- include/aidge/data/DataProvider.hpp | 49 +++++++++++++++++++ src/data/DataProvider.cpp | 76 +++++++++++++++++++++++++++++ 2 files changed, 125 insertions(+) create mode 100644 include/aidge/data/DataProvider.hpp create mode 100644 src/data/DataProvider.cpp diff --git a/include/aidge/data/DataProvider.hpp b/include/aidge/data/DataProvider.hpp new file mode 100644 index 000000000..5ca47ce7b --- /dev/null +++ b/include/aidge/data/DataProvider.hpp @@ -0,0 +1,49 @@ +#ifndef DATAPROVIDER_H_ +#define DATAPROVIDER_H_ + +#include "aidge/data/Database.hpp" +#include "aidge/data/Data.hpp" + +namespace Aidge{ + + +/** + * @brief Data Provider. Takes in a database and compose batches by fetching data from the given database. + * @todo Implement Drop last batch option. Currently returns the last batch with less elements in the batch. + * @todo Implement readRandomBatch to compose batches from the database with a random sampling startegy. Necessary for training. + */ +class DataProvider { + +public: + /** + * @brief Constructor of Data Provider. + * @param database database from which to load the data. + * @param batchSize number of data samples per batch. + */ + DataProvider(Database& database, size_t batchSize); + + /** + * @brief Create a batch for each data modality in the database. The returned batch contain the data as sorted in the database. + * @param startIndex the starting index in the database to start the batch from. + * @return a vector of tensors. Each tensor is a batch corresponding to one modality. + */ + std::vector<std::shared_ptr<Tensor>> readBatch(size_t startIndex); + +protected: + + // Dataset providing the data to the dataProvider + Database& mDatabase; + + size_t mNumberModality; + std::vector<std::vector<size_t>> mDataSizes; + std::vector<std::string> mDataBackends; + std::vector<DataType> mDataTypes; + + // Desired size of the produced batches + size_t mBatchSize; + +}; + +} + +#endif /* DATAPROVIDER_H_ */ \ No newline at end of file diff --git a/src/data/DataProvider.cpp b/src/data/DataProvider.cpp new file mode 100644 index 000000000..cdac0cda4 --- /dev/null +++ b/src/data/DataProvider.cpp @@ -0,0 +1,76 @@ +#include <cassert> + +#include "aidge/data/DataProvider.hpp" + +using namespace Aidge; + +DataProvider::DataProvider(Database& database, size_t batchSize) + : + mDatabase(database), + mBatchSize(batchSize) +{ + // Get the tensor dimensions, datatype and backend of each modality to ensure each data have the same + auto item = mDatabase.getItem(0); + mNumberModality = item.size(); + + // Iterating on each data modality in the database + for (std::size_t i = 0; i < mNumberModality; ++i) { + mDataSizes.push_back(item[i]->dims()); + // assert(std::strcmp(item[i]->getImpl()->backend(), "cpu") == 0 && "DataProvider currently only supports cpu backend tensors"); + // mDataBackends.push_back(item[i]->getImpl()->backend()); + mDataTypes.push_back(item[i]->dataType()); + } +} + +std::vector<std::shared_ptr<Tensor>> DataProvider::readBatch(size_t startIndex) +{ + assert((startIndex) <= mDatabase.getLen() && " DataProvider readBatch : database fetch out of bounds"); + + + // Determine the batch size (may differ for the last batch) + size_t current_batch_size; + if ((startIndex+mBatchSize) > mDatabase.getLen()){ + current_batch_size = mDatabase.getLen()-startIndex; + } else { + current_batch_size = mBatchSize; + } + + // Create batch tensors (dimensions, backends, datatype) for each modality + std::vector<std::shared_ptr<Tensor>> batchTensors; + auto dataBatchSize = mDataSizes; + for (std::size_t i = 0; i < mNumberModality; ++i) { + dataBatchSize[i].insert(dataBatchSize[i].begin(), current_batch_size); + auto batchData = std::make_shared<Tensor>(); + batchData->resize(dataBatchSize[i]); + // batchData->setBackend(mDataBackends[i]); + batchData->setBackend("cpu"); + batchData->setDatatype(mDataTypes[i]); + batchTensors.push_back(batchData); + } + + // Call each database item and concatenate each data modularity in the batch tensors + for (std::size_t i = 0; i < current_batch_size; ++i){ + + auto dataItem = mDatabase.getItem(startIndex+i); + // assert same number of modalities + assert(dataItem.size() == mNumberModality && "DataProvider readBatch : item from database have inconsistent number of modality."); + + // Browse each modularity in the database item + for (std::size_t j = 0; j < mNumberModality; ++j) { + auto dataSample = dataItem[j]; + + // Assert tensor sizes + assert(dataSample->dims() == mDataSizes[j] && "DataProvider readBatch : corrupted Data size"); + + // Assert implementation backend + // assert(dataSample->getImpl()->backend() == mDataBackends[j] && "DataProvider readBatch : corrupted data backend"); + + // Assert DataType + assert(dataSample->dataType() == mDataTypes[j] && "DataProvider readBatch : corrupted data DataType"); + + // Concatenate into the batch tensor + batchTensors[j]->getImpl()->copy(dataSample->getImpl()->rawPtr(), dataSample->size(), i*dataSample->size()); + } + } + return batchTensors; +} \ No newline at end of file -- GitLab