From bff7cdcfaf186aad47c3674b5d27efd8512c9f03 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Thu, 21 Mar 2024 10:23:20 +0100
Subject: [PATCH] Added Conv backward prototype (UNTESTED)

---
 .../aidge/backend/cuda/operator/ConvImpl.hpp  |   5 +
 src/operator/ConvImpl.cpp                     | 170 ++++++++++++++++++
 2 files changed, 175 insertions(+)

diff --git a/include/aidge/backend/cuda/operator/ConvImpl.hpp b/include/aidge/backend/cuda/operator/ConvImpl.hpp
index 65180fd..1f692b0 100644
--- a/include/aidge/backend/cuda/operator/ConvImpl.hpp
+++ b/include/aidge/backend/cuda/operator/ConvImpl.hpp
@@ -34,8 +34,11 @@ private:
     cudnnConvolutionDescriptor_t mConvDesc = nullptr;
     cudnnFilterDescriptor_t mFilterDesc = nullptr;
     cudnnConvolutionFwdAlgo_t mFwdAlgo;
+    cudnnConvolutionBwdFilterAlgo_t mBwdFilterAlgo;
+    cudnnConvolutionBwdDataAlgo_t mBwdDataAlgo;
     size_t mWorkspaceSize = 0;
     void* mFwdWorkspace = nullptr;
+    void* mBwdWorkspace = nullptr;
     std::shared_ptr<Tensor> mInput0Fallback;
     std::shared_ptr<Tensor> mInput1Fallback;
     std::shared_ptr<Tensor> mInput2Fallback;
@@ -49,10 +52,12 @@ public:
 
 public:
     void forward();
+    void backward();
     ~ConvImpl_cuda();
 
 private:
     template <class T> void forward_(const Tensor& input0, const Tensor& input1, const Tensor& input2);
+    template <class T> void backward_(const Tensor& input0, const Tensor& input1, const Tensor& input2);
 };
 
 namespace {
diff --git a/src/operator/ConvImpl.cpp b/src/operator/ConvImpl.cpp
index 19ce56b..e2ee633 100644
--- a/src/operator/ConvImpl.cpp
+++ b/src/operator/ConvImpl.cpp
@@ -157,6 +157,176 @@ void Aidge::ConvImpl_cuda<DIM>::forward_(const Tensor& input0, const Tensor& inp
     }
 }
 
