Skip to content
Snippets Groups Projects
Commit 3b9e29ce authored by Benjamin Halimi's avatar Benjamin Halimi
Browse files

modify normalizeActivations() for MatMul support

parent 2193615f
No related branches found
No related tags found
3 merge requests!54Update 0.3.1 -> 0.4.0,!49Forked from add_matmul (merged automatically),!45Add support for the MatMul operator
Pipeline #67828 passed
......@@ -796,26 +796,35 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_m
if (isMerging(node))
{
std::vector<std::shared_ptr<Node>> mergingNodes = node->getParents();
// Compute the max ratio ...
double maxRatio = 0;
for (std::shared_ptr<Node> mergingNode : mergingNodes)
if (node->type() == "MatMul")
{
double mergingNodeRatio = accumulatedRatios[mergingNode];
if (mergingNodeRatio > maxRatio)
maxRatio = mergingNodeRatio;
double leftRatio = accumulatedRatios[node->getParent(0)];
double rightRatio = accumulatedRatios[node->getParent(1)];
accumulatedRatios[node] = leftRatio * rightRatio;
}
else
{
std::vector<std::shared_ptr<Node>> mergingNodes = node->getParents();
accumulatedRatios[node] = maxRatio;
// Compute the max ratio ...
for (std::shared_ptr<Node> mergingNode : mergingNodes)
{
double mergingNodeRatio = accumulatedRatios[mergingNode];
std::shared_ptr<Node> scalingNode = getPreviousScalingNode(mergingNode);
multiplyScalingFactor(scalingNode, mergingNodeRatio / maxRatio);
// Log::notice(" SCALING NODE : {} {}", scalingNode->type(), scalingNode->name());
double maxRatio = 0;
for (std::shared_ptr<Node> mergingNode : mergingNodes)
{
double mergingNodeRatio = accumulatedRatios[mergingNode];
if (mergingNodeRatio > maxRatio)
maxRatio = mergingNodeRatio;
}
accumulatedRatios[node] = maxRatio;
for (std::shared_ptr<Node> mergingNode : mergingNodes)
{
double mergingNodeRatio = accumulatedRatios[mergingNode];
std::shared_ptr<Node> scalingNode = getPreviousScalingNode(mergingNode);
multiplyScalingFactor(scalingNode, mergingNodeRatio / maxRatio);
// Log::notice(" SCALING NODE : {} {}", scalingNode->type(), scalingNode->name());
}
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment