From 3b9e29ce9fcd6ec612482ab1efdd7c1658d38982 Mon Sep 17 00:00:00 2001
From: bhalimi <benjamin.halimi@cea.fr>
Date: Fri, 14 Mar 2025 10:14:05 +0000
Subject: [PATCH] modify normalizeActivations() for MatMul support

---
 src/PTQ/PTQ.cpp | 41 +++++++++++++++++++++++++----------------
 1 file changed, 25 insertions(+), 16 deletions(-)

diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp
index 002f390..3a24097 100644
--- a/src/PTQ/PTQ.cpp
+++ b/src/PTQ/PTQ.cpp
@@ -796,26 +796,35 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_m
         
         if (isMerging(node))
         {
-            std::vector<std::shared_ptr<Node>> mergingNodes = node->getParents();
-
-            // Compute the max ratio ...
-
-            double maxRatio = 0;
-            for (std::shared_ptr<Node> mergingNode : mergingNodes)
+            if (node->type() == "MatMul")
             {
-                double mergingNodeRatio = accumulatedRatios[mergingNode];
-                if (mergingNodeRatio > maxRatio)
-                    maxRatio = mergingNodeRatio;
+                double leftRatio  = accumulatedRatios[node->getParent(0)];
+                double rightRatio = accumulatedRatios[node->getParent(1)];
+                accumulatedRatios[node] = leftRatio * rightRatio;
             }
+            else
+            {
+                std::vector<std::shared_ptr<Node>> mergingNodes = node->getParents();
 
-            accumulatedRatios[node] = maxRatio;
+                // Compute the max ratio ...
 
-            for (std::shared_ptr<Node> mergingNode : mergingNodes)
-            {
-                double mergingNodeRatio = accumulatedRatios[mergingNode];
-                std::shared_ptr<Node> scalingNode = getPreviousScalingNode(mergingNode);
-                multiplyScalingFactor(scalingNode, mergingNodeRatio / maxRatio);
-                // Log::notice(" SCALING NODE : {} {}", scalingNode->type(), scalingNode->name());
+                double maxRatio = 0;
+                for (std::shared_ptr<Node> mergingNode : mergingNodes)
+                {
+                    double mergingNodeRatio = accumulatedRatios[mergingNode];
+                    if (mergingNodeRatio > maxRatio)
+                        maxRatio = mergingNodeRatio;
+                }
+
+                accumulatedRatios[node] = maxRatio;
+
+                for (std::shared_ptr<Node> mergingNode : mergingNodes)
+                {
+                    double mergingNodeRatio = accumulatedRatios[mergingNode];
+                    std::shared_ptr<Node> scalingNode = getPreviousScalingNode(mergingNode);
+                    multiplyScalingFactor(scalingNode, mergingNodeRatio / maxRatio);
+                    // Log::notice(" SCALING NODE : {} {}", scalingNode->type(), scalingNode->name());
+                }
             }
         }
 
-- 
GitLab