diff --git a/include/aidge/backend/cuda/operator/SqrtImpl_CUDA_kernels.hpp b/include/aidge/backend/cuda/operator/SqrtImpl_CUDA_kernels.hpp index 3aa43251b6b6da5db75b2d9b90837f06942f581c..a1ac156424e657be66423ecae2d260bc962ef894 100644 --- a/include/aidge/backend/cuda/operator/SqrtImpl_CUDA_kernels.hpp +++ b/include/aidge/backend/cuda/operator/SqrtImpl_CUDA_kernels.hpp @@ -25,16 +25,16 @@ namespace Aidge { template <class T> -void sqrtForward(const T* input, - T* output, +void sqrtForward(const T* input, + T* output, int size, const T alpha, const T beta); template <class T> -void sqrtBackward(const T* input, - const T* outputGrad, - T* inputGrad, +void sqrtBackward(const T* input, + const T* outputGrad, + T* inputGrad, int size, const T alpha, const T beta); diff --git a/src/operator/SqrtImpl.cpp b/src/operator/SqrtImpl.cpp index 562179de0beb2abbab28bf490c7de3251bc84773..c1eccd107f421592619474fbb4c641a0f0b958bc 100644 --- a/src/operator/SqrtImpl.cpp +++ b/src/operator/SqrtImpl.cpp @@ -20,7 +20,7 @@ #include "aidge/operator/Sqrt.hpp" #include "aidge/utils/Types.h" -void Aidge::SqrtImpl_cuda::forward() +void Aidge::SqrtImpl_cuda::forward() { const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); @@ -44,7 +44,7 @@ void Aidge::SqrtImpl_cuda::forward() } template <class T> -void Aidge::SqrtImpl_cuda::forward_(const Tensor& input) +void Aidge::SqrtImpl_cuda::forward_(const Tensor& input) { const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); @@ -54,14 +54,14 @@ void Aidge::SqrtImpl_cuda::forward_(const Tensor& input) const T * inputPtr = static_cast<const T*>(input.getImpl()->rawPtr()); T * outputPtr = static_cast<T*>(op.getOutput(0)->getImpl()->rawPtr()); - Aidge::sqrtForward<T>(inputPtr, + Aidge::sqrtForward<T>(inputPtr, outputPtr, static_cast<int>(op.getOutput(0)->size()), alpha, beta); } -void Aidge::SqrtImpl_cuda::backward() +void Aidge::SqrtImpl_cuda::backward() { const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); @@ -87,7 +87,7 @@ void Aidge::SqrtImpl_cuda::backward() } template <class T> -void Aidge::SqrtImpl_cuda::backward_(const Tensor& input, const Tensor& output_grad) +void Aidge::SqrtImpl_cuda::backward_(const Tensor& input, const Tensor& output_grad) { const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); @@ -98,9 +98,9 @@ void Aidge::SqrtImpl_cuda::backward_(const Tensor& input, const Tensor& output_g const T * outputGradPtr = static_cast<const T*>(output_grad.getImpl()->rawPtr()); T * inputGradPtr = static_cast<T*>(op.getInput(0)->grad()->getImpl()->rawPtr()); - Aidge::sqrtBackward<T>(inputPtr, - outputGradPtr, - inputGradPtr, + Aidge::sqrtBackward<T>(inputPtr, + outputGradPtr, + inputGradPtr, static_cast<int>(op.getOutput(0)->size()), alpha, beta); diff --git a/src/operator/SqrtImpl_CUDA_kernels.cu b/src/operator/SqrtImpl_CUDA_kernels.cu index c4146c23e05258733b630ce6ecaa995b7b8bb321..7af45c3ea9a0b8fb8a4ceab4cb9944f38fab3111 100644 --- a/src/operator/SqrtImpl_CUDA_kernels.cu +++ b/src/operator/SqrtImpl_CUDA_kernels.cu @@ -48,10 +48,10 @@ __device__ half mul_helper<half>(half a, half b) { // Forward Kernel template <class T> __global__ void sqrtCUDAForwardKernel(const T* input, - T* output, - int size, - const T alpha, - const T beta) + T* output, + int size, + const T alpha, + const T beta) { int idx = blockIdx.x * blockDim.x + threadIdx.x; @@ -62,9 +62,9 @@ __global__ void sqrtCUDAForwardKernel(const T* input, template <class T> void Aidge::sqrtForward(const T* input, - T* output, - int size, - const T alpha, + T* output, + int size, + const T alpha, const T beta) { const int blockSize = 256; @@ -76,12 +76,12 @@ void Aidge::sqrtForward(const T* input, // Backward Kernel template <class T> -__global__ void sqrtCUDABackwardKernel(const T* input, - const T* outputGrad, - T* inputGrad, +__global__ void sqrtCUDABackwardKernel(const T* input, + const T* outputGrad, + T* inputGrad, int size, const T alpha, - const T beta) + const T beta) { int idx = blockIdx.x * blockDim.x + threadIdx.x; @@ -93,9 +93,9 @@ __global__ void sqrtCUDABackwardKernel(const T* input, } template <class T> -void Aidge::sqrtBackward(const T* input, - const T* outputGrad, - T* inputGrad, +void Aidge::sqrtBackward(const T* input, + const T* outputGrad, + T* inputGrad, int size, const T alpha, const T beta) @@ -113,4 +113,4 @@ template void Aidge::sqrtForward<half>(const half* input, half* output, int size template void Aidge::sqrtBackward<double>(const double* input, const double* outputGrad, double* inputGrad, int size, const double alpha, const double beta); template void Aidge::sqrtBackward<float>(const float* input, const float* outputGrad, float* inputGrad, int size, const float alpha, const float beta); -template void Aidge::sqrtBackward<half>(const half* input, const half* outputGrad, half* inputGrad, int size, const half alpha, const half beta); \ No newline at end of file +template void Aidge::sqrtBackward<half>(const half* input, const half* outputGrad, half* inputGrad, int size, const half alpha, const half beta);