From 2bcba4fa0d36d0a2c82a50c833c38ee00c59b77e Mon Sep 17 00:00:00 2001 From: bhalimi <benjamin.halimi@cea.fr> Date: Thu, 27 Feb 2025 10:55:19 +0000 Subject: [PATCH 1/4] rework several parts of the PTQ code --- src/PTQ/PTQ.cpp | 343 +++++++++++++++++++++++++++--------------------- 1 file changed, 193 insertions(+), 150 deletions(-) diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp index fe717bb..637407c 100644 --- a/src/PTQ/PTQ.cpp +++ b/src/PTQ/PTQ.cpp @@ -50,6 +50,7 @@ bool isMerging(std::shared_ptr<Node> node) { return (mergingNodeTypes.find(node->type()) != mergingNodeTypes.end()); } + static int getInputIndex(std::shared_ptr<Node> node, std::shared_ptr<Node> parentNode) { int index = 0; @@ -58,27 +59,27 @@ static int getInputIndex(std::shared_ptr<Node> node, std::shared_ptr<Node> paren return index; } - -void multiplyScalingFactor(std::shared_ptr<Aidge::Node> node,double coeff) +void multiplyScalingFactor(std::shared_ptr<Aidge::Node> node, double coeff) { AIDGE_ASSERT(node->type() == "Mul" && (node->attributes()->hasAttr("quantization.ptq.isProducerScaling") || node->attributes()->hasAttr("quantization.ptq.isScaling")), "Cannot update the scaling factor on Node of type {} with no scaling tag",node->type()); auto scalingFactorTensor = std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(1); + std::shared_ptr<Tensor> fallback; const Tensor& localTensor = scalingFactorTensor->refCastFrom(fallback, DataType::Float64, "cpu"); double previousScalingFactor = localTensor.get<double>(0); - std::shared_ptr<Tensor> finalTensor = std::make_shared<Tensor>(Array1D<double, 1> {previousScalingFactor * coeff}); - node->input(1).first->getOperator()->setOutput(0, finalTensor); + + std::shared_ptr<Tensor> resultTensor = std::make_shared<Tensor>(Array1D<double, 1> {previousScalingFactor * coeff}); + node->input(1).first->getOperator()->setOutput(0, resultTensor); } -/* Util function to insert a node below another one already connected */ -void insertNodeBetween(std::shared_ptr<Node> parent, - std::shared_ptr<Node> newNode, - std::shared_ptr<GraphView> graphView) + +// Utility function that insert a node below another one already connected +static void insertChildren(std::shared_ptr<Node> parent, std::shared_ptr<Node> newNode, std::shared_ptr<GraphView> graphView) { // Checking the parents always have at least 1 children - AIDGE_ASSERT(parent->getChildren().size() > 0, "The parent node must have at least one child to insert a new node."); + AIDGE_ASSERT(parent->getChildren().size() > 0, " Parent node must have at least one child to insert a new node ! "); - // Retrieve children connection indexes + // Retreive children connection indexes std::vector<std::shared_ptr<Node>> nextNodes = parent->getChildren(0); std::vector<int> inputIndices(nextNodes.size()); for (std::size_t i = 0; i < nextNodes.size(); i++) { @@ -99,54 +100,20 @@ void insertNodeBetween(std::shared_ptr<Node> parent, graphView->add(newNode); } -bool insertRoundBelowProducer(std::shared_ptr<Node> node,std::shared_ptr<GraphView> graphView) +bool insertRoundBelowProducer(std::shared_ptr<Node> node, std::shared_ptr<GraphView> graphView) { - if(node->attributes()->hasAttr("quantization.ptq.isProducerScaling") && node->type() != "Round") + if (node->attributes()->hasAttr("quantization.ptq.isProducerScaling") && node->type() != "Round") { std::shared_ptr<Aidge::Node> roundNode = Round(node->name() + "_Round"); roundNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) roundNode->getOperator()->setBackend(node->getOperator()->backend()); - insertNodeBetween(node,roundNode,graphView); - roundNode->attributes()->addAttr("quantization.ptq.isProducerRounding",0.0); + insertChildren(node, roundNode, graphView); + roundNode->attributes()->addAttr("quantization.ptq.isProducerRounding", 0.0); return true; } return false; } -bool insertScalingBelowProducer(std::shared_ptr<Node> node,double scalingFactor, std::shared_ptr<GraphView> graphView) -{ - if(node->attributes()->hasAttr("quantization.ptq.isProducerRounding")) - { - //In this case we 'bump' the node to the one above him (an actual ProducerScaling) - // because the round node is not usable (only used when SSA is enabled) - node = node->getParent(0); - } - if(node->attributes()->hasAttr("quantization.ptq.isProducerScaling")) - { - // We accumulate the multiples scaling factors by multiplying the SF of the ProducerScaling node - // (adding new nodes each time would make the graph unusable) - multiplyScalingFactor(node,scalingFactor); - return true; - } - AIDGE_ASSERT(node->type() == "Producer","Cannot apply a scaling factor on node of type: {} which is not a producer", node->type()); - std::string scalingNodeName = makeUniqueName(node->name() + "_ProducerScaling", graphView); - - std::shared_ptr<Aidge::Node> scalingNode = Mul(scalingNodeName); - scalingNode->attributes()->addAttr("quantization.ptq.isProducerScaling",0.0); - - std::shared_ptr<Tensor> scalingFactorTensor = std::make_shared<Tensor>(Array1D<double, 1> {scalingFactor}); - std::shared_ptr<Node> scalingFactorProducer = addProducer(scalingNode, 1, {1}, "Factor"); - scalingFactorProducer->getOperator()->setOutput(0, scalingFactorTensor); - graphView->add(scalingFactorProducer); - - scalingNode->getOperator()->setDataType(DataType::Float64); - std::string producerBackend = std::static_pointer_cast<OperatorTensor>(node->getOperator())->getOutput(0)->backend(); - scalingNode->getOperator()->setBackend(producerBackend); - - insertNodeBetween(node, scalingNode, graphView); - - return true; -} bool checkArchitecture(std::shared_ptr<GraphView> graphView) { @@ -212,6 +179,7 @@ static std::vector<std::shared_ptr<Node>> removeMatchingNodes(std::vector<std::s return remainingNodes; } + static std::vector<std::shared_ptr<Node>> removeProdScalingNodes(std::vector<std::shared_ptr<Node>> nodeVector) { std::vector<std::shared_ptr<Node>> remainingNodes; @@ -291,8 +259,7 @@ void prepareNetwork(std::shared_ptr<GraphView> graphView) std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView); for (std::shared_ptr<Node> node : nodeVector) - if (node->type() == "BatchNorm") - { + if (node->type() == "BatchNorm") { containsBatchNorm = true; break; } @@ -310,8 +277,58 @@ static DataType getDataType(std::shared_ptr<Node> node) return op->getOutput(0)->dataType(); } +static std::shared_ptr<Aidge::Node> createScalingNode(std::string name, std::vector<std::string> attributes, double value, std::shared_ptr<GraphView> graphView) +{ + std::shared_ptr<Node> scalingNode = Mul(name); + + for (std::string attr : attributes) + scalingNode->attributes()->addAttr("quantization.ptq." + attr, 0.0); + + // Add the scaling factor as a producer of the node + + std::shared_ptr<Tensor> scalingFactorTensor = std::make_shared<Tensor>(Array1D<double, 1> {value}); + std::shared_ptr<Node> scalingFactorProducer = addProducer(scalingNode, 1, {1}, "ScalingFactor"); + + scalingFactorProducer->getOperator()->setOutput(0, scalingFactorTensor); + + graphView->add(scalingFactorProducer); + + return scalingNode; +} + +bool insertScalingBelowProducer(std::shared_ptr<Node> producerNode, double scalingFactor, std::shared_ptr<GraphView> graphView) +{ + if (producerNode->attributes()->hasAttr("quantization.ptq.isProducerRounding")) + { + // In this case we 'bump' the node to the one above him (an actual ProducerScaling) + // because the round node is not usable (only used when SSA is enabled) + producerNode = producerNode->getParent(0); + } + + if (producerNode->attributes()->hasAttr("quantization.ptq.isProducerScaling")) + { + // We accumulate the previous scaling factors by multiplying the SF of the ProducerScaling node + // (adding new nodes each time would make the graph unusable) + multiplyScalingFactor(producerNode, scalingFactor); + return true; + } + + AIDGE_ASSERT(producerNode->type() == "Producer", " Cannot apply a scaling factor on node of type: {} which is not a Producer", producerNode->type()); + + std::string scalingNodeName = makeUniqueName(producerNode->name() + "_ProducerScaling", graphView); + std::shared_ptr<Node> scalingNode = createScalingNode(scalingNodeName, {"isProducerScaling"}, scalingFactor, graphView); + + scalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) + auto producerOp = std::static_pointer_cast<OperatorTensor>(producerNode->getOperator()); + scalingNode->getOperator()->setBackend(producerOp->getOutput(0)->backend()); + + insertChildren(producerNode, scalingNode, graphView); + + return true; +} + // XXX HERE : Branches containing only Seamless nodes should be considered as residual too !!! -void insertResidualNodes(std::shared_ptr<GraphView> graphView) +void insertResidualScalingNodes(std::shared_ptr<GraphView> graphView) { // TODO: double check this ... @@ -330,92 +347,91 @@ void insertResidualNodes(std::shared_ptr<GraphView> graphView) if (parentIsForking) { // temporary verbose ... + Log::info(" ### found residual branch at index {}", i); Log::info(" ### inserting multiplicative node ..."); std::string residualNodeName = makeUniqueName(parentNode->name() + "_Res", graphView); - std::shared_ptr<Node> residualNode = Mul(residualNodeName); - residualNode->attributes()->addAttr("quantization.ptq.isScaling", 0.0); - residualNode->attributes()->addAttr("quantization.ptq.isResidual", 0.0); - - //Adding the SF as a producer of the node - std::shared_ptr<Tensor> scalingFactorTensor = std::make_shared<Tensor>(Array1D<double, 1> {1.0}); - std::shared_ptr<Node> scalingFactorProducer = addProducer(residualNode, 1, {1}, "ScalingFactor"); - scalingFactorProducer->getOperator()->setOutput(0, scalingFactorTensor); + std::shared_ptr<Node> residualNode = createScalingNode(residualNodeName, {"isScaling", "isResidual"}, 1.0, graphView); residualNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) residualNode->getOperator()->setBackend(parentNode->getOperator()->backend()); graphView->insertParent(node, residualNode, i, 0, 0); - graphView->add(scalingFactorProducer); } } } } } +static std::shared_ptr<Node> getPreviousScalingNode(std::shared_ptr<Node> node) +{ + std::shared_ptr<Node> currNode = node; + while(!currNode->attributes()->hasAttr("quantization.ptq.isScaling")) + { + if (currNode->getParents().size() == 0) + { + Log::warn(" Warning : No previous Scaling node were found ! "); + break; + } + currNode = currNode->getParents()[0]; + } + return currNode; +} void insertScalingNodes(std::shared_ptr<GraphView> graphView) { - insertResidualNodes(graphView); + insertResidualScalingNodes(graphView); std::set<std::shared_ptr<Node>> nodeSet = graphView->getNodes(); for (std::shared_ptr<Node> parentNode : nodeSet) { - if (isAffine(parentNode) || isMerging(parentNode)) + if (isAffine(parentNode) || isMerging(parentNode) || (parentNode->type() == "Sigmoid")) { std::string scalingNodeName = makeUniqueName(parentNode->name() + "_Scaling", graphView); - //std::shared_ptr<Node> scalingNode = Scaling(1.0, scalingNodeName); - - //Adding Mul operator with tag "quantization.ptq.isScaling" - std::shared_ptr<Aidge::Node> scalingNode = Mul(scalingNodeName); - scalingNode->attributes()->addAttr("quantization.ptq.isScaling",0.0); - - //Adding the SF as a producer of the node - std::shared_ptr<Tensor> scalingFactorTensor = std::make_shared<Tensor>(Array1D<double, 1> {1.0}); - std::shared_ptr<Node> scalingFactorProducer = addProducer(scalingNode, 1, {1}, "ScalingFactor"); - scalingFactorProducer->getOperator()->setOutput(0, scalingFactorTensor); + std::shared_ptr<Node> scalingNode = createScalingNode(scalingNodeName, {"isScaling"}, 1.0, graphView); scalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) scalingNode->getOperator()->setBackend(parentNode->getOperator()->backend()); - if (parentNode->getChildren().size() > 0) - { - insertNodeBetween(parentNode,scalingNode,graphView); - graphView->add(scalingFactorProducer); - } - else - { + if (parentNode->getChildren().size() > 0) { + insertChildren(parentNode, scalingNode, graphView); + } else { // Log::info(" last node reached ! "); parentNode->addChild(scalingNode, 0, 0); - graphView->add(scalingFactorProducer); graphView->add(scalingNode); } - } - } -} -static std::shared_ptr<Node> getPreviousScalingNode(std::shared_ptr<Node> mergingNode) -{ - std::shared_ptr<Node> currNode = mergingNode; - while(!currNode->attributes()->hasAttr("quantization.ptq.isScaling")) - { - if (currNode->getParents().size() == 0) - { - Log::warn(" Warning : No previous Scaling node were found ! "); - break; + // Non linear function handling starts here ! + + if (parentNode->type() == "Sigmoid") + { + // If the parent is a forking Scaling node, we need an extra Scaling + // node to completely isolate the non linearity ... + + std::shared_ptr<Node> prevScalingNode = getPreviousScalingNode(parentNode); + bool prevScalingNodeIsForking = (prevScalingNode->getChildren().size() > 1); + + if (prevScalingNodeIsForking) + { + std::string prevScalingNodeName = makeUniqueName(parentNode->name() + "_PrevScaling", graphView); + prevScalingNode = createScalingNode(prevScalingNodeName, {"isScaling"}, 1.0, graphView); + + prevScalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) + prevScalingNode->getOperator()->setBackend(parentNode->getOperator()->backend()); + + graphView->insertParent(parentNode, prevScalingNode, 0, 0, 0); + } + } } - currNode = currNode->getParents()[0]; } - return currNode; } // XXX double check this ! static bool nodeHasBias(std::shared_ptr<Node> node) { - if (node->getParents().size() == 3) - { + if (node->getParents().size() == 3) { std::shared_ptr<Tensor> biasTensor = getBiasTensor(node); if (biasTensor) return true; @@ -453,19 +469,19 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView) if (isAffine(node)) { // Rescale the weight tensor + std::shared_ptr<Tensor> weightTensor = getWeightTensor(node); double scaling = getTensorAbsoluteMax(weightTensor); double ratio = 1.0 / scaling; + //rescaleTensor(weightTensor, ratio); insertScalingBelowProducer(node->getParent(1), ratio, graphView); // Accumulate the ratio - if (node == firstNode) - { + + if (node == firstNode) { accumulatedRatios[node] = ratio; - } - else - { + } else { std::shared_ptr<Node> prevNode = node->getParent(0); accumulatedRatios[node] = accumulatedRatios[prevNode] * ratio; } @@ -480,11 +496,30 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView) } } + if (node->type() == "Sigmoid") + { + // Gather the previous scaling factor + + std::shared_ptr<Node> prevScalingNode = getPreviousScalingNode(node); + double prevRatio = accumulatedRatios[prevScalingNode]; + + // Cancel the accumulated ratio + + multiplyScalingFactor(prevScalingNode, 1 / prevRatio); + + // Revert the canceling by using the next scaling node + + accumulatedRatios[node] = prevRatio; + std::shared_ptr<Node> nextScalingNode = node->getChildren(0)[0]; + multiplyScalingFactor(nextScalingNode, prevRatio); + } + if (isMerging(node)) { std::vector<std::shared_ptr<Node>> mergingNodes = node->getParents(); // Compute the max ratio ... + double maxRatio = 0; for (std::shared_ptr<Node> mergingNode : mergingNodes) { @@ -503,7 +538,7 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView) std::shared_ptr<Node> scalingNode = getPreviousScalingNode(mergingNode); - multiplyScalingFactor(scalingNode,1/rescaling); + multiplyScalingFactor(scalingNode, 1 / rescaling); accumulatedRatios[mergingNode] /= rescaling; // optional ... } @@ -963,39 +998,47 @@ static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::u for (std::shared_ptr<Node> node : nodeVector) { - // A merging node is always followed by a Quantizer node at this point + // The appropriate strategy is to check if the Quantizer is not + // preceded by an Weighted node (that is not forking), and insert + // a coeff node (Compensation) if so ... + + if (node->type() == "Quantizer") + { + // Note : this works because a Quantizer has only one Parent ... - if (node->type() == "Quantizer" && (node->attributes()->hasAttr("quantization.ptq.isResidual") || !isAffine(node->getParent(0)))) - { + std::shared_ptr<Node> parentNode = node->getParent(0); + bool parentHasWeight = isAffine(parentNode); + bool parentIsForking = (parentNode->getChildren().size() > 1); - // check if the Quantizer is a residual one, and insert a compensation node if so ... - // create and insert the multplicative node before the Quantizer + if (parentIsForking || !parentHasWeight) // insert a Compensation Node ... + { + // Create and insert the multplicative node before the Quantizer - std::string mulNodeName = makeUniqueName(node->name() + "_Mul", graphView); - std::shared_ptr<Node> mulNode = Mul(mulNodeName); - - mulNode->attributes()->addAttr("quantization.ptq.isCompensation",0.0); - mulNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) - mulNode->getOperator()->setBackend(node->getOperator()->backend()); + std::string mulNodeName = makeUniqueName(node->name() + "_Mul", graphView); + std::shared_ptr<Node> mulNode = Mul(mulNodeName); + + mulNode->attributes()->addAttr("quantization.ptq.isCompensation", 0.0); + mulNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) + mulNode->getOperator()->setBackend(node->getOperator()->backend()); - graphView->insertParent(node, mulNode, 0, 0, 0); + graphView->insertParent(node, mulNode, 0, 0, 0); - // Add the coeff producer to the multiplier node + // Add the coeff producer to the multiplier node - std::shared_ptr<Node> coeffProducer = addProducer(mulNode, 1, {1}, ""); - std::shared_ptr<Tensor> coeffTensor = std::make_shared<Tensor>(Array1D<double, 1> {signedMax}); - coeffProducer->getOperator()->setOutput(0, coeffTensor); + std::shared_ptr<Node> coeffProducer = addProducer(mulNode, 1, {1}, ""); + std::shared_ptr<Tensor> coeffTensor = std::make_shared<Tensor>(Array1D<double, 1> {signedMax}); + coeffProducer->getOperator()->setOutput(0, coeffTensor); - coeffProducer->getOperator()->setDataType(DataType::Float64); - coeffProducer->getOperator()->setBackend(node->getOperator()->backend()); + coeffProducer->getOperator()->setDataType(DataType::Float64); + coeffProducer->getOperator()->setBackend(node->getOperator()->backend()); - graphView->add(coeffProducer); // needed ? + graphView->add(coeffProducer); // needed ? - // Adapt the scaling factor value accordingly + // Adapt the scaling factor value accordingly - double currScalingFactor = getScalingFactor(node); - updateScalingFactor(node, currScalingFactor / signedMax); - + double currScalingFactor = getScalingFactor(node); + updateScalingFactor(node, currScalingFactor / signedMax); + } } } } @@ -1006,33 +1049,33 @@ void performSingleShiftApproximation(std::shared_ptr<GraphView> graphView, bool for (std::shared_ptr<Node> node : nodeVector) { - if (isAffine(node) || (node->type() == "Mul" && node->attributes()->hasAttr("quantization.ptq.isCompensation"))) + if (node->type() == "Quantizer") { - std::shared_ptr<Node> scalingNode = (*node->getChildren().begin()); // TODO : use index = 0 + std::shared_ptr<Node> linearNode = node->getParent(0); - double base = getScalingFactor(scalingNode); + double base = getScalingFactor(node); double approx = std::pow(2, std::ceil(std::log2(base))); - updateScalingFactor(scalingNode,approx); + updateScalingFactor(node, approx); double ratio = base / approx; - insertScalingBelowProducer(node->getParent(1),ratio,graphView); + insertScalingBelowProducer(linearNode->getParent(1), ratio, graphView); if (!noQuant) - insertRoundBelowProducer(node->getParent(1),graphView); + insertRoundBelowProducer(linearNode->getParent(1), graphView); - if (nodeHasBias(node)) + if (nodeHasBias(linearNode)) { - insertScalingBelowProducer(node->getParent(2),ratio,graphView); - + insertScalingBelowProducer(linearNode->getParent(2), ratio, graphView); if (!noQuant) - insertRoundBelowProducer(node->getParent(2),graphView); + insertRoundBelowProducer(linearNode->getParent(2), graphView); } } } } + static void printScalingFactors(std::shared_ptr<GraphView> graphView) { for (auto node : retrieveNodeVector(graphView)) @@ -1060,48 +1103,48 @@ static void setupDataType(std::shared_ptr<GraphView> graphView, std::vector<std: void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet, Clipping clippingMode, bool noQuant, bool optimizeSigns, bool singleShift, bool useCuda, bool verbose) { - Log::info(" === QUANT PTQ 0.2.21 === "); + Log::notice(" === QUANT PTQ 0.2.21 === "); graphView->setBackend("cpu"); - DataType initialDataType = (inputDataSet[0])->dataType(); - setupDataType(graphView, inputDataSet, DataType::Float64); - if (!checkArchitecture(graphView)) return; - Log::info(" Preparing the network for the PTQ ... "); + DataType initialDataType = (inputDataSet[0])->dataType(); + setupDataType(graphView, inputDataSet, DataType::Float64); + + Log::notice(" Preparing the network for the PTQ ... "); prepareNetwork(graphView); - Log::info(" Inserting the scaling nodes ..."); + Log::notice(" Inserting the scaling nodes ..."); insertScalingNodes(graphView); crossLayerEqualization(graphView); - Log::info(" Normalizing the parameters ..."); + Log::notice(" Normalizing the parameters ..."); normalizeParameters(graphView); - Log::info(" Computing the value ranges ..."); + Log::notice(" Computing the value ranges ..."); std::unordered_map<std::shared_ptr<Node>, double> valueRanges = computeRanges(graphView, inputDataSet, true, useCuda); //Log::info(" === RANGES (BEFORE ADJUST) ==="); - Log::info(" Optimizing the clipping values ..."); + Log::notice(" Optimizing the clipping values ..."); valueRanges = adjustRanges(clippingMode, valueRanges, nbBits, graphView, inputDataSet, useCuda, verbose); //Log:debug("=== RANGES (AFTER ADJUST) ==="); //printRanges(graphView, valueRanges); - Log::info(" Normalizing the activations ..."); + Log::notice(" Normalizing the activations ..."); normalizeActivations(graphView, valueRanges); - Log::info(" Quantizing the normalized network ..."); + Log::notice(" Quantizing the normalized network ..."); quantizeNormalizedNetwork(graphView, nbBits, noQuant, optimizeSigns, verbose); if (singleShift) { - Log::info( " Inserting the compensation nodes ..."); + Log::notice( " Inserting the compensation nodes ..."); insertCompensationNodes(graphView, nbBits); - Log::info(" Performing the Single-Shift approximation ..."); + Log::notice(" Performing the Single-Shift approximation ..."); performSingleShiftApproximation(graphView, noQuant); } @@ -1111,11 +1154,11 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, if (useCuda) graphView->setBackend("cuda"); - Log::info(" Reseting the scheduler ..."); + Log::notice(" Reseting the scheduler ..."); SequentialScheduler scheduler(graphView); scheduler.resetScheduling(); - Log::info(" Network is quantized !"); + Log::notice(" Network is quantized !"); } std::unordered_map<std::string, double> getWeightRanges(std::shared_ptr<GraphView> graphView) -- GitLab From 0c0e9122886b4800831bdf25e26ed5fab28f65db Mon Sep 17 00:00:00 2001 From: bhalimi <benjamin.halimi@cea.fr> Date: Thu, 27 Feb 2025 13:43:57 +0000 Subject: [PATCH 2/4] minor changes --- src/PTQ/PTQ.cpp | 69 +++++++++++++++++++------------------------------ 1 file changed, 27 insertions(+), 42 deletions(-) diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp index 637407c..14cc8cc 100644 --- a/src/PTQ/PTQ.cpp +++ b/src/PTQ/PTQ.cpp @@ -291,6 +291,7 @@ static std::shared_ptr<Aidge::Node> createScalingNode(std::string name, std::vec scalingFactorProducer->getOperator()->setOutput(0, scalingFactorTensor); + graphView->add(scalingNode); graphView->add(scalingFactorProducer); return scalingNode; @@ -653,14 +654,14 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_m { std::shared_ptr<Node> firstNode = getFirstNode(graphView); - // CREATE THE SCALING FACTOR MAP ////////////////////////////////////////// + // CREATE THE ACCUMULATED RATIO MAP /////////////////////////////////////// std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView); - std::unordered_map<std::shared_ptr<Node>, double> scalingFactors; + std::unordered_map<std::shared_ptr<Node>, double> accumulatedRatios; for (std::shared_ptr<Node> node : nodeVector) - scalingFactors.insert(std::make_pair(node, 1.0)); + accumulatedRatios.insert(std::make_pair(node, 1.0)); // ITERATE OVER THE GRAPH ///////////////////////////////////////////////// @@ -670,45 +671,32 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_m if (isAffine(node) || isSeamless(node) || node->type() == "ReLU") { - if (node == firstNode) - { - scalingFactors[node] = 1.0; - } - else - { + if (node == firstNode) { + accumulatedRatios[node] = 1.0; + } else { std::shared_ptr<Node> prevNode = node->getParent(0); - scalingFactors[node] = scalingFactors[prevNode]; + accumulatedRatios[node] = accumulatedRatios[prevNode]; } } - // Here prevNode is either a 'Affine' or a 'Merging' - // => do not split the cases, just handle the bias ... + // Use the Scaling nodes to rescale the ranges ... if (node->attributes()->hasAttr("quantization.ptq.isScaling")) { - - // retrieve the previous scaling factor ... std::shared_ptr<Node> prevNode = node->getParent(0); - double prevScalingFactor = scalingFactors[prevNode]; - // ValueRanges must contains all the scaling nodes !!! - double scalingFactor = valueRanges[node]; + double prevRatio = accumulatedRatios[prevNode]; + double nodeRange = valueRanges[node]; - multiplyScalingFactor(node, 1 / (scalingFactor / prevScalingFactor)); + multiplyScalingFactor(node, prevRatio / nodeRange); - scalingFactors[node] = scalingFactor; + accumulatedRatios[node] = nodeRange; // If prevNode is Affine, fix the bias ... - if (isAffine(prevNode)) - { - - bool prevNodeHasBias = nodeHasBias(prevNode); - if (prevNodeHasBias) - { - std::shared_ptr<Tensor> biasTensor = getBiasTensor(prevNode); - //rescaleTensor(biasTensor, 1.0 / prevScalingFactor); - insertScalingBelowProducer(prevNode->getParent(2), 1.0 / prevScalingFactor, graphView); + if (isAffine(prevNode)) { + if (nodeHasBias(prevNode)) { + insertScalingBelowProducer(prevNode->getParent(2), 1.0 / prevRatio, graphView); } } } @@ -719,27 +707,24 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_m { std::vector<std::shared_ptr<Node>> mergingNodes = node->getParents(); - // Compute the max scaling ... - double maxScaling = 0; - for (std::size_t i = 0; i < mergingNodes.size(); i++) + // Compute the max ratio ... + + double maxRatio = 0; + for (std::shared_ptr<Node> mergingNode : mergingNodes) { - double mergingNodeScaling = scalingFactors[mergingNodes[i]]; - if (mergingNodeScaling > maxScaling) { - maxScaling = mergingNodeScaling; - } + double mergingNodeRatio = accumulatedRatios[mergingNode]; + if (mergingNodeRatio > maxRatio) + maxRatio = mergingNodeRatio; } - scalingFactors[node] = maxScaling; + accumulatedRatios[node] = maxRatio; for (std::shared_ptr<Node> mergingNode : mergingNodes) { - double mergingNodeScaling = scalingFactors[mergingNode]; - double rescaling = mergingNodeScaling / maxScaling; - + double mergingNodeRatio = accumulatedRatios[mergingNode]; std::shared_ptr<Node> scalingNode = getPreviousScalingNode(mergingNode); - //Log::info(" SCALING NODE : {} {}", scalingNode->type(), scalingNode->name()); - - multiplyScalingFactor(scalingNode, rescaling) ; + multiplyScalingFactor(scalingNode, mergingNodeRatio / maxRatio); + // Log::notice(" SCALING NODE : {} {}", scalingNode->type(), scalingNode->name()); } } } -- GitLab From 851d95fd51bfc3e0c4c4471ce42493791e6e6e32 Mon Sep 17 00:00:00 2001 From: bhalimi <benjamin.halimi@cea.fr> Date: Fri, 28 Feb 2025 15:13:49 +0000 Subject: [PATCH 3/4] add support for not-quantized operators (Sigmoid, Tanh, ...) --- include/aidge/quantization/PTQ/PTQ.hpp | 12 ++++ src/PTQ/CLE.cpp | 20 ++++-- src/PTQ/PTQ.cpp | 95 ++++++++++++++++++-------- 3 files changed, 91 insertions(+), 36 deletions(-) diff --git a/include/aidge/quantization/PTQ/PTQ.hpp b/include/aidge/quantization/PTQ/PTQ.hpp index 4bfe65f..f55894c 100644 --- a/include/aidge/quantization/PTQ/PTQ.hpp +++ b/include/aidge/quantization/PTQ/PTQ.hpp @@ -41,6 +41,11 @@ namespace Aidge { */ static const std::set<std::string> mergingNodeTypes({"Add", "Concat", "Sub"}); + /** + * @brief Set of the types of the nodes that won't be quanized + */ + static const std::set<std::string> notQuantizedNodeTypes({"Sigmoid", "Tanh"}); + /** * @brief Determine if a node contains an affine transform (that is Y = A.X + B) * @param node The node to be checked @@ -62,6 +67,13 @@ namespace Aidge { */ bool isMerging(std::shared_ptr<Node> node); + /** + * @brief Determine if a node contains an operator that won't be quantized + * @param node The node to be checked + * @return True if the node is not quantized, else false. + */ + bool isNotQuantized(std::shared_ptr<Node> node); + /** * @brief Retrieve the scheduled vector of node of a graphView, without the Producer nodes. * @param graphView The graphView containing the nodes diff --git a/src/PTQ/CLE.cpp b/src/PTQ/CLE.cpp index 2738f8a..7115a2f 100644 --- a/src/PTQ/CLE.cpp +++ b/src/PTQ/CLE.cpp @@ -125,14 +125,20 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetD std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView); // Check if the CLE can be applied ... + for (std::shared_ptr<Node> node : nodeVector) - if (node->getChildren().size() > 1) - { - Log::notice("Network have multiple branches, skipping the CLE ... "); + { + if (node->getChildren().size() > 1) { + Log::notice(" Network have multiple branches, skipping the CLE ... "); return; } + if (isNotQuantized(node)) { + Log::notice(" Network contains non linear nodes, skipping the CLE ... "); + return; + } + } - Log::info("Applying the Cross-Layer Equalization ... "); + Log::info(" Applying the Cross-Layer Equalization ... "); // Get the vector of affine nodes @@ -161,9 +167,9 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetD double s1 = std::sqrt(r1 * r2) / r1; double s2 = std::sqrt(r1 * r2) / r2; - insertScalingBelowProducer(n1->getParent(1),s1,graphView); - insertScalingBelowProducer(n2->getParent(1),s2,graphView); - insertScalingBelowProducer(n1->getParent(2),s1,graphView); + insertScalingBelowProducer(n1->getParent(1), s1, graphView); + insertScalingBelowProducer(n2->getParent(1), s2, graphView); + insertScalingBelowProducer(n1->getParent(2), s1, graphView); double rangeDelta = std::abs(r1 - r2); if (rangeDelta > maxRangeDelta) diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp index 14cc8cc..f16001b 100644 --- a/src/PTQ/PTQ.cpp +++ b/src/PTQ/PTQ.cpp @@ -51,6 +51,11 @@ bool isMerging(std::shared_ptr<Node> node) return (mergingNodeTypes.find(node->type()) != mergingNodeTypes.end()); } +bool isNotQuantized(std::shared_ptr<Node> node) +{ + return (notQuantizedNodeTypes.find(node->type()) != notQuantizedNodeTypes.end()); +} + static int getInputIndex(std::shared_ptr<Node> node, std::shared_ptr<Node> parentNode) { int index = 0; @@ -62,7 +67,8 @@ static int getInputIndex(std::shared_ptr<Node> node, std::shared_ptr<Node> paren void multiplyScalingFactor(std::shared_ptr<Aidge::Node> node, double coeff) { AIDGE_ASSERT(node->type() == "Mul" && (node->attributes()->hasAttr("quantization.ptq.isProducerScaling") || node->attributes()->hasAttr("quantization.ptq.isScaling")), - "Cannot update the scaling factor on Node of type {} with no scaling tag",node->type()); + "Cannot update the scaling factor on Node of type {} with no scaling tag", node->type()); + auto scalingFactorTensor = std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(1); std::shared_ptr<Tensor> fallback; @@ -122,7 +128,7 @@ bool checkArchitecture(std::shared_ptr<GraphView> graphView) for (std::shared_ptr<Node> node : graphView->getNodes()) { bool isOther = otherNodeTypes.find(node->type()) != otherNodeTypes.end(); - if (!isOther && !isAffine(node) && !isSeamless(node) && !isMerging(node)) { + if (!isOther && !isAffine(node) && !isSeamless(node) && !isMerging(node) && !isNotQuantized(node)) { Log::warn(" GraphView can't be quantized : node type {} is not supported !", node->type()); return false; } @@ -277,7 +283,7 @@ static DataType getDataType(std::shared_ptr<Node> node) return op->getOutput(0)->dataType(); } -static std::shared_ptr<Aidge::Node> createScalingNode(std::string name, std::vector<std::string> attributes, double value, std::shared_ptr<GraphView> graphView) +static std::shared_ptr<Aidge::Node> createScalingNode(std::string name, std::vector<std::string> attributes, double value) { std::shared_ptr<Node> scalingNode = Mul(name); @@ -291,8 +297,7 @@ static std::shared_ptr<Aidge::Node> createScalingNode(std::string name, std::vec scalingFactorProducer->getOperator()->setOutput(0, scalingFactorTensor); - graphView->add(scalingNode); - graphView->add(scalingFactorProducer); + // XXX graphView->add(scalingNode); return scalingNode; } @@ -317,13 +322,14 @@ bool insertScalingBelowProducer(std::shared_ptr<Node> producerNode, double scali AIDGE_ASSERT(producerNode->type() == "Producer", " Cannot apply a scaling factor on node of type: {} which is not a Producer", producerNode->type()); std::string scalingNodeName = makeUniqueName(producerNode->name() + "_ProducerScaling", graphView); - std::shared_ptr<Node> scalingNode = createScalingNode(scalingNodeName, {"isProducerScaling"}, scalingFactor, graphView); + std::shared_ptr<Node> scalingNode = createScalingNode(scalingNodeName, {"isProducerScaling"}, scalingFactor); scalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) auto producerOp = std::static_pointer_cast<OperatorTensor>(producerNode->getOperator()); scalingNode->getOperator()->setBackend(producerOp->getOutput(0)->backend()); insertChildren(producerNode, scalingNode, graphView); + graphView->add(scalingNode->getParent(1)); // add the scaling factor producer return true; } @@ -353,12 +359,14 @@ void insertResidualScalingNodes(std::shared_ptr<GraphView> graphView) Log::info(" ### inserting multiplicative node ..."); std::string residualNodeName = makeUniqueName(parentNode->name() + "_Res", graphView); - std::shared_ptr<Node> residualNode = createScalingNode(residualNodeName, {"isScaling", "isResidual"}, 1.0, graphView); + std::shared_ptr<Node> residualNode = createScalingNode(residualNodeName, {"isScaling", "isResidual"}, 1.0); residualNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) residualNode->getOperator()->setBackend(parentNode->getOperator()->backend()); graphView->insertParent(node, residualNode, i, 0, 0); + graphView->add(residualNode->getParent(1)); // add the scaling factor producer + } } } @@ -388,10 +396,10 @@ void insertScalingNodes(std::shared_ptr<GraphView> graphView) for (std::shared_ptr<Node> parentNode : nodeSet) { - if (isAffine(parentNode) || isMerging(parentNode) || (parentNode->type() == "Sigmoid")) + if (isAffine(parentNode) || isMerging(parentNode) || isNotQuantized(parentNode)) { std::string scalingNodeName = makeUniqueName(parentNode->name() + "_Scaling", graphView); - std::shared_ptr<Node> scalingNode = createScalingNode(scalingNodeName, {"isScaling"}, 1.0, graphView); + std::shared_ptr<Node> scalingNode = createScalingNode(scalingNodeName, {"isScaling"}, 1.0); scalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) scalingNode->getOperator()->setBackend(parentNode->getOperator()->backend()); @@ -404,27 +412,23 @@ void insertScalingNodes(std::shared_ptr<GraphView> graphView) graphView->add(scalingNode); } - // Non linear function handling starts here ! + graphView->add(scalingNode->getParent(1)); // add the scaling factor producer - if (parentNode->type() == "Sigmoid") - { - // If the parent is a forking Scaling node, we need an extra Scaling - // node to completely isolate the non linearity ... + // In the case the node is a non-linear operator we want to add an extra + // scaling node before it to rescale it's input ... - std::shared_ptr<Node> prevScalingNode = getPreviousScalingNode(parentNode); - bool prevScalingNodeIsForking = (prevScalingNode->getChildren().size() > 1); + if (isNotQuantized(parentNode)) + { + std::string prevScalingNodeName = makeUniqueName(parentNode->name() + "_PrevScaling", graphView); + std::shared_ptr<Node> prevScalingNode = createScalingNode(prevScalingNodeName, {"isScaling"}, 1.0); - if (prevScalingNodeIsForking) - { - std::string prevScalingNodeName = makeUniqueName(parentNode->name() + "_PrevScaling", graphView); - prevScalingNode = createScalingNode(prevScalingNodeName, {"isScaling"}, 1.0, graphView); - - prevScalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) - prevScalingNode->getOperator()->setBackend(parentNode->getOperator()->backend()); + prevScalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) + prevScalingNode->getOperator()->setBackend(parentNode->getOperator()->backend()); - graphView->insertParent(parentNode, prevScalingNode, 0, 0, 0); - } + graphView->insertParent(parentNode, prevScalingNode, 0, 0, 0); + graphView->add(prevScalingNode->getParent(1)); // add the scaling factor producer } + } } } @@ -497,7 +501,7 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView) } } - if (node->type() == "Sigmoid") + if (isNotQuantized(node)) { // Gather the previous scaling factor @@ -727,6 +731,20 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_m // Log::notice(" SCALING NODE : {} {}", scalingNode->type(), scalingNode->name()); } } + + if (isNotQuantized(node)) + { + std::shared_ptr<Node> prevScalingNode = node->getParent(0); + double prevRatio = accumulatedRatios[prevScalingNode]; + Log::notice(" prev ratio : {} ", prevRatio); + + // This causes the previous range to not full fill the [-1, 1] interval !!! + // It could be avoided by systematicly add an extra Scaling node before each + // non linearity ... + + multiplyScalingFactor(prevScalingNode, prevRatio); + } + } } @@ -821,7 +839,7 @@ std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> computeSignMap( } } - if (node->type() == "ReLU" || isSeamless(node)) + if (node->type() == "ReLU" || isSeamless(node) || isNotQuantized(node)) { // Thoses nodes always have a single parent std::shared_ptr<Node> parent = node->getParent(0); @@ -934,12 +952,30 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_ multiplyScalingFactor(scalingNode, rescaling) ; } + + if (isNotQuantized(node)) + { + double rescaling = 1 / signedMax; // XXX handle the signs !!! + + std::shared_ptr<Node> prevScalingNode = node->getParent(0); + multiplyScalingFactor(prevScalingNode, rescaling); + + std::shared_ptr<Node> nextScalingNode = node->getChildren(0)[0]; + multiplyScalingFactor(nextScalingNode, 1 / rescaling); + } // Handle the Scaling Nodes ... if (node->attributes()->hasAttr("quantization.ptq.isScaling")) { - if (!noQuant) + // Don't touch the scalings that precede non-linearities ... + + bool precedesNonLinearNode = false; + if (node->getChildren().size() > 0) + if (isNotQuantized(node->getChildren(0)[0])) + precedesNonLinearNode = true; + + if (!noQuant && !precedesNonLinearNode) { // Replace the Scaling Node by a Quantizer @@ -997,7 +1033,7 @@ static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::u if (parentIsForking || !parentHasWeight) // insert a Compensation Node ... { - // Create and insert the multplicative node before the Quantizer + // Create and insert the multiplicative node before the Quantizer std::string mulNodeName = makeUniqueName(node->name() + "_Mul", graphView); std::shared_ptr<Node> mulNode = Mul(mulNodeName); @@ -1105,6 +1141,7 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, insertScalingNodes(graphView); crossLayerEqualization(graphView); + Log::notice(" Normalizing the parameters ..."); normalizeParameters(graphView); -- GitLab From abb4162e1baab1dacff87ddd5762211a4135f0a8 Mon Sep 17 00:00:00 2001 From: bhalimi <benjamin.halimi@cea.fr> Date: Fri, 28 Feb 2025 16:37:12 +0000 Subject: [PATCH 4/4] minor changes --- src/PTQ/PTQ.cpp | 107 +++++++++++++++++++++++++++++------------------- 1 file changed, 65 insertions(+), 42 deletions(-) diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp index f16001b..22c3438 100644 --- a/src/PTQ/PTQ.cpp +++ b/src/PTQ/PTQ.cpp @@ -56,6 +56,61 @@ bool isNotQuantized(std::shared_ptr<Node> node) return (notQuantizedNodeTypes.find(node->type()) != notQuantizedNodeTypes.end()); } +bool checkArchitecture(std::shared_ptr<GraphView> graphView) +{ + std::set<std::string> otherNodeTypes({"Flatten", "Softmax", "BatchNorm2D", "ReLU", "Producer"}); + + std::set<std::string> notQuantizedNodesTypes; + + for (std::shared_ptr<Node> node : graphView->getNodes()) + { + bool isOther = otherNodeTypes.find(node->type()) != otherNodeTypes.end(); + if (!isOther && !isAffine(node) && !isSeamless(node) && !isMerging(node) && !isNotQuantized(node)) { + Log::warn(" GraphView can't be quantized : node type {} is not supported !", node->type()); + return false; + } + + if (isNotQuantized(node)) + notQuantizedNodesTypes.insert(node->type()); + } + + if (!notQuantizedNodesTypes.empty()) { + std::string tokens; + for (std::string s : notQuantizedNodesTypes) + tokens += (s + " "); + Log::warn(" Network contains non-linear nodes that won't be quantized : {}", tokens); + } + + return true; +} + +void prepareNetwork(std::shared_ptr<GraphView> graphView) +{ + removeFlatten(graphView); + sanitizeNodeNames(graphView); + + bool containsBatchNorm = false; + std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView); + + for (std::shared_ptr<Node> node : nodeVector) + if (node->type() == "BatchNorm") { + containsBatchNorm = true; + break; + } + + if (containsBatchNorm) + fuseBatchNorm(graphView); + + popSoftMax(graphView); +} + +static std::shared_ptr<Aidge::Node> getUniqueChildren(std::shared_ptr<Aidge::Node> node) +{ + std::set<std::shared_ptr<Aidge::Node>> childrenSet = node->getChildren(); + AIDGE_ASSERT(childrenSet.size() == 1, " Attempted to access to a unique child while the parent have multiple ones ! "); + return *(childrenSet.begin()); +} + static int getInputIndex(std::shared_ptr<Node> node, std::shared_ptr<Node> parentNode) { int index = 0; @@ -83,9 +138,11 @@ void multiplyScalingFactor(std::shared_ptr<Aidge::Node> node, double coeff) static void insertChildren(std::shared_ptr<Node> parent, std::shared_ptr<Node> newNode, std::shared_ptr<GraphView> graphView) { // Checking the parents always have at least 1 children + AIDGE_ASSERT(parent->getChildren().size() > 0, " Parent node must have at least one child to insert a new node ! "); // Retreive children connection indexes + std::vector<std::shared_ptr<Node>> nextNodes = parent->getChildren(0); std::vector<int> inputIndices(nextNodes.size()); for (std::size_t i = 0; i < nextNodes.size(); i++) { @@ -93,11 +150,13 @@ static void insertChildren(std::shared_ptr<Node> parent, std::shared_ptr<Node> n } // Disconnect childs from parent + for (std::shared_ptr<Node> nextNode : nextNodes) { parent->removeChild(nextNode, 0); } // Insert the new node between the child and the parent + parent->addChild(newNode, 0, 0); for (std::size_t i = 0; i < nextNodes.size(); i++) { newNode->addChild(nextNodes[i], 0, inputIndices[i]); @@ -121,22 +180,6 @@ bool insertRoundBelowProducer(std::shared_ptr<Node> node, std::shared_ptr<GraphV return false; } -bool checkArchitecture(std::shared_ptr<GraphView> graphView) -{ - std::set<std::string> otherNodeTypes({"Flatten", "Softmax", "BatchNorm2D", "ReLU", "Producer"}); - - for (std::shared_ptr<Node> node : graphView->getNodes()) - { - bool isOther = otherNodeTypes.find(node->type()) != otherNodeTypes.end(); - if (!isOther && !isAffine(node) && !isSeamless(node) && !isMerging(node) && !isNotQuantized(node)) { - Log::warn(" GraphView can't be quantized : node type {} is not supported !", node->type()); - return false; - } - } - - return true; -} - static double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor) { // get the abs tensor @@ -256,26 +299,6 @@ static std::shared_ptr<Node> getFirstNode(std::shared_ptr<GraphView> graphView) return retrieveNodeVector(graphView)[0]; } -void prepareNetwork(std::shared_ptr<GraphView> graphView) -{ - removeFlatten(graphView); - sanitizeNodeNames(graphView); - - bool containsBatchNorm = false; - std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView); - - for (std::shared_ptr<Node> node : nodeVector) - if (node->type() == "BatchNorm") { - containsBatchNorm = true; - break; - } - - if (containsBatchNorm) - fuseBatchNorm(graphView); - - popSoftMax(graphView); -} - // TODO : enhance this by modifying OperatorImpl in "core" ... static DataType getDataType(std::shared_ptr<Node> node) { @@ -515,7 +538,7 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView) // Revert the canceling by using the next scaling node accumulatedRatios[node] = prevRatio; - std::shared_ptr<Node> nextScalingNode = node->getChildren(0)[0]; + std::shared_ptr<Node> nextScalingNode = getUniqueChildren(node); multiplyScalingFactor(nextScalingNode, prevRatio); } @@ -933,7 +956,7 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_ rescaling /= inputIsUnsigned ? unsignedMax : signedMax; rescaling *= outputIsUnsigned ? unsignedMax : signedMax; - std::shared_ptr<Node> scalingNode = *(node->getChildren().begin()); // TODO : assert if scalingNode is a Scaling ... + std::shared_ptr<Node> scalingNode = getUniqueChildren(node); // TODO : assert if scalingNode is a Scaling ... multiplyScalingFactor(scalingNode,rescaling) ; } @@ -948,7 +971,7 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_ rescaling /= inputIsUnsigned ? unsignedMax : signedMax; rescaling *= outputIsUnsigned ? unsignedMax : signedMax; - std::shared_ptr<Node> scalingNode = *(node->getChildren().begin()); // TODO : assert if scalingNode is a Scaling ... + std::shared_ptr<Node> scalingNode = getUniqueChildren(node); // TODO : assert if scalingNode is a Scaling ... multiplyScalingFactor(scalingNode, rescaling) ; } @@ -960,7 +983,7 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_ std::shared_ptr<Node> prevScalingNode = node->getParent(0); multiplyScalingFactor(prevScalingNode, rescaling); - std::shared_ptr<Node> nextScalingNode = node->getChildren(0)[0]; + std::shared_ptr<Node> nextScalingNode = getUniqueChildren(node); multiplyScalingFactor(nextScalingNode, 1 / rescaling); } @@ -971,8 +994,8 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_ // Don't touch the scalings that precede non-linearities ... bool precedesNonLinearNode = false; - if (node->getChildren().size() > 0) - if (isNotQuantized(node->getChildren(0)[0])) + if (node->getChildren().size() == 1) + if (isNotQuantized(getUniqueChildren(node))) precedesNonLinearNode = true; if (!noQuant && !precedesNonLinearNode) -- GitLab