From 60a15d6715af18c3b092cc44c0542391f928ab19 Mon Sep 17 00:00:00 2001
From: bhalimi <benjamin.halimi@cea.fr>
Date: Fri, 4 Apr 2025 15:28:36 +0000
Subject: [PATCH] folding of the producer's quantizers

---
 include/aidge/quantization/PTQ/PTQ.hpp |  6 ++--
 python_binding/pybind_PTQ.cpp          |  6 +++-
 src/PTQ/PTQ.cpp                        | 47 +++++++++++++++++++++++++-
 3 files changed, 55 insertions(+), 4 deletions(-)

diff --git a/include/aidge/quantization/PTQ/PTQ.hpp b/include/aidge/quantization/PTQ/PTQ.hpp
index 8bce793..f398ea3 100644
--- a/include/aidge/quantization/PTQ/PTQ.hpp
+++ b/include/aidge/quantization/PTQ/PTQ.hpp
@@ -173,9 +173,11 @@ namespace Aidge {
      * @param noQuant Whether to apply the rounding operations or not.
      * @param optimizeSigns Whether to take account of the IO signs of the operators or not.
      * @param singleShift Whether to convert the scaling factors into powers of two. If true the approximations are compensated using the previous nodes weights.
-     * @param verbose Whether to print internal informations about the quantization process.
+     * @param useCuda Whether to use the CUDA backend for performing the activation calibration or not
+     * @param foldGraph Whether to fold the weight quantizers after the quantization or not.
+     * @param verbose Whether to print internal informations about the quantization process or not.
      */
-    void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet, Clipping clippingMode, bool noQuant, bool optimizeSigns, bool singleShift, bool useCuda, bool verbose);
+    void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet, Clipping clippingMode, bool noQuant, bool optimizeSigns, bool singleShift, bool useCuda, bool foldGraph, bool verbose);
 
     /**
      * @brief Compute the weight ranges of every affine node. Provided for debugging purposes.
diff --git a/python_binding/pybind_PTQ.cpp b/python_binding/pybind_PTQ.cpp
index 12d1434..5f21b68 100644
--- a/python_binding/pybind_PTQ.cpp
+++ b/python_binding/pybind_PTQ.cpp
@@ -93,7 +93,7 @@ void init_PTQ(py::module &m) {
     :type verbose: bool
     )mydelimiter");
 
-    m.def("quantize_network", &quantizeNetwork ,py::arg("network"), py::arg("nb_bits"), py::arg("input_dataset"), py::arg("clipping_mode") = Clipping::MAX , py::arg("no_quantization") = false, py::arg("optimize_signs") = false, py::arg("single_shift") = false,  py::arg("use_cuda") = false, py::arg("verbose") = false,
+    m.def("quantize_network", &quantizeNetwork ,py::arg("network"), py::arg("nb_bits"), py::arg("input_dataset"), py::arg("clipping_mode") = Clipping::MAX , py::arg("no_quantization") = false, py::arg("optimize_signs") = false, py::arg("single_shift") = false,  py::arg("use_cuda") = false, py::arg("fold_graph") = true, py::arg("verbose") = false,
     R"mydelimiter(
     Main quantization routine. Performs every step of the quantization pipeline.
     :param network: The GraphView to be quantized.
@@ -110,6 +110,10 @@ void init_PTQ(py::module &m) {
     :type optimize_signs: bool
     :param single_shift: Whether to convert the scaling factors into powers of two. If true the approximations are compensated using the previous nodes weights.
     :type single_shift: bool
+    :param use_cuda: Whether to use the CUDA backend for performing the activation calibration or not
+    :type use_cuda: bool  
+    :param fold_graph: Whether to fold the weight quantizers after the quantization or not.
+    :type fold_graph: bool
     :param verbose: Whether to print internal informations about the quantization process.
     :type verbose: bool
     )mydelimiter");
diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp
index 00956e9..e9bc406 100644
--- a/src/PTQ/PTQ.cpp
+++ b/src/PTQ/PTQ.cpp
@@ -238,6 +238,44 @@ static void insertChildren(std::shared_ptr<Node> parent, std::shared_ptr<Node> n
     graphView->add(newNode);
 }
 
+// XXX XXX XXX
+void foldProducerQuantizers(std::shared_ptr<GraphView> graphView)
+{
+    std::vector<std::shared_ptr<Node>> quantizers;
+    for (std::shared_ptr<Node> node : graphView->getNodes())
+        if (hasAttr(node, "isProducerScaling"))
+            quantizers.push_back(node);
+
+    for (std::shared_ptr<Node> quantizer : quantizers)
+    {
+        // Log::notice(" Quantizer : {} {} ", quantizer->name(), quantizer->type());
+
+        // Set the param producer to be constant
+
+        auto paramProducer = quantizer->getParent(0);
+        auto paramProducerOp = std::static_pointer_cast<Producer_Op>(paramProducer->getOperator());
+        paramProducerOp->constant() = true;
+
+        // Set the internal producers of the quantizer to be constant
+
+        auto quantizerOp = std::static_pointer_cast<MetaOperator_Op> (quantizer->getOperator());
+    
+        auto microGraph = quantizerOp->getMicroGraph();
+        
+        for (auto producer : microGraph->getNodes())
+            if (producer->type() == "Producer")
+            {
+                auto producerOp = std::static_pointer_cast<Producer_Op>(producer->getOperator());
+                producerOp->constant() = true;  
+                //Log::notice("node : {} ", producer->name());
+            }   
+
+        expandMetaOp(quantizer); // mandatory for now !!!
+    }
+
+    constantFolding(graphView);
+}
+
 double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor)
 {
     // get the abs tensor
@@ -1167,7 +1205,7 @@ static void setupDataType(std::shared_ptr<GraphView> graphView, std::vector<std:
         tensor->setDataType(dataType);
 }
 
-void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet, Clipping clippingMode, bool noQuant, bool optimizeSigns, bool singleShift, bool useCuda, bool verbose)
+void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet, Clipping clippingMode, bool noQuant, bool optimizeSigns, bool singleShift, bool useCuda, bool foldGraph, bool verbose)
 {
     Log::notice(" === QUANT PTQ 0.2.21 === ");
 
@@ -1212,6 +1250,13 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits,
         performSingleShiftApproximation(graphView);
     }
 
+    if (foldGraph)
+    {
+        Log::notice(" Folding producer's quantizers ...");
+        foldProducerQuantizers(graphView);
+        Log::notice(" YYY ZZZ ");
+    }
+
     if (verbose)
         printScalingFactors(graphView);
 
-- 
GitLab