From 6c170e5a708452bb4a58d336e8f81ae6ad0125cf Mon Sep 17 00:00:00 2001
From: bhalimi <benjamin.halimi@cea.fr>
Date: Wed, 5 Mar 2025 13:03:28 +0000
Subject: [PATCH 01/20] add a recipe for replacing MatMuls with FCs

---
 include/aidge/recipes/QuantRecipes.hpp        |  7 ++
 .../recipes/pybind_QuantRecipes.cpp           |  2 +
 src/operator/PTQMetaOps.cpp                   |  2 -
 src/recipes/QuantRecipes.cpp                  | 70 +++++++++++++++++++
 4 files changed, 79 insertions(+), 2 deletions(-)

diff --git a/include/aidge/recipes/QuantRecipes.hpp b/include/aidge/recipes/QuantRecipes.hpp
index 39349f9..55bcb0d 100644
--- a/include/aidge/recipes/QuantRecipes.hpp
+++ b/include/aidge/recipes/QuantRecipes.hpp
@@ -40,6 +40,13 @@ namespace Aidge
      * @param graphView The GraphView to process.
      */
     void sanitizeNodeNames(std::shared_ptr<GraphView> graphView);
+
+    /**
+     * @brief Given a GraphView, replace all it's MatMul nodes with equivalent FC ones
+     * @param graphView The GraphView to process.
+     */
+    void replaceMatMulWithFC(std::shared_ptr<GraphView> graphView);
+
 }
 
 #endif /* AIDGE_QUANTIZATION_QUANTRECIPES_H_ */
diff --git a/python_binding/recipes/pybind_QuantRecipes.cpp b/python_binding/recipes/pybind_QuantRecipes.cpp
index 0b96aef..ab9aaa7 100644
--- a/python_binding/recipes/pybind_QuantRecipes.cpp
+++ b/python_binding/recipes/pybind_QuantRecipes.cpp
@@ -25,6 +25,8 @@ void init_QuantRecipes(py::module &m) {
     m.def("pop_softmax", &popSoftMax, py::arg("network"));
     m.def("insert_batchnorm_nodes", &insertBatchNormNodes, py::arg("network"));
     m.def("sanitize_node_names", &sanitizeNodeNames, py::arg("network"));
+    m.def("replace_matmul_with_fc", &replaceMatMulWithFC, py::arg("network"));
+
 }
 
 } // namespace Aidge
diff --git a/src/operator/PTQMetaOps.cpp b/src/operator/PTQMetaOps.cpp
index f86d454..c70a772 100644
--- a/src/operator/PTQMetaOps.cpp
+++ b/src/operator/PTQMetaOps.cpp
@@ -70,8 +70,6 @@ static std::shared_ptr<Node> getSubNode(std::shared_ptr<GraphView> graphView, st
     return mulNode;
 }
 
-
-
 void updateScalingFactor(std::shared_ptr<Node> metaOpNode, double scalingFactor)
 {
     if(metaOpNode->type() != "Scaling" && metaOpNode->type() != "Quantizer")
diff --git a/src/recipes/QuantRecipes.cpp b/src/recipes/QuantRecipes.cpp
index f03eb46..1f498de 100644
--- a/src/recipes/QuantRecipes.cpp
+++ b/src/recipes/QuantRecipes.cpp
@@ -11,10 +11,13 @@
 
 
 #include "aidge/operator/Conv.hpp"
+#include "aidge/operator/Transpose.hpp"
 #include "aidge/operator/BatchNorm.hpp"
 //#include "aidge/quantization/PTQ/PTQ.hpp"
 #include "aidge/recipes/QuantRecipes.hpp"
 #include "aidge/graph/Node.hpp"
+#include "aidge/operator/FC.hpp"
+#include "aidge/graph/Matching.hpp"
 
 
 namespace Aidge 
@@ -121,4 +124,71 @@ void sanitizeNodeNames(std::shared_ptr<GraphView> graphView)
     }
 }
 
+void replaceMatMulWithFC(std::shared_ptr<GraphView> graphView)
+{
+    const auto matches = SinglePassGraphMatching(graphView).match("(MatMul#)");
+
+    for (const auto& match : matches)
+    {
+        auto node =  match.graph->rootNode(); 
+
+        std::pair<bool, bool> inputsAreProducers = {false, false};
+
+        if (node->getParent(0))
+            inputsAreProducers.first  = (node->getParent(0)->type() == "Producer");
+
+        if (node->getParent(1))
+            inputsAreProducers.second = (node->getParent(1)->type() == "Producer");
+
+        if (inputsAreProducers.first && inputsAreProducers.second)
+        { 
+            Log::warn(" Both input nodes of MatMul operator are Producers, it should be constant folded ! ");
+        }
+        else if (inputsAreProducers.first && !inputsAreProducers.second) 
+        {
+            Log::warn(" This input setup is not supported yet ! ");
+        }
+        else if (!inputsAreProducers.first && inputsAreProducers.second) 
+        {
+            // If the weight tensor is of rank 2, replace the MatMul with an FC !
+
+            std::shared_ptr<OperatorTensor> matMulOp = std::static_pointer_cast<OperatorTensor> (node->getOperator());
+            
+            std::shared_ptr<Tensor> weight = matMulOp->getInput(1);
+
+            if (weight->dims().size() == 2)
+            {
+                std::size_t inChannels  = weight->dims()[0];
+                std::size_t outChannels = weight->dims()[1];
+            
+                Log::notice(" ### channels = {} {} ", inChannels, outChannels);
+            
+                std::string name = node->name() + "_FC";
+            
+                std::shared_ptr<Node> FCNode = FC(inChannels, outChannels, true, name);
+
+                std::shared_ptr<OperatorTensor> FCOp = std::static_pointer_cast<OperatorTensor> (FCNode->getOperator());
+                
+                // Transpose the weights
+
+                auto transposeOp = Transpose_Op({1, 0});
+                transposeOp.setDataType(weight->dataType());
+                transposeOp.setBackend(weight->backend());
+            
+                transposeOp.associateInput(0, weight);
+                transposeOp.forward();
+                auto tranposedWeight = transposeOp.getOutput(0);
+ 
+                // Fill the FC weights
+
+                *FCOp->getInput(1) = *tranposedWeight;
+
+                // Replace the MatMul with the FC
+
+                bool success = graphView->replace({node, node->getParent(1)}, {FCNode, FCNode->getParent(1)});
+            }
+        }
+    }
+}
+
 }
\ No newline at end of file
-- 
GitLab


From 38e276af6dfa3abc686b2d7ccfcd41dbcd7b1e72 Mon Sep 17 00:00:00 2001
From: bhalimi <benjamin.halimi@cea.fr>
Date: Wed, 5 Mar 2025 13:10:20 +0000
Subject: [PATCH 02/20] minor changes

---
 src/recipes/QuantRecipes.cpp | 20 ++++++++++----------
 1 file changed, 10 insertions(+), 10 deletions(-)

diff --git a/src/recipes/QuantRecipes.cpp b/src/recipes/QuantRecipes.cpp
index 1f498de..04555aa 100644
--- a/src/recipes/QuantRecipes.cpp
+++ b/src/recipes/QuantRecipes.cpp
@@ -130,15 +130,15 @@ void replaceMatMulWithFC(std::shared_ptr<GraphView> graphView)
 
     for (const auto& match : matches)
     {
-        auto node =  match.graph->rootNode(); 
+        auto MatMulNode =  match.graph->rootNode(); 
 
         std::pair<bool, bool> inputsAreProducers = {false, false};
 
-        if (node->getParent(0))
-            inputsAreProducers.first  = (node->getParent(0)->type() == "Producer");
+        if (MatMulNode->getParent(0))
+            inputsAreProducers.first  = (MatMulNode->getParent(0)->type() == "Producer");
 
-        if (node->getParent(1))
-            inputsAreProducers.second = (node->getParent(1)->type() == "Producer");
+        if (MatMulNode->getParent(1))
+            inputsAreProducers.second = (MatMulNode->getParent(1)->type() == "Producer");
 
         if (inputsAreProducers.first && inputsAreProducers.second)
         { 
@@ -152,7 +152,7 @@ void replaceMatMulWithFC(std::shared_ptr<GraphView> graphView)
         {
             // If the weight tensor is of rank 2, replace the MatMul with an FC !
 
-            std::shared_ptr<OperatorTensor> matMulOp = std::static_pointer_cast<OperatorTensor> (node->getOperator());
+            std::shared_ptr<OperatorTensor> matMulOp = std::static_pointer_cast<OperatorTensor> (MatMulNode->getOperator());
             
             std::shared_ptr<Tensor> weight = matMulOp->getInput(1);
 
@@ -163,7 +163,7 @@ void replaceMatMulWithFC(std::shared_ptr<GraphView> graphView)
             
                 Log::notice(" ### channels = {} {} ", inChannels, outChannels);
             
-                std::string name = node->name() + "_FC";
+                std::string name = MatMulNode->name() + "_FC";
             
                 std::shared_ptr<Node> FCNode = FC(inChannels, outChannels, true, name);
 
@@ -177,15 +177,15 @@ void replaceMatMulWithFC(std::shared_ptr<GraphView> graphView)
             
                 transposeOp.associateInput(0, weight);
                 transposeOp.forward();
-                auto tranposedWeight = transposeOp.getOutput(0);
+                auto transposedWeight = transposeOp.getOutput(0);
  
                 // Fill the FC weights
 
-                *FCOp->getInput(1) = *tranposedWeight;
+                *FCOp->getInput(1) = *transposedWeight;
 
                 // Replace the MatMul with the FC
 
-                bool success = graphView->replace({node, node->getParent(1)}, {FCNode, FCNode->getParent(1)});
+                graphView->replace({MatMulNode, MatMulNode->getParent(1)}, {FCNode, FCNode->getParent(1)});
             }
         }
     }
-- 
GitLab


From a818dd58be4200f8eed7fbd463b8d27d9ebd50bf Mon Sep 17 00:00:00 2001
From: bhalimi <benjamin.halimi@cea.fr>
Date: Wed, 5 Mar 2025 13:14:00 +0000
Subject: [PATCH 03/20] minor changes (case)

---
 src/recipes/QuantRecipes.cpp | 16 ++++++++--------
 1 file changed, 8 insertions(+), 8 deletions(-)

diff --git a/src/recipes/QuantRecipes.cpp b/src/recipes/QuantRecipes.cpp
index 04555aa..795b3ca 100644
--- a/src/recipes/QuantRecipes.cpp
+++ b/src/recipes/QuantRecipes.cpp
@@ -130,15 +130,15 @@ void replaceMatMulWithFC(std::shared_ptr<GraphView> graphView)
 
     for (const auto& match : matches)
     {
-        auto MatMulNode =  match.graph->rootNode(); 
+        auto matMulNode =  match.graph->rootNode(); 
 
         std::pair<bool, bool> inputsAreProducers = {false, false};
 
-        if (MatMulNode->getParent(0))
-            inputsAreProducers.first  = (MatMulNode->getParent(0)->type() == "Producer");
+        if (matMulNode->getParent(0))
+            inputsAreProducers.first  = (matMulNode->getParent(0)->type() == "Producer");
 
-        if (MatMulNode->getParent(1))
-            inputsAreProducers.second = (MatMulNode->getParent(1)->type() == "Producer");
+        if (matMulNode->getParent(1))
+            inputsAreProducers.second = (matMulNode->getParent(1)->type() == "Producer");
 
         if (inputsAreProducers.first && inputsAreProducers.second)
         { 
@@ -152,7 +152,7 @@ void replaceMatMulWithFC(std::shared_ptr<GraphView> graphView)
         {
             // If the weight tensor is of rank 2, replace the MatMul with an FC !
 
-            std::shared_ptr<OperatorTensor> matMulOp = std::static_pointer_cast<OperatorTensor> (MatMulNode->getOperator());
+            std::shared_ptr<OperatorTensor> matMulOp = std::static_pointer_cast<OperatorTensor> (matMulNode->getOperator());
             
             std::shared_ptr<Tensor> weight = matMulOp->getInput(1);
 
@@ -163,7 +163,7 @@ void replaceMatMulWithFC(std::shared_ptr<GraphView> graphView)
             
                 Log::notice(" ### channels = {} {} ", inChannels, outChannels);
             
-                std::string name = MatMulNode->name() + "_FC";
+                std::string name = matMulNode->name() + "_FC";
             
                 std::shared_ptr<Node> FCNode = FC(inChannels, outChannels, true, name);
 
@@ -185,7 +185,7 @@ void replaceMatMulWithFC(std::shared_ptr<GraphView> graphView)
 
                 // Replace the MatMul with the FC
 
-                graphView->replace({MatMulNode, MatMulNode->getParent(1)}, {FCNode, FCNode->getParent(1)});
+                graphView->replace({matMulNode, matMulNode->getParent(1)}, {FCNode, FCNode->getParent(1)});
             }
         }
     }
