diff --git a/include/aidge/quantization/PTQ/Clipping.hpp b/include/aidge/quantization/PTQ/Clipping.hpp index 3f65c42eb2032da10c4d337b53fb1bdd08a7aa55..159b64f12f8c6ae2bb3e88592b29f211e15fa614 100644 --- a/include/aidge/quantization/PTQ/Clipping.hpp +++ b/include/aidge/quantization/PTQ/Clipping.hpp @@ -13,7 +13,7 @@ #define AIDGE_QUANTIZATION_QUANTIZATION_PTQ_CLIP_H_ #include <cstdint> // std::uint8_t -#include <map> +#include <unordered_map> #include <memory> #include <string> #include <vector> @@ -36,7 +36,7 @@ namespace Aidge * @param inputDataSet The input dataset, consisting of a vector of input samples. * @return A map associating each node name to it's corresponding activation histogram. */ - std::map<std::string, std::vector<int>> computeHistograms(std::map<std::string, double> valueRanges, int nbBins, std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet, bool useCuda); + std::unordered_map<std::shared_ptr<Node>, std::vector<int>> computeHistograms(std::unordered_map<std::shared_ptr<Node>, double> valueRanges, int nbBins, std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet, bool useCuda); /** * @brief Given an input activation histogram, compute the optimal clipping value in the sense of the Lp norm. @@ -67,7 +67,7 @@ namespace Aidge * @param verbose Whether to print the clipping values or not. * @return The corrected map associating each provided node to its clipped range. */ - std::map<std::string, double> adjustRanges(Clipping clippingMode, std::map<std::string, double> valueRanges, std::uint8_t nbBits, std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet, bool useCuda, bool verbose); + std::unordered_map<std::shared_ptr<Node>, double> adjustRanges(Clipping clippingMode, std::unordered_map<std::shared_ptr<Node>, double> valueRanges, std::uint8_t nbBits, std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet, bool useCuda, bool verbose); } diff --git a/include/aidge/quantization/PTQ/PTQ.hpp b/include/aidge/quantization/PTQ/PTQ.hpp index 1d1b71ba7501580ea99103d351eafac9a7f793d2..4bfe65fd5f6514f7bdc939b583142d6f8e107099 100644 --- a/include/aidge/quantization/PTQ/PTQ.hpp +++ b/include/aidge/quantization/PTQ/PTQ.hpp @@ -13,7 +13,7 @@ #define AIDGE_QUANTIZATION_QUANTIZATION_PTQ_PTQ_H_ #include <cstdint> // std::uint8_t -#include <map> +#include <unordered_map> #include <memory> #include <set> #include <string> @@ -129,7 +129,7 @@ namespace Aidge { * @param scalingNodesOnly Whether to restrain the retreival of the ranges to scaling nodes only or not. * @return A map associating each affine node name to it's corresponding output range. */ - std::map<std::string, double> computeRanges(std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet, bool scalingNodesOnly, bool useCuda); + std::unordered_map<std::shared_ptr<Node>, double> computeRanges(std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet, bool scalingNodesOnly, bool useCuda); /** * @brief Normalize the activations of each affine node so that they fit in the [-1:1] range. @@ -137,7 +137,7 @@ namespace Aidge { * @param graphView The GraphView containing the affine nodes. * @param valueRanges The node output value ranges computed over the calibration dataset. */ - void normalizeActivations(std::shared_ptr<GraphView> graphView, std::map<std::string, double> valueRanges); + void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_map<std::shared_ptr<Node>, double> valueRanges); /** * @brief For each node, compute the sign of its input and output values. @@ -146,7 +146,7 @@ namespace Aidge { * @param verbose Whether to print the sign map or not. * @return A map associating a pair of sign to each node of the GraphView (a sign for the input and one for the output). */ - std::map<std::string, std::pair<bool, bool>> computeSignMap(std::shared_ptr<GraphView> graphView, bool verbose); + std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> computeSignMap(std::shared_ptr<GraphView> graphView, bool verbose); /** * @brief Quantize an already normalized (in term of parameters and activations) network. @@ -176,7 +176,7 @@ namespace Aidge { * @param graphView The GraphView containing the affine nodes. * @return A map associating each affine node name to it's corresponding weight range. */ - std::map<std::string, double> getWeightRanges(std::shared_ptr<GraphView> graphView); + std::unordered_map<std::string, double> getWeightRanges(std::shared_ptr<GraphView> graphView); /** * @brief Clear the affine nodes biases. Provided form debugging purposes. diff --git a/include/aidge/quantization/QAT/QAT_LSQ.hpp b/include/aidge/quantization/QAT/QAT_LSQ.hpp index 4970be07fae8737a1c2863600757bb81ff3a65f9..b1e7b6fcf99a50e707da2fdc7f7c35cdb2d778f7 100644 --- a/include/aidge/quantization/QAT/QAT_LSQ.hpp +++ b/include/aidge/quantization/QAT/QAT_LSQ.hpp @@ -9,36 +9,30 @@ * ********************************************************************************/ -#ifndef AIDGE_QUANTIZATION_QAT_LSQ_H_ -#define AIDGE_QUANTIZATION_QAT_LSQ_H_ - -#include "aidge/graph/Node.hpp" -#include "aidge/graph/GraphView.hpp" -#include "aidge/data/Tensor.hpp" - -namespace Aidge { -namespace QuantLSQ { - -/** - * @brief Insert the LSQ quantizer nodes in a given GraphView - * @param graphView The GraphView containing the graph to quantize. - * @param nbBits Number of quantization bits. - * @param span Fixed output span of the quantizers. - */ -void insertQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits, float step_size); - -/** - * @brief Given a GraphView with parameters properly initialized and some calibration data, - * insert the LSQ quantizer nodes, and adjust their step-sizes. - * @param graphView The GraphView containing the graph to quantize. - * @param nbBits Number of quantization bits. - * @param calibrationData Calibration data used to adjust the spans. - * @param scale Multiplicative constant applied to the spans. - */ -void insertAndInitQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits, std::shared_ptr<Tensor> calibrationData); - -} -} - -#endif /* AIDGE_QUANTIZATION_QAT_LSQ_H_ */ - + #ifndef AIDGE_QUANTIZATION_QUANTIZATION_QAT_LSQ_H_ + #define AIDGE_QUANTIZATION_QUANTIZATION_QAT_LSQ_H_ + + #include <cstddef> // std::size_t + #include <memory> + + #include "aidge/data/Tensor.hpp" + #include "aidge/graph/GraphView.hpp" + + namespace Aidge { + namespace QuantLSQ { + + /** + * @brief Given a GraphView with parameters properly initialized, insert + * the LSQ quantizer nodes, and setup the adjustment their step-sizes. + * @param graphView The GraphView containing the network to quantize. + * @param nbBits Number of quantization bits. + */ + + void setupQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits); + + } // namespace QuantLSQ + } // namespace Aidge + + #endif /* AIDGE_QUANTIZATION_QUANTIZATION_QAT_LSQ_H_ */ + + \ No newline at end of file diff --git a/python_binding/pybind_PTQ.cpp b/python_binding/pybind_PTQ.cpp index ae0a0def28a861e2fc207adbc27c6af47dc0ded8..12d14340f9353114d06121fa8f1e1fd4f050e3f4 100644 --- a/python_binding/pybind_PTQ.cpp +++ b/python_binding/pybind_PTQ.cpp @@ -174,7 +174,6 @@ void init_PTQ(py::module &m) { :rtype: dict )mydelimiter"); - m.def("compute_sign_map", &computeSignMap, py::arg("network"), py::arg("verbose") = false, R"mydelimiter( For each node, compute the sign of its input and output values. diff --git a/python_binding/pybind_QAT_LSQ.cpp b/python_binding/pybind_QAT_LSQ.cpp index 206985efe4558a84ce1ed67a1264bd6902213764..dd118dccc24dca71185c9401a924fbae0d22cc6c 100644 --- a/python_binding/pybind_QAT_LSQ.cpp +++ b/python_binding/pybind_QAT_LSQ.cpp @@ -9,22 +9,21 @@ * ********************************************************************************/ -#include <pybind11/pybind11.h> -#include <pybind11/stl.h> - -#include "aidge/quantization/QAT/QAT_LSQ.hpp" -#include "aidge/graph/GraphView.hpp" - -namespace py = pybind11; - -namespace Aidge { - -void init_QAT_LSQ(py::module &m) { - - auto mQuantLSQ = m.def_submodule("lsq"); - - mQuantLSQ.def("insert_quantizers", &QuantLSQ::insertQuantizers, py::arg("network"), py::arg("nb_bits"), py::arg("step_size")); - - mQuantLSQ.def("insert_and_init_quantizers", &QuantLSQ::insertAndInitQuantizers, py::arg("network"), py::arg("nb_bits"), py::arg("calibration_data")); -} -} // namespace Aidge + #include <pybind11/pybind11.h> + #include <pybind11/stl.h> + + #include "aidge/quantization/QAT/QAT_LSQ.hpp" + #include "aidge/graph/GraphView.hpp" + + namespace py = pybind11; + + namespace Aidge { + + void init_QAT_LSQ(py::module &m) { + + auto mQuantLSQ = m.def_submodule("lsq"); + + mQuantLSQ.def("setup_quantizers", &QuantLSQ::setupQuantizers, py::arg("network"), py::arg("nb_bits")); + } + } // namespace Aidge + \ No newline at end of file diff --git a/src/PTQ/Clipping.cpp b/src/PTQ/Clipping.cpp index a4e7fed921604fcf9d18c6e50991220c4785f3bb..5bd2a7da90f8ddd2d5d3903e4a4479e7654233e5 100644 --- a/src/PTQ/Clipping.cpp +++ b/src/PTQ/Clipping.cpp @@ -19,7 +19,7 @@ namespace Aidge { -std::map<std::string, std::vector<int>> computeHistograms(std::map<std::string, double> valueRanges, int nbBins, std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet, bool useCuda) +std::unordered_map<std::shared_ptr<Node>, std::vector<int>> computeHistograms(std::unordered_map<std::shared_ptr<Node>, double> valueRanges, int nbBins, std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet, bool useCuda) { if (useCuda) graphView->setBackend("cuda"); @@ -28,7 +28,7 @@ std::map<std::string, std::vector<int>> computeHistograms(std::map<std::string, // Log::debug(" COMPUTING HISTOGRAMS ... "); - std::map<std::string, std::vector<int>> histograms; + std::unordered_map<std::shared_ptr<Node>, std::vector<int>> histograms; SequentialScheduler scheduler(graphView); scheduler.resetScheduling(); @@ -37,14 +37,14 @@ std::map<std::string, std::vector<int>> computeHistograms(std::map<std::string, for (std::shared_ptr<Node> node : graphView->getNodes()) { - bool isInsideRanges = (valueRanges.find(node->name()) != valueRanges.end()); + bool isInsideRanges = (valueRanges.find(node) != valueRanges.end()); if (isInsideRanges) { std::vector<int> histogram; for (int i = 0; i < nbBins; i++) histogram.push_back(0); - histograms.insert(std::make_pair(node->name(), histogram)); + histograms.insert(std::make_pair(node, histogram)); } } @@ -69,10 +69,10 @@ std::map<std::string, std::vector<int>> computeHistograms(std::map<std::string, for (std::shared_ptr<Node> node : graphView->getNodes()) { - bool isInsideRanges = (valueRanges.find(node->name()) != valueRanges.end()); + bool isInsideRanges = (valueRanges.find(node) != valueRanges.end()); if (isInsideRanges) { - double valueRange = valueRanges[node->name()]; + double valueRange = valueRanges[node]; std::shared_ptr<Operator> nodeOperator = node->getOperator(); std::shared_ptr<Tensor> valueTensor = std::static_pointer_cast<Tensor> (nodeOperator->getRawOutput(0)); @@ -82,7 +82,7 @@ std::map<std::string, std::vector<int>> computeHistograms(std::map<std::string, double * castedTensor = static_cast<double *> (valueTensor->getImpl()->rawPtr()); - std::vector<int> nodeHistogram = histograms[node->name()]; + std::vector<int> nodeHistogram = histograms[node]; for(std::size_t i = 0; i < valueTensor->size(); i++) { std::size_t bin = std::round(std::abs(castedTensor[i] / valueRange * nbBins)); @@ -90,7 +90,7 @@ std::map<std::string, std::vector<int>> computeHistograms(std::map<std::string, nodeHistogram[bin]++; } - histograms[node->name()] = nodeHistogram; + histograms[node] = nodeHistogram; if (useCuda) valueTensor->setBackend("cuda"); @@ -207,7 +207,7 @@ double computeKLClipping(std::vector<int> refHistogram, std::uint8_t nbBits) } -std::map<std::string, double> adjustRanges(Clipping clippingMode, std::map<std::string, double> valueRanges, std::uint8_t nbBits, std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet, bool useCuda, bool verbose) +std::unordered_map<std::shared_ptr<Node>, double> adjustRanges(Clipping clippingMode, std::unordered_map<std::shared_ptr<Node>, double> valueRanges, std::uint8_t nbBits, std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet, bool useCuda, bool verbose) { double clipping = 1.0f; @@ -218,13 +218,13 @@ std::map<std::string, double> adjustRanges(Clipping clippingMode, std::map<std:: if (verbose) Log::info(" === CLIPPING VALUES === "); - std::map<std::string, std::vector<int>> histograms = computeHistograms(valueRanges, nbBins, graphView, inputDataSet, useCuda); + std::unordered_map<std::shared_ptr<Node>, std::vector<int>> histograms = computeHistograms(valueRanges, nbBins, graphView, inputDataSet, useCuda); for (std::shared_ptr<Node> node : graphView->getNodes()) { if (node->attributes()->hasAttr("quantization.ptq.isScaling")) { - std::vector<int> histogram = histograms[node->name()]; + std::vector<int> histogram = histograms[node]; if (clippingMode == Clipping::MSE) clipping = computeMEClipping(histogram, nbBits, 2.0); @@ -236,7 +236,7 @@ std::map<std::string, double> adjustRanges(Clipping clippingMode, std::map<std:: if (verbose) Log::info(" {:.6f} ({})", clipping, node->name()); - valueRanges[node->name()] *= clipping; + valueRanges[node] *= clipping; } } } diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp index f03fc7bcea039a1939e116cc842f7062f28c5cae..fe717bbdbae7d5bb3bd74fe65124bfce8f59da2c 100644 --- a/src/PTQ/PTQ.cpp +++ b/src/PTQ/PTQ.cpp @@ -21,7 +21,6 @@ #include "aidge/scheduler/Scheduler.hpp" #include "aidge/utils/Log.hpp" - #include "aidge/operator/Producer.hpp" #include "aidge/operator/Mul.hpp" #include "aidge/operator/Round.hpp" @@ -31,11 +30,9 @@ #include "aidge/operator/ArgMax.hpp" #include "aidge/operator/Reshape.hpp" - #include "aidge/recipes/Recipes.hpp" #include "aidge/recipes/QuantRecipes.hpp" - namespace Aidge { @@ -432,11 +429,9 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView) std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView); - std::map<std::string, double> accumulatedRatios; + std::unordered_map<std::shared_ptr<Node>, double> accumulatedRatios; for (std::shared_ptr<Node> node : nodeVector) - { - accumulatedRatios.insert(std::make_pair(node->name(), 1.0)); - } + accumulatedRatios.insert(std::make_pair(node, 1.0)); // ITERATE OVER THE GRAPH ///////////////////////////////////////////////// @@ -450,7 +445,7 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView) if (node != firstNode) { std::shared_ptr<Node> prevNode = node->getParent(0); - accumulatedRatios[node->name()] = accumulatedRatios[prevNode->name()]; + accumulatedRatios[node] = accumulatedRatios[prevNode]; } } @@ -462,17 +457,17 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView) double scaling = getTensorAbsoluteMax(weightTensor); double ratio = 1.0 / scaling; //rescaleTensor(weightTensor, ratio); - insertScalingBelowProducer(node->getParent(1),ratio,graphView); + insertScalingBelowProducer(node->getParent(1), ratio, graphView); // Accumulate the ratio if (node == firstNode) { - accumulatedRatios[node->name()] = ratio; + accumulatedRatios[node] = ratio; } else { std::shared_ptr<Node> prevNode = node->getParent(0); - accumulatedRatios[node->name()] = accumulatedRatios[prevNode->name()] * ratio; + accumulatedRatios[node] = accumulatedRatios[prevNode] * ratio; } // Handle the bias .. @@ -480,8 +475,8 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView) if (nodeHasBias(node)) { std::shared_ptr<Tensor> biasTensor = getBiasTensor(node); - //rescaleTensor(biasTensor, accumulatedRatios[node->name()] ); - insertScalingBelowProducer(node->getParent(2),accumulatedRatios[node->name()],graphView); + //rescaleTensor(biasTensor, accumulatedRatios[node] ); + insertScalingBelowProducer(node->getParent(2), accumulatedRatios[node], graphView); } } @@ -493,33 +488,33 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView) double maxRatio = 0; for (std::shared_ptr<Node> mergingNode : mergingNodes) { - double merginNodeRatio = accumulatedRatios[mergingNode->name()]; + double merginNodeRatio = accumulatedRatios[mergingNode]; if (merginNodeRatio > maxRatio) maxRatio = merginNodeRatio; } - accumulatedRatios[node->name()] = maxRatio; + accumulatedRatios[node] = maxRatio; // Rescale the previous scaling Nodes for (std::shared_ptr<Node> mergingNode : mergingNodes) { - double mergingNodeRatio = accumulatedRatios[mergingNode->name()]; + double mergingNodeRatio = accumulatedRatios[mergingNode]; double rescaling = mergingNodeRatio / maxRatio; std::shared_ptr<Node> scalingNode = getPreviousScalingNode(mergingNode); multiplyScalingFactor(scalingNode,1/rescaling); - accumulatedRatios[mergingNode->name()] /= rescaling; // optional ... + accumulatedRatios[mergingNode] /= rescaling; // optional ... } } } } // XXX TODO : take care of the CUDA backend for this too !!! -std::map<std::string, double> computeRanges(std::shared_ptr<GraphView> graphView, std::shared_ptr<Tensor> inputTensor, bool scalingNodesOnly) +std::unordered_map<std::shared_ptr<Node>, double> computeRanges(std::shared_ptr<GraphView> graphView, std::shared_ptr<Tensor> inputTensor, bool scalingNodesOnly) { - std::map<std::string, double> valueRanges; + std::unordered_map<std::shared_ptr<Node>, double> valueRanges; SequentialScheduler scheduler(graphView); scheduler.resetScheduling(); @@ -540,23 +535,23 @@ std::map<std::string, double> computeRanges(std::shared_ptr<GraphView> graphView double range = getTensorAbsoluteMax(valueTensor); // Associate the value to the scaling node ... - valueRanges.insert(std::make_pair(node->name(), range)); + valueRanges.insert(std::make_pair(node, range)); } } return valueRanges; } -std::map<std::string, double> computeRanges(std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet, bool scalingNodesOnly, bool useCuda) +std::unordered_map<std::shared_ptr<Node>, double> computeRanges(std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet, bool scalingNodesOnly, bool useCuda) { - std::map<std::string, double> valueRanges; + std::unordered_map<std::shared_ptr<Node>, double> valueRanges; std::set<std::shared_ptr<Node>> nodeSet = graphView->getNodes(); // std::shared_ptr<Node> inputNode = getFirstNode(graphView); for (std::shared_ptr<Node> node : nodeSet) if ((scalingNodesOnly && (node->attributes()->hasAttr("quantization.ptq.isScaling"))) || (!scalingNodesOnly && (node->type() != "Producer"))) - valueRanges.insert(std::make_pair(node->name(), 0)); + valueRanges.insert(std::make_pair(node, 0)); if (useCuda) graphView->setBackend("cuda"); @@ -579,7 +574,7 @@ std::map<std::string, double> computeRanges(std::shared_ptr<GraphView> graphView // Gather the sample ranges ... - std::map<std::string, double> sampleRanges; + std::unordered_map<std::shared_ptr<Node>, double> sampleRanges; for (std::shared_ptr<Node> node : nodeSet) { if ((scalingNodesOnly && (node->attributes()->hasAttr("quantization.ptq.isScaling"))) || (!scalingNodesOnly && (node->type() != "Producer"))) @@ -593,7 +588,7 @@ std::map<std::string, double> computeRanges(std::shared_ptr<GraphView> graphView double range = getTensorAbsoluteMax(valueTensor); // Associate the value to the scaling node ... - sampleRanges.insert(std::make_pair(node->name(), range)); + sampleRanges.insert(std::make_pair(node, range)); if (useCuda) valueTensor->setBackend("cuda"); @@ -605,11 +600,8 @@ std::map<std::string, double> computeRanges(std::shared_ptr<GraphView> graphView for (std::shared_ptr<Node> node : nodeSet) { if ((scalingNodesOnly && (node->attributes()->hasAttr("quantization.ptq.isScaling"))) || (!scalingNodesOnly && (node->type() != "Producer"))) - { - std::string nodeName = node->name(); - if (sampleRanges[nodeName] > valueRanges[nodeName]) - valueRanges[nodeName] = sampleRanges[nodeName]; - } + if (sampleRanges[node] > valueRanges[node]) + valueRanges[node] = sampleRanges[node]; } if (useCuda) @@ -622,7 +614,7 @@ std::map<std::string, double> computeRanges(std::shared_ptr<GraphView> graphView return valueRanges; } -void normalizeActivations(std::shared_ptr<GraphView> graphView, std::map<std::string, double> valueRanges) +void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_map<std::shared_ptr<Node>, double> valueRanges) { std::shared_ptr<Node> firstNode = getFirstNode(graphView); @@ -630,10 +622,10 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::map<std::st std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView); - std::map<std::string, double> scalingFactors; + std::unordered_map<std::shared_ptr<Node>, double> scalingFactors; for (std::shared_ptr<Node> node : nodeVector) - scalingFactors.insert(std::make_pair(node->name(), 1.0)); + scalingFactors.insert(std::make_pair(node, 1.0)); // ITERATE OVER THE GRAPH ///////////////////////////////////////////////// @@ -645,16 +637,15 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::map<std::st { if (node == firstNode) { - scalingFactors[node->name()] = 1.0; + scalingFactors[node] = 1.0; } else { std::shared_ptr<Node> prevNode = node->getParent(0); - scalingFactors[node->name()] = scalingFactors[prevNode->name()]; + scalingFactors[node] = scalingFactors[prevNode]; } } - // Here prevNode is either a 'Affine' or a 'Merging' // => do not split the cases, just handle the bias ... @@ -663,14 +654,14 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::map<std::st // retrieve the previous scaling factor ... std::shared_ptr<Node> prevNode = node->getParent(0); - double prevScalingFactor = scalingFactors[prevNode->name()]; + double prevScalingFactor = scalingFactors[prevNode]; // ValueRanges must contains all the scaling nodes !!! - double scalingFactor = valueRanges[node->name()]; + double scalingFactor = valueRanges[node]; - multiplyScalingFactor(node,1/(scalingFactor / prevScalingFactor)); + multiplyScalingFactor(node, 1 / (scalingFactor / prevScalingFactor)); - scalingFactors[node->name()] = scalingFactor; + scalingFactors[node] = scalingFactor; // If prevNode is Affine, fix the bias ... @@ -682,12 +673,12 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::map<std::st { std::shared_ptr<Tensor> biasTensor = getBiasTensor(prevNode); //rescaleTensor(biasTensor, 1.0 / prevScalingFactor); - insertScalingBelowProducer(prevNode->getParent(2),1.0 / prevScalingFactor,graphView); + insertScalingBelowProducer(prevNode->getParent(2), 1.0 / prevScalingFactor, graphView); } } } - // Merging nodes handling : use a maximum arbritation ... + // Merging nodes handling : use a maximum arbritration ... if (isMerging(node)) { @@ -697,38 +688,38 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::map<std::st double maxScaling = 0; for (std::size_t i = 0; i < mergingNodes.size(); i++) { - double merginNodeScaling = scalingFactors[mergingNodes[i]->name()]; - if (merginNodeScaling > maxScaling) { - maxScaling = merginNodeScaling; + double mergingNodeScaling = scalingFactors[mergingNodes[i]]; + if (mergingNodeScaling > maxScaling) { + maxScaling = mergingNodeScaling; } } - scalingFactors[node->name()] = maxScaling; + scalingFactors[node] = maxScaling; for (std::shared_ptr<Node> mergingNode : mergingNodes) { - double mergingNodeScaling = scalingFactors[mergingNode->name()]; + double mergingNodeScaling = scalingFactors[mergingNode]; double rescaling = mergingNodeScaling / maxScaling; std::shared_ptr<Node> scalingNode = getPreviousScalingNode(mergingNode); //Log::info(" SCALING NODE : {} {}", scalingNode->type(), scalingNode->name()); - multiplyScalingFactor(scalingNode,rescaling) ; + multiplyScalingFactor(scalingNode, rescaling) ; } } } } -std::map<std::string, std::pair<bool, bool>> computeSignMap(std::shared_ptr<GraphView> graphView, bool verbose) +std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> computeSignMap(std::shared_ptr<GraphView> graphView, bool verbose) { std::shared_ptr<Node> firstNode = getFirstNode(graphView); - std::map<std::string, std::pair<bool, bool>> signMap; + std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> signMap; std::pair<bool, bool> unsignedPair(true, true); for (std::shared_ptr<Node> node : graphView->getNodes()) if (node->type() != "Producer") - signMap.insert(std::make_pair(node->name(), unsignedPair)); + signMap.insert(std::make_pair(node, unsignedPair)); // ITERATE OVER THE GRAPH @@ -741,17 +732,17 @@ std::map<std::string, std::pair<bool, bool>> computeSignMap(std::shared_ptr<Grap if (isAffine(node)) { // Affine nodes always have a single parent - if (!isFirstNode) - signMap[node->name()].first = signMap[node->getParent(0)->name()].second; + if (!isFirstNode) + signMap[node].first = signMap[node->getParent(0)].second; else - signMap[node->name()].first = false; + signMap[node].first = false; - signMap[node->name()].second = false; + signMap[node].second = false; } if (node->attributes()->hasAttr("quantization.ptq.isScaling")) { - signMap[node->name()].second = false; + signMap[node].second = false; // Scaling nodes always have a single parent std::shared_ptr<Node> parent = node->getParent(0); @@ -764,14 +755,14 @@ std::map<std::string, std::pair<bool, bool>> computeSignMap(std::shared_ptr<Grap // Correct the previous single node (when it is an Affine node) ... if (allChildrenAreReLU) if (isAffine(parent) || isMerging(parent)) - signMap[parent->name()].second = true; + signMap[parent].second = true; // Maintain unsigned output - if (signMap[parent->name()].second) - signMap[node->name()].second = true; + if (signMap[parent].second) + signMap[node].second = true; // Set the link ... - signMap[node->name()].first = signMap[parent->name()].second; + signMap[node].first = signMap[parent].second; } if (isMerging(node)) @@ -782,15 +773,15 @@ std::map<std::string, std::pair<bool, bool>> computeSignMap(std::shared_ptr<Grap bool allParentAreUnsigned = true; for(std::shared_ptr<Node> parent : parentNodes) { - bool parentSign = signMap[parent->name()].second; + bool parentSign = signMap[parent].second; allParentAreSigned &= !parentSign; allParentAreUnsigned &= parentSign; } if (allParentAreSigned) - signMap[node->name()] = std::make_pair(false, false); + signMap[node] = std::make_pair(false, false); else if (allParentAreUnsigned) - signMap[node->name()] = std::make_pair(true, true); + signMap[node] = std::make_pair(true, true); else { // Arbitration : Signed type wins ! @@ -798,15 +789,15 @@ std::map<std::string, std::pair<bool, bool>> computeSignMap(std::shared_ptr<Grap { while (!parent->attributes()->hasAttr("quantization.ptq.isScaling")) { - signMap[parent->name()] = std::make_pair(false, false); + signMap[parent] = std::make_pair(false, false); // We are on a branch so nodes always have 1 parent ... parent = parent->getParent(0); } - signMap[parent->name()].second = false; + signMap[parent].second = false; } - signMap[node->name()].first = false; + signMap[node].first = false; } } @@ -816,8 +807,8 @@ std::map<std::string, std::pair<bool, bool>> computeSignMap(std::shared_ptr<Grap std::shared_ptr<Node> parent = node->getParent(0); if (parent) { - signMap[node->name()].first = signMap[parent->name()].second; - signMap[node->name()].second = signMap[node->name()].first; + signMap[node].first = signMap[parent].second; + signMap[node].second = signMap[node].first; } } @@ -829,7 +820,7 @@ std::map<std::string, std::pair<bool, bool>> computeSignMap(std::shared_ptr<Grap { Log::info(" === SIGN MAP === "); for (std::shared_ptr<Node> node : nodeVector) - Log::info(" {}{} | {}", static_cast<int>(signMap[node->name()].first), static_cast<int>(signMap[node->name()].second), node->name()); + Log::info(" {}{} | {}", static_cast<int>(signMap[node].first), static_cast<int>(signMap[node].second), node->name()); } // SANITY CHECK (TEMPORARY) @@ -838,7 +829,7 @@ std::map<std::string, std::pair<bool, bool>> computeSignMap(std::shared_ptr<Grap { for (std::shared_ptr<Node> child : node->getChildren()) { - if (signMap[node->name()].second != signMap[child->name()].first) + if (signMap[node].second != signMap[child].first) Log::error(" computeSignMap : link is not sane ! ({} -> {})", node->name(), child->name()); } } @@ -846,18 +837,15 @@ std::map<std::string, std::pair<bool, bool>> computeSignMap(std::shared_ptr<Grap return signMap; } - void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, bool noQuant, bool optimizeSigns, bool verbose) { if (optimizeSigns && noQuant) - { - AIDGE_THROW_OR_ABORT(std::runtime_error,"Signs optimization can not be applied if network is not fully quantized ..."); - } + AIDGE_THROW_OR_ABORT(std::runtime_error, " Sign-optimization can not be applied if network is not fully quantized ..."); double signedMax = (1 << (nbBits - 1)) - 1; double unsignedMax = (1 << nbBits) - 1; - std::map<std::string, std::pair<bool, bool>> signMap; + std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> signMap; if (optimizeSigns) signMap = computeSignMap(graphView, verbose); @@ -866,7 +854,7 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_ std::pair<bool, bool> signedPair(false, false); for (std::shared_ptr<Node> node : graphView->getNodes()) if (node->type() != "Producer") - signMap.insert(std::make_pair(node->name(), signedPair)); + signMap.insert(std::make_pair(node, signedPair)); } // ITERATE OVER THE GRAPH ///////////////////////////////////////////////// @@ -887,7 +875,7 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_ // Rescale the bias tensor if (nodeHasBias(node)) { - bool inputIsUnsigned = signMap[node->name()].first; + bool inputIsUnsigned = signMap[node].first; double rescaling = inputIsUnsigned ? unsignedMax * signedMax : signedMax * signedMax; std::shared_ptr<Tensor> biasTensor = getBiasTensor(node); @@ -901,13 +889,13 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_ double rescaling = 1.0 / signedMax; - bool inputIsUnsigned = signMap[node->name()].first; - bool outputIsUnsigned = signMap[node->name()].second; + bool inputIsUnsigned = signMap[node].first; + bool outputIsUnsigned = signMap[node].second; rescaling /= inputIsUnsigned ? unsignedMax : signedMax; rescaling *= outputIsUnsigned ? unsignedMax : signedMax; - std::shared_ptr<Node> scalingNode = *(node->getChildren().begin()); // Assert if scalingNode is a Scaling ... + std::shared_ptr<Node> scalingNode = *(node->getChildren().begin()); // TODO : assert if scalingNode is a Scaling ... multiplyScalingFactor(scalingNode,rescaling) ; } @@ -916,15 +904,15 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_ { double rescaling = 1.0; - bool inputIsUnsigned = signMap[node->name()].first; - bool outputIsUnsigned = signMap[node->name()].second; + bool inputIsUnsigned = signMap[node].first; + bool outputIsUnsigned = signMap[node].second; rescaling /= inputIsUnsigned ? unsignedMax : signedMax; rescaling *= outputIsUnsigned ? unsignedMax : signedMax; - std::shared_ptr<Node> scalingNode = *(node->getChildren().begin()); // Assert if scalingNode is a Scaling ... + std::shared_ptr<Node> scalingNode = *(node->getChildren().begin()); // TODO : assert if scalingNode is a Scaling ... - multiplyScalingFactor(scalingNode,rescaling) ; + multiplyScalingFactor(scalingNode, rescaling) ; } // Handle the Scaling Nodes ... @@ -933,23 +921,24 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_ { if (!noQuant) { - // Replace the Scaling Node by Quantizer + // Replace the Scaling Node by a Quantizer + auto scalingFactorTensor = std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(1); std::shared_ptr<Tensor> fallback; const Tensor& localTensor = scalingFactorTensor->refCastFrom(fallback, DataType::Float64, "cpu"); - double old_sf = localTensor.get<double>(0);//!\\ + double oldScalingFactor = localTensor.get<double>(0); //!\\ - std::shared_ptr<Node> quantizerNode = Quantizer(old_sf, -(signedMax + 1), signedMax, node->name()); + std::shared_ptr<Node> quantizerNode = Quantizer(oldScalingFactor, -(signedMax + 1), signedMax, node->name()); quantizerNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) quantizerNode->getOperator()->setBackend(node->getOperator()->backend()); - graphView->replace({node,node->getParent(1)}, {quantizerNode}); + graphView->replace({node, node->getParent(1)}, {quantizerNode}); if (optimizeSigns) { double rescaling = 1.0; - bool inputIsUnsigned = signMap[node->name()].first; - bool outputIsUnsigned = signMap[node->name()].second; + bool inputIsUnsigned = signMap[node].first; + bool outputIsUnsigned = signMap[node].second; rescaling /= inputIsUnsigned ? unsignedMax : signedMax; rescaling *= outputIsUnsigned ? unsignedMax : signedMax; @@ -957,11 +946,8 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_ double currScalingFactor = getScalingFactor(quantizerNode); updateScalingFactor(quantizerNode, currScalingFactor * rescaling); - if(outputIsUnsigned) - { - setClipRange(quantizerNode,0,unsignedMax); - } + setClipRange(quantizerNode, 0, unsignedMax); } } } @@ -1022,7 +1008,7 @@ void performSingleShiftApproximation(std::shared_ptr<GraphView> graphView, bool { if (isAffine(node) || (node->type() == "Mul" && node->attributes()->hasAttr("quantization.ptq.isCompensation"))) { - std::shared_ptr<Node> scalingNode = (*node->getChildren().begin()); + std::shared_ptr<Node> scalingNode = (*node->getChildren().begin()); // TODO : use index = 0 double base = getScalingFactor(scalingNode); @@ -1095,7 +1081,7 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, normalizeParameters(graphView); Log::info(" Computing the value ranges ..."); - std::map<std::string, double> valueRanges = computeRanges(graphView, inputDataSet, true, useCuda); + std::unordered_map<std::shared_ptr<Node>, double> valueRanges = computeRanges(graphView, inputDataSet, true, useCuda); //Log::info(" === RANGES (BEFORE ADJUST) ==="); @@ -1132,9 +1118,9 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, Log::info(" Network is quantized !"); } -std::map<std::string, double> getWeightRanges(std::shared_ptr<GraphView> graphView) +std::unordered_map<std::string, double> getWeightRanges(std::shared_ptr<GraphView> graphView) { - std::map<std::string, double> weightRanges; + std::unordered_map<std::string, double> weightRanges; for (std::shared_ptr<Node> node : graphView->getNodes()) { diff --git a/src/QAT/QAT_LSQ.cpp b/src/QAT/QAT_LSQ.cpp index 9b51e846df498a9303b7373ae1c86d4b007a96f0..dcac6819365e134d777be7479a95d6b8e4093b5e 100644 --- a/src/QAT/QAT_LSQ.cpp +++ b/src/QAT/QAT_LSQ.cpp @@ -9,205 +9,164 @@ * ********************************************************************************/ -#include "aidge/quantization/QAT/QAT_LSQ.hpp" -#include "aidge/operator/LSQ.hpp" -#include "aidge/operator/ReLU.hpp" - - -#include "aidge/data/Tensor.hpp" -#include "aidge/graph/GraphView.hpp" -#include "aidge/scheduler/SequentialScheduler.hpp" -#include "aidge/scheduler/Scheduler.hpp" -#include "aidge/graph/Matching.hpp" -#include "aidge/recipes/QuantRecipes.hpp" - -namespace Aidge { - -void QuantLSQ::insertQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits, float stepSize) -{ - const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|FC#)"); - - for (const auto& match : matches) - { - auto linearNode = match.graph->rootNode(); - - std::pair<int, int> signedRange = {-std::pow(2, nbBits - 1), std::pow(2, nbBits - 1) - 1}; - std::pair<int, int> unsignedRange = {0, std::pow(2, nbBits) - 1}; - - // INPUT QUANTIZERS INSERTION - - // TODO : double check this, and use createUniqueName() - auto inputQuantizerName = makeUniqueName(linearNode->name() + "_lsq_i", graphView); - auto inputQuantizerNode = LSQ(signedRange, inputQuantizerName); - - // Set the step size - - auto inputStepSizeOp = inputQuantizerNode->getParent(1)->getOperator(); - auto inputStepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}})); - inputStepSizeOp->setOutput(0, inputStepSizeTensor); - - // Absorb the ReLU when possible ... - - // XXX is this safe ??? - bool nodeHasParent = static_cast<bool> (linearNode->getParents()[0]); - // bool nodeHasParent = (linearNode->getParents().size() != 0); - - if (nodeHasParent) { - auto parentNode = linearNode->getParents()[0]; - if (parentNode->type() == "ReLU") { - auto inputQuantizerOp = std::static_pointer_cast<LSQ_Op> (inputQuantizerNode->getOperator()); - inputQuantizerOp->range() = unsignedRange; - graphView->replace({parentNode}, {}); - } - } - - // We need to handle the case where the linear node is the first one ... - - if (nodeHasParent) { - graphView->insertParent(linearNode, inputQuantizerNode, 0, 0, 0); - } else { - inputQuantizerNode->addChild(graphView); - graphView->add(inputQuantizerNode); - } - - // PARAM QUANTIZERS INSERTION - - // TODO : double check this, and use createUniqueName() - auto paramQuantizerName = makeUniqueName(linearNode->name() + "_lsq_p", graphView); - auto paramQuantizerNode = LSQ(signedRange, paramQuantizerName); - graphView->insertParent(linearNode, paramQuantizerNode, 1, 0, 0); - - // Set the step size - - auto paramStepSizeOp = paramQuantizerNode->getParent(1)->getOperator(); - auto paramStepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}})); - paramStepSizeOp->setOutput(0, paramStepSizeTensor); - } - -} - -static float getTensorAbsMean(std::shared_ptr<Tensor> tensor) -{ - auto backend = tensor->backend(); - if (backend == "cuda") - tensor->setBackend("cpu"); - - float acc = 0; - float* castedTensor = static_cast<float *> (tensor->getImpl()->rawPtr()); - for(std::size_t i = 0; i < tensor->size(); i++) - acc += std::abs(castedTensor[i]); - acc /= static_cast<float> (tensor->size()); - - if (backend == "cuda") - tensor->setBackend("cuda"); - - return acc; -} - -static std::map<std::string, float> collectInputStats(std::shared_ptr<GraphView> graphView, std::shared_ptr<Tensor> calibrationData, bool useCuda) -{ - // Propagate the calibration tensor - - SequentialScheduler scheduler(graphView); - scheduler.resetScheduling(); - scheduler.forward(true, {calibrationData}); - - // Store the input tensor statistics - - if (useCuda) - graphView->setBackend("cpu"); - - std::map<std::string, float> inputStats; - for (auto node : graphView->getNodes()) - { - if (node->type() == "FC" || node->type() == "Conv2D") // TODO: use graph matching !!! - { - const auto op = std::static_pointer_cast<LSQ_Op>(node->getOperator()); - float inputAbsMean = getTensorAbsMean(op->getInput(0)); - inputStats.insert(std::make_pair(node->name(), inputAbsMean)); - fmt::println("{} -> {}", node->name(), inputAbsMean); - } - } - - if (useCuda) - graphView->setBackend("cuda"); - - return inputStats; -} - -static std::map<std::string, float> collectParamStats(std::shared_ptr<GraphView> graphView, bool useCuda) -{ - if (useCuda) - graphView->setBackend("cpu"); - - std::map<std::string, float> paramStats; - for (auto node : graphView->getNodes()) - { - if (node->type() == "FC" || node->type() == "Conv2D") // TODO: use graph matching !!! - { - const auto op = std::static_pointer_cast<LSQ_Op>(node->getOperator()); - float paramAbsMean = getTensorAbsMean(op->getInput(1)); - paramStats.insert(std::make_pair(node->name(), paramAbsMean)); - fmt::println("{} -> {}", node->name(), paramAbsMean); - } - } - - if (useCuda) - graphView->setBackend("cuda"); - - return paramStats; -} - -static void adjustQuantizersStepSizes(std::shared_ptr<GraphView> graphView, std::map<std::string, float> inputStats, std::map<std::string, float> paramStats) -{ - const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|FC#)"); - - for (const auto& match : matches) - { - auto linearNode = match.graph->rootNode(); - - // INPUT QUANTIZERS STEP-SIZES - - auto inputQuantNode = linearNode->getParent(0); - auto inputQuantOp = std::static_pointer_cast<LSQ_Op>(inputQuantNode->getOperator()); - - float absMean = inputStats[linearNode->name()]; - float stepSize = 2.0f * (absMean / std::sqrt(inputQuantOp->range().second)); - - auto inputStepSizeOp = inputQuantNode->getParent(1)->getOperator(); - // XXX inputStepSizeOp->setOutput(0, std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}}))); - auto inputStepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}})); - inputStepSizeOp->setOutput(0, inputStepSizeTensor); - - // PARAM QUANTIZERS STEP-SIZES - - auto paramQuantNode = linearNode->getParent(1); - auto paramQuantOp = std::static_pointer_cast<LSQ_Op>(paramQuantNode->getOperator()); - - absMean = paramStats[linearNode->name()]; - stepSize = 2.0f * (absMean / std::sqrt(paramQuantOp->range().second)); - - auto paramStepSizeOp = paramQuantNode->getParent(1)->getOperator(); - // XXX paramStepSizeOp->setOutput(0, std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}}))); - auto paramStepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}})); - paramStepSizeOp->setOutput(0, paramStepSizeTensor); - } -} - -void QuantLSQ::insertAndInitQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits, std::shared_ptr<Tensor> calibrationData) -{ - bool useCuda = (calibrationData->backend() == "cuda"); - - // Collect the tensor statisics - auto inputStats = collectInputStats(graphView, calibrationData, useCuda); - - auto paramStats = collectParamStats(graphView, useCuda); - - // Insert the quantizers - insertQuantizers(graphView, nbBits, 1.0); - - // Adjust the quantizers step-sizes - adjustQuantizersStepSizes(graphView, inputStats, paramStats); -} - -} \ No newline at end of file + #include "aidge/quantization/QAT/QAT_LSQ.hpp" + #include "aidge/operator/LSQ.hpp" + #include "aidge/operator/ReLU.hpp" + + + #include "aidge/data/Tensor.hpp" + #include "aidge/graph/GraphView.hpp" + #include "aidge/scheduler/SequentialScheduler.hpp" + #include "aidge/scheduler/Scheduler.hpp" + #include "aidge/graph/Matching.hpp" + #include "aidge/recipes/QuantRecipes.hpp" + + + namespace Aidge + { + + static float getTensorAbsMean(std::shared_ptr<Tensor> tensor) + { + auto valueTensor = (*tensor).abs().mean(); + std::shared_ptr<Tensor> fallback; + const Tensor& localTensor = valueTensor.refCastFrom(fallback, DataType::Float32, "cpu"); + return localTensor.get<float>(0); + } + + static float getTensorStd(std::shared_ptr<Tensor> tensor) + { + auto valueTensor = (*tensor); + + auto skewedTensor = valueTensor - valueTensor.mean(); + auto squaredTensor = skewedTensor * skewedTensor; + auto varianceTensor = squaredTensor.mean(); + + std::shared_ptr<Tensor> fallback; + auto localTensor = varianceTensor.refCastFrom(fallback, DataType::Float32, "cpu"); + + float variance = localTensor.get<float>(0); + return std::sqrt(variance); + } + + + // INIT THE STEP SIZE OF A QUANTIZER NODE + + static bool initStepSize(std::shared_ptr<Node> quantizer) + { + const auto quantizerOp = std::static_pointer_cast<LSQ_Op>(quantizer->getOperator()); + + // This formula is the one proposed in the paper ... + + // float inputAbsMean = getTensorAbsMean(quantizerOp->getInput(0)); + // float stepSize = 2.0f * (inputAbsMean / std::sqrt(quantizerOp->range().second)); + + // .. but this formula seems to work better !!! + + float inputStd = getTensorStd(quantizerOp->getInput(0)); + float stepSize = 8.0f * (inputStd / (quantizerOp->range().second)); + + // TODO : use the scalar constructor + auto stepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}})); + + // XXX Manage backend here ? + stepSizeTensor->setBackend(quantizerOp->getInput(0)->backend()); + stepSizeTensor->setDataType(quantizerOp->getInput(0)->dataType()); + + auto stepSizeProducer = quantizer->getParent(1); + + stepSizeProducer->getOperator()->setOutput(0, stepSizeTensor); + + Log::notice(" [ INIT STEP SIZE = {} ] ", stepSize); + + return false; + } + + static void setupInputQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits) + { + const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|PaddedConv2D#|FC#)"); + + for (const auto& match : matches) + { + auto linearNode = match.graph->rootNode(); + + // Log::notice(" SET INPUT QUANTIZER : {} ", linearNode->type()); + + std::pair<int, int> signedRange = {-std::pow(2, nbBits - 1), std::pow(2, nbBits - 1) - 1}; + std::pair<int, int> unsignedRange = {0, std::pow(2, nbBits) - 1}; + + // Create the input quantizer node + + auto quantizerName = makeUniqueName(linearNode->name() + "_lsq_i", graphView); + auto quantizerNode = LSQ(signedRange, quantizerName); + + // Init the step-size using the node call stack + + quantizerNode->addBeforeForward([quantizerNode](){ return initStepSize(quantizerNode); }); + + // Absorb the ReLU when possible ... + + bool nodeHasParent = static_cast<bool> (linearNode->getParents()[0]); // XXX is this safe ? + + if (nodeHasParent) + { + bool allParentsAreReLU = true; + for (auto parentNode : linearNode->getParents()) + if (parentNode->type() != "ReLU") + allParentsAreReLU = false; + + if (allParentsAreReLU) { + auto quantizerOp = std::static_pointer_cast<LSQ_Op> (quantizerNode->getOperator()); + quantizerOp->range() = unsignedRange; + } + + // TODO : remove the ReLUs when possible + } + + // Insert the quantizer in the graphView ... + // (We need to handle the case where the linear node is the first one) + + if (nodeHasParent) { + graphView->insertParent(linearNode, quantizerNode, 0, 0, 0); + } else { + quantizerNode->addChild(graphView); + graphView->add(quantizerNode); + } + } + } + + // PARAM QUANTIZERS INSERTION + + static void setupParamQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits) + { + const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|PaddedConv2D#|FC#)"); + + std::pair<int, int> signedRange = {-std::pow(2, nbBits - 1), std::pow(2, nbBits - 1) - 1}; + + for (const auto& match : matches) + { + auto linearNode = match.graph->rootNode(); + + // Log::notice(" SET PARAM QUANTIZER : {} ", linearNode->type()); + + // TODO : double check this, and use createUniqueName() + auto quantizerName = makeUniqueName(linearNode->name() + "_lsq_p", graphView); + auto quantizerNode = LSQ(signedRange, quantizerName); + + // Init the step-size using the node call stack + + quantizerNode->addBeforeForward([quantizerNode](){ return initStepSize(quantizerNode); }); + + // Insert the quantizer in the graphView + + graphView->insertParent(linearNode, quantizerNode, 1, 0, 0); + } + } + + void QuantLSQ::setupQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits) + { + sanitizeNodeNames(graphView); + setupInputQuantizers(graphView, nbBits); + setupParamQuantizers(graphView, nbBits); + } + + } \ No newline at end of file