From ebef5c349a3beeb567a5a6a010de96c3c7c8cecf Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Thu, 21 Sep 2023 14:57:06 +0200 Subject: [PATCH] Fixed CuDNN descriptor lazy initialization --- include/aidge/backend/cuda/operator/ConvImpl.hpp | 5 +---- src/operator/ConvImpl.cpp | 6 ++++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/include/aidge/backend/cuda/operator/ConvImpl.hpp b/include/aidge/backend/cuda/operator/ConvImpl.hpp index 1c1a516..9391998 100644 --- a/include/aidge/backend/cuda/operator/ConvImpl.hpp +++ b/include/aidge/backend/cuda/operator/ConvImpl.hpp @@ -44,10 +44,7 @@ private: void* mWorkspace = nullptr; public: - ConvImpl_cuda(const Conv_Op<DIM> &op) : mOp(op) { - CHECK_CUDNN_STATUS(cudnnCreateConvolutionDescriptor(&mConvDesc)); - CHECK_CUDNN_STATUS(cudnnCreateFilterDescriptor(&mFilterDesc)); - } + ConvImpl_cuda(const Conv_Op<DIM> &op) : mOp(op) {} static std::unique_ptr<ConvImpl_cuda> create(const Conv_Op<2> &op) { return std::make_unique<ConvImpl_cuda>(op); diff --git a/src/operator/ConvImpl.cpp b/src/operator/ConvImpl.cpp index e76538a..35efa84 100644 --- a/src/operator/ConvImpl.cpp +++ b/src/operator/ConvImpl.cpp @@ -78,12 +78,13 @@ void Aidge::ConvImpl_cuda<DIM>::forward() { assert(mOp.getInput(0) && "missing input #0"); assert(mOp.getInput(1) && "missing input #1"); - // Initialize CuDNN convolution descriptor + // Lazy-initialize CuDNN convolution descriptor if (mConvDesc == nullptr) { const std::vector<int> strides(mOp.template get<ConvParam::StrideDims>().begin(), mOp.template get<ConvParam::StrideDims>().end()); const std::vector<int> paddings(DIM, 0); const std::vector<int> upscales(mOp.template get<ConvParam::DilationDims>().begin(), mOp.template get<ConvParam::DilationDims>().end()); + CHECK_CUDNN_STATUS(cudnnCreateConvolutionDescriptor(&mConvDesc)); CHECK_CUDNN_STATUS( cudnnSetConvolutionNdDescriptor(mConvDesc, DIM, @@ -94,10 +95,11 @@ void Aidge::ConvImpl_cuda<DIM>::forward() { DataTypeToCudnn(mOp.getOutput(0)->dataType()))); } - // Initialize CuDNN filter descriptor + // Lazy-initialize CuDNN filter descriptor if (mFilterDesc == nullptr) { const std::vector<int> kernels(mOp.getInput(1)->dims().begin(), mOp.getInput(1)->dims().end()); + CHECK_CUDNN_STATUS(cudnnCreateFilterDescriptor(&mFilterDesc)); CHECK_CUDNN_STATUS(cudnnSetFilterNdDescriptor(mFilterDesc, DataTypeToCudnn(mOp.getInput(1)->dataType()), CUDNN_TENSOR_NCHW, -- GitLab