Skip to content
Snippets Groups Projects
Commit d3798ad6 authored by Benjamin Halimi's avatar Benjamin Halimi
Browse files

set the LSQ op backward kernels to gradient accumulation mode

parent 11585297
No related branches found
No related tags found
3 merge requests!54Update 0.3.1 -> 0.4.0,!36Global Quantization Improvements,!30Support of the YOLOv3
Pipeline #62417 passed
......@@ -67,16 +67,16 @@ void LSQImpl_cpu_backward_kernel(const std::size_t inputLength,
const GI fullPrecScale_4 = input[4*i+3] / stepSize[0];
/*****************Features Gradient Computation********************/
// STE method is simply applied
grad_input[4*i] = grad_output[4*i]*((fullPrecScale_1 <= static_cast<GI>(range.first)) ? GI(0.0) :
grad_input[4*i] += grad_output[4*i]*((fullPrecScale_1 <= static_cast<GI>(range.first)) ? GI(0.0) :
(fullPrecScale_1 >= static_cast<GI>(range.second)) ? GI(0.0) :
GI(1.0));
grad_input[4*i+1] = grad_output[4*i+1]*((fullPrecScale_2 <= static_cast<GI>(range.first)) ? GI(0.0) :
grad_input[4*i+1] += grad_output[4*i+1]*((fullPrecScale_2 <= static_cast<GI>(range.first)) ? GI(0.0) :
(fullPrecScale_2 >= static_cast<GI>(range.second)) ? GI(0.0) :
GI(1.0));
grad_input[4*i+2] = grad_output[4*i+2]*((fullPrecScale_3 <= static_cast<GI>(range.first)) ? GI(0.0) :
grad_input[4*i+2] += grad_output[4*i+2]*((fullPrecScale_3 <= static_cast<GI>(range.first)) ? GI(0.0) :
(fullPrecScale_3 >= static_cast<GI>(range.second)) ? GI(0.0) :
GI(1.0));
grad_input[4*i+3] = grad_output[4*i+3]*((fullPrecScale_4 <= static_cast<GI>(range.first)) ? GI(0.0) :
grad_input[4*i+3] += grad_output[4*i+3]*((fullPrecScale_4 <= static_cast<GI>(range.first)) ? GI(0.0) :
(fullPrecScale_4 >= static_cast<GI>(range.second)) ? GI(0.0) :
GI(1.0));
......@@ -105,7 +105,7 @@ void LSQImpl_cpu_backward_kernel(const std::size_t inputLength,
// Process remaining
for(unsigned int i=inputLength-inputLength%4; i<inputLength; ++i) {
const GI fullPrecScale = input[i] / stepSize[0];
grad_input[i] = grad_output[i]*((fullPrecScale <= static_cast<GI>(range.first)) ? GI(0.0) :
grad_input[i] += grad_output[i]*((fullPrecScale <= static_cast<GI>(range.first)) ? GI(0.0) :
(fullPrecScale >= static_cast<GI>(range.second)) ? GI(0.0) :
GI(1.0));
GI qData = fullPrecScale;
......@@ -117,7 +117,7 @@ void LSQImpl_cpu_backward_kernel(const std::size_t inputLength,
const GI gradScaleFactor = static_cast<GI>(1.0f / std::sqrt(inputLength * range.second));
// 3rd: Multiply Step Size gradient with scale factor
grad_stepSize[0] = diffStepSize * gradScaleFactor;
grad_stepSize[0] += diffStepSize * gradScaleFactor;
}
......
......@@ -84,10 +84,11 @@ __global__ void LSQImpl_cuda_backward_kernel_(const std::size_t inputLength,
const GI fullPrecScale = input[i] / stepSize[0];
/*****************************Data/Weights Gradient Computation************************/
// STE method is simply apply:
grad_input[i] = grad_output[i]*( (fullPrecScale <= static_cast<GI>(range.first)) ? GI(0.0) :
(fullPrecScale >= static_cast<GI>(range.second)) ? GI(0.0) :
GI(1.0));
// STE method is simply applied :
// (we accumulate the gradient instead of replacing it)
grad_input[i] += grad_output[i] * ((fullPrecScale <= static_cast<GI>(range.first)) ? GI(0.0) :
(fullPrecScale >= static_cast<GI>(range.second)) ? GI(0.0) :
GI(1.0));
/*****************************Step Size Gradient Computation*************************/
GI qData = fullPrecScale;
......@@ -142,7 +143,9 @@ void Aidge::LSQImpl_cuda_backward_kernel(const std::size_t inputLength,
// for simplicity and foolproof-ness
thrust::device_ptr<GI> grad_workspacePtr(grad_workspace);
thrust::device_ptr<GI> grad_stepSizePtr(grad_stepSize);
grad_stepSizePtr[0] = thrust::reduce(grad_workspacePtr, grad_workspacePtr + inputLength, GI(0.0));
// We accumulate the stepSize gradient instead of replacing it
grad_stepSizePtr[0] += thrust::reduce(grad_workspacePtr, grad_workspacePtr + inputLength, GI(0.0));
//printf(" step grad = %f \n", (float) grad_stepSizePtr[0]);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment