diff --git a/aidge_quantization/unit_tests/assets/BranchNetV4.onnx b/aidge_quantization/unit_tests/assets/BranchNetV4.onnx new file mode 100644 index 0000000000000000000000000000000000000000..34cccc47c4b5014f0adc4757d0b8e362a8e5ddce Binary files /dev/null and b/aidge_quantization/unit_tests/assets/BranchNetV4.onnx differ diff --git a/aidge_quantization/unit_tests/assets/MLP.onnx b/aidge_quantization/unit_tests/assets/MLP.onnx new file mode 100644 index 0000000000000000000000000000000000000000..f6b72dbbd8c829a1d3609d923478887892b9e231 Binary files /dev/null and b/aidge_quantization/unit_tests/assets/MLP.onnx differ diff --git a/aidge_quantization/unit_tests/assets/TestNet.onnx b/aidge_quantization/unit_tests/assets/TestNet.onnx new file mode 100644 index 0000000000000000000000000000000000000000..7f73e9b11d8a2ca43c88e52295dd201211f1e741 Binary files /dev/null and b/aidge_quantization/unit_tests/assets/TestNet.onnx differ diff --git a/aidge_quantization/unit_tests/test_ptq.py b/aidge_quantization/unit_tests/test_ptq.py index 56080bff0d1f4a95248fa983316dbafd35565501..e2acab95c7c3c7ee517ebaba5af102d336679cbe 100644 --- a/aidge_quantization/unit_tests/test_ptq.py +++ b/aidge_quantization/unit_tests/test_ptq.py @@ -1,118 +1,125 @@ import unittest -import gzip import numpy as np -from pathlib import Path - +import gzip import aidge_core -import aidge_backend_cpu import aidge_onnx +import aidge_backend_cpu import aidge_quantization +import sys +from pathlib import Path -from aidge_core import Log, Level - +""" +Unit test for the PTQ pipeline: +This script is designed to test and validate the accuracy of five small model topologies on the MNIST dataset: +["MiniResNet", "ConvNet", "BranchNetV4", "TestNet", "MLP"] +It compares the results for three configurations: the baseline, quantization, and quantization with single shift. +The value of sigma represents the tolerance for the tests. +""" +aidge_core.Log.set_console_level(aidge_core.Level.Error) # Reduce useless logs # -------------------------------------------------------------- -# CONFIGS +# CONFIGURATION # -------------------------------------------------------------- -NB_SAMPLES = 1000 # max : 1000 -SAMPLE_SHAPE = (1, 1, 28, 28) -MODEL_NAME = 'MiniResNet.onnx' # 'ConvNet.onnx' -ACCURACIES = (95.4, 94.4) # (97.9, 97.7) -NB_BITS = 4 +NB_SAMPLES = 1000 +SAMPLE_SHAPE = (1, 1, 28, 28) +NB_BITS = 4 +CLIPPING = aidge_quantization.Clipping.MSE +EXPECTED_RESULTS = { + "MiniResNet.onnx": (95.4, 94.5, 94.7), + "ConvNet.onnx": (97.9, 97.7, 97.4), + "BranchNetV4.onnx": (93.8, 93.2, 93.7), + "TestNet.onnx": (95.5, 94.2, 94.2), + "MLP.onnx": (94.7, 94.2, 93.3) +} +SIGMA = 0.05 # -------------------------------------------------------------- # UTILS # -------------------------------------------------------------- def propagate(model, scheduler, sample): + sample = np.reshape(sample, SAMPLE_SHAPE) input_tensor = aidge_core.Tensor(sample) scheduler.forward(True, [input_tensor]) output_node = model.get_output_nodes().pop() output_tensor = output_node.get_operator().get_output(0) return np.array(output_tensor) -def prepare_sample(sample): - sample = np.reshape(sample, SAMPLE_SHAPE) - return sample.astype('float32') - def compute_accuracy(model, samples, labels): - acc = 0 - scheduler = aidge_core.SequentialScheduler(model) - for i, sample in enumerate(samples): - x = prepare_sample(sample) - y = propagate(model, scheduler, x) - if labels[i] == np.argmax(y): - acc += 1 + schedueler = aidge_core.SequentialScheduler(model) + acc = sum(labels[i] == np.argmax(propagate(model, schedueler, x)) for i, x in enumerate(samples)) return acc / len(samples) # -------------------------------------------------------------- # TEST CLASS # -------------------------------------------------------------- -class test_ptq(unittest.TestCase): +class TestQuantization(unittest.TestCase): def setUp(self): - - # load the samples / labels (numpy) - curr_file_dir = Path(__file__).parent.resolve() self.samples = np.load(gzip.GzipFile(curr_file_dir / 'assets/mnist_samples.npy.gz', "r")) - self.labels = np.load(gzip.GzipFile(curr_file_dir / 'assets/mnist_labels.npy.gz', "r")) - - # load the model in AIDGE - - self.model = aidge_onnx.load_onnx(curr_file_dir / "assets/" / MODEL_NAME, verbose=False) - aidge_core.remove_flatten(self.model) - - self.model.set_datatype(aidge_core.dtype.float32) - self.model.set_backend("cpu") - - def tearDown(self): - pass - - - def test_model(self): - - Log.set_console_level(Level.Info) - # compute the base accuracy - accuracy = compute_accuracy(self.model, self.samples[0:NB_SAMPLES], self.labels) - self.assertAlmostEqual(accuracy * 100, ACCURACIES[0], msg='base accuracy does not meet the baseline !', delta=0.1) - - def test_quant_model(self): - - Log.set_console_level(Level.Debug) - - # create the calibration dataset - - tensors = [] - for sample in self.samples[0:NB_SAMPLES]: - sample = prepare_sample(sample) - tensor = aidge_core.Tensor(sample) - tensors.append(tensor) - - # quantize the model - - aidge_quantization.quantize_network( - self.model, - NB_BITS, - tensors, - clipping_mode=aidge_quantization.Clipping.MSE, - no_quantization=False, - optimize_signs=True, - single_shift=False - ) - - # rescale the inputs - - scaling = 2**(NB_BITS-1)-1 - for i in range(NB_SAMPLES): - self.samples[i] = self.samples[i]*scaling # XXX np.round ??? - - # compute the quantized accuracy - - accuracy = compute_accuracy(self.model, self.samples, self.labels) - self.assertAlmostEqual(accuracy * 100, ACCURACIES[1], msg='quantized accuracy does not meet the baseline !', delta=0.1) - + self.labels = np.load(gzip.GzipFile(curr_file_dir / 'assets/mnist_labels.npy.gz', "r")) + self.quantized_sample = np.load(gzip.GzipFile(curr_file_dir / 'assets/mnist_samples.npy.gz', "r")) * ((1 << (NB_BITS - 1)) - 1) + + def run_model_test(self, model_name): + model_path = Path(__file__).parent / "assets" / model_name + model = aidge_onnx.load_onnx(model_path, verbose=False) + aidge_core.remove_flatten(model) + model.set_datatype(aidge_core.dtype.float64) + model.set_backend("cpu") + + expected_base, expected_quant, expected_quant_ss = EXPECTED_RESULTS[model_name] + + # Baseline Accuracy + base_accuracy = compute_accuracy(model, self.samples[:NB_SAMPLES], self.labels) + self.assertAlmostEqual(base_accuracy * 100, expected_base, delta=SIGMA, msg=f"[X] Baseline accuracy mismatch for {model_name}. Expected accuracy was: {expected_base}, but got: {base_accuracy * 100}") + + # Quantize + tensors = [aidge_core.Tensor(np.reshape(sample, SAMPLE_SHAPE)) for sample in self.samples[:NB_SAMPLES]] + + aidge_quantization.quantize_network(network = model, + nb_bits = NB_BITS, + input_dataset = tensors, + clipping_mode = CLIPPING, + target_type = aidge_core.dtype.float64, + no_quantization = False, + optimize_signs = True, + single_shift = False, + use_cuda = False, + fold_graph = True, + bitshift_rounding = False, + verbose = False) + quant_accuracy = compute_accuracy(model, self.quantized_sample[:NB_SAMPLES], self.labels) + + self.assertAlmostEqual(quant_accuracy * 100, expected_quant, delta=SIGMA, msg=f"[X] Quantized accuracy mismatch for {model_name},Expected accuracy was: {expected_quant}, but got: {quant_accuracy * 100}") + + # Quantize with Single Shift + model_ss = aidge_onnx.load_onnx(model_path, verbose=False) + aidge_core.remove_flatten(model_ss) + model_ss.set_datatype(aidge_core.dtype.float64) + model_ss.set_backend("cpu") + + aidge_quantization.quantize_network(network = model_ss, + nb_bits = NB_BITS, + input_dataset = tensors, + clipping_mode = CLIPPING, + target_type = aidge_core.dtype.float64, + no_quantization = False, + optimize_signs = True, + single_shift = True, + use_cuda = False, + fold_graph = True, + bitshift_rounding = False, + verbose = False) + + quant_accuracy_ss = compute_accuracy(model_ss, self.quantized_sample[:NB_SAMPLES], self.labels) + self.assertAlmostEqual(quant_accuracy_ss * 100, expected_quant_ss, delta=SIGMA, msg=f"[X] Quantized Single Shift accuracy mismatch for {model_name}.Expected accuracy was: {expected_quant_ss}, but got: {quant_accuracy_ss * 100}") + + def test_models(self): + for model in EXPECTED_RESULTS.keys(): + with self.subTest(model=model): + self.run_model_test(model) if __name__ == '__main__': unittest.main() diff --git a/include/aidge/operator/PTQMetaOps.hpp b/include/aidge/operator/PTQMetaOps.hpp index 9ca76fbd40b9366aa82c6521fba931d284da137a..ff8235c6ea4c92935b869d8ba522a3fdcbc9b8e2 100644 --- a/include/aidge/operator/PTQMetaOps.hpp +++ b/include/aidge/operator/PTQMetaOps.hpp @@ -29,6 +29,34 @@ namespace Aidge { /// @return A shared pointer to an instance of the meta-operator node. std::shared_ptr<Aidge::Node> Quantizer(double scalingFactor, double clipMin, double clipMax, const std::string& name); +/// @brief IntQuantizer acts as an extension of the Quantizer meta-operator, enabling seamless integration +/// into computation graphs with a data type other than Float while preserving floating-point precision. +/// +/// This operator modifies the provided Quantizer by inserting explicit casting operations before and after +/// the quantization process. It first casts the input to Float64, applies the quantization steps (Mul, Clip, Round), +/// and then casts the result back to the target data type. This ensures compatibility with integer-based computation graphs +/// while maintaining the precision of floating-point operations. +/// +/// @param oldQuantizer A shared pointer to the existing Quantizer node that will be adapted. +/// @param targetType The target data type to which the final output should be cast after the quantization process. +/// @param name The name of the meta-operator node created. +/// @return A shared pointer to a new instance of the modified meta-operator node. +std::shared_ptr<Node> IntQuantizer(std::shared_ptr<Node> oldQuantizer, DataType targetType, const std::string& name); + +/// @brief BitShiftQuantizer acts as an extension of the Quantizer meta-operator, enabling seamless integration +/// into computation graphs with a data type other than Float while preserving floating-point precision. +/// +/// This operator modifies the provided Quantizer by inserting explicit casting operations before and after +/// the quantization process. It first casts the input to Float64, applies the quantization steps (Mul, Clip, Round), +/// and then casts the result back to the target data type. This ensures compatibility with integer-based computation graphs +/// while maintaining the precision of floating-point operations. +/// +/// @param oldQuantizer A shared pointer to the existing Quantizer node that will be adapted. +/// @param targetType The target data type to which the final output should be cast after the quantization process. +/// @param name The name of the meta-operator node created. +/// @return A shared pointer to a new instance of the modified meta-operator node. +std::shared_ptr<Node> BitShiftQuantizer(std::shared_ptr<Node> oldQuantizer, DataType targetType,bool bitshiftRounding, const std::string& name); + /// @brief Updates the scaling factor of a PTQ meta-operator node, allowing for dynamic adjustment of the scaling parameter. /// This function sets a new scaling factor for a specified meta-operator node, modifying the scalar applied in the [Mul] operation. /// The meta-operator node must be a PTQ-specific operator, such as a Quantizer or Scaling node. diff --git a/include/aidge/quantization/PTQ/PTQ.hpp b/include/aidge/quantization/PTQ/PTQ.hpp index 1c911801c543cac8cb464acaab80e6061703e6e7..7f11c012636b04d0434ff7a6a04bbf5131096171 100644 --- a/include/aidge/quantization/PTQ/PTQ.hpp +++ b/include/aidge/quantization/PTQ/PTQ.hpp @@ -181,14 +181,29 @@ namespace Aidge { * @param graphView The GraphView to be quantized. * @param nbBits The desired number of bits of the quantization. * @param inputDataSet The input dataset on which the value ranges are computed. - * @param clippingMode: Type of the clipping optimization. Can be either 'MAX', 'MSE', 'AA' or 'KL'. + * @param clippingMode Type of the clipping optimization. Can be either 'MAX', 'MSE', 'AA' or 'KL'. + * @param targetType Desired target type to cast the graph into (default is Float64 which will NOT apply casting on the network) * @param noQuant Whether to apply the rounding operations or not. * @param optimizeSigns Whether to take account of the IO signs of the operators or not. * @param singleShift Whether to convert the scaling factors into powers of two. If true the approximations are compensated using the previous nodes weights. + * @param useCuda Wheter to speed up the PTQ by computing the values ranges using CUDA kernels. + * This flag does not set the backend of the graphview to "cuda" at the end of the PTQ pipeline + * @param foldGraph Whether to apply the constant folding recipe which makes the end graphview much easier to read + * @param bitshiftRounding Whether rounding should be applied after bit-shifting operations. If enabled, the result of bit-shifting is rounded to the nearest integer. * @param verbose Whether to print internal informations about the quantization process. */ - void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet, Clipping clippingMode, bool noQuant, bool optimizeSigns, bool singleShift, bool useCuda, bool verbose); - + void quantizeNetwork(std::shared_ptr<GraphView> graphView, + std::uint8_t nbBits, + std::vector<std::shared_ptr<Tensor>> inputDataSet, + Clipping clippingMode, + DataType targetType, + bool noQuant, + bool optimizeSigns, + bool singleShift, + bool useCuda, + bool foldGraph, + bool bitshiftRounding, + bool verbose); /** * @brief Compute the weight ranges of every affine node. Provided for debugging purposes. * @param graphView The GraphView containing the affine nodes. diff --git a/python_binding/pybind_PTQ.cpp b/python_binding/pybind_PTQ.cpp index 12d14340f9353114d06121fa8f1e1fd4f050e3f4..d7bc00dcc095736419732b9ed56918ca37663b50 100644 --- a/python_binding/pybind_PTQ.cpp +++ b/python_binding/pybind_PTQ.cpp @@ -93,7 +93,20 @@ void init_PTQ(py::module &m) { :type verbose: bool )mydelimiter"); - m.def("quantize_network", &quantizeNetwork ,py::arg("network"), py::arg("nb_bits"), py::arg("input_dataset"), py::arg("clipping_mode") = Clipping::MAX , py::arg("no_quantization") = false, py::arg("optimize_signs") = false, py::arg("single_shift") = false, py::arg("use_cuda") = false, py::arg("verbose") = false, + m.def("quantize_network", + &quantizeNetwork, + py::arg("network"), + py::arg("nb_bits"), + py::arg("input_dataset"), + py::arg("clipping_mode") = Clipping::MAX, + py::arg("target_type") = DataType::Float64, + py::arg("no_quantization") = false, + py::arg("optimize_signs") = false, + py::arg("single_shift") = false, + py::arg("use_cuda") = false, + py::arg("fold_graph") = true, + py::arg("bitshift_rounding") = false, + py::arg("verbose") = false, R"mydelimiter( Main quantization routine. Performs every step of the quantization pipeline. :param network: The GraphView to be quantized. diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp index 0eecc450d7567b8eb0421cd95251ba8ace447a7e..df203f2547e720bcfbef109e05e7ccca5ed42b9e 100644 --- a/src/PTQ/PTQ.cpp +++ b/src/PTQ/PTQ.cpp @@ -20,6 +20,7 @@ #include "aidge/scheduler/SequentialScheduler.hpp" #include "aidge/scheduler/Scheduler.hpp" #include "aidge/utils/Log.hpp" +#include "aidge/operator/MetaOperator.hpp" #include "aidge/operator/Producer.hpp" #include "aidge/operator/Mul.hpp" @@ -30,6 +31,9 @@ #include "aidge/operator/ArgMax.hpp" #include "aidge/operator/Reshape.hpp" #include "aidge/operator/MatMul.hpp" +#include "aidge/operator/Cast.hpp" + + #include "aidge/recipes/Recipes.hpp" #include "aidge/recipes/QuantRecipes.hpp" @@ -79,6 +83,20 @@ bool isNotQuantized(std::shared_ptr<Node> node) return (notQuantizedNodeTypes.find(node->type()) != notQuantizedNodeTypes.end()); } +std::shared_ptr<Aidge::Node> getFirstNode(std::shared_ptr<GraphView> graphView) +{ + return graphView->getOrderedInputs()[0].first; +} +void clearGraphViewInputNodes(std::shared_ptr<GraphView> graphView,DataType targetType) +{ + for (std::shared_ptr<Aidge::Node> inputNode: graphView->inputNodes()) + { + for (Aidge::IOIndex_t index = inputNode->getFirstFreeDataInput();index < inputNode->getNbFreeDataInputs(); index++) + { + inputNode->getOperator()->resetInput(index); + } + } +} bool checkArchitecture(std::shared_ptr<GraphView> graphView) { std::set<std::string> removedNodeTypes({"Flatten", "Softmax", "BatchNorm2D"}); @@ -249,6 +267,59 @@ static void insertChildren(std::shared_ptr<Node> parent, std::shared_ptr<Node> n graphView->add(newNode); } +void applyConstFold(std::shared_ptr<GraphView> &graphView) +{ + for (const std::shared_ptr<Node> node : graphView->getNodes()) + { + if (node->type() == "Producer" ) + { + const auto& producer = std::static_pointer_cast<Producer_Op>(node->getOperator()); + producer->constant() = true; + } + } + constantFolding(graphView); +} + +bool castQuantizedGraph(std::shared_ptr<GraphView> &graphView, Aidge::DataType targetType, bool singleShift,bool bitshiftRounding) +{ + //We need a deepcopy of the graphs nodes since we will replace some nodes + std::vector<std::shared_ptr<Node>> nodeVector(graphView->getNodes().begin(), graphView->getNodes().end()); + + for (std::shared_ptr<Node> node : nodeVector) + { + if (node->type() == "Round" && node->attributes()->hasAttr("quantization.ptq.isProducerRounding")) + { + std::shared_ptr<Aidge::Node> castNode = Cast(targetType,node->name() + "_Cast"); + castNode->getOperator()->setDataType(targetType); + castNode->getOperator()->setBackend(node->getOperator()->backend()); + insertChildren(node,castNode,graphView); + castNode->attributes()->addAttr("quantization.ptq.isProducerCasting",0.0); + node->getOperator()->setDataType(DataType::Float64); + } + else if(node->type() == "Quantizer") + { + if(singleShift) + { + std::shared_ptr<Node> newBitShiftQuantizer = BitShiftQuantizer(node,targetType,bitshiftRounding,node->name()+"_BitShift_Quantizer"); + newBitShiftQuantizer->getOperator()->setBackend(node->getOperator()->backend()); + graphView->replace({node},{newBitShiftQuantizer}); + + } + else //If single shift is not enabled we keep using the alternative Int Quantizer (which cast the data before and after the regular Quantizer Operations) + { + std::shared_ptr<Node> newIntQuantizer = IntQuantizer(node,targetType,node->name()); + newIntQuantizer->getOperator()->setBackend(node->getOperator()->backend()); + graphView->replace({node},{newIntQuantizer}); + } + } + else if (node->type() != "Producer" && + !node->attributes()->hasAttr("quantization.ptq.isProducerScaling")) + { + node->getOperator()->setDataType(targetType); + } + } + return true; +} bool insertRoundBelowProducer(std::shared_ptr<Node> node, std::shared_ptr<GraphView> graphView) { @@ -270,7 +341,6 @@ double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor) { // get the abs tensor std::shared_ptr<Tensor> fallback; //Fallback tensor for refCastFR - std::shared_ptr<Tensor> absTensor = std::make_shared<Tensor>(tensor->abs()); // flatten the abs tensor @@ -371,10 +441,6 @@ std::vector<std::shared_ptr<Node>> retrieveNodeVector(std::shared_ptr<GraphView> return nodeVector; } -static std::shared_ptr<Node> getFirstNode(std::shared_ptr<GraphView> graphView) -{ - return retrieveNodeVector(graphView)[0]; -} // TODO : enhance this by modifying OperatorImpl in "core" ... static DataType getDataType(std::shared_ptr<Node> node) @@ -554,7 +620,7 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView) // ITERATE OVER THE GRAPH ///////////////////////////////////////////////// - std::shared_ptr<Node> firstNode = getFirstNode(graphView); + std::shared_ptr<Node> firstNode =getFirstNode(graphView); for (std::shared_ptr<Node> node : nodeVector) { @@ -699,8 +765,6 @@ std::unordered_map<std::shared_ptr<Node>, double> computeRanges(std::shared_ptr< std::unordered_map<std::shared_ptr<Node>, double> valueRanges; std::set<std::shared_ptr<Node>> nodeSet = graphView->getNodes(); - // std::shared_ptr<Node> inputNode = getFirstNode(graphView); - for (std::shared_ptr<Node> node : nodeSet) if ((scalingNodesOnly && hasAttr(node, "isScaling")) || (!scalingNodesOnly && (node->type() != "Producer"))) valueRanges.insert(std::make_pair(node, 0)); @@ -768,7 +832,7 @@ std::unordered_map<std::shared_ptr<Node>, double> computeRanges(std::shared_ptr< void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_map<std::shared_ptr<Node>, double> valueRanges) { - std::shared_ptr<Node> firstNode = getFirstNode(graphView); + std::shared_ptr<Node> firstNode = getFirstNode(graphView); // CREATE THE ACCUMULATED RATIO MAP /////////////////////////////////////// @@ -871,7 +935,7 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_m std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> computeSignMap(std::shared_ptr<GraphView> graphView, bool verbose) { - std::shared_ptr<Node> firstNode = getFirstNode(graphView); + std::shared_ptr<Node> firstNode = getFirstNode(graphView); std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> signMap; @@ -1248,7 +1312,18 @@ static void setupDataType(std::shared_ptr<GraphView> graphView, std::vector<std: tensor->setDataType(dataType); } -void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet, Clipping clippingMode, bool noQuant, bool optimizeSigns, bool singleShift, bool useCuda, bool verbose) +void quantizeNetwork(std::shared_ptr<GraphView> graphView, + std::uint8_t nbBits, + std::vector<std::shared_ptr<Tensor>> inputDataSet, + Clipping clippingMode, + DataType targetType, + bool noQuant, + bool optimizeSigns, + bool singleShift, + bool useCuda, + bool foldGraph, + bool bitshiftRounding, + bool verbose) { Log::notice(" === QUANT PTQ 0.2.21 === "); @@ -1292,13 +1367,33 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, Log::notice(" Performing the Single-Shift approximation ..."); performSingleShiftApproximation(graphView, noQuant); } + if( targetType != DataType::Float64 && targetType != DataType::Float32 && targetType != DataType::Float16) + { + AIDGE_ASSERT(!noQuant,"Cannot cast operators with the noQuant (Fake Quantization) flag set to true!") + Log::notice("Starting to cast operators into the desired type ..."); + castQuantizedGraph(graphView,targetType,singleShift,bitshiftRounding); + + graphView->updateInputsOutputs(); + clearGraphViewInputNodes(graphView,targetType); //Convert all input tensors of the GV into targetType + } + else + { + setupDataType(graphView, inputDataSet, targetType); + } + if(foldGraph) + { + Log::notice("Applying constant folding recipe to the graph ..."); + applyConstFold(graphView); + } + //Mandatory to handle all of the newly added connections! + graphView->updateInputsOutputs(); + + //Clearing input nodes + Log::notice("Clearing all input nodes ..."); if (verbose) printScalingFactors(graphView); - - if (useCuda) - graphView->setBackend("cuda"); - + Log::notice(" Reseting the scheduler ..."); SequentialScheduler scheduler(graphView); scheduler.resetScheduling(); @@ -1333,4 +1428,4 @@ void clearBiases(std::shared_ptr<GraphView> graphView) } } } -} +} \ No newline at end of file diff --git a/src/operator/PTQMetaOps.cpp b/src/operator/PTQMetaOps.cpp index c70a7726c143ed4cd028099f849de25a16ab11d3..c043e4739991c6b7b93fc8e5a710e56e0ce2ac30 100644 --- a/src/operator/PTQMetaOps.cpp +++ b/src/operator/PTQMetaOps.cpp @@ -18,6 +18,8 @@ #include "aidge/operator/Clip.hpp" #include "aidge/operator/Mul.hpp" #include "aidge/operator/Round.hpp" +#include "aidge/operator/BitShift.hpp" +#include "aidge/operator/Cast.hpp" #include "aidge/graph/Node.hpp" #include "aidge/graph/OpArgs.hpp" @@ -33,7 +35,15 @@ namespace Aidge { +static std::shared_ptr<Node> getSubNode(std::shared_ptr<GraphView> graphView, std::string nodeType) +{ + std::shared_ptr<Node> mulNode = nullptr; + for(std::shared_ptr<Node> node : graphView->getNodes()) + if (node->type() == nodeType) + mulNode = node; + return mulNode; +} std::shared_ptr<Node> Quantizer(double scalingFactor, double clipMin, double clipMax, const std::string& name) { // create the nodes @@ -55,19 +65,65 @@ std::shared_ptr<Node> Quantizer(double scalingFactor, double clipMin, double cli // return the metaop - std::shared_ptr<Node> metaopNode = MetaOperator("Quantizer", connectedGraphView, {}, name); // XXX alternative prototype + return MetaOperator("Quantizer", connectedGraphView, {}, name); // XXX alternative prototype - return metaopNode; } +std::shared_ptr<Node> BitShiftQuantizer(std::shared_ptr<Node> oldQuantizer, DataType targetType,bool bitshiftRounding, const std::string& name) +{ + double scalingFactor = getScalingFactor(oldQuantizer); -static std::shared_ptr<Node> getSubNode(std::shared_ptr<GraphView> graphView, std::string nodeType) + std::shared_ptr<MetaOperator_Op> metaOp = std::static_pointer_cast<MetaOperator_Op> (oldQuantizer->getOperator()); + std::shared_ptr<Node> oldclipNode = getSubNode(metaOp->getMicroGraph(), "Clip"); + + if (!oldclipNode) { + Log::warn("Invalid PTQ MetaOperator, no Clip found inside node of type {}", oldQuantizer->type()); + return nullptr; + } + + std::shared_ptr<Clip_Op> clipOp = std::static_pointer_cast<Clip_Op>(oldclipNode->getOperator()); + int shift = std::log2(scalingFactor); + BitShift_Op::BitShiftDirection direction = BitShift_Op::BitShiftDirection::left; + + if(shift < 0 ) + { + direction = BitShift_Op::BitShiftDirection::right; + shift = -shift; + } + + std::shared_ptr<Node> bitShiftNode = BitShift(direction,bitshiftRounding,(!name.empty()) ? name + "_MulIQuant" : ""); + std::shared_ptr<Node> clipNode = Clip((!name.empty()) ? name + "_IClipQuant" : "", clipOp->min(), clipOp->max()); + + std::shared_ptr<Tensor> bitshiftTensor = std::make_shared<Tensor>(Array1D<int, 1> {shift}); + std::shared_ptr<Node> bitshiftProducer = addProducer(bitShiftNode, 1, {1}, "ScalingFactor"); + + bitshiftProducer->getOperator()->setOutput(0, bitshiftTensor); + bitshiftProducer->attributes()->addAttr("quantization.ptq.ShiftAmount",shift); + bitshiftProducer->getOperator()->setDataType(targetType); + + // connect the scaling factor producer + + bitShiftNode->getOperator()->setDataType(targetType); + clipNode->getOperator()->setDataType(targetType); + + // create the metaop graph + + std::shared_ptr<GraphView> graphView = Sequential({bitShiftNode,clipNode}); + std::shared_ptr<GraphView> connectedGraphView = getConnectedGraphView(bitShiftNode); // XXX why not use the graphView ??? + + // return the metaop + return MetaOperator("BitShiftQuantizer", connectedGraphView, {}, name); // XXX alternative prototype +} +std::shared_ptr<Node> IntQuantizer(std::shared_ptr<Node> oldQuantizer, DataType targetType, const std::string& name) { - std::shared_ptr<Node> mulNode = nullptr; - for(std::shared_ptr<Node> node : graphView->getNodes()) - if (node->type() == nodeType) - mulNode = node; + std::shared_ptr<Node> castPreNode = Cast(DataType::Float64,((!name.empty()) ? name + "_PreCast" : "")); + std::shared_ptr<Node> castPostNode = Cast(targetType,((!name.empty()) ? name + "_PostCast" : "")); - return mulNode; + castPreNode->getOperator()->setDataType(DataType::Float64); + castPostNode->getOperator()->setDataType(targetType); + + std::shared_ptr<GraphView> graphView = Sequential({castPreNode, oldQuantizer->clone(), castPostNode}); + + return MetaOperator("IntQuantizer", graphView, {}, name); // XXX alternative prototype } void updateScalingFactor(std::shared_ptr<Node> metaOpNode, double scalingFactor)