diff --git a/include/aidge/backend/cuda/data/TensorImpl.hpp b/include/aidge/backend/cuda/data/TensorImpl.hpp index 7526680d46ed86268fdf021f9da5d73f6d0b263f..ea63f581400b5dc0669793a8d5f4c351885a9bb9 100644 --- a/include/aidge/backend/cuda/data/TensorImpl.hpp +++ b/include/aidge/backend/cuda/data/TensorImpl.hpp @@ -1,9 +1,6 @@ #ifndef AIDGE_BACKEND_CUDA_DATA_TENSORIMPL_H_ #define AIDGE_BACKEND_CUDA_DATA_TENSORIMPL_H_ -#include <thrust/equal.h> -#include <thrust/device_ptr.h> - #include "aidge/backend/TensorImpl.hpp" #include "aidge/data/Tensor.hpp" #include "aidge/utils/Registrar.hpp" @@ -13,8 +10,20 @@ #include "aidge/backend/cuda/utils/CudaContext.hpp" namespace Aidge { +/** + * @brief Abstract class for the TensorImpl_cuda class template. + * @details Its purpose is to provide access to base methods that are specific + * to the implementation (which are therefore not present in the TensorImpl + * class), but whose data type does not need to be known. + */ class TensorImpl_cuda_ { public: + /** + * @brief Return the CuDNN tensor descriptor of the tensor. + * @details This method uses lazy initialization for the descriptor + * (which is therefore mutable in the derived class). + * @return cudnnTensorDescriptor_t CuDNN tensor descriptor. + */ virtual const cudnnTensorDescriptor_t& getCudnnTensorDesc() const = 0; }; diff --git a/include/aidge/backend/cuda/operator/ConvImpl.hpp b/include/aidge/backend/cuda/operator/ConvImpl.hpp index 07c8a7f0d4b8c9663b17a42e216619e8e927ba89..1c1a516f031fed2727bea3282bb1d1cd2e9e6214 100644 --- a/include/aidge/backend/cuda/operator/ConvImpl.hpp +++ b/include/aidge/backend/cuda/operator/ConvImpl.hpp @@ -33,18 +33,18 @@ template <DimIdx_t DIM> class ConvImpl_cuda : public OperatorImpl { private: const Conv_Op<DIM> &mOp; - std::array<NbElts_t, 3> mNbConsumedData; - std::array<NbElts_t, 1> mNbProducedData; + std::array<NbElts_t, 3> mNbConsumedData = {0, 0, 0}; + std::array<NbElts_t, 1> mNbProducedData = {0}; + // CuDNN specific variables + cudnnConvolutionDescriptor_t mConvDesc = nullptr; + cudnnFilterDescriptor_t mFilterDesc = nullptr; + cudnnConvolutionFwdAlgo_t mFwdAlgo; size_t mWorkspaceSize = 0; void* mWorkspace = nullptr; - cudnnFilterDescriptor_t mFilterDesc; - cudnnConvolutionFwdAlgo_t mFwdAlgo; - cudnnConvolutionDescriptor_t mConvDesc; - public: - ConvImpl_cuda(const Conv_Op<DIM> &op) : mOp(op), mNbConsumedData({0, 0, 0}), mNbProducedData({0}) { + ConvImpl_cuda(const Conv_Op<DIM> &op) : mOp(op) { CHECK_CUDNN_STATUS(cudnnCreateConvolutionDescriptor(&mConvDesc)); CHECK_CUDNN_STATUS(cudnnCreateFilterDescriptor(&mFilterDesc)); } diff --git a/src/data/TensorImpl.cu b/src/data/TensorImpl.cu index 4104d27321fb75500fa55dd696bf1c589d69e76e..ba2a1348b4c5eae8499884bfe2488d67f016a060 100644 --- a/src/data/TensorImpl.cu +++ b/src/data/TensorImpl.cu @@ -11,6 +11,9 @@ #include "aidge/backend/cuda/data/TensorImpl.hpp" +#include <thrust/equal.h> +#include <thrust/device_ptr.h> + template <class T> bool Aidge::TensorImpl_cuda<T>::operator==(const TensorImpl &otherImpl) const { const auto& otherImplCuda = static_cast<const TensorImpl_cuda<T>&>(otherImpl); diff --git a/src/operator/ConvImpl.cpp b/src/operator/ConvImpl.cpp index 7d46072725566d901dadd60b92a42e9b9491c3aa..e76538a4127edbde3c6fc530fb7ccd4ff5220ec2 100644 --- a/src/operator/ConvImpl.cpp +++ b/src/operator/ConvImpl.cpp @@ -78,118 +78,75 @@ void Aidge::ConvImpl_cuda<DIM>::forward() { assert(mOp.getInput(0) && "missing input #0"); assert(mOp.getInput(1) && "missing input #1"); - const std::vector<int> strides(mOp.template get<ConvParam::StrideDims>().begin(), mOp.template get<ConvParam::StrideDims>().end()); - const std::vector<int> paddings(DIM, 0); - const std::vector<int> upscales(mOp.template get<ConvParam::DilationDims>().begin(), mOp.template get<ConvParam::DilationDims>().end()); - - CHECK_CUDNN_STATUS( - cudnnSetConvolutionNdDescriptor(mConvDesc, - DIM, - &paddings[0], - &strides[0], - &upscales[0], - CUDNN_CROSS_CORRELATION, - DataTypeToCudnn(mOp.getOutput(0)->dataType()))); - - const std::vector<int> kernels(mOp.getInput(1)->dims().begin(), mOp.getInput(1)->dims().end()); - - CHECK_CUDNN_STATUS(cudnnSetFilterNdDescriptor(mFilterDesc, - DataTypeToCudnn(mOp.getInput(1)->dataType()), - CUDNN_TENSOR_NCHW, - kernels.size(), - &kernels[0])); - - int maxAlgoIterations = 0; - cudnnGetConvolutionForwardAlgorithmMaxCount(CudaContext::cudnnHandle(), - &maxAlgoIterations); - - assert(maxAlgoIterations > 0 && "No available CUDNN ConvolutionForwardAlgorithm"); - - int returnAlgoCounts = 0; - - std::vector<cudnnConvolutionFwdAlgoPerf_t> returnFwdAlgo(maxAlgoIterations); - - /************************************************************************************************************** - https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnFindConvolutionForwardAlgorithm - This function attempts all cuDNN algorithms (including CUDNN_TENSOR_OP_MATH and CUDNN_DEFAULT_MATH - versions of algorithms where CUDNN_TENSOR_OP_MATH may be available) for cudnnConvolutionForward(), - using memory allocated via cudaMalloc(), and outputs performance metrics to a user-allocated array - of cudnnConvolutionFwdAlgoPerf_t. These metrics are written in sorted fashion where the first element - has the lowest compute time. The total number of resulting algorithms can be queried through - the API cudnnGetConvolutionForwardMaxCount(). - ***************************************************************************************************************/ - CHECK_CUDNN_STATUS(cudnnFindConvolutionForwardAlgorithm( - CudaContext::cudnnHandle(), - dynamic_cast<TensorImpl_cuda_*>(mOp.getInput(0)->getImpl().get())->getCudnnTensorDesc(), - mFilterDesc, - mConvDesc, - dynamic_cast<TensorImpl_cuda_*>(mOp.getOutput(0)->getImpl().get())->getCudnnTensorDesc(), - maxAlgoIterations, - &returnAlgoCounts, - &returnFwdAlgo[0])); - // std::cout << "Layer " << mName << "(" << k << ")" - // << " cuDNN forward algorithm heuristic results: " << std::endl; - - for(int fwdAlgo = 0; fwdAlgo < maxAlgoIterations; ++fwdAlgo) - { - std::string algoName - = (returnFwdAlgo[fwdAlgo].algo - == CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM) - ? "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM" - : (returnFwdAlgo[fwdAlgo].algo - == CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM) - ? "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM" - : (returnFwdAlgo[fwdAlgo].algo - == CUDNN_CONVOLUTION_FWD_ALGO_GEMM) - ? "CUDNN_CONVOLUTION_FWD_ALGO_GEMM" - : (returnFwdAlgo[fwdAlgo].algo - == CUDNN_CONVOLUTION_FWD_ALGO_DIRECT) - ? "CUDNN_CONVOLUTION_FWD_ALGO_DIRECT" - : (returnFwdAlgo[fwdAlgo].algo - == CUDNN_CONVOLUTION_FWD_ALGO_FFT) - ? "CUDNN_CONVOLUTION_FWD_ALGO_FFT" - : (returnFwdAlgo[fwdAlgo].algo - == CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING) - ? "CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING" - : (returnFwdAlgo[fwdAlgo].algo - == CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD) - ? "CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD" - : (returnFwdAlgo[fwdAlgo].algo - == CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED) - ? "CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED" - : (returnFwdAlgo[fwdAlgo].algo - == CUDNN_CONVOLUTION_FWD_ALGO_COUNT) - ? "CUDNN_CONVOLUTION_FWD_ALGO_COUNT" - : "Undetermined Algorithm"; - - - // std::cout << "----> Forward convolution algorithm: " << algoName - // << " [" << returnFwdAlgo[fwdAlgo].time << " ms][" << returnFwdAlgo[fwdAlgo].memory / 1.0e6 << " MB]" - // << std::endl; + // Initialize CuDNN convolution descriptor + if (mConvDesc == nullptr) { + const std::vector<int> strides(mOp.template get<ConvParam::StrideDims>().begin(), mOp.template get<ConvParam::StrideDims>().end()); + const std::vector<int> paddings(DIM, 0); + const std::vector<int> upscales(mOp.template get<ConvParam::DilationDims>().begin(), mOp.template get<ConvParam::DilationDims>().end()); + + CHECK_CUDNN_STATUS( + cudnnSetConvolutionNdDescriptor(mConvDesc, + DIM, + &paddings[0], + &strides[0], + &upscales[0], + CUDNN_CROSS_CORRELATION, + DataTypeToCudnn(mOp.getOutput(0)->dataType()))); } - mFwdAlgo = returnFwdAlgo[0].algo; - size_t workspaceSize = 0; + // Initialize CuDNN filter descriptor + if (mFilterDesc == nullptr) { + const std::vector<int> kernels(mOp.getInput(1)->dims().begin(), mOp.getInput(1)->dims().end()); - CHECK_CUDNN_STATUS(cudnnGetConvolutionForwardWorkspaceSize( - CudaContext::cudnnHandle(), - dynamic_cast<TensorImpl_cuda_*>(mOp.getInput(0)->getImpl().get())->getCudnnTensorDesc(), - mFilterDesc, - mConvDesc, - dynamic_cast<TensorImpl_cuda_*>(mOp.getOutput(0)->getImpl().get())->getCudnnTensorDesc(), - mFwdAlgo, - &workspaceSize)); + CHECK_CUDNN_STATUS(cudnnSetFilterNdDescriptor(mFilterDesc, + DataTypeToCudnn(mOp.getInput(1)->dataType()), + CUDNN_TENSOR_NCHW, + kernels.size(), + &kernels[0])); + } - if (mWorkspaceSize != workspaceSize) { - if (mWorkspace != nullptr) { - cudaFree(mWorkspace); - mWorkspaceSize = 0; - } + // Set forward algorithm and allocate the required workspace + if (mWorkspace == nullptr) { + // Find the best CuDNN forward algorithm (the one with the lowest compute time) + int maxAlgoIterations = 0; + cudnnGetConvolutionForwardAlgorithmMaxCount(CudaContext::cudnnHandle(), + &maxAlgoIterations); + + assert(maxAlgoIterations > 0 && "No available CUDNN ConvolutionForwardAlgorithm"); + + int returnAlgoCounts = 0; + std::vector<cudnnConvolutionFwdAlgoPerf_t> returnFwdAlgo(maxAlgoIterations); + + CHECK_CUDNN_STATUS(cudnnFindConvolutionForwardAlgorithm( + CudaContext::cudnnHandle(), + dynamic_cast<TensorImpl_cuda_*>(mOp.getInput(0)->getImpl().get())->getCudnnTensorDesc(), + mFilterDesc, + mConvDesc, + dynamic_cast<TensorImpl_cuda_*>(mOp.getOutput(0)->getImpl().get())->getCudnnTensorDesc(), + maxAlgoIterations, + &returnAlgoCounts, + &returnFwdAlgo[0])); + mFwdAlgo = returnFwdAlgo[0].algo; + + // Allocate the workspace required by the chosen CuDNN forward algorithm + size_t workspaceSize = 0; + + CHECK_CUDNN_STATUS(cudnnGetConvolutionForwardWorkspaceSize( + CudaContext::cudnnHandle(), + dynamic_cast<TensorImpl_cuda_*>(mOp.getInput(0)->getImpl().get())->getCudnnTensorDesc(), + mFilterDesc, + mConvDesc, + dynamic_cast<TensorImpl_cuda_*>(mOp.getOutput(0)->getImpl().get())->getCudnnTensorDesc(), + mFwdAlgo, + &workspaceSize)); CHECK_CUDA_STATUS(cudaMalloc(&mWorkspace, workspaceSize)); mWorkspaceSize = workspaceSize; } + // Do the actual forward computation + // Template is only for scaling parameters, which are always in float + // excepted when the convolution is performed in double precision. if (mOp.getOutput(0)->dataType() == DataType::Float64) { forward_<double>(); } @@ -219,6 +176,7 @@ void Aidge::ConvImpl_cuda<DIM>::forward_() { dynamic_cast<TensorImpl_cuda_*>(mOp.getOutput(0)->getImpl().get())->getCudnnTensorDesc(), mOp.getOutput(0)->getImpl()->rawPtr())); + // Add bias (if there is any) if (mOp.getInput(2) && mOp.getInput(2)->size() > 0) { // Bias tensor needs to have the same number of dims than output tensor for cudnnAddTensor() std::vector<DimSize_t> biasDims(DIM+2, 1); @@ -228,7 +186,7 @@ void Aidge::ConvImpl_cuda<DIM>::forward_() { Tensor bias(mOp.getInput(2)->dataType()); bias.setBackend("cuda"); bias.resize(biasDims); - // TODO: find a more elegant solution + // TODO: find a more elegant solution(?) CHECK_CUDNN_STATUS(cudnnAddTensor(CudaContext::cudnnHandle(), &alpha, @@ -242,8 +200,13 @@ void Aidge::ConvImpl_cuda<DIM>::forward_() { template <Aidge::DimIdx_t DIM> Aidge::ConvImpl_cuda<DIM>::~ConvImpl_cuda() { - cudnnDestroyConvolutionDescriptor(mConvDesc); - cudnnDestroyFilterDescriptor(mFilterDesc); + if (mConvDesc != nullptr) { + cudnnDestroyConvolutionDescriptor(mConvDesc); + } + + if (mFilterDesc != nullptr) { + cudnnDestroyFilterDescriptor(mFilterDesc); + } if (mWorkspace != nullptr) { cudaFree(mWorkspace);