/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include <cassert>
#include <chrono>  // std::chrono::milliseconds
#include <numeric> // std::accumulate
#include <thread>  // std::this_thread::sleep_for
#include <vector>

#include "aidge/utils/Types.h"
#include "aidge/operator/Conv.hpp"

#include "aidge/backend/cuda/data/TensorImpl.hpp"
#include "aidge/backend/cuda/operator/ConvImpl.hpp"
#include "aidge/backend/cuda/utils/CudaContext.hpp"

template <Aidge::DimIdx_t DIM>
Aidge::NbElts_t Aidge::ConvImpl_cuda<DIM>::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const {
    assert(mOp.getInput(inputIdx) && "requires valid input");

    // Requires the whole tensors
    const auto &inputDims = std::static_pointer_cast<Tensor>(mOp.getInput(inputIdx))->dims();

    return std::accumulate(inputDims.begin(), inputDims.end(), Aidge::NbElts_t(1), std::multiplies<NbElts_t>());
}

template <Aidge::DimIdx_t DIM>
Aidge::NbElts_t Aidge::ConvImpl_cuda<DIM>::getNbRequiredProtected(IOIndex_t /*inputIdx*/) const {
    // for the direct convolution algorithm, convolutions can be in-place, if
    // there is no padding!
    return 0;
}

template <Aidge::DimIdx_t DIM>
Aidge::NbElts_t Aidge::ConvImpl_cuda<DIM>::getRequiredMemory(const Aidge::IOIndex_t outputIdx,
                                                         const std::vector<Aidge::DimSize_t> &/*inputsSize*/) const {
    // Requires the whole tensors, regardless of available data on inputs
    assert(outputIdx == 0 && "operator has only one output");
    (void) outputIdx;

    const auto &outputDims = std::static_pointer_cast<Tensor>(mOp.getOutput(0))->dims();
    return std::accumulate(outputDims.begin(), outputDims.end(), NbElts_t(1), std::multiplies<NbElts_t>());
}

template <Aidge::DimIdx_t DIM>
Aidge::NbElts_t Aidge::ConvImpl_cuda<DIM>::getNbConsumedData(Aidge::IOIndex_t inputIdx) const {
    assert(static_cast<std::size_t>(inputIdx) < mNbConsumedData.size());
    return mNbConsumedData[static_cast<std::size_t>(inputIdx)];
}

template <Aidge::DimIdx_t DIM>
Aidge::NbElts_t Aidge::ConvImpl_cuda<DIM>::getNbProducedData(Aidge::IOIndex_t outputIdx) const {
    assert((outputIdx == 0) && (static_cast<std::size_t>(outputIdx) < mNbProducedData.size()));
    return mNbProducedData[static_cast<std::size_t>(outputIdx)];
}

template <Aidge::DimIdx_t DIM>
void Aidge::ConvImpl_cuda<DIM>::updateConsummerProducer(){
    // Update producer-consumer data
    for (std::size_t inputIdx = 0; inputIdx < mNbConsumedData.size(); ++inputIdx)
        mNbConsumedData[inputIdx] += getNbRequiredData(static_cast<IOIndex_t>(inputIdx));  // each input is consumed by the minimum
                                                                   // amount for a forward pass

    mNbProducedData[0] += getRequiredMemory(0, {});
}

