diff --git a/src/operator/ConvImpl.cpp b/src/operator/ConvImpl.cpp index 42f6e84052c6fb6170d270fbc6a0ca3467095776..19ce56bcb99f60e08427f8d9b110637c90582adf 100644 --- a/src/operator/ConvImpl.cpp +++ b/src/operator/ConvImpl.cpp @@ -24,14 +24,16 @@ template <Aidge::DimIdx_t DIM> void Aidge::ConvImpl_cuda<DIM>::forward() { + const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); + // FIXME: uncomment the following code once memory handling will work assert(mOp.getRawInput(0) && "missing input #0"); assert(mOp.getRawInput(1) && "missing input #1"); // Convert input data (no overhead if not needed!) - const auto& input0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->refCastFrom(mInput0Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))); - const auto& input1 = std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->refCastFrom(mInput1Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))); - const auto& input2 = std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->refCastFrom(mInput2Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))); + const auto& input0 = op.getInput(0)->refCastFrom(mInput0Fallback, *op.getOutput(0)); + const auto& input1 = op.getInput(1)->refCastFrom(mInput1Fallback, *op.getOutput(0)); + const auto& input2 = op.getInput(2)->refCastFrom(mInput2Fallback, *op.getOutput(0)); // Lazy-initialize CuDNN convolution descriptor if (mConvDesc == nullptr) { @@ -41,14 +43,13 @@ void Aidge::ConvImpl_cuda<DIM>::forward() { const std::vector<int> upscales(convOp.template getAttr<ConvAttr::DilationDims>().begin(), convOp.template getAttr<ConvAttr::DilationDims>().end()); CHECK_CUDNN_STATUS(cudnnCreateConvolutionDescriptor(&mConvDesc)); - CHECK_CUDNN_STATUS( - cudnnSetConvolutionNdDescriptor(mConvDesc, - DIM, - &paddings[0], - &strides[0], - &upscales[0], - CUDNN_CROSS_CORRELATION, - DataTypeToCudnn(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()))); + CHECK_CUDNN_STATUS(cudnnSetConvolutionNdDescriptor(mConvDesc, + DIM, + &paddings[0], + &strides[0], + &upscales[0], + CUDNN_CROSS_CORRELATION, + DataTypeToCudnn(op.getOutput(0)->dataType()))); } // Lazy-initialize CuDNN filter descriptor @@ -57,10 +58,10 @@ void Aidge::ConvImpl_cuda<DIM>::forward() { CHECK_CUDNN_STATUS(cudnnCreateFilterDescriptor(&mFilterDesc)); CHECK_CUDNN_STATUS(cudnnSetFilterNdDescriptor(mFilterDesc, - DataTypeToCudnn(input1.dataType()), - CUDNN_TENSOR_NCHW, - kernels.size(), - &kernels[0])); + DataTypeToCudnn(input1.dataType()), + CUDNN_TENSOR_NCHW, + kernels.size(), + &kernels[0])); } // Set forward algorithm and allocate the required workspace @@ -76,14 +77,14 @@ void Aidge::ConvImpl_cuda<DIM>::forward() { std::vector<cudnnConvolutionFwdAlgoPerf_t> returnFwdAlgo(maxAlgoIterations); CHECK_CUDNN_STATUS(cudnnFindConvolutionForwardAlgorithm( - CudaContext::cudnnHandle(), - dynamic_cast<TensorImpl_cuda_*>(input0.getImpl().get())->getCudnnTensorDesc(input0), - mFilterDesc, - mConvDesc, - dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl().get())->getCudnnTensorDesc(*std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))), - maxAlgoIterations, - &returnAlgoCounts, - &returnFwdAlgo[0])); + CudaContext::cudnnHandle(), + std::dynamic_pointer_cast<TensorImpl_cuda_>(input0.getImpl())->getCudnnTensorDesc(input0), + mFilterDesc, + mConvDesc, + std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)), + maxAlgoIterations, + &returnAlgoCounts, + &returnFwdAlgo[0])); mFwdAlgo = returnFwdAlgo[0].algo; // Allocate the workspace required by the chosen CuDNN forward algorithm @@ -91,10 +92,10 @@ void Aidge::ConvImpl_cuda<DIM>::forward() { CHECK_CUDNN_STATUS(cudnnGetConvolutionForwardWorkspaceSize( CudaContext::cudnnHandle(), - dynamic_cast<TensorImpl_cuda_*>(input0.getImpl().get())->getCudnnTensorDesc(input0), + std::dynamic_pointer_cast<TensorImpl_cuda_>(input0.getImpl())->getCudnnTensorDesc(input0), mFilterDesc, mConvDesc, - dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl().get())->getCudnnTensorDesc(*std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))), + std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)), mFwdAlgo, &workspaceSize)); @@ -105,7 +106,7 @@ void Aidge::ConvImpl_cuda<DIM>::forward() { // Do the actual forward computation // Template is only for scaling parameters, which are always in float // excepted when the convolution is performed in double precision. - if (std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType() == DataType::Float64) { + if (op.getOutput(0)->dataType() == DataType::Float64) { forward_<double>(input0, input1, input2); } else { @@ -116,12 +117,13 @@ void Aidge::ConvImpl_cuda<DIM>::forward() { template <Aidge::DimIdx_t DIM> template <class T> void Aidge::ConvImpl_cuda<DIM>::forward_(const Tensor& input0, const Tensor& input1, const Tensor& input2) { + const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); const T alpha = 1.0f; const T beta = 0.0f; CHECK_CUDNN_STATUS(cudnnConvolutionForward(CudaContext::cudnnHandle(), &alpha, - dynamic_cast<TensorImpl_cuda_*>(input0.getImpl().get())->getCudnnTensorDesc(input0), + std::dynamic_pointer_cast<TensorImpl_cuda_>(input0.getImpl())->getCudnnTensorDesc(input0), input0.getImpl()->rawPtr(), mFilterDesc, input1.getImpl()->rawPtr(), @@ -130,8 +132,8 @@ void Aidge::ConvImpl_cuda<DIM>::forward_(const Tensor& input0, const Tensor& inp mFwdWorkspace, mWorkspaceSize, &beta, - dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl().get())->getCudnnTensorDesc(*std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))), - std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr())); + std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)), + op.getOutput(0)->getImpl()->rawPtr())); // Add bias (if there is any) if (mOp.getRawInput(2) && input2.size() > 0) { @@ -147,11 +149,11 @@ void Aidge::ConvImpl_cuda<DIM>::forward_(const Tensor& input0, const Tensor& inp CHECK_CUDNN_STATUS(cudnnAddTensor(CudaContext::cudnnHandle(), &alpha, - dynamic_cast<TensorImpl_cuda_*>(bias.getImpl().get())->getCudnnTensorDesc(bias), + std::dynamic_pointer_cast<TensorImpl_cuda_>(bias.getImpl())->getCudnnTensorDesc(bias), input2.getImpl()->rawPtr(), &alpha, - dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl().get())->getCudnnTensorDesc(*std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))), - std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr())); + std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)), + op.getOutput(0)->getImpl()->rawPtr())); } }