diff --git a/include/aidge/backend/cuda/operator/ILayerNormImpl_CUDA_kernels.hpp b/include/aidge/backend/cuda/operator/ILayerNormImpl_CUDA_kernels.hpp index aa54029ea29bc46809f227038a1a23d91bc161ee..4d320e8d29bd79e2d904486a3295ca05b567c1f1 100644 --- a/include/aidge/backend/cuda/operator/ILayerNormImpl_CUDA_kernels.hpp +++ b/include/aidge/backend/cuda/operator/ILayerNormImpl_CUDA_kernels.hpp @@ -68,7 +68,7 @@ void ILayerNormforward(const T* input, T* output, double SF, const T* weight_raw * @param size: Number of elements in the input tensor */ template <class T> -__global__ void ILayerNormbackward_(T* output_grad, T* input_tensor, T* output_tensor, T* mean, T* var, T* weight, T* bias, T* input_grad, T* weight_grad, T* bias_grad, int size); +__global__ void ILayerNormbackward_(T* output_grad, T* input_tensor, T* output_tensor, T* mean, T* var, T* weight, T* bias, T* input_grad, T* weight_grad, T* bias_grad, int size, const T alpha, const T beta); /** * @brief Wrapper function to execute ILayerNormbackward_ @@ -85,7 +85,7 @@ __global__ void ILayerNormbackward_(T* output_grad, T* input_tensor, T* output_t * @param size: Number of elements in the input tensor */ template <class T> -void ILayerNormbackward(const T* input_tensor, const T* output_grad, const T* output_tensor,const T* mean,const T* var, const T* weight, const T* bias, T* input_grad, T* weight_grad, T* bias_grad, size_t size); +void ILayerNormbackward(const T* input_tensor, const T* output_grad, const T* output_tensor,const T* mean,const T* var, const T* weight, const T* bias, T* input_grad, T* weight_grad, T* bias_grad, size_t size, const T alpha, const T beta); } diff --git a/include/aidge/backend/cuda/operator/ShiftGELUImpl_CUDA_kernels.hpp b/include/aidge/backend/cuda/operator/ShiftGELUImpl_CUDA_kernels.hpp index 14268521451a631ccb9194d44ed7543af8d494f5..f598ba29c623a9146966634dfdebccd89d9ae331 100644 --- a/include/aidge/backend/cuda/operator/ShiftGELUImpl_CUDA_kernels.hpp +++ b/include/aidge/backend/cuda/operator/ShiftGELUImpl_CUDA_kernels.hpp @@ -61,7 +61,7 @@ void ShiftGELUforward(const T* input, T* output, double SF,int N, int output_bit * @param size: Number of elements in the input tensor */ template <class T> -__global__ void ShiftGELUbackward_(T* input_grad, const T* output_tensor, const T* output_grad, int size); +__global__ void ShiftGELUbackward_(T* input_grad, const T* output_tensor, const T* output_grad, int size, const T alpha, const T beta); /** * @brief Wrapper function to execute ShiftGELUbackward_ @@ -71,7 +71,7 @@ __global__ void ShiftGELUbackward_(T* input_grad, const T* output_tensor, const * @param size: Number of elements in the input tensor */ template <class T> -void ShiftGELUbackward(const T* output_tensor, const T* output_grad, T* input_grad, size_t size); +void ShiftGELUbackward(const T* output_tensor, const T* output_grad, T* input_grad, size_t size, const T alpha, const T beta); } diff --git a/include/aidge/backend/cuda/operator/ShiftMaxImpl_CUDA_kernels.hpp b/include/aidge/backend/cuda/operator/ShiftMaxImpl_CUDA_kernels.hpp index 037a7cbb6362a8eca5a9e6f5a277b29a6a6bd907..43bfe7903f3ae8c0c59bc0deb286d5725990cc58 100644 --- a/include/aidge/backend/cuda/operator/ShiftMaxImpl_CUDA_kernels.hpp +++ b/include/aidge/backend/cuda/operator/ShiftMaxImpl_CUDA_kernels.hpp @@ -61,7 +61,7 @@ void ShiftMaxforward(const T* input, T* output, double SF,int N, int output_bits * @param dims: Dimensions of input tensor */ template <class T> -__global__ void ShiftMaxbackward_(T* input_grad, const T* output_tensor, const T* output_grad, const int* dims); +__global__ void ShiftMaxbackward_(T* input_grad, const T* output_tensor, const T* output_grad, const int* dims, const T alpha, const T beta); /** * @brief Wrapper function to execute ShiftMaxbackward_ @@ -72,7 +72,7 @@ __global__ void ShiftMaxbackward_(T* input_grad, const T* output_tensor, const T * @param dims: Dimensions of input tensor */ template <class T> -void ShiftMaxbackward(const T* output_tensor, const T* output_grad, T* input_grad, size_t size, std::vector<long unsigned int> dims); +void ShiftMaxbackward(const T* output_tensor, const T* output_grad, T* input_grad, size_t size, std::vector<long unsigned int> dims, const T alpha, const T beta); } diff --git a/src/operator/ILayerNormImpl.cpp b/src/operator/ILayerNormImpl.cpp index 47dd1d5d1a3f127c9e08788f605796020a7814a7..c887e79b88987a8c858975ceac78c61e1b68ce84 100644 --- a/src/operator/ILayerNormImpl.cpp +++ b/src/operator/ILayerNormImpl.cpp @@ -119,6 +119,9 @@ void Aidge::ILayerNormImpl_cuda::backward_(const Tensor& output_grad) { size_t size = output_grad.size(); std::vector<DimSize_t> dims_input = output_grad.dims(); + const T alpha = 1.0f; + const T beta = 1.0f; // accumulate + const T * output = static_cast<const T*>(std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr()); T * input_grad = static_cast<T*>(op.getInput(0)->grad()->getImpl()->rawPtr()); @@ -200,5 +203,5 @@ void Aidge::ILayerNormImpl_cuda::backward_(const Tensor& output_grad) { const T* var_ = flat_vars.data(); const T * output_grad_raw = static_cast<const T*>(output_grad.getImpl()->rawPtr()); - ILayerNormbackward(output, output_grad_raw, input, mean_, var_, weight, bias, input_grad, weight_grad, bias_grad, size); + ILayerNormbackward(output, output_grad_raw, input, mean_, var_, weight, bias, input_grad, weight_grad, bias_grad, size, alpha, beta); } diff --git a/src/operator/ILayerNormImpl_CUDA_kernels.cu b/src/operator/ILayerNormImpl_CUDA_kernels.cu index f4ecff9cc105c9ae4e22eb94ed060a4a8bf6454b..4ddf3ff5cc384551d9c63be1cebb8bb2519c6838 100644 --- a/src/operator/ILayerNormImpl_CUDA_kernels.cu +++ b/src/operator/ILayerNormImpl_CUDA_kernels.cu @@ -61,8 +61,7 @@ __global__ void ILayerNormforward_(T* input, double SF, int* dims, int* quantize { k = floorf((k + floorf(sum / k))/2); } - - long factor = (((1u << 31) - 1) / k); + int factor = (((1 << 31) - 1) / k); for (int i = 0; i < dims[3]; i++) { int idx = maxIdx + i; square_tensor[idx] = (biase[idx]/weight[idx])/new_SF; @@ -184,7 +183,7 @@ void ILayerNormforward<double>(const double* input, double* output, double SF, c } template <class T> -__global__ void ILayerNormbackward_(T* output_grad, T* input_tensor, T* output_tensor, T* mean, T* var, T* weight, T* bias, T* input_grad, T* weight_grad, T* bias_grad, int size) +__global__ void ILayerNormbackward_(T* output_grad, T* input_tensor, T* output_tensor, T* mean, T* var, T* weight, T* bias, T* input_grad, T* weight_grad, T* bias_grad, int size, const T alpha, const T beta) { int i = blockIdx.x * blockDim.x + threadIdx.x; if (i < size) { @@ -193,14 +192,14 @@ __global__ void ILayerNormbackward_(T* output_grad, T* input_tensor, T* output_t T d_mean = d_norm * -1 / sqrtf(var[i] + 1e-6) + d_var * -2 * mean[i] / size; T d_input = d_norm / sqrtf(var[i] + 1e-6) + d_var * 2 * (input_tensor[i] - mean[i]) / size + d_mean / size; - input_grad[i] = d_input; - weight_grad[i] = output_grad[i] * output_tensor[i]; - bias_grad[i] = output_grad[i]; + input_grad[i] = alpha * d_input + beta * input_grad[i]; + weight_grad[i] = alpha * output_grad[i] * output_tensor[i] + beta * weight_grad[i]; + bias_grad[i] = alpha * output_grad[i] + beta * bias_grad[i]; } } template <> -void ILayerNormbackward<float>(const float* input_tensor, const float* output_grad, const float* output_tensor,const float* mean,const float* var, const float* weight, const float* bias, float* input_grad, float* weight_grad, float* bias_grad, std::size_t size) +void ILayerNormbackward<float>(const float* input_tensor, const float* output_grad, const float* output_tensor,const float* mean,const float* var, const float* weight, const float* bias, float* input_grad, float* weight_grad, float* bias_grad, std::size_t size, const float alpha, const float beta) { float* input_cuda_tensor; cudaMalloc(&input_cuda_tensor,size*sizeof(float)); @@ -244,7 +243,7 @@ void ILayerNormbackward<float>(const float* input_tensor, const float* output_gr dim3 threadParBlock(256); dim3 Blocks((size + threadParBlock.x -1) / threadParBlock.x); - ILayerNormbackward_<<<Blocks,threadParBlock>>>(output_grad_,input_cuda_tensor,output_tensor_,mean_,var_,weight_,bias_,input_grad_, weight_grad_, bias_grad_, size); + ILayerNormbackward_<<<Blocks,threadParBlock>>>(output_grad_,input_cuda_tensor,output_tensor_,mean_,var_,weight_,bias_,input_grad_, weight_grad_, bias_grad_, size, alpha, beta); cudaDeviceSynchronize(); cudaError_t err = cudaGetLastError(); @@ -270,7 +269,7 @@ void ILayerNormbackward<float>(const float* input_tensor, const float* output_gr } template <> -void ILayerNormbackward<double>(const double* input_tensor, const double* output_grad, const double* output_tensor,const double* mean,const double* var, const double* weight, const double* bias, double* input_grad, double* weight_grad, double* bias_grad, std::size_t size) +void ILayerNormbackward<double>(const double* input_tensor, const double* output_grad, const double* output_tensor,const double* mean,const double* var, const double* weight, const double* bias, double* input_grad, double* weight_grad, double* bias_grad, std::size_t size, const double alpha, const double beta) { double* input_cuda_tensor; cudaMalloc(&input_cuda_tensor,size*sizeof(double)); @@ -314,7 +313,7 @@ void ILayerNormbackward<double>(const double* input_tensor, const double* output dim3 threadParBlock(256); dim3 Blocks((size + threadParBlock.x -1) / threadParBlock.x); - ILayerNormbackward_<<<Blocks,threadParBlock>>>(output_grad_,input_cuda_tensor,output_tensor_,mean_,var_,weight_,bias_,input_grad_, weight_grad_, bias_grad_, size); + ILayerNormbackward_<<<Blocks,threadParBlock>>>(output_grad_,input_cuda_tensor,output_tensor_,mean_,var_,weight_,bias_,input_grad_, weight_grad_, bias_grad_, size, alpha, beta); cudaDeviceSynchronize(); cudaError_t err = cudaGetLastError(); @@ -340,4 +339,4 @@ void ILayerNormbackward<double>(const double* input_tensor, const double* output cudaFree(bias_grad_); } -} +} \ No newline at end of file diff --git a/src/operator/ShiftGELUImpl.cpp b/src/operator/ShiftGELUImpl.cpp index c2774804d04a422aefd0c66ed0d1fc1d949b1f06..0b5b2c3b74fbb15b9d232ff06e3ae87e92233a08 100644 --- a/src/operator/ShiftGELUImpl.cpp +++ b/src/operator/ShiftGELUImpl.cpp @@ -109,11 +109,14 @@ void Aidge::ShiftGELUImpl_cuda::backward_(const Tensor& output_grad) { const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); const T * input = static_cast<const T*>(std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr()); + const T alpha = 1.0f; + const T beta = 1.0f; // accumulate + size_t size = output_grad.size(); T * output = static_cast<T*>(op.getInput(0)->grad()->getImpl()->rawPtr()); const T * output_grad_raw = static_cast<const T*>(output_grad.getImpl()->rawPtr()); - ShiftGELUbackward(input, output_grad_raw, output, size); + ShiftGELUbackward(input, output_grad_raw, output, size, alpha, beta); } \ No newline at end of file diff --git a/src/operator/ShiftGELUImpl_CUDA_kernels.cu b/src/operator/ShiftGELUImpl_CUDA_kernels.cu index aabd89c04e960f9f19eca69247173168d3eaf71e..58578836a0e4b2149d7e19cce3a9c8595a17b782 100644 --- a/src/operator/ShiftGELUImpl_CUDA_kernels.cu +++ b/src/operator/ShiftGELUImpl_CUDA_kernels.cu @@ -178,7 +178,7 @@ void ShiftGELUforward<double>(const double* input, double* output, double SF,int } template <class T> -__global__ void ShiftGELUbackward_(T* input_grad, const T* output_tensor, const T* output_grad, int size) { +__global__ void ShiftGELUbackward_(T* input_grad, const T* output_tensor, const T* output_grad, int size, const T alpha, const T beta) { int index = blockIdx.x * blockDim.x + threadIdx.x; if (index < size) { @@ -189,12 +189,12 @@ __global__ void ShiftGELUbackward_(T* input_grad, const T* output_tensor, const float pdf = exp(-0.5 * x * x) / sqrt(2.0 * M_PI); float dx = pdf + x * cdf; float backprop_grad = grad * dx; - input_grad[index] = backprop_grad; + input_grad[index] = alpha * backprop_grad + beta * input_grad[index]; } } template <> -void ShiftGELUbackward<float>(const float* output_tensor, const float* output_grad, float* input_grad, size_t size) +void ShiftGELUbackward<float>(const float* output_tensor, const float* output_grad, float* input_grad, size_t size, const float alpha, const float beta) { float* output_cuda_tensor; cudaMalloc(&output_cuda_tensor,size*sizeof(float)); @@ -210,7 +210,7 @@ void ShiftGELUbackward<float>(const float* output_tensor, const float* output_gr dim3 threadParBlock(256); dim3 Blocks((size + threadParBlock.x -1) / threadParBlock.x); - ShiftGELUbackward_<float><<<Blocks,threadParBlock>>>(input_grad_,output_cuda_tensor,output_grad_,size); + ShiftGELUbackward_<float><<<Blocks,threadParBlock>>>(input_grad_,output_cuda_tensor,output_grad_,size,alpha,beta); cudaDeviceSynchronize(); cudaError_t err = cudaGetLastError(); if(err != cudaSuccess) @@ -224,7 +224,7 @@ void ShiftGELUbackward<float>(const float* output_tensor, const float* output_gr } template <> -void ShiftGELUbackward<double>(const double* output_tensor, const double* output_grad, double* input_grad, size_t size) +void ShiftGELUbackward<double>(const double* output_tensor, const double* output_grad, double* input_grad, size_t size, const double alpha, const double beta) { double* output_cuda_tensor; cudaMalloc(&output_cuda_tensor,size*sizeof(double)); @@ -240,7 +240,7 @@ void ShiftGELUbackward<double>(const double* output_tensor, const double* output dim3 threadParBlock(256); dim3 Blocks((size + threadParBlock.x -1) / threadParBlock.x); - ShiftGELUbackward_<double><<<Blocks,threadParBlock>>>(input_grad_,output_cuda_tensor,output_grad_,size); + ShiftGELUbackward_<double><<<Blocks,threadParBlock>>>(input_grad_,output_cuda_tensor,output_grad_,size,alpha,beta); cudaDeviceSynchronize(); cudaError_t err = cudaGetLastError(); if(err != cudaSuccess) diff --git a/src/operator/ShiftMaxImpl.cpp b/src/operator/ShiftMaxImpl.cpp index 1134cc5d6b99e53eb492c82e32d811bc0bcba0e0..ebe0bdc3f15b033fecfc6abcc72b75015808aa6d 100644 --- a/src/operator/ShiftMaxImpl.cpp +++ b/src/operator/ShiftMaxImpl.cpp @@ -110,12 +110,15 @@ void Aidge::ShiftMaxImpl_cuda::backward_(const Tensor& output_grad) { const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); const T * output_tensor = static_cast<const T*>(std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr()); + const T alpha = 1.0f; + const T beta = 1.0f; // accumulate + size_t size = output_grad.size(); std::vector<DimSize_t> dims_output = output_grad.dims(); T * input_grad = static_cast<T*>(op.getInput(0)->grad()->getImpl()->rawPtr()); const T * output_grad_raw = static_cast<const T*>(output_grad.getImpl()->rawPtr()); - ShiftMaxbackward(output_tensor, output_grad_raw, input_grad, size, dims_output); + ShiftMaxbackward(output_tensor, output_grad_raw, input_grad, size, dims_output, alpha, beta); } \ No newline at end of file diff --git a/src/operator/ShiftMaxImpl_CUDA_kernels.cu b/src/operator/ShiftMaxImpl_CUDA_kernels.cu index 169770c237cd80a8c3357dbd483251b480808b1a..0dd462c0b33456faf48c9dca2d14038cd4863788 100644 --- a/src/operator/ShiftMaxImpl_CUDA_kernels.cu +++ b/src/operator/ShiftMaxImpl_CUDA_kernels.cu @@ -190,7 +190,7 @@ void ShiftMaxforward<double>(const double* input, double* output, double SF, int template <class T> -__global__ void ShiftMaxbackward_(T* input_grad, const T* output_tensor, const T* output_grad, const int* dims) { +__global__ void ShiftMaxbackward_(T* input_grad, const T* output_tensor, const T* output_grad, const int* dims, const T alpha, const T beta) { int index = blockIdx.x * blockDim.x + threadIdx.x; if (index < dims[0] * dims[1] * dims[2] * dims[3]) { int w = (index / dims[3]) % dims[2]; @@ -201,12 +201,12 @@ __global__ void ShiftMaxbackward_(T* input_grad, const T* output_tensor, const T for (int i = 0; i < dims[3]; ++i) { sum += output_tensor[n * dims[1] * dims[2] * dims[3] + h * dims[2] * dims[3] + w * dims[3] + i] * output_grad[n * dims[1] * dims[2] * dims[3] + h * dims[2] * dims[3] + w * dims[3] + i]; } - input_grad[index] = output_tensor[index] * (output_grad[index] - sum); + input_grad[index] = alpha * output_tensor[index] * (output_grad[index] - sum) + beta * input_grad[index]; } } template <> -void ShiftMaxbackward<float>(const float* output_tensor, const float* output_grad, float* input_grad, size_t size, std::vector<long unsigned int> dims) +void ShiftMaxbackward<float>(const float* output_tensor, const float* output_grad, float* input_grad, size_t size, std::vector<long unsigned int> dims, const float alpha, const float beta) { int dims_input_cuda[4] = {1, 1, 1, 1}; for (std::size_t i = 0; i < std::min(dims.size(), size_t(4)); ++i) { @@ -231,7 +231,7 @@ void ShiftMaxbackward<float>(const float* output_tensor, const float* output_gra dim3 threadParBlock(256); dim3 Blocks((size + threadParBlock.x -1) / threadParBlock.x); - ShiftMaxbackward_<float><<<Blocks,threadParBlock>>>(input_grad_,output_cuda_tensor,output_grad_,dims_); + ShiftMaxbackward_<float><<<Blocks,threadParBlock>>>(input_grad_,output_cuda_tensor,output_grad_,dims_,alpha,beta); cudaDeviceSynchronize(); cudaError_t err = cudaGetLastError(); if(err != cudaSuccess) @@ -247,7 +247,7 @@ void ShiftMaxbackward<float>(const float* output_tensor, const float* output_gra } template <> -void ShiftMaxbackward<double>(const double* output_tensor, const double* output_grad, double* input_grad, size_t size, std::vector<long unsigned int> dims) +void ShiftMaxbackward<double>(const double* output_tensor, const double* output_grad, double* input_grad, size_t size, std::vector<long unsigned int> dims, const double alpha, const double beta) { int dims_input_cuda[4] = {1, 1, 1, 1}; for (std::size_t i = 0; i < std::min(dims.size(), size_t(4)); ++i) { @@ -272,7 +272,7 @@ void ShiftMaxbackward<double>(const double* output_tensor, const double* output_ dim3 threadParBlock(256); dim3 Blocks((size + threadParBlock.x -1) / threadParBlock.x); - ShiftMaxbackward_<double><<<Blocks,threadParBlock>>>(input_grad_,output_cuda_tensor,output_grad_,dims_); + ShiftMaxbackward_<double><<<Blocks,threadParBlock>>>(input_grad_,output_cuda_tensor,output_grad_,dims_,alpha,beta); cudaDeviceSynchronize(); cudaError_t err = cudaGetLastError(); if(err != cudaSuccess)