diff --git a/.gitignore b/.gitignore index ba5c59398b68083c6c1c5fe820fb9070d999c18e..c64cbb5b6997c5c332326460eb36296247a88979 100644 --- a/.gitignore +++ b/.gitignore @@ -5,8 +5,10 @@ build*/ install*/ include/aidge/backend/quantization_version.h +include/aidge/quantization_version.h -# VSCode + +# VSCodes .vscode # Python diff --git a/aidge_quantization/_version.py b/aidge_quantization/_version.py new file mode 100644 index 0000000000000000000000000000000000000000..2d34d3557071ed5c22aea83c63bfb7684b180cf9 --- /dev/null +++ b/aidge_quantization/_version.py @@ -0,0 +1,4 @@ +# file generated by setuptools_scm +# don't change, don't track in version control +__version__ = version = '0.2.1.dev60+g8044e79.d20250106' +__version_tuple__ = version_tuple = (0, 2, 1, 'dev60', 'g8044e79.d20250106') \ No newline at end of file diff --git a/include/aidge/quantization/PTQ/PTQMetaOps.hpp b/include/aidge/operator/PTQMetaOps.hpp similarity index 64% rename from include/aidge/quantization/PTQ/PTQMetaOps.hpp rename to include/aidge/operator/PTQMetaOps.hpp index 62fac873235f2b89a242042de9260fc350ad6aa8..a65e4d52a11eb83463208088707da57cbc78eae2 100644 --- a/include/aidge/quantization/PTQ/PTQMetaOps.hpp +++ b/include/aidge/operator/PTQMetaOps.hpp @@ -37,13 +37,33 @@ namespace Aidge { /// @return A shared pointer to an instance of the meta-operator node. std::shared_ptr<Aidge::Node> Quantizer(double scalingFactor, double clipMin, double clipMax, const std::string& name); -/// @brief The purpose of Scaling is to encapsulate the Mul operator and tag it as a PTQ node rather than a regular Mul operator. -/// Therefore, this meta-operator consists solely of a [Mul] operation. +/// @brief IntQuantizer acts as an extension of the Quantizer meta-operator, enabling seamless integration +/// into computation graphs with a data type other than Float while preserving floating-point precision. +/// +/// This operator modifies the provided Quantizer by inserting explicit casting operations before and after +/// the quantization process. It first casts the input to Float64, applies the quantization steps (Mul, Clip, Round), +/// and then casts the result back to the target data type. This ensures compatibility with integer-based computation graphs +/// while maintaining the precision of floating-point operations. /// -/// @param scalingFactor The scaling factor to apply to the input (a scalar to multiply the input with). +/// @param oldQuantizer A shared pointer to the existing Quantizer node that will be adapted. +/// @param targetType The target data type to which the final output should be cast after the quantization process. /// @param name The name of the meta-operator node created. -/// @return A shared pointer to an instance of the scaling node. -std::shared_ptr<Aidge::Node> Scaling(double scalingFactor, const std::string& name = ""); +/// @return A shared pointer to a new instance of the modified meta-operator node. +std::shared_ptr<Node> IntQuantizer(std::shared_ptr<Node> oldQuantizer, DataType targetType, const std::string& name); + +/// @brief BitShiftQuantizer acts as an extension of the Quantizer meta-operator, enabling seamless integration +/// into computation graphs with a data type other than Float while preserving floating-point precision. +/// +/// This operator modifies the provided Quantizer by inserting explicit casting operations before and after +/// the quantization process. It first casts the input to Float64, applies the quantization steps (Mul, Clip, Round), +/// and then casts the result back to the target data type. This ensures compatibility with integer-based computation graphs +/// while maintaining the precision of floating-point operations. +/// +/// @param oldQuantizer A shared pointer to the existing Quantizer node that will be adapted. +/// @param targetType The target data type to which the final output should be cast after the quantization process. +/// @param name The name of the meta-operator node created. +/// @return A shared pointer to a new instance of the modified meta-operator node. +std::shared_ptr<Node> BitShiftQuantizer(std::shared_ptr<Node> oldQuantizer, DataType targetType, const std::string& name); /// @brief Updates the scaling factor of a PTQ meta-operator node, allowing for dynamic adjustment of the scaling parameter. /// This function sets a new scaling factor for a specified meta-operator node, modifying the scalar applied in the [Mul] operation. diff --git a/include/aidge/quantization/PTQ/PTQ.hpp b/include/aidge/quantization/PTQ/PTQ.hpp index d2b8b7f78fccc15cf4afd598b02f0f7b391375e9..3a35017404337c60845578aee8d0f0bb249bb0b7 100644 --- a/include/aidge/quantization/PTQ/PTQ.hpp +++ b/include/aidge/quantization/PTQ/PTQ.hpp @@ -66,6 +66,26 @@ namespace Aidge { * @return The scheduled vector of nodes */ std::vector<std::shared_ptr<Node>> retrieveNodeVector(std::shared_ptr<GraphView> graphView, bool newSchedule = true, bool verbose = false); + + /** + * @brief Inserts a scaling node below the given producer node in the graph view. + * If the node is already a producer scaling node, it accumulates the scaling factor by multiplyins its value directly. + * + * @param node A shared pointer to the producer node where the scaling node will be inserted (below). + * @param scalingFactor The scaling factor to apply. + * @param graphView A shared pointer to the graph view in which the nodes are located. + * @return True if the scaling node was successfully inserted or the scaling factor was accumulated; False otherwise. + */ + bool insertScalingBelowProducer(std::shared_ptr<Node> node, double scalingFactor, std::shared_ptr<GraphView> graphView); + + /** + * @brief Inserts a rounding node below the given producer (also below its ows producerScaling) node in the graph view. + * + * @param node A shared pointer to the producer node where the rounding node will be inserted. + * @param graphView A shared pointer to the graph view in which the nodes are located. + * @return True if the rounding node was successfully inserted; False otherwise. + */ + bool insertRoundBelowProducer(std::shared_ptr<Node> node, std::shared_ptr<GraphView> graphView); /** * @brief Determine whether an input GraphView can be quantized or not. @@ -74,6 +94,14 @@ namespace Aidge { */ bool checkArchitecture(std::shared_ptr<GraphView> graphView); + /** + * @brief This function multiplies the existing scaling factor by a given coefficient. It verifies that the node is of the correct type ("Mul") + * and has the `isScaling` attribute. If these conditions are not met, a warning is logged. + * @param node A shared pointer to an `Aidge::Node` object representing the node to modify. + * @param coeff A double representing the multiplication coefficient to apply to the scaling factor. + */ + void multiplyScalingFactor(std::shared_ptr<Aidge::Node> node, double coeff); + void prepareNetwork(std::shared_ptr<GraphView> graphView); @@ -138,7 +166,8 @@ namespace Aidge { * @param singleShift Whether to convert the scaling factors into powers of two. If true the approximations are compensated using the previous nodes weights. * @param verbose Whether to print internal informations about the quantization process. */ - void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet, Clipping clippingMode, bool applyRounding, bool optimizeSigns, bool singleShift, bool useCuda, bool verbose); + void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet, + Clipping clippingMode, DataType targetType, bool noQuant, bool optimizeSigns, bool singleShift, bool useCuda,bool foldGraph ,bool verbose); /** * @brief Compute the weight ranges of every affine node. Provided for debugging purposes. diff --git a/include/aidge/quantization/QAT/QAT_LSQ.hpp b/include/aidge/quantization/QAT/QAT_LSQ.hpp index 4970be07fae8737a1c2863600757bb81ff3a65f9..d7d03ca78ff63b328ba068dd4ff82c61270218e3 100644 --- a/include/aidge/quantization/QAT/QAT_LSQ.hpp +++ b/include/aidge/quantization/QAT/QAT_LSQ.hpp @@ -20,22 +20,14 @@ namespace Aidge { namespace QuantLSQ { /** - * @brief Insert the LSQ quantizer nodes in a given GraphView - * @param graphView The GraphView containing the graph to quantize. + * @brief Given a GraphView with parameters properly initialized, insert + * the LSQ quantizer nodes, and setup the adjustment their step-sizes. + * @param graphView The GraphView containing the network to quantize. * @param nbBits Number of quantization bits. - * @param span Fixed output span of the quantizers. */ -void insertQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits, float step_size); +void setupQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits); -/** - * @brief Given a GraphView with parameters properly initialized and some calibration data, - * insert the LSQ quantizer nodes, and adjust their step-sizes. - * @param graphView The GraphView containing the graph to quantize. - * @param nbBits Number of quantization bits. - * @param calibrationData Calibration data used to adjust the spans. - * @param scale Multiplicative constant applied to the spans. - */ -void insertAndInitQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits, std::shared_ptr<Tensor> calibrationData); +void devLSQ(std::shared_ptr<Tensor> tensor); } } diff --git a/include/aidge/quantization_version.h b/include/aidge/quantization_version.h index 546263af3a7e8b7a73991173f48d0b095c7d9501..909ab28d77313e34ed93f46af9ef1dc1d086a036 100644 --- a/include/aidge/quantization_version.h +++ b/include/aidge/quantization_version.h @@ -3,9 +3,9 @@ namespace Aidge { static constexpr const int PROJECT_VERSION_MAJOR = 0; -static constexpr const int PROJECT_VERSION_MINOR = 2; +static constexpr const int PROJECT_VERSION_MINOR = 3; static constexpr const int PROJECT_VERSION_PATCH = 0; -static constexpr const char * PROJECT_VERSION = "0.2.0"; -static constexpr const char * PROJECT_GIT_HASH = "f50c860"; +static constexpr const char * PROJECT_VERSION = "0.3.0"; +static constexpr const char * PROJECT_GIT_HASH = "f0f9e60"; } #endif // VERSION_H diff --git a/python_binding/pybind_PTQ.cpp b/python_binding/pybind_PTQ.cpp index b5193bddcfe345a1702f02fcc139a4cf5b94a1ce..290d59d822cd34f861533f3adc0019ab7fa538e9 100644 --- a/python_binding/pybind_PTQ.cpp +++ b/python_binding/pybind_PTQ.cpp @@ -13,11 +13,10 @@ #include <pybind11/stl.h> #include <string> - +#include "aidge/operator/PTQMetaOps.hpp" #include "aidge/quantization/PTQ/Clipping.hpp" #include "aidge/quantization/PTQ/CLE.hpp" #include "aidge/quantization/PTQ/PTQ.hpp" - #include "aidge/graph/GraphView.hpp" namespace py = pybind11; @@ -40,6 +39,8 @@ void init_PTQ(py::module &m) { :rtype: bool )mydelimiter"); + m.def("quantizer",&Quantizer,py::arg("sf"),py::arg("min"),py::arg("max"),py::arg("name")); + m.def("insert_scaling_nodes", &insertScalingNodes, py::arg("network"), R"mydelimiter( Insert a scaling node after each affine node of the GraphView. @@ -48,6 +49,14 @@ void init_PTQ(py::module &m) { :type network: :py:class:`aidge_core.GraphView` )mydelimiter"); + m.def( "multiply_scaling_factor",&multiplyScalingFactor,py::arg("node"), py::arg("coeff"), + R"mydelimiter( + Updates the scaling factor of a "Mul" node in a graph if the node is marked as a scaling node. This function multiplies the existing scaling factor by a given coefficient. + :param node: A node representing the node to modify. + :param coeff: A floating value representing the multiplication coefficient to apply to the scaling factor. + )mydelimiter" + ); + m.def("normalize_parameters", &normalizeParameters, py::arg("network"), R"mydelimiter( Normalize the parameters of each parametrized node, so that they fit in the [-1:1] range. @@ -93,7 +102,9 @@ void init_PTQ(py::module &m) { :type verbose: bool )mydelimiter"); - m.def("quantize_network", &quantizeNetwork ,py::arg("network"), py::arg("nb_bits"), py::arg("input_dataset"), py::arg("clipping_mode") = Clipping::MAX , py::arg("no_quantization") = true, py::arg("optimize_signs") = false, py::arg("single_shift") = false, py::arg("use_cuda") = false, py::arg("verbose") = false, + m.def("quantize_network", &quantizeNetwork ,py::arg("network"), py::arg("nb_bits"), py::arg("input_dataset"), + py::arg("clipping_mode") = Clipping::MAX ,py::arg("target_type") = DataType::Float64 ,py::arg("no_quantization") = true, py::arg("optimize_signs") = false, + py::arg("single_shift") = false, py::arg("use_cuda") = false, py::arg("fold_graph") = true, py::arg("verbose") = false, R"mydelimiter( Main quantization routine. Performs every step of the quantization pipeline. :param network: The GraphView to be quantized. diff --git a/python_binding/pybind_QAT_LSQ.cpp b/python_binding/pybind_QAT_LSQ.cpp index 206985efe4558a84ce1ed67a1264bd6902213764..0b9fcc29d1a144708537084d4538eaa47873cd05 100644 --- a/python_binding/pybind_QAT_LSQ.cpp +++ b/python_binding/pybind_QAT_LSQ.cpp @@ -23,8 +23,9 @@ void init_QAT_LSQ(py::module &m) { auto mQuantLSQ = m.def_submodule("lsq"); - mQuantLSQ.def("insert_quantizers", &QuantLSQ::insertQuantizers, py::arg("network"), py::arg("nb_bits"), py::arg("step_size")); + mQuantLSQ.def("setup_quantizers", &QuantLSQ::setupQuantizers, py::arg("network"), py::arg("nb_bits")); + + mQuantLSQ.def("dev_lsq", &QuantLSQ::devLSQ, py::arg("tensor")); - mQuantLSQ.def("insert_and_init_quantizers", &QuantLSQ::insertAndInitQuantizers, py::arg("network"), py::arg("nb_bits"), py::arg("calibration_data")); } } // namespace Aidge diff --git a/scripts/PTQ/ptq_ts.py b/scripts/PTQ/ptq_ts.py new file mode 100644 index 0000000000000000000000000000000000000000..b836a7b41981735299ffedb610dc42acda37d903 --- /dev/null +++ b/scripts/PTQ/ptq_ts.py @@ -0,0 +1,135 @@ +import unittest +import re +import numpy as np +import gzip +import aidge_core +import aidge_onnx +import os +import copy +import aidge_backend_cpu +import aidge_quantization +import sys +import concurrent.futures + +aidge_core.Log.set_console_level(aidge_core.Level.Error) + +SIGMA = 0.05 # Tolérance + +def print_in_color(text, color_code): + print(f"\033[{color_code}m{text}\033[0m") + +def run_model_test(model_name, expected_values, use_multithreading, asset_path, model_path): + NB_SAMPLES = 1000 + NB_BITS = 4 + CLIPPING = aidge_quantization.Clipping.MSE + VERBOSE = False + + results = [] + + samples = np.load(gzip.GzipFile(asset_path + '/mnist_samples.npy.gz', "r")) + labels = np.load(gzip.GzipFile(asset_path + '/mnist_labels.npy.gz', "r")) + + def load_model(): + model = aidge_onnx.load_onnx(model_path + '/' + model_name + ".onnx", verbose=False) + aidge_core.remove_flatten(model) + model.set_datatype(aidge_core.dtype.float32) + model.set_backend("cpu") + return model + + aidge_model = load_model() + scheduler = aidge_core.SequentialScheduler(aidge_model) + + def propagate(model, scheduler, sample): + sample = np.reshape(sample, (1, 1, 28, 28)) + input_tensor = aidge_core.Tensor(sample) + scheduler.forward(True, [input_tensor]) + output_node = model.get_output_nodes().pop() + output_tensor = output_node.get_operator().get_output(0) + return np.array(output_tensor) + + def compute_accuracy(model, samples, labels): + acc = sum(labels[i] == np.argmax(propagate(model, scheduler, x)) for i, x in enumerate(samples)) + return acc / len(samples) + + base_accuracy = compute_accuracy(aidge_model, samples[:NB_SAMPLES], labels) + if abs(base_accuracy * 100 - expected_values[0]) >= SIGMA: + results.append(f"⌠[ERROR] Baseline accuracy mismatch for {model_name}: Expected {expected_values[0]}, got {base_accuracy * 100:.2f}") + else: + results.append(f"✅ Baseline accuracy for {model_name}: Expected {expected_values[0]}, got {base_accuracy * 100:.2f}") + + quant_model = load_model() + tensors = [aidge_core.Tensor(np.reshape(sample, (1, 1, 28, 28))) for sample in samples[:NB_SAMPLES]] + aidge_quantization.quantize_network(quant_model, NB_BITS, tensors, CLIPPING, aidge_core.dtype.float64, False, True, False, VERBOSE) + scheduler = aidge_core.SequentialScheduler(quant_model) + + scaling = 2**(NB_BITS - 1) - 1 + samples = samples * scaling + + quant_accuracy = compute_accuracy(quant_model, samples[:NB_SAMPLES], labels) + if abs(quant_accuracy * 100 - expected_values[1]) >= SIGMA: + results.append(f"⌠[ERROR] Quantized accuracy mismatch for {model_name}: Expected {expected_values[1]}, got {quant_accuracy * 100:.2f}") + else: + results.append(f"✅ Quantized accuracy for {model_name}: Expected {expected_values[1]}, got {quant_accuracy * 100:.2f}") + + # Quantification Single Shift + quant_model_ss = load_model() + aidge_quantization.quantize_network(quant_model_ss, NB_BITS, tensors, CLIPPING, aidge_core.dtype.float64, False, True, True, VERBOSE) + scheduler = aidge_core.SequentialScheduler(quant_model_ss) + quant_accuracy_ss = compute_accuracy(quant_model_ss, samples[:NB_SAMPLES], labels) + + if abs(quant_accuracy_ss * 100 - expected_values[2]) >= SIGMA: + results.append(f"⌠[ERROR] Quantized Single Shift Approximation accuracy mismatch for {model_name}: Expected {expected_values[2]}, got {quant_accuracy_ss * 100:.2f}") + else: + results.append(f"✅ Quantized Single Shift Approximation accuracy for {model_name}: Expected {expected_values[2]}, got {quant_accuracy_ss * 100:.2f}") + + return model_name, results + +def run_quantization_test(use_multithreading,model_path,asset_path): + EXPECTED_RESULTS = { + "MiniResNet": (95.4, 94.5, 94.7), + "ConvNet": (97.9, 97.7, 97.4), + "BranchNetV4": (93.8, 93.2, 93.7), + "TestNet": (95.5, 94.2, 94.2), + "MLP": (94.7, 94.2, 93.3) + } + + all_results = [] + + if use_multithreading: + with concurrent.futures.ProcessPoolExecutor() as executor: + futures = {executor.submit(run_model_test, model, values, use_multithreading,asset_path,model_path): model for model, values in EXPECTED_RESULTS.items()} + + for future in concurrent.futures.as_completed(futures): + model_name = futures[future] + try: + model_name, results = future.result() + all_results.append((model_name, results)) + except Exception as exc: + all_results.append((model_name, [f"⌠[ERROR] {model_name} test failed with exception: {exc}"])) + else: + for model, values in EXPECTED_RESULTS.items(): + try: + model_name, results = run_model_test(model, values, use_multithreading,asset_path,model_path) + all_results.append((model_name, results)) + except Exception as exc: + all_results.append((model, [f"⌠[ERROR] {model} test failed with exception: {exc}"])) + + os.system("clear") + for model_name, results in all_results: + print(f"Results for {model_name}:") + for result in results: + if "⌠[ERROR]" in result: + print_in_color(result, 31) + else: + print_in_color(result, 32) + print() + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description="Run quantization tests.") + parser.add_argument("-j", action="store_true", help="Enable multithreading") + parser.add_argument("--models_path", type=str, default="/data1/is156025/nz280189/sbx/Models", help="Path to models directory (default: /data)") + parser.add_argument("--asset_path", type=str, default="/data1/is156025/nz280189/sbx/assets", help="Path to assets directory (default: /data)") + args = parser.parse_args() + + run_quantization_test(use_multithreading=args.j,model_path = args.models_path, asset_path = args.asset_path) diff --git a/src/PTQ/CLE.cpp b/src/PTQ/CLE.cpp index 2c818155877349ad5e5a141469de9f6657873be7..eb5ca7a04ae28326094523d4f6e6974b99aec283 100644 --- a/src/PTQ/CLE.cpp +++ b/src/PTQ/CLE.cpp @@ -14,11 +14,18 @@ #include "aidge/quantization/PTQ/PTQ.hpp" #include "aidge/graph/GraphView.hpp" + #include "aidge/scheduler/SequentialScheduler.hpp" #include "aidge/scheduler/Scheduler.hpp" #include "aidge/utils/Log.hpp" #include "aidge/operator/OperatorTensor.hpp" +#include "aidge/operator/Mul.hpp" +#include "aidge/operator/ArgMax.hpp" +#include "aidge/operator/Abs.hpp" +#include "aidge/operator/Reshape.hpp" +#include "aidge/operator/Round.hpp" + namespace Aidge { @@ -34,27 +41,68 @@ static std::shared_ptr<Tensor> getBiasTensor(std::shared_ptr<Node> node) static void rescaleTensor(std::shared_ptr<Tensor> tensor, double scaling) { - // Get the tensor data pointer - double * castedTensor = static_cast<double *> (tensor->getImpl()->rawPtr()); - - // Rescale the tensor - for(std::size_t i = 0; i < tensor->size(); i++) - castedTensor[i] *= scaling; + auto mulOp = Mul_Op(); + mulOp.setDataType(tensor->dataType()); + mulOp.setBackend(tensor->backend()); + + std::shared_ptr<Aidge::Tensor> scalingTensor = std::make_shared<Aidge::Tensor>(Aidge::Array1D<double, 1> {scaling}); + scalingTensor->setDataType(tensor->dataType()); + scalingTensor->setBackend(tensor->backend()); + + mulOp.associateInput(0, tensor); + mulOp.associateInput(1, scalingTensor); + + mulOp.forward(); + + auto outTensor = mulOp.getOutput(0); + *tensor = *outTensor; + //tensor->copyCast(*outTensor); } -static double getTensorAbsoluteMax(std::shared_ptr <Tensor> tensor) +// TODO : make the retreival of argmax values backend independant (refCastFrom) +static double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor) { - // Get the tensor data pointer and edit it - double * castedTensor = static_cast<double*> (tensor->getImpl()->rawPtr()); - - // Get the tensor absolute max value - double maxValue = 0.0f; - for(std::size_t i = 0; i < tensor->size(); ++i) { - if(std::fabs(castedTensor[i]) > maxValue) { - maxValue = std::fabs(castedTensor[i]); - } + // get the abs tensor + + std::shared_ptr<Tensor> absTensor = std::make_shared<Tensor>(tensor->abs()); + + // flatten the abs tensor + + std::int64_t nbElement = tensor->size(); + + auto reshapeOp = Reshape_Op({nbElement}); + reshapeOp.setDataType(tensor->dataType()); + reshapeOp.setBackend(tensor->backend()); + + reshapeOp.associateInput(0, absTensor); + reshapeOp.forward(); + std::shared_ptr<Tensor> flatTensor = reshapeOp.getOutput(0); + + // Get the argmax + + auto argmaxOp = ArgMax_Op(0, true, false); + argmaxOp.setDataType(tensor->dataType()); + argmaxOp.setBackend(tensor->backend()); + + argmaxOp.associateInput(0, flatTensor); + argmaxOp.forward(); + std::shared_ptr<Tensor> argmaxTensor = argmaxOp.getOutput(0); + + // Return the max + + int maxIndex = std::round(argmaxTensor->get<double>(0)); + + return flatTensor->get<double>(maxIndex); +} +//Function used to extraxt the local tensor (from a ProducerScalingNode) +std::shared_ptr<Aidge::Tensor> getLocalTensor(std::shared_ptr<Node> node) { + if (node->getParent(1)->attributes()->hasAttr("isProducerScaling")) { + std::shared_ptr<Aidge::OperatorTensor> operatorTensor = std::static_pointer_cast<OperatorTensor>(node->getParent(1)->getOperator()); + operatorTensor->forward();// We need the forward pass to compute the scaled value of the Tensor + return operatorTensor->getOutput(0); + } else { + return getWeightTensor(node); } - return maxValue; } void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetDelta) @@ -94,16 +142,18 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetD std::shared_ptr<Node> n1 = affineNodeVector[i]; std::shared_ptr<Node> n2 = affineNodeVector[i+1]; - double r1 = getTensorAbsoluteMax(getWeightTensor(n1)); - double r2 = getTensorAbsoluteMax(getWeightTensor(n2)); + std::shared_ptr<Aidge::Tensor> n1localTensor = getLocalTensor(n1); + std::shared_ptr<Aidge::Tensor> n2localTensor = getLocalTensor(n2); + + double r1 = getTensorAbsoluteMax(n1localTensor); + double r2 = getTensorAbsoluteMax(n2localTensor); double s1 = std::sqrt(r1 * r2) / r1; double s2 = std::sqrt(r1 * r2) / r2; - rescaleTensor(getWeightTensor(n1), s1); - rescaleTensor(getWeightTensor(n2), s2); - - rescaleTensor(getBiasTensor(n1), s1); + insertScalingBelowProducer(n1->getParent(1),s1,graphView); + insertScalingBelowProducer(n2->getParent(1),s2,graphView); + insertScalingBelowProducer(n1->getParent(2),s1,graphView); double rangeDelta = std::abs(r1 - r2); if (rangeDelta > maxRangeDelta) diff --git a/src/PTQ/Clipping.cpp b/src/PTQ/Clipping.cpp index 57ad7a836bbb6251a8eeb6da87e3647b4f54afe2..1901e3864066d3e9bc00f3093fe099c5bcfdec94 100644 --- a/src/PTQ/Clipping.cpp +++ b/src/PTQ/Clipping.cpp @@ -222,7 +222,7 @@ std::map<std::string, double> adjustRanges(Clipping clippingMode, std::map<std:: for (std::shared_ptr<Node> node : graphView->getNodes()) { - if (node->type() == "Scaling") + if (node->attributes()->hasAttr("isScaling")) { std::vector<int> histogram = histograms[node->name()]; diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp index 0e26313475bbbda23a56dcdda52d55a0a5af8204..c2bc0e20dee70ba88c05f49d1b7acacb66da047b 100644 --- a/src/PTQ/PTQ.cpp +++ b/src/PTQ/PTQ.cpp @@ -12,8 +12,7 @@ #include "aidge/quantization/PTQ/CLE.hpp" #include "aidge/quantization/PTQ/Clipping.hpp" #include "aidge/quantization/PTQ/PTQ.hpp" -#include "aidge/quantization/PTQ/PTQMetaOps.hpp" - +#include "aidge/operator/PTQMetaOps.hpp" #include "aidge/data/Tensor.hpp" #include "aidge/graph/GraphView.hpp" @@ -22,11 +21,16 @@ #include "aidge/scheduler/Scheduler.hpp" #include "aidge/utils/Log.hpp" +#include "aidge/operator/BitShift.hpp" #include "aidge/operator/Producer.hpp" #include "aidge/operator/Mul.hpp" #include "aidge/operator/ReLU.hpp" #include "aidge/operator/BatchNorm.hpp" #include "aidge/operator/Conv.hpp" +#include "aidge/operator/ArgMax.hpp" +#include "aidge/operator/Reshape.hpp" +#include "aidge/operator/Cast.hpp" + #include "aidge/recipes/Recipes.hpp" #include "aidge/recipes/QuantRecipes.hpp" @@ -49,6 +53,155 @@ bool isMerging(std::shared_ptr<Node> node) { return (mergingNodeTypes.find(node->type()) != mergingNodeTypes.end()); } +static int getInputIndex(std::shared_ptr<Node> node, std::shared_ptr<Node> parentNode) +{ + int index = 0; + while (node->getParent(index) != parentNode) + index++; + return index; +} + +void multiplyScalingFactor(std::shared_ptr<Aidge::Node> node,double coeff) +{ + AIDGE_ASSERT(node->type() == "Mul" && (node->attributes()->hasAttr("isProducerScaling") || node->attributes()->hasAttr("isScaling")), + "Cannot update the scaling factor on Node of type {} with no scaling tag",node->type()); + auto scalingFactorTensor = std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(1); + std::shared_ptr<Tensor> fallback; + const Tensor& localTensor = scalingFactorTensor->refCastFrom(fallback, DataType::Float64, "cpu"); + + double previousScalingFactor = localTensor.get<double>(0); + std::shared_ptr<Tensor> finalTensor = std::make_shared<Tensor>(Array1D<double, 1> {previousScalingFactor * coeff}); + node->input(1).first->getOperator()->setOutput(0, finalTensor); +} +/* Util function to insert a node below another one already connected */ +void insertNodeBetween(std::shared_ptr<Node> parent, + std::shared_ptr<Node> newNode, + std::shared_ptr<GraphView> graphView) +{ + // Checking the parents always have at least 1 children + if(parent->getChildren().size() == 0) + { + parent->addChild(newNode, 0, 0); + graphView->add(newNode); + return; + } + std::vector<std::shared_ptr<Node>> nextNodes = parent->getChildren(0); + std::vector<int> inputIndices(nextNodes.size()); + for (std::size_t i = 0; i < nextNodes.size(); i++) { + inputIndices[i] = getInputIndex(nextNodes[i], parent); + } + + // Disconnect childs from parent + for (std::shared_ptr<Node> nextNode : nextNodes) { + parent->removeChild(nextNode, 0); + } + + // Insert the new node between the child and the parent + parent->addChild(newNode, 0, 0); + for (std::size_t i = 0; i < nextNodes.size(); i++) { + newNode->addChild(nextNodes[i], 0, inputIndices[i]); + } + + graphView->add(newNode); +} + +void applyConstFold(std::shared_ptr<GraphView> &graphView) +{ + for (const std::shared_ptr<Node> node : graphView->getNodes()) + { + if (node->type() == "Producer" ) + { + const auto& producer = std::static_pointer_cast<Producer_Op>(node->getOperator()); + producer->constant() = true; + } + } + constantFolding(graphView); +} +//Add a condition to insert Cast Node to cast User Input Data into the desired type +bool castQuantizedGraph(std::shared_ptr<GraphView> &graphView, Aidge::DataType targetType, bool singleShift) +{ + //We need a deepcopy of the graphs nodes since we will replace some nodes + std::vector<std::shared_ptr<Node>> nodeVector(graphView->getNodes().begin(), graphView->getNodes().end()); + + for (std::shared_ptr<Node> node : nodeVector) + { + if (node->type() == "Round" && node->attributes()->hasAttr("isProducerRounding")) + { + std::shared_ptr<Aidge::Node> castNode = Cast(targetType,node->name() + "_Cast"); + castNode->getOperator()->setDataType(targetType); + castNode->getOperator()->setBackend(node->getOperator()->backend()); + insertNodeBetween(node,castNode,graphView); + castNode->attributes()->addAttr("isProducerCasting",0.0); + node->getOperator()->setDataType(DataType::Float64); + } + else if(node->type() == "Quantizer") + { + if(singleShift) + { + std::shared_ptr<Node> newBitShiftQuantizer = BitShiftQuantizer(node,targetType,node->name()+"_BitShift_Quantizer"); + newBitShiftQuantizer->getOperator()->setBackend(node->getOperator()->backend()); + graphView->replace({node},{newBitShiftQuantizer}); + + } + else //If single shift is not enabled we keep using the alternative Int Quantizer (which cast the data before and after the regular Quantizer Operations) + { + std::shared_ptr<Node> newIntQuantizer = IntQuantizer(node,targetType,node->name()); + newIntQuantizer->getOperator()->setBackend(node->getOperator()->backend()); + graphView->replace({node},{newIntQuantizer}); + } + } + else if (node->type() != "Producer" && + !node->attributes()->hasAttr("isProducerScaling")) + { + node->getOperator()->setDataType(targetType); + } + } + return true; +} +bool insertRoundBelowProducer(std::shared_ptr<Node> node,std::shared_ptr<GraphView> graphView) +{ + std::shared_ptr<Aidge::Node> roundNode = Round(node->name() + "_Round"); + roundNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) + roundNode->getOperator()->setBackend("cpu"); + + insertNodeBetween(node,roundNode,graphView); + + roundNode->attributes()->addAttr("isProducerRounding",0.0); + return true; +} +bool insertScalingBelowProducer(std::shared_ptr<Node> node,double scalingFactor, std::shared_ptr<GraphView> graphView) +{ + if(node->attributes()->hasAttr("isProducerRounding")) + { + //In this case we 'bump' the node to the one above him (an actual ProducerScaling) + // because the round node is not usable (only used when SSA is enabled) + node = node->getParent(0); + } + if(node->attributes()->hasAttr("isProducerScaling")) + { + // We accumulate the multiples scaling factors by multiplying the SF of the ProducerScaling node + // (adding new nodes each time would make the graph unusable) + multiplyScalingFactor(node,scalingFactor); + return true; + } + AIDGE_ASSERT(node->type() == "Producer","Cannot apply a scaling factor on node of type: {} which is not a producer", node->type()); + std::string scalingNodeName = makeUniqueName(node->name() + "_Producer_Scaling", graphView); + + std::shared_ptr<Aidge::Node> scalingNode = Mul(scalingNodeName); + scalingNode->attributes()->addAttr("isProducerScaling",0.0); + + std::shared_ptr<Tensor> scalingFactorTensor = std::make_shared<Tensor>(Array1D<double, 1> {scalingFactor}); + std::shared_ptr<Node> scalingFactorProducer = addProducer(scalingNode, 1, {1}, "Factor"); + scalingFactorProducer->getOperator()->setOutput(0, scalingFactorTensor); + graphView->add(scalingFactorProducer); + + scalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) + scalingNode->getOperator()->setBackend("cpu"); + + insertNodeBetween(node, scalingNode, graphView); + + return true; +} bool checkArchitecture(std::shared_ptr<GraphView> graphView) { @@ -66,51 +219,43 @@ bool checkArchitecture(std::shared_ptr<GraphView> graphView) return true; } -static void fillTensor(std::shared_ptr<Tensor> tensor, double value) +// TODO : make the retreival of argmax values backend independant (refCastFrom) +static double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor) { - // Get the tensor data pointer - double * castedTensor = static_cast <double *> (tensor->getImpl()->rawPtr()); + // get the abs tensor - // Fill the tensor - for(std::size_t i = 0; i < tensor->size(); i++) - castedTensor[i] = value; -} + std::shared_ptr<Tensor> absTensor = std::make_shared<Tensor>(tensor->abs()); -static void rescaleTensor(std::shared_ptr<Tensor> tensor, double scaling) -{ - // Get the tensor data pointer - double * castedTensor = static_cast <double *> (tensor->getImpl()->rawPtr()); + // flatten the abs tensor - // Rescale the tensor - for(std::size_t i = 0; i < tensor->size(); i++) - castedTensor[i] *= scaling; -} + std::int64_t nbElement = tensor->size(); -static void roundTensor(std::shared_ptr<Tensor> tensor) -{ - // Get the tensor data pointer - double * castedTensor = static_cast <double *> (tensor->getImpl()->rawPtr()); + auto reshapeOp = Reshape_Op({nbElement}); + reshapeOp.setDataType(tensor->dataType()); + reshapeOp.setBackend(tensor->backend()); - // Rescale the tensor - for(std::size_t i = 0; i < tensor->size(); i++) - castedTensor[i] = std::nearbyint(castedTensor[i]);//Round -} + reshapeOp.associateInput(0, absTensor); + reshapeOp.forward(); + std::shared_ptr<Tensor> flatTensor = reshapeOp.getOutput(0); -static double getTensorAbsoluteMax(std::shared_ptr <Tensor> tensor) -{ - // Get the tensor data pointer and edit it - double * castedTensor = static_cast<double*>(tensor->getImpl()->rawPtr()); - - // Get the tensor absolute max value - double maxValue = 0.0f; - for(std::size_t i = 0; i < tensor->size(); ++i) { - if(std::fabs(castedTensor[i]) > maxValue) { - maxValue = std::fabs(castedTensor[i]); - } - } - return maxValue; + // Get the argmax + + auto argmaxOp = ArgMax_Op(0, true, false); + argmaxOp.setDataType(tensor->dataType()); + argmaxOp.setBackend(tensor->backend()); + + argmaxOp.associateInput(0, flatTensor); + argmaxOp.forward(); + std::shared_ptr<Tensor> argmaxTensor = argmaxOp.getOutput(0); + + // Return the max + + int maxIndex = std::round(argmaxTensor->get<double>(0)); + + return flatTensor->get<double>(maxIndex); } + // TODO : pass nodeVector by reference ... static std::vector<std::shared_ptr<Node>> removeMatchingNodes(std::vector<std::shared_ptr<Node>> nodeVector, std::string nodeType) { @@ -121,6 +266,15 @@ static std::vector<std::shared_ptr<Node>> removeMatchingNodes(std::vector<std::s return remainingNodes; } +static std::vector<std::shared_ptr<Node>> removeProdScalingNodes(std::vector<std::shared_ptr<Node>> nodeVector) +{ + std::vector<std::shared_ptr<Node>> remainingNodes; + for (std::shared_ptr<Node> node : nodeVector) + if (!node->attributes()->hasAttr("isProducerScaling")) + remainingNodes.push_back(node); + + return remainingNodes; +} static void fixScheduling(std::vector<std::shared_ptr<Node>>& nodeVector) { @@ -165,12 +319,13 @@ std::vector<std::shared_ptr<Node>> retrieveNodeVector(std::shared_ptr<GraphView> fixScheduling(nodeVector); nodeVector = removeMatchingNodes(nodeVector, "Producer"); + nodeVector = removeProdScalingNodes(nodeVector); if (verbose) { - Log::info("NB OF NODES = {}", nodeVector.size()); + Log::notice("NB OF NODES = {}", nodeVector.size()); for (std::shared_ptr<Node> node : nodeVector) - Log::info("{} {}", node->type(), node->name()); + Log::notice("{} {}", node->type(), node->name()); } return nodeVector; @@ -184,6 +339,7 @@ static std::shared_ptr<Node> getFirstNode(std::shared_ptr<GraphView> graphView) void prepareNetwork(std::shared_ptr<GraphView> graphView) { removeFlatten(graphView); + sanitizeNodeNames(graphView); bool containsBatchNorm = false; std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView); @@ -228,29 +384,30 @@ void insertResidualNodes(std::shared_ptr<GraphView> graphView) if (parentIsForking) { // temporary verbose ... - Log::info(" ### found residual branch at index {}", i); - Log::info(" ### inserting multiplicative node ..."); + Log::notice(" ### found residual branch at index {}", i); + Log::notice(" ### inserting multiplicative node ..."); std::string residualNodeName = makeUniqueName(parentNode->name() + "_Res", graphView); - std::shared_ptr<Node> residualNode = Scaling(1.0, residualNodeName); + std::shared_ptr<Node> residualNode = Mul(residualNodeName); + residualNode->attributes()->addAttr("isScaling", 0.0); + residualNode->attributes()->addAttr("isResidual", 0.0); + + //Adding the SF as a producer of the node + std::shared_ptr<Tensor> scalingFactorTensor = std::make_shared<Tensor>(Array1D<double, 1> {1.0}); + std::shared_ptr<Node> scalingFactorProducer = addProducer(residualNode, 1, {1}, "ScalingFactor"); + scalingFactorProducer->getOperator()->setOutput(0, scalingFactorTensor); - residualNode->getOperator()->setDataType(DataType::Float64); //getDataType(parentNode) + residualNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) residualNode->getOperator()->setBackend("cpu"); graphView->insertParent(node, residualNode, i, 0, 0); + graphView->add(scalingFactorProducer); } } } } } -static int getInputIndex(std::shared_ptr<Node> node, std::shared_ptr<Node> parentNode) -{ - int index = 0; - while (node->getParent(index) != parentNode) - index++; - return index; -} void insertScalingNodes(std::shared_ptr<GraphView> graphView) { @@ -263,37 +420,30 @@ void insertScalingNodes(std::shared_ptr<GraphView> graphView) if (isAffine(parentNode) || isMerging(parentNode)) { std::string scalingNodeName = makeUniqueName(parentNode->name() + "_Scaling", graphView); - std::shared_ptr<Node> scalingNode = Scaling(1.0, scalingNodeName); + //std::shared_ptr<Node> scalingNode = Scaling(1.0, scalingNodeName); + + //Adding Mul operator with tag "isScaling" + std::shared_ptr<Aidge::Node> scalingNode = Mul(scalingNodeName); + scalingNode->attributes()->addAttr("isScaling",0.0); + + //Adding the SF as a producer of the node + std::shared_ptr<Tensor> scalingFactorTensor = std::make_shared<Tensor>(Array1D<double, 1> {1.0}); + std::shared_ptr<Node> scalingFactorProducer = addProducer(scalingNode, 1, {1}, "ScalingFactor"); + scalingFactorProducer->getOperator()->setOutput(0, scalingFactorTensor); scalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) scalingNode->getOperator()->setBackend("cpu"); if (parentNode->getChildren().size() > 0) { - // SCALING NODE INSERTION - - // We always have one output from Affine and Add nodes, but possibly multiple childs - std::vector<std::shared_ptr<Node>> nextNodes = parentNode->getChildren(0); - - // For each node in nextNodes store the connexion index - std::vector<int> inputIndices(nextNodes.size()); - for (std::size_t i = 0; i < nextNodes.size(); i++) - inputIndices[i] = getInputIndex(nextNodes[i], parentNode); - - for (std::shared_ptr<Node> nextNode : nextNodes) - parentNode->removeChild(nextNode, 0); - - parentNode->addChild(scalingNode, 0, 0); - - for (std::size_t i = 0; i < nextNodes.size(); i++) - scalingNode->addChild(nextNodes[i], 0, inputIndices[i]); - - graphView->add(scalingNode); + insertNodeBetween(parentNode,scalingNode,graphView); + graphView->add(scalingFactorProducer); } else { - // Log::info(" last node reached ! "); + // Log::notice(" last node reached ! "); parentNode->addChild(scalingNode, 0, 0); + graphView->add(scalingFactorProducer); graphView->add(scalingNode); } } @@ -303,7 +453,7 @@ void insertScalingNodes(std::shared_ptr<GraphView> graphView) static std::shared_ptr<Node> getPreviousScalingNode(std::shared_ptr<Node> mergingNode) { std::shared_ptr<Node> currNode = mergingNode; - while(currNode->type() != "Scaling") + while(!currNode->attributes()->hasAttr("isScaling")) { if (currNode->getParents().size() == 0) { @@ -346,7 +496,7 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView) for (std::shared_ptr<Node> node : nodeVector) { // Scaling nodes still have a ratio of 1, so they are seamless ... - if (node->type() == "ReLU" || node->type() == "Scaling" || isSeamless(node)) + if (node->type() == "ReLU" || node->attributes()->hasAttr("isScaling") || isSeamless(node)) { if (node != firstNode) { @@ -362,7 +512,8 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView) std::shared_ptr<Tensor> weightTensor = getWeightTensor(node); double scaling = getTensorAbsoluteMax(weightTensor); double ratio = 1.0 / scaling; - rescaleTensor(weightTensor, ratio); + //rescaleTensor(weightTensor, ratio); + insertScalingBelowProducer(node->getParent(1),ratio,graphView); // Accumulate the ratio if (node == firstNode) @@ -380,7 +531,8 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView) if (nodeHasBias(node)) { std::shared_ptr<Tensor> biasTensor = getBiasTensor(node); - rescaleTensor(biasTensor, accumulatedRatios[node->name()] ); + //rescaleTensor(biasTensor, accumulatedRatios[node->name()] ); + insertScalingBelowProducer(node->getParent(2),accumulatedRatios[node->name()],graphView); } } @@ -407,8 +559,7 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView) std::shared_ptr<Node> scalingNode = getPreviousScalingNode(mergingNode); - double currScalingFactor = getScalingFactor(scalingNode); - updateScalingFactor(scalingNode, currScalingFactor / rescaling); + multiplyScalingFactor(scalingNode,1/rescaling); accumulatedRatios[mergingNode->name()] /= rescaling; // optional ... } @@ -433,7 +584,7 @@ std::map<std::string, double> computeRanges(std::shared_ptr<GraphView> graphView std::set<std::shared_ptr<Node>> nodeSet = graphView->getNodes(); for (std::shared_ptr<Node> node : nodeSet) { - if ((scalingNodesOnly && (node->type() == "Scaling")) || (!scalingNodesOnly && (node->type() != "Producer"))) + if ((scalingNodesOnly && (node->attributes()->hasAttr("isScaling"))) || (!scalingNodesOnly && (node->type() != "Producer"))) { std::shared_ptr<Operator> nodeOperator = node->getOperator(); std::shared_ptr<Tensor> valueTensor = std::static_pointer_cast<Tensor> (nodeOperator->getRawOutput(0)); @@ -455,7 +606,7 @@ std::map<std::string, double> computeRanges(std::shared_ptr<GraphView> graphView // std::shared_ptr<Node> inputNode = getFirstNode(graphView); for (std::shared_ptr<Node> node : nodeSet) - if ((scalingNodesOnly && (node->type() == "Scaling")) || (!scalingNodesOnly && (node->type() != "Producer"))) + if ((scalingNodesOnly && (node->attributes()->hasAttr("isScaling"))) || (!scalingNodesOnly && (node->type() != "Producer"))) valueRanges.insert(std::make_pair(node->name(), 0)); if (useCuda) @@ -468,7 +619,7 @@ std::map<std::string, double> computeRanges(std::shared_ptr<GraphView> graphView for (std::shared_ptr<Tensor> sample : inputDataSet) { - //Log::info(" IT : {}", it++); + //Log::notice(" IT : {}", it++); // Inference ... @@ -482,7 +633,7 @@ std::map<std::string, double> computeRanges(std::shared_ptr<GraphView> graphView std::map<std::string, double> sampleRanges; for (std::shared_ptr<Node> node : nodeSet) { - if ((scalingNodesOnly && (node->type() == "Scaling")) || (!scalingNodesOnly && (node->type() != "Producer"))) + if ((scalingNodesOnly && (node->attributes()->hasAttr("isScaling"))) || (!scalingNodesOnly && (node->type() != "Producer"))) { std::shared_ptr<Operator> nodeOperator = node->getOperator(); std::shared_ptr<Tensor> valueTensor = std::static_pointer_cast<Tensor> (nodeOperator->getRawOutput(0)); @@ -504,7 +655,7 @@ std::map<std::string, double> computeRanges(std::shared_ptr<GraphView> graphView for (std::shared_ptr<Node> node : nodeSet) { - if ((scalingNodesOnly && (node->type() == "Scaling")) || (!scalingNodesOnly && (node->type() != "Producer"))) + if ((scalingNodesOnly && (node->attributes()->hasAttr("isScaling"))) || (!scalingNodesOnly && (node->type() != "Producer"))) { std::string nodeName = node->name(); if (sampleRanges[nodeName] > valueRanges[nodeName]) @@ -540,7 +691,7 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::map<std::st for (std::shared_ptr<Node> node : nodeVector) { // Seamless scaling factor propagation ... - + if (isAffine(node) || isSeamless(node) || node->type() == "ReLU") { if (node == firstNode) @@ -554,11 +705,13 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::map<std::st } } + // Here prevNode is either a 'Affine' or a 'Merging' // => do not split the cases, just handle the bias ... - if (node->type() == "Scaling") + if (node->attributes()->hasAttr("isScaling")) { + // retrieve the previous scaling factor ... std::shared_ptr<Node> prevNode = node->getParent(0); double prevScalingFactor = scalingFactors[prevNode->name()]; @@ -566,8 +719,7 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::map<std::st // ValueRanges must contains all the scaling nodes !!! double scalingFactor = valueRanges[node->name()]; - double currScalingFactor = getScalingFactor(node); - updateScalingFactor(node, currScalingFactor / (scalingFactor / prevScalingFactor)); + multiplyScalingFactor(node,1/(scalingFactor / prevScalingFactor)); scalingFactors[node->name()] = scalingFactor; @@ -575,11 +727,13 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::map<std::st if (isAffine(prevNode)) { + bool prevNodeHasBias = nodeHasBias(prevNode); if (prevNodeHasBias) - { + { std::shared_ptr<Tensor> biasTensor = getBiasTensor(prevNode); - rescaleTensor(biasTensor, 1.0 / prevScalingFactor); + //rescaleTensor(biasTensor, 1.0 / prevScalingFactor); + insertScalingBelowProducer(prevNode->getParent(2),1.0 / prevScalingFactor,graphView); } } } @@ -608,10 +762,9 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::map<std::st double rescaling = mergingNodeScaling / maxScaling; std::shared_ptr<Node> scalingNode = getPreviousScalingNode(mergingNode); - //Log::info(" SCALING NODE : {} {}", scalingNode->type(), scalingNode->name()); - - double currScalingFactor = getScalingFactor(scalingNode); - updateScalingFactor(scalingNode, currScalingFactor * rescaling); + //Log::notice(" SCALING NODE : {} {}", scalingNode->type(), scalingNode->name()); + + multiplyScalingFactor(scalingNode,rescaling) ; } } } @@ -647,7 +800,7 @@ std::map<std::string, std::pair<bool, bool>> computeSignMap(std::shared_ptr<Grap signMap[node->name()].second = false; } - if (node->type() == "Scaling") + if (node->attributes()->hasAttr("isScaling")) { signMap[node->name()].second = false; @@ -694,7 +847,7 @@ std::map<std::string, std::pair<bool, bool>> computeSignMap(std::shared_ptr<Grap // Arbitration : Signed type wins ! for(std::shared_ptr<Node> parent : parentNodes) { - while (parent->type() != "Scaling") + while (!parent->attributes()->hasAttr("isScaling")) { signMap[parent->name()] = std::make_pair(false, false); // We are on a branch so nodes always have 1 parent ... @@ -725,9 +878,9 @@ std::map<std::string, std::pair<bool, bool>> computeSignMap(std::shared_ptr<Grap if (verbose) { - Log::info(" === SIGN MAP === "); + Log::notice(" === SIGN MAP === "); for (std::shared_ptr<Node> node : nodeVector) - Log::info(" {}{} | {}", static_cast<int>(signMap[node->name()].first), static_cast<int>(signMap[node->name()].second), node->name()); + Log::notice(" {}{} | {}", static_cast<int>(signMap[node->name()].first), static_cast<int>(signMap[node->name()].second), node->name()); } // SANITY CHECK (TEMPORARY) @@ -776,26 +929,23 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_ if (isAffine(node)) { // Rescale the weight tensor - std::shared_ptr<Tensor> weightTensor = getWeightTensor(node); - rescaleTensor(weightTensor, signedMax); + insertScalingBelowProducer(node->getParent(1),signedMax,graphView); if (!noQuant) - roundTensor(weightTensor); + insertRoundBelowProducer(node->getParent(1),graphView); // Rescale the bias tensor - if (nodeHasBias(node)) { bool inputIsUnsigned = signMap[node->name()].first; double rescaling = inputIsUnsigned ? unsignedMax * signedMax : signedMax * signedMax; - - + std::shared_ptr<Tensor> biasTensor = getBiasTensor(node); - rescaleTensor(biasTensor, rescaling); + insertScalingBelowProducer(node->getParent(2),rescaling,graphView); if (!noQuant) - roundTensor(biasTensor); + insertRoundBelowProducer(node->getParent(2),graphView); } // Compensate the rescaling using the next Scaling node @@ -810,8 +960,7 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_ std::shared_ptr<Node> scalingNode = *(node->getChildren().begin()); // Assert if scalingNode is a Scaling ... - double currScalingFactor = getScalingFactor(scalingNode); - updateScalingFactor(scalingNode, currScalingFactor * rescaling); + multiplyScalingFactor(scalingNode,rescaling) ; } if (isMerging(node)) @@ -826,23 +975,25 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_ std::shared_ptr<Node> scalingNode = *(node->getChildren().begin()); // Assert if scalingNode is a Scaling ... - double currScalingFactor = getScalingFactor(scalingNode); // XXX bad naming - updateScalingFactor(scalingNode, currScalingFactor * rescaling); + multiplyScalingFactor(scalingNode,rescaling) ; } // Handle the Scaling Nodes ... - if (node->type() == "Scaling") + if (node->attributes()->hasAttr("isScaling")) { if (!noQuant) { // Replace the Scaling Node by Quantizer + auto scalingFactorTensor = std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(1); + std::shared_ptr<Tensor> fallback; + const Tensor& localTensor = scalingFactorTensor->refCastFrom(fallback, DataType::Float64, "cpu"); + double old_sf = localTensor.get<double>(0);//!\\ - std::shared_ptr<Node> quantizerNode = Quantizer(getScalingFactor(node), -(signedMax + 1), signedMax, node->name()); + std::shared_ptr<Node> quantizerNode = Quantizer(old_sf, -(signedMax + 1), signedMax, node->name()); quantizerNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) quantizerNode->getOperator()->setBackend("cpu"); - - graphView->replace({node}, {quantizerNode}); + graphView->replace({node,node->getParent(1)}, {quantizerNode}); if (optimizeSigns) { @@ -856,6 +1007,7 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_ double currScalingFactor = getScalingFactor(quantizerNode); updateScalingFactor(quantizerNode, currScalingFactor * rescaling); + if(outputIsUnsigned) { @@ -876,51 +1028,40 @@ static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::u for (std::shared_ptr<Node> node : nodeVector) { - // A merging node is always followed by a scaling node at this point ... + // A merging node is always followed by a Quantizer node at this point - if (node->type() == "Quantizer") + if (node->type() == "Quantizer" && (node->attributes()->hasAttr("isResidual") || !isAffine(node->getParent(0)))) { - bool prevNodeIsForking = ((node->getParent(0))->getChildren().size() > 1); - bool prevNodeIsAffine = isAffine(node->getParent(0)); - bool insertNode = prevNodeIsForking || !prevNodeIsAffine; - - if (insertNode) - { - // create and insert the multplicative node - - std::string mulNodeName = makeUniqueName(node->name() + "_Mul", graphView); - std::shared_ptr<Node> mulNode = Mul(mulNodeName); - - mulNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) - mulNode->getOperator()->setBackend("cpu"); - - graphView->insertParent(node, mulNode, 0, 0, 0); - // create and insert the producer node + // check if the Quantizer is a residual one, and insert a compensation node if so ... + // create and insert the multplicative node before the Quantizer - std::shared_ptr<Tensor> inputTensor = std::static_pointer_cast<Tensor> (mulNode->getOperator()->getRawInput(0)); - std::shared_ptr<Tensor> coeffTensor = std::make_shared<Tensor>(); + std::string mulNodeName = makeUniqueName(node->name() + "_Mul", graphView); + std::shared_ptr<Node> mulNode = Mul(mulNodeName); + + mulNode->attributes()->addAttr("isCompensation",0.0); + mulNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) + mulNode->getOperator()->setBackend("cpu"); - coeffTensor->setDataType(DataType::Float64); // getDataType(parentNode) - coeffTensor->setBackend("cpu"); + graphView->insertParent(node, mulNode, 0, 0, 0); - coeffTensor->resize(inputTensor->dims()); - fillTensor(coeffTensor, 1); + // Add the coeff producer to the multiplier node - std::shared_ptr<Node> producerNode = Producer(coeffTensor, makeUniqueName("coeff", graphView)); - producerNode->addChild(mulNode); - graphView->add(producerNode); + std::shared_ptr<Node> coeffProducer = addProducer(mulNode, 1, {1}, ""); + std::shared_ptr<Tensor> coeffTensor = std::make_shared<Tensor>(Array1D<double, 1> {signedMax}); + coeffProducer->getOperator()->setOutput(0, coeffTensor); - // rescale the coeffs and edit scaling factor + coeffProducer->getOperator()->setDataType(DataType::Float64); + coeffProducer->attributes()->addAttr("quantization.ptq.CompensationCoeff",signedMax); + coeffProducer->getOperator()->setBackend("cpu"); - fillTensor(coeffTensor, signedMax); + graphView->add(coeffProducer); // needed ? - double currScalingFactor = getScalingFactor(node); // XXX bad naming ! - updateScalingFactor(node, currScalingFactor / signedMax); + // Adapt the scaling factor value accordingly - // TODO : double check this !!! - //std::cout << getTensorAbsoluteMax(coeffTensor) << std::endl; - } + double currScalingFactor = getScalingFactor(node); + updateScalingFactor(node, currScalingFactor / signedMax); + } } } @@ -931,10 +1072,11 @@ void performSingleShiftApproximation(std::shared_ptr<GraphView> graphView, bool for (std::shared_ptr<Node> node : nodeVector) { - // Use A meatoperator of type Scaling of MulCompensation instead - if (isAffine(node) || (node->type() == "Mul")) + if (isAffine(node) || (node->type() == "Mul" && node->attributes()->hasAttr("isCompensation"))) { std::shared_ptr<Node> scalingNode = (*node->getChildren().begin()); + if(scalingNode->attributes()->hasAttr("isCasting")) + scalingNode = (*node->getChildren().begin()); double base = getScalingFactor(scalingNode); @@ -944,17 +1086,16 @@ void performSingleShiftApproximation(std::shared_ptr<GraphView> graphView, bool double ratio = base / approx; - std::shared_ptr<Tensor> weightTensor = getWeightTensor(node); - rescaleTensor(weightTensor, ratio); - if (!noQuant) - roundTensor(weightTensor); + insertScalingBelowProducer(node->getParent(1),ratio,graphView); + if (!noQuant && !node->getParent(1)->attributes()->hasAttr("isProducerRounding")) + insertRoundBelowProducer(node->getParent(1),graphView); if (nodeHasBias(node)) { - std::shared_ptr<Tensor> biasTensor = getBiasTensor(node); - rescaleTensor(biasTensor, ratio); - if (!noQuant) - roundTensor(biasTensor); + insertScalingBelowProducer(node->getParent(2),ratio,graphView); + + if (!noQuant && !node->getParent(1)->attributes()->hasAttr("isProducerRounding")) + insertRoundBelowProducer(node->getParent(2),graphView); } } } @@ -962,12 +1103,12 @@ void performSingleShiftApproximation(std::shared_ptr<GraphView> graphView, bool static void printScalingFactors(std::shared_ptr<GraphView> graphView) { - Log::info(" === SCALING FACTORS === "); + Log::notice(" === SCALING FACTORS === "); for (auto node : retrieveNodeVector(graphView)) - if (node->type() == "Scaling" || node->type() == "Quantizer") + if (node->attributes()->hasAttr("isScaling") || node->type() == "Quantizer") { double scalingFactor = getScalingFactor(node); - Log::info(" {:.6f} ({})", scalingFactor, node->name()); + Log::notice(" {:.6f} ({})", scalingFactor, node->name()); } } @@ -994,13 +1135,14 @@ static void printRanges(std::shared_ptr<GraphView> graphView, std::map<std::stri auto scheduling = scheduler.getStaticScheduling(); for (auto node : scheduling) - if (node->type() == "Scaling") - fmt::println("{} range = {}", node->name(), valueRanges[node->name()]); + if (node->attributes()->hasAttr("isScaling")) + Log::debug("{} range = {}",node->name(),valueRanges[node->name()]); } -void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet, Clipping clippingMode, bool noQuant, bool optimizeSigns, bool singleShift, bool useCuda, bool verbose) +void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet, + Clipping clippingMode, DataType targetType,bool noQuant, bool optimizeSigns, bool singleShift, bool useCuda, bool foldGraph, bool verbose) { - Log::info(" === QUANT PTQ 0.2.21 === "); + Log::notice(" === QUANT PTQ 0.2.21 === "); graphView->setBackend("cpu"); @@ -1010,62 +1152,79 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, if (!checkArchitecture(graphView)) return; - Log::info(" Preparing the network for the PTQ ... "); + Log::notice(" Preparing the network for the PTQ ... "); prepareNetwork(graphView); - Log::info(" Inserting the scaling nodes ..."); + Log::notice(" Inserting the scaling nodes ..."); insertScalingNodes(graphView); crossLayerEqualization(graphView); - - Log::info(" Normalizing the parameters ..."); + Log::notice(" Normalizing the parameters ..."); normalizeParameters(graphView); - Log::info(" Computing the value ranges ..."); + Log::notice(" Computing the value ranges ..."); std::map<std::string, double> valueRanges = computeRanges(graphView, inputDataSet, true, useCuda); - //std::cout << " === RANGES (BEFORE ADJUST) ===" << std::endl; + //Log:debug("=== RANGES (BEFORE ADJUST) ==="); //printRanges(graphView, valueRanges); - Log::info(" Optimizing the clipping values ..."); + Log::notice(" Optimizing the clipping values ..."); valueRanges = adjustRanges(clippingMode, valueRanges, nbBits, graphView, inputDataSet, useCuda, verbose); - //std::cout << " === RANGES (AFTER ADJUST) ===" << std::endl; + //Log:debug("=== RANGES (AFTER ADJUST) ==="); //printRanges(graphView, valueRanges); - - Log::info(" Normalizing the activations ..."); + Log::notice(" Normalizing the activations ..."); normalizeActivations(graphView, valueRanges); - Log::info(" Quantizing the normalized network ..."); + Log::notice(" Quantizing the normalized network ..."); quantizeNormalizedNetwork(graphView, nbBits, noQuant, optimizeSigns, verbose); - + if (singleShift) { - Log::info( " Inserting the compensation nodes ..."); + Log::notice( " Inserting the compensation nodes ..."); insertCompensationNodes(graphView, nbBits); - Log::info(" Performing the Single-Shift approximation ..."); + Log::notice(" Performing the Single-Shift approximation ..."); performSingleShiftApproximation(graphView, noQuant); } + if(targetType != DataType::Float64 && targetType != DataType::Float32 && targetType != DataType::Float16) + { + AIDGE_ASSERT(!noQuant,"Cannot cast operators with the noQuant (Fake Quantization) flag set to true!") + Log::notice("Starting to cast operators into the desired type ..."); + castQuantizedGraph(graphView,DataType::Int32,singleShift); + } + else + { + setupDataType(graphView, inputDataSet, targetType); + } + + if(foldGraph) + { + Log::notice("Applying constant folding recipe to the graph ..."); + applyConstFold(graphView); + } + //Mandatory to handle all of the newly added connections! + graphView->updateInputsOutputs(); + + //reset input nodes + /*for(Aidge::NodePtr input_node : graphView->inputNodes()) + { + std::static_pointer_cast<OperatorTensor>(input_node->getOperator())->resetInput() + }*/ if (verbose) printScalingFactors(graphView); - //std::cout << " === SCALINGS (BEFORE CAST) ===" << std::endl; - //printScalingFactors(graphView); - setupDataType(graphView, inputDataSet, initialDataType); if (useCuda) - graphView->setBackend("cuda"); + //graphView->setBackend("cuda"); - //std::cout << " === SCALINGS (AFTER CAST) ===" << std::endl; - //printScalingFactors(graphView); - - Log::info(" Reseting the scheduler ..."); + Log::notice(" Reseting the scheduler ..."); SequentialScheduler scheduler(graphView); scheduler.resetScheduling(); - Log::info(" Network is quantized !"); + Log::notice(" Network is quantized !"); + } std::map<std::string, double> getWeightRanges(std::shared_ptr<GraphView> graphView) @@ -1090,15 +1249,14 @@ void clearBiases(std::shared_ptr<GraphView> graphView) for (std::shared_ptr<Node> node : graphView->getNodes()) { if (node->type() == "FC" || node->type() == "Conv2D") { std::shared_ptr<Tensor> biasTensor = std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(2); - rescaleTensor(biasTensor, 0); + //rescaleTensor(biasTensor, 0); + insertScalingBelowProducer(node->getParent(2),0,graphView); } } } - void devPTQ(std::shared_ptr<GraphView> graphView) { for (std::shared_ptr<Node> node : graphView->getNodes()) - fmt::println(" UUU : {}", node->name()); + Log::debug(" UUU : {}", node->name()); } - } diff --git a/src/PTQ/PTQMetaOps.cpp b/src/PTQ/PTQMetaOps.cpp deleted file mode 100644 index 527d8534ae4981471e1fa7dca04f08b4e668677b..0000000000000000000000000000000000000000 --- a/src/PTQ/PTQMetaOps.cpp +++ /dev/null @@ -1,152 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include "aidge/quantization/PTQ/PTQMetaOps.hpp" - -#include <array> -#include <memory> -#include <utility> - -//Operator -#include "aidge/operator/Clip.hpp" -#include "aidge/operator/Mul.hpp" -#include "aidge/operator/Round.hpp" - -#include "aidge/graph/Node.hpp" -#include "aidge/graph/OpArgs.hpp" -#include "aidge/operator/MetaOperator.hpp" -#include "aidge/operator/Producer.hpp" -#include "aidge/utils/ArrayHelpers.hpp" -#include "aidge/utils/Types.h" -#include "aidge/operator/Identity.hpp" -#include "aidge/data/Tensor.hpp" -#include "aidge/operator/OperatorTensor.hpp" -#include "aidge/utils/Log.hpp" - - -namespace Aidge -{ - -std::shared_ptr<Node> Quantizer(double scalingFactor, double clipMin, double clipMax, const std::string& name) -{ - // create the nodes - - std::shared_ptr<Node> mulNode = Mul((!name.empty()) ? name + "_MulQuant" : ""); - std::shared_ptr<Node> roundNode = Round((!name.empty()) ? name + "_RoundQuant" : ""); - std::shared_ptr<Node> clipNode = Clip((!name.empty()) ? name + "_ClipQuant" : "", clipMin, clipMax); - - // connect the scaling factor producer - - std::shared_ptr<Tensor> scalingFactorTensor = std::make_shared<Tensor>(Array1D<double, 1> {scalingFactor}); - std::shared_ptr<Node> scalingFactorProducer = addProducer<1>(mulNode, 1, {1}, "ScalingFactor"); - scalingFactorProducer->getOperator()->setOutput(0, scalingFactorTensor); - - // create the metaop graph - - std::shared_ptr<GraphView> graphView = Sequential({mulNode, roundNode, clipNode}); - std::shared_ptr<GraphView> connectedGraphView = getConnectedGraphView(mulNode); // XXX why not use the graphView ??? - - // return the metaop - - std::shared_ptr<Node> metaopNode = MetaOperator("Quantizer", connectedGraphView, {}, name); // XXX alternative prototype - - return metaopNode; -} - -std::shared_ptr<Node> Scaling(double scalingFactor, const std::string& name) -{ - std::shared_ptr<Tensor> scalingFactorTensor = std::make_shared<Tensor>(Array1D<double, 1> {scalingFactor}); - - std::shared_ptr<Node> mulNode = Mul((!name.empty()) ? name + "_Scaling" : ""); - - std::shared_ptr<Node> scalingFactorProducer = addProducer<1>(mulNode, 1, {1}, "ScalingFactor"); - scalingFactorProducer->getOperator()->setOutput(0, scalingFactorTensor); - - std::shared_ptr<GraphView> graphView = Sequential({mulNode}); - std::shared_ptr<GraphView> connectedGraphView = getConnectedGraphView(mulNode); - - NodePtr metaopNode = MetaOperator("Scaling", connectedGraphView, {}, name); - - return metaopNode; -} - -static std::shared_ptr<Node> getSubNode(std::shared_ptr<GraphView> graphView, std::string nodeType) -{ - std::shared_ptr<Node> mulNode = nullptr; - for(std::shared_ptr<Node> node : graphView->getNodes()) - if (node->type() == nodeType) - mulNode = node; - - return mulNode; -} - -void updateScalingFactor(std::shared_ptr<Node> metaOpNode, double scalingFactor) -{ - if(metaOpNode->type() != "Scaling" && metaOpNode->type() != "Quantizer") - Log::warn(" Cannot update the scaling factor on Node of type {}", metaOpNode->type()); - - std::shared_ptr<Tensor> scalingFactorTensor = std::make_shared<Tensor>(Array1D<double, 1> {scalingFactor}); - - std::shared_ptr<MetaOperator_Op> metaOp = std::static_pointer_cast<MetaOperator_Op>(metaOpNode->getOperator()); - - std::shared_ptr<Node> mulNode = getSubNode(metaOp->getMicroGraph(), "Mul"); - - if (!mulNode) - Log::warn(" Invalid PTQ MetaOperator, no Mul node found inside ! "); - - mulNode->input(1).first->getOperator()->setOutput(0, scalingFactorTensor); -} - -double getScalingFactor(std::shared_ptr<Node> MetaOpNode) -{ - if (MetaOpNode->type() != "Scaling" && MetaOpNode->type() != "Quantizer") { - Log::warn(" Cannot get the scaling factor on Node of type {}", MetaOpNode->type()); - return 0; - } - - std::shared_ptr<MetaOperator_Op> metaOp = std::static_pointer_cast<MetaOperator_Op>(MetaOpNode->getOperator()); - - std::shared_ptr<Node> mulNode = getSubNode(metaOp->getMicroGraph(), "Mul"); - - if (!mulNode) { - Log::warn(" Invalid PTQ MetaOperator, no Mul found inside node of type {}", MetaOpNode->type()); - return 0; - } - - auto scalingFactorTensor = std::static_pointer_cast<OperatorTensor>(mulNode->getOperator())->getInput(1); - std::shared_ptr<Tensor> fallback; - const Tensor& localTensor = scalingFactorTensor->refCastFrom(fallback, DataType::Float64, "cpu"); - - return localTensor.get<double>(0); -} - - -void setClipRange(std::shared_ptr<Node> quantizerNode, double min, double max) -{ - if (quantizerNode->type() != "Quantizer") { - Log::warn(" Cannot set the clipping range on Node of type {}", quantizerNode->type()); - return; - } - - std::shared_ptr<MetaOperator_Op> metaOp = std::static_pointer_cast<MetaOperator_Op> (quantizerNode->getOperator()); - - std::shared_ptr<Node> clipNode = getSubNode(metaOp->getMicroGraph(), "Clip"); - - if (!clipNode) { - Log::warn(" Invalid PTQ MetaOperator, no Clip found inside node of type {}", quantizerNode->type()); - return; - } - - std::shared_ptr<Clip_Op> clipOp = std::static_pointer_cast<Clip_Op>(clipNode->getOperator()); - clipOp->max() = max; - clipOp->min() = min; -} -} \ No newline at end of file diff --git a/src/QAT/QAT_LSQ.cpp b/src/QAT/QAT_LSQ.cpp index 9b51e846df498a9303b7373ae1c86d4b007a96f0..8a42770ac9ff5c9426c0538d407c7f58d0021c15 100644 --- a/src/QAT/QAT_LSQ.cpp +++ b/src/QAT/QAT_LSQ.cpp @@ -13,7 +13,6 @@ #include "aidge/operator/LSQ.hpp" #include "aidge/operator/ReLU.hpp" - #include "aidge/data/Tensor.hpp" #include "aidge/graph/GraphView.hpp" #include "aidge/scheduler/SequentialScheduler.hpp" @@ -23,7 +22,42 @@ namespace Aidge { -void QuantLSQ::insertQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits, float stepSize) +static float getTensorAbsMean(std::shared_ptr<Tensor> tensor) +{ + auto valueTensor = (*tensor).abs().mean(); + std::shared_ptr<Tensor> fallback; + const Tensor& localTensor = valueTensor.refCastFrom(fallback, DataType::Float32, "cpu"); + return localTensor.get<float>(0); +} + +// INIT THE STEP SIZE OF A QUANTIZER NODE + +static bool initStepSize(std::shared_ptr<Node> quantizer) +{ + const auto quantizerOp = std::static_pointer_cast<LSQ_Op>(quantizer->getOperator()); + + float inputAbsMean = getTensorAbsMean(quantizerOp->getInput(0)); + + float stepSize = 2.0f * (inputAbsMean / std::sqrt(quantizerOp->range().second)); + + auto stepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}})); + + // XXX Manage backend here ? + stepSizeTensor->setBackend(quantizerOp->getInput(0)->backend()); + stepSizeTensor->setDataType(quantizerOp->getInput(0)->dataType()); + + auto stepSizeProducer = quantizer->getParent(1); + + stepSizeProducer->getOperator()->setOutput(0, stepSizeTensor); + + Log::debug("[ INIT STEP SIZE = {} ]",stepSize); + + return false; +} + +// INPUT QUANTIZERS INSERTION + +static void setupInputQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits) { const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|FC#)"); @@ -34,180 +68,76 @@ void QuantLSQ::insertQuantizers(std::shared_ptr<GraphView> graphView, size_t nbB std::pair<int, int> signedRange = {-std::pow(2, nbBits - 1), std::pow(2, nbBits - 1) - 1}; std::pair<int, int> unsignedRange = {0, std::pow(2, nbBits) - 1}; - // INPUT QUANTIZERS INSERTION + // Create the input quantizer node - // TODO : double check this, and use createUniqueName() - auto inputQuantizerName = makeUniqueName(linearNode->name() + "_lsq_i", graphView); - auto inputQuantizerNode = LSQ(signedRange, inputQuantizerName); + auto quantizerName = makeUniqueName(linearNode->name() + "_lsq_i", graphView); + auto quantizerNode = LSQ(signedRange, quantizerName); - // Set the step size + // Init the step-size using the node call stack - auto inputStepSizeOp = inputQuantizerNode->getParent(1)->getOperator(); - auto inputStepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}})); - inputStepSizeOp->setOutput(0, inputStepSizeTensor); + quantizerNode->addBeforeForward([quantizerNode](){ return initStepSize(quantizerNode); }); // Absorb the ReLU when possible ... - // XXX is this safe ??? - bool nodeHasParent = static_cast<bool> (linearNode->getParents()[0]); - // bool nodeHasParent = (linearNode->getParents().size() != 0); + bool nodeHasParent = static_cast<bool> (linearNode->getParents()[0]); // XXX is this safe ? if (nodeHasParent) { auto parentNode = linearNode->getParents()[0]; if (parentNode->type() == "ReLU") { - auto inputQuantizerOp = std::static_pointer_cast<LSQ_Op> (inputQuantizerNode->getOperator()); - inputQuantizerOp->range() = unsignedRange; + auto quantizerOp = std::static_pointer_cast<LSQ_Op> (quantizerNode->getOperator()); + quantizerOp->range() = unsignedRange; graphView->replace({parentNode}, {}); } } - // We need to handle the case where the linear node is the first one ... + // Insert the quantizer in the graphView ... + // (We need to handle the case where the linear node is the first one) if (nodeHasParent) { - graphView->insertParent(linearNode, inputQuantizerNode, 0, 0, 0); + graphView->insertParent(linearNode, quantizerNode, 0, 0, 0); } else { - inputQuantizerNode->addChild(graphView); - graphView->add(inputQuantizerNode); + quantizerNode->addChild(graphView); + graphView->add(quantizerNode); } - - // PARAM QUANTIZERS INSERTION - - // TODO : double check this, and use createUniqueName() - auto paramQuantizerName = makeUniqueName(linearNode->name() + "_lsq_p", graphView); - auto paramQuantizerNode = LSQ(signedRange, paramQuantizerName); - graphView->insertParent(linearNode, paramQuantizerNode, 1, 0, 0); - - // Set the step size - - auto paramStepSizeOp = paramQuantizerNode->getParent(1)->getOperator(); - auto paramStepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}})); - paramStepSizeOp->setOutput(0, paramStepSizeTensor); } - } -static float getTensorAbsMean(std::shared_ptr<Tensor> tensor) -{ - auto backend = tensor->backend(); - if (backend == "cuda") - tensor->setBackend("cpu"); - - float acc = 0; - float* castedTensor = static_cast<float *> (tensor->getImpl()->rawPtr()); - for(std::size_t i = 0; i < tensor->size(); i++) - acc += std::abs(castedTensor[i]); - acc /= static_cast<float> (tensor->size()); - - if (backend == "cuda") - tensor->setBackend("cuda"); - - return acc; -} +// PARAM QUANTIZERS INSERTION -static std::map<std::string, float> collectInputStats(std::shared_ptr<GraphView> graphView, std::shared_ptr<Tensor> calibrationData, bool useCuda) +static void setupParamQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits) { - // Propagate the calibration tensor + const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|FC#)"); - SequentialScheduler scheduler(graphView); - scheduler.resetScheduling(); - scheduler.forward(true, {calibrationData}); + std::pair<int, int> signedRange = {-std::pow(2, nbBits - 1), std::pow(2, nbBits - 1) - 1}; - // Store the input tensor statistics + for (const auto& match : matches) + { + auto linearNode = match.graph->rootNode(); - if (useCuda) - graphView->setBackend("cpu"); + // TODO : double check this, and use createUniqueName() + auto quantizerName = makeUniqueName(linearNode->name() + "_lsq_p", graphView); + auto quantizerNode = LSQ(signedRange, quantizerName); - std::map<std::string, float> inputStats; - for (auto node : graphView->getNodes()) - { - if (node->type() == "FC" || node->type() == "Conv2D") // TODO: use graph matching !!! - { - const auto op = std::static_pointer_cast<LSQ_Op>(node->getOperator()); - float inputAbsMean = getTensorAbsMean(op->getInput(0)); - inputStats.insert(std::make_pair(node->name(), inputAbsMean)); - fmt::println("{} -> {}", node->name(), inputAbsMean); - } - } + // Init the step-size using the node call stack - if (useCuda) - graphView->setBackend("cuda"); + quantizerNode->addBeforeForward([quantizerNode](){ return initStepSize(quantizerNode); }); - return inputStats; -} + // Insert the quantizer in the graphView -static std::map<std::string, float> collectParamStats(std::shared_ptr<GraphView> graphView, bool useCuda) -{ - if (useCuda) - graphView->setBackend("cpu"); - - std::map<std::string, float> paramStats; - for (auto node : graphView->getNodes()) - { - if (node->type() == "FC" || node->type() == "Conv2D") // TODO: use graph matching !!! - { - const auto op = std::static_pointer_cast<LSQ_Op>(node->getOperator()); - float paramAbsMean = getTensorAbsMean(op->getInput(1)); - paramStats.insert(std::make_pair(node->name(), paramAbsMean)); - fmt::println("{} -> {}", node->name(), paramAbsMean); - } + graphView->insertParent(linearNode, quantizerNode, 1, 0, 0); } - - if (useCuda) - graphView->setBackend("cuda"); - - return paramStats; } -static void adjustQuantizersStepSizes(std::shared_ptr<GraphView> graphView, std::map<std::string, float> inputStats, std::map<std::string, float> paramStats) +void QuantLSQ::setupQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits) { - const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|FC#)"); - - for (const auto& match : matches) - { - auto linearNode = match.graph->rootNode(); - - // INPUT QUANTIZERS STEP-SIZES - - auto inputQuantNode = linearNode->getParent(0); - auto inputQuantOp = std::static_pointer_cast<LSQ_Op>(inputQuantNode->getOperator()); - - float absMean = inputStats[linearNode->name()]; - float stepSize = 2.0f * (absMean / std::sqrt(inputQuantOp->range().second)); - - auto inputStepSizeOp = inputQuantNode->getParent(1)->getOperator(); - // XXX inputStepSizeOp->setOutput(0, std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}}))); - auto inputStepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}})); - inputStepSizeOp->setOutput(0, inputStepSizeTensor); - - // PARAM QUANTIZERS STEP-SIZES - - auto paramQuantNode = linearNode->getParent(1); - auto paramQuantOp = std::static_pointer_cast<LSQ_Op>(paramQuantNode->getOperator()); - - absMean = paramStats[linearNode->name()]; - stepSize = 2.0f * (absMean / std::sqrt(paramQuantOp->range().second)); - - auto paramStepSizeOp = paramQuantNode->getParent(1)->getOperator(); - // XXX paramStepSizeOp->setOutput(0, std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}}))); - auto paramStepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}})); - paramStepSizeOp->setOutput(0, paramStepSizeTensor); - } + setupInputQuantizers(graphView, nbBits); + setupParamQuantizers(graphView, nbBits); } -void QuantLSQ::insertAndInitQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits, std::shared_ptr<Tensor> calibrationData) +void QuantLSQ::devLSQ(std::shared_ptr<Tensor> tensor) { - bool useCuda = (calibrationData->backend() == "cuda"); - - // Collect the tensor statisics - auto inputStats = collectInputStats(graphView, calibrationData, useCuda); - - auto paramStats = collectParamStats(graphView, useCuda); - - // Insert the quantizers - insertQuantizers(graphView, nbBits, 1.0); - - // Adjust the quantizers step-sizes - adjustQuantizersStepSizes(graphView, inputStats, paramStats); + float mean = (tensor->mean()).get<float> (0); + Log::debug("MEAN = {}",mean); } } \ No newline at end of file diff --git a/src/operator/PTQMetaOps.cpp b/src/operator/PTQMetaOps.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fb7366467cf9d1e84b3465929027e0217b2a354f --- /dev/null +++ b/src/operator/PTQMetaOps.cpp @@ -0,0 +1,229 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/operator/PTQMetaOps.hpp" + +#include <array> +#include <memory> +#include <utility> + +//Operator +#include "aidge/operator/Clip.hpp" +#include "aidge/operator/Mul.hpp" +#include "aidge/operator/Round.hpp" +#include "aidge/operator/Cast.hpp" +#include "aidge/operator/BitShift.hpp" + +#include "aidge/graph/Node.hpp" +#include "aidge/graph/OpArgs.hpp" +#include "aidge/operator/MetaOperator.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/utils/ArrayHelpers.hpp" +#include "aidge/utils/Types.h" +#include "aidge/operator/Identity.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/utils/Log.hpp" + + +namespace Aidge +{ +static std::shared_ptr<Node> getSubNode(std::shared_ptr<GraphView> graphView, std::string nodeType) +{ + std::shared_ptr<Node> mulNode = nullptr; + for(std::shared_ptr<Node> node : graphView->getNodes()) + if (node->type() == nodeType) + mulNode = node; + + return mulNode; +} + +std::shared_ptr<Node> Quantizer(double scalingFactor, double clipMin, double clipMax, const std::string& name) +{ + // create the nodes + + std::shared_ptr<Node> mulNode = Mul((!name.empty()) ? name + "_MulQuant" : ""); + std::shared_ptr<Node> roundNode = Round((!name.empty()) ? name + "_RoundQuant" : ""); + std::shared_ptr<Node> clipNode = Clip((!name.empty()) ? name + "_ClipQuant" : "", clipMin, clipMax); + + // connect the scaling factor producer + + std::shared_ptr<Tensor> scalingFactorTensor = std::make_shared<Tensor>(Array1D<double, 1> {scalingFactor}); + std::shared_ptr<Node> scalingFactorProducer = addProducer<1>(mulNode, 1, {1}, "ScalingFactor"); + scalingFactorProducer->getOperator()->setOutput(0, scalingFactorTensor); + + // create the metaop graph + + std::shared_ptr<GraphView> graphView = Sequential({mulNode, roundNode, clipNode}); + std::shared_ptr<GraphView> connectedGraphView = getConnectedGraphView(mulNode); // XXX why not use the graphView ??? + + // return the metaop + + std::shared_ptr<Node> metaopNode = MetaOperator("Quantizer", connectedGraphView, {}, name); // XXX alternative prototype + + return metaopNode; +} +std::shared_ptr<Node> BitShiftQuantizer(std::shared_ptr<Node> oldQuantizer, DataType targetType, const std::string& name) +{ + double scalingFactor = getScalingFactor(oldQuantizer); + + std::shared_ptr<MetaOperator_Op> metaOp = std::static_pointer_cast<MetaOperator_Op> (oldQuantizer->getOperator()); + std::shared_ptr<Node> oldclipNode = getSubNode(metaOp->getMicroGraph(), "Clip"); + + if (!oldclipNode) { + Log::warn("Invalid PTQ MetaOperator, no Clip found inside node of type {}", oldQuantizer->type()); + return nullptr; + } + + std::shared_ptr<Clip_Op> clipOp = std::static_pointer_cast<Clip_Op>(oldclipNode->getOperator()); + int shift = std::log2(scalingFactor); + BitShift_Op::BitShiftDirection direction = BitShift_Op::BitShiftDirection::left; + + if(shift < 0 ) + { + direction = BitShift_Op::BitShiftDirection::right; + shift = -shift; + } + + std::shared_ptr<Node> bitShiftNode = BitShift(direction,(!name.empty()) ? name + "_MulIQuant" : ""); + std::shared_ptr<Node> clipNode = Clip((!name.empty()) ? name + "_IClipQuant" : "", clipOp->min(), clipOp->max()); + + std::shared_ptr<Tensor> bitshiftTensor = std::make_shared<Tensor>(Array1D<int, 1> {shift}); + std::shared_ptr<Node> bitshiftProducer = addProducer(bitShiftNode, 1, {1}, "ScalingFactor"); + + bitshiftProducer->getOperator()->setOutput(0, bitshiftTensor); + bitshiftProducer->attributes()->addAttr("quantization.ptq.ShiftAmount",shift); + bitshiftProducer->getOperator()->setDataType(targetType); + + // connect the scaling factor producer + + bitShiftNode->getOperator()->setDataType(targetType); + clipNode->getOperator()->setDataType(targetType); + + // create the metaop graph + + std::shared_ptr<GraphView> graphView = Sequential({bitShiftNode,clipNode}); + std::shared_ptr<GraphView> connectedGraphView = getConnectedGraphView(bitShiftNode); // XXX why not use the graphView ??? + + // return the metaop + std::shared_ptr<Node> metaopNode = MetaOperator("BitShiftQuantizer", connectedGraphView, {}, name); // XXX alternative prototype + + return metaopNode; +} +std::shared_ptr<Node> IntQuantizer(std::shared_ptr<Node> oldQuantizer, DataType targetType, const std::string& name) +{ + double scalingFactor = getScalingFactor(oldQuantizer); + + std::shared_ptr<MetaOperator_Op> metaOp = std::static_pointer_cast<MetaOperator_Op> (oldQuantizer->getOperator()); + std::shared_ptr<Node> oldclipNode = getSubNode(metaOp->getMicroGraph(), "Clip"); + + if (!oldclipNode) { + Log::warn("Invalid PTQ MetaOperator, no Clip found inside node of type {}", oldQuantizer->type()); + return nullptr; + } + std::shared_ptr<Clip_Op> clipOp = std::static_pointer_cast<Clip_Op>(oldclipNode->getOperator()); + + std::shared_ptr<Node> castPreNode = Cast(DataType::Float64,((!name.empty()) ? name + "_PreCast" : "")); + + std::shared_ptr<Node> mulNode = Mul((!name.empty()) ? name + "_MulIQuant" : ""); + std::shared_ptr<Node> roundNode = Round((!name.empty()) ? name + "_IRoundQuant" : ""); + std::shared_ptr<Node> clipNode = Clip((!name.empty()) ? name + "_IClipQuant" : "", clipOp->min(), clipOp->max()); + + std::shared_ptr<Node> castPostNode = Cast(targetType,((!name.empty()) ? name + "_PostCast" : "")); + + // connect the scaling factor producer + + castPreNode->getOperator()->setDataType(DataType::Float64); + mulNode->getOperator()->setDataType(DataType::Float64); + roundNode->getOperator()->setDataType(DataType::Float64); + clipNode->getOperator()->setDataType(DataType::Float64); + + castPostNode->getOperator()->setDataType(targetType); + + std::shared_ptr<Tensor> scalingFactorTensor = std::make_shared<Tensor>(Array1D<double, 1> {scalingFactor}); + std::shared_ptr<Node> scalingFactorProducer = addProducer<1>(mulNode, 1, {1}, "ScalingFactor"); + scalingFactorProducer->getOperator()->setOutput(0, scalingFactorTensor); + + // create the metaop graph + + std::shared_ptr<GraphView> graphView = Sequential({castPreNode, mulNode, roundNode, clipNode, castPostNode}); + std::shared_ptr<GraphView> connectedGraphView = getConnectedGraphView(mulNode); // XXX why not use the graphView ??? + + // return the metaop + + std::shared_ptr<Node> metaopNode = MetaOperator("IntQuantizer", connectedGraphView, {}, name); // XXX alternative prototype + + return metaopNode; +} + + +void updateScalingFactor(std::shared_ptr<Node> metaOpNode, double scalingFactor) +{ + if(metaOpNode->type() != "Scaling" && metaOpNode->type() != "Quantizer") + Log::warn("Cannot update the scaling factor on Node of type {}", metaOpNode->type()); + + std::shared_ptr<Tensor> scalingFactorTensor = std::make_shared<Tensor>(Array1D<double, 1> {scalingFactor}); + + std::shared_ptr<MetaOperator_Op> metaOp = std::static_pointer_cast<MetaOperator_Op>(metaOpNode->getOperator()); + + std::shared_ptr<Node> mulNode = getSubNode(metaOp->getMicroGraph(), "Mul"); + + if (!mulNode) + Log::warn("Invalid PTQ MetaOperator, no Mul node found inside ! "); + + mulNode->input(1).first->getOperator()->setOutput(0, scalingFactorTensor); +} + +double getScalingFactor(std::shared_ptr<Node> MetaOpNode) +{ + if (MetaOpNode->type() != "Scaling" && MetaOpNode->type() != "Quantizer") { + Log::warn("Cannot get the scaling factor on Node of type {}", MetaOpNode->type()); + return 0; + } + + std::shared_ptr<MetaOperator_Op> metaOp = std::static_pointer_cast<MetaOperator_Op>(MetaOpNode->getOperator()); + + std::shared_ptr<Node> mulNode = getSubNode(metaOp->getMicroGraph(), "Mul"); + + if (!mulNode) { + Log::warn("Invalid PTQ MetaOperator, no Mul found inside node of type {}", MetaOpNode->type()); + return 0; + } + + auto scalingFactorTensor = std::static_pointer_cast<OperatorTensor>(mulNode->getOperator())->getInput(1); + std::shared_ptr<Tensor> fallback; + const Tensor& localTensor = scalingFactorTensor->refCastFrom(fallback, DataType::Float64, "cpu"); + + return localTensor.get<double>(0); +} + + +void setClipRange(std::shared_ptr<Node> quantizerNode, double min, double max) +{ + if (quantizerNode->type() != "Quantizer") { + Log::warn("Cannot set the clipping range on Node of type {}", quantizerNode->type()); + return; + } + + std::shared_ptr<MetaOperator_Op> metaOp = std::static_pointer_cast<MetaOperator_Op> (quantizerNode->getOperator()); + + std::shared_ptr<Node> clipNode = getSubNode(metaOp->getMicroGraph(), "Clip"); + + if (!clipNode) { + Log::warn("Invalid PTQ MetaOperator, no Clip found inside node of type {}", quantizerNode->type()); + return; + } + + std::shared_ptr<Clip_Op> clipOp = std::static_pointer_cast<Clip_Op>(clipNode->getOperator()); + clipOp->max() = max; + clipOp->min() = min; +} +} \ No newline at end of file