/********************************************************************************
 * Copyright (c) 2024 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include <array>
#include <numeric> // std::accumulate
#include <random>  // std::random_device, std::mt19937, std::uniform_real_distribution

#include <catch2/catch_test_macros.hpp>

#include "aidge/backend/cpu.hpp"
#include "aidge/backend/cuda.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/TensorUtils.hpp"

namespace Aidge {

TEST_CASE("[gpu/operator] Ln", "[Ln][GPU]") {
constexpr std::uint16_t NBTRIALS = 10;
        // Create a random number generator
        std::random_device rd;
        std::mt19937 gen(rd());
        std::uniform_real_distribution<float> valueDist(
            0.1f, 1.1f); // Random float distribution between 0 and 1
        std::uniform_int_distribution<std::size_t> dimSizeDist(std::size_t(1),
                                                               std::size_t(10));
        std::uniform_int_distribution<std::size_t> nbDimsDist(std::size_t(1), std::size_t(8));

        // To measure execution time of 'forward()'
        std::chrono::time_point<std::chrono::system_clock> start;
        std::chrono::time_point<std::chrono::system_clock> end;
        std::chrono::duration<double, std::micro> duration{};
        std::size_t number_of_operation = 0;
        for (std::uint16_t trial = 0; trial < NBTRIALS; ++trial)
        {
            // Create Ln Operator CUDA
            std::shared_ptr<Node> myLnCUDA = Ln();
            auto op_cuda = std::static_pointer_cast<OperatorTensor>(myLnCUDA -> getOperator());

            // Create Ln Operator CPU
            std::shared_ptr<Node> myLnCPU = Ln();
            auto op_cpu = std::static_pointer_cast<OperatorTensor>(myLnCPU -> getOperator());
            op_cpu->setDataType(DataType::Float32);
            op_cpu->setBackend("cpu");

            const std::size_t nbDims = nbDimsDist(gen);
            std::vector<std::size_t> dims;
            for (std::size_t i = 0; i < nbDims; ++i) {
                dims.push_back(dimSizeDist(gen));
            }

            const std::size_t nb_elements = std::accumulate(dims.cbegin(), dims.cend(), std::size_t(1), std::multiplies<std::size_t>());
            number_of_operation += nb_elements;
            float* array0 = new float[nb_elements];

            for (std::size_t i = 0; i < nb_elements; ++i) {
                array0[i] = valueDist(gen);
            }

            // input0 CUDA
            float* array0_d;
            std::shared_ptr<Tensor> T0_cuda = std::make_shared<Tensor>();
            T0_cuda->setDataType(DataType::Float32);
            T0_cuda->setBackend("cuda");
            T0_cuda->resize(dims);
            op_cuda->associateInput(0, T0_cuda);
            cudaMalloc(reinterpret_cast<void **>(&array0_d), sizeof(float) * nb_elements);
            cudaMemcpy(array0_d, array0, sizeof(float) * nb_elements, cudaMemcpyHostToDevice);
            T0_cuda->getImpl()->setRawPtr(array0_d, nb_elements);

            // input0 CPU
            std::shared_ptr<Tensor> T0_cpu = std::make_shared<Tensor>();
            op_cpu->associateInput(0,T0_cpu);
            T0_cpu->setDataType(DataType::Float32);
            T0_cpu->setBackend("cpu");
            T0_cpu->resize(dims);
            T0_cpu -> getImpl() -> setRawPtr(array0, nb_elements);

            // forward CUDA
            op_cuda->setDataType(DataType::Float32);
            op_cuda->setBackend("cuda");
            start = std::chrono::system_clock::now();
            op_cuda->forward();
            end = std::chrono::system_clock::now();
            duration += std::chrono::duration_cast<std::chrono::microseconds>(end - start);

            // forward CPU
            op_cpu->forward();
            float *computedCPU = static_cast<float*>(op_cpu->getOutput(0)->getImpl()->rawPtr());

            std::shared_ptr<Tensor> outputFallback;
            const auto& cudaOutput = op_cuda->getOutput(0)->refCastFrom(outputFallback, *op_cpu->getOutput(0));
            REQUIRE(approxEq<float>(cudaOutput, *(op_cpu->getOutput(0))));

            delete[] array0;
            cudaFree(array0_d);
        }
}
} // namespace Aidge
