diff --git a/include/aidge/quantization/PTQ/PTQ.hpp b/include/aidge/quantization/PTQ/PTQ.hpp index 968f9b5cd2d6c93892052a2e740e11388d5edef6..d9b944e33cc5706bb8f62ddeb1553ace0619245d 100644 --- a/include/aidge/quantization/PTQ/PTQ.hpp +++ b/include/aidge/quantization/PTQ/PTQ.hpp @@ -112,9 +112,13 @@ namespace Aidge { void multiplyScalingFactor(std::shared_ptr<Aidge::Node> node, double coeff); + /** + * @brief Prepare a network before the quantization is applied to it, by removing, replacing + * or fusing the nodes that are not supported by the PTQ pipeline. + * @param graphView The network to prepare for the quantization + */ void prepareNetwork(std::shared_ptr<GraphView> graphView); - /** * @brief Insert a scaling node after each affine node of the GraphView. * Also insert a scaling node in every purely residual branches. diff --git a/src/PTQ/Clipping.cpp b/src/PTQ/Clipping.cpp index 41a1d24ff76e30d81d078f111f636ed5f3f97eca..4107c555e6b356401ed06057c9a06463084b74bd 100644 --- a/src/PTQ/Clipping.cpp +++ b/src/PTQ/Clipping.cpp @@ -18,7 +18,7 @@ namespace Aidge { - + std::unordered_map<std::shared_ptr<Node>, std::vector<int>> computeHistograms(std::unordered_map<std::shared_ptr<Node>, double> valueRanges, int nbBins, std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> calibrationSet, bool useCuda) { if (useCuda) @@ -196,7 +196,6 @@ double computeKLClipping(std::vector<int> refHistogram, std::uint8_t nbBits) return bestClipping; } - std::unordered_map<std::shared_ptr<Node>, double> adjustRanges(Clipping clippingMode, std::unordered_map<std::shared_ptr<Node>, double> valueRanges, std::uint8_t nbBits, std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> calibrationSet, bool useCuda, bool verbose) { double clipping = 1.0f; @@ -231,7 +230,6 @@ std::unordered_map<std::shared_ptr<Node>, double> adjustRanges(Clipping clipping } } - return valueRanges; } diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp index 3775d18989306ab871702a65f2e5e5d67c5d9c9f..9e2a62c8b975bbca4f63c90d500324a75e492e64 100644 --- a/src/PTQ/PTQ.cpp +++ b/src/PTQ/PTQ.cpp @@ -9,36 +9,31 @@ * ********************************************************************************/ -#include "aidge/quantization/PTQ/CLE.hpp" -#include "aidge/quantization/PTQ/Clipping.hpp" -#include "aidge/quantization/PTQ/PTQ.hpp" -#include "aidge/operator/PTQMetaOps.hpp" - -#include "aidge/data/Tensor.hpp" -#include "aidge/graph/GraphView.hpp" -#include "aidge/graph/Node.hpp" -#include "aidge/scheduler/SequentialScheduler.hpp" -#include "aidge/scheduler/Scheduler.hpp" -#include "aidge/utils/Log.hpp" -#include "aidge/operator/MetaOperator.hpp" - -#include "aidge/operator/Producer.hpp" -#include "aidge/operator/Mul.hpp" -#include "aidge/operator/Round.hpp" -#include "aidge/operator/ReLU.hpp" -#include "aidge/operator/BatchNorm.hpp" -#include "aidge/operator/Conv.hpp" -#include "aidge/operator/ArgMax.hpp" -#include "aidge/operator/Reshape.hpp" -#include "aidge/operator/MatMul.hpp" -#include "aidge/operator/Cast.hpp" - - - -#include "aidge/recipes/Recipes.hpp" -#include "aidge/recipes/QuantRecipes.hpp" -#include "aidge/operator/MetaOperator.hpp" - + #include "aidge/quantization/PTQ/CLE.hpp" + #include "aidge/quantization/PTQ/Clipping.hpp" + #include "aidge/quantization/PTQ/PTQ.hpp" + #include "aidge/operator/PTQMetaOps.hpp" + + #include "aidge/data/Tensor.hpp" + #include "aidge/graph/GraphView.hpp" + #include "aidge/graph/Node.hpp" + #include "aidge/scheduler/SequentialScheduler.hpp" + #include "aidge/scheduler/Scheduler.hpp" + #include "aidge/utils/Log.hpp" + + #include "aidge/operator/Producer.hpp" + #include "aidge/operator/Mul.hpp" + #include "aidge/operator/Round.hpp" + #include "aidge/operator/ReLU.hpp" + #include "aidge/operator/BatchNorm.hpp" + #include "aidge/operator/Conv.hpp" + #include "aidge/operator/ArgMax.hpp" + #include "aidge/operator/Reshape.hpp" + #include "aidge/operator/MatMul.hpp" + + #include "aidge/recipes/Recipes.hpp" + #include "aidge/recipes/QuantRecipes.hpp" + #include "aidge/operator/MetaOperator.hpp" namespace Aidge { @@ -85,20 +80,13 @@ bool isNotQuantized(std::shared_ptr<Node> node) return (notQuantizedNodeTypes.find(node->type()) != notQuantizedNodeTypes.end()); } -std::shared_ptr<Aidge::Node> getFirstNode(std::shared_ptr<GraphView> graphView) -{ - return graphView->getOrderedInputs()[0].first; -} void clearGraphViewInputNodes(std::shared_ptr<GraphView> graphView,DataType targetType) { - for (std::shared_ptr<Aidge::Node> inputNode: graphView->inputNodes()) - { - for (Aidge::IOIndex_t index = inputNode->getFirstFreeDataInput();index < inputNode->getNbFreeDataInputs(); index++) - { - inputNode->getOperator()->resetInput(index); - } - } + for (std::shared_ptr<Aidge::Node> node : graphView->inputNodes()) + for (Aidge::IOIndex_t i = node->getFirstFreeDataInput(); i < node->getNbFreeDataInputs(); i++) + node->getOperator()->resetInput(i); } + bool checkArchitecture(std::shared_ptr<GraphView> graphView) { std::set<std::string> removedNodeTypes({"Flatten", "Softmax", "BatchNorm2D"}); @@ -253,59 +241,6 @@ static void insertChildren(std::shared_ptr<Node> parent, std::shared_ptr<Node> n graphView->add(newNode); } -void applyConstFold(std::shared_ptr<GraphView> &graphView) -{ - for (const std::shared_ptr<Node> node : graphView->getNodes()) - { - if (node->type() == "Producer" ) - { - const auto& producer = std::static_pointer_cast<Producer_Op>(node->getOperator()); - producer->constant() = true; - } - } - constantFolding(graphView); -} - -bool castQuantizedGraph(std::shared_ptr<GraphView> &graphView, Aidge::DataType targetType, bool singleShift,bool bitshiftRounding) -{ - //We need a deepcopy of the graphs nodes since we will replace some nodes - std::vector<std::shared_ptr<Node>> nodeVector(graphView->getNodes().begin(), graphView->getNodes().end()); - - for (std::shared_ptr<Node> node : nodeVector) - { - if (node->type() == "Round" && node->attributes()->hasAttr("quantization.ptq.isProducerRounding")) - { - std::shared_ptr<Aidge::Node> castNode = Cast(targetType,node->name() + "_Cast"); - castNode->getOperator()->setDataType(targetType); - castNode->getOperator()->setBackend(node->getOperator()->backend()); - insertChildren(node,castNode,graphView); - castNode->attributes()->addAttr("quantization.ptq.isProducerCasting",0.0); - node->getOperator()->setDataType(DataType::Float64); - } - else if(node->type() == "Quantizer") - { - if(singleShift) - { - std::shared_ptr<Node> newBitShiftQuantizer = BitShiftQuantizer(node,targetType,bitshiftRounding,node->name()+"_BitShift_Quantizer"); - newBitShiftQuantizer->getOperator()->setBackend(node->getOperator()->backend()); - graphView->replace({node},{newBitShiftQuantizer}); - - } - else //If single shift is not enabled we keep using the alternative Int Quantizer (which cast the data before and after the regular Quantizer Operations) - { - std::shared_ptr<Node> newIntQuantizer = IntQuantizer(node,targetType,node->name()); - newIntQuantizer->getOperator()->setBackend(node->getOperator()->backend()); - graphView->replace({node},{newIntQuantizer}); - } - } - else if (node->type() != "Producer" && - !node->attributes()->hasAttr("quantization.ptq.isProducerScaling")) - { - node->getOperator()->setDataType(targetType); - } - } - return true; -} void foldProducerQuantizers(std::shared_ptr<GraphView> graphView) { @@ -316,8 +251,6 @@ void foldProducerQuantizers(std::shared_ptr<GraphView> graphView) for (std::shared_ptr<Node> quantizer : producerQuantizers) { - // Log::notice(" Quantizer : {} {} ", quantizer->name(), quantizer->type()); - // Set the param producer to be constant auto paramProducer = quantizer->getParent(0); @@ -335,7 +268,6 @@ void foldProducerQuantizers(std::shared_ptr<GraphView> graphView) { auto producerOp = std::static_pointer_cast<Producer_Op>(producer->getOperator()); producerOp->constant() = true; - //Log::notice("node : {} ", producer->name()); } expandMetaOp(quantizer); // mandatory for now !!! @@ -511,6 +443,10 @@ std::vector<std::shared_ptr<Node>> retrieveNodeVector(std::shared_ptr<GraphView> return nodeVector; } +static std::shared_ptr<Node> getFirstNode(std::shared_ptr<GraphView> graphView) +{ + return retrieveNodeVector(graphView)[0]; +} // TODO : enhance this by modifying OperatorImpl in "core" ... static DataType getDataType(std::shared_ptr<Node> node) @@ -672,7 +608,7 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView) // ITERATE OVER THE GRAPH ///////////////////////////////////////////////// - std::shared_ptr<Node> firstNode =getFirstNode(graphView); + std::shared_ptr<Node> firstNode = getFirstNode(graphView); for (std::shared_ptr<Node> node : nodeVector) { @@ -853,7 +789,7 @@ std::unordered_map<std::shared_ptr<Node>, double> computeRanges(std::shared_ptr< void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_map<std::shared_ptr<Node>, double> valueRanges) { - std::shared_ptr<Node> firstNode = getFirstNode(graphView); + std::shared_ptr<Node> firstNode = getFirstNode(graphView); // CREATE THE ACCUMULATED RATIO MAP /////////////////////////////////////// @@ -953,7 +889,7 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_m std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> computeSignMap(std::shared_ptr<GraphView> graphView, bool verbose) { - std::shared_ptr<Node> firstNode = getFirstNode(graphView); + std::shared_ptr<Node> firstNode = getFirstNode(graphView); std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> signMap; @@ -1333,7 +1269,7 @@ static void setupDataType(std::shared_ptr<GraphView> graphView, std::vector<std: void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> calibrationSet, DataType targetType, Clipping clippingMode, bool noQuant, bool optimizeSigns, bool singleShift, bool useCuda, bool foldGraph, bool verbose) { - Log::notice(" === QUANT PTQ 0.2.21 === "); + Log::notice(" === QUANT PTQ 0.3.0 === "); graphView->setBackend("cpu"); @@ -1384,33 +1320,16 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, Log::notice(" Folding the Producer's Quantizers ..."); foldProducerQuantizers(graphView); } - if( targetType != DataType::Float64 && targetType != DataType::Float32 && targetType != DataType::Float16) - { - AIDGE_ASSERT(!noQuant,"Cannot cast operators with the noQuant (Fake Quantization) flag set to true!") - Log::notice("Starting to cast operators into the desired type ..."); - castQuantizedGraph(graphView,targetType,singleShift,bitshiftRounding); - - graphView->updateInputsOutputs(); - clearGraphViewInputNodes(graphView,targetType); //Convert all input tensors of the GV into targetType - } - else - { - setupDataType(graphView, inputDataSet, targetType); - } - if(foldGraph) - { - Log::notice("Applying constant folding recipe to the graph ..."); - applyConstFold(graphView); - } - //Mandatory to handle all of the newly added connections! - graphView->updateInputsOutputs(); - - //Clearing input nodes - Log::notice("Clearing all input nodes ..."); + + // TODO ... + // Log::notice(" Clearing the input nodes ..."); if (verbose) printScalingFactors(graphView); - + + if (useCuda) + graphView->setBackend("cuda"); + Log::notice(" Reseting the scheduler ..."); SequentialScheduler scheduler(graphView); scheduler.resetScheduling(); diff --git a/src/operator/PTQMetaOps.cpp b/src/operator/PTQMetaOps.cpp index 8686242668e610923abb20c5c002eda0292b2157..82b2e501dc275b361899a0ae8284f8a5409d32dc 100644 --- a/src/operator/PTQMetaOps.cpp +++ b/src/operator/PTQMetaOps.cpp @@ -33,13 +33,7 @@ #include "aidge/utils/Log.hpp" namespace Aidge -{ -static std::shared_ptr<Node> getSubNode(std::shared_ptr<GraphView> graphView, std::string nodeType) -{ - std::shared_ptr<Node> mulNode = nullptr; - for(std::shared_ptr<Node> node : graphView->getNodes()) - if (node->type() == nodeType) - mulNode = node; +{ static bool hasAttr(std::shared_ptr<Aidge::Node> node, std::string attr) {