From fe97cdf9e5cc59f6813d225075626bd1f98d5847 Mon Sep 17 00:00:00 2001 From: bhalimi <benjamin.halimi@cea.fr> Date: Mon, 13 May 2024 13:40:46 +0000 Subject: [PATCH] bugfix + support of MSE cliping --- include/aidge/QuantPTQ.hpp | 11 +- python_binding/pybind_QuantPTQ.cpp | 13 +- src/QuantPTQ.cpp | 195 +++++++++++++++++++++++++---- 3 files changed, 191 insertions(+), 28 deletions(-) diff --git a/include/aidge/QuantPTQ.hpp b/include/aidge/QuantPTQ.hpp index 499ded0d..fdc9300e 100644 --- a/include/aidge/QuantPTQ.hpp +++ b/include/aidge/QuantPTQ.hpp @@ -54,9 +54,9 @@ namespace Aidge { * @brief Normalize the activations of each affine node so that it become equal to one. * This is done by reconfiguring the scaling nodes, as well as rescaling the weights and biases tensors. * @param graphView The GraphView containing the affine nodes. - * @param inputDataSet The input dataset on which the value ranges are computed. + * @param valueRanges The node output value ranges computed over the calibration dataset. */ - void normalizeActivations(std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet); + void normalizeActivations(std::shared_ptr<GraphView> graphView, std::map<std::string, float> valueRanges); /** @@ -72,7 +72,7 @@ namespace Aidge { * @param nbBits The desired number of bits of the quantization. * @param inputDataSet The input dataset on which the value ranges are computed. */ - void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet); + void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet, bool OptimizeCliping); /** * @brief Compute the weight ranges of every affine node. Provided for debuging purposes. @@ -88,6 +88,11 @@ namespace Aidge { void clearBiases(std::shared_ptr<GraphView> graphView); void devPTQ(std::shared_ptr<GraphView> graphView); + + std::map<std::string, std::vector<int>> computeScalingHistograms(std::map<std::string, float> valueRanges, int nbBins, std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet); + + float computeBestClipping(std::vector<int> histogram, std::uint8_t nbBits); + } #endif /* AIDGE_QUANTIZATION_QUANTPTQ_H_ */ diff --git a/python_binding/pybind_QuantPTQ.cpp b/python_binding/pybind_QuantPTQ.cpp index 1a9f08e9..6927513c 100644 --- a/python_binding/pybind_QuantPTQ.cpp +++ b/python_binding/pybind_QuantPTQ.cpp @@ -57,14 +57,14 @@ void init_QuantPTQ(py::module &m) { :rtype: dict )mydelimiter"); - m.def("normalize_activations", &normalizeActivations, py::arg("network"), py::arg("input_dataset"), + m.def("normalize_activations", &normalizeActivations, py::arg("network"), py::arg("value_ranges"), R"mydelimiter( Normalize the activations of each affine node so that it become equal to one. This is done by reconfiguring the scaling nodes, as well as rescaling the weights and biases tensors. :param network: The GraphView containing the affine nodes. :type network: :py:class:`aidge_core.GraphView` - :param input_dataset: The input dataset on which the value ranges are computed. - :type input_dataset: list of :py:class:`aidge_core.Tensor` + :param value_ranges: The node output value ranges computed over the calibration dataset. + :type value_ranges: list of float. )mydelimiter"); m.def("quantize_normalized_network", &quantizeNormalizedNetwork, py::arg("network"), py::arg("nb_bits"), @@ -76,7 +76,7 @@ void init_QuantPTQ(py::module &m) { :type nb_bits: int )mydelimiter"); - m.def("quantize_network", &quantizeNetwork ,py::arg("network"), py::arg("nb_bits"), py::arg("input_dataset"), + m.def("quantize_network", &quantizeNetwork ,py::arg("network"), py::arg("nb_bits"), py::arg("input_dataset"), py::arg("optimize_cliping") = false, R"mydelimiter( Main quantization routine. Performs every step of the quantization pipeline. :param network: The GraphView to be quantized. @@ -85,6 +85,8 @@ void init_QuantPTQ(py::module &m) { :type nb_bits: int :param input_dataset: The input dataset on which the value ranges are computed. :type input_dataset: list of :py:class:`aidge_core.Tensor` + :param optimize_cliping: Whether to optimize the cliping values or not. + :type optimize_cliping: bool )mydelimiter"); m.def("get_weight_ranges", &getWeightRanges, py::arg("network"), @@ -103,6 +105,9 @@ void init_QuantPTQ(py::module &m) { :type network: :py:class:`aidge_core.GraphView` )mydelimiter"); + m.def("compute_scaling_histograms", &computeScalingHistograms, py::arg("value_ranges"), py::arg("nb_bins"), py::arg("network"), py::arg("input_dataset"), "compute scaling histogram"); + m.def("compute_best_clipping", &computeBestClipping, py::arg("histogram"), py::arg("nb_bits"), "compute the best clipping for an histogram"); + m.def("dev_ptq", &devPTQ, py::arg("network"), "dev ptq"); } diff --git a/src/QuantPTQ.cpp b/src/QuantPTQ.cpp index 0f77a055..cff72b72 100644 --- a/src/QuantPTQ.cpp +++ b/src/QuantPTQ.cpp @@ -78,7 +78,7 @@ static bool isSeamless(std::shared_ptr<Node> node) bool checkArchitecture(std::shared_ptr<GraphView> graphView) { - std::set<std::string> otherNodeTypes({"Add", "Concat", "Softmax", "ReLU", "Producer"}); + std::set<std::string> otherNodeTypes({"Flatten", "Add", "Concat", "Softmax", "ReLU", "Producer"}); for (std::shared_ptr<Node> node : graphView->getNodes()) { @@ -232,10 +232,15 @@ void appendIdentity(std::shared_ptr<GraphView> graphView) { std::vector<std::shared_ptr<Node>> extractNodeVector(std::shared_ptr<GraphView> graphView, bool verbose) { + std::vector<std::shared_ptr<Node>> nodeVector; + SequentialScheduler scheduler(graphView); scheduler.forward(); + nodeVector = scheduler.getStaticScheduling(); - std::vector<std::shared_ptr<Node>> nodeVector = scheduler.getStaticScheduling(); + //graphView->forwardDims(); + //scheduler.generateScheduling(); + //nodeVector = scheduler.getStaticScheduling(); fixScheduling(nodeVector); @@ -248,6 +253,9 @@ std::vector<std::shared_ptr<Node>> extractNodeVector(std::shared_ptr<GraphView> Log::info("{} {}", node->type(), node->name()); } + //for (auto node : nodeVector) + // std::cout << node->type() << std::endl; + return nodeVector; } @@ -329,7 +337,14 @@ static std::shared_ptr<Node> getPreviousScalingNode(std::shared_ptr<Node> mergin { std::shared_ptr<Node> currNode = mergingNode; while(currNode->type() != "Scaling") + { + if (currNode->getParents().size() == 0) + { + Log::warn(" Warning : No previous Scaling node were found ! "); + break; + } currNode = currNode->getParents()[0]; + } return currNode; } @@ -439,7 +454,7 @@ std::map<std::string, float> computeRanges(std::shared_ptr<GraphView> graphView, std::set<std::shared_ptr<Node>> nodeSet = graphView->getNodes(); for (std::shared_ptr<Node> node : nodeSet) { - if (node->type() != "Producer") + if (node->type() == "Scaling") // XXX (node->type() != "Producer") { std::shared_ptr<Operator> nodeOperator = node->getOperator(); std::shared_ptr<Tensor> valueTensor = std::static_pointer_cast<Tensor> (nodeOperator->getRawOutput(0)); @@ -459,33 +474,29 @@ std::map<std::string, float> computeRanges(std::shared_ptr<GraphView> graphView, std::set<std::shared_ptr<Node>> nodeSet = graphView->getNodes(); for (std::shared_ptr<Node> node : nodeSet) - if (node->type() != "Producer") + if (node->type() == "Scaling") // XXX (node->type() != "Producer") valueRanges.insert(std::make_pair(node->name(), 0)); //int i = 0; - for (std::shared_ptr<Tensor> sample : inputDataSet) // XXX HERE const + for (std::shared_ptr<Tensor> sample : inputDataSet) { std::map<std::string, float> sampleRanges = computeRanges(graphView, sample); for (std::shared_ptr<Node> node : nodeSet) { - if (node->type() != "Producer") + if (node->type() == "Scaling") // XXX (node->type() != "Producer") { std::string nodeName = node->name(); if (sampleRanges[nodeName] > valueRanges[nodeName]) valueRanges[nodeName] = sampleRanges[nodeName]; } } - } + } return valueRanges; } -void normalizeActivations(std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet) +void normalizeActivations(std::shared_ptr<GraphView> graphView, std::map<std::string, float> valueRanges) { - // EXTRACT VALUE RANGE MAP //////////////////////////////////////////////// - - std::map<std::string, float> valueRanges = computeRanges(graphView, inputDataSet); - // CREATE THE SCALING FACTOR MAP ////////////////////////////////////////// std::vector<std::shared_ptr<Node>> nodeVector = extractNodeVector(graphView, false); @@ -519,7 +530,8 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::vector<std: std::shared_ptr<Node> prevNode = node->getParent(0); float prevScalingFactor = scalingFactors[prevNode->name()]; - float scalingFactor = valueRanges[node->name()]; // XXX HERE !!! + // XXX HERE : valueRanges must contains all the scaling nodes !!! + float scalingFactor = valueRanges[node->name()]; std::shared_ptr<Scaling_Op> scalingOperator = std::static_pointer_cast<Scaling_Op> (node->getOperator()); scalingOperator->getAttr<float>("scalingFactor") /= (scalingFactor / prevScalingFactor); @@ -554,11 +566,10 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::vector<std: // Ensure that the adding node does not overflow ... if (node->type() == "Add") { std::shared_ptr<Node> maxNode = mergingNodes[maxNodeIndex]; - maxScaling /= valueRanges[maxNode->name()] / valueRanges[node->name()]; + maxScaling /= valueRanges[getPreviousScalingNode(maxNode)->name()]; + maxScaling *= valueRanges[getPreviousScalingNode(node)->name()]; } - // Log::info(" MAX SCALING : {} ", maxScaling); - scalingFactors[node->name()] = maxScaling; for (std::shared_ptr<Node> mergingNode : mergingNodes) @@ -622,9 +633,138 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_ } } -void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet) +std::map<std::string, std::vector<int>> computeScalingHistograms(std::map<std::string, float> valueRanges, int nbBins, std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet) { - Log::info(" === QUANT PTQ 0.2.7 === "); + std::cout << " COMPUTING HISTOGRAMS ... " << std::endl; + + std::map<std::string, std::vector<int>> histograms; + + SequentialScheduler scheduler(graphView); + + std::shared_ptr<Node> inputNode = getFirstNode(graphView); + + // Setup the histograms ... + + for (std::shared_ptr<Node> node : graphView->getNodes()) + { + if (node->type() == "Scaling") + { + std::vector<int> histogram; + for (int i = 0; i < nbBins; i++) + histogram.push_back(0); + + histograms.insert(std::make_pair(node->name(), histogram)); + } + } + + // Fill the histograms ... + + for (std::shared_ptr<Tensor> inputTensor : inputDataSet) + { + // Setup the input + std::shared_ptr<Node> inputProducer = inputNode->getParent(0); + inputProducer->getOperator()->setOutput(0, inputTensor); + + // Forward ... + scheduler.forward(); + + // Gather values ... + for (std::shared_ptr<Node> node : graphView->getNodes()) + { + if (node->type() == "Scaling") + { + float valueRange = valueRanges[node->name()]; + + std::shared_ptr<Operator> nodeOperator = node->getOperator(); + std::shared_ptr<Tensor> valueTensor = std::static_pointer_cast<Tensor> (nodeOperator->getRawOutput(0)); + + float * castedTensor = static_cast<float *> (valueTensor->getImpl()->rawPtr()); + for(std::size_t i = 0; i < valueTensor->size(); i++) + { + int bin = std::round(std::abs(castedTensor[i] / valueRange * nbBins)); + histograms[node->name()][bin]++; + } + } + } + } + + return histograms; +} + +float computeBestClipping(std::vector<int> histogram, std::uint8_t nbBits) +{ + //std::cout << " TEST " << std::endl; + + int nbBins = histogram.size(); + int nbIter = 100; + int signedMax = (1 << (nbBits - 1)) - 1; + + // Compute the cumulative approximation error : + // At each iteration we test a clipping candidate and loop over + // the histogram to accumulate the squared error + + std::vector<float> clippingErrors; + for (int it = 0; it < nbIter; it++) + { + // Compute the rounding cost of this particular clipping ... + float acc = 0.0; + float clipping = it / static_cast<float> (nbIter); + for (int bin = 0; bin < nbBins; bin++) + { + float value = (bin + 0.5) / nbBins; + float scaling = signedMax / clipping; + float rounded = std::round(value * scaling) / scaling; + float clipped = std::min(clipping, rounded); + + float approxError = (clipped - value); + acc += (approxError * approxError) * histogram[bin]; + } + clippingErrors.push_back(acc); + } + + //for (int it = 0; it < nbIter; it++) + // std::cout << " it : " << it << " " << clippingErrors[it] << std::endl; + + float bestClipping = 0.0; + float minClippingError = clippingErrors[0]; + for (int it = 0; it < nbIter; it++) + if (clippingErrors[it] < minClippingError) + { + bestClipping = it / static_cast<float> (nbIter); + minClippingError = clippingErrors[it]; + } + + return bestClipping; +} + +void adjustRanges(std::map<std::string, float>& valueRanges, std::uint8_t nbBits, std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet) +{ + //std::cout << " BEFORE CLIPING : " << std::endl; + //std::map<std::string, float>::iterator it; + //for (it = valueRanges.begin(); it != valueRanges.end(); it++) + // std::cout << it->first << " " << it->second << std::endl; + + int nbBins = 2000; // XXX FIX THIS !!! + + std::map<std::string, std::vector<int>> histograms = computeScalingHistograms(valueRanges, nbBins, graphView, inputDataSet); + + for (std::shared_ptr<Node> node : graphView->getNodes()) + if (node->type() == "Scaling") + { + std::vector<int> histogram = histograms[node->name()]; + float cliping = computeBestClipping(histogram, nbBits); + //std::cout << " cliping " << node->name() << " " << cliping << std::endl; + valueRanges[node->name()] *= cliping; + } + + //std::cout << " AFTER CLIPING : " << std::endl; + //for (it = valueRanges.begin(); it != valueRanges.end(); it++) + // std::cout << it->first << " " << it->second << std::endl; +} + +void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet, bool OptimizeCliping) +{ + Log::info(" === QUANT PTQ 0.2.8 === "); if (!checkArchitecture(graphView)) return; @@ -641,8 +781,17 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, Log::info(" Normalizing the parameters ..."); normalizeParameters(graphView); + Log::info(" Computing the value ranges ..."); + std::map<std::string, float> valueRanges = computeRanges(graphView, inputDataSet); + + if (OptimizeCliping) + { + Log::info(" Optimizing the cliping values ..."); + adjustRanges(valueRanges, nbBits, graphView, inputDataSet); + } + Log::info(" Normalizing the activations ..."); - normalizeActivations(graphView, inputDataSet); + normalizeActivations(graphView, valueRanges); Log::info(" Quantizing the normalized network ..."); quantizeNormalizedNetwork(graphView, nbBits); @@ -682,10 +831,14 @@ void devPTQ(std::shared_ptr<GraphView> graphView) for (std::shared_ptr<Node> node : graphView->getNodes()) std::cout << " ### node : " << node->type() << std::endl; } - - } +/* + std::shared_ptr<Node> lastScalingNode = getPreviousScalingNode(getLastNode(graphView)); + std::cout << " LAST SCALING NODE : " << lastScalingNode->name() << " " << lastScalingNode->type() << std::endl; + std::vector<int> histogram = nodeHistogram(lastScalingNode, graphView, inputDataSet); + histograms.insert(std::make_pair(lastScalingNode->name(), histogram)); +*/ /* std::map<std::string, std::vector<int>> getValueHistograms(std::shared_ptr<GraphView> graphView, std::shared_ptr<Tensor> inputTensor, std::map<std::string, float> valueRanges) -- GitLab