Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • eclipse/aidge/aidge_core
  • hrouis/aidge_core
  • mszczep/aidge_core
  • oantoni/aidge_core
  • cguillon/aidge_core
  • jeromeh/aidge_core
  • axelfarr/aidge_core
  • cmoineau/aidge_core
  • noamzerah/aidge_core
  • lrakotoarivony/aidge_core
  • silvanosky/aidge_core
  • maab05/aidge_core
  • mick94/aidge_core
  • lucaslopez/aidge_core_ll
  • wboussella/aidge_core
  • farnez/aidge_core
  • mnewson/aidge_core
17 results
Show changes
Commits on Source (6)
......@@ -70,6 +70,7 @@
#include "aidge/utils/Attributes.hpp"
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/DynamicAttributes.hpp"
#include "aidge/utils/Random.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
......
......@@ -20,8 +20,6 @@
#include "aidge/data/Database.hpp"
#include "aidge/data/Data.hpp"
namespace Aidge {
/**
......@@ -33,14 +31,35 @@ class DataProvider {
private:
// Dataset providing the data to the dataProvider
const Database& mDatabase;
// Desired size of the produced batches
const std::size_t mBatchSize;
// Enable random shuffling for learning
const bool mShuffle;
// Drops the last non-full batch
const bool mDropLast;
// Number of modality in one item
const std::size_t mNumberModality;
std::vector<std::vector<std::size_t>> mDataSizes;
std::vector<std::string> mDataBackends;
std::vector<DataType> mDataTypes;
// Desired size of the produced batches
const std::size_t mBatchSize;
// mNbItems contains the number of items in the database
std::size_t mNbItems;
// mBatches contains the call order of each database item
std::vector<unsigned int> mBatches;
// mIndex browsing the number of batch
std::size_t mIndexBatch;
// mNbBatch contains the number of batch
std::size_t mNbBatch;
// Size of the Last batch
std::size_t mLastBatchSize;
// Store each modality dimensions, backend and type
std::vector<std::vector<std::size_t>> mDataDims;
std::vector<std::string> mDataBackends;
std::vector<DataType> mDataTypes;
public:
/**
......@@ -48,15 +67,76 @@ public:
* @param database database from which to load the data.
* @param batchSize number of data samples per batch.
*/
DataProvider(const Database& database, const std::size_t batchSize);
DataProvider(const Database& database, const std::size_t batchSize, const bool shuffle = false, const bool dropLast = false);
public:
/**
* @brief Create a batch for each data modality in the database. The returned batch contain the data as sorted in the database.
* @param startIndex the starting index in the database to start the batch from.
* @brief Create a batch for each data modality in the database.
* @return a vector of tensors. Each tensor is a batch corresponding to one modality.
*/
std::vector<std::shared_ptr<Tensor>> readBatch(const std::size_t startIndex) const;
std::vector<std::shared_ptr<Tensor>> readBatch() const;
/**
* @brief Get the Number of Batch.
*
* @return std::size_t
*/
inline std::size_t getNbBatch(){
return mNbBatch;
};
/**
* @brief Get the current Index Batch.
*
* @return std::size_t
*/
inline std::size_t getIndexBatch(){
return mIndexBatch;
};
/**
* @brief Reset the internal index batch that browses the data of the database to zero.
*/
inline void resetIndexBatch(){
mIndexBatch = 0;
};
/**
* @brief Increment the internal index batch that browses the data of the database.
*/
inline void incrementIndexBatch(){
++mIndexBatch;
};
/**
* @brief Setup the batches for one pass on the database.
*/
void setBatches();
/**
* @brief End condition of dataProvider for one pass on the database.
*
* @return true when all batch were fetched, False otherwise
*/
inline bool done(){
return (mIndexBatch == mNbBatch);
};
// Functions for python iterator iter and next (definition in pybind.cpp)
/**
* @brief __iter__ method for iterator protocol
*
* @return DataProvider*
*/
DataProvider* iter();
/**
* @brief __next__ method for iterator protocol
*
* @return std::vector<std::shared_ptr<Aidge::Tensor>>
*/
std::vector<std::shared_ptr<Aidge::Tensor>> next();
};
} // namespace Aidge
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_RANDOM_H_
#define AIDGE_RANDOM_H_
#include <algorithm>
#include <vector>
#include <random>
namespace Random {
void randShuffle(std::vector<unsigned int>& vec) {
std::random_device rd;
std::mt19937 g(rd());
std::shuffle(vec.begin(), vec.end(), g);
}
}
#endif //AIDGE_RANDOM_H_
\ No newline at end of file
......@@ -4,19 +4,33 @@
#include "aidge/data/Database.hpp"
namespace py = pybind11;
namespace Aidge {
// __iter__ method for iterator protocol
DataProvider* DataProvider::iter(){
resetIndexBatch();
setBatches();
return this;
}
// __next__ method for iterator protocol
std::vector<std::shared_ptr<Aidge::Tensor>> DataProvider::next() {
if (!done()){
incrementIndexBatch();
return readBatch();
} else {
throw py::stop_iteration();
}
}
void init_DataProvider(py::module& m){
py::class_<DataProvider, std::shared_ptr<DataProvider>>(m, "DataProvider")
.def(py::init<Database&, std::size_t>(), py::arg("database"), py::arg("batchSize"))
.def("read_batch", &DataProvider::readBatch, py::arg("start_index"),
R"mydelimiter(
Return a batch of each data modality.
:param start_index: Database starting index to read the batch from
:type start_index: int
)mydelimiter");
.def(py::init<Database&, std::size_t, bool, bool>(), py::arg("database"), py::arg("batch_size"), py::arg("shuffle"), py::arg("drop_last"))
.def("__iter__", &DataProvider::iter)
.def("__next__", &DataProvider::next)
.def("__len__", &DataProvider::getNbBatch);
}
}
......@@ -96,10 +96,18 @@ void init_Tensor(py::module& m){
return py::cast(b.get<double>(idx));
case DataType::Float32:
return py::cast(b.get<float>(idx));
case DataType::Int8:
return py::cast(b.get<std::int8_t>(idx));
case DataType::Int16:
return py::cast(b.get<std::int16_t>(idx));
case DataType::Int32:
return py::cast(b.get<std::int32_t>(idx));
case DataType::Int64:
return py::cast(b.get<std::int64_t>(idx));
case DataType::UInt8:
return py::cast(b.get<std::uint8_t>(idx));
case DataType::UInt16:
return py::cast(b.get<std::uint16_t>(idx));
default:
return py::none();
}
......@@ -111,10 +119,18 @@ void init_Tensor(py::module& m){
return py::cast(b.get<double>(coordIdx));
case DataType::Float32:
return py::cast(b.get<float>(coordIdx));
case DataType::Int8:
return py::cast(b.get<std::int8_t>(coordIdx));
case DataType::Int16:
return py::cast(b.get<std::int16_t>(coordIdx));
case DataType::Int32:
return py::cast(b.get<std::int32_t>(coordIdx));
case DataType::Int64:
return py::cast(b.get<std::int64_t>(coordIdx));
case DataType::UInt8:
return py::cast(b.get<std::uint8_t>(coordIdx));
case DataType::UInt16:
return py::cast(b.get<std::uint16_t>(coordIdx));
default:
return py::none();
}
......@@ -141,6 +157,12 @@ void init_Tensor(py::module& m){
break;
case DataType::Float32:
dataFormatDescriptor = py::format_descriptor<float>::format();
break;;
case DataType::Int8:
dataFormatDescriptor = py::format_descriptor<std::int8_t>::format();
break;
case DataType::Int16:
dataFormatDescriptor = py::format_descriptor<std::int16_t>::format();
break;
case DataType::Int32:
dataFormatDescriptor = py::format_descriptor<std::int32_t>::format();
......@@ -148,6 +170,12 @@ void init_Tensor(py::module& m){
case DataType::Int64:
dataFormatDescriptor = py::format_descriptor<std::int64_t>::format();
break;
case DataType::UInt8:
dataFormatDescriptor = py::format_descriptor<std::uint8_t>::format();
break;
case DataType::UInt16:
dataFormatDescriptor = py::format_descriptor<std::uint16_t>::format();
break;
default:
throw py::value_error("Unsupported data format");
}
......
......@@ -13,45 +13,56 @@
#include <cstddef> // std::size_t
#include <memory>
#include <vector>
#include <cmath>
#include "aidge/data/Database.hpp"
#include "aidge/data/DataProvider.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/Random.hpp"
Aidge::DataProvider::DataProvider(const Aidge::Database& database, const std::size_t batchSize)
Aidge::DataProvider::DataProvider(const Aidge::Database& database, const std::size_t batchSize, const bool shuffle, const bool dropLast)
: mDatabase(database),
mBatchSize(batchSize),
mShuffle(shuffle),
mDropLast(dropLast),
mNumberModality(database.getItem(0).size()),
mBatchSize(batchSize)
mNbItems(mDatabase.getLen()),
mIndexBatch(0)
{
// Iterating on each data modality in the database
// Get the tensor dimensions, datatype and backend of each modality to ensure each data have the same
for (const auto& modality : mDatabase.getItem(0)) {
mDataSizes.push_back(modality->dims());
mDataDims.push_back(modality->dims());
// assert(std::strcmp(item[i]->getImpl()->backend(), "cpu") == 0 && "DataProvider currently only supports cpu backend tensors");
// mDataBackends.push_back(item[i]->getImpl()->backend());
mDataTypes.push_back(modality->dataType());
}
// Compute the number of bacthes depending on mDropLast boolean
mNbBatch = (mDropLast) ?
static_cast<std::size_t>(std::floor(mNbItems / mBatchSize)) :
static_cast<std::size_t>(std::ceil(mNbItems / mBatchSize));
}
std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch(const std::size_t startIndex) const
std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch() const
{
assert((startIndex) <= mDatabase.getLen() && " DataProvider readBatch : database fetch out of bounds");
// Determine the batch size (may differ for the last batch)
const std::size_t current_batch_size = ((startIndex + mBatchSize) > mDatabase.getLen()) ?
mDatabase.getLen()-startIndex :
mBatchSize;
AIDGE_ASSERT(mIndexBatch <= mNbBatch, "Cannot fetch more data than available in database");
std::size_t current_batch_size;
if (mIndexBatch == mNbBatch) {
current_batch_size = mLastBatchSize;
} else {
current_batch_size = mBatchSize;
}
// Create batch tensors (dimensions, backends, datatype) for each modality
std::vector<std::shared_ptr<Tensor>> batchTensors;
auto dataBatchSize = mDataSizes;
auto dataBatchDims = mDataDims;
for (std::size_t i = 0; i < mNumberModality; ++i) {
dataBatchSize[i].insert(dataBatchSize[i].begin(), current_batch_size);
dataBatchDims[i].insert(dataBatchDims[i].begin(), current_batch_size);
auto batchData = std::make_shared<Tensor>();
batchData->resize(dataBatchSize[i]);
// batchData->setBackend(mDataBackends[i]);
batchData->resize(dataBatchDims[i]);
batchData->setBackend("cpu");
batchData->setDataType(mDataTypes[i]);
batchTensors.push_back(batchData);
......@@ -60,7 +71,8 @@ std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch(const
// Call each database item and concatenate each data modularity in the batch tensors
for (std::size_t i = 0; i < current_batch_size; ++i){
auto dataItem = mDatabase.getItem(startIndex+i);
auto dataItem = mDatabase.getItem(mBatches[(mIndexBatch-1)*mBatchSize+i]);
// auto dataItem = mDatabase.getItem(startIndex+i);
// assert same number of modalities
assert(dataItem.size() == mNumberModality && "DataProvider readBatch : item from database have inconsistent number of modality.");
......@@ -69,7 +81,7 @@ std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch(const
auto dataSample = dataItem[j];
// Assert tensor sizes
assert(dataSample->dims() == mDataSizes[j] && "DataProvider readBatch : corrupted Data size");
assert(dataSample->dims() == mDataDims[j] && "DataProvider readBatch : corrupted Data size");
// Assert implementation backend
// assert(dataSample->getImpl()->backend() == mDataBackends[j] && "DataProvider readBatch : corrupted data backend");
......@@ -82,4 +94,31 @@ std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch(const
}
}
return batchTensors;
}
\ No newline at end of file
}
void Aidge::DataProvider::setBatches(){
mBatches.clear();
mBatches.resize(mNbItems);
std::iota(mBatches.begin(),
mBatches.end(),
0U);
if (mShuffle){
Random::randShuffle(mBatches);
}
if (mNbItems % mBatchSize !=0){ // The last batch is not full
std::size_t lastBatchSize = static_cast<std::size_t>(mNbItems % mBatchSize);
if (mDropLast){ // Remove the last non-full batch
AIDGE_ASSERT(lastBatchSize <= mBatches.size(), "Last batch bigger than the size of database");
mBatches.erase(mBatches.end() - lastBatchSize, mBatches.end());
mLastBatchSize = mBatchSize;
} else { // Keep the last non-full batch
mLastBatchSize = lastBatchSize;
}
} else { // The last batch is full
mLastBatchSize = mBatchSize;
}
}