diff --git a/include/aidge/backend/cpu/operator/LSQImpl_kernels.hpp b/include/aidge/backend/cpu/operator/LSQImpl_kernels.hpp index ddb820997837ec9b3603c6007497c8161145d587..1ed05e232ba9f8332c372a9524edd26fc7d9c45a 100644 --- a/include/aidge/backend/cpu/operator/LSQImpl_kernels.hpp +++ b/include/aidge/backend/cpu/operator/LSQImpl_kernels.hpp @@ -67,16 +67,16 @@ void LSQImpl_cpu_backward_kernel(const std::size_t inputLength, const GI fullPrecScale_4 = input[4*i+3] / stepSize[0]; /*****************Features Gradient Computation********************/ // STE method is simply applied - grad_input[4*i] = grad_output[4*i]*((fullPrecScale_1 <= static_cast<GI>(range.first)) ? GI(0.0) : + grad_input[4*i] += grad_output[4*i]*((fullPrecScale_1 <= static_cast<GI>(range.first)) ? GI(0.0) : (fullPrecScale_1 >= static_cast<GI>(range.second)) ? GI(0.0) : GI(1.0)); - grad_input[4*i+1] = grad_output[4*i+1]*((fullPrecScale_2 <= static_cast<GI>(range.first)) ? GI(0.0) : + grad_input[4*i+1] += grad_output[4*i+1]*((fullPrecScale_2 <= static_cast<GI>(range.first)) ? GI(0.0) : (fullPrecScale_2 >= static_cast<GI>(range.second)) ? GI(0.0) : GI(1.0)); - grad_input[4*i+2] = grad_output[4*i+2]*((fullPrecScale_3 <= static_cast<GI>(range.first)) ? GI(0.0) : + grad_input[4*i+2] += grad_output[4*i+2]*((fullPrecScale_3 <= static_cast<GI>(range.first)) ? GI(0.0) : (fullPrecScale_3 >= static_cast<GI>(range.second)) ? GI(0.0) : GI(1.0)); - grad_input[4*i+3] = grad_output[4*i+3]*((fullPrecScale_4 <= static_cast<GI>(range.first)) ? GI(0.0) : + grad_input[4*i+3] += grad_output[4*i+3]*((fullPrecScale_4 <= static_cast<GI>(range.first)) ? GI(0.0) : (fullPrecScale_4 >= static_cast<GI>(range.second)) ? GI(0.0) : GI(1.0)); @@ -105,7 +105,7 @@ void LSQImpl_cpu_backward_kernel(const std::size_t inputLength, // Process remaining for(unsigned int i=inputLength-inputLength%4; i<inputLength; ++i) { const GI fullPrecScale = input[i] / stepSize[0]; - grad_input[i] = grad_output[i]*((fullPrecScale <= static_cast<GI>(range.first)) ? GI(0.0) : + grad_input[i] += grad_output[i]*((fullPrecScale <= static_cast<GI>(range.first)) ? GI(0.0) : (fullPrecScale >= static_cast<GI>(range.second)) ? GI(0.0) : GI(1.0)); GI qData = fullPrecScale; @@ -117,7 +117,7 @@ void LSQImpl_cpu_backward_kernel(const std::size_t inputLength, const GI gradScaleFactor = static_cast<GI>(1.0f / std::sqrt(inputLength * range.second)); // 3rd: Multiply Step Size gradient with scale factor - grad_stepSize[0] = diffStepSize * gradScaleFactor; + grad_stepSize[0] += diffStepSize * gradScaleFactor; } diff --git a/src/backend/cuda/operator/LSQImpl_CUDA_kernels.cu b/src/backend/cuda/operator/LSQImpl_CUDA_kernels.cu index 0d5490946af3a4ab172bafc13d9af8c191695b84..96065e41376a1facee8a05260f33a1ce68ceb92a 100644 --- a/src/backend/cuda/operator/LSQImpl_CUDA_kernels.cu +++ b/src/backend/cuda/operator/LSQImpl_CUDA_kernels.cu @@ -84,10 +84,11 @@ __global__ void LSQImpl_cuda_backward_kernel_(const std::size_t inputLength, const GI fullPrecScale = input[i] / stepSize[0]; /*****************************Data/Weights Gradient Computation************************/ - // STE method is simply apply: - grad_input[i] = grad_output[i]*( (fullPrecScale <= static_cast<GI>(range.first)) ? GI(0.0) : - (fullPrecScale >= static_cast<GI>(range.second)) ? GI(0.0) : - GI(1.0)); + // STE method is simply applied : + // (we accumulate the gradient instead of replacing it) + grad_input[i] += grad_output[i] * ((fullPrecScale <= static_cast<GI>(range.first)) ? GI(0.0) : + (fullPrecScale >= static_cast<GI>(range.second)) ? GI(0.0) : + GI(1.0)); /*****************************Step Size Gradient Computation*************************/ GI qData = fullPrecScale; @@ -142,7 +143,9 @@ void Aidge::LSQImpl_cuda_backward_kernel(const std::size_t inputLength, // for simplicity and foolproof-ness thrust::device_ptr<GI> grad_workspacePtr(grad_workspace); thrust::device_ptr<GI> grad_stepSizePtr(grad_stepSize); - grad_stepSizePtr[0] = thrust::reduce(grad_workspacePtr, grad_workspacePtr + inputLength, GI(0.0)); + + // We accumulate the stepSize gradient instead of replacing it + grad_stepSizePtr[0] += thrust::reduce(grad_workspacePtr, grad_workspacePtr + inputLength, GI(0.0)); //printf(" step grad = %f \n", (float) grad_stepSizePtr[0]);