/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include <chrono>      // std::micro, std::chrono::time_point,
                       // std::chrono::system_clock, std::chrono::duration
#include <cstddef>     // std::size_t
#include <cstdint>     // std::uint16_t
#include <functional>  // std::multiplies
#include <memory>
#include <numeric>     // std::accumulate
#include <random>      // std::random_device, std::mt19937
                       // std::uniform_int_distribution, std::uniform_real_distribution
#include <vector>

#include <catch2/catch_test_macros.hpp>
#include <fmt/core.h>

#include "aidge/backend/cpu/data/TensorImpl.hpp"
#include "aidge/backend/cpu/operator/RoundImpl.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Round.hpp"
#include "aidge/utils/TensorUtils.hpp"

namespace Aidge {

TEST_CASE("[cpu/operator] Round_Test", "[Round][CPU]") {
    constexpr std::uint16_t NBTRIALS = 15;
    // Create a random number generator
    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_real_distribution<float> valueDist(-15, 15);
    std::uniform_int_distribution<std::size_t> dimSizeDist(std::size_t(2), std::size_t(5));
    std::uniform_int_distribution<std::size_t> nbDimsDist(std::size_t(1), std::size_t(3));

    // Create BitShift Operator
    std::shared_ptr<Node> myRound = Round();
    auto op = std::static_pointer_cast<OperatorTensor>(myRound-> getOperator());
    op->setDataType(DataType::Float32);
    op->setBackend("cpu");

    // Create 2 input Tensors
    std::shared_ptr<Tensor> T0 = std::make_shared<Tensor>();
    op->associateInput(0,T0);
    T0->setDataType(DataType::Float32);
    T0->setBackend("cpu");
    // Create results Tensor
    std::shared_ptr<Tensor> Tres = std::make_shared<Tensor>();
    Tres->setDataType(DataType::Float32);
    Tres->setBackend("cpu");

    // To measure execution time of 'Round_Op::forward()' member function call
    std::chrono::time_point<std::chrono::system_clock> start;
    std::chrono::time_point<std::chrono::system_clock> end;
    std::chrono::duration<double, std::micro> duration{};

    SECTION("Round [Forward]") {
        SECTION("Test Forward Kernel") {
            std::size_t number_of_operation = 0;

            for (std::uint16_t trial = 0; trial < NBTRIALS; ++trial) {

                // generate 2 random Tensors
                const std::size_t nbDims = nbDimsDist(gen);
                std::vector<std::size_t> dims;
                for (std::size_t i = 0; i < nbDims; ++i) {
                    dims.push_back(dimSizeDist(gen));
                }
                const std::size_t nb_elements = std::accumulate(dims.cbegin(), dims.cend(), std::size_t(1), std::multiplies<std::size_t>());
                number_of_operation += nb_elements;

                // without broadcasting
                float* array0 = new float[nb_elements];
                float* result = new float[nb_elements];

                for (std::size_t i = 0; i < nb_elements; ++i) {
                    array0[i] = valueDist(gen);
                    result[i] = std::nearbyint(array0[i]);

                }

                // input0
                T0->resize(dims);
                T0 -> getImpl() -> setRawPtr(array0, nb_elements);

                // results
                Tres->resize(dims);
                Tres -> getImpl() -> setRawPtr(result, nb_elements);

                op->forwardDims();
                start = std::chrono::system_clock::now();
                myRound->forward();
                end = std::chrono::system_clock::now();
                duration += std::chrono::duration_cast<std::chrono::microseconds>(end - start);

                REQUIRE(approxEq<float>(*(op->getOutput(0)), *Tres));

                delete[] array0;
                delete[] result;


            }
            Log::info("number of elements over time spent: {}\n", (number_of_operation / duration.count()));
            Log::info("total time: {} μs\n", duration.count());
        }
    }
} // namespace Aidge
}