From ebef5c349a3beeb567a5a6a010de96c3c7c8cecf Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Thu, 21 Sep 2023 14:57:06 +0200
Subject: [PATCH] Fixed CuDNN descriptor lazy initialization

---
 include/aidge/backend/cuda/operator/ConvImpl.hpp | 5 +----
 src/operator/ConvImpl.cpp                        | 6 ++++--
 2 files changed, 5 insertions(+), 6 deletions(-)

diff --git a/include/aidge/backend/cuda/operator/ConvImpl.hpp b/include/aidge/backend/cuda/operator/ConvImpl.hpp
index 1c1a516..9391998 100644
--- a/include/aidge/backend/cuda/operator/ConvImpl.hpp
+++ b/include/aidge/backend/cuda/operator/ConvImpl.hpp
@@ -44,10 +44,7 @@ private:
     void* mWorkspace = nullptr;
 
 public:
-    ConvImpl_cuda(const Conv_Op<DIM> &op) : mOp(op) {
-        CHECK_CUDNN_STATUS(cudnnCreateConvolutionDescriptor(&mConvDesc));
-        CHECK_CUDNN_STATUS(cudnnCreateFilterDescriptor(&mFilterDesc));
-    }
+    ConvImpl_cuda(const Conv_Op<DIM> &op) : mOp(op) {}
 
     static std::unique_ptr<ConvImpl_cuda> create(const Conv_Op<2> &op) {
         return std::make_unique<ConvImpl_cuda>(op);
diff --git a/src/operator/ConvImpl.cpp b/src/operator/ConvImpl.cpp
index e76538a..35efa84 100644
--- a/src/operator/ConvImpl.cpp
+++ b/src/operator/ConvImpl.cpp
@@ -78,12 +78,13 @@ void Aidge::ConvImpl_cuda<DIM>::forward() {
     assert(mOp.getInput(0) && "missing input #0");
     assert(mOp.getInput(1) && "missing input #1");
 
-    // Initialize CuDNN convolution descriptor
+    // Lazy-initialize CuDNN convolution descriptor
     if (mConvDesc == nullptr) {
         const std::vector<int> strides(mOp.template get<ConvParam::StrideDims>().begin(), mOp.template get<ConvParam::StrideDims>().end());
         const std::vector<int> paddings(DIM, 0);
         const std::vector<int> upscales(mOp.template get<ConvParam::DilationDims>().begin(), mOp.template get<ConvParam::DilationDims>().end());
 
+        CHECK_CUDNN_STATUS(cudnnCreateConvolutionDescriptor(&mConvDesc));
         CHECK_CUDNN_STATUS(
             cudnnSetConvolutionNdDescriptor(mConvDesc,
                                             DIM,
@@ -94,10 +95,11 @@ void Aidge::ConvImpl_cuda<DIM>::forward() {
                                             DataTypeToCudnn(mOp.getOutput(0)->dataType())));
     }
 
-    // Initialize CuDNN filter descriptor
+    // Lazy-initialize CuDNN filter descriptor
     if (mFilterDesc == nullptr) {
         const std::vector<int> kernels(mOp.getInput(1)->dims().begin(), mOp.getInput(1)->dims().end());
 
+        CHECK_CUDNN_STATUS(cudnnCreateFilterDescriptor(&mFilterDesc));
         CHECK_CUDNN_STATUS(cudnnSetFilterNdDescriptor(mFilterDesc,
                                                     DataTypeToCudnn(mOp.getInput(1)->dataType()),
                                                     CUDNN_TENSOR_NCHW,
-- 
GitLab