template <Aidge::DimIdx_t DIM>
void Aidge::ConvImpl_cuda<DIM>::forward() {
    // FIXME: uncomment the following code once memory handling will work
    assert(mOp.getInput(0) && "missing input #0");
    assert(mOp.getInput(1) && "missing input #1");
    assert(mOp.getInput(2) && "missing input #2");

    const std::vector<int> strides(mOp.template get<ConvParam::StrideDims>().rbegin(), mOp.template get<ConvParam::StrideDims>().rend());
    const std::vector<int> paddings(DIM, 0);
    const std::vector<int> upscales(mOp.template get<ConvParam::DilationDims>().rbegin(), mOp.template get<ConvParam::DilationDims>().rend());

    CHECK_CUDNN_STATUS(
        cudnnSetConvolutionNdDescriptor(mConvDesc,
                                        DIM,
                                        &paddings[0],
                                        &strides[0],
                                        &upscales[0],
                                        CUDNN_CROSS_CORRELATION,
                                        DataTypeToCudnn(mOp.getInput(2)->dataType())));

    const std::vector<int> cudaKernelDims(mOp.getInput(1)->dims().rbegin(),
                                          mOp.getInput(1)->dims().rend());

    CHECK_CUDNN_STATUS(cudnnCreateFilterDescriptor(&mFilterDesc));
    CHECK_CUDNN_STATUS(cudnnSetFilterNdDescriptor(mFilterDesc,
                                                DataTypeToCudnn(mOp.getInput(1)->dataType()),
                                                CUDNN_TENSOR_NCHW,
                                                cudaKernelDims.size(),
                                                &cudaKernelDims[0]));

    int maxAlgoIterations = 0;
    cudnnGetConvolutionForwardAlgorithmMaxCount(CudaContext::cudnnHandle(),
                                                &maxAlgoIterations);

    assert(maxAlgoIterations > 0 && "No available CUDNN ConvolutionForwardAlgorithm");

    int returnAlgoCounts = 0;

    std::vector<cudnnConvolutionFwdAlgoPerf_t> returnFwdAlgo(maxAlgoIterations);
/**************************************************************************************************************
https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnFindConvolutionForwardAlgorithm
This function attempts all cuDNN algorithms (including CUDNN_TENSOR_OP_MATH and CUDNN_DEFAULT_MATH
versions of algorithms where CUDNN_TENSOR_OP_MATH may be available) for cudnnConvolutionForward(),
using memory allocated via cudaMalloc(), and outputs performance metrics to a user-allocated array
of cudnnConvolutionFwdAlgoPerf_t. These metrics are written in sorted fashion where the first element
has the lowest compute time. The total number of resulting algorithms can be queried through
the API cudnnGetConvolutionForwardMaxCount().
***************************************************************************************************************/

    CHECK_CUDNN_STATUS(cudnnFindConvolutionForwardAlgorithm(
                        CudaContext::cudnnHandle(),
                        static_cast<TensorImpl_cuda<float>*>(mOp.getInput(0)->getImpl().get())->getCudnnTensorDesc(),  // FIXME: PLAIN WRONG
                        mFilterDesc,
                        mConvDesc,
                        static_cast<TensorImpl_cuda<float>*>(mOp.getOutput(0)->getImpl().get())->getCudnnTensorDesc(),  // FIXME: PLAIN WRONG
                        maxAlgoIterations,
                        &returnAlgoCounts,
                        &returnFwdAlgo[0]));
    // std::cout << "Layer " << mName << "(" << k  << ")"
    //     << " cuDNN forward algorithm heuristic results: " << std::endl;

    for(int fwdAlgo = 0; fwdAlgo < maxAlgoIterations; ++fwdAlgo)
    {
        std::string algoName
                            = (returnFwdAlgo[fwdAlgo].algo
                                    == CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM)
                                ? "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM"
                            : (returnFwdAlgo[fwdAlgo].algo
                                    == CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM)
                                ? "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM"
                            : (returnFwdAlgo[fwdAlgo].algo
                                    == CUDNN_CONVOLUTION_FWD_ALGO_GEMM)
                                ? "CUDNN_CONVOLUTION_FWD_ALGO_GEMM"
                            : (returnFwdAlgo[fwdAlgo].algo
                                    == CUDNN_CONVOLUTION_FWD_ALGO_DIRECT)
                                ? "CUDNN_CONVOLUTION_FWD_ALGO_DIRECT"
                            : (returnFwdAlgo[fwdAlgo].algo
                                    == CUDNN_CONVOLUTION_FWD_ALGO_FFT)
                                ? "CUDNN_CONVOLUTION_FWD_ALGO_FFT"
                            : (returnFwdAlgo[fwdAlgo].algo
                                    == CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING)
                                ? "CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING"
                            : (returnFwdAlgo[fwdAlgo].algo
                                    == CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD)
                                ? "CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD"
                            : (returnFwdAlgo[fwdAlgo].algo
                                    == CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED)
                                ? "CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED"
                            : (returnFwdAlgo[fwdAlgo].algo
                                    == CUDNN_CONVOLUTION_FWD_ALGO_COUNT)
                                ? "CUDNN_CONVOLUTION_FWD_ALGO_COUNT"
                            : "Undetermined Algorithm";


        // std::cout << "----> Forward convolution algorithm: " << algoName
        //     << " [" << returnFwdAlgo[fwdAlgo].time << " ms][" << returnFwdAlgo[fwdAlgo].memory / 1.0e6 << " MB]"
        //     << std::endl;
    }
    mFwdAlgo = returnFwdAlgo[0].algo;
}

template <Aidge::DimIdx_t DIM>
Aidge::ConvImpl_cuda<DIM>::~ConvImpl_cuda() {

}

template <Aidge::DimIdx_t DIM>
void Aidge::ConvImpl_cuda<DIM>::backward() { printf("Not implemented yet.\n"); }


// Template declarations
void ConvImpl_cuda_template_declaration ()
{
    Aidge::ConvImpl_cuda<2> ConvImpl_cuda2(Aidge::Conv_Op<2>());
}