diff --git a/include/aidge/quantization/PTQ/PTQ.hpp b/include/aidge/quantization/PTQ/PTQ.hpp index 5a15f06ccd7922fc931cd1d05af507588cec5162..8bce79300a7113bf6525db72159c0052f4b01c97 100644 --- a/include/aidge/quantization/PTQ/PTQ.hpp +++ b/include/aidge/quantization/PTQ/PTQ.hpp @@ -96,15 +96,6 @@ namespace Aidge { */ void insertScalingBelowProducer(std::shared_ptr<Node> node, std::shared_ptr<GraphView> graphView); - /** - * @brief Inserts a rounding node below the given producer (also below its ows producerScaling) node in the graph view. - * - * @param node A shared pointer to the producer node where the rounding node will be inserted. - * @param graphView A shared pointer to the graph view in which the nodes are located. - * @return True if the rounding node was successfully inserted; False otherwise. - */ - bool insertRoundBelowProducer(std::shared_ptr<Node> node, std::shared_ptr<GraphView> graphView); - /** * @brief Determine whether an input GraphView can be quantized or not. * @param graphView The GraphView to be checked. diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp index f622ae5c7ec02cb7dd8f3c53287774b3f1a8e59a..fc278220565a45de5b0f7db2e115f636fd0e6807 100644 --- a/src/PTQ/PTQ.cpp +++ b/src/PTQ/PTQ.cpp @@ -238,26 +238,6 @@ static void insertChildren(std::shared_ptr<Node> parent, std::shared_ptr<Node> n graphView->add(newNode); } -bool insertRoundBelowProducer(std::shared_ptr<Node> node, std::shared_ptr<GraphView> graphView) -{ -/* - if (hasAttr(node, "isProducerScaling") && node->type() != "Round") - { - std::shared_ptr<Aidge::Node> roundNode = Round(node->name() + "_Round"); - roundNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) - roundNode->getOperator()->setBackend(determineBackend(node)); - - insertChildren(node, roundNode, graphView); - addAttr(roundNode, "isProducerRounding"); - - return true; - } - return false; -*/ - Log::warn(" ROUND : DUMMY ! "); - return true; -} - double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor) { // get the abs tensor @@ -1081,7 +1061,6 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_ if (optimizeSigns) { -/* double rescaling = 1.0; bool inputIsUnsigned = signMap[node].first; @@ -1090,16 +1069,11 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_ rescaling /= inputIsUnsigned ? unsignedMax : signedMax; rescaling *= outputIsUnsigned ? unsignedMax : signedMax; - // XXX OK - //double currScalingFactor = getScalingFactor(quantizerNode); - //updateScalingFactor(quantizerNode, currScalingFactor * rescaling); + // XXX XXX XXX multiplyScalingFactor(node, rescaling); - // XXX XXX XXX HERE : Fix this !!! - - if(outputIsUnsigned) - setClipRange(quantizerNode, 0, unsignedMax); -*/ + if (outputIsUnsigned) + setClipRange(node, 0, unsignedMax); } } } @@ -1119,7 +1093,7 @@ static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::u // preceded by an Weighted node (that is not forking), and insert // a mul node (Compensation) before it if so ... - if (node->type() == "Quantizer") + if (node->type() == "BaseQuantizer") { // Note : this works because a Quantizer has only one Parent ... @@ -1167,6 +1141,7 @@ static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::u std::shared_ptr<Node> coeffQuantizer = mulNode->getParent(1); appendRoundClip(coeffQuantizer, -(signedMax + 1), signedMax); } + } } } @@ -1178,7 +1153,7 @@ static void performSingleShiftApproximation(std::shared_ptr<GraphView> graphView for (std::shared_ptr<Node> node : nodeVector) { - if (node->type() == "Quantizer") + if (node->type() == "BaseQuantizer") { std::shared_ptr<Node> linearNode = node->getParent(0); @@ -1202,10 +1177,10 @@ static void performSingleShiftApproximation(std::shared_ptr<GraphView> graphView static void printScalingFactors(std::shared_ptr<GraphView> graphView) { for (auto node : retrieveNodeVector(graphView)) - if (hasAttr(node, "isScaling") || node->type() == "Quantizer") + if (hasAttr(node, "isScaling") || node->type() == "BaseQuantizer") { double scalingFactor = getScalingFactor(node); - Log::info(" {:.6f} ({})", scalingFactor, node->name()); + Log::notice(" SCALING FACTOR : {} ({})", scalingFactor, node->name()); } } @@ -1280,6 +1255,9 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, scheduler.resetScheduling(); Log::notice(" Network is quantized !"); + + // XXX XXX XXX + printScalingFactors(graphView); } std::unordered_map<std::string, double> getWeightRanges(std::shared_ptr<GraphView> graphView)