diff --git a/include/aidge/backend/cuda/operator/ConvImpl.hpp b/include/aidge/backend/cuda/operator/ConvImpl.hpp index 31b80adbae0602211fa5c11873875a1a10eb40db..0ad995b24082782c611b93fbcc040d1319a7362f 100644 --- a/include/aidge/backend/cuda/operator/ConvImpl.hpp +++ b/include/aidge/backend/cuda/operator/ConvImpl.hpp @@ -49,7 +49,7 @@ public: ~ConvImpl_cuda(); private: - template <class T> void forward_(); + template <class T> void forward_(const Tensor& input0, const Tensor& input1, const Tensor& input2); }; namespace { diff --git a/src/operator/ConvImpl.cpp b/src/operator/ConvImpl.cpp index 88b5e7deb26c6b5cff7ea7a0466102f7c3c6a7f3..95a6aa523257f9fcd527460315dfc66e846ff714 100644 --- a/src/operator/ConvImpl.cpp +++ b/src/operator/ConvImpl.cpp @@ -28,6 +28,15 @@ void Aidge::ConvImpl_cuda<DIM>::forward() { assert(mOp.getRawInput(0) && "missing input #0"); assert(mOp.getRawInput(1) && "missing input #1"); + // Convert input data (no overhead if not needed!) + // TODO: right now, if needed, memory will be allocated/deallocated at each + // call to forward(). We might put the following shared_ptr as members of + // this class to avoid that. + std::shared_ptr<Tensor> input0Fallback, input1Fallback, input2Fallback; + const auto& input0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->refCastFrom(input0Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))); + const auto& input1 = std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->refCastFrom(input1Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))); + const auto& input2 = std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->refCastFrom(input2Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))); + // Lazy-initialize CuDNN convolution descriptor if (mConvDesc == nullptr) { const Conv_Op<DIM>& convOp = static_cast<const Conv_Op<DIM>&>(mOp); @@ -48,11 +57,11 @@ void Aidge::ConvImpl_cuda<DIM>::forward() { // Lazy-initialize CuDNN filter descriptor if (mFilterDesc == nullptr) { - const std::vector<int> kernels(std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dims().begin(), std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dims().end()); + const std::vector<int> kernels(input1.dims().begin(), input1.dims().end()); CHECK_CUDNN_STATUS(cudnnCreateFilterDescriptor(&mFilterDesc)); CHECK_CUDNN_STATUS(cudnnSetFilterNdDescriptor(mFilterDesc, - DataTypeToCudnn(std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dataType()), + DataTypeToCudnn(input1.dataType()), CUDNN_TENSOR_NCHW, kernels.size(), &kernels[0])); @@ -72,7 +81,7 @@ void Aidge::ConvImpl_cuda<DIM>::forward() { CHECK_CUDNN_STATUS(cudnnFindConvolutionForwardAlgorithm( CudaContext::cudnnHandle(), - dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl().get())->getCudnnTensorDesc(), + dynamic_cast<TensorImpl_cuda_*>(input0.getImpl().get())->getCudnnTensorDesc(), mFilterDesc, mConvDesc, dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl().get())->getCudnnTensorDesc(), @@ -86,7 +95,7 @@ void Aidge::ConvImpl_cuda<DIM>::forward() { CHECK_CUDNN_STATUS(cudnnGetConvolutionForwardWorkspaceSize( CudaContext::cudnnHandle(), - dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl().get())->getCudnnTensorDesc(), + dynamic_cast<TensorImpl_cuda_*>(input0.getImpl().get())->getCudnnTensorDesc(), mFilterDesc, mConvDesc, dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl().get())->getCudnnTensorDesc(), @@ -101,26 +110,26 @@ void Aidge::ConvImpl_cuda<DIM>::forward() { // Template is only for scaling parameters, which are always in float // excepted when the convolution is performed in double precision. if (std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType() == DataType::Float64) { - forward_<double>(); + forward_<double>(input0, input1, input2); } else { - forward_<float>(); + forward_<float>(input0, input1, input2); } } template <Aidge::DimIdx_t DIM> template <class T> -void Aidge::ConvImpl_cuda<DIM>::forward_() { +void Aidge::ConvImpl_cuda<DIM>::forward_(const Tensor& input0, const Tensor& input1, const Tensor& input2) { const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f; typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f; CHECK_CUDNN_STATUS( cudnnConvolutionForward(CudaContext::cudnnHandle(), &alpha, - dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl().get())->getCudnnTensorDesc(), - std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(), + dynamic_cast<TensorImpl_cuda_*>(input0.getImpl().get())->getCudnnTensorDesc(), + input0.getImpl()->rawPtr(), mFilterDesc, - std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->getImpl()->rawPtr(), + input1.getImpl()->rawPtr(), mConvDesc, mFwdAlgo, mWorkspace, @@ -130,13 +139,13 @@ void Aidge::ConvImpl_cuda<DIM>::forward_() { std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr())); // Add bias (if there is any) - if (mOp.getRawInput(2) && std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->size() > 0) { + if (mOp.getRawInput(2) && input2.size() > 0) { // Bias tensor needs to have the same number of dims than output tensor for cudnnAddTensor() std::vector<DimSize_t> biasDims(DIM+2, 1); - biasDims[1] = std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->size(); + biasDims[1] = input2.size(); // Create a dummy tensor with the right dims in order to get a CuDNN tensor descriptor (with getCudnnTensorDesc()) - Tensor bias(std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->dataType()); + Tensor bias(input2.dataType()); bias.setBackend("cuda"); bias.resize(biasDims); // TODO: find a more elegant solution(?) @@ -144,7 +153,7 @@ void Aidge::ConvImpl_cuda<DIM>::forward_() { CHECK_CUDNN_STATUS(cudnnAddTensor(CudaContext::cudnnHandle(), &alpha, dynamic_cast<TensorImpl_cuda_*>(bias.getImpl().get())->getCudnnTensorDesc(), - std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->getImpl()->rawPtr(), + input2.getImpl()->rawPtr(), &alpha, dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl().get())->getCudnnTensorDesc(), std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()));