diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp index d162421417e8e81fb6f42415137aca1777ff9a5f..7be15a1aa225a75dddeb794c9d565082f01625d3 100644 --- a/src/PTQ/PTQ.cpp +++ b/src/PTQ/PTQ.cpp @@ -277,31 +277,36 @@ void castQuantizedNetwork(std::shared_ptr<GraphView> graphView, Aidge::DataType { if (singleShift) { - // Remove the round nodes (that cannot round integers) - - std::set<std::shared_ptr<Node>> nodes = graphView->getNodes(); - for (std::shared_ptr<Node> node : nodes) - if (node->type() == "BaseQuantizer") - removeRound(node); - // Replace the scaling nodes with bit-shifts (activations only) - nodes = graphView->getNodes(); // must be called again because of removeRound() + std::set<std::shared_ptr<Node>> nodes = graphView->getNodes(); // must be called again because of removeRound() ! for (std::shared_ptr<Node> node : nodes) - if (node->type() == "BaseQuantizer" && hasAttr(node, "isActivationScaling")) - replaceScalingWithBitShift(node); - - // Cast all the graph tensors to integers + { + if (node->type() == "BaseQuantizer") + { + if (hasAttr(node, "isActivationScaling")) + { + removeRound(node); + replaceScalingWithBitShift(node); + } + else if (hasAttr(node, "isProducerScaling")) + castQuantizerIOs(node, targetType); + } + } - graphView->setDataType(targetType); + // Cast the nodes (excepted the producers and quantizers) to integer precision + nodes = graphView->getNodes(); + for (std::shared_ptr<Node> node : nodes) + if (node->type() != "Producer" && !hasAttr(node, "isProducerScaling")) // TODO : double check this ! + node->getOperator()->setDataType(targetType); } else { - // Set all the nodes, excepted the quantizers, to have integer IOs + // Set the nodes (excepted the producers and quantizers) to have integer IOs std::set<std::shared_ptr<Node>> nodes = graphView->getNodes(); for (std::shared_ptr<Node> node : nodes) - if (node->type() != "BaseQuantizer") + if (node->type() != "BaseQuantizer" && node->type() != "Producer") node->getOperator()->setDataType(targetType); // Cast the quantizers input and outputs by inserting Cast nodes @@ -310,6 +315,8 @@ void castQuantizedNetwork(std::shared_ptr<GraphView> graphView, Aidge::DataType if (node->type() == "BaseQuantizer") castQuantizerIOs(node, targetType); } + + // XXX graphView->updateInputsOutputs(); } double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor) @@ -1124,7 +1131,7 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_ } } -static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, bool noQuant) +void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, bool noQuant) { // XXX Use the signMap to increase the resolution when possible ... double signedMax = (1 << (nbBits - 1)) - 1; @@ -1191,7 +1198,7 @@ static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::u } } -static void performSingleShiftApproximation(std::shared_ptr<GraphView> graphView) +void performSingleShiftApproximation(std::shared_ptr<GraphView> graphView) { std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView); @@ -1262,7 +1269,7 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, insertScalingNodes(graphView); // TODO : double check the CLE ... - crossLayerEqualization(graphView); + crossLayerEqualization(graphView); // XXX XXX XXX Log::notice(" Normalizing the parameters ..."); normalizeParameters(graphView); diff --git a/src/operator/PTQMetaOps.cpp b/src/operator/PTQMetaOps.cpp index 442a3c383bd18fb1883ce7aad2561a140a0836df..27cee8272b76ef1f449c074570d0691782459a4d 100644 --- a/src/operator/PTQMetaOps.cpp +++ b/src/operator/PTQMetaOps.cpp @@ -78,7 +78,7 @@ std::shared_ptr<Node> BaseQuantizer(double scalingFactor, const std::string& nam // alternative : capture the Producer ... // std::shared_ptr<GraphView> connectedGraphView = getConnectedGraphView(mulNode); - std::shared_ptr<Node> quantizerNode = MetaOperator("BaseQuantizer", graphView, {}, name); // XXX alternative prototype -> + std::shared_ptr<Node> quantizerNode = MetaOperator("BaseQuantizer", graphView, {}, name); // an simpler prototype exists ... return quantizerNode; } @@ -164,7 +164,7 @@ void appendRoundClip(std::shared_ptr<Node>& quantizer, double clipMin, double cl copyDynamicAttributes(quantizer, newQuantizer); GraphView::replace({quantizer}, {newQuantizer}); - // XXX : replace the old pointer with the new one (by reference) + // replace the old pointer with the new one (by reference) quantizer = newQuantizer; } @@ -246,8 +246,6 @@ void setClipRange(std::shared_ptr<Node> quantizerNode, double min, double max) maxProducer->getOperator()->setOutput(0, newMaxTensor); } -// XXX TODO : manage the datatype / backend - void removeRound(std::shared_ptr<Node>& quantizer) { // Retreive a clone of the microGraph @@ -275,7 +273,7 @@ void removeRound(std::shared_ptr<Node>& quantizer) copyDynamicAttributes(quantizer, newQuantizer); GraphView::replace({quantizer}, {newQuantizer}); - // XXX : replace the old pointer with the new one (by reference) + // replace the old pointer with the new one (by reference) quantizer = newQuantizer; } @@ -301,12 +299,21 @@ void replaceScalingWithBitShift(std::shared_ptr<Node>& quantizer) auto dataType = mulOp->getOutput(0)->dataType(); auto backend = mulOp->getOutput(0)->backend(); - // compute the shift value (ALL OF THIS MUST BE REWORKED !) + // compute the shift value - double scaling = getScalingFactor(quantizer); - int bitShiftAmount = std::round(std::log2(scaling)); + double scalingFactor = getScalingFactor(quantizer); + int bitShiftAmount = -std::round(std::log2(scalingFactor)); auto bitShiftDirection = BitShift_Op::BitShiftDirection::right; - bool bitShiftRounding = false; // XXX XXX XXX use an argument !!! + + Log::notice(" SHIFT AMOUNT = {} ({})", bitShiftAmount, scalingFactor); + + if (bitShiftAmount < 0 ) + { + bitShiftDirection = BitShift_Op::BitShiftDirection::left; + bitShiftAmount = -bitShiftAmount; + } + + bool bitShiftRounding = true; // XXX use an argument !!! // create the replacement bit-shift nodes @@ -331,7 +338,7 @@ void replaceScalingWithBitShift(std::shared_ptr<Node>& quantizer) copyDynamicAttributes(quantizer, newQuantizer); GraphView::replace({quantizer}, {newQuantizer}); - // XXX : replace the old pointer with the new one (by reference) + // replace the old pointer with the new one (by reference) quantizer = newQuantizer; } @@ -370,7 +377,7 @@ void castQuantizerIOs(std::shared_ptr<Node>& quantizer, Aidge::DataType external copyDynamicAttributes(quantizer, newQuantizer); GraphView::replace({quantizer}, {newQuantizer}); - // XXX : replace the old pointer with the new one (by reference) + // replace the old pointer with the new one (by reference) quantizer = newQuantizer; }