-- 
GitLab


From 73c86bc45aeefacfc6d8ebb43e17482daa1cc100 Mon Sep 17 00:00:00 2001
From: bhalimi <benjamin.halimi@cea.fr>
Date: Wed, 5 Mar 2025 15:22:42 +0000
Subject: [PATCH 04/20] add support of MatMuls which have a parameter input

---
 src/PTQ/PTQ.cpp | 37 ++++++++++++++++++++++++++++++++-----
 1 file changed, 32 insertions(+), 5 deletions(-)

diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp
index 22c3438..638b6fd 100644
--- a/src/PTQ/PTQ.cpp
+++ b/src/PTQ/PTQ.cpp
@@ -29,6 +29,7 @@
 #include "aidge/operator/Conv.hpp"
 #include "aidge/operator/ArgMax.hpp"
 #include "aidge/operator/Reshape.hpp"
+#include "aidge/operator/MatMul.hpp"
 
 #include "aidge/recipes/Recipes.hpp"
 #include "aidge/recipes/QuantRecipes.hpp"
@@ -37,8 +38,29 @@ namespace Aidge
 {
 
 bool isAffine(std::shared_ptr<Node> node)
-{
-    return (affineNodeTypes.find(node->type()) != affineNodeTypes.end());
+{   
+    if (affineNodeTypes.find(node->type()) != affineNodeTypes.end())
+        return true;
+
+    // Check if the MatMul has a parameter on input 1 (in that case it's an Affine node)
+
+    if (node->type() == "MatMul") {
+        std::shared_ptr<Node> rightParent = node->getParent(1);
+        if (rightParent) {
+            bool hasWeight = rightParent->type() == "Producer";
+            if (hasWeight) {
+                return true;
+            } else if (rightParent->type() == "Mul") {
+                bool hasScaledWeight = rightParent->attributes()->hasAttr("quantization.ptq.isProducerScaling");
+                hasScaledWeight &= (rightParent->getParent(0)->type() == "Producer");
+                if (hasScaledWeight) {
+                    return true;
+                }
+            }
+        }
+    }
+
+    return false;
 }
 
 bool isSeamless(std::shared_ptr<Node> node)
@@ -58,14 +80,17 @@ bool isNotQuantized(std::shared_ptr<Node> node)
 
 bool checkArchitecture(std::shared_ptr<GraphView> graphView)
 {
-    std::set<std::string> otherNodeTypes({"Flatten", "Softmax", "BatchNorm2D", "ReLU", "Producer"});
+    std::set<std::string> removedNodeTypes({"Flatten", "Softmax", "BatchNorm2D"});
+
+    std::set<std::string> specialNodeTypes({"MatMul", "ReLU", "Producer"});
 
     std::set<std::string> notQuantizedNodesTypes;
 
     for (std::shared_ptr<Node> node : graphView->getNodes())
     {
-        bool isOther = otherNodeTypes.find(node->type()) != otherNodeTypes.end();
-        if (!isOther && !isAffine(node) && !isSeamless(node) && !isMerging(node) && !isNotQuantized(node)) {
+        bool isRemoved = removedNodeTypes.find(node->type()) != removedNodeTypes.end(); 
+        bool isSpecial = specialNodeTypes.find(node->type()) != specialNodeTypes.end();
+        if (!isRemoved && !isSpecial && !isAffine(node) && !isSeamless(node) && !isMerging(node) && !isNotQuantized(node)) {
             Log::warn(" GraphView can't be quantized : node type {} is not supported !", node->type());
             return false;
         }
@@ -87,6 +112,8 @@ bool checkArchitecture(std::shared_ptr<GraphView> graphView)
 void prepareNetwork(std::shared_ptr<GraphView> graphView)
 {
     removeFlatten(graphView);
+
+    // XXX remove this !
     sanitizeNodeNames(graphView);
 
     bool containsBatchNorm = false;
-- 
GitLab


From fd28511098b2c651f8d07be9c03ff8358184dd8e Mon Sep 17 00:00:00 2001
From: bhalimi <benjamin.halimi@cea.fr>
Date: Thu, 6 Mar 2025 14:54:55 +0000
Subject: [PATCH 05/20] code improvements

---
 include/aidge/quantization/QAT/QAT_LSQ.hpp |  50 ++--
 src/PTQ/CLE.cpp                            |  27 +-
 src/PTQ/PTQ.cpp                            |  85 ++++--
 src/QAT/QAT_LSQ.cpp                        | 322 ++++++++++-----------
 4 files changed, 263 insertions(+), 221 deletions(-)

diff --git a/include/aidge/quantization/QAT/QAT_LSQ.hpp b/include/aidge/quantization/QAT/QAT_LSQ.hpp
index b1e7b6f..7919b1a 100644
--- a/include/aidge/quantization/QAT/QAT_LSQ.hpp
+++ b/include/aidge/quantization/QAT/QAT_LSQ.hpp
@@ -9,30 +9,30 @@
  *
  ********************************************************************************/
 
- #ifndef AIDGE_QUANTIZATION_QUANTIZATION_QAT_LSQ_H_
- #define AIDGE_QUANTIZATION_QUANTIZATION_QAT_LSQ_H_
- 
- #include <cstddef>  // std::size_t
- #include <memory>
- 
- #include "aidge/data/Tensor.hpp"
- #include "aidge/graph/GraphView.hpp"
- 
- namespace Aidge {
- namespace QuantLSQ {
- 
- /**
-  * @brief Given a GraphView with parameters properly initialized, insert
-  * the LSQ quantizer nodes, and setup the adjustment their step-sizes.
-  * @param graphView The GraphView containing the network to quantize.
-  * @param nbBits Number of quantization bits.
-  */
- 
- void setupQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits);
- 
- }  // namespace QuantLSQ
- }  // namespace Aidge
- 
- #endif /* AIDGE_QUANTIZATION_QUANTIZATION_QAT_LSQ_H_ */
+#ifndef AIDGE_QUANTIZATION_QUANTIZATION_QAT_LSQ_H_
+#define AIDGE_QUANTIZATION_QUANTIZATION_QAT_LSQ_H_
+
+#include <cstddef>  // std::size_t
+#include <memory>
+
+#include "aidge/data/Tensor.hpp"
+#include "aidge/graph/GraphView.hpp"
+
+namespace Aidge {
+namespace QuantLSQ {
+
+/**
+ * @brief Given a GraphView with parameters properly initialized, insert
+ * the LSQ quantizer nodes, and setup the adjustment their step-sizes.
+ * @param graphView The GraphView containing the network to quantize.
+ * @param nbBits Number of quantization bits.
+ */
+
+void setupQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits);
+
+}  // namespace QuantLSQ
+}  // namespace Aidge
+
+#endif /* AIDGE_QUANTIZATION_QUANTIZATION_QAT_LSQ_H_ */
  
  
\ No newline at end of file
diff --git a/src/PTQ/CLE.cpp b/src/PTQ/CLE.cpp
index 7115a2f..5ffc8eb 100644
--- a/src/PTQ/CLE.cpp
+++ b/src/PTQ/CLE.cpp
@@ -109,11 +109,15 @@ static double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor)
 
     return localFlatTensor.get<double>(maxIndex);
 }
-//Function used to extraxt the local tensor (from a ProducerScalingNode)
-std::shared_ptr<Aidge::Tensor> getLocalTensor(std::shared_ptr<Node> node) {
-    if (node->getParent(1)->attributes()->hasAttr("quantization.ptq.isProducerScaling")) {
+
+// What is this thing ???
+// Function used to extract the local tensor (from a ProducerScalingNode)
+std::shared_ptr<Aidge::Tensor> getLocalTensor(std::shared_ptr<Node> node) 
+{
+    if (node->getParent(1)->attributes()->hasAttr("quantization.ptq.isProducerScaling")) 
+    {
         std::shared_ptr<Aidge::OperatorTensor> operatorTensor = std::static_pointer_cast<OperatorTensor>(node->getParent(1)->getOperator());
-        operatorTensor->forward();// We need the forward pass to compute the scaled value of the Tensor
+        operatorTensor->forward(); // We need the forward pass to compute the scaled value of the Tensor
         return operatorTensor->getOutput(0);
     } else {
         return getWeightTensor(node);
@@ -129,16 +133,16 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetD
     for (std::shared_ptr<Node> node : nodeVector)
     {
         if (node->getChildren().size() > 1) {
-            Log::notice(" Network have multiple branches, skipping the CLE ... ");
+            Log::warn(" Network have multiple branches, skipping the CLE ... ");
             return;
         }
         if (isNotQuantized(node)) {
-            Log::notice(" Network contains non linear nodes, skipping the CLE ... ");
+            Log::warn(" Network contains non linear nodes, skipping the CLE ... ");
             return;
         }
     }
 
-    Log::info(" Applying the Cross-Layer Equalization ... ");
+    Log::notice(" Applying the Cross-Layer Equalization ... ");
 
     // Get the vector of affine nodes
 
@@ -148,13 +152,14 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetD
             affineNodeVector.push_back(node);
 
     double maxRangeDelta;
-
     do
     {
         maxRangeDelta = 0.0;
         
         for (size_t i = 0; i < (affineNodeVector.size() - 1); i++)
         {
+            // Log::notice(" node index : {} ", i);
+
             std::shared_ptr<Node> n1 = affineNodeVector[i];
             std::shared_ptr<Node> n2 = affineNodeVector[i+1];
 
@@ -168,8 +173,12 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetD
             double s2 = std::sqrt(r1 * r2) / r2;
 
             insertScalingBelowProducer(n1->getParent(1), s1, graphView);
+
+            if (n1->type() != "MatMul") // TODO : enhance this !
+                if (n1->getParent(2))
+                    insertScalingBelowProducer(n1->getParent(2), s1, graphView);
+
             insertScalingBelowProducer(n2->getParent(1), s2, graphView);
-            insertScalingBelowProducer(n1->getParent(2), s1, graphView);
 
             double rangeDelta = std::abs(r1 - r2);
             if (rangeDelta > maxRangeDelta)
diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp
index 638b6fd..9d59502 100644
--- a/src/PTQ/PTQ.cpp
+++ b/src/PTQ/PTQ.cpp
@@ -37,11 +37,20 @@
 namespace Aidge
 {
 
+static bool hasAttr(std::shared_ptr<Aidge::Node> node, std::string attr)
+{
+    return node->attributes()->hasAttr("quantization.ptq." + attr);
+}
+
 bool isAffine(std::shared_ptr<Node> node)
 {   
     if (affineNodeTypes.find(node->type()) != affineNodeTypes.end())
         return true;
 
+    if ((node->type() == "MatMul") && hasAttr(node, "isWeighted"))
+        return true;
+
+    /*
     // Check if the MatMul has a parameter on input 1 (in that case it's an Affine node)
 
     if (node->type() == "MatMul") {
@@ -59,7 +68,7 @@ bool isAffine(std::shared_ptr<Node> node)
             }
         }
     }
-
+*/
     return false;
 }
 
@@ -111,23 +120,50 @@ bool checkArchitecture(std::shared_ptr<GraphView> graphView)
 
 void prepareNetwork(std::shared_ptr<GraphView> graphView)
 {
-    removeFlatten(graphView);
-
     // XXX remove this !
+
     sanitizeNodeNames(graphView);
 
-    bool containsBatchNorm = false;
+    // remove the flatten nodes
+
+    removeFlatten(graphView);
+
     std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView);
 
+    // tag the weighted nodes
+
     for (std::shared_ptr<Node> node : nodeVector)
+    {
+        bool isWeighted = isAffine(node);
+        if (node->type() == "MatMul") {
+            std::shared_ptr<Node> parent = node->getParent(1);
+            if (parent) {
+                if (parent->type() == "Producer") {
+                    isWeighted = true;
+                }
+            }
+        }
+
+        if (isWeighted) {
+            node->attributes()->addAttr("quantization.ptq.isWeighted", 0.0);
+        }
+    }
+
+    // fuse the batchnorms
+
+    bool containsBatchNorm = false;
+    for (std::shared_ptr<Node> node : nodeVector) {
         if (node->type() == "BatchNorm") {
             containsBatchNorm = true;
             break;
         }
+    }
 
     if (containsBatchNorm)
         fuseBatchNorm(graphView);   
 
+    // pop the softmax
+
     popSoftMax(graphView);
 }
 
@@ -148,8 +184,8 @@ static int getInputIndex(std::shared_ptr<Node> node, std::shared_ptr<Node> paren
 
 void multiplyScalingFactor(std::shared_ptr<Aidge::Node> node, double coeff)
 {
-    AIDGE_ASSERT(node->type() == "Mul" && (node->attributes()->hasAttr("quantization.ptq.isProducerScaling") || node->attributes()->hasAttr("quantization.ptq.isScaling")),
-    "Cannot update the scaling factor on Node of type {} with no scaling tag", node->type());
+    AIDGE_ASSERT(node->type() == "Mul" && hasAttr(node, "isProducerScaling") || hasAttr(node, "isScaling"),
+        "Cannot update the scaling factor on Node of type {} with no scaling tag", node->type());
     
     auto scalingFactorTensor = std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(1);
 
@@ -194,7 +230,7 @@ static void insertChildren(std::shared_ptr<Node> parent, std::shared_ptr<Node> n
 
 bool insertRoundBelowProducer(std::shared_ptr<Node> node, std::shared_ptr<GraphView> graphView)
 {
-    if (node->attributes()->hasAttr("quantization.ptq.isProducerScaling") && node->type() != "Round")
+    if (hasAttr(node, "isProducerScaling") && node->type() != "Round")
     {
         std::shared_ptr<Aidge::Node> roundNode = Round(node->name() + "_Round");
         roundNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
@@ -260,7 +296,7 @@ static std::vector<std::shared_ptr<Node>> removeProdScalingNodes(std::vector<std
 {
     std::vector<std::shared_ptr<Node>> remainingNodes;
     for (std::shared_ptr<Node> node : nodeVector)
-        if (!node->attributes()->hasAttr("quantization.ptq.isProducerScaling"))
+        if (!hasAttr(node, "isProducerScaling"))
             remainingNodes.push_back(node);
 
     return remainingNodes;
@@ -354,14 +390,14 @@ static std::shared_ptr<Aidge::Node> createScalingNode(std::string name, std::vec
 
 bool insertScalingBelowProducer(std::shared_ptr<Node> producerNode, double scalingFactor, std::shared_ptr<GraphView> graphView)
 {
-    if (producerNode->attributes()->hasAttr("quantization.ptq.isProducerRounding"))
+    if (hasAttr(producerNode, "isProducerRounding"))
     {
         // In this case we 'bump' the node to the one above him (an actual ProducerScaling)
         // because the round node is not usable (only used when SSA is enabled)
         producerNode = producerNode->getParent(0);
     }
 
-    if (producerNode->attributes()->hasAttr("quantization.ptq.isProducerScaling"))
+    if (hasAttr(producerNode, "isProducerScaling"))
     {
         // We accumulate the previous scaling factors by multiplying the SF of the ProducerScaling node 
         // (adding new nodes each time would make the graph unusable)
@@ -426,7 +462,7 @@ void insertResidualScalingNodes(std::shared_ptr<GraphView> graphView)
 static std::shared_ptr<Node> getPreviousScalingNode(std::shared_ptr<Node> node)
 {
     std::shared_ptr<Node> currNode = node;
-    while(!currNode->attributes()->hasAttr("quantization.ptq.isScaling"))
+    while(!hasAttr(currNode, "isScaling"))
     {
         if (currNode->getParents().size() == 0)
         {
@@ -511,7 +547,7 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView)
     for (std::shared_ptr<Node> node : nodeVector)
     {
         // Scaling nodes still have a ratio of 1, so they are seamless ...
-        if (node->type() == "ReLU" || node->attributes()->hasAttr("quantization.ptq.isScaling") || isSeamless(node))
+        if (node->type() == "ReLU" || hasAttr(node, "isScaling") || isSeamless(node))
         {
             if (node != firstNode)
             {
@@ -618,7 +654,7 @@ std::unordered_map<std::shared_ptr<Node>, double> computeRanges(std::shared_ptr<
     std::set<std::shared_ptr<Node>> nodeSet = graphView->getNodes();
     for (std::shared_ptr<Node> node : nodeSet)
     {
-        if ((scalingNodesOnly && (node->attributes()->hasAttr("quantization.ptq.isScaling"))) || (!scalingNodesOnly && (node->type() != "Producer")))
+        if ((scalingNodesOnly && hasAttr(node, "isScaling")) || (!scalingNodesOnly && (node->type() != "Producer")))
         {
             std::shared_ptr<Operator> nodeOperator = node->getOperator();
             std::shared_ptr<Tensor> valueTensor = std::static_pointer_cast<Tensor> (nodeOperator->getRawOutput(0));
@@ -640,7 +676,7 @@ std::unordered_map<std::shared_ptr<Node>, double> computeRanges(std::shared_ptr<
     // std::shared_ptr<Node> inputNode = getFirstNode(graphView);
 
     for (std::shared_ptr<Node> node : nodeSet)
-        if ((scalingNodesOnly && (node->attributes()->hasAttr("quantization.ptq.isScaling"))) || (!scalingNodesOnly && (node->type() != "Producer")))
+        if ((scalingNodesOnly && hasAttr(node, "isScaling")) || (!scalingNodesOnly && (node->type() != "Producer")))
             valueRanges.insert(std::make_pair(node, 0));
 
     if (useCuda)
@@ -667,7 +703,7 @@ std::unordered_map<std::shared_ptr<Node>, double> computeRanges(std::shared_ptr<
         std::unordered_map<std::shared_ptr<Node>, double> sampleRanges;
         for (std::shared_ptr<Node> node : nodeSet)
         {
-            if ((scalingNodesOnly && (node->attributes()->hasAttr("quantization.ptq.isScaling"))) || (!scalingNodesOnly && (node->type() != "Producer")))
+            if ((scalingNodesOnly && hasAttr(node, "isScaling")) || (!scalingNodesOnly && (node->type() != "Producer")))
             {
                 std::shared_ptr<Operator> nodeOperator = node->getOperator();
                 std::shared_ptr<Tensor> valueTensor = std::static_pointer_cast<Tensor> (nodeOperator->getRawOutput(0));
@@ -689,7 +725,7 @@ std::unordered_map<std::shared_ptr<Node>, double> computeRanges(std::shared_ptr<
 
         for (std::shared_ptr<Node> node : nodeSet)
         {
-            if ((scalingNodesOnly && (node->attributes()->hasAttr("quantization.ptq.isScaling"))) || (!scalingNodesOnly && (node->type() != "Producer")))
+            if ((scalingNodesOnly && hasAttr(node, "isScaling")) || (!scalingNodesOnly && (node->type() != "Producer")))
                 if (sampleRanges[node] > valueRanges[node])
                     valueRanges[node] = sampleRanges[node];
         }
@@ -735,7 +771,7 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_m
 
         // Use the Scaling nodes to rescale the ranges ...
 
-        if (node->attributes()->hasAttr("quantization.ptq.isScaling")) 
+        if (hasAttr(node, "isScaling")) 
         {
             std::shared_ptr<Node> prevNode = node->getParent(0);
 
@@ -828,7 +864,7 @@ std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> computeSignMap(
             signMap[node].second = false;
         } 
 
-        if (node->attributes()->hasAttr("quantization.ptq.isScaling")) 
+        if (hasAttr(node, "isScaling")) 
         {
             signMap[node].second = false;
 
@@ -875,7 +911,7 @@ std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> computeSignMap(
                 // Arbitration : Signed type wins !
                 for(std::shared_ptr<Node> parent : parentNodes)
                 {
-                    while (!parent->attributes()->hasAttr("quantization.ptq.isScaling"))
+                    while (!hasAttr(parent, "isScaling"))
                     {
                         signMap[parent] = std::make_pair(false, false);
                         // We are on a branch so nodes always have 1 parent ...
@@ -1016,7 +1052,7 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_
         
         // Handle the Scaling Nodes ...
 
-        if (node->attributes()->hasAttr("quantization.ptq.isScaling"))
+        if (hasAttr(node, "isScaling"))
         {
             // Don't touch the scalings that precede non-linearities ...
 
@@ -1150,7 +1186,7 @@ void performSingleShiftApproximation(std::shared_ptr<GraphView> graphView, bool
 static void printScalingFactors(std::shared_ptr<GraphView> graphView)
 {
     for (auto node : retrieveNodeVector(graphView))
-        if (node->attributes()->hasAttr("quantization.ptq.isScaling") || node->type() == "Quantizer")
+        if (hasAttr(node, "isScaling") || node->type() == "Quantizer")
         {
             double scalingFactor = getScalingFactor(node);
             Log::info(" {:.6f} ({})", scalingFactor, node->name());
@@ -1190,6 +1226,7 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits,
     Log::notice(" Inserting the scaling nodes ...");
     insertScalingNodes(graphView);
 
+    // TODO : double check this !
     crossLayerEqualization(graphView);
 
     Log::notice(" Normalizing the parameters ...");
@@ -1198,13 +1235,9 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits,
     Log::notice(" Computing the value ranges ...");
     std::unordered_map<std::shared_ptr<Node>, double> valueRanges = computeRanges(graphView, inputDataSet, true, useCuda);
 
-    //Log::info(" === RANGES (BEFORE ADJUST) ===");
-
     Log::notice(" Optimizing the clipping values ...");
     valueRanges = adjustRanges(clippingMode, valueRanges, nbBits, graphView, inputDataSet, useCuda, verbose);
 
-    //Log:debug("=== RANGES (AFTER ADJUST) ===");
-    //printRanges(graphView, valueRanges);
     Log::notice(" Normalizing the activations ...");
     normalizeActivations(graphView, valueRanges);
 
@@ -1256,7 +1289,7 @@ void clearBiases(std::shared_ptr<GraphView> graphView)
         if (node->type() == "FC" || node->type() == "Conv2D") {
             std::shared_ptr<Tensor> biasTensor = std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(2);
             //rescaleTensor(biasTensor, 0);
-            insertScalingBelowProducer(node->getParent(2),0,graphView);
+            insertScalingBelowProducer(node->getParent(2), 0, graphView);
         }
     }
 }
diff --git a/src/QAT/QAT_LSQ.cpp b/src/QAT/QAT_LSQ.cpp
index dcac681..6eae077 100644
--- a/src/QAT/QAT_LSQ.cpp
+++ b/src/QAT/QAT_LSQ.cpp
@@ -9,164 +9,164 @@
  *
  ********************************************************************************/
 
- #include "aidge/quantization/QAT/QAT_LSQ.hpp"
- #include "aidge/operator/LSQ.hpp"
- #include "aidge/operator/ReLU.hpp"
- 
- 
- #include "aidge/data/Tensor.hpp"
- #include "aidge/graph/GraphView.hpp"
- #include "aidge/scheduler/SequentialScheduler.hpp"
- #include "aidge/scheduler/Scheduler.hpp"
- #include "aidge/graph/Matching.hpp"
- #include "aidge/recipes/QuantRecipes.hpp"
- 
- 
- namespace Aidge 
- {
- 
- static float getTensorAbsMean(std::shared_ptr<Tensor> tensor)
- {
-     auto valueTensor = (*tensor).abs().mean();
-     std::shared_ptr<Tensor> fallback;
-     const Tensor& localTensor = valueTensor.refCastFrom(fallback, DataType::Float32, "cpu");
-     return localTensor.get<float>(0);
- }
- 
- static float getTensorStd(std::shared_ptr<Tensor> tensor)
- {
-     auto valueTensor = (*tensor);
-     
-     auto skewedTensor = valueTensor - valueTensor.mean();
-     auto squaredTensor = skewedTensor * skewedTensor;
-     auto varianceTensor = squaredTensor.mean();
- 
-     std::shared_ptr<Tensor> fallback;
-     auto localTensor = varianceTensor.refCastFrom(fallback, DataType::Float32, "cpu");
-     
-     float variance = localTensor.get<float>(0);
-     return std::sqrt(variance);
- }
- 
- 
- // INIT THE STEP SIZE OF A QUANTIZER NODE
- 
- static bool initStepSize(std::shared_ptr<Node> quantizer)
- {
-     const auto quantizerOp = std::static_pointer_cast<LSQ_Op>(quantizer->getOperator());
- 
-     // This formula is the one proposed in the paper ...
- 
-     // float inputAbsMean = getTensorAbsMean(quantizerOp->getInput(0));
-     // float stepSize = 2.0f * (inputAbsMean / std::sqrt(quantizerOp->range().second));
- 
-     // .. but this formula seems to work better !!!
- 
-     float inputStd = getTensorStd(quantizerOp->getInput(0));
-     float stepSize = 8.0f * (inputStd / (quantizerOp->range().second));
- 
-     // TODO : use the scalar constructor
-     auto stepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}})); 
- 
-     // XXX Manage backend here ?
-     stepSizeTensor->setBackend(quantizerOp->getInput(0)->backend());
-     stepSizeTensor->setDataType(quantizerOp->getInput(0)->dataType());
- 
-     auto stepSizeProducer = quantizer->getParent(1);
- 
-     stepSizeProducer->getOperator()->setOutput(0, stepSizeTensor);
- 
-     Log::notice(" [ INIT STEP SIZE = {} ] ", stepSize);
- 
-     return false;
- }
- 
- static void setupInputQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits)
- {
-     const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|PaddedConv2D#|FC#)");
- 
-     for (const auto& match : matches) 
-     {
-         auto linearNode = match.graph->rootNode(); 
- 
-         // Log::notice(" SET INPUT QUANTIZER : {} ", linearNode->type());
- 
-         std::pair<int, int> signedRange = {-std::pow(2, nbBits - 1), std::pow(2, nbBits - 1) - 1};
-         std::pair<int, int> unsignedRange = {0, std::pow(2, nbBits) - 1};
- 
-         // Create the input quantizer node
- 
-         auto quantizerName = makeUniqueName(linearNode->name() + "_lsq_i", graphView);  
-         auto quantizerNode = LSQ(signedRange, quantizerName);
- 
-         // Init the step-size using the node call stack
- 
-         quantizerNode->addBeforeForward([quantizerNode](){ return initStepSize(quantizerNode); });
- 
-         // Absorb the ReLU when possible ...
- 
-         bool nodeHasParent = static_cast<bool> (linearNode->getParents()[0]);  // XXX is this safe ?
- 
-         if (nodeHasParent) 
-         {
-             bool allParentsAreReLU = true;
-             for (auto parentNode : linearNode->getParents())
-                 if (parentNode->type() != "ReLU")
-                     allParentsAreReLU = false;
- 
-             if (allParentsAreReLU) {
-                 auto quantizerOp = std::static_pointer_cast<LSQ_Op> (quantizerNode->getOperator());
-                 quantizerOp->range() = unsignedRange;
-             }
- 
-             // TODO : remove the ReLUs when possible
-         }
- 
-         // Insert the quantizer in the graphView ...
-         // (We need to handle the case where the linear node is the first one)
- 
-         if (nodeHasParent) {
-             graphView->insertParent(linearNode, quantizerNode, 0, 0, 0);
-         } else {
-             quantizerNode->addChild(graphView);
-             graphView->add(quantizerNode);
-         }
-     }
- }
- 
- // PARAM QUANTIZERS INSERTION
- 
- static void setupParamQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits)
- {
-     const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|PaddedConv2D#|FC#)");
- 
-     std::pair<int, int> signedRange = {-std::pow(2, nbBits - 1), std::pow(2, nbBits - 1) - 1};
- 
-     for (const auto& match : matches) 
-     {       
-         auto linearNode = match.graph->rootNode(); 
- 
-         // Log::notice(" SET PARAM QUANTIZER : {} ", linearNode->type());
- 
-         // TODO : double check this, and use createUniqueName()
-         auto quantizerName = makeUniqueName(linearNode->name() + "_lsq_p", graphView);  
-         auto quantizerNode = LSQ(signedRange, quantizerName); 
- 
-         // Init the step-size using the node call stack
- 
-         quantizerNode->addBeforeForward([quantizerNode](){ return initStepSize(quantizerNode); });
- 
-         // Insert the quantizer in the graphView
- 
-         graphView->insertParent(linearNode, quantizerNode, 1, 0, 0);
-     }
- }
- 
- void QuantLSQ::setupQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits)
- {
-     sanitizeNodeNames(graphView);
-     setupInputQuantizers(graphView, nbBits);
-     setupParamQuantizers(graphView, nbBits);
- }
- 
- }
\ No newline at end of file
+#include "aidge/quantization/QAT/QAT_LSQ.hpp"
+#include "aidge/operator/LSQ.hpp"
+#include "aidge/operator/ReLU.hpp"
+
+
+#include "aidge/data/Tensor.hpp"
+#include "aidge/graph/GraphView.hpp"
+#include "aidge/scheduler/SequentialScheduler.hpp"
+#include "aidge/scheduler/Scheduler.hpp"
+#include "aidge/graph/Matching.hpp"
+#include "aidge/recipes/QuantRecipes.hpp"
+
+
+namespace Aidge 
+{
+
+static float getTensorAbsMean(std::shared_ptr<Tensor> tensor)
+{
+    auto valueTensor = (*tensor).abs().mean();
+    std::shared_ptr<Tensor> fallback;
+    const Tensor& localTensor = valueTensor.refCastFrom(fallback, DataType::Float32, "cpu");
+    return localTensor.get<float>(0);
+}
+
+static float getTensorStd(std::shared_ptr<Tensor> tensor)
+{
+    auto valueTensor = (*tensor);
+    
+    auto skewedTensor = valueTensor - valueTensor.mean();
+    auto squaredTensor = skewedTensor * skewedTensor;
+    auto varianceTensor = squaredTensor.mean();
+
+    std::shared_ptr<Tensor> fallback;
+    auto localTensor = varianceTensor.refCastFrom(fallback, DataType::Float32, "cpu");
+    
+    float variance = localTensor.get<float>(0);
+    return std::sqrt(variance);
+}
+
+
+// INIT THE STEP SIZE OF A QUANTIZER NODE
+
+static bool initStepSize(std::shared_ptr<Node> quantizer)
+{
+    const auto quantizerOp = std::static_pointer_cast<LSQ_Op>(quantizer->getOperator());
+
+    // This formula is the one proposed in the paper ...
+
+    // float inputAbsMean = getTensorAbsMean(quantizerOp->getInput(0));
+    // float stepSize = 2.0f * (inputAbsMean / std::sqrt(quantizerOp->range().second));
+
+    // .. but this formula seems to work better !!!
+
+    float inputStd = getTensorStd(quantizerOp->getInput(0));
+    float stepSize = 8.0f * (inputStd / (quantizerOp->range().second));
+
+    // TODO : use the scalar constructor
+    auto stepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}})); 
+
+    // XXX Manage backend here ?
+    stepSizeTensor->setBackend(quantizerOp->getInput(0)->backend());
+    stepSizeTensor->setDataType(quantizerOp->getInput(0)->dataType());
+
+    auto stepSizeProducer = quantizer->getParent(1);
+
+    stepSizeProducer->getOperator()->setOutput(0, stepSizeTensor);
+
+    Log::notice(" [ INIT STEP SIZE = {} ] ", stepSize);
+
+    return false;
+}
+
+static void setupInputQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits)
+{
+    const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|PaddedConv2D#|FC#)");
+
+    for (const auto& match : matches) 
+    {
+        auto linearNode = match.graph->rootNode(); 
+
+        // Log::notice(" SET INPUT QUANTIZER : {} ", linearNode->type());
+
+        std::pair<int, int> signedRange = {-std::pow(2, nbBits - 1), std::pow(2, nbBits - 1) - 1};
+        std::pair<int, int> unsignedRange = {0, std::pow(2, nbBits) - 1};
+
+        // Create the input quantizer node
+
+        auto quantizerName = makeUniqueName(linearNode->name() + "_lsq_i", graphView);  
+        auto quantizerNode = LSQ(signedRange, quantizerName);
+
+        // Init the step-size using the node call stack
+
+        quantizerNode->addBeforeForward([quantizerNode](){ return initStepSize(quantizerNode); });
+
+        // Absorb the ReLU when possible ...
+
+        bool nodeHasParent = static_cast<bool> (linearNode->getParents()[0]);  // XXX is this safe ?
+
+        if (nodeHasParent) 
+        {
+            bool allParentsAreReLU = true;
+            for (auto parentNode : linearNode->getParents())
+                if (parentNode->type() != "ReLU")
+                    allParentsAreReLU = false;
+
+            if (allParentsAreReLU) {
+                auto quantizerOp = std::static_pointer_cast<LSQ_Op> (quantizerNode->getOperator());
+                quantizerOp->range() = unsignedRange;
+            }
+
+            // TODO : remove the ReLUs when possible
+        }
+
+        // Insert the quantizer in the graphView ...
+        // (We need to handle the case where the linear node is the first one)
+
+        if (nodeHasParent) {
+            graphView->insertParent(linearNode, quantizerNode, 0, 0, 0);
+        } else {
+            quantizerNode->addChild(graphView);
+            graphView->add(quantizerNode);
+        }
+    }
+}
+
+// PARAM QUANTIZERS INSERTION
+
+static void setupParamQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits)
+{
+    const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|PaddedConv2D#|FC#)");
+
+    std::pair<int, int> signedRange = {-std::pow(2, nbBits - 1), std::pow(2, nbBits - 1) - 1};
+
+    for (const auto& match : matches) 
+    {       
+        auto linearNode = match.graph->rootNode(); 
+
+        // Log::notice(" SET PARAM QUANTIZER : {} ", linearNode->type());
+
+        // TODO : double check this, and use createUniqueName()
+        auto quantizerName = makeUniqueName(linearNode->name() + "_lsq_p", graphView);  
+        auto quantizerNode = LSQ(signedRange, quantizerName); 
+
+        // Init the step-size using the node call stack
+
+        quantizerNode->addBeforeForward([quantizerNode](){ return initStepSize(quantizerNode); });
+
+        // Insert the quantizer in the graphView
+
+        graphView->insertParent(linearNode, quantizerNode, 1, 0, 0);
+    }
+}
+
+void QuantLSQ::setupQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits)
+{
+    sanitizeNodeNames(graphView);
+    setupInputQuantizers(graphView, nbBits);
+    setupParamQuantizers(graphView, nbBits);
+}
+
+}
\ No newline at end of file
-- 
GitLab


From b983d4befdcaa32d9480d300cccb90bb2a2ad851 Mon Sep 17 00:00:00 2001
From: bhalimi <benjamin.halimi@cea.fr>
Date: Thu, 6 Mar 2025 15:01:06 +0000
Subject: [PATCH 06/20] remove commented code

---
 src/PTQ/PTQ.cpp | 19 -------------------
 1 file changed, 19 deletions(-)

diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp
index 9d59502..ce86a31 100644
--- a/src/PTQ/PTQ.cpp
+++ b/src/PTQ/PTQ.cpp
@@ -50,25 +50,6 @@ bool isAffine(std::shared_ptr<Node> node)
     if ((node->type() == "MatMul") && hasAttr(node, "isWeighted"))
         return true;
 
-    /*
-    // Check if the MatMul has a parameter on input 1 (in that case it's an Affine node)
-
-    if (node->type() == "MatMul") {
-        std::shared_ptr<Node> rightParent = node->getParent(1);
-        if (rightParent) {
-            bool hasWeight = rightParent->type() == "Producer";
-            if (hasWeight) {
-                return true;
-            } else if (rightParent->type() == "Mul") {
-                bool hasScaledWeight = rightParent->attributes()->hasAttr("quantization.ptq.isProducerScaling");
-                hasScaledWeight &= (rightParent->getParent(0)->type() == "Producer");
-                if (hasScaledWeight) {
-                    return true;
-                }
-            }
-        }
-    }
-*/
     return false;
 }
 
-- 
GitLab


From a0c3e6037ed633700c4d75c63e9b493671a26bce Mon Sep 17 00:00:00 2001
From: bhalimi <benjamin.halimi@cea.fr>
Date: Thu, 6 Mar 2025 15:15:36 +0000
Subject: [PATCH 07/20] make use of getOrderedNodes()

---
 src/PTQ/PTQ.cpp        | 17 ++++-------------
 src/QAT/QAT_FixedQ.cpp |  4 ++--
 2 files changed, 6 insertions(+), 15 deletions(-)

diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp
index ce86a31..df2d5d2 100644
--- a/src/PTQ/PTQ.cpp
+++ b/src/PTQ/PTQ.cpp
@@ -312,25 +312,16 @@ static std::shared_ptr<Tensor> getBiasTensor(std::shared_ptr<Node> node)
 
 std::vector<std::shared_ptr<Node>> retrieveNodeVector(std::shared_ptr<GraphView> graphView, bool newSchedule, bool verbose)
 {
-    std::vector<std::shared_ptr<Node>> nodeVector;
-
-    SequentialScheduler scheduler(graphView);
-
-    if (newSchedule)
-    {
-        scheduler.resetScheduling();
-        scheduler.generateScheduling(); // old way : scheduler.forward(); 
-    }
-
-    nodeVector = scheduler.getStaticScheduling();
+    std::vector<std::shared_ptr<Node>> nodeVector = graphView->getOrderedNodes();
+   
+    fixScheduling(nodeVector); 
 
-    fixScheduling(nodeVector);
     nodeVector = removeMatchingNodes(nodeVector, "Producer");
     nodeVector = removeProdScalingNodes(nodeVector);
 
     if (verbose) 
     {
-        Log::info("NB OF NODES = {}", nodeVector.size());
+        Log::info(" NB OF NODES = {}", nodeVector.size());
         for (std::shared_ptr<Node> node : nodeVector)
             Log::info("{} {}", node->type(), node->name());
     }
diff --git a/src/QAT/QAT_FixedQ.cpp b/src/QAT/QAT_FixedQ.cpp
index 6ada532..8e9adb2 100644
--- a/src/QAT/QAT_FixedQ.cpp
+++ b/src/QAT/QAT_FixedQ.cpp
@@ -154,8 +154,8 @@ void QuantFixedQ::devQAT(std::shared_ptr<GraphView> graphView)
 {
     SequentialScheduler scheduler(graphView);
     scheduler.generateScheduling();
-    auto s = scheduler.getStaticScheduling();
-    for (std::shared_ptr<Node> node : s)
+    auto nodeVector = graphView->getOrderedNodes();
+    for (std::shared_ptr<Node> node : nodeVector)
         Log::info(" name : {} ", node->name());
 }
 
-- 
GitLab


From 294e7d6e08f804041ceea6b3e8f229796498ce62 Mon Sep 17 00:00:00 2001
From: bhalimi <benjamin.halimi@cea.fr>
Date: Thu, 6 Mar 2025 16:14:04 +0000
Subject: [PATCH 08/20] minor changes

---
 src/PTQ/PTQ.cpp | 29 ++++++++++++++++++-----------
 1 file changed, 18 insertions(+), 11 deletions(-)

diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp
index df2d5d2..121c4a1 100644
--- a/src/PTQ/PTQ.cpp
+++ b/src/PTQ/PTQ.cpp
@@ -42,6 +42,11 @@ static bool hasAttr(std::shared_ptr<Aidge::Node> node, std::string attr)
     return node->attributes()->hasAttr("quantization.ptq." + attr);
 }
 
+static void addAttr(std::shared_ptr<Aidge::Node> node, std::string attr, double value = 0.0)
+{
+    node->attributes()->addAttr("quantization.ptq." + attr, value);
+}
+
 bool isAffine(std::shared_ptr<Node> node)
 {   
     if (affineNodeTypes.find(node->type()) != affineNodeTypes.end())
@@ -126,7 +131,7 @@ void prepareNetwork(std::shared_ptr<GraphView> graphView)
         }
 
         if (isWeighted) {
-            node->attributes()->addAttr("quantization.ptq.isWeighted", 0.0);
+            addAttr(node, "isWeighted");
         }
     }
 
@@ -148,7 +153,7 @@ void prepareNetwork(std::shared_ptr<GraphView> graphView)
     popSoftMax(graphView);
 }
 
-static std::shared_ptr<Aidge::Node> getUniqueChildren(std::shared_ptr<Aidge::Node> node) 
+static std::shared_ptr<Aidge::Node> getUniqueChild(std::shared_ptr<Aidge::Node> node) 
 {
     std::set<std::shared_ptr<Aidge::Node>> childrenSet = node->getChildren();
     AIDGE_ASSERT(childrenSet.size() == 1, " Attempted to access to a unique child while the parent have multiple ones ! ");
@@ -218,7 +223,8 @@ bool insertRoundBelowProducer(std::shared_ptr<Node> node, std::shared_ptr<GraphV
         roundNode->getOperator()->setBackend(node->getOperator()->backend());
 
         insertChildren(node, roundNode, graphView);
-        roundNode->attributes()->addAttr("quantization.ptq.isProducerRounding", 0.0);
+        addAttr(roundNode, "isProducerRounding");
+    
         return true;
     }
     return false;
@@ -345,8 +351,8 @@ static std::shared_ptr<Aidge::Node> createScalingNode(std::string name, std::vec
 {
     std::shared_ptr<Node> scalingNode = Mul(name);
   
-    for (std::string attr : attributes)
-        scalingNode->attributes()->addAttr("quantization.ptq." + attr, 0.0);
+    for (std::string a : attributes)
+        addAttr(scalingNode, a);
     
     // Add the scaling factor as a producer of the node
 
@@ -573,7 +579,7 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView)
             // Revert the canceling by using the next scaling node
 
             accumulatedRatios[node] = prevRatio;
-            std::shared_ptr<Node> nextScalingNode = getUniqueChildren(node);
+            std::shared_ptr<Node> nextScalingNode = getUniqueChild(node);
             multiplyScalingFactor(nextScalingNode, prevRatio);
         }
 
@@ -991,7 +997,7 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_
             rescaling /= inputIsUnsigned  ? unsignedMax : signedMax;
             rescaling *= outputIsUnsigned ? unsignedMax : signedMax;
             
-            std::shared_ptr<Node> scalingNode = getUniqueChildren(node); // TODO : assert if scalingNode is a Scaling ...
+            std::shared_ptr<Node> scalingNode = getUniqueChild(node); // TODO : assert if scalingNode is a Scaling ...
 
             multiplyScalingFactor(scalingNode,rescaling) ;          
         }
@@ -1006,7 +1012,7 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_
             rescaling /= inputIsUnsigned  ? unsignedMax : signedMax;
             rescaling *= outputIsUnsigned ? unsignedMax : signedMax;
 
-            std::shared_ptr<Node> scalingNode = getUniqueChildren(node); // TODO : assert if scalingNode is a Scaling ...
+            std::shared_ptr<Node> scalingNode = getUniqueChild(node); // TODO : assert if scalingNode is a Scaling ...
         
             multiplyScalingFactor(scalingNode, rescaling) ;          
         }
@@ -1018,7 +1024,7 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_
             std::shared_ptr<Node> prevScalingNode = node->getParent(0);
             multiplyScalingFactor(prevScalingNode, rescaling);
 
-            std::shared_ptr<Node> nextScalingNode = getUniqueChildren(node);
+            std::shared_ptr<Node> nextScalingNode = getUniqueChild(node);
             multiplyScalingFactor(nextScalingNode, 1 / rescaling);
         }
         
@@ -1030,7 +1036,7 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_
 
             bool precedesNonLinearNode = false;
             if (node->getChildren().size() == 1)
-                if (isNotQuantized(getUniqueChildren(node)))
+                if (isNotQuantized(getUniqueChild(node)))
                     precedesNonLinearNode = true; 
 
             if (!noQuant && !precedesNonLinearNode) 
@@ -1096,7 +1102,8 @@ static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::u
                 std::string mulNodeName = makeUniqueName(node->name() + "_Mul", graphView);
                 std::shared_ptr<Node> mulNode = Mul(mulNodeName);
                 
-                mulNode->attributes()->addAttr("quantization.ptq.isCompensation", 0.0);
+                addAttr(mulNode, "isCompensation");
+
                 mulNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
                 mulNode->getOperator()->setBackend(node->getOperator()->backend());
 
-- 
GitLab


From 478b96c0fd8ebc1fad683d0578e3f12f58541fd7 Mon Sep 17 00:00:00 2001
From: bhalimi <benjamin.halimi@cea.fr>
Date: Tue, 11 Mar 2025 09:53:33 +0000
Subject: [PATCH 09/20] wip

---
 include/aidge/recipes/QuantRecipes.hpp        |  7 ++
 .../recipes/pybind_QuantRecipes.cpp           |  6 +-
 src/backend/cuda/operator/LSQImpl.cpp         |  2 +-
 src/recipes/QuantRecipes.cpp                  | 68 ++++++++++++++++++-
 4 files changed, 78 insertions(+), 5 deletions(-)

diff --git a/include/aidge/recipes/QuantRecipes.hpp b/include/aidge/recipes/QuantRecipes.hpp
index 55bcb0d..302b43a 100644
--- a/include/aidge/recipes/QuantRecipes.hpp
+++ b/include/aidge/recipes/QuantRecipes.hpp
@@ -47,6 +47,13 @@ namespace Aidge
      */
     void replaceMatMulWithFC(std::shared_ptr<GraphView> graphView);
 
+    /**
+     * @brief Given a GraphView, set all it's MatMul weights to index 1 (required for the PTQ)
+     * This operation involve the insertion of Transpose nodes as well as the transposition of  
+     * the MatMul weight tensors.
+     * @param graphView The GraphView to process.
+     */
+    void reorderMatMulInputs(std::shared_ptr<GraphView> graphView);
 }
 
 #endif /* AIDGE_QUANTIZATION_QUANTRECIPES_H_ */
diff --git a/python_binding/recipes/pybind_QuantRecipes.cpp b/python_binding/recipes/pybind_QuantRecipes.cpp
index ab9aaa7..dd046d8 100644
--- a/python_binding/recipes/pybind_QuantRecipes.cpp
+++ b/python_binding/recipes/pybind_QuantRecipes.cpp
@@ -20,13 +20,13 @@ namespace py = pybind11;
 
 namespace Aidge {
 
-void init_QuantRecipes(py::module &m) {
-
+void init_QuantRecipes(py::module &m) 
+{
     m.def("pop_softmax", &popSoftMax, py::arg("network"));
     m.def("insert_batchnorm_nodes", &insertBatchNormNodes, py::arg("network"));
     m.def("sanitize_node_names", &sanitizeNodeNames, py::arg("network"));
     m.def("replace_matmul_with_fc", &replaceMatMulWithFC, py::arg("network"));
-
+    m.def("reorder_matmul_inputs", &reorderMatMulInputs, py::arg("network"));
 }
 
 } // namespace Aidge
diff --git a/src/backend/cuda/operator/LSQImpl.cpp b/src/backend/cuda/operator/LSQImpl.cpp
index fa45f21..2ef9ce0 100644
--- a/src/backend/cuda/operator/LSQImpl.cpp
+++ b/src/backend/cuda/operator/LSQImpl.cpp
@@ -57,7 +57,7 @@ void Aidge::LSQImpl_cuda::backward() {
         if (mWorkspace != nullptr) {
             cudaFree(mWorkspace);
         }
-        CHECK_CUDA_STATUS(cudaMalloc(&mWorkspace, 8 * gra_int0->size())); // XXX This must be changed !!!
+        CHECK_CUDA_STATUS(cudaMalloc(&mWorkspace, 4 * gra_int0->size())); // XXX This must be changed !!!
         mWorkspaceSize = gra_int0->size();
     }
 
diff --git a/src/recipes/QuantRecipes.cpp b/src/recipes/QuantRecipes.cpp
index 795b3ca..f751f1f 100644
--- a/src/recipes/QuantRecipes.cpp
+++ b/src/recipes/QuantRecipes.cpp
@@ -9,9 +9,11 @@
  *
  ********************************************************************************/
 
-
+#include "aidge/graph/OpArgs.hpp"
+#include "aidge/operator/Producer.hpp"
 #include "aidge/operator/Conv.hpp"
 #include "aidge/operator/Transpose.hpp"
+#include "aidge/operator/MatMul.hpp"
 #include "aidge/operator/BatchNorm.hpp"
 //#include "aidge/quantization/PTQ/PTQ.hpp"
 #include "aidge/recipes/QuantRecipes.hpp"
@@ -191,4 +193,68 @@ void replaceMatMulWithFC(std::shared_ptr<GraphView> graphView)
     }
 }
 
+void reorderMatMulInputs(std::shared_ptr<GraphView> graphView)
+{
+    const auto matches = SinglePassGraphMatching(graphView).match("(MatMul#)");
+
+    for (auto match : matches)
+    {
+        auto node = match.graph->rootNode();
+
+        // Check if the MatMul inputs have to be permuted
+
+        bool permuteInputs = false;
+
+        if (node->getParent(0))
+            if (node->getParent(0)->type() == "Producer")
+                permuteInputs = true;
+
+        if (node->getParent(1))
+            if (node->getParent(1)->type() == "Producer")
+                permuteInputs = false;
+
+        // Perform the permutation of the inputs ...
+
+        if (permuteInputs)
+        {
+            // Replace the MatMul node with the new micrograph ...
+
+            auto prevMatMul = node; 
+            auto prevTensor = (std::static_pointer_cast<OperatorTensor> (node->getOperator()))->getInput(0);
+
+            auto newMatMul = MatMul();
+            auto newDims = prevTensor->dims();
+            std::swap(newDims[0], newDims[1]);
+            auto newTensor = std::make_shared<Tensor>(newDims);
+            auto newProducer = Producer(newTensor, "");
+            newProducer->addChild(newMatMul, 0, 1);
+
+            auto prevMicroGraph = Sequential({prevMatMul});
+            prevMicroGraph->add(prevMatMul->getParent(0));
+
+            auto newMicroGraph = Sequential({Transpose({1, 0}), newMatMul, Transpose({1, 0})});
+            newMicroGraph->add(newMatMul->getParent(1));
+
+            // TODO: do the graphView replace !!!
+
+            // Fill the new tensor with the transposed of the old one ...
+            
+            // TODO : use copyTranspose() instead !!!
+
+            auto transposeOp = Transpose_Op({1, 0});
+            transposeOp.setDataType(prevTensor->dataType());
+            transposeOp.setBackend(prevTensor->backend());
+        
+            transposeOp.associateInput(0, prevTensor);
+            transposeOp.forward();
+
+            *(newTensor) = *(transposeOp.getOutput(0));
+        }
+    }
+
+    // TODO : fold the Transpose operators when possible ...
+
+    // USE REGEXPS !!!
+}
+
 }
\ No newline at end of file
-- 
GitLab


From de5f4b7b1e107b91c0b143c1d0ec658511666c09 Mon Sep 17 00:00:00 2001
From: bhalimi <benjamin.halimi@cea.fr>
Date: Wed, 12 Mar 2025 13:21:22 +0000
Subject: [PATCH 10/20] enhance reorderMatMulInputs() and remove
 replaceMatMulWithFC()

---
 include/aidge/recipes/QuantRecipes.hpp        |   8 +-
 .../recipes/pybind_QuantRecipes.cpp           |   4 +-
 src/recipes/QuantRecipes.cpp                  | 129 +++++++++---------
 3 files changed, 66 insertions(+), 75 deletions(-)

diff --git a/include/aidge/recipes/QuantRecipes.hpp b/include/aidge/recipes/QuantRecipes.hpp
index 302b43a..1e78699 100644
--- a/include/aidge/recipes/QuantRecipes.hpp
+++ b/include/aidge/recipes/QuantRecipes.hpp
@@ -40,13 +40,7 @@ namespace Aidge
      * @param graphView The GraphView to process.
      */
     void sanitizeNodeNames(std::shared_ptr<GraphView> graphView);
-
-    /**
-     * @brief Given a GraphView, replace all it's MatMul nodes with equivalent FC ones
-     * @param graphView The GraphView to process.
-     */
-    void replaceMatMulWithFC(std::shared_ptr<GraphView> graphView);
-
+    
     /**
      * @brief Given a GraphView, set all it's MatMul weights to index 1 (required for the PTQ)
      * This operation involve the insertion of Transpose nodes as well as the transposition of  
diff --git a/python_binding/recipes/pybind_QuantRecipes.cpp b/python_binding/recipes/pybind_QuantRecipes.cpp
index dd046d8..15257b0 100644
--- a/python_binding/recipes/pybind_QuantRecipes.cpp
+++ b/python_binding/recipes/pybind_QuantRecipes.cpp
@@ -18,14 +18,14 @@
 
 namespace py = pybind11;
 
-namespace Aidge {
+namespace Aidge 
+{
 
 void init_QuantRecipes(py::module &m) 
 {
     m.def("pop_softmax", &popSoftMax, py::arg("network"));
     m.def("insert_batchnorm_nodes", &insertBatchNormNodes, py::arg("network"));
     m.def("sanitize_node_names", &sanitizeNodeNames, py::arg("network"));
-    m.def("replace_matmul_with_fc", &replaceMatMulWithFC, py::arg("network"));
     m.def("reorder_matmul_inputs", &reorderMatMulInputs, py::arg("network"));
 }
 
diff --git a/src/recipes/QuantRecipes.cpp b/src/recipes/QuantRecipes.cpp
index f751f1f..95d1e30 100644
--- a/src/recipes/QuantRecipes.cpp
+++ b/src/recipes/QuantRecipes.cpp
@@ -126,6 +126,68 @@ void sanitizeNodeNames(std::shared_ptr<GraphView> graphView)
     }
 }
 
+void reorderMatMulInputs(std::shared_ptr<GraphView> graphView)
+{
+    const auto matches = SinglePassGraphMatching(graphView).match("(MatMul#)");
+
+    for (auto match : matches)
+    {
+        auto node = match.graph->rootNode();
+
+        // Check if the MatMul inputs have to be permuted
+
+        bool permuteInputs = false;
+
+        if (node->getParent(0))
+            if (node->getParent(0)->type() == "Producer")
+                permuteInputs = true;
+
+        if (node->getParent(1))
+            if (node->getParent(1)->type() == "Producer")
+                permuteInputs = false;
+
+        // Perform the permutation of the inputs ...
+
+        if (permuteInputs)
+        {
+            auto prevMatMul = node; 
+            auto prevTensor = (std::static_pointer_cast<OperatorTensor> (node->getOperator()))->getInput(0);
+
+            // Create the new MatMul op and it's Producer
+
+            auto newMatMul = MatMul();
+    
+            auto newDims = prevTensor->dims();
+            std::swap(newDims[0], newDims[1]);
+            auto newTensor = std::make_shared<Tensor>(newDims);
+            
+            newTensor->setDataType(prevTensor->dataType());
+            newTensor->setBackend(prevTensor->backend());
+            newTensor->copyTranspose(*prevTensor, std::vector<Aidge::DimSize_t>({1, 0}));
+
+            auto newProducer = Producer(newTensor, "");
+            newProducer->addChild(newMatMul, 0, 1);
+
+            // Replace the node by a micrograph
+
+            auto prevMicroGraph = Sequential({prevMatMul});
+            prevMicroGraph->add(prevMatMul->getParent(0));
+
+            auto newMicroGraph = Sequential({Transpose({1, 0}), newMatMul, Transpose({1, 0})});
+            newMicroGraph->add(newMatMul->getParent(1));
+
+            graphView->replace(prevMicroGraph, newMicroGraph);
+        }
+    }
+
+    // TODO : fold the Transpose operators when possible ...
+
+    // USE REGEXPS !!!
+}
+
+}
+
+/*
 void replaceMatMulWithFC(std::shared_ptr<GraphView> graphView)
 {
     const auto matches = SinglePassGraphMatching(graphView).match("(MatMul#)");
@@ -192,69 +254,4 @@ void replaceMatMulWithFC(std::shared_ptr<GraphView> graphView)
         }
     }
 }
-
-void reorderMatMulInputs(std::shared_ptr<GraphView> graphView)
-{
-    const auto matches = SinglePassGraphMatching(graphView).match("(MatMul#)");
-
-    for (auto match : matches)
-    {
-        auto node = match.graph->rootNode();
-
-        // Check if the MatMul inputs have to be permuted
-
-        bool permuteInputs = false;
-
-        if (node->getParent(0))
-            if (node->getParent(0)->type() == "Producer")
-                permuteInputs = true;
-
-        if (node->getParent(1))
-            if (node->getParent(1)->type() == "Producer")
-                permuteInputs = false;
-
-        // Perform the permutation of the inputs ...
-
-        if (permuteInputs)
-        {
-            // Replace the MatMul node with the new micrograph ...
-
-            auto prevMatMul = node; 
-            auto prevTensor = (std::static_pointer_cast<OperatorTensor> (node->getOperator()))->getInput(0);
-
-            auto newMatMul = MatMul();
-            auto newDims = prevTensor->dims();
-            std::swap(newDims[0], newDims[1]);
-            auto newTensor = std::make_shared<Tensor>(newDims);
-            auto newProducer = Producer(newTensor, "");
-            newProducer->addChild(newMatMul, 0, 1);
-
-            auto prevMicroGraph = Sequential({prevMatMul});
-            prevMicroGraph->add(prevMatMul->getParent(0));
-
-            auto newMicroGraph = Sequential({Transpose({1, 0}), newMatMul, Transpose({1, 0})});
-            newMicroGraph->add(newMatMul->getParent(1));
-
-            // TODO: do the graphView replace !!!
-
-            // Fill the new tensor with the transposed of the old one ...
-            
-            // TODO : use copyTranspose() instead !!!
-
-            auto transposeOp = Transpose_Op({1, 0});
-            transposeOp.setDataType(prevTensor->dataType());
-            transposeOp.setBackend(prevTensor->backend());
-        
-            transposeOp.associateInput(0, prevTensor);
-            transposeOp.forward();
-
-            *(newTensor) = *(transposeOp.getOutput(0));
-        }
-    }
-
-    // TODO : fold the Transpose operators when possible ...
-
-    // USE REGEXPS !!!
-}
-
-}
\ No newline at end of file
+*/
\ No newline at end of file
-- 
GitLab


From 2193615f3aaec6426ea9e80564d999ca644be96a Mon Sep 17 00:00:00 2001
From: bhalimi <benjamin.halimi@cea.fr>
Date: Wed, 12 Mar 2025 15:18:10 +0000
Subject: [PATCH 11/20] integration of matmul (insert scalings + norm
 parameters)

---
 src/PTQ/PTQ.cpp | 75 ++++++++++++++++++++++++++++++++-----------------
 1 file changed, 49 insertions(+), 26 deletions(-)

diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp
index 121c4a1..002f390 100644
--- a/src/PTQ/PTQ.cpp
+++ b/src/PTQ/PTQ.cpp
@@ -65,7 +65,13 @@ bool isSeamless(std::shared_ptr<Node> node)
 
 bool isMerging(std::shared_ptr<Node> node)
 {
-    return (mergingNodeTypes.find(node->type()) != mergingNodeTypes.end());
+    if (mergingNodeTypes.find(node->type()) != mergingNodeTypes.end())
+        return true;
+
+    if ((node->type() == "MatMul") && !hasAttr(node, "isWeighted"))
+        return true;
+
+    return false;
 }
 
 bool isNotQuantized(std::shared_ptr<Node> node)
@@ -121,18 +127,16 @@ void prepareNetwork(std::shared_ptr<GraphView> graphView)
     for (std::shared_ptr<Node> node : nodeVector)
     {
         bool isWeighted = isAffine(node);
-        if (node->type() == "MatMul") {
+        if (node->type() == "MatMul") 
+        {
             std::shared_ptr<Node> parent = node->getParent(1);
-            if (parent) {
-                if (parent->type() == "Producer") {
+            if (parent) 
+                if (parent->type() == "Producer") 
                     isWeighted = true;
-                }
-            }
         }
 
-        if (isWeighted) {
+        if (isWeighted)
             addAttr(node, "isWeighted");
-        }
     }
 
     // fuse the batchnorms
@@ -492,7 +496,6 @@ void insertScalingNodes(std::shared_ptr<GraphView> graphView)
                 graphView->insertParent(parentNode, prevScalingNode, 0, 0, 0);
                 graphView->add(prevScalingNode->getParent(1)); // add the scaling factor producer
             }
-
         }
     }
 }
@@ -537,6 +540,8 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView)
         // Residual nodes should enter in this category but their ratio is 1 ...
         if (isAffine(node))
         {
+            Log::warn(" affine : {} ", node->name());
+
             // Rescale the weight tensor
             
             std::shared_ptr<Tensor> weightTensor = getWeightTensor(node);
@@ -585,31 +590,49 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView)
 
         if (isMerging(node))
         {
-            std::vector<std::shared_ptr<Node>> mergingNodes = node->getParents();
+            if (node->type() == "MatMul")
+            {
+                Log::warn(" matmul : {} ", node->name());
 
-            // Compute the max ratio ...
+                // Multiply the input scaling factors !
 
-            double maxRatio = 0;
-            for (std::shared_ptr<Node> mergingNode : mergingNodes)
-            {
-                double merginNodeRatio = accumulatedRatios[mergingNode];
-                if (merginNodeRatio > maxRatio)
-                    maxRatio = merginNodeRatio;
+                double leftRatio  = accumulatedRatios[node->getParent(0)];
+                double rightRatio = accumulatedRatios[node->getParent(1)];
+
+                accumulatedRatios[node] = leftRatio * rightRatio;
             }
+            else
+            {
+                // Use a maximum arbitration !
 
-            accumulatedRatios[node] = maxRatio;
+                Log::warn(" merging : {} ", node->name());
 
-            // Rescale the previous scaling Nodes
-            for (std::shared_ptr<Node> mergingNode : mergingNodes)
-            {
-                double mergingNodeRatio = accumulatedRatios[mergingNode];
-                double rescaling = mergingNodeRatio / maxRatio;
+                std::vector<std::shared_ptr<Node>> mergingNodes = node->getParents();
 
-                std::shared_ptr<Node> scalingNode = getPreviousScalingNode(mergingNode);
+                // Compute the max ratio ...
 
-                multiplyScalingFactor(scalingNode, 1 / rescaling);
+                double maxRatio = 0;
+                for (std::shared_ptr<Node> mergingNode : mergingNodes)
+                {
+                    double merginNodeRatio = accumulatedRatios[mergingNode];
+                    if (merginNodeRatio > maxRatio)
+                        maxRatio = merginNodeRatio;
+                }
+
+                accumulatedRatios[node] = maxRatio;
+
+                // Rescale the previous scaling Nodes
+                for (std::shared_ptr<Node> mergingNode : mergingNodes)
+                {
+                    double mergingNodeRatio = accumulatedRatios[mergingNode];
+                    double rescaling = mergingNodeRatio / maxRatio;
+
+                    std::shared_ptr<Node> scalingNode = getPreviousScalingNode(mergingNode);
 
-                accumulatedRatios[mergingNode] /= rescaling; // optional ...
+                    multiplyScalingFactor(scalingNode, 1 / rescaling);
+
+                    accumulatedRatios[mergingNode] /= rescaling; // optional ...
+                }
             }
         }
     }
-- 
GitLab


From 3b9e29ce9fcd6ec612482ab1efdd7c1658d38982 Mon Sep 17 00:00:00 2001
From: bhalimi <benjamin.halimi@cea.fr>
Date: Fri, 14 Mar 2025 10:14:05 +0000
Subject: [PATCH 12/20] modify normalizeActivations() for MatMul support

---
 src/PTQ/PTQ.cpp | 41 +++++++++++++++++++++++++----------------
 1 file changed, 25 insertions(+), 16 deletions(-)

diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp
index 002f390..3a24097 100644
--- a/src/PTQ/PTQ.cpp
+++ b/src/PTQ/PTQ.cpp
@@ -796,26 +796,35 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_m
         
         if (isMerging(node))
         {
-            std::vector<std::shared_ptr<Node>> mergingNodes = node->getParents();
-
-            // Compute the max ratio ...
-
-            double maxRatio = 0;
-            for (std::shared_ptr<Node> mergingNode : mergingNodes)
+            if (node->type() == "MatMul")
             {
-                double mergingNodeRatio = accumulatedRatios[mergingNode];
-                if (mergingNodeRatio > maxRatio)
-                    maxRatio = mergingNodeRatio;
+                double leftRatio  = accumulatedRatios[node->getParent(0)];
+                double rightRatio = accumulatedRatios[node->getParent(1)];
+                accumulatedRatios[node] = leftRatio * rightRatio;
             }
+            else
+            {
+                std::vector<std::shared_ptr<Node>> mergingNodes = node->getParents();
 
-            accumulatedRatios[node] = maxRatio;
+                // Compute the max ratio ...
 
-            for (std::shared_ptr<Node> mergingNode : mergingNodes)
-            {
-                double mergingNodeRatio = accumulatedRatios[mergingNode];
-                std::shared_ptr<Node> scalingNode = getPreviousScalingNode(mergingNode);
-                multiplyScalingFactor(scalingNode, mergingNodeRatio / maxRatio);
-                // Log::notice(" SCALING NODE : {} {}", scalingNode->type(), scalingNode->name());
+                double maxRatio = 0;
+                for (std::shared_ptr<Node> mergingNode : mergingNodes)
+                {
+                    double mergingNodeRatio = accumulatedRatios[mergingNode];
+                    if (mergingNodeRatio > maxRatio)
+                        maxRatio = mergingNodeRatio;
+                }
+
+                accumulatedRatios[node] = maxRatio;
+
+                for (std::shared_ptr<Node> mergingNode : mergingNodes)
+                {
+                    double mergingNodeRatio = accumulatedRatios[mergingNode];
+                    std::shared_ptr<Node> scalingNode = getPreviousScalingNode(mergingNode);
+                    multiplyScalingFactor(scalingNode, mergingNodeRatio / maxRatio);
+                    // Log::notice(" SCALING NODE : {} {}", scalingNode->type(), scalingNode->name());
+                }
             }
         }
 
-- 
GitLab


From 9538457f878a811649c6472b281b33088fed9abd Mon Sep 17 00:00:00 2001
From: bhalimi <benjamin.halimi@cea.fr>
Date: Fri, 14 Mar 2025 12:59:11 +0000
Subject: [PATCH 13/20] edit quantizeNormalizedNetwork() (MatMul support)

---
 src/PTQ/PTQ.cpp | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp
index 3a24097..8a36e4f 100644
--- a/src/PTQ/PTQ.cpp
+++ b/src/PTQ/PTQ.cpp
@@ -1046,6 +1046,10 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_
 
             std::shared_ptr<Node> scalingNode = getUniqueChild(node); // TODO : assert if scalingNode is a Scaling ...
         
+            // TODO : double check this ...
+            if (node->type() == "MatMul")
+                rescaling /= inputIsUnsigned ? unsignedMax : signedMax;
+
             multiplyScalingFactor(scalingNode, rescaling) ;          
         }
 
-- 
GitLab


From 58ab5a51d16fefc5fe7f31cbb0c0bc569b59b843 Mon Sep 17 00:00:00 2001
From: bhalimi <benjamin.halimi@cea.fr>
Date: Fri, 14 Mar 2025 15:09:08 +0000
Subject: [PATCH 14/20] remove commented code

---
 src/recipes/QuantRecipes.cpp | 69 ------------------------------------
 1 file changed, 69 deletions(-)

diff --git a/src/recipes/QuantRecipes.cpp b/src/recipes/QuantRecipes.cpp
index 95d1e30..c184882 100644
--- a/src/recipes/QuantRecipes.cpp
+++ b/src/recipes/QuantRecipes.cpp
@@ -186,72 +186,3 @@ void reorderMatMulInputs(std::shared_ptr<GraphView> graphView)
 }
 
 }
-
-/*
-void replaceMatMulWithFC(std::shared_ptr<GraphView> graphView)
-{
-    const auto matches = SinglePassGraphMatching(graphView).match("(MatMul#)");
-
-    for (const auto& match : matches)
-    {
-        auto matMulNode =  match.graph->rootNode(); 
-
-        std::pair<bool, bool> inputsAreProducers = {false, false};
-
-        if (matMulNode->getParent(0))
-            inputsAreProducers.first  = (matMulNode->getParent(0)->type() == "Producer");
-
-        if (matMulNode->getParent(1))
-            inputsAreProducers.second = (matMulNode->getParent(1)->type() == "Producer");
-
-        if (inputsAreProducers.first && inputsAreProducers.second)
-        { 
-            Log::warn(" Both input nodes of MatMul operator are Producers, it should be constant folded ! ");
-        }
-        else if (inputsAreProducers.first && !inputsAreProducers.second) 
-        {
-            Log::warn(" This input setup is not supported yet ! ");
-        }
-        else if (!inputsAreProducers.first && inputsAreProducers.second) 
-        {
-            // If the weight tensor is of rank 2, replace the MatMul with an FC !
-
-            std::shared_ptr<OperatorTensor> matMulOp = std::static_pointer_cast<OperatorTensor> (matMulNode->getOperator());
-            
-            std::shared_ptr<Tensor> weight = matMulOp->getInput(1);
-
-            if (weight->dims().size() == 2)
-            {
-                std::size_t inChannels  = weight->dims()[0];
-                std::size_t outChannels = weight->dims()[1];
-            
-                Log::notice(" ### channels = {} {} ", inChannels, outChannels);
-            
-                std::string name = matMulNode->name() + "_FC";
-            
-                std::shared_ptr<Node> FCNode = FC(inChannels, outChannels, true, name);
-
-                std::shared_ptr<OperatorTensor> FCOp = std::static_pointer_cast<OperatorTensor> (FCNode->getOperator());
-                
-                // Transpose the weights
-
-                auto transposeOp = Transpose_Op({1, 0});
-                transposeOp.setDataType(weight->dataType());
-                transposeOp.setBackend(weight->backend());
-            
-                transposeOp.associateInput(0, weight);
-                transposeOp.forward();
-                auto transposedWeight = transposeOp.getOutput(0);
- 
-                // Fill the FC weights
-
-                *FCOp->getInput(1) = *transposedWeight;
-
-                // Replace the MatMul with the FC
-
-                graphView->replace({matMulNode, matMulNode->getParent(1)}, {FCNode, FCNode->getParent(1)});
-            }
-        }
-    }
-}
-*/
\ No newline at end of file
-- 
GitLab


From ffd24c34ee85aa5e5d897cad2bde1e2b87d67db3 Mon Sep 17 00:00:00 2001
From: bhalimi <benjamin.halimi@cea.fr>
Date: Mon, 17 Mar 2025 15:56:53 +0000
Subject: [PATCH 15/20] fix the LSQ workspace cuda allocation

---
 src/backend/cuda/operator/LSQImpl.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/src/backend/cuda/operator/LSQImpl.cpp b/src/backend/cuda/operator/LSQImpl.cpp
index 2ef9ce0..5f9f032 100644
--- a/src/backend/cuda/operator/LSQImpl.cpp
+++ b/src/backend/cuda/operator/LSQImpl.cpp
@@ -53,11 +53,11 @@ void Aidge::LSQImpl_cuda::backward() {
     std::shared_ptr<Tensor> gra_out0 = op_.getOutput(0)->grad();    
 
     if (gra_int0->size() > mWorkspaceSize) {
-        // std::cout << " reallocation " << sizeof(gra_int0) << " " << gra_int0->size() << std::endl;
         if (mWorkspace != nullptr) {
             cudaFree(mWorkspace);
         }
-        CHECK_CUDA_STATUS(cudaMalloc(&mWorkspace, 4 * gra_int0->size())); // XXX This must be changed !!!
+        std::size_t sizeofData = getDataTypeBitWidth(gra_int0->dataType()) / 8;
+        CHECK_CUDA_STATUS(cudaMalloc(&mWorkspace, sizeofData * gra_int0->size()));
         mWorkspaceSize = gra_int0->size();
     }
 
-- 
GitLab


From 718f6a1f0b9c59a8d65b05f9c8b9c0ca83a784f5 Mon Sep 17 00:00:00 2001
From: bhalimi <benjamin.halimi@cea.fr>
Date: Mon, 17 Mar 2025 16:01:23 +0000
Subject: [PATCH 16/20] minor change

---
 src/QAT/QAT_FixedQ.cpp | 2 --
 1 file changed, 2 deletions(-)

diff --git a/src/QAT/QAT_FixedQ.cpp b/src/QAT/QAT_FixedQ.cpp
index 8e9adb2..b51123c 100644
--- a/src/QAT/QAT_FixedQ.cpp
+++ b/src/QAT/QAT_FixedQ.cpp
@@ -152,8 +152,6 @@ void QuantFixedQ::insertAndInitQuantizers(std::shared_ptr<GraphView> graphView,
 
 void QuantFixedQ::devQAT(std::shared_ptr<GraphView> graphView) 
 {
-    SequentialScheduler scheduler(graphView);
-    scheduler.generateScheduling();
     auto nodeVector = graphView->getOrderedNodes();
     for (std::shared_ptr<Node> node : nodeVector)
         Log::info(" name : {} ", node->name());
-- 
GitLab


From f3d41082d3bc8c48d1e60414d080dc0de36193c8 Mon Sep 17 00:00:00 2001
From: bhalimi <benjamin.halimi@cea.fr>
Date: Tue, 18 Mar 2025 10:20:16 +0000
Subject: [PATCH 17/20] handle the MatMuls in prepareNetwork()

---
 src/PTQ/PTQ.cpp | 6 ++++++
 1 file changed, 6 insertions(+)

diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp
index 8a36e4f..ea592a6 100644
--- a/src/PTQ/PTQ.cpp
+++ b/src/PTQ/PTQ.cpp
@@ -122,6 +122,12 @@ void prepareNetwork(std::shared_ptr<GraphView> graphView)
 
     std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView);
 
+    // handle the MatMuls
+
+    reorderMatMulInputs(graphView);
+
+    matMulToFC(graphView);
+
     // tag the weighted nodes
 
     for (std::shared_ptr<Node> node : nodeVector)
-- 
GitLab


From c3b64dd1aada2be79f85a7c0cff066c3224d0c0b Mon Sep 17 00:00:00 2001
From: bhalimi <benjamin.halimi@cea.fr>
Date: Tue, 18 Mar 2025 11:00:35 +0000
Subject: [PATCH 18/20] modify prepareNetwork()

---
 src/PTQ/PTQ.cpp              | 11 +++--------
 src/recipes/QuantRecipes.cpp |  3 +++
 2 files changed, 6 insertions(+), 8 deletions(-)

diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp
index ea592a6..ee17991 100644
--- a/src/PTQ/PTQ.cpp
+++ b/src/PTQ/PTQ.cpp
@@ -112,24 +112,19 @@ bool checkArchitecture(std::shared_ptr<GraphView> graphView)
 
 void prepareNetwork(std::shared_ptr<GraphView> graphView)
 {
-    // XXX remove this !
-
-    sanitizeNodeNames(graphView);
-
     // remove the flatten nodes
 
     removeFlatten(graphView);
 
-    std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView);
-
     // handle the MatMuls
 
     reorderMatMulInputs(graphView);
-
-    matMulToFC(graphView);
+    // matMulToFC(graphView); // not working properly atm !
 
     // tag the weighted nodes
 
+    std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView);
+
     for (std::shared_ptr<Node> node : nodeVector)
     {
         bool isWeighted = isAffine(node);
diff --git a/src/recipes/QuantRecipes.cpp b/src/recipes/QuantRecipes.cpp
index c184882..1806e3d 100644
--- a/src/recipes/QuantRecipes.cpp
+++ b/src/recipes/QuantRecipes.cpp
@@ -176,6 +176,9 @@ void reorderMatMulInputs(std::shared_ptr<GraphView> graphView)
             auto newMicroGraph = Sequential({Transpose({1, 0}), newMatMul, Transpose({1, 0})});
             newMicroGraph->add(newMatMul->getParent(1));
 
+            newMicroGraph->setDataType(prevTensor->dataType());
+            newMicroGraph->setBackend(prevTensor->backend());
+
             graphView->replace(prevMicroGraph, newMicroGraph);
         }
     }
-- 
GitLab


From 6ab743834cdccc9b714e4515fa69a6c1a05942cc Mon Sep 17 00:00:00 2001
From: bhalimi <benjamin.halimi@cea.fr>
Date: Tue, 18 Mar 2025 16:09:56 +0000
Subject: [PATCH 19/20] minor changes

---
 src/PTQ/CLE.cpp                       | 2 +-
 src/backend/cuda/operator/LSQImpl.cpp | 4 ++--
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/src/PTQ/CLE.cpp b/src/PTQ/CLE.cpp
index 5ffc8eb..f70793c 100644
--- a/src/PTQ/CLE.cpp
+++ b/src/PTQ/CLE.cpp
@@ -174,7 +174,7 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetD
 
             insertScalingBelowProducer(n1->getParent(1), s1, graphView);
 
-            if (n1->type() != "MatMul") // TODO : enhance this !
+            if (n1->type() != "MatMul") // TODO : exclude every node that we can't call getParent(2) on !
                 if (n1->getParent(2))
                     insertScalingBelowProducer(n1->getParent(2), s1, graphView);
 
diff --git a/src/backend/cuda/operator/LSQImpl.cpp b/src/backend/cuda/operator/LSQImpl.cpp
index 5f9f032..bb30cc1 100644
--- a/src/backend/cuda/operator/LSQImpl.cpp
+++ b/src/backend/cuda/operator/LSQImpl.cpp
@@ -56,8 +56,8 @@ void Aidge::LSQImpl_cuda::backward() {
         if (mWorkspace != nullptr) {
             cudaFree(mWorkspace);
         }
-        std::size_t sizeofData = getDataTypeBitWidth(gra_int0->dataType()) / 8;
-        CHECK_CUDA_STATUS(cudaMalloc(&mWorkspace, sizeofData * gra_int0->size()));
+        std::size_t sizeOfData = getDataTypeBitWidth(gra_int0->dataType()) / 8;
+        CHECK_CUDA_STATUS(cudaMalloc(&mWorkspace, sizeOfData * gra_int0->size()));
         mWorkspaceSize = gra_int0->size();
     }
 
-- 
GitLab


From 541e72994529dc63d09547ea3ce592b0db4f8338 Mon Sep 17 00:00:00 2001
From: bhalimi <benjamin.halimi@cea.fr>
Date: Wed, 19 Mar 2025 10:22:56 +0000
Subject: [PATCH 20/20] avoid getTensorAbsoluteMax() redefinition (+ minor
 changes)

---
 include/aidge/quantization/PTQ/PTQ.hpp |  6 +++
 src/PTQ/CLE.cpp                        | 67 ++++----------------------
 src/PTQ/PTQ.cpp                        |  8 +--
 3 files changed, 16 insertions(+), 65 deletions(-)

diff --git a/include/aidge/quantization/PTQ/PTQ.hpp b/include/aidge/quantization/PTQ/PTQ.hpp
index f55894c..1c91180 100644
--- a/include/aidge/quantization/PTQ/PTQ.hpp
+++ b/include/aidge/quantization/PTQ/PTQ.hpp
@@ -74,6 +74,12 @@ namespace Aidge {
      */
     bool isNotQuantized(std::shared_ptr<Node> node);
 
+    /**
+     * @brief Compute the absolute max of a tensor
+     * @param tensor The Tensor to process
+     */
+    double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor);
+
     /**
      * @brief Retrieve the scheduled vector of node of a graphView, without the Producer nodes.
      * @param graphView The graphView containing the nodes
diff --git a/src/PTQ/CLE.cpp b/src/PTQ/CLE.cpp
index f70793c..57787a8 100644
--- a/src/PTQ/CLE.cpp
+++ b/src/PTQ/CLE.cpp
@@ -52,62 +52,14 @@ static std::shared_ptr<Tensor> getBiasTensor(std::shared_ptr<Node> node)
     return std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(2);
 }
 
-static void rescaleTensor(std::shared_ptr<Tensor> tensor, double scaling)
+static bool nodeHasBias(std::shared_ptr<Node> node)
 {
-    auto mulOp = Mul_Op();
-    mulOp.setDataType(tensor->dataType());
-    mulOp.setBackend(tensor->backend());
-
-    std::shared_ptr<Aidge::Tensor> scalingTensor = std::make_shared<Aidge::Tensor>(Aidge::Array1D<double, 1> {scaling});
-    scalingTensor->setDataType(tensor->dataType());
-    scalingTensor->setBackend(tensor->backend());
-
-    mulOp.associateInput(0, tensor);
-    mulOp.associateInput(1, scalingTensor);
-
-    mulOp.forward();
-    
-    auto outTensor = mulOp.getOutput(0);
-    *tensor = *outTensor;
-    //tensor->copyCast(*outTensor);
-}
-
-// TODO : make the retreival of argmax values backend independant (refCastFrom)
-static double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor)
-{
-    // get the abs tensor
-    std::shared_ptr<Tensor> fallback; //Fallback tensor for refCastFR
-    std::shared_ptr<Tensor> absTensor = std::make_shared<Tensor>(tensor->abs());
-
-    // flatten the abs tensor
-
-    std::int64_t nbElement = tensor->size();
-
-    auto reshapeOp = Reshape_Op({nbElement});
-    reshapeOp.setDataType(tensor->dataType());
-    reshapeOp.setBackend(tensor->backend());
-
-    reshapeOp.associateInput(0, absTensor);
-    reshapeOp.forward();
-    std::shared_ptr<Tensor> flatTensor = reshapeOp.getOutput(0);
-    const Tensor& localFlatTensor = flatTensor->refCastFrom(fallback, DataType::Float64, "cpu"); 
-
-    // Get the argmax
-
-    auto argmaxOp = ArgMax_Op(0, true, false);
-    argmaxOp.setDataType(tensor->dataType());
-    argmaxOp.setBackend(tensor->backend());
-
-    argmaxOp.associateInput(0, flatTensor);
-    argmaxOp.forward();
-
-    const Tensor& argMaxTensor = argmaxOp.getOutput(0)->refCastFrom(fallback, DataType::Float64, "cpu"); 
-
-    // Return the max
-
-    int maxIndex = std::round(argMaxTensor.get<double>(0));
-
-    return localFlatTensor.get<double>(maxIndex);
+    if (node->getParents().size() == 3) {
+        std::shared_ptr<Tensor> biasTensor = getBiasTensor(node);
+        if (biasTensor)
+            return true;
+    }
+    return false;
 }
 
 // What is this thing ???
@@ -174,9 +126,8 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetD
 
             insertScalingBelowProducer(n1->getParent(1), s1, graphView);
 
-            if (n1->type() != "MatMul") // TODO : exclude every node that we can't call getParent(2) on !
-                if (n1->getParent(2))
-                    insertScalingBelowProducer(n1->getParent(2), s1, graphView);
+            if (nodeHasBias(n1))
+                insertScalingBelowProducer(n1->getParent(2), s1, graphView);
 
             insertScalingBelowProducer(n2->getParent(1), s2, graphView);
 
diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp
index 91a003d..0eecc45 100644
--- a/src/PTQ/PTQ.cpp
+++ b/src/PTQ/PTQ.cpp
@@ -266,7 +266,7 @@ bool insertRoundBelowProducer(std::shared_ptr<Node> node, std::shared_ptr<GraphV
     return false;
 }
 
-static double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor)
+double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor)
 {
     // get the abs tensor
     std::shared_ptr<Tensor> fallback; //Fallback tensor for refCastFR
@@ -571,8 +571,6 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView)
         // Residual nodes should enter in this category but their ratio is 1 ...
         if (isAffine(node))
         {
-            Log::warn(" affine : {} ", node->name());
-
             // Rescale the weight tensor
             
             std::shared_ptr<Tensor> weightTensor = getWeightTensor(node);
@@ -623,8 +621,6 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView)
         {
             if (node->type() == "MatMul")
             {
-                Log::warn(" matmul : {} ", node->name());
-
                 // Multiply the input scaling factors !
 
                 double leftRatio  = accumulatedRatios[node->getParent(0)];
@@ -636,8 +632,6 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView)
             {
                 // Use a maximum arbitration !
 
-                Log::warn(" merging : {} ", node->name());
-
                 std::vector<std::shared_ptr<Node>> mergingNodes = node->getParents();
 
                 // Compute the max ratio ...
-- 
GitLab