Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
ConvImpl.cpp 7.94 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include <cassert>
#include <chrono>  // std::chrono::milliseconds
#include <numeric> // std::accumulate
#include <thread>  // std::this_thread::sleep_for
#include <vector>

#include "aidge/utils/Types.h"
#include "aidge/operator/Conv.hpp"

#include "aidge/backend/cuda/data/TensorImpl.hpp"
#include "aidge/backend/cuda/operator/ConvImpl.hpp"
#include "aidge/backend/cuda/utils/CudaContext.hpp"

template <Aidge::DimIdx_t DIM>
void Aidge::ConvImpl_cuda<DIM>::forward() {
    // FIXME: uncomment the following code once memory handling will work
    assert(mOp.getRawInput(0) && "missing input #0");
    assert(mOp.getRawInput(1) && "missing input #1");

    // Convert input data (no overhead if not needed!)
    const auto& input0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->refCastFrom(mInput0Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
    const auto& input1 = std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->refCastFrom(mInput1Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
    const auto& input2 = std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->refCastFrom(mInput2Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));

    // Lazy-initialize CuDNN convolution descriptor
    if (mConvDesc == nullptr) {
        const Conv_Op<DIM>& convOp = static_cast<const Conv_Op<DIM>&>(mOp);
        const std::vector<int> strides(convOp.template getAttr<ConvAttr::StrideDims>().begin(), convOp.template getAttr<ConvAttr::StrideDims>().end());
        const std::vector<int> paddings(DIM, 0);
        const std::vector<int> upscales(convOp.template getAttr<ConvAttr::DilationDims>().begin(), convOp.template getAttr<ConvAttr::DilationDims>().end());

        CHECK_CUDNN_STATUS(cudnnCreateConvolutionDescriptor(&mConvDesc));
        CHECK_CUDNN_STATUS(
            cudnnSetConvolutionNdDescriptor(mConvDesc,
                                            DIM,
                                            &paddings[0],
                                            &strides[0],
                                            &upscales[0],
                                            CUDNN_CROSS_CORRELATION,
                                            DataTypeToCudnn(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType())));
    }

    // Lazy-initialize CuDNN filter descriptor
    if (mFilterDesc == nullptr) {
        const std::vector<int> kernels(input1.dims().begin(), input1.dims().end());

        CHECK_CUDNN_STATUS(cudnnCreateFilterDescriptor(&mFilterDesc));
        CHECK_CUDNN_STATUS(cudnnSetFilterNdDescriptor(mFilterDesc,
                                                    DataTypeToCudnn(input1.dataType()),
                                                    CUDNN_TENSOR_NCHW,
                                                    kernels.size(),
                                                    &kernels[0]));
    }

    // Set forward algorithm and allocate the required workspace
    if (mFwdWorkspace == nullptr) {
        // Find the best CuDNN forward algorithm (the one with the lowest compute time)
        int maxAlgoIterations = 0;
        cudnnGetConvolutionForwardAlgorithmMaxCount(CudaContext::cudnnHandle(),
                                                    &maxAlgoIterations);

        assert(maxAlgoIterations > 0 && "No available CUDNN ConvolutionForwardAlgorithm");

        int returnAlgoCounts = 0;
        std::vector<cudnnConvolutionFwdAlgoPerf_t> returnFwdAlgo(maxAlgoIterations);

        CHECK_CUDNN_STATUS(cudnnFindConvolutionForwardAlgorithm(
                            CudaContext::cudnnHandle(),
                            dynamic_cast<TensorImpl_cuda_*>(input0.getImpl().get())->getCudnnTensorDesc(input0),
                            mFilterDesc,
                            mConvDesc,
                            dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl().get())->getCudnnTensorDesc(*std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))),
                            maxAlgoIterations,
                            &returnAlgoCounts,
                            &returnFwdAlgo[0]));
        mFwdAlgo = returnFwdAlgo[0].algo;

        // Allocate the workspace required by the chosen CuDNN forward algorithm
        size_t workspaceSize = 0;

        CHECK_CUDNN_STATUS(cudnnGetConvolutionForwardWorkspaceSize(
            CudaContext::cudnnHandle(),
            dynamic_cast<TensorImpl_cuda_*>(input0.getImpl().get())->getCudnnTensorDesc(input0),
            mFilterDesc,
            mConvDesc,
            dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl().get())->getCudnnTensorDesc(*std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))),
            mFwdAlgo,
            &workspaceSize));

        CHECK_CUDA_STATUS(cudaMalloc(&mFwdWorkspace, workspaceSize));
        mWorkspaceSize = workspaceSize;
    }

    // Do the actual forward computation
    // Template is only for scaling parameters, which are always in float
    // excepted when the convolution is performed in double precision.
    if (std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType() == DataType::Float64) {
        forward_<double>(input0, input1, input2);
    }
    else {
        forward_<float>(input0, input1, input2);
    }
}

template <Aidge::DimIdx_t DIM>
template <class T>
void Aidge::ConvImpl_cuda<DIM>::forward_(const Tensor& input0, const Tensor& input1, const Tensor& input2) {
    const T alpha = 1.0f;
    const T beta = 0.0f;

    CHECK_CUDNN_STATUS(cudnnConvolutionForward(CudaContext::cudnnHandle(),
        &alpha,
        dynamic_cast<TensorImpl_cuda_*>(input0.getImpl().get())->getCudnnTensorDesc(input0),
        input0.getImpl()->rawPtr(),
        mFilterDesc,
        input1.getImpl()->rawPtr(),
        mConvDesc,
        mFwdAlgo,
        mFwdWorkspace,
        mWorkspaceSize,
        &beta,
        dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl().get())->getCudnnTensorDesc(*std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))),
        std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()));

    // Add bias (if there is any)
    if (mOp.getRawInput(2) && input2.size() > 0) {
        // Bias tensor needs to have the same number of dims than output tensor for cudnnAddTensor()
        std::vector<DimSize_t> biasDims(DIM+2, 1);
        biasDims[1] = input2.size();

        // Create a dummy tensor with the right dims in order to get a CuDNN tensor descriptor (with getCudnnTensorDesc())
        Tensor bias(input2.dataType());
        bias.setBackend("cuda");
        bias.resize(biasDims);
        // TODO: find a more elegant solution(?)

        CHECK_CUDNN_STATUS(cudnnAddTensor(CudaContext::cudnnHandle(),
            &alpha,
            dynamic_cast<TensorImpl_cuda_*>(bias.getImpl().get())->getCudnnTensorDesc(bias),
            input2.getImpl()->rawPtr(),
            &alpha,
            dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl().get())->getCudnnTensorDesc(*std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))),
            std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()));
    }
}

template <Aidge::DimIdx_t DIM>
Aidge::ConvImpl_cuda<DIM>::~ConvImpl_cuda() {
    if (mConvDesc != nullptr) {
        cudnnDestroyConvolutionDescriptor(mConvDesc);
    }

    if (mFilterDesc != nullptr) {
        cudnnDestroyFilterDescriptor(mFilterDesc);
    }

    if (mFwdWorkspace != nullptr) {
        cudaFree(mFwdWorkspace);
    }
}


// Template declarations
template class Aidge::ConvImpl_cuda<2>;