-
Maxence Naud authored
- [Optimize] add const, inline, noexcept where possible - [Utils] add DataUtils file for conversion from cv type to cpp type - [Syntax] Stimulis -> Stimuli and Stimuli -> Stimulus - [#define] Use standard Aidge syntax format - [Licence] add where missing - [#include] add what is used and remove what is not - [class] uniformize class member definition order - [types] change size_t for std::size_t from <stddef> for uniformazation - [types] change integer types for exact-width integers from <cstddint> - Remove end-of-line spaces
Maxence Naud authored- [Optimize] add const, inline, noexcept where possible - [Utils] add DataUtils file for conversion from cv type to cpp type - [Syntax] Stimulis -> Stimuli and Stimuli -> Stimulus - [#define] Use standard Aidge syntax format - [Licence] add where missing - [#include] add what is used and remove what is not - [class] uniformize class member definition order - [types] change size_t for std::size_t from <stddef> for uniformazation - [types] change integer types for exact-width integers from <cstddint> - Remove end-of-line spaces
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
Test_DataProvider.cpp 2.41 KiB
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include "aidge/backend/opencv/database/MNIST.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/data/DataProvider.hpp"
// #include "aidge/backend/opencv/data/TensorImpl.hpp"
// #include "aidge/backend/cpu/data/TensorImpl.hpp"
using namespace Aidge;
TEST_CASE("DataProvider instanciation & test mnist","[Data][OpenCV]") {
// Create database
std::string path = "/data1/is156025/tb256203/dev/eclipse_aidge/aidge/user_tests/test_mnist_database";
bool train = false;
MNIST mnist(path, train);
// DataProvider settings
unsigned int batchSize = 256;
unsigned int number_batch = std::ceil(mnist.getLen() / batchSize);
// Instanciate the dataloader
DataProvider provider(mnist, batchSize);
// Perform the tests on the batches
for (unsigned int i = 0; i < number_batch; ++i){
auto batch = provider.readBatch(i*batchSize);
auto data_batch_ptr = static_cast<uint8_t*>(batch[0]->getImpl()->rawPtr());
auto label_batch_ptr = static_cast<int*>(batch[1]->getImpl()->rawPtr());
for (unsigned int s = 0; s < batchSize; ++s){
auto data = mnist.getItem(i*batchSize+s)[0];
auto label = mnist.getItem(i*batchSize+s)[1];
unsigned int size_data = data->size();
unsigned int size_label = label->size();
auto data_ptr = static_cast<uint8_t*>(data->getImpl()->rawPtr());
auto label_ptr = static_cast<int*>(label->getImpl()->rawPtr());
for (unsigned int j = 0; j < size_data; ++j){
auto element_data = data_ptr[j];
auto element_data_batch = data_batch_ptr[size_data*s+j];
REQUIRE(element_data == element_data_batch);
}
for (unsigned int j = 0; j < size_label; ++j){
auto element_label = label_ptr[j];
auto element_label_batch = label_batch_ptr[size_label*s+j];
REQUIRE(element_label == element_label_batch);
}
}
}
}