diff --git a/include/aidge/operator/PTQMetaOps.hpp b/include/aidge/operator/PTQMetaOps.hpp index 97e601b44cb1f70771a248b68a4f5c5017d3fd3e..05e026ff38ef85e70c12b8cb7690f1f3ddf31862 100644 --- a/include/aidge/operator/PTQMetaOps.hpp +++ b/include/aidge/operator/PTQMetaOps.hpp @@ -21,9 +21,10 @@ namespace Aidge { std::shared_ptr<Aidge::Node> BaseQuantizer(double scalingFactor, const std::string& name); void multiplyScalingFactor(std::shared_ptr<Aidge::Node> scalingNode, double coeff); - void appendRoundClip(std::shared_ptr<Node> metaOpNode, double clipMin, double clipMax); - + void appendRoundClip(std::shared_ptr<Node>& quantizer, double clipMin, double clipMax); + void removeRound(std::shared_ptr<Node>& quantizer); + void replaceScalingWithBitShift(std::shared_ptr<Node>& quantizer); /// @brief Quantizer acts as a meta-operator to handle scaling operations in the PTQ, replacing the Scaling Operator. /// This operator is composed of a sequence of [Mul] -> [Clip] -> [Round] operations. diff --git a/include/aidge/quantization/PTQ/PTQ.hpp b/include/aidge/quantization/PTQ/PTQ.hpp index f398ea3c9a20c4b3b113539f1e523d6eeee5e5f0..7db4f2704fb7a912e45cdb7a8e31252da248e94c 100644 --- a/include/aidge/quantization/PTQ/PTQ.hpp +++ b/include/aidge/quantization/PTQ/PTQ.hpp @@ -179,6 +179,13 @@ namespace Aidge { */ void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet, Clipping clippingMode, bool noQuant, bool optimizeSigns, bool singleShift, bool useCuda, bool foldGraph, bool verbose); + /** + * @brief Take a quantized graphview and cast it to true integer precision. Also optionally replaces the scaling + * nodes contained in the activation quantizers with bit-shift nodes. + * @param graphView The GraphView to edit. + */ + void castQuantizedNetwork(std::shared_ptr<GraphView> graphView /*, Aidge::DataType targetType, bool singleShift, bool bitShiftRounding*/); + /** * @brief Compute the weight ranges of every affine node. Provided for debugging purposes. * @param graphView The GraphView containing the affine nodes. diff --git a/python_binding/pybind_PTQ.cpp b/python_binding/pybind_PTQ.cpp index 5f21b68fd2ab7b5bf1321c79e2b5afea9333f080..e7ed07c3258d7cf57ffbb4e0bbd92c11f4f7b12b 100644 --- a/python_binding/pybind_PTQ.cpp +++ b/python_binding/pybind_PTQ.cpp @@ -93,7 +93,9 @@ void init_PTQ(py::module &m) { :type verbose: bool )mydelimiter"); - m.def("quantize_network", &quantizeNetwork ,py::arg("network"), py::arg("nb_bits"), py::arg("input_dataset"), py::arg("clipping_mode") = Clipping::MAX , py::arg("no_quantization") = false, py::arg("optimize_signs") = false, py::arg("single_shift") = false, py::arg("use_cuda") = false, py::arg("fold_graph") = true, py::arg("verbose") = false, + m.def("cast_quantized_network", &castQuantizedNetwork, py::arg("network")); + + m.def("quantize_network", &quantizeNetwork, py::arg("network"), py::arg("nb_bits"), py::arg("input_dataset"), py::arg("clipping_mode") = Clipping::MAX , py::arg("no_quantization") = false, py::arg("optimize_signs") = false, py::arg("single_shift") = false, py::arg("use_cuda") = false, py::arg("fold_graph") = true, py::arg("verbose") = false, R"mydelimiter( Main quantization routine. Performs every step of the quantization pipeline. :param network: The GraphView to be quantized. @@ -217,7 +219,6 @@ void init_PTQ(py::module &m) { )mydelimiter"); m.def("prepare_network", &prepareNetwork, py::arg("network"), "prepare the network for the PTQ"); - } } // namespace Aidge diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp index 0a1f3f17100ccce994942d17bf4f2f75b85f62d0..c2e8fb50c7715d9dc76e304be6da0ed785c527b1 100644 --- a/src/PTQ/PTQ.cpp +++ b/src/PTQ/PTQ.cpp @@ -207,7 +207,6 @@ static int getInputIndex(std::shared_ptr<Node> node, std::shared_ptr<Node> paren return index; } -// Utility function that insert a node below another one already connected static void insertChildren(std::shared_ptr<Node> parent, std::shared_ptr<Node> newNode, std::shared_ptr<GraphView> graphView) { // Checking the parents always have at least 1 children @@ -275,10 +274,24 @@ void foldProducerQuantizers(std::shared_ptr<GraphView> graphView) constantFolding(graphView); } +void castQuantizedNetwork(std::shared_ptr<GraphView> graphView /*, Aidge::DataType targetType, bool singleShift, bool bitShiftRounding*/) +{ + std::set<std::shared_ptr<Node>> nodes = graphView->getNodes(); + for (std::shared_ptr<Node> node : nodes) + if (node->type() == "BaseQuantizer") + removeRound(node); + + nodes = graphView->getNodes(); // must be called again because of removeRound() + for (std::shared_ptr<Node> node : nodes) + if (node->type() == "BaseQuantizer" && hasAttr(node, "isActivationScaling")) + replaceScalingWithBitShift(node); +} + double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor) { + std::shared_ptr<Tensor> fallback; + // get the abs tensor - std::shared_ptr<Tensor> fallback; //Fallback tensor for refCastFR std::shared_ptr<Tensor> absTensor = std::make_shared<Tensor>(tensor->abs()); @@ -353,7 +366,7 @@ std::vector<std::shared_ptr<Node>> retrieveNodeVector(std::shared_ptr<GraphView> { std::vector<std::shared_ptr<Node>> nodeVector = graphView->getOrderedNodes(); - // Remove duplicate nodes. Is it still needed ??? + // Remove duplicate nodes. XXX Is it still needed ??? fixScheduling(nodeVector); @@ -471,7 +484,7 @@ void insertResidualScalingNodes(std::shared_ptr<GraphView> graphView) // XXX XXX XXX std::shared_ptr<Node> residualNode = BaseQuantizer(1.0, residualNodeName); addAttr(residualNode, "isActivationScaling"); - addAttr(residualNode, "isResidual"); + // XXX addAttr(residualNode, "isResidual"); residualNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) residualNode->getOperator()->setBackend(determineBackend(parentNode)); @@ -1060,7 +1073,11 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_ if (!noQuant && !precedesNonLinearNode) { - // Old : Replace the Scaling Node by a Quantizer + // we need to gather the sign informations before we modify + // the node pointer with appendRoundClip() ... + + bool inputIsUnsigned = signMap[node].first; + bool outputIsUnsigned = signMap[node].second; appendRoundClip(node, -(signedMax + 1), signedMax); @@ -1068,17 +1085,14 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_ { double rescaling = 1.0; - bool inputIsUnsigned = signMap[node].first; - bool outputIsUnsigned = signMap[node].second; - rescaling /= inputIsUnsigned ? unsignedMax : signedMax; rescaling *= outputIsUnsigned ? unsignedMax : signedMax; // XXX XXX XXX multiplyScalingFactor(node, rescaling); - + if (outputIsUnsigned) - setClipRange(node, 0, unsignedMax); + setClipRange(node, 0, unsignedMax); } } } @@ -1113,7 +1127,7 @@ static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::u std::string mulNodeName = makeUniqueName(node->name() + "_Mul", graphView); std::shared_ptr<Node> mulNode = Mul(mulNodeName); - addAttr(mulNode, "isCompensation"); + // XXX XXX XXX addAttr(mulNode, "isCompensation"); mulNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) mulNode->getOperator()->setBackend(determineBackend(node)); @@ -1170,7 +1184,7 @@ static void performSingleShiftApproximation(std::shared_ptr<GraphView> graphView multiplyScalingFactor(node, ratio); - // compensate the ratio using the previous node weigths ... + // compensate the ratio using the previous node scaling factors ... multiplyScalingFactor(linearNode->getParent(1), 1.0 / ratio); if (nodeHasBias(linearNode)) diff --git a/src/operator/PTQMetaOps.cpp b/src/operator/PTQMetaOps.cpp index 4134d16cba524e8f243b7578b4e62e0fc1bb7376..58563a2e7231ede1dd3b097faeca084c2ac6c478 100644 --- a/src/operator/PTQMetaOps.cpp +++ b/src/operator/PTQMetaOps.cpp @@ -44,6 +44,16 @@ static void addAttr(std::shared_ptr<Aidge::Node> node, std::string attr, double node->attributes()->addAttr("quantization.ptq." + attr, value); } +// XXX TODO : rework this +static void copyDynamicAttributes(std::shared_ptr<Aidge::Node> prevNode, std::shared_ptr<Aidge::Node> newNode) +{ + if (hasAttr(prevNode, "isProducerScaling")) + addAttr(newNode, "isProducerScaling"); + + if (hasAttr(prevNode, "isActivationScaling")) + addAttr(newNode, "isActivationScaling"); +} + std::shared_ptr<Node> BaseQuantizer(double scalingFactor, const std::string& name) { std::shared_ptr<Node> mulNode = Mul(name + "_MulQuant"); @@ -102,10 +112,9 @@ void multiplyScalingFactor(std::shared_ptr<Aidge::Node> scalingNode, double coef auto producer = mulNode->getParent(1); producer->getOperator()->setOutput(0, newScalingFactorTensor); - // XXX prev way : mulNode->input(1).first->getOperator()->setOutput(0, resultTensor); } -std::shared_ptr<Node> appendRoundClip(std::shared_ptr<Node> quantizer, double clipMin, double clipMax) +void appendRoundClip(std::shared_ptr<Node>& quantizer, double clipMin, double clipMax) { // Retreive a clone of the microGraph @@ -154,21 +163,15 @@ std::shared_ptr<Node> appendRoundClip(std::shared_ptr<Node> quantizer, double cl // Copy the flags - if (hasAttr(quantizer, "isProducerScaling")) - addAttr(newQuantizer, "isProducerScaling"); - - if (hasAttr(quantizer, "isActivationScaling")) - addAttr(newQuantizer, "isActivationScaling"); - + copyDynamicAttributes(quantizer, newQuantizer); + // replace the previous quantizer with the new one GraphView::replace({quantizer}, {newQuantizer}); - // TODO : replace the old pointer with the new one (by reference) - - // quantizer = newQuantizer; + // XXX : replace the old pointer with the new one (by reference) - return newQuantizer; + quantizer = newQuantizer; } void updateScalingFactor(std::shared_ptr<Node> quantizerNode, double scalingFactor) @@ -248,7 +251,46 @@ void setClipRange(std::shared_ptr<Node> quantizerNode, double min, double max) maxProducer->getOperator()->setOutput(0, newMaxTensor); } -std::shared_ptr<Node> replaceScalingWithBitShift(std::shared_ptr<Node> quantizer) +// XXX +void removeRound(std::shared_ptr<Node>& quantizer) +{ + // Retreive a clone of the microGraph + + auto quantizerOp = std::static_pointer_cast<MetaOperator_Op> (quantizer->getOperator()); + auto microGraph = quantizerOp->getMicroGraph()->clone(); + + // retreive the multiplicative (scaling) node + + std::shared_ptr<Node> roundNode = nullptr; + for (auto node : microGraph->getNodes()) + if (node->type() == "Round") + roundNode = node; + + if (roundNode == nullptr) + return; + + // remove the Round node + + microGraph->replace({roundNode}, {}); + + // Create the new meta-operator + + std::shared_ptr<Node> newQuantizer = MetaOperator("BaseQuantizer", microGraph, {}, quantizer->name()); + + // Copy the flags + + copyDynamicAttributes(quantizer, newQuantizer); + + // replace the previous quantizer with the new one + + GraphView::replace({quantizer}, {newQuantizer}); + + // XXX : replace the old pointer with the new one (by reference) + + quantizer = newQuantizer; +} + +void replaceScalingWithBitShift(std::shared_ptr<Node>& quantizer) { // Retreive a clone of the microGraph @@ -299,21 +341,15 @@ std::shared_ptr<Node> replaceScalingWithBitShift(std::shared_ptr<Node> quantize // Copy the flags - if (hasAttr(quantizer, "isProducerScaling")) - addAttr(newQuantizer, "isProducerScaling"); - - if (hasAttr(quantizer, "isActivationScaling")) - addAttr(newQuantizer, "isActivationScaling"); + copyDynamicAttributes(quantizer, newQuantizer); // replace the previous quantizer with the new one GraphView::replace({quantizer}, {newQuantizer}); - // TODO : replace the old pointer with the new one (by reference) - - // quantizer = newQuantizer; + // XXX : replace the old pointer with the new one (by reference) - return newQuantizer; + quantizer = newQuantizer; } } \ No newline at end of file