Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
aidge_core
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Cyril Moineau
aidge_core
Commits
01f4ca8b
Commit
01f4ca8b
authored
1 year ago
by
Thibault Allenet
Browse files
Options
Downloads
Patches
Plain Diff
Add DataProvider iterator for python and shuffle and droplast batch
parent
da8c26ab
No related branches found
Branches containing commit
No related tags found
Tags containing commit
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
include/aidge/data/DataProvider.hpp
+78
-11
78 additions, 11 deletions
include/aidge/data/DataProvider.hpp
python_binding/data/pybind_DataProvider.cpp
+22
-8
22 additions, 8 deletions
python_binding/data/pybind_DataProvider.cpp
src/data/DataProvider.cpp
+58
-19
58 additions, 19 deletions
src/data/DataProvider.cpp
with
158 additions
and
38 deletions
include/aidge/data/DataProvider.hpp
+
78
−
11
View file @
01f4ca8b
...
@@ -20,8 +20,6 @@
...
@@ -20,8 +20,6 @@
#include
"aidge/data/Database.hpp"
#include
"aidge/data/Database.hpp"
#include
"aidge/data/Data.hpp"
#include
"aidge/data/Data.hpp"
namespace
Aidge
{
namespace
Aidge
{
/**
/**
...
@@ -33,14 +31,35 @@ class DataProvider {
...
@@ -33,14 +31,35 @@ class DataProvider {
private:
private:
// Dataset providing the data to the dataProvider
// Dataset providing the data to the dataProvider
const
Database
&
mDatabase
;
const
Database
&
mDatabase
;
// Desired size of the produced batches
const
std
::
size_t
mBatchSize
;
// Enable random shuffling for learning
const
bool
mShuffle
;
// Drops the last non-full batch
const
bool
mDropLast
;
// Number of modality in one item
const
std
::
size_t
mNumberModality
;
const
std
::
size_t
mNumberModality
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
mDataSizes
;
std
::
vector
<
std
::
string
>
mDataBackends
;
std
::
vector
<
DataType
>
mDataTypes
;
// Desired size of the produced batches
// mNbItems contains the number of items in the database
const
std
::
size_t
mBatchSize
;
std
::
size_t
mNbItems
;
// mBatches contains the call order of each database item
std
::
vector
<
unsigned
int
>
mBatches
;
// mIndex browsing the number of batch
std
::
size_t
mIndexBatch
;
// mNbBatch contains the number of batch
std
::
size_t
mNbBatch
;
// Size of the Last batch
std
::
size_t
mLastBatchSize
;
// Store each modality dimensions, backend and type
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
mDataDims
;
std
::
vector
<
std
::
string
>
mDataBackends
;
std
::
vector
<
DataType
>
mDataTypes
;
public:
public:
/**
/**
...
@@ -48,15 +67,63 @@ public:
...
@@ -48,15 +67,63 @@ public:
* @param database database from which to load the data.
* @param database database from which to load the data.
* @param batchSize number of data samples per batch.
* @param batchSize number of data samples per batch.
*/
*/
DataProvider
(
const
Database
&
database
,
const
std
::
size_t
batchSize
);
DataProvider
(
const
Database
&
database
,
const
std
::
size_t
batchSize
,
const
bool
shuffle
=
false
,
const
bool
dropLast
=
false
);
public:
public:
/**
/**
* @brief Create a batch for each data modality in the database. The returned batch contain the data as sorted in the database.
* @brief Create a batch for each data modality in the database.
* @param startIndex the starting index in the database to start the batch from.
* @return a vector of tensors. Each tensor is a batch corresponding to one modality.
* @return a vector of tensors. Each tensor is a batch corresponding to one modality.
*/
*/
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
readBatch
(
const
std
::
size_t
startIndex
)
const
;
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
readBatch
()
const
;
/**
* @brief Get the Number of Batch
*
* @return std::size_t
*/
inline
std
::
size_t
getNbBatch
(){
return
mNbBatch
;
};
/**
* @brief Get the current Index Batch
*
* @return std::size_t
*/
inline
std
::
size_t
getIndexBatch
(){
return
mIndexBatch
;
};
/**
* @brief Reset the internal index batch that browses the data of the database to zero.
*/
inline
void
resetIndexBatch
(){
mIndexBatch
=
0
;
};
/**
* @brief Increment the internal index batch that browses the data of the database.
*/
inline
void
incrementIndexBatch
(){
++
mIndexBatch
;
};
void
setBatches
();
/**
* @brief End of dataProvider condition
*
* @return true when all batch were fetched, False otherwise
*/
inline
bool
done
(){
return
(
mIndexBatch
==
mNbBatch
);
};
// Functions for python iterator iter and next (definition in pybind.cpp)
// __iter__ method for iterator protocol
DataProvider
*
iter
();
// __next__ method for iterator protocol
std
::
vector
<
std
::
shared_ptr
<
Aidge
::
Tensor
>>
next
();
};
};
}
// namespace Aidge
}
// namespace Aidge
...
...
This diff is collapsed.
Click to expand it.
python_binding/data/pybind_DataProvider.cpp
+
22
−
8
View file @
01f4ca8b
...
@@ -4,19 +4,33 @@
...
@@ -4,19 +4,33 @@
#include
"aidge/data/Database.hpp"
#include
"aidge/data/Database.hpp"
namespace
py
=
pybind11
;
namespace
py
=
pybind11
;
namespace
Aidge
{
namespace
Aidge
{
// __iter__ method for iterator protocol
DataProvider
*
DataProvider
::
iter
(){
resetIndexBatch
();
setBatches
();
return
this
;
}
// __next__ method for iterator protocol
std
::
vector
<
std
::
shared_ptr
<
Aidge
::
Tensor
>>
DataProvider
::
next
()
{
if
(
!
done
()){
incrementIndexBatch
();
return
readBatch
();
}
else
{
throw
py
::
stop_iteration
();
}
}
void
init_DataProvider
(
py
::
module
&
m
){
void
init_DataProvider
(
py
::
module
&
m
){
py
::
class_
<
DataProvider
,
std
::
shared_ptr
<
DataProvider
>>
(
m
,
"DataProvider"
)
py
::
class_
<
DataProvider
,
std
::
shared_ptr
<
DataProvider
>>
(
m
,
"DataProvider"
)
.
def
(
py
::
init
<
Database
&
,
std
::
size_t
>
(),
py
::
arg
(
"database"
),
py
::
arg
(
"batchSize"
))
.
def
(
py
::
init
<
Database
&
,
std
::
size_t
,
bool
,
bool
>
(),
py
::
arg
(
"database"
),
py
::
arg
(
"batch_size"
),
py
::
arg
(
"shuffle"
),
py
::
arg
(
"drop_last"
))
.
def
(
"read_batch"
,
&
DataProvider
::
readBatch
,
py
::
arg
(
"start_index"
),
.
def
(
"__iter__"
,
&
DataProvider
::
iter
)
R"mydelimiter(
.
def
(
"__next__"
,
&
DataProvider
::
next
)
Return a batch of each data modality.
.
def
(
"__len__"
,
&
DataProvider
::
getNbBatch
);
:param start_index: Database starting index to read the batch from
:type start_index: int
)mydelimiter"
);
}
}
}
}
This diff is collapsed.
Click to expand it.
src/data/DataProvider.cpp
+
58
−
19
View file @
01f4ca8b
...
@@ -13,45 +13,56 @@
...
@@ -13,45 +13,56 @@
#include
<cstddef>
// std::size_t
#include
<cstddef>
// std::size_t
#include
<memory>
#include
<memory>
#include
<vector>
#include
<vector>
#include
<cmath>
#include
"aidge/data/Database.hpp"
#include
"aidge/data/Database.hpp"
#include
"aidge/data/DataProvider.hpp"
#include
"aidge/data/DataProvider.hpp"
#include
"aidge/data/Tensor.hpp"
#include
"aidge/data/Tensor.hpp"
#include
"aidge/utils/Random.hpp"
Aidge
::
DataProvider
::
DataProvider
(
const
Aidge
::
Database
&
database
,
const
std
::
size_t
batchSize
)
Aidge
::
DataProvider
::
DataProvider
(
const
Aidge
::
Database
&
database
,
const
std
::
size_t
batchSize
,
const
bool
shuffle
,
const
bool
dropLast
)
:
mDatabase
(
database
),
:
mDatabase
(
database
),
mBatchSize
(
batchSize
),
mShuffle
(
shuffle
),
mDropLast
(
dropLast
),
mNumberModality
(
database
.
getItem
(
0
).
size
()),
mNumberModality
(
database
.
getItem
(
0
).
size
()),
mBatchSize
(
batchSize
)
mNbItems
(
mDatabase
.
getLen
()),
mIndexBatch
(
0
)
{
{
// Iterating on each data modality in the database
// Iterating on each data modality in the database
// Get the tensor dimensions, datatype and backend of each modality to ensure each data have the same
// Get the tensor dimensions, datatype and backend of each modality to ensure each data have the same
for
(
const
auto
&
modality
:
mDatabase
.
getItem
(
0
))
{
for
(
const
auto
&
modality
:
mDatabase
.
getItem
(
0
))
{
mData
Size
s
.
push_back
(
modality
->
dims
());
mData
Dim
s
.
push_back
(
modality
->
dims
());
// assert(std::strcmp(item[i]->getImpl()->backend(), "cpu") == 0 && "DataProvider currently only supports cpu backend tensors");
// assert(std::strcmp(item[i]->getImpl()->backend(), "cpu") == 0 && "DataProvider currently only supports cpu backend tensors");
// mDataBackends.push_back(item[i]->getImpl()->backend());
mDataTypes
.
push_back
(
modality
->
dataType
());
mDataTypes
.
push_back
(
modality
->
dataType
());
}
}
// Compute the number of bacthes depending on mDropLast boolean
mNbBatch
=
(
mDropLast
)
?
static_cast
<
std
::
size_t
>
(
std
::
floor
(
mNbItems
/
mBatchSize
))
:
static_cast
<
std
::
size_t
>
(
std
::
ceil
(
mNbItems
/
mBatchSize
));
}
}
std
::
vector
<
std
::
shared_ptr
<
Aidge
::
Tensor
>>
Aidge
::
DataProvider
::
readBatch
(
const
std
::
size_t
startIndex
)
const
std
::
vector
<
std
::
shared_ptr
<
Aidge
::
Tensor
>>
Aidge
::
DataProvider
::
readBatch
()
const
{
{
assert
((
startIndex
)
<=
mDatabase
.
getLen
()
&&
" DataProvider readBatch : database fetch out of bounds
"
);
AIDGE_ASSERT
(
mIndexBatch
<=
mNbBatch
,
"Cannot fetch more data than available in database
"
);
std
::
size_t
current_batch_size
;
if
(
mIndexBatch
==
mNbBatch
)
{
// Determine the
batch
size
(may differ for the l
ast
b
atch
)
current_
batch
_
size
=
mL
ast
B
atch
Size
;
const
std
::
size_t
current_batch_size
=
((
startIndex
+
mBatchSize
)
>
mDatabase
.
getLen
())
?
}
else
{
mDatabase
.
getLen
()
-
startIndex
:
current_batch_size
=
mBatchSize
;
mBatchSize
;
}
// Create batch tensors (dimensions, backends, datatype) for each modality
// Create batch tensors (dimensions, backends, datatype) for each modality
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
batchTensors
;
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
batchTensors
;
auto
dataBatch
Size
=
mData
Size
s
;
auto
dataBatch
Dims
=
mData
Dim
s
;
for
(
std
::
size_t
i
=
0
;
i
<
mNumberModality
;
++
i
)
{
for
(
std
::
size_t
i
=
0
;
i
<
mNumberModality
;
++
i
)
{
dataBatch
Size
[
i
].
insert
(
dataBatch
Size
[
i
].
begin
(),
current_batch_size
);
dataBatch
Dims
[
i
].
insert
(
dataBatch
Dims
[
i
].
begin
(),
current_batch_size
);
auto
batchData
=
std
::
make_shared
<
Tensor
>
();
auto
batchData
=
std
::
make_shared
<
Tensor
>
();
batchData
->
resize
(
dataBatchSize
[
i
]);
batchData
->
resize
(
dataBatchDims
[
i
]);
// batchData->setBackend(mDataBackends[i]);
batchData
->
setBackend
(
"cpu"
);
batchData
->
setBackend
(
"cpu"
);
batchData
->
setDataType
(
mDataTypes
[
i
]);
batchData
->
setDataType
(
mDataTypes
[
i
]);
batchTensors
.
push_back
(
batchData
);
batchTensors
.
push_back
(
batchData
);
...
@@ -60,7 +71,8 @@ std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch(const
...
@@ -60,7 +71,8 @@ std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch(const
// Call each database item and concatenate each data modularity in the batch tensors
// Call each database item and concatenate each data modularity in the batch tensors
for
(
std
::
size_t
i
=
0
;
i
<
current_batch_size
;
++
i
){
for
(
std
::
size_t
i
=
0
;
i
<
current_batch_size
;
++
i
){
auto
dataItem
=
mDatabase
.
getItem
(
startIndex
+
i
);
auto
dataItem
=
mDatabase
.
getItem
(
mBatches
[(
mIndexBatch
-
1
)
*
mBatchSize
+
i
]);
// auto dataItem = mDatabase.getItem(startIndex+i);
// assert same number of modalities
// assert same number of modalities
assert
(
dataItem
.
size
()
==
mNumberModality
&&
"DataProvider readBatch : item from database have inconsistent number of modality."
);
assert
(
dataItem
.
size
()
==
mNumberModality
&&
"DataProvider readBatch : item from database have inconsistent number of modality."
);
...
@@ -69,7 +81,7 @@ std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch(const
...
@@ -69,7 +81,7 @@ std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch(const
auto
dataSample
=
dataItem
[
j
];
auto
dataSample
=
dataItem
[
j
];
// Assert tensor sizes
// Assert tensor sizes
assert
(
dataSample
->
dims
()
==
mData
Size
s
[
j
]
&&
"DataProvider readBatch : corrupted Data size"
);
assert
(
dataSample
->
dims
()
==
mData
Dim
s
[
j
]
&&
"DataProvider readBatch : corrupted Data size"
);
// Assert implementation backend
// Assert implementation backend
// assert(dataSample->getImpl()->backend() == mDataBackends[j] && "DataProvider readBatch : corrupted data backend");
// assert(dataSample->getImpl()->backend() == mDataBackends[j] && "DataProvider readBatch : corrupted data backend");
...
@@ -82,4 +94,31 @@ std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch(const
...
@@ -82,4 +94,31 @@ std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch(const
}
}
}
}
return
batchTensors
;
return
batchTensors
;
}
}
\ No newline at end of file
void
Aidge
::
DataProvider
::
setBatches
(){
mBatches
.
clear
();
mBatches
.
resize
(
mNbItems
);
std
::
iota
(
mBatches
.
begin
(),
mBatches
.
end
(),
0U
);
if
(
mShuffle
){
Random
::
randShuffle
(
mBatches
);
}
if
(
mNbItems
%
mBatchSize
!=
0
){
// The last batch is not full
std
::
size_t
lastBatchSize
=
static_cast
<
std
::
size_t
>
(
mNbItems
%
mBatchSize
);
if
(
mDropLast
){
// Remove the last non-full batch
AIDGE_ASSERT
(
lastBatchSize
<=
mBatches
.
size
(),
"Last batch bigger than the size of database"
);
mBatches
.
erase
(
mBatches
.
end
()
-
lastBatchSize
,
mBatches
.
end
());
mLastBatchSize
=
mBatchSize
;
}
else
{
// Keep the last non-full batch
mLastBatchSize
=
lastBatchSize
;
}
}
else
{
// The last batch is full
mLastBatchSize
=
mBatchSize
;
}
}
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment