From ace61c0d289111c31f3fb11c550e4b1637022a13 Mon Sep 17 00:00:00 2001 From: hrouis <houssemeddine.rouis92@gmail.com> Date: Wed, 27 Mar 2024 14:56:33 +0100 Subject: [PATCH] minor cleanups --- .../backend/cuda/operator/AvgPoolingImpl.hpp | 1 + .../aidge/backend/cuda/operator/FCImpl.hpp | 4 +++- .../cuda/operator/FCImpl_CUDA_kernels.hpp | 6 ----- .../backend/cuda/operator/MaxPoolingImpl.hpp | 1 + .../aidge/backend/cuda/operator/ReLUImpl.hpp | 1 + src/operator/AvgPoolingImpl.cpp | 3 +-- src/operator/FCImpl.cpp | 22 +++++++++-------- src/operator/MaxPoolingImpl.cpp | 3 +-- src/operator/ReLUImpl.cpp | 24 ++++++++++++------- 9 files changed, 36 insertions(+), 29 deletions(-) diff --git a/include/aidge/backend/cuda/operator/AvgPoolingImpl.hpp b/include/aidge/backend/cuda/operator/AvgPoolingImpl.hpp index f32dcb1..bec3728 100644 --- a/include/aidge/backend/cuda/operator/AvgPoolingImpl.hpp +++ b/include/aidge/backend/cuda/operator/AvgPoolingImpl.hpp @@ -33,6 +33,7 @@ private: // CuDNN specific variables cudnnPoolingDescriptor_t mAvgPoolingDesc = nullptr; cudnnPoolingMode_t mMode = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING; + std::shared_ptr<Tensor> mInputFallback; public: AvgPoolingImpl_cuda(const AvgPooling_Op<DIM> &op) : OperatorImpl(op) {} diff --git a/include/aidge/backend/cuda/operator/FCImpl.hpp b/include/aidge/backend/cuda/operator/FCImpl.hpp index ee2f1c5..023757b 100644 --- a/include/aidge/backend/cuda/operator/FCImpl.hpp +++ b/include/aidge/backend/cuda/operator/FCImpl.hpp @@ -32,7 +32,9 @@ class FCImplForward_cuda : public Registrable<FCImplForward_cuda, void(std::size_t , std::size_t, std::size_t, bool, const void* , const void* , const void* , void*)> {}; class FCImpl_cuda : public OperatorImpl { private: - // CuDNN specific variables + std::shared_ptr<Tensor> mInput0Fallback; + std::shared_ptr<Tensor> mInput1Fallback; + std::shared_ptr<Tensor> mInput2Fallback; public: diff --git a/include/aidge/backend/cuda/operator/FCImpl_CUDA_kernels.hpp b/include/aidge/backend/cuda/operator/FCImpl_CUDA_kernels.hpp index 9c83332..9084e01 100644 --- a/include/aidge/backend/cuda/operator/FCImpl_CUDA_kernels.hpp +++ b/include/aidge/backend/cuda/operator/FCImpl_CUDA_kernels.hpp @@ -32,11 +32,5 @@ cublasStatus_t cublasGemm(cublasHandle_t handle, const T *B, int ldb, const T *beta, T *C, int ldc); -// cublasGemm(cublasContext*&, cublasOperation_t, cublasOperation_t, int&, int&, int&, - // const type*, - // const __half*&, int&, - // const __half*&, int&, - // const type*, - // __half*&, int&)’ } #endif /* AIDGE_CUDA_OPERATOR_FCIMPL_FORWARD_KERNEL_H_ */ \ No newline at end of file diff --git a/include/aidge/backend/cuda/operator/MaxPoolingImpl.hpp b/include/aidge/backend/cuda/operator/MaxPoolingImpl.hpp index 9da9773..9216eca 100644 --- a/include/aidge/backend/cuda/operator/MaxPoolingImpl.hpp +++ b/include/aidge/backend/cuda/operator/MaxPoolingImpl.hpp @@ -33,6 +33,7 @@ private: // CuDNN specific variables cudnnPoolingDescriptor_t mMaxPoolingDesc = nullptr; cudnnPoolingMode_t mMode = CUDNN_POOLING_MAX; + std::shared_ptr<Tensor> mInputFallback; public: MaxPoolingImpl_cuda(const MaxPooling_Op<DIM> &op) : OperatorImpl(op) {} diff --git a/include/aidge/backend/cuda/operator/ReLUImpl.hpp b/include/aidge/backend/cuda/operator/ReLUImpl.hpp index 3b6cbcc..27d0a61 100644 --- a/include/aidge/backend/cuda/operator/ReLUImpl.hpp +++ b/include/aidge/backend/cuda/operator/ReLUImpl.hpp @@ -35,6 +35,7 @@ private: #else cudnnActivationMode_t mReLUDesc = nullptr; #endif + std::shared_ptr<Tensor> mInputFallback; public: ReLUImpl_cuda(const ReLU_Op &op) : OperatorImpl(op) {} diff --git a/src/operator/AvgPoolingImpl.cpp b/src/operator/AvgPoolingImpl.cpp index 861533e..eb9cc6a 100644 --- a/src/operator/AvgPoolingImpl.cpp +++ b/src/operator/AvgPoolingImpl.cpp @@ -25,8 +25,7 @@ void Aidge::AvgPoolingImpl_cuda<DIM>::forward() { assert(mOp.getRawInput(0) && "missing input #0"); - std::shared_ptr<Tensor> inputFallback; - const auto& input = std::static_pointer_cast<Tensor>(op.getRawInput(0))->refCastFrom(inputFallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))); + const auto& input = op.getInput(0)->refCastFrom(mInputFallback, *op.getOutput(0)); // Lazy-initialize CuDNN AvgPooling descriptor if (mAvgPoolingDesc == nullptr) { diff --git a/src/operator/FCImpl.cpp b/src/operator/FCImpl.cpp index a8f8da8..8b60f7f 100644 --- a/src/operator/FCImpl.cpp +++ b/src/operator/FCImpl.cpp @@ -28,15 +28,14 @@ void Aidge::FCImpl_cuda::forward() { assert(mOp.getRawInput(1) && "missing input #1"); assert(mOp.getRawInput(2) && "missing input #2"); - std::shared_ptr<Tensor> inputFallback, input1Fallback, input2Fallback; - const auto& input0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->refCastFrom(inputFallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))); - const auto& input1 = std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->refCastFrom(input1Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))); - const auto& input2 = std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->refCastFrom(input2Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))); - const auto& fcOp = static_cast<const FC_Op&>(mOp); bool noBias = fcOp.template getAttr<FCAttr::NoBias>(); std::size_t outChannels = static_cast<std::size_t>(fcOp.template getAttr<FCAttr::OutChannels>()); + const auto& input0 = fcOp.getInput(0)->refCastFrom(mInput0Fallback, *fcOp.getOutput(0)); + const auto& input1 = fcOp.getInput(1)->refCastFrom(mInput1Fallback, *fcOp.getOutput(0)); + const auto& input2 = fcOp.getInput(2)->refCastFrom(mInput2Fallback, *fcOp.getOutput(0)); + switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) { case DataType::Float64: forward_<double>(input0, input1, input2, noBias, outChannels); @@ -55,17 +54,19 @@ void Aidge::FCImpl_cuda::forward() { template<class T> void Aidge::FCImpl_cuda::forward_(const Tensor& input0, const Tensor& input1, const Tensor& input2, bool noBias, std::size_t outChannels) { - const T * input = static_cast<const T*>(input0.getImpl()->rawPtr()); const T * weights = static_cast<const T*>(input1.getImpl()->rawPtr()); T * output = static_cast<T*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()); + // Performing output = T(weights) * input + // [n x m] = [n x k] * [k x m] + // cublas is column-major so instead of transposing inputs, computing output [m x n] and transposing output, we compute output as [n x m] int n = outChannels; int m = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->size()/n; int k = input0.size()/m; - int lda = k; - int ldb = k; - int ldc = n; + int lda = k; // leading dimension of weights + int ldb = k; // leading dimension of input + int ldc = n; // leading dimension of output const T alpha = 1.0f; const T beta = 0.0f; CHECK_CUBLAS_STATUS(cublasGemm(CudaContext::cublasHandle(), @@ -93,7 +94,8 @@ void Aidge::FCImpl_cuda::forward_(const Tensor& input0, const Tensor& input1, co m * sizeof(T), cudaMemcpyHostToDevice)); const T * biases = static_cast<const T*>(input2.getImpl()->rawPtr()); - + // Performing output = biases * onesVector + output + // [n x m] = [n x 1] * [1 x m] + [n x m] CHECK_CUBLAS_STATUS(cublasGemm(CudaContext::cublasHandle(), CUBLAS_OP_N, CUBLAS_OP_N, diff --git a/src/operator/MaxPoolingImpl.cpp b/src/operator/MaxPoolingImpl.cpp index 19a567f..b8d7c81 100644 --- a/src/operator/MaxPoolingImpl.cpp +++ b/src/operator/MaxPoolingImpl.cpp @@ -25,8 +25,7 @@ void Aidge::MaxPoolingImpl_cuda<DIM>::forward() { assert(mOp.getRawInput(0) && "missing input #0"); - std::shared_ptr<Tensor> inputFallback; - const auto& input = std::static_pointer_cast<Tensor>(op.getRawInput(0))->refCastFrom(inputFallback, *std::static_pointer_cast<Tensor>(op.getRawOutput(0))); + const auto& input = op.getInput(0)->refCastFrom(mInputFallback, *op.getOutput(0)); // Lazy-initialize CuDNN MaxPooling descriptor if (mMaxPoolingDesc == nullptr) { diff --git a/src/operator/ReLUImpl.cpp b/src/operator/ReLUImpl.cpp index c880184..2ebd6b2 100644 --- a/src/operator/ReLUImpl.cpp +++ b/src/operator/ReLUImpl.cpp @@ -24,8 +24,7 @@ void Aidge::ReLUImpl_cuda::forward() { assert(mOp.getRawInput(0) && "missing input #0"); - std::shared_ptr<Tensor> inputFallback; - const auto& input = std::static_pointer_cast<Tensor>(op.getRawInput(0))->refCastFrom(inputFallback, *std::static_pointer_cast<Tensor>(op.getRawOutput(0))); + const auto& input = op.getInput(0)->refCastFrom(mInputFallback, *op.getOutput(0)); // Lazy-initialize CuDNN ReLU descriptor if (mReLUDesc == nullptr) { @@ -38,11 +37,18 @@ void Aidge::ReLUImpl_cuda::forward() { #endif } - if (std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType() == DataType::Float64) { - forward_<double>(input); - } - else { - forward_<float>(input); + switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) { + case DataType::Float64: + forward_<double>(input); + break; + case DataType::Float32: + forward_<float>(input); + break; + case DataType::Float16: + forward_<half>(input); + break; + default: + AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda"); } } @@ -64,7 +70,9 @@ void Aidge::ReLUImpl_cuda::forward_(const Tensor& input) { Aidge::ReLUImpl_cuda::~ReLUImpl_cuda() { if (mReLUDesc != nullptr) { - cudnnDestroyActivationDescriptor(mReLUDesc); + #if CUDNN_VERSION >= 5000 + cudnnDestroyActivationDescriptor(mReLUDesc); + #endif } } -- GitLab