Skip to content
Snippets Groups Projects
  • Maxence Naud's avatar
    129c2f99
    [Upd] Code rework · 129c2f99
    Maxence Naud authored
    - [Optimize] add const, inline, noexcept where possible
    - [Utils] add DataUtils file for conversion from cv type to cpp type
    - [Syntax] Stimulis -> Stimuli and Stimuli -> Stimulus
    - [#define] Use standard Aidge syntax format
    - [Licence] add where missing
    - [#include] add what is used and remove what is not
    - [class] uniformize class member definition order
    - [types] change size_t for std::size_t from <stddef> for uniformazation
    - [types] change integer types for exact-width integers from <cstddint>
    - Remove end-of-line spaces
    129c2f99
    History
    [Upd] Code rework
    Maxence Naud authored
    - [Optimize] add const, inline, noexcept where possible
    - [Utils] add DataUtils file for conversion from cv type to cpp type
    - [Syntax] Stimulis -> Stimuli and Stimuli -> Stimulus
    - [#define] Use standard Aidge syntax format
    - [Licence] add where missing
    - [#include] add what is used and remove what is not
    - [class] uniformize class member definition order
    - [types] change size_t for std::size_t from <stddef> for uniformazation
    - [types] change integer types for exact-width integers from <cstddint>
    - Remove end-of-line spaces
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
Test_DataProvider.cpp 2.41 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include <catch2/catch_test_macros.hpp>

#include "aidge/backend/opencv/database/MNIST.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/data/DataProvider.hpp"


// #include "aidge/backend/opencv/data/TensorImpl.hpp"
// #include "aidge/backend/cpu/data/TensorImpl.hpp"

using namespace Aidge;

TEST_CASE("DataProvider instanciation & test mnist","[Data][OpenCV]") {

    // Create database
    std::string path = "/data1/is156025/tb256203/dev/eclipse_aidge/aidge/user_tests/test_mnist_database";
    bool train = false;
    MNIST mnist(path, train);

    // DataProvider settings
    unsigned int batchSize = 256;
    unsigned int number_batch = std::ceil(mnist.getLen() / batchSize);

    // Instanciate the dataloader
    DataProvider provider(mnist, batchSize);

    // Perform the tests on the batches
    for (unsigned int i = 0; i < number_batch; ++i){
        auto batch = provider.readBatch(i*batchSize);
        auto data_batch_ptr = static_cast<uint8_t*>(batch[0]->getImpl()->rawPtr());
        auto label_batch_ptr = static_cast<int*>(batch[1]->getImpl()->rawPtr());

        for (unsigned int s = 0; s < batchSize; ++s){
            auto data = mnist.getItem(i*batchSize+s)[0];
            auto label = mnist.getItem(i*batchSize+s)[1];
            unsigned int size_data = data->size();
            unsigned int size_label = label->size();
            auto data_ptr = static_cast<uint8_t*>(data->getImpl()->rawPtr());
            auto label_ptr = static_cast<int*>(label->getImpl()->rawPtr());
            for (unsigned int j = 0; j < size_data; ++j){
                auto element_data = data_ptr[j];
                auto element_data_batch = data_batch_ptr[size_data*s+j];
                REQUIRE(element_data == element_data_batch);
            }
            for (unsigned int j = 0; j < size_label; ++j){
                auto element_label = label_ptr[j];
                auto element_label_batch = label_batch_ptr[size_label*s+j];
                REQUIRE(element_label == element_label_batch);
            }
        }
    }
}