diff --git a/include/aidge/aidge.hpp b/include/aidge/aidge.hpp index 9e0e457b49fe40b2a6e9e3ce5c5e4b77bee1d93e..6c4ca93ce28c0a8c769606f07b1badee676423fd 100644 --- a/include/aidge/aidge.hpp +++ b/include/aidge/aidge.hpp @@ -70,6 +70,7 @@ #include "aidge/utils/Attributes.hpp" #include "aidge/utils/StaticAttributes.hpp" #include "aidge/utils/DynamicAttributes.hpp" +#include "aidge/utils/Random.hpp" #include "aidge/utils/Registrar.hpp" #include "aidge/utils/Types.h" diff --git a/include/aidge/data/DataProvider.hpp b/include/aidge/data/DataProvider.hpp index 5c7a1c73ce4ad4eb512a446879cb1ad9b673eb2f..62d10a6983e8cf5fd8e2730d3203bed97284e336 100644 --- a/include/aidge/data/DataProvider.hpp +++ b/include/aidge/data/DataProvider.hpp @@ -20,8 +20,6 @@ #include "aidge/data/Database.hpp" #include "aidge/data/Data.hpp" - - namespace Aidge { /** @@ -33,14 +31,35 @@ class DataProvider { private: // Dataset providing the data to the dataProvider const Database& mDatabase; + + // Desired size of the produced batches + const std::size_t mBatchSize; + + // Enable random shuffling for learning + const bool mShuffle; + // Drops the last non-full batch + const bool mDropLast; + + // Number of modality in one item const std::size_t mNumberModality; - std::vector<std::vector<std::size_t>> mDataSizes; - std::vector<std::string> mDataBackends; - std::vector<DataType> mDataTypes; - // Desired size of the produced batches - const std::size_t mBatchSize; + // mNbItems contains the number of items in the database + std::size_t mNbItems; + // mBatches contains the call order of each database item + std::vector<unsigned int> mBatches; + // mIndex browsing the number of batch + std::size_t mIndexBatch; + + // mNbBatch contains the number of batch + std::size_t mNbBatch; + // Size of the Last batch + std::size_t mLastBatchSize; + + // Store each modality dimensions, backend and type + std::vector<std::vector<std::size_t>> mDataDims; + std::vector<std::string> mDataBackends; + std::vector<DataType> mDataTypes; public: /** @@ -48,15 +67,76 @@ public: * @param database database from which to load the data. * @param batchSize number of data samples per batch. */ - DataProvider(const Database& database, const std::size_t batchSize); + DataProvider(const Database& database, const std::size_t batchSize, const bool shuffle = false, const bool dropLast = false); public: /** - * @brief Create a batch for each data modality in the database. The returned batch contain the data as sorted in the database. - * @param startIndex the starting index in the database to start the batch from. + * @brief Create a batch for each data modality in the database. * @return a vector of tensors. Each tensor is a batch corresponding to one modality. */ - std::vector<std::shared_ptr<Tensor>> readBatch(const std::size_t startIndex) const; + std::vector<std::shared_ptr<Tensor>> readBatch() const; + + /** + * @brief Get the Number of Batch. + * + * @return std::size_t + */ + inline std::size_t getNbBatch(){ + return mNbBatch; + }; + + /** + * @brief Get the current Index Batch. + * + * @return std::size_t + */ + inline std::size_t getIndexBatch(){ + return mIndexBatch; + }; + + /** + * @brief Reset the internal index batch that browses the data of the database to zero. + */ + inline void resetIndexBatch(){ + mIndexBatch = 0; + }; + + /** + * @brief Increment the internal index batch that browses the data of the database. + */ + inline void incrementIndexBatch(){ + ++mIndexBatch; + }; + + /** + * @brief Setup the batches for one pass on the database. + */ + void setBatches(); + + /** + * @brief End condition of dataProvider for one pass on the database. + * + * @return true when all batch were fetched, False otherwise + */ + inline bool done(){ + return (mIndexBatch == mNbBatch); + }; + + + // Functions for python iterator iter and next (definition in pybind.cpp) + /** + * @brief __iter__ method for iterator protocol + * + * @return DataProvider* + */ + DataProvider* iter(); + + /** + * @brief __next__ method for iterator protocol + * + * @return std::vector<std::shared_ptr<Aidge::Tensor>> + */ + std::vector<std::shared_ptr<Aidge::Tensor>> next(); }; } // namespace Aidge diff --git a/include/aidge/utils/Random.hpp b/include/aidge/utils/Random.hpp new file mode 100644 index 0000000000000000000000000000000000000000..704609c0c778c7065a580b86fc67aea7e9d3525d --- /dev/null +++ b/include/aidge/utils/Random.hpp @@ -0,0 +1,31 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + + +#ifndef AIDGE_RANDOM_H_ +#define AIDGE_RANDOM_H_ + + +#include <algorithm> +#include <vector> +#include <random> + +namespace Random { + + void randShuffle(std::vector<unsigned int>& vec) { + std::random_device rd; + std::mt19937 g(rd()); + std::shuffle(vec.begin(), vec.end(), g); + } + +} + +#endif //AIDGE_RANDOM_H_ \ No newline at end of file diff --git a/python_binding/data/pybind_DataProvider.cpp b/python_binding/data/pybind_DataProvider.cpp index dfdf188946673c4e2a7ea2dc0829312758d80f96..2f652aff5008f8008952ffb1bb6fb1738021b436 100644 --- a/python_binding/data/pybind_DataProvider.cpp +++ b/python_binding/data/pybind_DataProvider.cpp @@ -4,19 +4,33 @@ #include "aidge/data/Database.hpp" namespace py = pybind11; + namespace Aidge { +// __iter__ method for iterator protocol +DataProvider* DataProvider::iter(){ + resetIndexBatch(); + setBatches(); + return this; +} + +// __next__ method for iterator protocol +std::vector<std::shared_ptr<Aidge::Tensor>> DataProvider::next() { + if (!done()){ + incrementIndexBatch(); + return readBatch(); + } else { + throw py::stop_iteration(); + } +} + void init_DataProvider(py::module& m){ py::class_<DataProvider, std::shared_ptr<DataProvider>>(m, "DataProvider") - .def(py::init<Database&, std::size_t>(), py::arg("database"), py::arg("batchSize")) - .def("read_batch", &DataProvider::readBatch, py::arg("start_index"), - R"mydelimiter( - Return a batch of each data modality. - - :param start_index: Database starting index to read the batch from - :type start_index: int - )mydelimiter"); + .def(py::init<Database&, std::size_t, bool, bool>(), py::arg("database"), py::arg("batch_size"), py::arg("shuffle"), py::arg("drop_last")) + .def("__iter__", &DataProvider::iter) + .def("__next__", &DataProvider::next) + .def("__len__", &DataProvider::getNbBatch); } } diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp index e07f70eaa7de8dc4daa489ec93c8fd9273559ff2..f8a0567bdc7bb27bdff1137a020857cac5a45604 100644 --- a/python_binding/data/pybind_Tensor.cpp +++ b/python_binding/data/pybind_Tensor.cpp @@ -96,10 +96,18 @@ void init_Tensor(py::module& m){ return py::cast(b.get<double>(idx)); case DataType::Float32: return py::cast(b.get<float>(idx)); + case DataType::Int8: + return py::cast(b.get<std::int8_t>(idx)); + case DataType::Int16: + return py::cast(b.get<std::int16_t>(idx)); case DataType::Int32: return py::cast(b.get<std::int32_t>(idx)); case DataType::Int64: return py::cast(b.get<std::int64_t>(idx)); + case DataType::UInt8: + return py::cast(b.get<std::uint8_t>(idx)); + case DataType::UInt16: + return py::cast(b.get<std::uint16_t>(idx)); default: return py::none(); } @@ -111,10 +119,18 @@ void init_Tensor(py::module& m){ return py::cast(b.get<double>(coordIdx)); case DataType::Float32: return py::cast(b.get<float>(coordIdx)); + case DataType::Int8: + return py::cast(b.get<std::int8_t>(coordIdx)); + case DataType::Int16: + return py::cast(b.get<std::int16_t>(coordIdx)); case DataType::Int32: return py::cast(b.get<std::int32_t>(coordIdx)); case DataType::Int64: return py::cast(b.get<std::int64_t>(coordIdx)); + case DataType::UInt8: + return py::cast(b.get<std::uint8_t>(coordIdx)); + case DataType::UInt16: + return py::cast(b.get<std::uint16_t>(coordIdx)); default: return py::none(); } @@ -141,6 +157,12 @@ void init_Tensor(py::module& m){ break; case DataType::Float32: dataFormatDescriptor = py::format_descriptor<float>::format(); + break;; + case DataType::Int8: + dataFormatDescriptor = py::format_descriptor<std::int8_t>::format(); + break; + case DataType::Int16: + dataFormatDescriptor = py::format_descriptor<std::int16_t>::format(); break; case DataType::Int32: dataFormatDescriptor = py::format_descriptor<std::int32_t>::format(); @@ -148,6 +170,12 @@ void init_Tensor(py::module& m){ case DataType::Int64: dataFormatDescriptor = py::format_descriptor<std::int64_t>::format(); break; + case DataType::UInt8: + dataFormatDescriptor = py::format_descriptor<std::uint8_t>::format(); + break; + case DataType::UInt16: + dataFormatDescriptor = py::format_descriptor<std::uint16_t>::format(); + break; default: throw py::value_error("Unsupported data format"); } diff --git a/src/data/DataProvider.cpp b/src/data/DataProvider.cpp index dffb5745d9e324856548387069bcf1d5ff6a7b48..7783ed86cf4ae1d8672cc6a35a97ca9a996457b6 100644 --- a/src/data/DataProvider.cpp +++ b/src/data/DataProvider.cpp @@ -13,45 +13,56 @@ #include <cstddef> // std::size_t #include <memory> #include <vector> +#include <cmath> + #include "aidge/data/Database.hpp" #include "aidge/data/DataProvider.hpp" #include "aidge/data/Tensor.hpp" +#include "aidge/utils/Random.hpp" + -Aidge::DataProvider::DataProvider(const Aidge::Database& database, const std::size_t batchSize) +Aidge::DataProvider::DataProvider(const Aidge::Database& database, const std::size_t batchSize, const bool shuffle, const bool dropLast) : mDatabase(database), + mBatchSize(batchSize), + mShuffle(shuffle), + mDropLast(dropLast), mNumberModality(database.getItem(0).size()), - mBatchSize(batchSize) + mNbItems(mDatabase.getLen()), + mIndexBatch(0) { // Iterating on each data modality in the database // Get the tensor dimensions, datatype and backend of each modality to ensure each data have the same for (const auto& modality : mDatabase.getItem(0)) { - mDataSizes.push_back(modality->dims()); + mDataDims.push_back(modality->dims()); // assert(std::strcmp(item[i]->getImpl()->backend(), "cpu") == 0 && "DataProvider currently only supports cpu backend tensors"); - // mDataBackends.push_back(item[i]->getImpl()->backend()); mDataTypes.push_back(modality->dataType()); } + + // Compute the number of bacthes depending on mDropLast boolean + mNbBatch = (mDropLast) ? + static_cast<std::size_t>(std::floor(mNbItems / mBatchSize)) : + static_cast<std::size_t>(std::ceil(mNbItems / mBatchSize)); } -std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch(const std::size_t startIndex) const +std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch() const { - assert((startIndex) <= mDatabase.getLen() && " DataProvider readBatch : database fetch out of bounds"); - - - // Determine the batch size (may differ for the last batch) - const std::size_t current_batch_size = ((startIndex + mBatchSize) > mDatabase.getLen()) ? - mDatabase.getLen()-startIndex : - mBatchSize; + AIDGE_ASSERT(mIndexBatch <= mNbBatch, "Cannot fetch more data than available in database"); + std::size_t current_batch_size; + if (mIndexBatch == mNbBatch) { + current_batch_size = mLastBatchSize; + } else { + current_batch_size = mBatchSize; + } // Create batch tensors (dimensions, backends, datatype) for each modality std::vector<std::shared_ptr<Tensor>> batchTensors; - auto dataBatchSize = mDataSizes; + auto dataBatchDims = mDataDims; for (std::size_t i = 0; i < mNumberModality; ++i) { - dataBatchSize[i].insert(dataBatchSize[i].begin(), current_batch_size); + dataBatchDims[i].insert(dataBatchDims[i].begin(), current_batch_size); auto batchData = std::make_shared<Tensor>(); - batchData->resize(dataBatchSize[i]); - // batchData->setBackend(mDataBackends[i]); + batchData->resize(dataBatchDims[i]); batchData->setBackend("cpu"); batchData->setDataType(mDataTypes[i]); batchTensors.push_back(batchData); @@ -60,7 +71,8 @@ std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch(const // Call each database item and concatenate each data modularity in the batch tensors for (std::size_t i = 0; i < current_batch_size; ++i){ - auto dataItem = mDatabase.getItem(startIndex+i); + auto dataItem = mDatabase.getItem(mBatches[(mIndexBatch-1)*mBatchSize+i]); + // auto dataItem = mDatabase.getItem(startIndex+i); // assert same number of modalities assert(dataItem.size() == mNumberModality && "DataProvider readBatch : item from database have inconsistent number of modality."); @@ -69,7 +81,7 @@ std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch(const auto dataSample = dataItem[j]; // Assert tensor sizes - assert(dataSample->dims() == mDataSizes[j] && "DataProvider readBatch : corrupted Data size"); + assert(dataSample->dims() == mDataDims[j] && "DataProvider readBatch : corrupted Data size"); // Assert implementation backend // assert(dataSample->getImpl()->backend() == mDataBackends[j] && "DataProvider readBatch : corrupted data backend"); @@ -82,4 +94,31 @@ std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch(const } } return batchTensors; -} \ No newline at end of file +} + + +void Aidge::DataProvider::setBatches(){ + + mBatches.clear(); + mBatches.resize(mNbItems); + std::iota(mBatches.begin(), + mBatches.end(), + 0U); + + if (mShuffle){ + Random::randShuffle(mBatches); + } + + if (mNbItems % mBatchSize !=0){ // The last batch is not full + std::size_t lastBatchSize = static_cast<std::size_t>(mNbItems % mBatchSize); + if (mDropLast){ // Remove the last non-full batch + AIDGE_ASSERT(lastBatchSize <= mBatches.size(), "Last batch bigger than the size of database"); + mBatches.erase(mBatches.end() - lastBatchSize, mBatches.end()); + mLastBatchSize = mBatchSize; + } else { // Keep the last non-full batch + mLastBatchSize = lastBatchSize; + } + } else { // The last batch is full + mLastBatchSize = mBatchSize; + } +}