From e87d3bfb3271034b7c619c1970ac6d928fc60e5d Mon Sep 17 00:00:00 2001
From: bhalimi <benjamin.halimi@cea.fr>
Date: Tue, 29 Apr 2025 09:16:51 +0000
Subject: [PATCH] handle halfway rounding modes + scaling factor setter

---
 include/aidge/operator/PTQMetaOps.hpp |  7 +++
 src/PTQ/PTQ.cpp                       | 10 ++---
 src/operator/PTQMetaOps.cpp           | 65 +++++++++++++++------------
 3 files changed, 49 insertions(+), 33 deletions(-)

diff --git a/include/aidge/operator/PTQMetaOps.hpp b/include/aidge/operator/PTQMetaOps.hpp
index 9618220..8768b71 100644
--- a/include/aidge/operator/PTQMetaOps.hpp
+++ b/include/aidge/operator/PTQMetaOps.hpp
@@ -66,6 +66,13 @@ namespace Aidge {
      */
     void castQuantizerIOs(std::shared_ptr<Node>& quantizer, Aidge::DataType externalType);
 
+    /**
+     * @brief Given a Quantizer, set the coefficient of it's Mul node.
+     * @param quantizer The quantizer containing the multiplicative node.
+     * @param value The new value of the multiplicative coefficient.
+     */
+    void setScalingFactor(std::shared_ptr<Aidge::Node> quantizer, double value);
+
     /**
      * @brief Given a Quantizer, retreive the coefficient of it's Mul node.
      * @param quantizer The quantizer containing the multiplicative coefficient.
diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp
index 9e2a62c..bab86d0 100644
--- a/src/PTQ/PTQ.cpp
+++ b/src/PTQ/PTQ.cpp
@@ -1226,18 +1226,18 @@ void performSingleShiftApproximation(std::shared_ptr<GraphView> graphView)
             std::shared_ptr<Node> linearNode = node->getParent(0);
 
             double base = getScalingFactor(node);
-            double approx = std::pow(2, std::ceil(std::log2(base)));
-            double ratio = approx / base;
+            double approx = std::pow(2, static_cast<int>(std::ceil(std::log2(base))));
 
             // set the scaling factor value to the approximation ...
 
-            multiplyScalingFactor(node, ratio);
+            setScalingFactor(node, approx);
 
             // compensate the ratio using the previous node scaling factors ...
 
-            multiplyScalingFactor(linearNode->getParent(1), 1.0 / ratio);
+            double ratio = base / approx;
+            multiplyScalingFactor(linearNode->getParent(1), ratio);
             if (nodeHasBias(linearNode))
-                multiplyScalingFactor(linearNode->getParent(2), 1.0 / ratio);
+                multiplyScalingFactor(linearNode->getParent(2), ratio);
         }
     }
 }
diff --git a/src/operator/PTQMetaOps.cpp b/src/operator/PTQMetaOps.cpp
index 82b2e50..c83e600 100644
--- a/src/operator/PTQMetaOps.cpp
+++ b/src/operator/PTQMetaOps.cpp
@@ -82,7 +82,35 @@ std::shared_ptr<Node> Quantizer(double scalingFactor, const std::string& name)
     return quantizer;
 }
 
-void multiplyScalingFactor(std::shared_ptr<Aidge::Node> quantizer, double coeff)
+
+double getScalingFactor(std::shared_ptr<Node> quantizer)
+{
+    // Retreive the previous microGraph
+
+    auto quantizerOp = std::static_pointer_cast<MetaOperator_Op> (quantizer->getOperator());
+    auto microGraph = quantizerOp->getMicroGraph();
+
+    // Get the Mul node from the microGraph
+
+    std::shared_ptr<Node> mulNode = nullptr;
+    for (auto node : microGraph->getNodes())
+        if (node->type() == "Mul")
+            mulNode = node;
+
+    auto mulOp = std::static_pointer_cast<OperatorTensor> (mulNode->getOperator()); 
+
+    // Retreive the scaling factor
+
+    auto scalingFactorTensor = mulOp->getInput(1);
+
+    std::shared_ptr<Tensor> fallback;
+    const Tensor& localTensor = scalingFactorTensor->refCastFrom(fallback, DataType::Float64, "cpu");
+    double scalingFactor = localTensor.get<double>(0);
+
+    return scalingFactor;
+}
+
+void setScalingFactor(std::shared_ptr<Aidge::Node> quantizer, double value)
 {
     auto quantizerOp = std::static_pointer_cast<MetaOperator_Op> (quantizer->getOperator());
 
@@ -104,7 +132,7 @@ void multiplyScalingFactor(std::shared_ptr<Aidge::Node> quantizer, double coeff)
 
     // Create the new scaling factor tensor
 
-    std::shared_ptr<Tensor> newScalingFactorTensor = std::make_shared<Tensor>(prevScalingFactor * coeff);
+    std::shared_ptr<Tensor> newScalingFactorTensor = std::make_shared<Tensor>(value);
     newScalingFactorTensor->setBackend(scalingFactorTensor->backend());
     newScalingFactorTensor->setDataType(scalingFactorTensor->dataType());
 
@@ -114,6 +142,12 @@ void multiplyScalingFactor(std::shared_ptr<Aidge::Node> quantizer, double coeff)
     producer->getOperator()->setOutput(0, newScalingFactorTensor);
 }
 
+void multiplyScalingFactor(std::shared_ptr<Aidge::Node> quantizer, double coeff)
+{
+    double prevScalingFactor = getScalingFactor(quantizer);
+    setScalingFactor(quantizer, coeff * prevScalingFactor);
+}
+
 void appendRoundClip(std::shared_ptr<Node>& quantizer, double clipMin, double clipMax)
 {
     // Retreive a clone of the microGraph
@@ -131,7 +165,7 @@ void appendRoundClip(std::shared_ptr<Node>& quantizer, double clipMin, double cl
 
     // append round
 
-    auto roundNode = Round(quantizer->name() + "_RoundQuant");
+    auto roundNode = Round(Round_Op::HalfwayRounding::NextInteger, quantizer->name() + "_RoundQuant");
     outputNode->addChild(roundNode, 0, 0);
     microGraph->add(roundNode);
 
@@ -168,32 +202,7 @@ void appendRoundClip(std::shared_ptr<Node>& quantizer, double clipMin, double cl
     quantizer = newQuantizer;
 }
 
-double getScalingFactor(std::shared_ptr<Node> quantizer)
-{
-    // Retreive the previous microGraph
 
-    auto quantizerOp = std::static_pointer_cast<MetaOperator_Op> (quantizer->getOperator());
-    auto microGraph = quantizerOp->getMicroGraph();
-
-    // Get the Mul node from the microGraph
-
-    std::shared_ptr<Node> mulNode = nullptr;
-    for (auto node : microGraph->getNodes())
-        if (node->type() == "Mul")
-            mulNode = node;
-
-    auto mulOp = std::static_pointer_cast<OperatorTensor> (mulNode->getOperator()); 
-
-    // Retreive the scaling factor
-
-    auto scalingFactorTensor = mulOp->getInput(1);
-
-    std::shared_ptr<Tensor> fallback;
-    const Tensor& localTensor = scalingFactorTensor->refCastFrom(fallback, DataType::Float64, "cpu");
-    double scalingFactor = localTensor.get<double>(0);
-
-    return scalingFactor;
-}
 
 void setClipRange(std::shared_ptr<Node> quantizer, double min, double max)
 {
-- 
GitLab