-
Cyril Moineau authoredCyril Moineau authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
ConvImpl.cpp 14.88 KiB
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cassert>
#include <vector>
#include "aidge/backend/cuda/data/TensorImpl.hpp"
#include "aidge/backend/cuda/operator/ConvImpl.hpp"
#include "aidge/backend/cuda/utils/CudaUtils.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/ConvDepthWise.hpp"
#include "aidge/utils/Types.h"
template <Aidge::DimIdx_t DIM>
void Aidge::ConvImpl_cuda<DIM>::forward() {
const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
// FIXME: uncomment the following code once memory handling will work
assert(mOp.getRawInput(0) && "missing input #0");
assert(mOp.getRawInput(1) && "missing input #1");
// Convert input data (no overhead if not needed!)
const auto& input0 = op.getInput(0)->refCastFrom(mInput0Fallback, *op.getOutput(0));
const auto& input1 = op.getInput(1)->refCastFrom(mInput1Fallback, *op.getOutput(0));
const auto& input2 = op.getInput(2)->refCastFrom(mInput2Fallback, *op.getOutput(0));
// Lazy-initialize CuDNN convolution descriptor
if (mConvDesc == nullptr) {
const std::vector<int> paddings(DIM, 0);
std::vector<int> strides, upscales;
if (mDepthWise) {
const ConvDepthWise_Op<DIM>& convDWOp = static_cast<const ConvDepthWise_Op<DIM>&>(mOp);
strides = std::vector<int>(convDWOp.template getAttr<ConvDepthWiseAttr::StrideDims>().begin(), convDWOp.template getAttr<ConvDepthWiseAttr::StrideDims>().end());
upscales = std::vector<int>(convDWOp.template getAttr<ConvDepthWiseAttr::DilationDims>().begin(), convDWOp.template getAttr<ConvDepthWiseAttr::DilationDims>().end());
}
else {
const Conv_Op<DIM>& convOp = static_cast<const Conv_Op<DIM>&>(mOp);
strides = std::vector<int>(convOp.template getAttr<ConvAttr::StrideDims>().begin(), convOp.template getAttr<ConvAttr::StrideDims>().end());
upscales = std::vector<int>(convOp.template getAttr<ConvAttr::DilationDims>().begin(), convOp.template getAttr<ConvAttr::DilationDims>().end());
}
CHECK_CUDNN_STATUS(cudnnCreateConvolutionDescriptor(&mConvDesc));
CHECK_CUDNN_STATUS(cudnnSetConvolutionNdDescriptor(mConvDesc,
DIM,
&paddings[0],
&strides[0],
&upscales[0],
CUDNN_CROSS_CORRELATION,
DataTypeToCudnn(op.getOutput(0)->dataType())));
}
// Lazy-initialize CuDNN filter descriptor
if (mFilterDesc == nullptr) {
const std::vector<int> kernels(input1.dims().begin(), input1.dims().end());
CHECK_CUDNN_STATUS(cudnnCreateFilterDescriptor(&mFilterDesc));
CHECK_CUDNN_STATUS(cudnnSetFilterNdDescriptor(mFilterDesc,
DataTypeToCudnn(input1.dataType()),
CUDNN_TENSOR_NCHW,
kernels.size(),
&kernels[0]));
}
// Set forward algorithm and allocate the required workspace
if (mFwdWorkspace == nullptr) {
// Find the best CuDNN forward algorithm (the one with the lowest compute time)
int maxAlgoIterations = 0;
cudnnGetConvolutionForwardAlgorithmMaxCount(CudaContext::cudnnHandle(),
&maxAlgoIterations);
assert(maxAlgoIterations > 0 && "No available CUDNN ConvolutionForwardAlgorithm");
int returnAlgoCounts = 0;
std::vector<cudnnConvolutionFwdAlgoPerf_t> returnFwdAlgo(maxAlgoIterations);
CHECK_CUDNN_STATUS(cudnnFindConvolutionForwardAlgorithm(
CudaContext::cudnnHandle(),
std::dynamic_pointer_cast<TensorImpl_cuda_>(input0.getImpl())->getCudnnTensorDesc(input0),
mFilterDesc,
mConvDesc,
std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
maxAlgoIterations,
&returnAlgoCounts,
&returnFwdAlgo[0]));
mFwdAlgo = returnFwdAlgo[0].algo;
// Allocate the workspace required by the chosen CuDNN forward algorithm
size_t workspaceSize = 0;
CHECK_CUDNN_STATUS(cudnnGetConvolutionForwardWorkspaceSize(
CudaContext::cudnnHandle(),
std::dynamic_pointer_cast<TensorImpl_cuda_>(input0.getImpl())->getCudnnTensorDesc(input0),
mFilterDesc,
mConvDesc,
std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
mFwdAlgo,
&workspaceSize));
CHECK_CUDA_STATUS(cudaMalloc(&mFwdWorkspace, workspaceSize));
mWorkspaceSize = workspaceSize;
}
// Do the actual forward computation
// Template is only for scaling parameters, which are always in float
// excepted when the convolution is performed in double precision.
if (op.getOutput(0)->dataType() == DataType::Float64) {
forward_<double>(input0, input1, input2);
}
else {
forward_<float>(input0, input1, input2);
}
}
template <Aidge::DimIdx_t DIM>
template <class T>
void Aidge::ConvImpl_cuda<DIM>::forward_(const Tensor& input0, const Tensor& input1, const Tensor& input2) {
const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f;
const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f;
CHECK_CUDNN_STATUS(cudnnConvolutionForward(CudaContext::cudnnHandle(),
&alpha,
std::dynamic_pointer_cast<TensorImpl_cuda_>(input0.getImpl())->getCudnnTensorDesc(input0),
input0.getImpl()->rawPtr(),
mFilterDesc,
input1.getImpl()->rawPtr(),
mConvDesc,
mFwdAlgo,
mFwdWorkspace,
mWorkspaceSize,
&beta,
std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
op.getOutput(0)->getImpl()->rawPtr()));
// Add bias (if there is any)
if (mOp.getRawInput(2) && input2.size() > 0) {
// Bias tensor needs to have the same number of dims than output tensor for cudnnAddTensor()
std::vector<DimSize_t> biasDims(DIM+2, 1);
biasDims[1] = input2.size();
// Create a dummy tensor with the right dims in order to get a CuDNN tensor descriptor (with getCudnnTensorDesc())
Tensor bias(input2.dataType());
bias.setBackend("cuda");
bias.resize(biasDims);
// TODO: find a more elegant solution(?)
CHECK_CUDNN_STATUS(cudnnAddTensor(CudaContext::cudnnHandle(),
&alpha,
std::dynamic_pointer_cast<TensorImpl_cuda_>(bias.getImpl())->getCudnnTensorDesc(bias),
input2.getImpl()->rawPtr(),
&alpha,
std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
op.getOutput(0)->getImpl()->rawPtr()));
}
}
template <Aidge::DimIdx_t DIM>
void Aidge::ConvImpl_cuda<DIM>::backward() {
const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
// FIXME: uncomment the following code once memory handling will work
assert(mOp.getRawInput(0) && "missing input #0");
assert(mOp.getRawInput(1) && "missing input #1");
// Convert input data (no overhead if not needed!)
const auto& input0 = op.getInput(0)->ref(mInput0Fallback, *op.getOutput(0));
const auto& input1 = op.getInput(1)->ref(mInput1Fallback, *op.getOutput(0));
const auto& input2 = op.getInput(2)->ref(mInput2Fallback, *op.getOutput(0));
// Set forward algorithm and allocate the required workspace
if (mBwdWorkspace == nullptr) {
// Find the best CuDNN backward algorithm (the one with the lowest compute time)
int maxAlgoIterations = 0;
cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(CudaContext::cudnnHandle(),
&maxAlgoIterations);
assert(maxAlgoIterations > 0 && "No available CUDNN ConvolutionBackwardFilterAlgorithm");
int returnAlgoCounts = 0;
std::vector<cudnnConvolutionBwdFilterAlgoPerf_t> returnBwdFilterAlgo(maxAlgoIterations);
CHECK_CUDNN_STATUS(cudnnFindConvolutionBackwardFilterAlgorithm(
CudaContext::cudnnHandle(),
std::dynamic_pointer_cast<TensorImpl_cuda_>(input0.getImpl())->getCudnnTensorDesc(input0),
std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
mConvDesc,
mFilterDesc,
maxAlgoIterations,
&returnAlgoCounts,
&returnBwdFilterAlgo[0]));
mBwdFilterAlgo = returnBwdFilterAlgo[0].algo;
maxAlgoIterations = 0;
cudnnGetConvolutionBackwardDataAlgorithmMaxCount(CudaContext::cudnnHandle(),
&maxAlgoIterations);
assert(maxAlgoIterations > 0 && "No available CUDNN ConvolutionBackwardDataAlgorithm");
returnAlgoCounts = 0;
std::vector<cudnnConvolutionBwdDataAlgoPerf_t> returnBwdDataAlgo(maxAlgoIterations);
CHECK_CUDNN_STATUS(cudnnFindConvolutionBackwardDataAlgorithm(
CudaContext::cudnnHandle(),
mFilterDesc,
std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
mConvDesc,
std::dynamic_pointer_cast<TensorImpl_cuda_>(input0.getImpl())->getCudnnTensorDesc(input0),
maxAlgoIterations,
&returnAlgoCounts,
&returnBwdDataAlgo[0]));
mBwdDataAlgo = returnBwdDataAlgo[0].algo;
// Allocate the workspace required by the chosen CuDNN backward algorithm
size_t workspaceSize = 0;
CHECK_CUDNN_STATUS(cudnnGetConvolutionBackwardFilterWorkspaceSize(
CudaContext::cudnnHandle(),
// same arguments as cudnnGetConvolutionBackwardFilterAlgorithm()
// -->
std::dynamic_pointer_cast<TensorImpl_cuda_>(input0.getImpl())->getCudnnTensorDesc(input0),
std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
mConvDesc,
mFilterDesc,
// <--
mBwdFilterAlgo,
&workspaceSize));
size_t workspaceSizeData = 0;
CHECK_CUDNN_STATUS(cudnnGetConvolutionBackwardDataWorkspaceSize(
CudaContext::cudnnHandle(),
// same arguments as cudnnGetConvolutionBackwardDataAlgorithm() -->
mFilterDesc,
std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
mConvDesc,
std::dynamic_pointer_cast<TensorImpl_cuda_>(input0.getImpl())->getCudnnTensorDesc(input0),
// <--
mBwdDataAlgo,
&workspaceSizeData));
if (workspaceSizeData > workspaceSize)
workspaceSize = workspaceSizeData;
if (workspaceSize > mWorkspaceSize) {
if (mFwdWorkspace != nullptr) {
cudaFree(mFwdWorkspace);
}
CHECK_CUDA_STATUS(cudaMalloc(&mFwdWorkspace, workspaceSize));
mWorkspaceSize = workspaceSize;
}
mBwdWorkspace = mFwdWorkspace;
}
// Do the actual backward computation
// Template is only for scaling parameters, which are always in float
// excepted when the convolution is performed in double precision.
if (op.getOutput(0)->dataType() == DataType::Float64) {
backward_<double>(input0, input1, input2);
}
else {
backward_<float>(input0, input1, input2);
}
}
template <Aidge::DimIdx_t DIM>
template <class T>
void Aidge::ConvImpl_cuda<DIM>::backward_(const Tensor& input0, const Tensor& input1, const Tensor& input2) {
const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
std::shared_ptr<Tensor> gradOutputFallback;
const auto& gradOutput = op.getOutput(0)->grad()->refCastFrom(gradOutputFallback, *(op.getInput(0)->grad()));
const T alpha = 1.0f;
const T beta = 0.0f;
CHECK_CUDNN_STATUS(cudnnConvolutionBackwardFilter(
CudaContext::cudnnHandle(),
&alpha,
std::dynamic_pointer_cast<TensorImpl_cuda_>(input0.getImpl())->getCudnnTensorDesc(input0),
input0.getImpl()->rawPtr(),
std::dynamic_pointer_cast<TensorImpl_cuda_>(gradOutput.getImpl())->getCudnnTensorDesc(gradOutput),
gradOutput.getImpl()->rawPtr(),
mConvDesc,
mBwdFilterAlgo,
mBwdWorkspace,
mWorkspaceSize,
&beta,
mFilterDesc,
op.getInput(1)->grad()->getImpl()->rawPtr()));
CHECK_CUDNN_STATUS(cudnnConvolutionBackwardData(
CudaContext::cudnnHandle(),
&alpha,
mFilterDesc,
input1.getImpl()->rawPtr(),
std::dynamic_pointer_cast<TensorImpl_cuda_>(gradOutput.getImpl())->getCudnnTensorDesc(gradOutput),
gradOutput.getImpl()->rawPtr(),
mConvDesc,
mBwdDataAlgo,
mBwdWorkspace,
mWorkspaceSize,
&beta,
std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getInput(0)->grad()->getImpl())->getCudnnTensorDesc(*op.getInput(0)),
op.getInput(0)->grad()->getImpl()->rawPtr()));
// Add bias (if there is any)
if (mOp.getRawInput(2) && input2.size() > 0) {
// Bias tensor needs to have the same number of dims than output tensor for cudnnAddTensor()
std::vector<DimSize_t> gradBiasDims(DIM+2, 1);
gradBiasDims[1] = op.getInput(2)->grad()->size();
// Create a dummy tensor with the right dims in order to get a CuDNN tensor descriptor (with getCudnnTensorDesc())
Tensor gradBias(op.getInput(2)->grad()->dataType());
gradBias.setBackend("cuda");
gradBias.resize(gradBiasDims);
// TODO: find a more elegant solution(?)
CHECK_CUDNN_STATUS(cudnnConvolutionBackwardBias(CudaContext::cudnnHandle(),
&alpha,
std::dynamic_pointer_cast<TensorImpl_cuda_>(gradOutput.getImpl())->getCudnnTensorDesc(gradOutput),
gradOutput.getImpl()->rawPtr(),
&beta,
std::dynamic_pointer_cast<TensorImpl_cuda_>(gradBias.getImpl())->getCudnnTensorDesc(gradBias),
op.getInput(2)->grad()->getImpl()->rawPtr()));
}
}
template <Aidge::DimIdx_t DIM>
Aidge::ConvImpl_cuda<DIM>::~ConvImpl_cuda() {
if (mConvDesc != nullptr) {
cudnnDestroyConvolutionDescriptor(mConvDesc);
}
if (mFilterDesc != nullptr) {
cudnnDestroyFilterDescriptor(mFilterDesc);
}
if (mFwdWorkspace != nullptr) {
cudaFree(mFwdWorkspace);
}
}
// Template declarations
template class Aidge::ConvImpl_cuda<1>;
template class Aidge::ConvImpl_cuda<2>;