From 3b9e29ce9fcd6ec612482ab1efdd7c1658d38982 Mon Sep 17 00:00:00 2001 From: bhalimi <benjamin.halimi@cea.fr> Date: Fri, 14 Mar 2025 10:14:05 +0000 Subject: [PATCH] modify normalizeActivations() for MatMul support --- src/PTQ/PTQ.cpp | 41 +++++++++++++++++++++++++---------------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp index 002f390..3a24097 100644 --- a/src/PTQ/PTQ.cpp +++ b/src/PTQ/PTQ.cpp @@ -796,26 +796,35 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_m if (isMerging(node)) { - std::vector<std::shared_ptr<Node>> mergingNodes = node->getParents(); - - // Compute the max ratio ... - - double maxRatio = 0; - for (std::shared_ptr<Node> mergingNode : mergingNodes) + if (node->type() == "MatMul") { - double mergingNodeRatio = accumulatedRatios[mergingNode]; - if (mergingNodeRatio > maxRatio) - maxRatio = mergingNodeRatio; + double leftRatio = accumulatedRatios[node->getParent(0)]; + double rightRatio = accumulatedRatios[node->getParent(1)]; + accumulatedRatios[node] = leftRatio * rightRatio; } + else + { + std::vector<std::shared_ptr<Node>> mergingNodes = node->getParents(); - accumulatedRatios[node] = maxRatio; + // Compute the max ratio ... - for (std::shared_ptr<Node> mergingNode : mergingNodes) - { - double mergingNodeRatio = accumulatedRatios[mergingNode]; - std::shared_ptr<Node> scalingNode = getPreviousScalingNode(mergingNode); - multiplyScalingFactor(scalingNode, mergingNodeRatio / maxRatio); - // Log::notice(" SCALING NODE : {} {}", scalingNode->type(), scalingNode->name()); + double maxRatio = 0; + for (std::shared_ptr<Node> mergingNode : mergingNodes) + { + double mergingNodeRatio = accumulatedRatios[mergingNode]; + if (mergingNodeRatio > maxRatio) + maxRatio = mergingNodeRatio; + } + + accumulatedRatios[node] = maxRatio; + + for (std::shared_ptr<Node> mergingNode : mergingNodes) + { + double mergingNodeRatio = accumulatedRatios[mergingNode]; + std::shared_ptr<Node> scalingNode = getPreviousScalingNode(mergingNode); + multiplyScalingFactor(scalingNode, mergingNodeRatio / maxRatio); + // Log::notice(" SCALING NODE : {} {}", scalingNode->type(), scalingNode->name()); + } } } -- GitLab