diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp index 002f3900886480c799417f25feb9bcd48ff08610..3a240971dd3465300c94885b8aed7d4fb0ef79b2 100644 --- a/src/PTQ/PTQ.cpp +++ b/src/PTQ/PTQ.cpp @@ -796,26 +796,35 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_m if (isMerging(node)) { - std::vector<std::shared_ptr<Node>> mergingNodes = node->getParents(); - - // Compute the max ratio ... - - double maxRatio = 0; - for (std::shared_ptr<Node> mergingNode : mergingNodes) + if (node->type() == "MatMul") { - double mergingNodeRatio = accumulatedRatios[mergingNode]; - if (mergingNodeRatio > maxRatio) - maxRatio = mergingNodeRatio; + double leftRatio = accumulatedRatios[node->getParent(0)]; + double rightRatio = accumulatedRatios[node->getParent(1)]; + accumulatedRatios[node] = leftRatio * rightRatio; } + else + { + std::vector<std::shared_ptr<Node>> mergingNodes = node->getParents(); - accumulatedRatios[node] = maxRatio; + // Compute the max ratio ... - for (std::shared_ptr<Node> mergingNode : mergingNodes) - { - double mergingNodeRatio = accumulatedRatios[mergingNode]; - std::shared_ptr<Node> scalingNode = getPreviousScalingNode(mergingNode); - multiplyScalingFactor(scalingNode, mergingNodeRatio / maxRatio); - // Log::notice(" SCALING NODE : {} {}", scalingNode->type(), scalingNode->name()); + double maxRatio = 0; + for (std::shared_ptr<Node> mergingNode : mergingNodes) + { + double mergingNodeRatio = accumulatedRatios[mergingNode]; + if (mergingNodeRatio > maxRatio) + maxRatio = mergingNodeRatio; + } + + accumulatedRatios[node] = maxRatio; + + for (std::shared_ptr<Node> mergingNode : mergingNodes) + { + double mergingNodeRatio = accumulatedRatios[mergingNode]; + std::shared_ptr<Node> scalingNode = getPreviousScalingNode(mergingNode); + multiplyScalingFactor(scalingNode, mergingNodeRatio / maxRatio); + // Log::notice(" SCALING NODE : {} {}", scalingNode->type(), scalingNode->name()); + } } }