/******************************************************************************** * Copyright (c) 2023 CEA-List * * This program and the accompanying materials are made available under the * terms of the Eclipse Public License 2.0 which is available at * http://www.eclipse.org/legal/epl-2.0. * * SPDX-License-Identifier: EPL-2.0 * ********************************************************************************/ #include <catch2/catch_test_macros.hpp> #include "aidge/backend/opencv/database/MNIST.hpp" #include "aidge/data/Tensor.hpp" #include "aidge/data/DataProvider.hpp" // #include "aidge/backend/opencv/data/TensorImpl.hpp" // #include "aidge/backend/cpu/data/TensorImpl.hpp" using namespace Aidge; TEST_CASE("DataProvider instanciation & test mnist","[Data][OpenCV]") { // Create database std::string path = "/data1/is156025/tb256203/dev/eclipse_aidge/aidge/user_tests/test_mnist_database"; bool train = false; MNIST mnist(path, train); // DataProvider settings unsigned int batchSize = 256; unsigned int number_batch = std::ceil(mnist.getLen() / batchSize); // Instanciate the dataloader DataProvider provider(mnist, batchSize); // Perform the tests on the batches for (unsigned int i = 0; i < number_batch; ++i){ auto batch = provider.readBatch(i*batchSize); auto data_batch_ptr = static_cast<uint8_t*>(batch[0]->getImpl()->rawPtr()); auto label_batch_ptr = static_cast<int*>(batch[1]->getImpl()->rawPtr()); for (unsigned int s = 0; s < batchSize; ++s){ auto data = mnist.getItem(i*batchSize+s)[0]; auto label = mnist.getItem(i*batchSize+s)[1]; unsigned int size_data = data->size(); unsigned int size_label = label->size(); auto data_ptr = static_cast<uint8_t*>(data->getImpl()->rawPtr()); auto label_ptr = static_cast<int*>(label->getImpl()->rawPtr()); for (unsigned int j = 0; j < size_data; ++j){ auto element_data = data_ptr[j]; auto element_data_batch = data_batch_ptr[size_data*s+j]; REQUIRE(element_data == element_data_batch); } for (unsigned int j = 0; j < size_label; ++j){ auto element_label = label_ptr[j]; auto element_label_batch = label_batch_ptr[size_label*s+j]; REQUIRE(element_label == element_label_batch); } } } }