+template <Aidge::DimIdx_t DIM>
+void Aidge::ConvImpl_cuda<DIM>::backward() {
+    const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
+
+    // FIXME: uncomment the following code once memory handling will work
+    assert(mOp.getRawInput(0) && "missing input #0");
+    assert(mOp.getRawInput(1) && "missing input #1");
+
+    // Convert input data (no overhead if not needed!)
+    const auto& input0 = op.getInput(0)->ref(mInput0Fallback, *op.getOutput(0));
+    const auto& input1 = op.getInput(1)->ref(mInput1Fallback, *op.getOutput(0));
+    const auto& input2 = op.getInput(2)->ref(mInput2Fallback, *op.getOutput(0));
+
+    // Set forward algorithm and allocate the required workspace
+    if (mBwdWorkspace == nullptr) {
+        // Find the best CuDNN backward algorithm (the one with the lowest compute time)
+        int maxAlgoIterations = 0;
+        cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(CudaContext::cudnnHandle(),
+                                                    &maxAlgoIterations);
+        assert(maxAlgoIterations > 0 && "No available CUDNN ConvolutionBackwardFilterAlgorithm");
+
+        int returnAlgoCounts = 0;
+        std::vector<cudnnConvolutionBwdFilterAlgoPerf_t> returnBwdFilterAlgo(maxAlgoIterations);
+
+        CHECK_CUDNN_STATUS(cudnnFindConvolutionBackwardFilterAlgorithm(
+            CudaContext::cudnnHandle(),
+            std::dynamic_pointer_cast<TensorImpl_cuda_>(input0.getImpl())->getCudnnTensorDesc(input0),
+            std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
+            mConvDesc,
+            mFilterDesc,
+            maxAlgoIterations,
+            &returnAlgoCounts,
+            &returnBwdFilterAlgo[0]));
+
+        mBwdFilterAlgo = returnBwdFilterAlgo[0].algo;
+
+        maxAlgoIterations = 0;
+        cudnnGetConvolutionBackwardDataAlgorithmMaxCount(CudaContext::cudnnHandle(),
+                                                    &maxAlgoIterations);
+        assert(maxAlgoIterations > 0 && "No available CUDNN ConvolutionBackwardDataAlgorithm");
+
+        returnAlgoCounts = 0;
+        std::vector<cudnnConvolutionBwdDataAlgoPerf_t> returnBwdDataAlgo(maxAlgoIterations);
+
+        CHECK_CUDNN_STATUS(cudnnFindConvolutionBackwardDataAlgorithm(
+            CudaContext::cudnnHandle(),
+            mFilterDesc,
+            std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
+            mConvDesc,
+            std::dynamic_pointer_cast<TensorImpl_cuda_>(input0.getImpl())->getCudnnTensorDesc(input0),
+            maxAlgoIterations,
+            &returnAlgoCounts,
+            &returnBwdDataAlgo[0]));
+
+        mBwdDataAlgo = returnBwdDataAlgo[0].algo;
+
+        // Allocate the workspace required by the chosen CuDNN backward algorithm
+        size_t workspaceSize = 0;
+        CHECK_CUDNN_STATUS(cudnnGetConvolutionBackwardFilterWorkspaceSize(
+            CudaContext::cudnnHandle(),
+            // same arguments as cudnnGetConvolutionBackwardFilterAlgorithm()
+            // -->
+            std::dynamic_pointer_cast<TensorImpl_cuda_>(input0.getImpl())->getCudnnTensorDesc(input0),
+            std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
+            mConvDesc,
+            mFilterDesc,
+            // <--
+            mBwdFilterAlgo,
+            &workspaceSize));
+
+        size_t workspaceSizeData = 0;
+        CHECK_CUDNN_STATUS(cudnnGetConvolutionBackwardDataWorkspaceSize(
+            CudaContext::cudnnHandle(),
+            // same arguments as cudnnGetConvolutionBackwardDataAlgorithm() -->
+            mFilterDesc,
+            std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
+            mConvDesc,
+            std::dynamic_pointer_cast<TensorImpl_cuda_>(input0.getImpl())->getCudnnTensorDesc(input0),
+            // <--
+            mBwdDataAlgo,
+            &workspaceSizeData));
+
+        if (workspaceSizeData > workspaceSize)
+            workspaceSize = workspaceSizeData;
+
+        if (workspaceSize > mWorkspaceSize) {
+            if (mFwdWorkspace != nullptr) {
+                cudaFree(mFwdWorkspace);
+            }
+            CHECK_CUDA_STATUS(cudaMalloc(&mFwdWorkspace, workspaceSize));
+            mWorkspaceSize = workspaceSize;
+        }
+
+        mBwdWorkspace = mFwdWorkspace;
+    }
+
+    // Do the actual backward computation
+    // Template is only for scaling parameters, which are always in float
+    // excepted when the convolution is performed in double precision.
+    if (op.getOutput(0)->dataType() == DataType::Float64) {
+        backward_<double>(input0, input1, input2);
+    }
+    else {
+        backward_<float>(input0, input1, input2);
+    }
+}
+
+template <Aidge::DimIdx_t DIM>
+template <class T>
+void Aidge::ConvImpl_cuda<DIM>::backward_(const Tensor& input0, const Tensor& input1, const Tensor& input2) {
+    const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
+
+    std::shared_ptr<Tensor> gradOutputFallback;
+    const auto& gradOutput = op.getOutput(0)->grad()->refCastFrom(gradOutputFallback, *(op.getInput(0)->grad()));
+
+    const T alpha = 1.0f;
+    const T beta = 0.0f;
+
+    CHECK_CUDNN_STATUS(cudnnConvolutionBackwardFilter(
+        CudaContext::cudnnHandle(),
+        &alpha,
+        std::dynamic_pointer_cast<TensorImpl_cuda_>(input0.getImpl())->getCudnnTensorDesc(input0),
+        input0.getImpl()->rawPtr(),
+        std::dynamic_pointer_cast<TensorImpl_cuda_>(gradOutput.getImpl())->getCudnnTensorDesc(gradOutput),
+        gradOutput.getImpl()->rawPtr(),
+        mConvDesc,
+        mBwdFilterAlgo,
+        mBwdWorkspace,
+        mWorkspaceSize,
+        &beta,
+        mFilterDesc,
+        op.getInput(1)->grad()->getImpl()->rawPtr()));
+
+    CHECK_CUDNN_STATUS(cudnnConvolutionBackwardData(
+        CudaContext::cudnnHandle(),
+        &alpha,
+        mFilterDesc,
+        input1.getImpl()->rawPtr(),
+        std::dynamic_pointer_cast<TensorImpl_cuda_>(gradOutput.getImpl())->getCudnnTensorDesc(gradOutput),
+        gradOutput.getImpl()->rawPtr(),
+        mConvDesc,
+        mBwdDataAlgo,
+        mBwdWorkspace,
+        mWorkspaceSize,
+        &beta,
+        std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getInput(0)->grad()->getImpl())->getCudnnTensorDesc(*op.getInput(0)),
+        op.getInput(0)->grad()->getImpl()->rawPtr()));
+
+    // Add bias (if there is any)
+    if (mOp.getRawInput(2) && input2.size() > 0) {
+        // Bias tensor needs to have the same number of dims than output tensor for cudnnAddTensor()
+        std::vector<DimSize_t> gradBiasDims(DIM+2, 1);
+        gradBiasDims[1] = op.getInput(2)->grad()->size();
+
+        // Create a dummy tensor with the right dims in order to get a CuDNN tensor descriptor (with getCudnnTensorDesc())
+        Tensor gradBias(op.getInput(2)->grad()->dataType());
+        gradBias.setBackend("cuda");
+        gradBias.resize(gradBiasDims);
+        // TODO: find a more elegant solution(?)
+
+        CHECK_CUDNN_STATUS(cudnnConvolutionBackwardBias(CudaContext::cudnnHandle(),
+            &alpha,
+            std::dynamic_pointer_cast<TensorImpl_cuda_>(gradOutput.getImpl())->getCudnnTensorDesc(gradOutput),
+            gradOutput.getImpl()->rawPtr(),
+            &beta,
+            std::dynamic_pointer_cast<TensorImpl_cuda_>(gradBias.getImpl())->getCudnnTensorDesc(gradBias),
+            op.getInput(2)->grad()->getImpl()->rawPtr()));
+    }
+}
+
 template <Aidge::DimIdx_t DIM>
 Aidge::ConvImpl_cuda<DIM>::~ConvImpl_cuda() {
     if (mConvDesc != nullptr) {
-- 
GitLab