diff --git a/unit_tests/Test_DataProvider.cpp b/unit_tests/Test_DataProvider.cpp deleted file mode 100644 index 69f4f7b871c60cbbe074e4f8f4a700c7fa68899c..0000000000000000000000000000000000000000 --- a/unit_tests/Test_DataProvider.cpp +++ /dev/null @@ -1,63 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include <catch2/catch_test_macros.hpp> - -#include "aidge/backend/opencv/database/MNIST.hpp" -#include "aidge/data/Tensor.hpp" -#include "aidge/data/DataProvider.hpp" - - -// #include "aidge/backend/opencv/data/TensorImpl.hpp" -// #include "aidge/backend/cpu/data/TensorImpl.hpp" - -using namespace Aidge; - -TEST_CASE("DataProvider instanciation & test mnist","[Data][OpenCV]") { - - // Create database - std::string path = "/data1/is156025/tb256203/dev/eclipse_aidge/aidge/user_tests/test_mnist_database"; - bool train = false; - MNIST mnist(path, train); - - // DataProvider settings - unsigned int batchSize = 256; - unsigned int number_batch = std::ceil(mnist.getLen() / batchSize); - - // Instanciate the dataloader - DataProvider provider(mnist, batchSize); - - // Perform the tests on the batches - for (unsigned int i = 0; i < number_batch; ++i){ - auto batch = provider.readBatch(i*batchSize); - auto data_batch_ptr = static_cast<uint8_t*>(batch[0]->getImpl()->rawPtr()); - auto label_batch_ptr = static_cast<int*>(batch[1]->getImpl()->rawPtr()); - - for (unsigned int s = 0; s < batchSize; ++s){ - auto data = mnist.getItem(i*batchSize+s)[0]; - auto label = mnist.getItem(i*batchSize+s)[1]; - unsigned int size_data = data->size(); - unsigned int size_label = label->size(); - auto data_ptr = static_cast<uint8_t*>(data->getImpl()->rawPtr()); - auto label_ptr = static_cast<int*>(label->getImpl()->rawPtr()); - for (unsigned int j = 0; j < size_data; ++j){ - auto element_data = data_ptr[j]; - auto element_data_batch = data_batch_ptr[size_data*s+j]; - REQUIRE(element_data == element_data_batch); - } - for (unsigned int j = 0; j < size_label; ++j){ - auto element_label = label_ptr[j]; - auto element_label_batch = label_batch_ptr[size_label*s+j]; - REQUIRE(element_label == element_label_batch); - } - } - } -} \ No newline at end of file