From eb48b9ea8921d280a4aefde9ea913db5df3f8576 Mon Sep 17 00:00:00 2001
From: thibault allenet <thibault.allenet@cea.fr>
Date: Wed, 6 Dec 2023 14:00:18 +0000
Subject: [PATCH] Add DataProvider to create batches

---
 include/aidge/data/DataProvider.hpp | 49 +++++++++++++++++++
 src/data/DataProvider.cpp           | 76 +++++++++++++++++++++++++++++
 2 files changed, 125 insertions(+)
 create mode 100644 include/aidge/data/DataProvider.hpp
 create mode 100644 src/data/DataProvider.cpp

diff --git a/include/aidge/data/DataProvider.hpp b/include/aidge/data/DataProvider.hpp
new file mode 100644
index 000000000..5ca47ce7b
--- /dev/null
+++ b/include/aidge/data/DataProvider.hpp
@@ -0,0 +1,49 @@
+#ifndef DATAPROVIDER_H_
+#define DATAPROVIDER_H_
+
+#include "aidge/data/Database.hpp"
+#include "aidge/data/Data.hpp"
+
+namespace Aidge{
+
+
+/**
+ * @brief Data Provider. Takes in a database and compose batches by fetching data from the given database.
+ * @todo Implement Drop last batch option. Currently returns the last batch with less elements in the batch.
+ * @todo Implement readRandomBatch to compose batches from the database with a random sampling startegy. Necessary for training.
+ */
+class DataProvider {
+
+public:
+    /**
+     * @brief Constructor of Data Provider. 
+     * @param database database from which to load the data.
+     * @param batchSize number of data samples per batch.
+     */
+    DataProvider(Database& database, size_t batchSize);
+
+    /**
+     * @brief Create a batch for each data modality in the database. The returned batch contain the data as sorted in the database.
+     * @param startIndex the starting index in the database to start the batch from. 
+     * @return a vector of tensors. Each tensor is a batch corresponding to one modality.
+     */
+    std::vector<std::shared_ptr<Tensor>> readBatch(size_t startIndex);
+
+protected:
+
+    // Dataset providing the data to the dataProvider
+    Database& mDatabase;
+    
+    size_t mNumberModality;
+    std::vector<std::vector<size_t>> mDataSizes;
+    std::vector<std::string> mDataBackends;
+    std::vector<DataType> mDataTypes;
+    
+    // Desired size of the produced batches
+    size_t mBatchSize;
+
+};
+
+}
+
+#endif /* DATAPROVIDER_H_ */
\ No newline at end of file
diff --git a/src/data/DataProvider.cpp b/src/data/DataProvider.cpp
new file mode 100644
index 000000000..cdac0cda4
--- /dev/null
+++ b/src/data/DataProvider.cpp
@@ -0,0 +1,76 @@
+#include <cassert>
+
+#include "aidge/data/DataProvider.hpp"
+
+using namespace Aidge; 
+
+DataProvider::DataProvider(Database& database, size_t batchSize)
+    :
+    mDatabase(database),
+    mBatchSize(batchSize)
+{
+    // Get the tensor dimensions, datatype and backend of each modality to ensure each data have the same
+    auto item = mDatabase.getItem(0);
+    mNumberModality = item.size();
+
+    // Iterating on each data modality in the database
+    for (std::size_t i = 0; i < mNumberModality; ++i) {
+        mDataSizes.push_back(item[i]->dims());
+        // assert(std::strcmp(item[i]->getImpl()->backend(), "cpu") == 0 && "DataProvider currently only supports cpu backend tensors");
+        // mDataBackends.push_back(item[i]->getImpl()->backend());
+        mDataTypes.push_back(item[i]->dataType());
+    }
+}
+
+std::vector<std::shared_ptr<Tensor>> DataProvider::readBatch(size_t startIndex)
+{
+    assert((startIndex) <= mDatabase.getLen() && " DataProvider readBatch : database fetch out of bounds");
+    
+    
+    // Determine the batch size (may differ for the last batch)
+    size_t current_batch_size;
+    if ((startIndex+mBatchSize) > mDatabase.getLen()){
+        current_batch_size = mDatabase.getLen()-startIndex;
+    } else {
+        current_batch_size = mBatchSize;
+    }
+
+    // Create batch tensors (dimensions, backends, datatype) for each modality
+    std::vector<std::shared_ptr<Tensor>> batchTensors;
+    auto dataBatchSize = mDataSizes;
+    for (std::size_t i = 0; i < mNumberModality; ++i) {
+        dataBatchSize[i].insert(dataBatchSize[i].begin(), current_batch_size);
+        auto batchData = std::make_shared<Tensor>();
+        batchData->resize(dataBatchSize[i]);
+        // batchData->setBackend(mDataBackends[i]);
+        batchData->setBackend("cpu");
+        batchData->setDatatype(mDataTypes[i]);
+        batchTensors.push_back(batchData);
+    }
+    
+    // Call each database item and concatenate each data modularity in the batch tensors
+    for (std::size_t i = 0; i < current_batch_size; ++i){
+
+        auto dataItem = mDatabase.getItem(startIndex+i); 
+        // assert same number of modalities
+        assert(dataItem.size() == mNumberModality && "DataProvider readBatch : item from database have inconsistent number of modality.");
+
+        // Browse each modularity in the database item
+        for (std::size_t j = 0; j < mNumberModality; ++j) {
+            auto dataSample = dataItem[j];
+            
+            // Assert tensor sizes
+            assert(dataSample->dims() == mDataSizes[j] && "DataProvider readBatch : corrupted Data size");
+            
+            // Assert implementation backend
+            // assert(dataSample->getImpl()->backend() == mDataBackends[j] && "DataProvider readBatch : corrupted data backend");
+
+            // Assert DataType
+            assert(dataSample->dataType() == mDataTypes[j] && "DataProvider readBatch : corrupted data DataType");
+
+            // Concatenate into the batch tensor 
+            batchTensors[j]->getImpl()->copy(dataSample->getImpl()->rawPtr(), dataSample->size(), i*dataSample->size());
+        }
+    }
+    return batchTensors;
+}
\ No newline at end of file
-- 
GitLab