diff --git a/include/aidge/backend/cuda/operator/PadImpl_CUDA_kernels.hpp b/include/aidge/backend/cuda/operator/PadImpl_CUDA_kernels.hpp index 11ddb0ea8b0e6603bf009c4ae0a7fa3247a8904f..b52d9883fa0acd320396bb358f253dcf62fea638 100644 --- a/include/aidge/backend/cuda/operator/PadImpl_CUDA_kernels.hpp +++ b/include/aidge/backend/cuda/operator/PadImpl_CUDA_kernels.hpp @@ -32,6 +32,8 @@ namespace Aidge unsigned int padType, T padValue, const T *input, - T *outputs); + T *outputs, + const T alpha, + const T beta); } #endif /* AIDGE_CUDA_OPERATOR_PADIMPL_KERNELS_H_ */ \ No newline at end of file diff --git a/include/aidge/backend/cuda/operator/ReduceImpl_CUDA_kernels.hpp b/include/aidge/backend/cuda/operator/ReduceImpl_CUDA_kernels.hpp index 9d352b8b1d14aeaa4230accd7aa81c279c18b7a8..bd9d4804330344e10cda9beffa595881d996ce9d 100644 --- a/include/aidge/backend/cuda/operator/ReduceImpl_CUDA_kernels.hpp +++ b/include/aidge/backend/cuda/operator/ReduceImpl_CUDA_kernels.hpp @@ -25,6 +25,8 @@ namespace Aidge const std::vector<std::size_t>& outputDims, const std::vector<int>& axes, const std::vector<std::size_t>& factors, - int outSize); + int outSize, + const T alpha, + const T beta); } #endif /* AIDGE_CUDA_OPERATOR_REDUCEIMPL_KERNEL_H_ */ \ No newline at end of file diff --git a/include/aidge/backend/cuda/operator/SqrtImpl_CUDA_kernels.hpp b/include/aidge/backend/cuda/operator/SqrtImpl_CUDA_kernels.hpp index 26f609cab8137a16fb5cf682561f237aaab74530..3aa43251b6b6da5db75b2d9b90837f06942f581c 100644 --- a/include/aidge/backend/cuda/operator/SqrtImpl_CUDA_kernels.hpp +++ b/include/aidge/backend/cuda/operator/SqrtImpl_CUDA_kernels.hpp @@ -25,10 +25,19 @@ namespace Aidge { template <class T> -void sqrtForward(const T* input, T* output, int size); +void sqrtForward(const T* input, + T* output, + int size, + const T alpha, + const T beta); template <class T> -void sqrtBackward(const T* input, const T* outputGrad, T* inputGrad, int size); +void sqrtBackward(const T* input, + const T* outputGrad, + T* inputGrad, + int size, + const T alpha, + const T beta); } #endif /* AIDGE_CUDA_OPERATOR_SQRTIMPL_KERNEL_H_ */ diff --git a/src/operator/AddImpl.cpp b/src/operator/AddImpl.cpp index de7ea925554906ea5fe1e5dcba268b17a06a47bd..8771a79e938dff893d5295bd847567a0dcb18f32 100644 --- a/src/operator/AddImpl.cpp +++ b/src/operator/AddImpl.cpp @@ -155,10 +155,12 @@ void Aidge::AddImpl_cuda::backward() { } template <class T> -void Aidge::AddImpl_cuda::backward_(const Tensor& outputGrad, const std::vector<std::vector<int>>& inputsDims, const std::vector<std::vector<int>>& inputsStrides) { +void Aidge::AddImpl_cuda::backward_(const Tensor& outputGrad, const std::vector<std::vector<int>>& inputsDims, const std::vector<std::vector<int>>& inputsStrides) +{ const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); + const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f; - const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f; + const typename Cuda::cudnn_scaling_type<T>::type beta = 1.0f; // accumulate for (std::size_t i = 0; i < inputsDims.size(); i++) { diff --git a/src/operator/AvgPoolingImpl.cpp b/src/operator/AvgPoolingImpl.cpp index d1270ee4b0a556e1053f3cfde8d71ec5efbee279..854171017899c7ea52a20f59e9181e8008a4d3ad 100644 --- a/src/operator/AvgPoolingImpl.cpp +++ b/src/operator/AvgPoolingImpl.cpp @@ -97,11 +97,13 @@ void Aidge::AvgPoolingImpl_cuda<DIM>::backward() { template <Aidge::DimIdx_t DIM> template <class T> -void Aidge::AvgPoolingImpl_cuda<DIM>::backward_(const Tensor& output_grad) { +void Aidge::AvgPoolingImpl_cuda<DIM>::backward_(const Tensor& output_grad) +{ const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); const T alpha = 1.0f; - const T beta = 0.0f; + const T beta = 1.0f; // accumulate + CHECK_CUDNN_STATUS( cudnnPoolingBackward(CudaContext::cudnnHandle(), mAvgPoolingDesc, diff --git a/src/operator/BatchNormImpl.cpp b/src/operator/BatchNormImpl.cpp index 8a17600aaf05dad74c031716404f15b272ecc408..f72e0abee0e925aaa213265670cc77f3ca1e13b3 100644 --- a/src/operator/BatchNormImpl.cpp +++ b/src/operator/BatchNormImpl.cpp @@ -193,9 +193,9 @@ template <class T> void Aidge::BatchNormImpl_cuda<DIM>::backward_(const Tensor& input0, const Tensor& outputGrad, const Tensor& weights) { const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f; - const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f; + const typename Cuda::cudnn_scaling_type<T>::type beta = 1.0f; // accumulate const typename Cuda::cudnn_scaling_type<T>::type alphaData = 1.0f; - const typename Cuda::cudnn_scaling_type<T>::type betaData = 0.0f; + const typename Cuda::cudnn_scaling_type<T>::type betaData = 1.0f; // accumulate cudnnTensorDescriptor_t scaleBiasDesc; // For scale, bias, var and mean, if we have a 1D tensor, the dim should go on the channels diff --git a/src/operator/ConvImpl.cpp b/src/operator/ConvImpl.cpp index 24e01db03692ffaa884b31a224a1947a9e1645a0..076dccab3e52cc458b7b95788890e7fb600e4e49 100644 --- a/src/operator/ConvImpl.cpp +++ b/src/operator/ConvImpl.cpp @@ -265,7 +265,7 @@ void Aidge::ConvImpl_cuda<DIM>::backward_(const Tensor& input0, const Tensor& in const auto& gradOutput = op.getOutput(0)->grad()->refCastFrom(gradOutputFallback, *(op.getInput(0)->grad())); const T alpha = 1.0f; - const T beta = 0.0f; + const T beta = 1.0f; // accumulate CHECK_CUDNN_STATUS(cudnnConvolutionBackwardFilter( CudaContext::cudnnHandle(), diff --git a/src/operator/DivImpl.cpp b/src/operator/DivImpl.cpp index 0326a60c1a3aabf43ca3a1d892328991d6d72366..8f5fdc717dd2337a1324a0c1be4887133bb70492 100644 --- a/src/operator/DivImpl.cpp +++ b/src/operator/DivImpl.cpp @@ -108,6 +108,6 @@ template <class T> void Aidge::DivImpl_cuda::backward_(const Tensor& outGrad) { const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f; - const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f; + const typename Cuda::cudnn_scaling_type<T>::type beta = 1.0f; // accumulate // TODO } \ No newline at end of file diff --git a/src/operator/FCImpl.cpp b/src/operator/FCImpl.cpp index 1a7bb8edb51312d08467354e20723ad19176bfee..55cb31b7492956a2c722e775c225276d22fbdf4e 100644 --- a/src/operator/FCImpl.cpp +++ b/src/operator/FCImpl.cpp @@ -116,6 +116,7 @@ void Aidge::FCImpl_cuda::forward_(const Tensor& input0, const Tensor& input1, co } void Aidge::FCImpl_cuda::backward() { + AIDGE_ASSERT(mOp.getRawInput(0), "missing input #0"); AIDGE_ASSERT(mOp.getRawInput(1), "missing input #1"); AIDGE_ASSERT(mOp.getRawInput(2), "missing input #2"); @@ -146,9 +147,11 @@ template<class T> void Aidge::FCImpl_cuda::backward_(const Tensor& input0, const Tensor& input1, const Tensor& input2, std::size_t outChannels) { const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f; - const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f; - const typename Cuda::cudnn_scaling_type<T>::type betaData = 0.0f; + const typename Cuda::cudnn_scaling_type<T>::type beta = 1.0f; // accumulate + const typename Cuda::cudnn_scaling_type<T>::type betaData = 1.0f; // accumulate + const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); + const T * input = static_cast<const T*>(input0.getImpl()->rawPtr()); const T * weights = static_cast<const T*>(input1.getImpl()->rawPtr()); const T * outputGrad = static_cast<const T*>(op.getOutput(0)->grad()->getImpl()->rawPtr()); @@ -175,7 +178,8 @@ void Aidge::FCImpl_cuda::backward_(const Tensor& input0, const Tensor& input1, c weightsGrad, m)); - if(!input2.empty()){ + if (!input2.empty()) + { T * biasGrad = static_cast<T*>(op.getInput(2)->grad()->getImpl()->rawPtr()); T* onesVector; CHECK_CUDA_STATUS(cudaMalloc((void**)&onesVector, m * sizeof(T))); @@ -200,6 +204,7 @@ void Aidge::FCImpl_cuda::backward_(const Tensor& input0, const Tensor& input1, c 1)); CHECK_CUDA_STATUS(cudaFree(onesVector)); } + // Performing inputGrad = (weights) * (outputGrad) CHECK_CUBLAS_STATUS(cublasGemm( CudaContext::cublasHandle(), diff --git a/src/operator/GlobalAveragePoolingImpl.cpp b/src/operator/GlobalAveragePoolingImpl.cpp index 8c83d477094d9cce41807d888cca57bd614e9cc6..c409c84a4eef466e43fa8dd2e2f138bb55158a0d 100644 --- a/src/operator/GlobalAveragePoolingImpl.cpp +++ b/src/operator/GlobalAveragePoolingImpl.cpp @@ -88,7 +88,7 @@ void Aidge::GlobalAveragePoolingImpl_cuda::backward_(const Tensor& output_grad) const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); const T alpha = 1.0f; - const T beta = 0.0f; + const T beta = 1.0f; // accumulate CHECK_CUDNN_STATUS( cudnnPoolingBackward(CudaContext::cudnnHandle(), mGlobalAveragePoolingDesc, diff --git a/src/operator/LnImpl.cpp b/src/operator/LnImpl.cpp index ed09ed45f5006c3760376a9d6f44f29d05bcfabe..7f0ac34d262f2c903e08dd93194cf9901da6282a 100644 --- a/src/operator/LnImpl.cpp +++ b/src/operator/LnImpl.cpp @@ -47,8 +47,7 @@ void Aidge::LnImpl_cuda::forward_(const Tensor& input) { const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); const T * inputPtr = static_cast<const T*>(input.getImpl()->rawPtr()); T * outputPtr = static_cast<T*>(op.getOutput(0)->getImpl()->rawPtr()); - - + Aidge::lnForward<T>(inputPtr, outputPtr, static_cast<int>(op.getOutput(0)->size())); } diff --git a/src/operator/MaxPoolingImpl.cpp b/src/operator/MaxPoolingImpl.cpp index 39050635102ebebaed8192cb4bb338e2bc31d5e8..19aacb5076e6ca32241eac6efa8b83bbadbcd456 100644 --- a/src/operator/MaxPoolingImpl.cpp +++ b/src/operator/MaxPoolingImpl.cpp @@ -102,7 +102,7 @@ void Aidge::MaxPoolingImpl_cuda<DIM>::backward_(const Tensor& output_grad) { const MaxPooling_Op<DIM>& op_ = static_cast<const MaxPooling_Op<DIM>&>(mOp); const T alpha = 1.0f; - const T beta = 0.0f; + const T beta = 1.0f; // accumulate CHECK_CUDNN_STATUS( cudnnPoolingBackward(CudaContext::cudnnHandle(), mMaxPoolingDesc, diff --git a/src/operator/MulImpl.cpp b/src/operator/MulImpl.cpp index af87251e8f29eded7d24cca2f08b880557ebb482..aa9b4c74785d3d5785f9d9d62d1a72503f8be104 100644 --- a/src/operator/MulImpl.cpp +++ b/src/operator/MulImpl.cpp @@ -172,10 +172,10 @@ void Aidge::MulImpl_cuda::backward() { template <class T> void Aidge::MulImpl_cuda::backward_(const Tensor& outputGrad, const std::vector<std::vector<int>>& inputsDims, const std::vector<std::vector<int>>& inputsStrides) { + const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f; - const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f; - + const typename Cuda::cudnn_scaling_type<T>::type beta = 1.0f; // accumulate // Create a Tensor descriptor with the broadcasted dims and strides cudnnTensorDescriptor_t tensorDesc0, tensorDesc1; diff --git a/src/operator/PadImpl.cpp b/src/operator/PadImpl.cpp index 3606ba66d002f1467aa65771015cab02c066d5a5..0b17332d84c9b7eccf864ab99c3f1bb453640aa4 100644 --- a/src/operator/PadImpl.cpp +++ b/src/operator/PadImpl.cpp @@ -60,7 +60,12 @@ void Aidge::PadImpl_cuda<DIM>::forward_(const Tensor &input) { const auto outDims = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(); const T *inputPtr = static_cast<const T *>(input.getImpl()->rawPtr()); + + const T alpha = 1.0f; + const T beta = 0.0f; + T *output = static_cast<T *>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()); + Aidge::cudaPadding(CudaContext::getDeviceProp(), outDims[1], outDims[3], @@ -74,7 +79,9 @@ void Aidge::PadImpl_cuda<DIM>::forward_(const Tensor &input) mPadType, static_cast<T>(mPadVal), inputPtr, - output); + output, + alpha, + beta); } template <Aidge::DimIdx_t DIM> @@ -116,7 +123,12 @@ void Aidge::PadImpl_cuda<DIM>::backward_(const Tensor &outGrad) { const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); const auto inputGradDims = op.getInput(0)->grad()->dims(); + + const T alpha = 1.0f; + const T beta = 1.0f; // accumulate + T *inputGrad = static_cast<T *>(op.getInput(0)->grad()->getImpl()->rawPtr()); + Aidge::cudaPadding(CudaContext::getDeviceProp(), inputGradDims[1], inputGradDims[3], @@ -130,7 +142,9 @@ void Aidge::PadImpl_cuda<DIM>::backward_(const Tensor &outGrad) mPadType, static_cast<T>(mPadVal), static_cast<const T *>(outGrad.getImpl()->rawPtr()), - inputGrad); + inputGrad, + alpha, + beta); } // Template declarations diff --git a/src/operator/PadImpl_CUDA_kernels.cu b/src/operator/PadImpl_CUDA_kernels.cu index a20a4c10a6cb5e783a09868389b8f968bc0f42a3..0628751311ab69d6b19cc4ff870f93f6dae2cf5a 100644 --- a/src/operator/PadImpl_CUDA_kernels.cu +++ b/src/operator/PadImpl_CUDA_kernels.cu @@ -23,7 +23,9 @@ __global__ void cudaPadding_kernel(unsigned int nbOutputs, unsigned int padType, T padValue, const T *input, - T *outputs) + T *outputs, + const T alpha, + const T beta) { const unsigned int inputOffset = (blockIdx.z * blockDim.z + threadIdx.z) * nbChannels * inputWidth * inputHeight; @@ -48,8 +50,8 @@ __global__ void cudaPadding_kernel(unsigned int nbOutputs, if (ix >= 0 && ix < (int)inputWidth && iy >= 0 && iy < (int)inputHeight) { - outputValue = input[ix + - iy * inputWidth + ch * inputWidth * inputHeight + inputOffset]; + int inputIndex = ix + iy * inputWidth + ch * inputWidth * inputHeight + inputOffset; + outputValue = input[inputIndex]; } } else if (padType == 1) // Edge padding @@ -57,8 +59,8 @@ __global__ void cudaPadding_kernel(unsigned int nbOutputs, int ix = max(0, min((int)inputWidth - 1, (int)ox - leftPad)); int iy = max(0, min((int)inputHeight - 1, (int)oy - topPad)); - outputValue = input[ix + - iy * inputWidth + ch * inputWidth * inputHeight + inputOffset]; + int inputIndex = ix + iy * inputWidth + ch * inputWidth * inputHeight + inputOffset; + outputValue = input[inputIndex]; } else if (padType == 2) // Reflect padding { @@ -74,18 +76,22 @@ __global__ void cudaPadding_kernel(unsigned int nbOutputs, if (iy >= (int)inputHeight) iy = (int)inputHeight - iy; - outputValue = input[ix + - iy * inputWidth + ch * inputWidth * inputHeight + inputOffset]; + int inputIndex = ix + iy * inputWidth + ch * inputWidth * inputHeight + inputOffset; + outputValue = input[inputIndex]; } else if (padType == 3) // Wrap padding { int ix = (inputWidth + (int)ox - leftPad) % inputWidth; int iy = (inputHeight + (int)oy - topPad) % inputHeight; - outputValue = input[ix + - iy * inputWidth + ch * inputWidth * inputHeight + inputOffset]; + int inputIndex = ix + iy * inputWidth + ch * inputWidth * inputHeight + inputOffset; + outputValue = input[inputIndex]; } - outputs[ox + oy * outputWidth + ch * outputWidth * outputHeight + outputOffset] = outputValue; + + int outputIndex = ox + oy * outputWidth + ch * outputWidth * outputHeight + outputOffset; + + // old : outputs[outputIndex] = outputValue; + outputs[outputIndex] = alpha * outputValue + beta * outputs[outputIndex]; } } } @@ -105,7 +111,9 @@ void Aidge::cudaPadding(const cudaDeviceProp &deviceProp, unsigned int padType, double padValue, const double *input, - double *outputs) + double *outputs, + const double alpha, + const double beta) { const unsigned int maxSize = (unsigned int)deviceProp.maxThreadsPerBlock; const unsigned int prefMultiple = (unsigned int)deviceProp.warpSize; @@ -131,7 +139,9 @@ void Aidge::cudaPadding(const cudaDeviceProp &deviceProp, padType, padValue, input, - outputs); + outputs, + alpha, + beta); CHECK_CUDA_STATUS(cudaPeekAtLastError()); } @@ -149,7 +159,9 @@ void Aidge::cudaPadding(const cudaDeviceProp &deviceProp, unsigned int padType, float padValue, const float *input, - float *outputs) + float *outputs, + const float alpha, + const float beta) { const unsigned int maxSize = (unsigned int)deviceProp.maxThreadsPerBlock; const unsigned int prefMultiple = (unsigned int)deviceProp.warpSize; @@ -175,7 +187,9 @@ void Aidge::cudaPadding(const cudaDeviceProp &deviceProp, padType, padValue, input, - outputs); + outputs, + alpha, + beta); CHECK_CUDA_STATUS(cudaPeekAtLastError()); } @@ -193,7 +207,9 @@ void Aidge::cudaPadding(const cudaDeviceProp &deviceProp, unsigned int padType, half padValue, const half *input, - half *outputs) + half *outputs, + const half alpha, + const half beta) { const unsigned int maxSize = (unsigned int)deviceProp.maxThreadsPerBlock; const unsigned int prefMultiple = (unsigned int)deviceProp.warpSize; @@ -219,6 +235,8 @@ void Aidge::cudaPadding(const cudaDeviceProp &deviceProp, padType, padValue, input, - outputs); + outputs, + alpha, + beta); CHECK_CUDA_STATUS(cudaPeekAtLastError()); } \ No newline at end of file diff --git a/src/operator/ReLUImpl.cpp b/src/operator/ReLUImpl.cpp index 80d52045e832b42a95b6d7448f2016530bb9d1ac..db2739290c1deab2995c360573afae410d2870b8 100644 --- a/src/operator/ReLUImpl.cpp +++ b/src/operator/ReLUImpl.cpp @@ -94,10 +94,13 @@ void Aidge::ReLUImpl_cuda::backward() { } template <class T> -void Aidge::ReLUImpl_cuda::backward_(const Tensor& output_grad) { +void Aidge::ReLUImpl_cuda::backward_(const Tensor& output_grad) +{ const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); + const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f; - const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f; + const typename Cuda::cudnn_scaling_type<T>::type beta = 1.0f; // accumulate + CHECK_CUDNN_STATUS( cudnnActivationBackward(CudaContext::cudnnHandle(), mReLUDesc, diff --git a/src/operator/ReduceImpl_CUDA_kernels.cu b/src/operator/ReduceImpl_CUDA_kernels.cu index 7002e34116d2c1050987dc0cb93dbf7339a7ea93..4ce42389624fcfb88b0d6eb35a746a24504ac35a 100644 --- a/src/operator/ReduceImpl_CUDA_kernels.cu +++ b/src/operator/ReduceImpl_CUDA_kernels.cu @@ -12,7 +12,18 @@ #include "aidge/backend/cuda/operator/ReduceImpl_CUDA_kernels.hpp" template <typename T> -__global__ void duplicateElements(const T* input, T* output, const std::size_t* shape, const std::size_t* new_shape, const int* axes, const std::size_t* factors, int num_dims, int num_axes) { +__global__ void duplicateElements( + const T* input, + T* output, + const std::size_t* shape, + const std::size_t* new_shape, + const int* axes, + const std::size_t* factors, + int num_dims, + int num_axes, + const T alpha, + const T beta) +{ int idx = blockIdx.x * blockDim.x + threadIdx.x; int input_size = 1; int output_size = 1; @@ -55,15 +66,25 @@ __global__ void duplicateElements(const T* input, T* output, const std::size_t* output_stride *= new_shape[i]; } - output[out_linear_idx] = input[in_linear_idx]; + // old : output[out_linear_idx] = input[in_linear_idx]; + output[out_linear_idx] = alpha * input[in_linear_idx] + beta * output[out_linear_idx]; delete[] out_idx; delete[] in_idx; } template <typename T> -void Aidge::ReduceBackward(const T* input, T* output, const std::vector<std::size_t>& inputDims, const std::vector<std::size_t>& outputDims, const std::vector<int>& axes, const std::vector<std::size_t>& factors, int outSize) { - +void Aidge::ReduceBackward( + const T* input, + T* output, + const std::vector<std::size_t>& inputDims, + const std::vector<std::size_t>& outputDims, + const std::vector<int>& axes, + const std::vector<std::size_t>& factors, + int outSize, + const T alpha, + const T beta) +{ std::size_t* d_shape; std::size_t* d_new_shape; int* d_axes; @@ -81,7 +102,18 @@ void Aidge::ReduceBackward(const T* input, T* output, const std::vector<std::siz int blockSize = 256; int numBlocks = (outSize + blockSize - 1) / blockSize; - duplicateElements<<<numBlocks, blockSize>>>(input, output, d_shape, d_new_shape, d_axes, d_factors, static_cast<int>(inputDims.size()), static_cast<int>(axes.size())); + duplicateElements<<<numBlocks, blockSize>>> ( + input, + output, + d_shape, + d_new_shape, + d_axes, + d_factors, + static_cast<int>(inputDims.size()), + static_cast<int>(axes.size()), + alpha, + beta); + cudaFree(d_shape); cudaFree(d_new_shape); cudaFree(d_axes); @@ -95,7 +127,9 @@ template void Aidge::ReduceBackward(const double* input, const std::vector<std::size_t>& outputDims, const std::vector<int>& axes, const std::vector<std::size_t>& factors, - int outSize); + int outSize, + const double alpha, + const double beta); template void Aidge::ReduceBackward(const float* input, float* output, @@ -103,7 +137,10 @@ template void Aidge::ReduceBackward(const float* input, const std::vector<std::size_t>& outputDims, const std::vector<int>& axes, const std::vector<std::size_t>& factors, - int outSize); + int outSize, + const float alpha, + const float beta); + template void Aidge::ReduceBackward(const half* input, half* output, @@ -111,4 +148,6 @@ template void Aidge::ReduceBackward(const half* input, const std::vector<std::size_t>& outputDims, const std::vector<int>& axes, const std::vector<std::size_t>& factors, - int outSize); + int outSize, + const half alpha, + const half beta); diff --git a/src/operator/ReduceMeanImpl.cpp b/src/operator/ReduceMeanImpl.cpp index 645929355d9c9036503ae8a90043573ed0aef4b1..2746d4c36fd2d2a5cfe8196d4b091c67ce0f2324 100644 --- a/src/operator/ReduceMeanImpl.cpp +++ b/src/operator/ReduceMeanImpl.cpp @@ -179,9 +179,15 @@ void Aidge::ReduceMeanImpl_cuda::backward() { template <class T> void Aidge::ReduceMeanImpl_cuda::backward_(const Tensor& outGrad, const std::vector<int>& axes) { + const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); + // const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f; // const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f; + + const T alpha = 1.0f; + const T beta = 1.0f; // accumulate + const T * outputGrad = static_cast<const T*>(op.getOutput(0)->grad()->getImpl()->rawPtr()); T * inputGrad = static_cast<T*>(op.getInput(0)->grad()->getImpl()->rawPtr()); @@ -196,5 +202,7 @@ void Aidge::ReduceMeanImpl_cuda::backward_(const Tensor& outGrad, const std::vec op.getInput(0)->grad()->dims(), axes, factors, - static_cast<int>(op.getInput(0)->grad()->size())); + static_cast<int>(op.getInput(0)->grad()->size()), + alpha, + beta); } diff --git a/src/operator/ReduceSumImpl.cpp b/src/operator/ReduceSumImpl.cpp index 84658cae495fefb1b893b78e1515e42a7d1f65f7..e8c5b1e98d10d40dc01157465ba21f3a5330ced4 100644 --- a/src/operator/ReduceSumImpl.cpp +++ b/src/operator/ReduceSumImpl.cpp @@ -178,7 +178,11 @@ void Aidge::ReduceSumImpl_cuda::backward() { } template <class T> -void Aidge::ReduceSumImpl_cuda::backward_(const Tensor& outGrad, const std::vector<int>& axes) { +void Aidge::ReduceSumImpl_cuda::backward_(const Tensor& outGrad, const std::vector<int>& axes) +{ + const T alpha = 1.0f; + const T beta = 1.0f; // accumulate + const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); const T * outputGrad = static_cast<const T*>(op.getOutput(0)->grad()->getImpl()->rawPtr()); @@ -195,5 +199,7 @@ void Aidge::ReduceSumImpl_cuda::backward_(const Tensor& outGrad, const std::vect op.getInput(0)->grad()->dims(), axes, factors, - static_cast<int>(op.getInput(0)->grad()->size())); + static_cast<int>(op.getInput(0)->grad()->size()), + alpha, + beta); } diff --git a/src/operator/ReshapeImpl.cpp b/src/operator/ReshapeImpl.cpp index 783e244057b0fc42a782fd363c3a99aa6d73b46b..49f732e120fd5de9454b47828caaa5b2ce6f5c58 100644 --- a/src/operator/ReshapeImpl.cpp +++ b/src/operator/ReshapeImpl.cpp @@ -32,10 +32,11 @@ void Aidge::ReshapeImpl_cuda::forward() { } void Aidge::ReshapeImpl_cuda::backward() { + const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); AIDGE_ASSERT(op.getOutput(0)->grad(), "missing output grad #0"); const auto& output_grad = op.getOutput(0)->grad()->refCastFrom(mOutputGradFallback, *op.getOutput(0)->grad()); - std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->grad() -> getImpl() -> setRawPtr(output_grad.getImpl()->rawPtr(), output_grad.getImpl()->size()); + std::static_pointer_cast<Tensor> (mOp.getRawInput(0))->grad()->getImpl()->setRawPtr(output_grad.getImpl()->rawPtr(), output_grad.getImpl()->size()); } diff --git a/src/operator/SigmoidImpl.cpp b/src/operator/SigmoidImpl.cpp index 386cd9d821b3019cf8f0de2cc757ae514446f1a6..f6b0695cc71ffb82d6a1514195b13d23bae4a213 100644 --- a/src/operator/SigmoidImpl.cpp +++ b/src/operator/SigmoidImpl.cpp @@ -95,9 +95,12 @@ void Aidge::SigmoidImpl_cuda::backward() { template <class T> void Aidge::SigmoidImpl_cuda::backward_(const Tensor& output_grad) { + const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); + const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f; - const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f; + const typename Cuda::cudnn_scaling_type<T>::type beta = 1.0f; // accumulate + CHECK_CUDNN_STATUS( cudnnActivationBackward(CudaContext::cudnnHandle(), mSigmoidDesc, diff --git a/src/operator/SqrtImpl.cpp b/src/operator/SqrtImpl.cpp index 60498e2907b0953b756064d6d19aeb1667ea7575..562179de0beb2abbab28bf490c7de3251bc84773 100644 --- a/src/operator/SqrtImpl.cpp +++ b/src/operator/SqrtImpl.cpp @@ -20,7 +20,8 @@ #include "aidge/operator/Sqrt.hpp" #include "aidge/utils/Types.h" -void Aidge::SqrtImpl_cuda::forward() { +void Aidge::SqrtImpl_cuda::forward() +{ const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); AIDGE_ASSERT(op.getInput(0), "missing input #0"); @@ -43,15 +44,25 @@ void Aidge::SqrtImpl_cuda::forward() { } template <class T> -void Aidge::SqrtImpl_cuda::forward_(const Tensor& input) { +void Aidge::SqrtImpl_cuda::forward_(const Tensor& input) +{ const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); + + const T alpha = 1.0f; + const T beta = 0.0f; + const T * inputPtr = static_cast<const T*>(input.getImpl()->rawPtr()); T * outputPtr = static_cast<T*>(op.getOutput(0)->getImpl()->rawPtr()); - Aidge::sqrtForward<T>(inputPtr, outputPtr, static_cast<int>(op.getOutput(0)->size())); + Aidge::sqrtForward<T>(inputPtr, + outputPtr, + static_cast<int>(op.getOutput(0)->size()), + alpha, + beta); } -void Aidge::SqrtImpl_cuda::backward() { +void Aidge::SqrtImpl_cuda::backward() +{ const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); AIDGE_ASSERT(op.getInput(0), "missing input #0"); @@ -76,11 +87,21 @@ void Aidge::SqrtImpl_cuda::backward() { } template <class T> -void Aidge::SqrtImpl_cuda::backward_(const Tensor& input, const Tensor& output_grad) { +void Aidge::SqrtImpl_cuda::backward_(const Tensor& input, const Tensor& output_grad) +{ const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); + + const T alpha = 1.0f; + const T beta = 1.0f; // accumulate + const T * inputPtr = static_cast<const T*>(input.getImpl()->rawPtr()); const T * outputGradPtr = static_cast<const T*>(output_grad.getImpl()->rawPtr()); T * inputGradPtr = static_cast<T*>(op.getInput(0)->grad()->getImpl()->rawPtr()); - Aidge::sqrtBackward<T>(inputPtr, outputGradPtr, inputGradPtr, static_cast<int>(op.getOutput(0)->size())); + Aidge::sqrtBackward<T>(inputPtr, + outputGradPtr, + inputGradPtr, + static_cast<int>(op.getOutput(0)->size()), + alpha, + beta); } diff --git a/src/operator/SqrtImpl_CUDA_kernels.cu b/src/operator/SqrtImpl_CUDA_kernels.cu index b8da130e8b7cf4d7f94f2567ba77c7da363441ea..c4146c23e05258733b630ce6ecaa995b7b8bb321 100644 --- a/src/operator/SqrtImpl_CUDA_kernels.cu +++ b/src/operator/SqrtImpl_CUDA_kernels.cu @@ -47,46 +47,70 @@ __device__ half mul_helper<half>(half a, half b) { // Forward Kernel template <class T> -__global__ void sqrtCUDAForwardKernel(const T* input, T* output, int size) { +__global__ void sqrtCUDAForwardKernel(const T* input, + T* output, + int size, + const T alpha, + const T beta) +{ int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= size) return; - output[idx] = sqrt_helper(input[idx]); + output[idx] = alpha * sqrt_helper(input[idx]) + beta * output[idx]; } - template <class T> -void Aidge::sqrtForward(const T* input, T* output, int size) +void Aidge::sqrtForward(const T* input, + T* output, + int size, + const T alpha, + const T beta) { const int blockSize = 256; int numBlocks = (size + blockSize - 1) / blockSize; // Launch the kernel - sqrtCUDAForwardKernel<<<numBlocks, blockSize>>>(input, output, size); + sqrtCUDAForwardKernel<<<numBlocks, blockSize>>>(input, output, size, alpha, beta); }; // Backward Kernel template <class T> -__global__ void sqrtCUDABackwardKernel(const T* input, const T* outputGrad, T* inputGrad, int size) { +__global__ void sqrtCUDABackwardKernel(const T* input, + const T* outputGrad, + T* inputGrad, + int size, + const T alpha, + const T beta) +{ int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= size) return; - inputGrad[idx] = outputGrad[idx] / mul_helper(static_cast<T>(2), sqrt_helper(input[idx])); + T val = outputGrad[idx] / mul_helper(static_cast<T>(2), sqrt_helper(input[idx])); + + inputGrad[idx] = alpha * val + beta * inputGrad[idx]; } template <class T> -void Aidge::sqrtBackward(const T* input, const T* outputGrad, T* inputGrad, int size) +void Aidge::sqrtBackward(const T* input, + const T* outputGrad, + T* inputGrad, + int size, + const T alpha, + const T beta) { const int blockSize = 256; int numBlocks = (size + blockSize - 1) / blockSize; // Launch the kernel - sqrtCUDABackwardKernel<<<numBlocks, blockSize>>>(input, outputGrad, inputGrad, size); + sqrtCUDABackwardKernel<<<numBlocks, blockSize>>>(input, outputGrad, inputGrad, size, alpha, beta); }; -template void Aidge::sqrtForward<double>(const double* input, double* output, int size); -template void Aidge::sqrtForward<float>(const float* input, float* output, int size); -template void Aidge::sqrtForward<half>(const half* input, half* output, int size); -template void Aidge::sqrtBackward<double>(const double* input, const double* outputGrad, double* inputGrad, int size); -template void Aidge::sqrtBackward<float>(const float* input, const float* outputGrad, float* inputGrad, int size); -template void Aidge::sqrtBackward<half>(const half* input, const half* outputGrad, half* inputGrad, int size); \ No newline at end of file +template void Aidge::sqrtForward<double>(const double* input, double* output, int size, const double alpha, const double beta); +template void Aidge::sqrtForward<float>(const float* input, float* output, int size, const float alpha, const float beta); +template void Aidge::sqrtForward<half>(const half* input, half* output, int size, const half alpha, const half beta); + +template void Aidge::sqrtBackward<double>(const double* input, const double* outputGrad, double* inputGrad, int size, const double alpha, const double beta); +template void Aidge::sqrtBackward<float>(const float* input, const float* outputGrad, float* inputGrad, int size, const float alpha, const float beta); +template void Aidge::sqrtBackward<half>(const half* input, const half* outputGrad, half* inputGrad, int size, const half alpha, const half beta); \ No newline at end of file diff --git a/src/operator/SubImpl.cpp b/src/operator/SubImpl.cpp index a04a1c3018b0c9ba455d21ba563253eb3e004e10..249d95f5a03c17e96db41c924361be3de1cbc6b0 100644 --- a/src/operator/SubImpl.cpp +++ b/src/operator/SubImpl.cpp @@ -155,11 +155,17 @@ void Aidge::SubImpl_cuda::backward() { } template <class T> -void Aidge::SubImpl_cuda::backward_(const Tensor& outputGrad, const std::vector<std::vector<int>>& inputsDims, const std::vector<std::vector<int>>& inputsStrides) { +void Aidge::SubImpl_cuda::backward_( + const Tensor& outputGrad, + const std::vector<std::vector<int>>& inputsDims, + const std::vector<std::vector<int>>& inputsStrides) +{ const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); + const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f; - const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f; + const typename Cuda::cudnn_scaling_type<T>::type beta = 1.0f; // accumulate const typename Cuda::cudnn_scaling_type<T>::type gamma = -1.0f; + for (std::size_t i = 0; i < inputsDims.size(); i++) { if (op.getInput(i)->size() == op.getOutput(0)->size()) diff --git a/src/operator/TanhImpl.cpp b/src/operator/TanhImpl.cpp index 96c0330febba35cfea04bbbac97d9308195d6309..2e61280ec83488cbba5b7b1fd23d49e70276790c 100644 --- a/src/operator/TanhImpl.cpp +++ b/src/operator/TanhImpl.cpp @@ -94,10 +94,13 @@ void Aidge::TanhImpl_cuda::backward() { } template <class T> -void Aidge::TanhImpl_cuda::backward_(const Tensor& output_grad) { +void Aidge::TanhImpl_cuda::backward_(const Tensor& output_grad) +{ const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); + const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f; - const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f; + const typename Cuda::cudnn_scaling_type<T>::type beta = 1.0f; // accumulate + CHECK_CUDNN_STATUS( cudnnActivationBackward(CudaContext::cudnnHandle(), mTanhDesc,