diff --git a/include/aidge/backend/cuda/operator/ConvImpl.hpp b/include/aidge/backend/cuda/operator/ConvImpl.hpp index 1c1a516f031fed2727bea3282bb1d1cd2e9e6214..939199862b7c745dafe27226c45e08005539a301 100644 --- a/include/aidge/backend/cuda/operator/ConvImpl.hpp +++ b/include/aidge/backend/cuda/operator/ConvImpl.hpp @@ -44,10 +44,7 @@ private: void* mWorkspace = nullptr; public: - ConvImpl_cuda(const Conv_Op<DIM> &op) : mOp(op) { - CHECK_CUDNN_STATUS(cudnnCreateConvolutionDescriptor(&mConvDesc)); - CHECK_CUDNN_STATUS(cudnnCreateFilterDescriptor(&mFilterDesc)); - } + ConvImpl_cuda(const Conv_Op<DIM> &op) : mOp(op) {} static std::unique_ptr<ConvImpl_cuda> create(const Conv_Op<2> &op) { return std::make_unique<ConvImpl_cuda>(op); diff --git a/src/operator/ConvImpl.cpp b/src/operator/ConvImpl.cpp index e76538a4127edbde3c6fc530fb7ccd4ff5220ec2..35efa84dc3d0d7f54a7a463d9100115ac8a79fdd 100644 --- a/src/operator/ConvImpl.cpp +++ b/src/operator/ConvImpl.cpp @@ -78,12 +78,13 @@ void Aidge::ConvImpl_cuda<DIM>::forward() { assert(mOp.getInput(0) && "missing input #0"); assert(mOp.getInput(1) && "missing input #1"); - // Initialize CuDNN convolution descriptor + // Lazy-initialize CuDNN convolution descriptor if (mConvDesc == nullptr) { const std::vector<int> strides(mOp.template get<ConvParam::StrideDims>().begin(), mOp.template get<ConvParam::StrideDims>().end()); const std::vector<int> paddings(DIM, 0); const std::vector<int> upscales(mOp.template get<ConvParam::DilationDims>().begin(), mOp.template get<ConvParam::DilationDims>().end()); + CHECK_CUDNN_STATUS(cudnnCreateConvolutionDescriptor(&mConvDesc)); CHECK_CUDNN_STATUS( cudnnSetConvolutionNdDescriptor(mConvDesc, DIM, @@ -94,10 +95,11 @@ void Aidge::ConvImpl_cuda<DIM>::forward() { DataTypeToCudnn(mOp.getOutput(0)->dataType()))); } - // Initialize CuDNN filter descriptor + // Lazy-initialize CuDNN filter descriptor if (mFilterDesc == nullptr) { const std::vector<int> kernels(mOp.getInput(1)->dims().begin(), mOp.getInput(1)->dims().end()); + CHECK_CUDNN_STATUS(cudnnCreateFilterDescriptor(&mFilterDesc)); CHECK_CUDNN_STATUS(cudnnSetFilterNdDescriptor(mFilterDesc, DataTypeToCudnn(mOp.getInput(1)->dataType()), CUDNN_TENSOR_NCHW,