diff --git a/include/aidge/backend/cuda/operator/FCImpl_CUDA_kernels.hpp b/include/aidge/backend/cuda/operator/FCImpl_CUDA_kernels.hpp index ffbb06752d39c14dc15f24967a394ed6acb9521b..9c83332f9857ed802f5563faef558a7278d3e992 100644 --- a/include/aidge/backend/cuda/operator/FCImpl_CUDA_kernels.hpp +++ b/include/aidge/backend/cuda/operator/FCImpl_CUDA_kernels.hpp @@ -25,12 +25,18 @@ namespace Aidge { template <class T> cublasStatus_t cublasGemm(cublasHandle_t handle, - cublasOperation_t transa, cublasOperation_t transb, - int m, int n, int k, - const T *alpha, - const T *A, int lda, - const T *B, int ldb, - const T *beta, - T *C, int ldc); + cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, + const T *alpha, + const T *A, int lda, + const T *B, int ldb, + const T *beta, + T *C, int ldc); +// cublasGemm(cublasContext*&, cublasOperation_t, cublasOperation_t, int&, int&, int&, + // const type*, + // const __half*&, int&, + // const __half*&, int&, + // const type*, + // __half*&, int&)’ } #endif /* AIDGE_CUDA_OPERATOR_FCIMPL_FORWARD_KERNEL_H_ */ \ No newline at end of file diff --git a/include/aidge/backend/cuda/utils/CudaUtils.hpp b/include/aidge/backend/cuda/utils/CudaUtils.hpp index 2f66d0e778778400f0b7def345619d635cc37674..a505d331dc182b6e24857b0d1045282688fdf8d8 100644 --- a/include/aidge/backend/cuda/utils/CudaUtils.hpp +++ b/include/aidge/backend/cuda/utils/CudaUtils.hpp @@ -11,6 +11,8 @@ #include <cuda.h> #include <cudnn.h> +#include "aidge/data/half.hpp" + #define CHECK_CUDNN_STATUS(status) \ do { \ const cudnnStatus_t e = (status); \ @@ -62,6 +64,29 @@ namespace Aidge { namespace Cuda { + // CuDNN scaling parameters are typically "alpha" and "beta". + // Their type must be "float" for HALF and FLOAT (default template) + // and "double" for DOUBLE (specialized template) + template <class T> + struct cudnn_scaling_type { + typedef float type; + }; + + template <> + struct cudnn_scaling_type<double> { + typedef double type; + }; + + template <class T> + struct cuda_type { + typedef T type; + }; + + template <> + struct cuda_type<half_float::half> { + typedef __half type; + }; + const char* cublasGetErrorString(cublasStatus_t error); // Enable Peer-to-Peer communications between devices diff --git a/src/operator/AvgPoolingImpl.cpp b/src/operator/AvgPoolingImpl.cpp index 4737e3ebf43622f81dbe8938c5dddcaaca94cb80..861533eced6112903ea288f09711f3a382db542c 100644 --- a/src/operator/AvgPoolingImpl.cpp +++ b/src/operator/AvgPoolingImpl.cpp @@ -10,17 +10,14 @@ ********************************************************************************/ #include <cassert> -#include <chrono> // std::chrono::milliseconds -#include <numeric> // std::accumulate -#include <thread> // std::this_thread::sleep_for #include <vector> -#include "aidge/utils/Types.h" -#include "aidge/operator/AvgPooling.hpp" - #include "aidge/backend/cuda/data/TensorImpl.hpp" #include "aidge/backend/cuda/operator/AvgPoolingImpl.hpp" #include "aidge/backend/cuda/utils/CudaContext.hpp" +#include "aidge/backend/cuda/utils/CudaUtils.hpp" +#include "aidge/operator/AvgPooling.hpp" +#include "aidge/utils/Types.h" template <Aidge::DimIdx_t DIM> void Aidge::AvgPoolingImpl_cuda<DIM>::forward() { @@ -49,11 +46,18 @@ void Aidge::AvgPoolingImpl_cuda<DIM>::forward() { &strides[0])); } - if (std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType() == DataType::Float64) { - forward_<double>(input); - } - else { - forward_<float>(input); + switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) { + case DataType::Float64: + forward_<double>(input); + break; + case DataType::Float32: + forward_<float>(input); + break; + case DataType::Float16: + forward_<half>(input); + break; + default: + AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda"); } } @@ -61,8 +65,8 @@ template <Aidge::DimIdx_t DIM> template <class T> void Aidge::AvgPoolingImpl_cuda<DIM>::forward_(const Tensor& input) { const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); - const T alpha = 1.0f; - const T beta = 0.0f; + const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f; + const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f; CHECK_CUDNN_STATUS( cudnnPoolingForward( CudaContext::cudnnHandle(), @@ -83,6 +87,5 @@ Aidge::AvgPoolingImpl_cuda<DIM>::~AvgPoolingImpl_cuda() { cudnnDestroyPoolingDescriptor(mAvgPoolingDesc); } - // Template declarations template class Aidge::AvgPoolingImpl_cuda<2>; diff --git a/src/operator/ConvImpl.cpp b/src/operator/ConvImpl.cpp index 19ce56bcb99f60e08427f8d9b110637c90582adf..b64b2e3d470fd8813938a800293d55b7ba9a4076 100644 --- a/src/operator/ConvImpl.cpp +++ b/src/operator/ConvImpl.cpp @@ -10,17 +10,13 @@ ********************************************************************************/ #include <cassert> -#include <chrono> // std::chrono::milliseconds -#include <numeric> // std::accumulate -#include <thread> // std::this_thread::sleep_for #include <vector> -#include "aidge/utils/Types.h" -#include "aidge/operator/Conv.hpp" - #include "aidge/backend/cuda/data/TensorImpl.hpp" #include "aidge/backend/cuda/operator/ConvImpl.hpp" -#include "aidge/backend/cuda/utils/CudaContext.hpp" +#include "aidge/backend/cuda/utils/CudaUtils.hpp" +#include "aidge/operator/Conv.hpp" +#include "aidge/utils/Types.h" template <Aidge::DimIdx_t DIM> void Aidge::ConvImpl_cuda<DIM>::forward() { @@ -106,11 +102,18 @@ void Aidge::ConvImpl_cuda<DIM>::forward() { // Do the actual forward computation // Template is only for scaling parameters, which are always in float // excepted when the convolution is performed in double precision. - if (op.getOutput(0)->dataType() == DataType::Float64) { - forward_<double>(input0, input1, input2); - } - else { - forward_<float>(input0, input1, input2); + switch(op.getOutput(0)->dataType()) { + case DataType::Float64: + forward_<double>(input0, input1, input2); + break; + case DataType::Float32: + forward_<float>(input0, input1, input2); + break; + case DataType::Float16: + forward_<half>(input0, input1, input2); + break; + default: + AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda"); } } @@ -118,9 +121,8 @@ template <Aidge::DimIdx_t DIM> template <class T> void Aidge::ConvImpl_cuda<DIM>::forward_(const Tensor& input0, const Tensor& input1, const Tensor& input2) { const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); - const T alpha = 1.0f; - const T beta = 0.0f; - + const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f; + const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f; CHECK_CUDNN_STATUS(cudnnConvolutionForward(CudaContext::cudnnHandle(), &alpha, std::dynamic_pointer_cast<TensorImpl_cuda_>(input0.getImpl())->getCudnnTensorDesc(input0), diff --git a/src/operator/FCImpl.cpp b/src/operator/FCImpl.cpp index bb107a5c8044bd40c874493f275e20d02b7298d2..a8f8da8e114baaae3f7d09146cdd5b664150f260 100644 --- a/src/operator/FCImpl.cpp +++ b/src/operator/FCImpl.cpp @@ -15,15 +15,13 @@ #include <thread> // std::this_thread::sleep_for #include <vector> -#include "aidge/utils/Types.h" -#include "aidge/operator/FC.hpp" - #include "aidge/backend/cuda/data/TensorImpl.hpp" #include "aidge/backend/cuda/operator/FCImpl.hpp" #include "aidge/backend/cuda/operator/FCImpl_CUDA_kernels.hpp" #include "aidge/backend/cuda/utils/CudaContext.hpp" - #include "aidge/backend/cuda/utils/CudaUtils.hpp" +#include "aidge/operator/FC.hpp" +#include "aidge/utils/Types.h" void Aidge::FCImpl_cuda::forward() { assert(mOp.getRawInput(0) && "missing input #0"); @@ -39,11 +37,18 @@ void Aidge::FCImpl_cuda::forward() { bool noBias = fcOp.template getAttr<FCAttr::NoBias>(); std::size_t outChannels = static_cast<std::size_t>(fcOp.template getAttr<FCAttr::OutChannels>()); - if (std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType() == DataType::Float64) { - forward_<double>(input0, input1, input2, noBias, outChannels); - } - else { - forward_<float>(input0, input1, input2, noBias, outChannels); + switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) { + case DataType::Float64: + forward_<double>(input0, input1, input2, noBias, outChannels); + break; + case DataType::Float32: + forward_<float>(input0, input1, input2, noBias, outChannels); + break; + case DataType::Float16: + forward_<half>(input0, input1, input2, noBias, outChannels); + break; + default: + AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda"); } } @@ -61,26 +66,26 @@ void Aidge::FCImpl_cuda::forward_(const Tensor& input0, const Tensor& input1, co int lda = k; int ldb = k; int ldc = n; - const T alpha = T(1.0); - const T beta = T(0.0); + const T alpha = 1.0f; + const T beta = 0.0f; CHECK_CUBLAS_STATUS(cublasGemm(CudaContext::cublasHandle(), CUBLAS_OP_T, CUBLAS_OP_N, n, m, k, - &alpha, + reinterpret_cast<const typename Cuda::cuda_type<T>::type*>(&alpha), weights, ldb, input, lda, - &beta, + reinterpret_cast<const typename Cuda::cuda_type<T>::type*>(&beta), output, ldc)); if(!noBias){ T* onesVector; - cudaMalloc((void**)&onesVector, m * sizeof(T)); + CHECK_CUDA_STATUS(cudaMalloc((void**)&onesVector, m * sizeof(T))); // Fill the vector with ones std::vector<T> onesVec(m, T(1.0)); CHECK_CUDA_STATUS(cudaMemcpy(onesVector, @@ -95,12 +100,12 @@ void Aidge::FCImpl_cuda::forward_(const Tensor& input0, const Tensor& input1, co n, m, 1, - &alpha, + reinterpret_cast<const typename Cuda::cuda_type<T>::type*>(&alpha), biases, n, onesVector, 1, - &alpha, + reinterpret_cast<const typename Cuda::cuda_type<T>::type*>(&alpha), output, n)); diff --git a/src/operator/FCImpl_CUDA_kernels.cu b/src/operator/FCImpl_CUDA_kernels.cu index a30519ebe4bf262b87bb9e07342f18525df2e8f4..5139ac1d7edf61cf347870e6add2870b2792a0e5 100644 --- a/src/operator/FCImpl_CUDA_kernels.cu +++ b/src/operator/FCImpl_CUDA_kernels.cu @@ -16,61 +16,61 @@ namespace Aidge{ template <> cublasStatus_t cublasGemm<__half>(cublasHandle_t handle, - cublasOperation_t transa, cublasOperation_t transb, - int m, int n, int k, - const __half *alpha, - const __half *A, int lda, - const __half *B, int ldb, - const __half *beta, - __half *C, int ldc) + cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, + const __half *alpha, + const __half *A, int lda, + const __half *B, int ldb, + const __half *beta, + __half *C, int ldc) { return cublasHgemm(handle, - transa, transb, - m, n, k, - alpha, - A, lda, - B, ldb, - beta, - C, ldc); + transa, transb, + m, n, k, + alpha, + A, lda, + B, ldb, + beta, + C, ldc); } template <> cublasStatus_t cublasGemm<float>(cublasHandle_t handle, - cublasOperation_t transa, cublasOperation_t transb, - int m, int n, int k, - const float *alpha, - const float *A, int lda, - const float *B, int ldb, - const float *beta, - float *C, int ldc) + cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, + const float *alpha, + const float *A, int lda, + const float *B, int ldb, + const float *beta, + float *C, int ldc) { return cublasSgemm(handle, - transa, transb, - m, n, k, - alpha, - A, lda, - B, ldb, - beta, - C, ldc); + transa, transb, + m, n, k, + alpha, + A, lda, + B, ldb, + beta, + C, ldc); } template <> cublasStatus_t cublasGemm<double>(cublasHandle_t handle, - cublasOperation_t transa, cublasOperation_t transb, - int m, int n, int k, - const double *alpha, - const double *A, int lda, - const double *B, int ldb, - const double *beta, - double *C, int ldc) + cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, + const double *alpha, + const double *A, int lda, + const double *B, int ldb, + const double *beta, + double *C, int ldc) { return cublasDgemm(handle, - transa, transb, - m, n, k, - alpha, - A, lda, - B, ldb, - beta, - C, ldc); + transa, transb, + m, n, k, + alpha, + A, lda, + B, ldb, + beta, + C, ldc); } } \ No newline at end of file diff --git a/src/operator/MaxPoolingImpl.cpp b/src/operator/MaxPoolingImpl.cpp index 9304160d0d509014785820745f61187fdf13c17e..19a567fe4b273e821f95a16989d0a09bd510fe07 100644 --- a/src/operator/MaxPoolingImpl.cpp +++ b/src/operator/MaxPoolingImpl.cpp @@ -10,17 +10,14 @@ ********************************************************************************/ #include <cassert> -#include <chrono> // std::chrono::milliseconds -#include <numeric> // std::accumulate -#include <thread> // std::this_thread::sleep_for #include <vector> -#include "aidge/utils/Types.h" -#include "aidge/operator/MaxPooling.hpp" - #include "aidge/backend/cuda/data/TensorImpl.hpp" #include "aidge/backend/cuda/operator/MaxPoolingImpl.hpp" #include "aidge/backend/cuda/utils/CudaContext.hpp" +#include "aidge/backend/cuda/utils/CudaUtils.hpp" +#include "aidge/operator/MaxPooling.hpp" +#include "aidge/utils/Types.h" template <Aidge::DimIdx_t DIM> void Aidge::MaxPoolingImpl_cuda<DIM>::forward() { @@ -49,11 +46,18 @@ void Aidge::MaxPoolingImpl_cuda<DIM>::forward() { &strides[0])); } - if (std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType() == DataType::Float64) { - forward_<double>(input); - } - else { - forward_<float>(input); + switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) { + case DataType::Float64: + forward_<double>(input); + break; + case DataType::Float32: + forward_<float>(input); + break; + case DataType::Float16: + forward_<half>(input); + break; + default: + AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda"); } } @@ -61,8 +65,8 @@ template <Aidge::DimIdx_t DIM> template <class T> void Aidge::MaxPoolingImpl_cuda<DIM>::forward_(const Tensor& input) { const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); - const T alpha = 1.0f; - const T beta = 0.0f; + const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f; + const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f; CHECK_CUDNN_STATUS( cudnnPoolingForward( CudaContext::cudnnHandle(), diff --git a/src/operator/ReLUImpl.cpp b/src/operator/ReLUImpl.cpp index ed2e5d4a0ed14c68ff5329932e938876165e92e3..c880184bc51fce65710d02d9a483e43de6184d89 100644 --- a/src/operator/ReLUImpl.cpp +++ b/src/operator/ReLUImpl.cpp @@ -10,17 +10,14 @@ ********************************************************************************/ #include <cassert> -#include <chrono> // std::chrono::milliseconds -#include <numeric> // std::accumulate -#include <thread> // std::this_thread::sleep_for #include <vector> -#include "aidge/utils/Types.h" -#include "aidge/operator/ReLU.hpp" - #include "aidge/backend/cuda/data/TensorImpl.hpp" #include "aidge/backend/cuda/operator/ReLUImpl.hpp" #include "aidge/backend/cuda/utils/CudaContext.hpp" +#include "aidge/backend/cuda/utils/CudaUtils.hpp" +#include "aidge/operator/ReLU.hpp" +#include "aidge/utils/Types.h" void Aidge::ReLUImpl_cuda::forward() { const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); @@ -52,8 +49,8 @@ void Aidge::ReLUImpl_cuda::forward() { template <class T> void Aidge::ReLUImpl_cuda::forward_(const Tensor& input) { const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); - const T alpha = 1.0f; - const T beta = 0.0f; + const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f; + const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f; CHECK_CUDNN_STATUS( cudnnActivationForward(CudaContext::cudnnHandle(), mReLUDesc, diff --git a/unit_tests/Test_AvgPoolingImpl.cpp b/unit_tests/Test_AvgPoolingImpl.cpp index def2e93f4105eb107c30f7cce3a2a2038da12d58..9b7f898ecb1a321713dc114d6a03d057a810271f 100644 --- a/unit_tests/Test_AvgPoolingImpl.cpp +++ b/unit_tests/Test_AvgPoolingImpl.cpp @@ -12,9 +12,11 @@ #include <array> #include <catch2/catch_test_macros.hpp> +#include <cuda_fp16.h> #include "Test_cuda.hpp" +#include "aidge/data/half.hpp" #include "aidge/data/Tensor.hpp" #include "aidge/backend/cpu.hpp" @@ -122,4 +124,38 @@ TEST_CASE("[gpu/operator] AvgPooling(forward)", "[AvgPooling][GPU]") { delete[] computedOutput; } + + SECTION("half") { + std::shared_ptr<Tensor> myInput2 = std::make_shared<Tensor>(Array4D<half_float::half,1,1,3,3> { //NCHW + { + { + {{half_float::half(0.3745), half_float::half(0.9507), half_float::half(0.7320)}, + {half_float::half(0.5987), half_float::half(0.1560), half_float::half(0.1560)}, + {half_float::half(0.0581), half_float::half(0.8662), half_float::half(0.6011)}} + } + } + }); + myInput2->setBackend("cuda"); + + std::shared_ptr<Node> myAvgPool = AvgPooling({3,3}, "mycdw", {3,3}); + auto op = std::static_pointer_cast<OperatorTensor>(myAvgPool -> getOperator()); + std::shared_ptr<Tensor> myOutput = std::make_shared<Tensor>(Array4D<half_float::half,1,1,1,1> { + {{{{(half_float::half(0.3745) + half_float::half(0.9507) + half_float::half(0.7320) + half_float::half(0.5987) + half_float::half(0.1560) + half_float::half(0.1560) + half_float::half(0.0581) + half_float::half(0.8662) + half_float::half(0.6011))/half_float::half(9.0)}}}} + }); + op->associateInput(0,myInput2); + op->setDataType(DataType::Float16); + op->setBackend("cuda"); + op->computeOutputDims(); + myAvgPool->forward(); + + half* computedOutput = new half[myOutput->size()](); + cudaMemcpy(computedOutput, op->getOutput(0)->getImpl()->rawPtr(), sizeof(half) * myOutput->size(), cudaMemcpyDeviceToHost); + + for(int i = 0; i < myOutput->size(); i++){ + const half_float::half targetOutput = *(static_cast<half_float::half*>(myOutput->getImpl()->rawPtr()) + i); + REQUIRE(fabs(computedOutput[i] - targetOutput) < 1e-6); + } + + delete[] computedOutput; + } } \ No newline at end of file