Skip to content
Snippets Groups Projects
Commit f8eb193f authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Adapted Conv impl to use cast/from features

parent 1ae93537
No related branches found
No related tags found
1 merge request!4Add Convert operator (a.k.a. Transmitter)
Pipeline #35587 failed
......@@ -49,7 +49,7 @@ public:
~ConvImpl_cuda();
private:
template <class T> void forward_();
template <class T> void forward_(const Tensor& input0, const Tensor& input1, const Tensor& input2);
};
namespace {
......
......@@ -28,6 +28,15 @@ void Aidge::ConvImpl_cuda<DIM>::forward() {
assert(mOp.getRawInput(0) && "missing input #0");
assert(mOp.getRawInput(1) && "missing input #1");
// Convert input data (no overhead if not needed!)
// TODO: right now, if needed, memory will be allocated/deallocated at each
// call to forward(). We might put the following shared_ptr as members of
// this class to avoid that.
std::shared_ptr<Tensor> input0Fallback, input1Fallback, input2Fallback;
const auto& input0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->refCastFrom(input0Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
const auto& input1 = std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->refCastFrom(input1Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
const auto& input2 = std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->refCastFrom(input2Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
// Lazy-initialize CuDNN convolution descriptor
if (mConvDesc == nullptr) {
const Conv_Op<DIM>& convOp = static_cast<const Conv_Op<DIM>&>(mOp);
......@@ -48,11 +57,11 @@ void Aidge::ConvImpl_cuda<DIM>::forward() {
// Lazy-initialize CuDNN filter descriptor
if (mFilterDesc == nullptr) {
const std::vector<int> kernels(std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dims().begin(), std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dims().end());
const std::vector<int> kernels(input1.dims().begin(), input1.dims().end());
CHECK_CUDNN_STATUS(cudnnCreateFilterDescriptor(&mFilterDesc));
CHECK_CUDNN_STATUS(cudnnSetFilterNdDescriptor(mFilterDesc,
DataTypeToCudnn(std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dataType()),
DataTypeToCudnn(input1.dataType()),
CUDNN_TENSOR_NCHW,
kernels.size(),
&kernels[0]));
......@@ -72,7 +81,7 @@ void Aidge::ConvImpl_cuda<DIM>::forward() {
CHECK_CUDNN_STATUS(cudnnFindConvolutionForwardAlgorithm(
CudaContext::cudnnHandle(),
dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl().get())->getCudnnTensorDesc(),
dynamic_cast<TensorImpl_cuda_*>(input0.getImpl().get())->getCudnnTensorDesc(),
mFilterDesc,
mConvDesc,
dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl().get())->getCudnnTensorDesc(),
......@@ -86,7 +95,7 @@ void Aidge::ConvImpl_cuda<DIM>::forward() {
CHECK_CUDNN_STATUS(cudnnGetConvolutionForwardWorkspaceSize(
CudaContext::cudnnHandle(),
dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl().get())->getCudnnTensorDesc(),
dynamic_cast<TensorImpl_cuda_*>(input0.getImpl().get())->getCudnnTensorDesc(),
mFilterDesc,
mConvDesc,
dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl().get())->getCudnnTensorDesc(),
......@@ -101,26 +110,26 @@ void Aidge::ConvImpl_cuda<DIM>::forward() {
// Template is only for scaling parameters, which are always in float
// excepted when the convolution is performed in double precision.
if (std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType() == DataType::Float64) {
forward_<double>();
forward_<double>(input0, input1, input2);
}
else {
forward_<float>();
forward_<float>(input0, input1, input2);
}
}
template <Aidge::DimIdx_t DIM>
template <class T>
void Aidge::ConvImpl_cuda<DIM>::forward_() {
void Aidge::ConvImpl_cuda<DIM>::forward_(const Tensor& input0, const Tensor& input1, const Tensor& input2) {
const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f;
typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f;
CHECK_CUDNN_STATUS(
cudnnConvolutionForward(CudaContext::cudnnHandle(),
&alpha,
dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl().get())->getCudnnTensorDesc(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(),
dynamic_cast<TensorImpl_cuda_*>(input0.getImpl().get())->getCudnnTensorDesc(),
input0.getImpl()->rawPtr(),
mFilterDesc,
std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->getImpl()->rawPtr(),
input1.getImpl()->rawPtr(),
mConvDesc,
mFwdAlgo,
mWorkspace,
......@@ -130,13 +139,13 @@ void Aidge::ConvImpl_cuda<DIM>::forward_() {
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()));
// Add bias (if there is any)
if (mOp.getRawInput(2) && std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->size() > 0) {
if (mOp.getRawInput(2) && input2.size() > 0) {
// Bias tensor needs to have the same number of dims than output tensor for cudnnAddTensor()
std::vector<DimSize_t> biasDims(DIM+2, 1);
biasDims[1] = std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->size();
biasDims[1] = input2.size();
// Create a dummy tensor with the right dims in order to get a CuDNN tensor descriptor (with getCudnnTensorDesc())
Tensor bias(std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->dataType());
Tensor bias(input2.dataType());
bias.setBackend("cuda");
bias.resize(biasDims);
// TODO: find a more elegant solution(?)
......@@ -144,7 +153,7 @@ void Aidge::ConvImpl_cuda<DIM>::forward_() {
CHECK_CUDNN_STATUS(cudnnAddTensor(CudaContext::cudnnHandle(),
&alpha,
dynamic_cast<TensorImpl_cuda_*>(bias.getImpl().get())->getCudnnTensorDesc(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->getImpl()->rawPtr(),
input2.getImpl()->rawPtr(),
&alpha,
dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl().get())->getCudnnTensorDesc(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()));
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment