From ccea932f276aad2ed919951693f7d7628cb02472 Mon Sep 17 00:00:00 2001
From: bhalimi <benjamin.halimi@cea.fr>
Date: Wed, 15 Jan 2025 13:18:27 +0000
Subject: [PATCH] set the CLE data types to double

---
 include/aidge/quantization/PTQ/CLE.hpp |  2 +-
 src/PTQ/CLE.cpp                        | 24 ++++++++++++------------
 2 files changed, 13 insertions(+), 13 deletions(-)

diff --git a/include/aidge/quantization/PTQ/CLE.hpp b/include/aidge/quantization/PTQ/CLE.hpp
index d94b6e9..77eaf7f 100644
--- a/include/aidge/quantization/PTQ/CLE.hpp
+++ b/include/aidge/quantization/PTQ/CLE.hpp
@@ -30,7 +30,7 @@ namespace Aidge
      * @param graphView The GraphView to process.
      * @param targetDelta the stopping criterion (typical value : 0.01)
      */
-    void crossLayerEqualization(std::shared_ptr<GraphView> graphView, float targetDelta = 0.01);
+    void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetDelta = 0.01);
 
 }
 
diff --git a/src/PTQ/CLE.cpp b/src/PTQ/CLE.cpp
index 1d5ccc7..2c81815 100644
--- a/src/PTQ/CLE.cpp
+++ b/src/PTQ/CLE.cpp
@@ -32,23 +32,23 @@ static std::shared_ptr<Tensor> getBiasTensor(std::shared_ptr<Node> node)
     return std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(2);
 }
 
-static void rescaleTensor(std::shared_ptr<Tensor> tensor, float scaling)
+static void rescaleTensor(std::shared_ptr<Tensor> tensor, double scaling)
 {
     // Get the tensor data pointer
-    float * castedTensor = static_cast <float *> (tensor->getImpl()->rawPtr());
+    double * castedTensor = static_cast<double *> (tensor->getImpl()->rawPtr());
 
     // Rescale the tensor
     for(std::size_t i = 0; i < tensor->size(); i++)
         castedTensor[i] *= scaling;
 }
 
-static float getTensorAbsoluteMax(std::shared_ptr <Tensor> tensor)
+static double getTensorAbsoluteMax(std::shared_ptr <Tensor> tensor)
 {
     // Get the tensor data pointer and edit it
-    float * castedTensor = static_cast<float*>(tensor->getImpl()->rawPtr());
+    double * castedTensor = static_cast<double*> (tensor->getImpl()->rawPtr());
 
     // Get the tensor absolute max value
-    float maxValue = 0.0f;
+    double maxValue = 0.0f;
     for(std::size_t i = 0; i < tensor->size(); ++i) {
         if(std::fabs(castedTensor[i]) > maxValue) {
             maxValue = std::fabs(castedTensor[i]);
@@ -57,7 +57,7 @@ static float getTensorAbsoluteMax(std::shared_ptr <Tensor> tensor)
     return maxValue;
 }
 
-void crossLayerEqualization(std::shared_ptr<GraphView> graphView, float targetDelta)
+void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetDelta)
 {
     std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView);
 
@@ -79,7 +79,7 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, float targetDe
         if (isAffine(node))
             affineNodeVector.push_back(node);
 
-    float maxRangeDelta;
+    double maxRangeDelta;
 
     do 
     {
@@ -94,18 +94,18 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, float targetDe
             std::shared_ptr<Node> n1 = affineNodeVector[i];
             std::shared_ptr<Node> n2 = affineNodeVector[i+1];
 
-            float r1 = getTensorAbsoluteMax(getWeightTensor(n1));
-            float r2 = getTensorAbsoluteMax(getWeightTensor(n2));
+            double r1 = getTensorAbsoluteMax(getWeightTensor(n1));
+            double r2 = getTensorAbsoluteMax(getWeightTensor(n2));
 
-            float s1 = std::sqrt(r1 * r2) / r1;
-            float s2 = std::sqrt(r1 * r2) / r2;
+            double s1 = std::sqrt(r1 * r2) / r1;
+            double s2 = std::sqrt(r1 * r2) / r2;
 
             rescaleTensor(getWeightTensor(n1), s1);
             rescaleTensor(getWeightTensor(n2), s2);
 
             rescaleTensor(getBiasTensor(n1), s1);
 
-            float rangeDelta = std::abs(r1 - r2);
+            double rangeDelta = std::abs(r1 - r2);
             if (rangeDelta > maxRangeDelta)
                 maxRangeDelta = rangeDelta;
         }
-- 
GitLab