Skip to content
Snippets Groups Projects
Commit bff7cdcf authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added Conv backward prototype (UNTESTED)

parent 4e1f953b
No related branches found
No related tags found
No related merge requests found
...@@ -34,8 +34,11 @@ private: ...@@ -34,8 +34,11 @@ private:
cudnnConvolutionDescriptor_t mConvDesc = nullptr; cudnnConvolutionDescriptor_t mConvDesc = nullptr;
cudnnFilterDescriptor_t mFilterDesc = nullptr; cudnnFilterDescriptor_t mFilterDesc = nullptr;
cudnnConvolutionFwdAlgo_t mFwdAlgo; cudnnConvolutionFwdAlgo_t mFwdAlgo;
cudnnConvolutionBwdFilterAlgo_t mBwdFilterAlgo;
cudnnConvolutionBwdDataAlgo_t mBwdDataAlgo;
size_t mWorkspaceSize = 0; size_t mWorkspaceSize = 0;
void* mFwdWorkspace = nullptr; void* mFwdWorkspace = nullptr;
void* mBwdWorkspace = nullptr;
std::shared_ptr<Tensor> mInput0Fallback; std::shared_ptr<Tensor> mInput0Fallback;
std::shared_ptr<Tensor> mInput1Fallback; std::shared_ptr<Tensor> mInput1Fallback;
std::shared_ptr<Tensor> mInput2Fallback; std::shared_ptr<Tensor> mInput2Fallback;
...@@ -49,10 +52,12 @@ public: ...@@ -49,10 +52,12 @@ public:
public: public:
void forward(); void forward();
void backward();
~ConvImpl_cuda(); ~ConvImpl_cuda();
private: private:
template <class T> void forward_(const Tensor& input0, const Tensor& input1, const Tensor& input2); template <class T> void forward_(const Tensor& input0, const Tensor& input1, const Tensor& input2);
template <class T> void backward_(const Tensor& input0, const Tensor& input1, const Tensor& input2);
}; };
namespace { namespace {
......
...@@ -157,6 +157,176 @@ void Aidge::ConvImpl_cuda<DIM>::forward_(const Tensor& input0, const Tensor& inp ...@@ -157,6 +157,176 @@ void Aidge::ConvImpl_cuda<DIM>::forward_(const Tensor& input0, const Tensor& inp
} }
} }
template <Aidge::DimIdx_t DIM>
void Aidge::ConvImpl_cuda<DIM>::backward() {
const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
// FIXME: uncomment the following code once memory handling will work
assert(mOp.getRawInput(0) && "missing input #0");
assert(mOp.getRawInput(1) && "missing input #1");
// Convert input data (no overhead if not needed!)
const auto& input0 = op.getInput(0)->ref(mInput0Fallback, *op.getOutput(0));
const auto& input1 = op.getInput(1)->ref(mInput1Fallback, *op.getOutput(0));
const auto& input2 = op.getInput(2)->ref(mInput2Fallback, *op.getOutput(0));
// Set forward algorithm and allocate the required workspace
if (mBwdWorkspace == nullptr) {
// Find the best CuDNN backward algorithm (the one with the lowest compute time)
int maxAlgoIterations = 0;
cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(CudaContext::cudnnHandle(),
&maxAlgoIterations);
assert(maxAlgoIterations > 0 && "No available CUDNN ConvolutionBackwardFilterAlgorithm");
int returnAlgoCounts = 0;
std::vector<cudnnConvolutionBwdFilterAlgoPerf_t> returnBwdFilterAlgo(maxAlgoIterations);
CHECK_CUDNN_STATUS(cudnnFindConvolutionBackwardFilterAlgorithm(
CudaContext::cudnnHandle(),
std::dynamic_pointer_cast<TensorImpl_cuda_>(input0.getImpl())->getCudnnTensorDesc(input0),
std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
mConvDesc,
mFilterDesc,
maxAlgoIterations,
&returnAlgoCounts,
&returnBwdFilterAlgo[0]));
mBwdFilterAlgo = returnBwdFilterAlgo[0].algo;
maxAlgoIterations = 0;
cudnnGetConvolutionBackwardDataAlgorithmMaxCount(CudaContext::cudnnHandle(),
&maxAlgoIterations);
assert(maxAlgoIterations > 0 && "No available CUDNN ConvolutionBackwardDataAlgorithm");
returnAlgoCounts = 0;
std::vector<cudnnConvolutionBwdDataAlgoPerf_t> returnBwdDataAlgo(maxAlgoIterations);
CHECK_CUDNN_STATUS(cudnnFindConvolutionBackwardDataAlgorithm(
CudaContext::cudnnHandle(),
mFilterDesc,
std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
mConvDesc,
std::dynamic_pointer_cast<TensorImpl_cuda_>(input0.getImpl())->getCudnnTensorDesc(input0),
maxAlgoIterations,
&returnAlgoCounts,
&returnBwdDataAlgo[0]));
mBwdDataAlgo = returnBwdDataAlgo[0].algo;
// Allocate the workspace required by the chosen CuDNN backward algorithm
size_t workspaceSize = 0;
CHECK_CUDNN_STATUS(cudnnGetConvolutionBackwardFilterWorkspaceSize(
CudaContext::cudnnHandle(),
// same arguments as cudnnGetConvolutionBackwardFilterAlgorithm()
// -->
std::dynamic_pointer_cast<TensorImpl_cuda_>(input0.getImpl())->getCudnnTensorDesc(input0),
std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
mConvDesc,
mFilterDesc,
// <--
mBwdFilterAlgo,
&workspaceSize));
size_t workspaceSizeData = 0;
CHECK_CUDNN_STATUS(cudnnGetConvolutionBackwardDataWorkspaceSize(
CudaContext::cudnnHandle(),
// same arguments as cudnnGetConvolutionBackwardDataAlgorithm() -->
mFilterDesc,
std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
mConvDesc,
std::dynamic_pointer_cast<TensorImpl_cuda_>(input0.getImpl())->getCudnnTensorDesc(input0),
// <--
mBwdDataAlgo,
&workspaceSizeData));
if (workspaceSizeData > workspaceSize)
workspaceSize = workspaceSizeData;
if (workspaceSize > mWorkspaceSize) {
if (mFwdWorkspace != nullptr) {
cudaFree(mFwdWorkspace);
}
CHECK_CUDA_STATUS(cudaMalloc(&mFwdWorkspace, workspaceSize));
mWorkspaceSize = workspaceSize;
}
mBwdWorkspace = mFwdWorkspace;
}
// Do the actual backward computation
// Template is only for scaling parameters, which are always in float
// excepted when the convolution is performed in double precision.
if (op.getOutput(0)->dataType() == DataType::Float64) {
backward_<double>(input0, input1, input2);
}
else {
backward_<float>(input0, input1, input2);
}
}
template <Aidge::DimIdx_t DIM>
template <class T>
void Aidge::ConvImpl_cuda<DIM>::backward_(const Tensor& input0, const Tensor& input1, const Tensor& input2) {
const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
std::shared_ptr<Tensor> gradOutputFallback;
const auto& gradOutput = op.getOutput(0)->grad()->refCastFrom(gradOutputFallback, *(op.getInput(0)->grad()));
const T alpha = 1.0f;
const T beta = 0.0f;
CHECK_CUDNN_STATUS(cudnnConvolutionBackwardFilter(
CudaContext::cudnnHandle(),
&alpha,
std::dynamic_pointer_cast<TensorImpl_cuda_>(input0.getImpl())->getCudnnTensorDesc(input0),
input0.getImpl()->rawPtr(),
std::dynamic_pointer_cast<TensorImpl_cuda_>(gradOutput.getImpl())->getCudnnTensorDesc(gradOutput),
gradOutput.getImpl()->rawPtr(),
mConvDesc,
mBwdFilterAlgo,
mBwdWorkspace,
mWorkspaceSize,
&beta,
mFilterDesc,
op.getInput(1)->grad()->getImpl()->rawPtr()));
CHECK_CUDNN_STATUS(cudnnConvolutionBackwardData(
CudaContext::cudnnHandle(),
&alpha,
mFilterDesc,
input1.getImpl()->rawPtr(),
std::dynamic_pointer_cast<TensorImpl_cuda_>(gradOutput.getImpl())->getCudnnTensorDesc(gradOutput),
gradOutput.getImpl()->rawPtr(),
mConvDesc,
mBwdDataAlgo,
mBwdWorkspace,
mWorkspaceSize,
&beta,
std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getInput(0)->grad()->getImpl())->getCudnnTensorDesc(*op.getInput(0)),
op.getInput(0)->grad()->getImpl()->rawPtr()));
// Add bias (if there is any)
if (mOp.getRawInput(2) && input2.size() > 0) {
// Bias tensor needs to have the same number of dims than output tensor for cudnnAddTensor()
std::vector<DimSize_t> gradBiasDims(DIM+2, 1);
gradBiasDims[1] = op.getInput(2)->grad()->size();
// Create a dummy tensor with the right dims in order to get a CuDNN tensor descriptor (with getCudnnTensorDesc())
Tensor gradBias(op.getInput(2)->grad()->dataType());
gradBias.setBackend("cuda");
gradBias.resize(gradBiasDims);
// TODO: find a more elegant solution(?)
CHECK_CUDNN_STATUS(cudnnConvolutionBackwardBias(CudaContext::cudnnHandle(),
&alpha,
std::dynamic_pointer_cast<TensorImpl_cuda_>(gradOutput.getImpl())->getCudnnTensorDesc(gradOutput),
gradOutput.getImpl()->rawPtr(),
&beta,
std::dynamic_pointer_cast<TensorImpl_cuda_>(gradBias.getImpl())->getCudnnTensorDesc(gradBias),
op.getInput(2)->grad()->getImpl()->rawPtr()));
}
}
template <Aidge::DimIdx_t DIM> template <Aidge::DimIdx_t DIM>
Aidge::ConvImpl_cuda<DIM>::~ConvImpl_cuda() { Aidge::ConvImpl_cuda<DIM>::~ConvImpl_cuda() {
if (mConvDesc != nullptr) { if (mConvDesc != nullptr) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment