Skip to content
Snippets Groups Projects
Commit ccea932f authored by Benjamin Halimi's avatar Benjamin Halimi
Browse files

set the CLE data types to double

parent 8c892145
No related branches found
No related tags found
3 merge requests!54Update 0.3.1 -> 0.4.0,!36Global Quantization Improvements,!32Fix the PTQ NaN bug
Pipeline #63029 passed
......@@ -30,7 +30,7 @@ namespace Aidge
* @param graphView The GraphView to process.
* @param targetDelta the stopping criterion (typical value : 0.01)
*/
void crossLayerEqualization(std::shared_ptr<GraphView> graphView, float targetDelta = 0.01);
void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetDelta = 0.01);
}
......
......@@ -32,23 +32,23 @@ static std::shared_ptr<Tensor> getBiasTensor(std::shared_ptr<Node> node)
return std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(2);
}
static void rescaleTensor(std::shared_ptr<Tensor> tensor, float scaling)
static void rescaleTensor(std::shared_ptr<Tensor> tensor, double scaling)
{
// Get the tensor data pointer
float * castedTensor = static_cast <float *> (tensor->getImpl()->rawPtr());
double * castedTensor = static_cast<double *> (tensor->getImpl()->rawPtr());
// Rescale the tensor
for(std::size_t i = 0; i < tensor->size(); i++)
castedTensor[i] *= scaling;
}
static float getTensorAbsoluteMax(std::shared_ptr <Tensor> tensor)
static double getTensorAbsoluteMax(std::shared_ptr <Tensor> tensor)
{
// Get the tensor data pointer and edit it
float * castedTensor = static_cast<float*>(tensor->getImpl()->rawPtr());
double * castedTensor = static_cast<double*> (tensor->getImpl()->rawPtr());
// Get the tensor absolute max value
float maxValue = 0.0f;
double maxValue = 0.0f;
for(std::size_t i = 0; i < tensor->size(); ++i) {
if(std::fabs(castedTensor[i]) > maxValue) {
maxValue = std::fabs(castedTensor[i]);
......@@ -57,7 +57,7 @@ static float getTensorAbsoluteMax(std::shared_ptr <Tensor> tensor)
return maxValue;
}
void crossLayerEqualization(std::shared_ptr<GraphView> graphView, float targetDelta)
void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetDelta)
{
std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView);
......@@ -79,7 +79,7 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, float targetDe
if (isAffine(node))
affineNodeVector.push_back(node);
float maxRangeDelta;
double maxRangeDelta;
do
{
......@@ -94,18 +94,18 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, float targetDe
std::shared_ptr<Node> n1 = affineNodeVector[i];
std::shared_ptr<Node> n2 = affineNodeVector[i+1];
float r1 = getTensorAbsoluteMax(getWeightTensor(n1));
float r2 = getTensorAbsoluteMax(getWeightTensor(n2));
double r1 = getTensorAbsoluteMax(getWeightTensor(n1));
double r2 = getTensorAbsoluteMax(getWeightTensor(n2));
float s1 = std::sqrt(r1 * r2) / r1;
float s2 = std::sqrt(r1 * r2) / r2;
double s1 = std::sqrt(r1 * r2) / r1;
double s2 = std::sqrt(r1 * r2) / r2;
rescaleTensor(getWeightTensor(n1), s1);
rescaleTensor(getWeightTensor(n2), s2);
rescaleTensor(getBiasTensor(n1), s1);
float rangeDelta = std::abs(r1 - r2);
double rangeDelta = std::abs(r1 - r2);
if (rangeDelta > maxRangeDelta)
maxRangeDelta = rangeDelta;
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment