diff --git a/CMakeLists.txt b/CMakeLists.txt index e31509842a876151c31473d89f7e242f61617544..a8243358d6d506eec7248f2363ac5159ac2f298d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -50,6 +50,7 @@ target_link_libraries(${module_name} PUBLIC _aidge_core # _ is added because we link the target not the project CUDA::cudart + CUDA::cublas cudnn ) diff --git a/include/aidge/backend/cuda/operator/FCImpl_CUDA_kernels.hpp b/include/aidge/backend/cuda/operator/FCImpl_CUDA_kernels.hpp index a4f3e4ad59c66379f404a704b8f4110f25200a4f..ffbb06752d39c14dc15f24967a394ed6acb9521b 100644 --- a/include/aidge/backend/cuda/operator/FCImpl_CUDA_kernels.hpp +++ b/include/aidge/backend/cuda/operator/FCImpl_CUDA_kernels.hpp @@ -16,20 +16,21 @@ #include <cfloat> #include <cuda.h> #include <cuda_runtime_api.h> +#include <cuda_fp16.h> #include "aidge/data/Data.hpp" -#include "aidge/backend/cuda/operator/FCImpl.hpp" #include "aidge/backend/cuda/utils/CudaUtils.hpp" namespace Aidge { -template<class T> -void fc_forward_cuda(std::size_t nbInputs, std::size_t inChannels, std::size_t outChannels, bool noBias, const void *input, const void *weights, const void *bias, void *output); - -namespace { -static Registrar<FCImplForward_cuda> registrarFCImpl2DForward_cuda_Float32({DataType::Float32}, Aidge::fc_forward_cuda<float>); -static Registrar<FCImplForward_cuda> registrarFCImpl2DForward_cuda_Float64({DataType::Float64}, Aidge::fc_forward_cuda<double>); -} // namespace - +template <class T> +cublasStatus_t cublasGemm(cublasHandle_t handle, + cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, + const T *alpha, + const T *A, int lda, + const T *B, int ldb, + const T *beta, + T *C, int ldc); } #endif /* AIDGE_CUDA_OPERATOR_FCIMPL_FORWARD_KERNEL_H_ */ \ No newline at end of file diff --git a/src/operator/FCImpl.cpp b/src/operator/FCImpl.cpp index 3beeab373ddaea3563e4c4f1e6644e37769acfed..02c74dc30ac03dc3146e22b2209771f7539c8b8d 100644 --- a/src/operator/FCImpl.cpp +++ b/src/operator/FCImpl.cpp @@ -23,6 +23,7 @@ #include "aidge/backend/cuda/operator/FCImpl_CUDA_kernels.hpp" #include "aidge/backend/cuda/utils/CudaContext.hpp" +#include "aidge/backend/cuda/utils/CudaUtils.hpp" void Aidge::FCImpl_cuda::forward() { assert(mOp.getRawInput(0) && "missing input #0"); @@ -35,8 +36,9 @@ void Aidge::FCImpl_cuda::forward() { const auto& input2 = std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->refCastFrom(input2Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))); const auto& fcOp = static_cast<const FC_Op&>(mOp); - std::size_t outChannels = static_cast<std::size_t>(fcOp.template getAttr<FCAttr::OutChannels>()); bool noBias = fcOp.template getAttr<FCAttr::NoBias>(); + std::size_t outChannels = static_cast<std::size_t>(fcOp.template getAttr<FCAttr::OutChannels>()); + if (std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType() == DataType::Float64) { forward_<double>(input0, input1, input2, noBias, outChannels); } @@ -48,13 +50,61 @@ void Aidge::FCImpl_cuda::forward() { template<class T> void Aidge::FCImpl_cuda::forward_(const Tensor& input0, const Tensor& input1, const Tensor& input2, bool noBias, std::size_t outChannels) { - Aidge::fc_forward_cuda<T>( - input0.dims()[0], - input0.size() / input0.dims()[0], - outChannels, - noBias, - input0.getImpl()->rawPtr(), - input1.getImpl()->rawPtr(), - input2.getImpl()->rawPtr(), - std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()); + + const T * input = static_cast<const T*>(input0.getImpl()->rawPtr()); + const T * weights = static_cast<const T*>(input1.getImpl()->rawPtr()); + T * output = static_cast<T*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()); + + int n = outChannels; + int m = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->size()/n; + int k = input0.size()/m; + int lda = k; + int ldb = k; + int ldc = n; + const T alpha = 1.0; + const T beta = 0.0; + CHECK_CUBLAS_STATUS(cublasGemm(CudaContext::cublasHandle(), + CUBLAS_OP_T, + CUBLAS_OP_N, + n, + m, + k, + &alpha, + weights, + ldb, + input, + lda, + &beta, + output, + ldc)); + + if(!noBias){ + T* onesVector; + cudaMalloc((void**)&onesVector, outChannels * sizeof(T)); + // Fill the vector with ones + std::vector<T> onesVec(m, 1.0f); + CHECK_CUDA_STATUS(cudaMemcpy(onesVector, + &onesVec[0], + m * sizeof(T), + cudaMemcpyHostToDevice)); + const T * biases = static_cast<const T*>(input2.getImpl()->rawPtr()); + + CHECK_CUBLAS_STATUS(cublasGemm(CudaContext::cublasHandle(), + CUBLAS_OP_N, + CUBLAS_OP_N, + n, + m, + 1, + &alpha, + biases, + n, + onesVector, + 1, + &alpha, + output, + n)); + + cudaFree(onesVector); + } + } \ No newline at end of file diff --git a/src/operator/FCImpl_CUDA_kernels.cu b/src/operator/FCImpl_CUDA_kernels.cu index f263f150c93c6882dad5047f789bdc5103360a6f..a30519ebe4bf262b87bb9e07342f18525df2e8f4 100644 --- a/src/operator/FCImpl_CUDA_kernels.cu +++ b/src/operator/FCImpl_CUDA_kernels.cu @@ -12,37 +12,65 @@ #include "aidge/backend/cuda/operator/FCImpl_CUDA_kernels.hpp" -template<class T> -__global__ -void fc_forward_cuda_kernel(std::size_t nbInputs, std::size_t inChannels, std::size_t outChannels, bool noBias, const T* input, const T* weights, const T* bias, T *output) -{ - const std::size_t idx = blockIdx.x * blockDim.x + threadIdx.x; +namespace Aidge{ - for(std::size_t batch=idx; batch<nbInputs; ++batch) - { - for (std::size_t out = 0; out < outChannels; ++out) { - T sum = 0; - for (std::size_t in = 0; in < inChannels; ++in) { - sum += input[batch * inChannels + in] * weights[out * inChannels + in]; - } - output[batch * outChannels + out] = sum + (noBias ? 0 : bias[out]); - } - } +template <> +cublasStatus_t cublasGemm<__half>(cublasHandle_t handle, + cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, + const __half *alpha, + const __half *A, int lda, + const __half *B, int ldb, + const __half *beta, + __half *C, int ldc) +{ + return cublasHgemm(handle, + transa, transb, + m, n, k, + alpha, + A, lda, + B, ldb, + beta, + C, ldc); } -namespace Aidge{ -template<class T> -void fc_forward_cuda(std::size_t nbInputs, std::size_t inChannels, std::size_t outChannels, bool noBias, const void* input_, const void* weights_, const void* bias_, void* output_) +template <> +cublasStatus_t cublasGemm<float>(cublasHandle_t handle, + cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, + const float *alpha, + const float *A, int lda, + const float *B, int ldb, + const float *beta, + float *C, int ldc) { - const T* input = static_cast<const T*>(input_); - const T* weights = static_cast<const T*>(weights_); - const T* bias = static_cast<const T*>(bias_); - T * output = static_cast<T*>(output_); - - const dim3 blocksPerGrid = {(static_cast<unsigned int>(inChannels) + 255) / 256, 1, static_cast<unsigned int>(outChannels)}; - const dim3 threadsPerBlocks = {256, 1, 1}; - fc_forward_cuda_kernel<<<blocksPerGrid, threadsPerBlocks>>>(nbInputs, inChannels, outChannels, noBias, input, weights, bias, output); - - CHECK_CUDA_STATUS(cudaPeekAtLastError()); + return cublasSgemm(handle, + transa, transb, + m, n, k, + alpha, + A, lda, + B, ldb, + beta, + C, ldc); } + +template <> +cublasStatus_t cublasGemm<double>(cublasHandle_t handle, + cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, + const double *alpha, + const double *A, int lda, + const double *B, int ldb, + const double *beta, + double *C, int ldc) +{ + return cublasDgemm(handle, + transa, transb, + m, n, k, + alpha, + A, lda, + B, ldb, + beta, + C, ldc); } +} \ No newline at end of file diff --git a/unit_tests/Test_FCImpl.cpp b/unit_tests/Test_FCImpl.cpp index 05624d96b17a73ce84e9978599da849ad20c2764..54e37db15ded5546eb8fc3caacff9bae238b452c 100644 --- a/unit_tests/Test_FCImpl.cpp +++ b/unit_tests/Test_FCImpl.cpp @@ -76,7 +76,6 @@ TEST_CASE("[gpu/operator] FC(forward)", "[FC][GPU]") { for(int i = 0; i < myOutput->size(); i++){ const float targetOutput = *(static_cast<float*>(myOutput->getImpl()->rawPtr()) + i); - std::cout << "targetOutput " << targetOutput << ", out " << computedOutput[i]<<std::endl; REQUIRE(fabs(computedOutput[i] - targetOutput) < 1e-6); }