diff --git a/.gitignore b/.gitignore index ba5c59398b68083c6c1c5fe820fb9070d999c18e..57409a5cddc52f82eb67bf88b0ae28ca23e8a72b 100644 --- a/.gitignore +++ b/.gitignore @@ -4,8 +4,7 @@ # C++ Build build*/ install*/ -include/aidge/backend/quantization_version.h - +include/aidge/quantization_version.h # VSCode .vscode diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 6b36832776146dedcd397491fbaa3771e6558fdd..c970ca03506986817525463dc0fb9fc5dcade666 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -12,48 +12,39 @@ stages: - deploy include: - - project: 'eclipse/aidge/gitlab_shared_files' + - project: 'eclipse/aidge/gitlab_shared_files' ref: 'main' - file: + file: #Â choose which jobs to run by including the corresponding files. - '.gitlab/ci/ubuntu_cpp.gitlab-ci.yml' - - '.gitlab/ci/ubuntu_python.gitlab-ci.yml' - - '.gitlab/ci/release/cibuildwheel_ubuntu.gitlab-ci.yml' - - # Cannot find successful job on aidge_backend_cuda yet - # - '.gitlab/ci/windows_cpp.gitlab-ci.yml' - - # - '.gitlab/ci/windows_python.gitlab-ci.yml' - # - '.gitlab/ci/release/cibuildwheel_windows.gitlab-ci.yml' - + - '.gitlab/ci/release/cibuildwheel_ubuntu.gitlab-ci.yml' test:ubuntu_python: before_script: - - !reference [.retrieve_deps:apt, script] - - source venv/bin/activate - - python -m pip install numpy unittest-xml-reporting - - python -m pip list + - !reference [.setup:test:ubuntu_python, before_script] - DEPS_NAMES=("aidge_onnx") - DEPENDENCY_JOB="build:ubuntu_python" - !reference [.ubuntu:download:artifacts, script] coverage:ubuntu_python: - before_script: - - !reference [.retrieve_deps:apt, script] - - source venv/bin/activate - - python -m pip install numpy coverage + before_script: + - !reference [.setup:coverage:ubuntu_python, before_script] - DEPS_NAMES=("aidge_onnx") - DEPENDENCY_JOB="build:ubuntu_python" - !reference [.ubuntu:download:artifacts, script] -release:pip:ubuntu: - before_script: - - !reference [.retrieve_deps:apt, script] - - DEPS_NAMES=("aidge_core" "aidge_backend_cpu" "aidge_backend_cuda" "aidge_onnx") - - DEPENDENCY_JOB="build:ubuntu_python" - - !reference [.ubuntu:download:repositories, script] # located in common.gitlab-ci.yml - - curl -sSL https://get.docker.com/ | sh +# release:pip:ubuntu: +# variables: +# # Building CPU quantization package +# CIBW_ENVIRONMENT: >- +# AIDGE_WITH_CUDA=OFF +# before_script: +# - !reference [.retrieve_deps:apt, script] +# - DEPS_NAMES=("aidge_core" "aidge_backend_cpu" "aidge_onnx") +# - DEPENDENCY_JOB="build:ubuntu_python" +# - !reference [.ubuntu:download:repositories, script] # located in common.gitlab-ci.yml +# - curl -sSL https://get.docker.com/ | sh # release:pip:windows: @@ -66,12 +57,12 @@ release:pip:ubuntu: # # Install dependencies # - choco install cmake.install --installargs '"ADD_CMAKE_TO_PATH=System"' -Y # - choco install git -Y -# - choco install python --version=$python_version -Y +# - choco install python --version=$python_version -Y # # Update PATH # - $env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User") # - python -m pip install cibuildwheel==2.17.0 # # Download repositories # - $DEPS_NAMES = "aidge_core","aidge_backend_cpu","aidge_backend_cuda","aidge_onnx" # - $DEPENDENCY_JOB="build:windows_python" -# - !reference [.windows:download:repositories, script] +# - !reference [.windows:download:repositories, script] diff --git a/.gitlab/ci/cibuildwheel_build_deps_before_build_wheel.sh b/.gitlab/ci/cibuildwheel_build_deps_before_build_wheel.sh index 4f74488ae41714a4ce03ba7514bf93842768c5ae..da504a28acee381078a178f4721b6b70d82eb37b 100755 --- a/.gitlab/ci/cibuildwheel_build_deps_before_build_wheel.sh +++ b/.gitlab/ci/cibuildwheel_build_deps_before_build_wheel.sh @@ -1,6 +1,6 @@ #!/bin/bash set -e -if [[ "$1" == "" ]]; then +if [[ "$1" == "" ]]; then echo "build aidge deps in cibuildwheel container before building wheel." echo "search path defines where the dependencies will be searched." echo "Hint : In wheel containers, files are mounted on /host by default." @@ -10,13 +10,14 @@ set -x if [[ $AIDGE_DEPENDENCIES == "" ]]; then # case for aidge_ core mkdir -p build # creating build if its not already there to hold the build of cpp files rm -rf build/* # build from scratch -else +else for repo in $AIDGE_DEPENDENCIES ; do # case for other projects search_path=$1 REPO_PATH=$(find $search_path ! -writable -prune -o -type d \ -name "$repo" \ -not -path "*/install/*" \ -not -path "*/.git/*" \ + -not -path "*/.mypy_cache/*" \ -not -path "*/miniconda/*" \ -not -path "*/conda/*" \ -not -path "*/.local/*" \ @@ -24,7 +25,7 @@ else -not -path "*/$repo/$repo/*" \ -not -path "*/proc/*" \ -print -quit) - if [[ -z "$REPO_PATH" ]]; then + if [[ -z "$REPO_PATH" ]]; then echo "ERROR : dependency $repo not found in search_path \"$search_path\". ABORTING." exit -1 fi @@ -33,6 +34,10 @@ else mkdir -p build # creating build if its not already there to hold the build of cpp files rm -rf build/* # build from scratch pip install . -v + + # Give all rights on generated build folder to avoid root issues once out of the Docker + chmod -R a+rwX build/ + chmod -R a+rwX *.egg-info/ cd - done fi diff --git a/CMakeLists.txt b/CMakeLists.txt index b3c6d459dfaf29f5accbc0be4565a3709e9ffd3b..afb882af0d02000e5490f1d2a0c56b4487481be9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -61,7 +61,7 @@ option(PYBIND "python binding" OFF) option(WERROR "Warning as error" OFF) option(TEST "Enable tests" OFF) option(COVERAGE "Enable coverage" OFF) -option(CUDA "Enable CUDA backend" OFF) # XXX OFF +option(CUDA "Enable CUDA backend" ON) # XXX OFF option(ENABLE_ASAN "Enable ASan (AddressSanitizer) for runtime analysis of memory use (over/underflow, memory leak, ...)" OFF) ############################################## @@ -182,7 +182,6 @@ endif() # Coverage flags for GCC if(CMAKE_COMPILER_IS_GNUCXX AND COVERAGE) - include(CodeCoverage) append_coverage_compiler_flags() endif() diff --git a/aidge_quantization/unit_tests/assets/BranchNetV4.onnx b/aidge_quantization/unit_tests/assets/BranchNetV4.onnx new file mode 100644 index 0000000000000000000000000000000000000000..34cccc47c4b5014f0adc4757d0b8e362a8e5ddce Binary files /dev/null and b/aidge_quantization/unit_tests/assets/BranchNetV4.onnx differ diff --git a/aidge_quantization/unit_tests/assets/MLP.onnx b/aidge_quantization/unit_tests/assets/MLP.onnx new file mode 100644 index 0000000000000000000000000000000000000000..f6b72dbbd8c829a1d3609d923478887892b9e231 Binary files /dev/null and b/aidge_quantization/unit_tests/assets/MLP.onnx differ diff --git a/aidge_quantization/unit_tests/assets/TestNet.onnx b/aidge_quantization/unit_tests/assets/TestNet.onnx new file mode 100644 index 0000000000000000000000000000000000000000..7f73e9b11d8a2ca43c88e52295dd201211f1e741 Binary files /dev/null and b/aidge_quantization/unit_tests/assets/TestNet.onnx differ diff --git a/aidge_quantization/unit_tests/test_ptq.py b/aidge_quantization/unit_tests/test_ptq.py index 56080bff0d1f4a95248fa983316dbafd35565501..f6b243c27b1a08dbbfc5da522e385ceb4ec9c2f4 100644 --- a/aidge_quantization/unit_tests/test_ptq.py +++ b/aidge_quantization/unit_tests/test_ptq.py @@ -1,118 +1,137 @@ import unittest -import gzip import numpy as np -from pathlib import Path - +import gzip import aidge_core -import aidge_backend_cpu import aidge_onnx +import aidge_backend_cpu import aidge_quantization +import sys +from pathlib import Path -from aidge_core import Log, Level - +""" + Unit test for the PTQ pipeline + ============================== + This script is designed to test and validate the accuracy of five small topologies on the MNIST dataset : + ["MiniResNet", "ConvNet", "BranchNetV4", "TestNet", "MLP"] + It compares the results for three configurations : baseline, quantization, and quantization with single shift. + The value of delta represents the tolerance of the tests. +""" +aidge_core.Log.set_console_level(aidge_core.Level.Error) # Reduce useless logs # -------------------------------------------------------------- -# CONFIGS +# CONFIGURATION # -------------------------------------------------------------- -NB_SAMPLES = 1000 # max : 1000 -SAMPLE_SHAPE = (1, 1, 28, 28) -MODEL_NAME = 'MiniResNet.onnx' # 'ConvNet.onnx' -ACCURACIES = (95.4, 94.4) # (97.9, 97.7) -NB_BITS = 4 +NB_SAMPLES = 1000 +SAMPLE_SHAPE = (1, 1, 28, 28) +NB_BITS = 4 +TARGET_TYPE = aidge_core.dtype.int32 +CLIPPING = aidge_quantization.Clipping.MSE +NO_QUANT = False +OPTIM_SIGNS = True +FOLD_GRAPH = True +DELTA = 0.05 + +EXPECTED_RESULTS = { + "MiniResNet.onnx" : (95.4, 94.4, 95.0), + "ConvNet.onnx" : (97.9, 97.2, 96.7), + "BranchNetV4.onnx" : (93.8, 92.7, 93.7), + "TestNet.onnx" : (95.5, 94.0, 94.5), + "MLP.onnx" : (94.7, 92.9, 93.8) +} # -------------------------------------------------------------- # UTILS # -------------------------------------------------------------- def propagate(model, scheduler, sample): + sample = np.reshape(sample, SAMPLE_SHAPE) input_tensor = aidge_core.Tensor(sample) scheduler.forward(True, [input_tensor]) output_node = model.get_output_nodes().pop() output_tensor = output_node.get_operator().get_output(0) return np.array(output_tensor) -def prepare_sample(sample): - sample = np.reshape(sample, SAMPLE_SHAPE) - return sample.astype('float32') - def compute_accuracy(model, samples, labels): - acc = 0 scheduler = aidge_core.SequentialScheduler(model) - for i, sample in enumerate(samples): - x = prepare_sample(sample) - y = propagate(model, scheduler, x) - if labels[i] == np.argmax(y): - acc += 1 + acc = sum(labels[i] == np.argmax(propagate(model, scheduler, x)) for i, x in enumerate(samples)) return acc / len(samples) # -------------------------------------------------------------- # TEST CLASS # -------------------------------------------------------------- -class test_ptq(unittest.TestCase): +class TestQuantization(unittest.TestCase): def setUp(self): - - # load the samples / labels (numpy) - curr_file_dir = Path(__file__).parent.resolve() self.samples = np.load(gzip.GzipFile(curr_file_dir / 'assets/mnist_samples.npy.gz', "r")) - self.labels = np.load(gzip.GzipFile(curr_file_dir / 'assets/mnist_labels.npy.gz', "r")) - - # load the model in AIDGE + self.labels = np.load(gzip.GzipFile(curr_file_dir / 'assets/mnist_labels.npy.gz', "r")) + self.quant_samples = np.round(self.samples.copy() * (2**(NB_BITS-1)-1)) - self.model = aidge_onnx.load_onnx(curr_file_dir / "assets/" / MODEL_NAME, verbose=False) - aidge_core.remove_flatten(self.model) + def run_model_test(self, model_name): - self.model.set_datatype(aidge_core.dtype.float32) - self.model.set_backend("cpu") + expected_base, expected_quant, expected_quant_ssa = EXPECTED_RESULTS[model_name] - def tearDown(self): - pass + # load the model ... + model_path = Path(__file__).parent / "assets" / model_name + model = aidge_onnx.load_onnx(model_path, verbose=False) + aidge_core.remove_flatten(model) - def test_model(self): + model.set_datatype(aidge_core.dtype.float32) + model.set_backend("cpu") + + # create the tensor subset - Log.set_console_level(Level.Info) - # compute the base accuracy - accuracy = compute_accuracy(self.model, self.samples[0:NB_SAMPLES], self.labels) - self.assertAlmostEqual(accuracy * 100, ACCURACIES[0], msg='base accuracy does not meet the baseline !', delta=0.1) + tensors = [aidge_core.Tensor(np.reshape(sample, SAMPLE_SHAPE)) for sample in self.samples[:NB_SAMPLES]] - def test_quant_model(self): + # BASELINE ACCURACY - Log.set_console_level(Level.Debug) + base_accuracy = compute_accuracy(model, tensors, self.labels[:NB_SAMPLES]) + self.assertAlmostEqual(base_accuracy * 100, expected_base, delta=DELTA, msg=f"[X] Baseline accuracy mismatch for {model_name}. Expected accuracy was: {expected_base}, but got: {base_accuracy * 100}") - # create the calibration dataset - - tensors = [] - for sample in self.samples[0:NB_SAMPLES]: - sample = prepare_sample(sample) - tensor = aidge_core.Tensor(sample) - tensors.append(tensor) - - # quantize the model + # QUANTIZED ACCURACY aidge_quantization.quantize_network( - self.model, - NB_BITS, - tensors, - clipping_mode=aidge_quantization.Clipping.MSE, - no_quantization=False, - optimize_signs=True, - single_shift=False - ) - - # rescale the inputs - - scaling = 2**(NB_BITS-1)-1 - for i in range(NB_SAMPLES): - self.samples[i] = self.samples[i]*scaling # XXX np.round ??? - - # compute the quantized accuracy - - accuracy = compute_accuracy(self.model, self.samples, self.labels) - self.assertAlmostEqual(accuracy * 100, ACCURACIES[1], msg='quantized accuracy does not meet the baseline !', delta=0.1) - + network=model, + nb_bits=NB_BITS, + calibration_set=tensors, + target_type=TARGET_TYPE, + clipping_mode=CLIPPING, + no_quant=NO_QUANT, + optimize_signs=OPTIM_SIGNS, + single_shift=False, + use_cuda=False, + fold_graph=FOLD_GRAPH) + + quant_accuracy = compute_accuracy(model, self.quant_samples[:NB_SAMPLES], self.labels) + self.assertAlmostEqual(quant_accuracy * 100, expected_quant, delta=DELTA, msg=f"[X] Quantized accuracy mismatch for {model_name}. Expected accuracy was: {expected_quant}, but got: {quant_accuracy * 100}") + + # QUANTIZED ACCURACY WITH SSA + + model = aidge_onnx.load_onnx(model_path, verbose=False) + model.set_datatype(aidge_core.dtype.float32) + model.set_backend("cpu") + + aidge_quantization.quantize_network( + network=model, + nb_bits=NB_BITS, + calibration_set=tensors, + target_type=TARGET_TYPE, + clipping_mode=CLIPPING, + no_quant=NO_QUANT, + optimize_signs=OPTIM_SIGNS, + single_shift=True, + use_cuda=False, + fold_graph=FOLD_GRAPH) + + quant_accuracy_ssa = compute_accuracy(model, self.quant_samples[:NB_SAMPLES], self.labels) + self.assertAlmostEqual(quant_accuracy_ssa * 100, expected_quant_ssa, delta=DELTA, msg=f"[X] Quantized accuracy (with SSA) mismatch for {model_name}. Expected accuracy was: {expected_quant_ssa}, but got: {quant_accuracy_ssa * 100}") + + def test_models(self): + for model in EXPECTED_RESULTS.keys(): + with self.subTest(model=model): + self.run_model_test(model) if __name__ == '__main__': unittest.main() diff --git a/cmake/PybindModuleCreation.cmake b/cmake/PybindModuleCreation.cmake index 07cdd658b0e6ae6549b5dfd7663e9973c59c6a9f..e3fe6a7383656e053fe7f89da2fda1083d6374ae 100644 --- a/cmake/PybindModuleCreation.cmake +++ b/cmake/PybindModuleCreation.cmake @@ -4,7 +4,7 @@ function(generate_python_binding pybind_module_name target_to_bind) Include(FetchContent) - set(PYBIND_VERSION v2.10.4) + set(PYBIND_VERSION v2.13.6) message(STATUS "Retrieving pybind ${PYBIND_VERSION} from git") FetchContent_Declare( diff --git a/include/aidge/backend/cpu/operator/FixedQImpl_kernels.hpp b/include/aidge/backend/cpu/operator/FixedQImpl_kernels.hpp index 9d7a106cdd6b1a0970c87b2a27cc7d6637846b49..935d8f065a5e91729c5c0ff25b13f5ea1234a8b6 100644 --- a/include/aidge/backend/cpu/operator/FixedQImpl_kernels.hpp +++ b/include/aidge/backend/cpu/operator/FixedQImpl_kernels.hpp @@ -23,7 +23,7 @@ void FixedQImpl_cpu_forward_kernel( std::size_t nbBits, float span_, bool isOutputUnsigned, - std::size_t inputLenght, + std::size_t inputLength, const void* input_, void* output_) { @@ -40,7 +40,7 @@ void FixedQImpl_cpu_forward_kernel( const I* input = static_cast<const I*>(input_); O* output = static_cast<O*>(output_); - for (std::size_t i = 0; i < inputLenght; ++i) { + for (std::size_t i = 0; i < inputLength; ++i) { I clipped = std::max(lower, std::min(input[i], upper)); output[i] = std::round(clipped / stepSize) * stepSize; } @@ -49,14 +49,14 @@ void FixedQImpl_cpu_forward_kernel( template <class GI, class GO> void FixedQImpl_cpu_backward_kernel( - const std::size_t inputLenght, + const std::size_t inputLength, const void* grad_output_, void* grad_input_) { const GO* grad_output = static_cast<const GO*>(grad_output_); GI* grad_input = static_cast<GI*>(grad_input_); - for (std::size_t i = 0; i < inputLenght; ++i) { + for (std::size_t i = 0; i < inputLength; ++i) { // Straight Through Estimator grad_input[i] = grad_output[i]; } diff --git a/include/aidge/operator/LSQ.hpp b/include/aidge/operator/LSQ.hpp index 970c476cb7be18b8d001edb27d60079de85b9349..b6abf90371a3053fa7971b9242a5309362ea478e 100644 --- a/include/aidge/operator/LSQ.hpp +++ b/include/aidge/operator/LSQ.hpp @@ -55,7 +55,7 @@ public: */ LSQ_Op(const LSQ_Op& op) : OperatorTensor(op), - mAttributes(op.mAttributes) + mAttributes(std::make_shared<Attributes_>(*op.mAttributes)) { if (op.mImpl){ SET_IMPL_MACRO(LSQ_Op, *this, op.backend()); diff --git a/include/aidge/operator/PTQMetaOps.hpp b/include/aidge/operator/PTQMetaOps.hpp new file mode 100644 index 0000000000000000000000000000000000000000..96182202f38be20afa539eb41a8d32b989afcf9f --- /dev/null +++ b/include/aidge/operator/PTQMetaOps.hpp @@ -0,0 +1,85 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ +#ifndef AIDGE_QUANTIZATION_QUANTIZATION_PTQ_PTQMETAOPS_H_ +#define AIDGE_QUANTIZATION_QUANTIZATION_PTQ_PTQMETAOPS_H_ + +#include <memory> +#include <string> + +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/data/Data.hpp" + +namespace Aidge { + + /** + * @brief Create a Quantizer node that initially consists of a multiplier and a scaling factor. + * @param scalingFactor The value of the multiplicative coefficient. + * @param name Name of the Quantizer. + */ + std::shared_ptr<Aidge::Node> Quantizer(double scalingFactor, const std::string& name); + + /** + * @brief Given a Quantizer, multiply it's internal multiplicative coefficient by a value. + * @param coeff The value of the multiplicative coefficient. + */ + void multiplyScalingFactor(std::shared_ptr<Aidge::Node> quantizer, double coeff); + + /** + * @brief Given a Quantizer, create a copy of it that has a Round node and a Clip node added + * at its endpoint, and replace the given Quantizer by it (a swap is also done by reference). + * @param quantizer The quantizer to modify and replace. + * @param clipMin the min value of the clip node. + * @param clipMax the max value of the clip node. + */ + void appendRoundClip(std::shared_ptr<Node>& quantizer, double clipMin, double clipMax); + + /** + * @brief Given a Quantizer, create a copy of it that has the Round node removed, + * and replace the given Quantizer by it (a swap is also done by reference). + * @param quantizer The quantizer to modify and replace. + */ + void removeRound(std::shared_ptr<Node>& quantizer); + + /** + * @brief Given a Quantizer, create a copy of it where the Mul node is replaced by + * a Bit-Shift node, and replace the given Quantizer by it (a swap is also done by reference). + * @param quantizer The quantizer to modify and replace. + */ + void replaceScalingWithBitShift(std::shared_ptr<Node>& quantizer); + + /** + * @brief Given a Quantizer, create a copy of it that has Cast nodes inserted at it's IOs, + * and replace the given Quantizer by it (a swap is also done by reference). The input cast + * node convert the input data to the internal type, while the output cast convert it back + * to the external type. + * @param quantizer The quantizer to modify and replace. + * @param externalType The external data type used for the casts. + */ + void castQuantizerIOs(std::shared_ptr<Node>& quantizer, Aidge::DataType externalType); + + /** + * @brief Given a Quantizer, retreive the coefficient of it's Mul node. + * @param quantizer The quantizer containing the multiplicative coefficient. + */ + double getScalingFactor(std::shared_ptr<Aidge::Node> quantizer); + + /** + * @brief Given a Quantizer containing a Clip node, replace its clipping values. + * @param quantizer The quantizer containing the Clip node. + * @param min The min clipping value. + * @param max The max clipping value. + */ + void setClipRange(std::shared_ptr<Aidge::Node> quantizer, double min, double max); + +} + +#endif /* AIDGE_QUANTIZATION_QUANTIZATION_PTQ_PTQMETAOPS_H_ */ diff --git a/include/aidge/quantization/PTQ/Clipping.hpp b/include/aidge/quantization/PTQ/Clipping.hpp index 3f65c42eb2032da10c4d337b53fb1bdd08a7aa55..35f23f5f2022128238e1991717876d6462d0b6da 100644 --- a/include/aidge/quantization/PTQ/Clipping.hpp +++ b/include/aidge/quantization/PTQ/Clipping.hpp @@ -13,7 +13,7 @@ #define AIDGE_QUANTIZATION_QUANTIZATION_PTQ_CLIP_H_ #include <cstdint> // std::uint8_t -#include <map> +#include <unordered_map> #include <memory> #include <string> #include <vector> @@ -33,10 +33,10 @@ namespace Aidge * @param valueRanges A map associating each considered node name to its corresponding output range. * @param nbBins Desired number of bins of the returned histograms. * @param graphView The GraphView containing the considered nodes. - * @param inputDataSet The input dataset, consisting of a vector of input samples. + * @param calibrationSet The calibration dataset, consisting of a vector of input samples. * @return A map associating each node name to it's corresponding activation histogram. */ - std::map<std::string, std::vector<int>> computeHistograms(std::map<std::string, double> valueRanges, int nbBins, std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet, bool useCuda); + std::unordered_map<std::shared_ptr<Node>, std::vector<int>> computeHistograms(std::unordered_map<std::shared_ptr<Node>, double> valueRanges, int nbBins, std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> calibrationSet, bool useCuda); /** * @brief Given an input activation histogram, compute the optimal clipping value in the sense of the Lp norm. @@ -63,11 +63,11 @@ namespace Aidge * @param valueRanges The map associating each affine node to its output range. * @param nbBits The quantization number of bits. * @param graphView The GraphView containing the considered nodes. - * @param inputDataSet The input dataset, consisting of a vector of input samples. + * @param calibrationSet The calibration dataset, consisting of a vector of input samples. * @param verbose Whether to print the clipping values or not. * @return The corrected map associating each provided node to its clipped range. */ - std::map<std::string, double> adjustRanges(Clipping clippingMode, std::map<std::string, double> valueRanges, std::uint8_t nbBits, std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet, bool useCuda, bool verbose); + std::unordered_map<std::shared_ptr<Node>, double> adjustRanges(Clipping clippingMode, std::unordered_map<std::shared_ptr<Node>, double> valueRanges, std::uint8_t nbBits, std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> calibrationSet, bool useCuda, bool verbose); } diff --git a/include/aidge/quantization/PTQ/PTQ.hpp b/include/aidge/quantization/PTQ/PTQ.hpp index 4fc38bc3b959ec8264ddaddbd4673fbe1f75e4ab..d9b944e33cc5706bb8f62ddeb1553ace0619245d 100644 --- a/include/aidge/quantization/PTQ/PTQ.hpp +++ b/include/aidge/quantization/PTQ/PTQ.hpp @@ -13,7 +13,7 @@ #define AIDGE_QUANTIZATION_QUANTIZATION_PTQ_PTQ_H_ #include <cstdint> // std::uint8_t -#include <map> +#include <unordered_map> #include <memory> #include <set> #include <string> @@ -41,6 +41,11 @@ namespace Aidge { */ static const std::set<std::string> mergingNodeTypes({"Add", "Concat", "Sub"}); + /** + * @brief Set of the types of the nodes that won't be quanized + */ + static const std::set<std::string> notQuantizedNodeTypes({"Sigmoid", "Tanh"}); + /** * @brief Determine if a node contains an affine transform (that is Y = A.X + B) * @param node The node to be checked @@ -62,6 +67,19 @@ namespace Aidge { */ bool isMerging(std::shared_ptr<Node> node); + /** + * @brief Determine if a node contains an operator that won't be quantized + * @param node The node to be checked + * @return True if the node is not quantized, else false. + */ + bool isNotQuantized(std::shared_ptr<Node> node); + + /** + * @brief Compute the absolute max of a tensor + * @param tensor The Tensor to process + */ + double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor); + /** * @brief Retrieve the scheduled vector of node of a graphView, without the Producer nodes. * @param graphView The graphView containing the nodes @@ -69,6 +87,14 @@ namespace Aidge { * @return The scheduled vector of nodes */ std::vector<std::shared_ptr<Node>> retrieveNodeVector(std::shared_ptr<GraphView> graphView, bool newSchedule = true, bool verbose = false); + + /** + * @brief Inserts a scaling node below the given producer node in the graphView. + * @param node A shared pointer to the producer node where the scaling node will be inserted (below). + * @param graphView A shared pointer to the graph view in which the nodes are located. + * @return True if the scaling node was successfully inserted or the scaling factor was accumulated; False otherwise. + */ + void insertScalingBelowProducer(std::shared_ptr<Node> node, std::shared_ptr<GraphView> graphView); /** * @brief Determine whether an input GraphView can be quantized or not. @@ -77,9 +103,21 @@ namespace Aidge { */ bool checkArchitecture(std::shared_ptr<GraphView> graphView); + /** + * @brief This function multiplies the existing scaling factor by a given coefficient. It verifies that the node is of the correct type ("Mul") + * and has the `isScaling` attribute. If these conditions are not met, a warning is logged. + * @param node A shared pointer to an `Aidge::Node` object representing the node to modify. + * @param coeff A double representing the multiplication coefficient to apply to the scaling factor. + */ + void multiplyScalingFactor(std::shared_ptr<Aidge::Node> node, double coeff); - void prepareNetwork(std::shared_ptr<GraphView> graphView); + /** + * @brief Prepare a network before the quantization is applied to it, by removing, replacing + * or fusing the nodes that are not supported by the PTQ pipeline. + * @param graphView The network to prepare for the quantization + */ + void prepareNetwork(std::shared_ptr<GraphView> graphView); /** * @brief Insert a scaling node after each affine node of the GraphView. @@ -97,11 +135,11 @@ namespace Aidge { /** * @brief Compute the activation ranges of every affine node, given an input dataset. * @param graphView The GraphView containing the affine nodes, on which the inferences are performed. - * @param inputDataSet The input dataset, consisting of a vector of input samples. + * @param calibrationSet The calibration dataset, consisting of a vector of input samples. * @param scalingNodesOnly Whether to restrain the retreival of the ranges to scaling nodes only or not. * @return A map associating each affine node name to it's corresponding output range. */ - std::map<std::string, double> computeRanges(std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet, bool scalingNodesOnly, bool useCuda); + std::unordered_map<std::shared_ptr<Node>, double> computeRanges(std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> calibrationSet, bool scalingNodesOnly, bool useCuda); /** * @brief Normalize the activations of each affine node so that they fit in the [-1:1] range. @@ -109,7 +147,7 @@ namespace Aidge { * @param graphView The GraphView containing the affine nodes. * @param valueRanges The node output value ranges computed over the calibration dataset. */ - void normalizeActivations(std::shared_ptr<GraphView> graphView, std::map<std::string, double> valueRanges); + void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_map<std::shared_ptr<Node>, double> valueRanges); /** * @brief For each node, compute the sign of its input and output values. @@ -118,37 +156,52 @@ namespace Aidge { * @param verbose Whether to print the sign map or not. * @return A map associating a pair of sign to each node of the GraphView (a sign for the input and one for the output). */ - std::map<std::string, std::pair<bool, bool>> computeSignMap(std::shared_ptr<GraphView> graphView, bool verbose); + std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> computeSignMap(std::shared_ptr<GraphView> graphView, bool verbose); /** * @brief Quantize an already normalized (in term of parameters and activations) network. * @param graphView The GraphView to be quantized. * @param nbBits The desired number of bits of the quantization. - * @param applyRounding Whether to apply the rounding operations or not. + * @param noQuant Whether to apply the rounding operations or not. * @param optimizeSigns Whether to take account of the IO signs of the operators or not. * @param verbose Whether to print the sign map or not. */ - void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, bool applyRounding, bool optimizeSigns, bool verbose); + void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, bool noQuant, bool optimizeSigns, bool verbose); + + /** + * @brief Take a quantized GraphView represented in floating precision and cast it to the desired target precision. + * If single-shift option is set to True, the scaling nodes contained in the activation quantizers are replaced with bit-shifts. + * @param graphView The GraphView to modify. + * @param targetType The desired precision of the cast. + * @param singleShift If set to True, replace the scaling-factors by bit-shifts. + * @param bitShiftRounding If singleShift is True, specifies the kind of bit-shift roundinqg.. + * + */ + void castQuantizedNetwork(std::shared_ptr<GraphView> graphView, Aidge::DataType targetType, bool singleShift /*, bool bitShiftRounding*/); /** * @brief Main quantization routine. Performs every step of the quantization pipeline. * @param graphView The GraphView to be quantized. * @param nbBits The desired number of bits of the quantization. - * @param inputDataSet The input dataset on which the value ranges are computed. + * @param calibrationSet The calibration dataset used for the activations calibration. + * @param targetType The desired data-type of the outputed GraphView. * @param clippingMode: Type of the clipping optimization. Can be either 'MAX', 'MSE', 'AA' or 'KL'. - * @param applyRounding Whether to apply the rounding operations or not. + * @param noQuant Whether to apply the rounding operations or not. * @param optimizeSigns Whether to take account of the IO signs of the operators or not. - * @param singleShift Whether to convert the scaling factors into powers of two. If true the approximations are compensated using the previous nodes weights. - * @param verbose Whether to print internal informations about the quantization process. + * @param singleShift Whether to convert the scaling factors into powers of two. If true the approximations are compensated using the previous nodes parameters. + * @param useCuda Whether to use the CUDA backend for performing the activation calibration or not. + * @param foldGraph Whether to fold the parameter quantizers after the quantization or not. + * @param verbose Whether to print internal informations about the quantization process or not. */ - void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet, Clipping clippingMode, bool applyRounding, bool optimizeSigns, bool singleShift, bool useCuda, bool verbose); + void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> calibrationSet, DataType targetType, Clipping clippingMode, bool noQuant, bool optimizeSigns, bool singleShift, bool useCuda, bool foldGraph, bool verbose); + /** * @brief Compute the weight ranges of every affine node. Provided for debugging purposes. * @param graphView The GraphView containing the affine nodes. * @return A map associating each affine node name to it's corresponding weight range. */ - std::map<std::string, double> getWeightRanges(std::shared_ptr<GraphView> graphView); + std::unordered_map<std::string, double> getWeightRanges(std::shared_ptr<GraphView> graphView); /** * @brief Clear the affine nodes biases. Provided form debugging purposes. diff --git a/include/aidge/quantization/PTQ/PTQMetaOps.hpp b/include/aidge/quantization/PTQ/PTQMetaOps.hpp deleted file mode 100644 index b9bad0d18f099e94d4c52254b08629c7f947db6a..0000000000000000000000000000000000000000 --- a/include/aidge/quantization/PTQ/PTQMetaOps.hpp +++ /dev/null @@ -1,70 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ -#ifndef AIDGE_QUANTIZATION_QUANTIZATION_PTQ_PTQMETAOPS_H_ -#define AIDGE_QUANTIZATION_QUANTIZATION_PTQ_PTQMETAOPS_H_ - -#include <memory> -#include <string> - -#include "aidge/graph/GraphView.hpp" -#include "aidge/graph/Node.hpp" - -namespace Aidge { - -/// @brief Quantizer acts as a meta-operator to handle scaling operations in the PTQ, replacing the Scaling Operator. -/// This operator is composed of a sequence of [Mul] -> [Clip] -> [Round] operations. -/// -/// @param scalingFactor The scaling factor to apply to the input (essentially a scalar to multiply the input with). -/// @param clip_min The minimum value for the clip operation. -/// @param clip_max The maximum value for the clip operation. -/// @param name The name of the meta-operator node created. -/// @return A shared pointer to an instance of the meta-operator node. -std::shared_ptr<Aidge::Node> Quantizer(double scalingFactor, double clipMin, double clipMax, const std::string& name); - -/// @brief The purpose of Scaling is to encapsulate the Mul operator and tag it as a PTQ node rather than a regular Mul operator. -/// Therefore, this meta-operator consists solely of a [Mul] operation. -/// -/// @param scalingFactor The scaling factor to apply to the input (a scalar to multiply the input with). -/// @param name The name of the meta-operator node created. -/// @return A shared pointer to an instance of the scaling node. -std::shared_ptr<Aidge::Node> Scaling(double scalingFactor, const std::string& name = ""); - -/// @brief Updates the scaling factor of a PTQ meta-operator node, allowing for dynamic adjustment of the scaling parameter. -/// This function sets a new scaling factor for a specified meta-operator node, modifying the scalar applied in the [Mul] operation. -/// The meta-operator node must be a PTQ-specific operator, such as a Quantizer or Scaling node. -/// -/// @param MetaOpNode A shared pointer to the PTQ meta-operator node whose scaling factor will be updated. -/// @param newScalingFactor The new scaling factor to apply to the meta-operator node. -/// @return True if the scaling factor was successfully updated, false if the operation failed (e.g., if MetaOpNode is null or incompatible). -void updateScalingFactor(std::shared_ptr<Aidge::Node> MetaOpNode, double newScalingFactor); - -/// @brief Retrieves the current scaling factor of a PTQ meta-operator node. -/// This function returns the scaling factor associated with the specified PTQ meta-operator node, -/// allowing inspection of the current scalar applied in the [Mul] operation. -/// -/// @param MetaOpNode A shared pointer to the PTQ meta-operator node whose scaling factor is being queried. -/// @return The scaling factor currently applied to the meta-operator node, or -1 if the operation fails (e.g., if MetaOpNode is null or incompatible). -double getScalingFactor(std::shared_ptr<Aidge::Node> MetaOpNode); - -/// @brief Sets the clip range for an existing Quantizer node by specifying minimum and maximum clipping values. -/// This function modifies the clip range of a Quantizer node, allowing adjustment of the range within which values are clipped -/// in the [Clip] operation of the Quantizer sequence. -/// -/// @param QuantizerNode A shared pointer to the Quantizer node whose clip range is being set. -/// This node should have been created using the Quantizer function. -/// @param min The minimum value for the clip range. Values below this will be clipped to this minimum. -/// @param max The maximum value for the clip range. Values above this will be clipped to this maximum. -/// @return True if the clip range was successfully set, false if the operation failed (e.g., if QuantizerNode is null). -void setClipRange(std::shared_ptr<Aidge::Node> QuantizerNode, double min, double max); - -} - -#endif /* AIDGE_QUANTIZATION_QUANTIZATION_PTQ_PTQMETAOPS_H_ */ diff --git a/include/aidge/quantization/QAT/QAT_LSQ.hpp b/include/aidge/quantization/QAT/QAT_LSQ.hpp index a44c71b04ca9e9c6a8fba27c615c99b4893d3d8c..7919b1af10647379f11d8819d1c3583a6c1fe9cb 100644 --- a/include/aidge/quantization/QAT/QAT_LSQ.hpp +++ b/include/aidge/quantization/QAT/QAT_LSQ.hpp @@ -22,25 +22,17 @@ namespace Aidge { namespace QuantLSQ { /** - * @brief Insert the LSQ quantizer nodes in a given GraphView - * @param graphView The GraphView containing the graph to quantize. + * @brief Given a GraphView with parameters properly initialized, insert + * the LSQ quantizer nodes, and setup the adjustment their step-sizes. + * @param graphView The GraphView containing the network to quantize. * @param nbBits Number of quantization bits. - * @param span Fixed output span of the quantizers. */ -void insertQuantizers(std::shared_ptr<GraphView> graphView, std::size_t nbBits, float step_size); -/** - * @brief Given a GraphView with parameters properly initialized and some calibration data, - * insert the LSQ quantizer nodes, and adjust their step-sizes. - * @param graphView The GraphView containing the graph to quantize. - * @param nbBits Number of quantization bits. - * @param calibrationData Calibration data used to adjust the spans. - * @param scale Multiplicative constant applied to the spans. - */ -void insertAndInitQuantizers(std::shared_ptr<GraphView> graphView, std::size_t nbBits, std::shared_ptr<Tensor> calibrationData); +void setupQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits); } // namespace QuantLSQ } // namespace Aidge #endif /* AIDGE_QUANTIZATION_QUANTIZATION_QAT_LSQ_H_ */ - + + \ No newline at end of file diff --git a/include/aidge/quantization_version.h b/include/aidge/quantization_version.h deleted file mode 100644 index 546263af3a7e8b7a73991173f48d0b095c7d9501..0000000000000000000000000000000000000000 --- a/include/aidge/quantization_version.h +++ /dev/null @@ -1,11 +0,0 @@ -#ifndef VERSION_H -#define VERSION_H - -namespace Aidge { -static constexpr const int PROJECT_VERSION_MAJOR = 0; -static constexpr const int PROJECT_VERSION_MINOR = 2; -static constexpr const int PROJECT_VERSION_PATCH = 0; -static constexpr const char * PROJECT_VERSION = "0.2.0"; -static constexpr const char * PROJECT_GIT_HASH = "f50c860"; -} -#endif // VERSION_H diff --git a/include/aidge/recipes/QuantRecipes.hpp b/include/aidge/recipes/QuantRecipes.hpp index 39349f962d61970020741ba533403ba03559a53f..1e78699c579d53549ada884247ff545ac451f737 100644 --- a/include/aidge/recipes/QuantRecipes.hpp +++ b/include/aidge/recipes/QuantRecipes.hpp @@ -40,6 +40,14 @@ namespace Aidge * @param graphView The GraphView to process. */ void sanitizeNodeNames(std::shared_ptr<GraphView> graphView); + + /** + * @brief Given a GraphView, set all it's MatMul weights to index 1 (required for the PTQ) + * This operation involve the insertion of Transpose nodes as well as the transposition of + * the MatMul weight tensors. + * @param graphView The GraphView to process. + */ + void reorderMatMulInputs(std::shared_ptr<GraphView> graphView); } #endif /* AIDGE_QUANTIZATION_QUANTRECIPES_H_ */ diff --git a/pyproject.toml b/pyproject.toml index 088200e44f589e221982ddaab825986c4224243d..a483e7fee71ca5ec8565c51914755f71fe2647dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ description="Quantization algorithms to compress aidge networks." dependencies = [ "numpy>=1.21.6", ] -requires-python = ">= 3.8" +requires-python = ">= 3.10" readme = "README.md" license = { file = "LICENSE" } classifiers = [ @@ -61,37 +61,29 @@ exclude = [ # CIBUILDWHEEL [tool.cibuildwheel] build-frontend = "build" -test-requires = "pytest" -test-command = "pytest {package}/aidge_quantization/unit_tests" +# test-requires = "pytest" +# test-command = "pytest {package}/aidge_quantization/unit_tests" +test-command = "" +manylinux-x86_64-image = "quay.io/pypa/manylinux_2_28_x86_64:2024.11.16-1" # uncomment to run cibuildwheel locally on selected distros -# build=[ +build=[ # "cp38-manylinux_x86_64", # "cp39-manylinux_x86_64", # "cp310-manylinux_x86_64", # "cp38-win_amd64", # "cp39-win_amd64", # "cp310-win_amd64", -# ] -# PYLINT -[tool.pylint.main] -# A comma-separated list of package or module names from where C extensions may -# be loaded. Extensions are loading into the active Python interpreter and may -# run arbitrary code. -extension-pkg-allow-list = ["aidge_core", "aidge_backend_cpu", "aidge_quantization", "onnx"] -# Files or directories to be skipped. They should be base names, not paths. -ignore = ["CVS"] -# List of module names for which member attributes should not be checked (useful -# for modules/projects where namespaces are manipulated during runtime and thus -# existing member attributes cannot be deduced by static analysis). It supports -# qualified module names, as well as Unix pattern matching. -ignored-modules = ["aidge_core", "aidge_backend_cpu", "aidge_quantization", "onnx"] -## AIDGE DEPENDENCIES DECLARATION +] + [tool.cibuildwheel.environment] -AIDGE_DEPENDENCIES = "aidge_core aidge_backend_cpu aidge_onnx" # format => "dep_1 dep_2 ... dep_n" +AIDGE_DEPENDENCIES = "aidge_core aidge_backend_cpu" # format => "dep_1 dep_2 ... dep_n" AIDGE_INSTALL="/AIDGE_INSTALL_CIBUILDWHEEL" +SEARCH_PATH="" +AIDGE_WITH_CUDA="OFF" + [tool.cibuildwheel.linux] before-build = [ - "bash .gitlab/ci/cibuildwheel_build_deps_before_build_wheel.sh /host" + "bash .gitlab/ci/cibuildwheel_build_deps_before_build_wheel.sh /host/$SEARCH_PATH" ] before-test = [ "pip install aidge_core aidge_backend_cpu aidge_onnx" @@ -103,3 +95,18 @@ before-build = [ before-test = [ "pip install aidge_core aidge_backend_cpu aidge_onnx" ] + +# PYLINT +[tool.pylint.main] +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +extension-pkg-allow-list = ["aidge_core", "aidge_backend_cpu", "aidge_quantization", "onnx"] +# Files or directories to be skipped. They should be base names, not paths. +ignore = ["CVS"] +# List of module names for which member attributes should not be checked (useful +# for modules/projects where namespaces are manipulated during runtime and thus +# existing member attributes cannot be deduced by static analysis). It supports +# qualified module names, as well as Unix pattern matching. +ignored-modules = ["aidge_core", "aidge_backend_cpu", "aidge_quantization", "onnx"] +## AIDGE DEPENDENCIES DECLARATION diff --git a/python_binding/operator/pybind_PTQMetaOps.cpp b/python_binding/operator/pybind_PTQMetaOps.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c8df2992702da05d9336208d7dc09ac7510dc4f5 --- /dev/null +++ b/python_binding/operator/pybind_PTQMetaOps.cpp @@ -0,0 +1,121 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + + #include <pybind11/pybind11.h> + #include <pybind11/stl.h> + #include <pybind11/functional.h> + + #include "aidge/operator/PTQMetaOps.hpp" + #include "aidge/graph/Node.hpp" + #include "aidge/utils/Types.h" + + namespace py = pybind11; + + namespace Aidge { + + void init_PTQMetaOps(py::module &m) { + // Quantizer creation and manipulation + m.def("quantizer", &Quantizer, + py::arg("scaling_factor"), + py::arg("name") = "", + R"doc( + Create a quantizer node with specified scaling factor. + + Args: + scaling_factor (float): The scaling factor to apply + name (str): Optional name for the quantizer node + + Returns: + Node: The created quantizer node + )doc"); + + m.def("multiply_scaling_factor", &multiplyScalingFactor, + py::arg("quantizer"), + py::arg("coefficient"), + R"doc( + Multiply the scaling factor of a quantizer by a coefficient. + + Args: + quantizer (Node): The quantizer node to modify + coefficient (float): The multiplication factor + )doc"); + + m.def("get_scaling_factor", &getScalingFactor, + py::arg("quantizer"), + R"doc( + Get the current scaling factor of a quantizer. + + Args: + quantizer (Node): The quantizer node to query + + Returns: + float: The current scaling factor + )doc"); + + // Quantizer modification functions + m.def("append_round_clip", &appendRoundClip, + py::arg("quantizer"), + py::arg("clip_min"), + py::arg("clip_max"), + R"doc( + Append round and clip operations to a quantizer. + + Args: + quantizer (Node): The quantizer node to modify + clip_min (float): Minimum clipping value + clip_max (float): Maximum clipping value + )doc"); + + m.def("set_clip_range", &setClipRange, + py::arg("quantizer"), + py::arg("min"), + py::arg("max"), + R"doc( + Set the clipping range of a quantizer that already has clip operations. + + Args: + quantizer (Node): The quantizer node to modify + min (float): New minimum clipping value + max (float): New maximum clipping value + )doc"); + + m.def("remove_round", &removeRound, + py::arg("quantizer"), + R"doc( + Remove the round operation from a quantizer. + + Args: + quantizer (Node): The quantizer node to modify + )doc"); + + // Advanced quantization operations + m.def("replace_scaling_with_bitshift", &replaceScalingWithBitShift, + py::arg("quantizer"), + R"doc( + Replace multiplicative scaling with bit-shift operations. + + Args: + quantizer (Node): The quantizer node to modify + )doc"); + + m.def("cast_quantizer_ios", &castQuantizerIOs, + py::arg("quantizer"), + py::arg("external_type"), + R"doc( + Cast the input/output of a quantizer to specified data type. + + Args: + quantizer (Node): The quantizer node to modify + external_type (DataType): Target data type for I/O + )doc"); + } + + } // namespace Aidge \ No newline at end of file diff --git a/python_binding/pybind_PTQ.cpp b/python_binding/pybind_PTQ.cpp index b5193bddcfe345a1702f02fcc139a4cf5b94a1ce..ad6931c8f6dcc9e6f3dd8d16fb57e6cadf06efe6 100644 --- a/python_binding/pybind_PTQ.cpp +++ b/python_binding/pybind_PTQ.cpp @@ -55,13 +55,13 @@ void init_PTQ(py::module &m) { :type network: :py:class:`aidge_core.GraphView` )mydelimiter"); - m.def("compute_ranges", &computeRanges, py::arg("network"), py::arg("input_dataset"), py::arg("scaling_nodes_only"), py::arg("use_cuda"), + m.def("compute_ranges", &computeRanges, py::arg("network"), py::arg("calibration_set"), py::arg("scaling_nodes_only"), py::arg("use_cuda"), R"mydelimiter( Compute the activation ranges of every affine node, given an input dataset. :param network: The GraphView containing the affine nodes, on which the inferences are performed. :type network: :py:class:`aidge_core.GraphView` - :param input_dataset: The input dataset, consisting of a vector of input samples. - :type input_dataset: list of :py:class:`aidge_core.Tensor` + :param calibration_set: The input dataset, consisting of a vector of input samples. + :type calibration_set: list of :py:class:`aidge_core.Tensor` :param scaling_nodes_only: Whether to restrain the retreival of the ranges to scaling nodes only or not :type scaling_nodes_only: bool :return: A map associating each considered node name to it's corresponding output range. @@ -85,36 +85,54 @@ void init_PTQ(py::module &m) { :type network: :py:class:`aidge_core.GraphView` :param nb_bits: The desired number of bits of the quantization. :type nb_bits: int - :param apply_rounding: Whether to apply the rounding operations or not. - :type apply_rounding: bool + :param no_quant: Whether to apply the rounding operations or not. + :type no_quant: bool :param optimize_signs: Whether to take account of the IO signs of the operators or not. :type optimize_signs: bool :param verbose: Whether to print the sign map or not. :type verbose: bool )mydelimiter"); - m.def("quantize_network", &quantizeNetwork ,py::arg("network"), py::arg("nb_bits"), py::arg("input_dataset"), py::arg("clipping_mode") = Clipping::MAX , py::arg("no_quantization") = true, py::arg("optimize_signs") = false, py::arg("single_shift") = false, py::arg("use_cuda") = false, py::arg("verbose") = false, + m.def("cast_quantized_network", &castQuantizedNetwork, py::arg("network"), py::arg("target_type"), py::arg("single_shift"), /* py::arg("bitshift_rounding"),*/ + R"mydelimiter( + Take a quantized GraphView represented in floating precision and cast it to the desired target precision. + If single-shift option is set to True, the scaling nodes contained in the activation quantizers are replaced with bit-shifts. + :param network: The GraphView to cast. + :type network: :py:class:`aidge_core.GraphView` + :param targetType: The node output value ranges computed over the calibration dataset. + :type targetType: :py:class:`aidge_core.DataType` + :param single_shift: If set to True, replace the scaling-factors by bit-shifts. + :type single_shift: bool + )mydelimiter"); + + m.def("quantize_network", &quantizeNetwork, py::arg("network"), py::arg("nb_bits"), py::arg("calibration_set"), py::arg("target_type"), py::arg("clipping_mode") = Clipping::MAX , py::arg("no_quant") = false, py::arg("optimize_signs") = false, py::arg("single_shift") = false, py::arg("use_cuda") = false, py::arg("fold_graph") = true, py::arg("verbose") = false, R"mydelimiter( Main quantization routine. Performs every step of the quantization pipeline. :param network: The GraphView to be quantized. :type network: :py:class:`aidge_core.GraphView` :param nb_bits: The desired number of bits of the quantization. :type nb_bits: int - :param input_dataset: The input dataset on which the value ranges are computed. - :type input_dataset: list of :py:class:`aidge_core.Tensor` + :param calibration_set: The input dataset used for the activations calibration. + :type calibration_set: list of :py:class:`aidge_core.Tensor` + :param target_type: The desired data-type of the outputed GraphView. + :type target_type: :py:class:`aidge_core.DataType` :param clipping_mode: Type of the clipping optimization. Can be either 'MAX', 'MSE', 'AA' or 'KL'. - :type clipping_mode: string - :param no_quantization: Whether to truly quantize the network or not. - :type no_quantization: bool + :type clipping_mode: :py:class:`aidge_quantization.Clipping` + :param no_quant: Whether to apply the rounding operations or not. + :type no_quant: bool :param optimize_signs: Whether to take account of the IO signs of the operators or not. :type optimize_signs: bool - :param single_shift: Whether to convert the scaling factors into powers of two. If true the approximations are compensated using the previous nodes weights. + :param single_shift: Whether to convert the scaling factors into powers of two. If true the approximations are compensated using the previous nodes parameters. :type single_shift: bool + :param use_cuda: Whether to use the CUDA backend for performing the activation calibration or not. + :type use_cuda: bool + :param fold_graph: Whether to fold the parameter quantizers after the quantization or not. + :type fold_graph: bool :param verbose: Whether to print internal informations about the quantization process. :type verbose: bool )mydelimiter"); - m.def("compute_histograms", &computeHistograms, py::arg("value_ranges"), py::arg("nb_bins"), py::arg("network"), py::arg("input_dataset"), py::arg("use_cuda"), + m.def("compute_histograms", &computeHistograms, py::arg("value_ranges"), py::arg("nb_bins"), py::arg("network"), py::arg("calibration_set"), py::arg("use_cuda"), R"mydelimiter( Compute the histograms of the activations of each node contained in the map of the ranges (passed as argument). :param value_ranges: A map associating each considered node name to its corresponding output range. @@ -123,8 +141,8 @@ void init_PTQ(py::module &m) { :type nb_bins: int :param network: The GraphView containing the considered nodes. :type network: :py:class:`aidge_core.GraphView` - :param input_dataset: The input dataset, consisting of a list of input samples. - :type input_dataset: list of :py:class:`aidge_core.Tensor` + :param calibration_set: The input dataset, consisting of a list of input samples. + :type calibration_set: list of :py:class:`aidge_core.Tensor` :return: A map associating each node name to it's corresponding activation histogram. :rtype: dict )mydelimiter"); @@ -153,7 +171,7 @@ void init_PTQ(py::module &m) { :rtype: float )mydelimiter"); - m.def("adjust_ranges", &adjustRanges, py::arg("clipping_mode"), py::arg("value_ranges"), py::arg("nb_bits"), py::arg("network"), py::arg("input_dataset"), py::arg("use_cuda"), py::arg("verbose") = false, + m.def("adjust_ranges", &adjustRanges, py::arg("clipping_mode"), py::arg("value_ranges"), py::arg("nb_bits"), py::arg("network"), py::arg("calibration_set"), py::arg("use_cuda"), py::arg("verbose") = false, R"mydelimiter( Return a corrected map of the provided activation ranges. To do so compute the optimal clipping values for every node and multiply the input ranges by those values. @@ -166,15 +184,14 @@ void init_PTQ(py::module &m) { :type nb_bits: int :param network: The GraphView containing the considered nodes. :type network: :py:class:`aidge_core.GraphView` - :param input_dataset: The input dataset, consisting of a list of input samples. - :type input_dataset: list of :py:class:`aidge_core.Tensor` + :param calibration_set: The input dataset, consisting of a list of input samples. + :type calibration_set: list of :py:class:`aidge_core.Tensor` :param verbose: Whether to print the clipping values or not. :type verbose: bool :return: The corrected map associating to each provided node its clipped range. :rtype: dict )mydelimiter"); - m.def("compute_sign_map", &computeSignMap, py::arg("network"), py::arg("verbose") = false, R"mydelimiter( For each node, compute the sign of its input and output values. @@ -213,15 +230,7 @@ void init_PTQ(py::module &m) { :type network: :py:class:`aidge_core.GraphView` )mydelimiter"); - m.def("dev_ptq", &devPTQ, py::arg("network"), - R"mydelimiter( - Developement and test routine. - :param network: The GraphView under test. - :type network: :py:class:`aidge_core.GraphView` - )mydelimiter"); - m.def("prepare_network", &prepareNetwork, py::arg("network"), "prepare the network for the PTQ"); - } } // namespace Aidge diff --git a/python_binding/pybind_QAT_LSQ.cpp b/python_binding/pybind_QAT_LSQ.cpp index 206985efe4558a84ce1ed67a1264bd6902213764..dd118dccc24dca71185c9401a924fbae0d22cc6c 100644 --- a/python_binding/pybind_QAT_LSQ.cpp +++ b/python_binding/pybind_QAT_LSQ.cpp @@ -9,22 +9,21 @@ * ********************************************************************************/ -#include <pybind11/pybind11.h> -#include <pybind11/stl.h> - -#include "aidge/quantization/QAT/QAT_LSQ.hpp" -#include "aidge/graph/GraphView.hpp" - -namespace py = pybind11; - -namespace Aidge { - -void init_QAT_LSQ(py::module &m) { - - auto mQuantLSQ = m.def_submodule("lsq"); - - mQuantLSQ.def("insert_quantizers", &QuantLSQ::insertQuantizers, py::arg("network"), py::arg("nb_bits"), py::arg("step_size")); - - mQuantLSQ.def("insert_and_init_quantizers", &QuantLSQ::insertAndInitQuantizers, py::arg("network"), py::arg("nb_bits"), py::arg("calibration_data")); -} -} // namespace Aidge + #include <pybind11/pybind11.h> + #include <pybind11/stl.h> + + #include "aidge/quantization/QAT/QAT_LSQ.hpp" + #include "aidge/graph/GraphView.hpp" + + namespace py = pybind11; + + namespace Aidge { + + void init_QAT_LSQ(py::module &m) { + + auto mQuantLSQ = m.def_submodule("lsq"); + + mQuantLSQ.def("setup_quantizers", &QuantLSQ::setupQuantizers, py::arg("network"), py::arg("nb_bits")); + } + } // namespace Aidge + \ No newline at end of file diff --git a/python_binding/pybind_Quantization.cpp b/python_binding/pybind_Quantization.cpp index 7ac344dcfcd4fc93e3bba1dcd19c1413f5a29d0c..27f6885d48fcafc81d8679deebced6f36ce67e4a 100644 --- a/python_binding/pybind_Quantization.cpp +++ b/python_binding/pybind_Quantization.cpp @@ -31,6 +31,7 @@ void init_DoReFa(py::module& m); // quantization routines void init_PTQ(py::module &m); +void init_PTQMetaOps(py::module &m); void init_QAT_FixedQ(py::module &m); void init_QAT_LSQ(py::module &m); void init_QuantRecipes(py::module &m); @@ -45,6 +46,7 @@ PYBIND11_MODULE(aidge_quantization, m) init_DoReFa(m); init_PTQ(m); + init_PTQMetaOps(m); init_QAT_FixedQ(m); init_QAT_LSQ(m); init_QuantRecipes(m); diff --git a/python_binding/recipes/pybind_QuantRecipes.cpp b/python_binding/recipes/pybind_QuantRecipes.cpp index 0b96aef775a32cd362013998dd786a9985cc3fc1..15257b0a6b292d3205b6256fecb221ea0a7c7297 100644 --- a/python_binding/recipes/pybind_QuantRecipes.cpp +++ b/python_binding/recipes/pybind_QuantRecipes.cpp @@ -18,13 +18,15 @@ namespace py = pybind11; -namespace Aidge { - -void init_QuantRecipes(py::module &m) { +namespace Aidge +{ +void init_QuantRecipes(py::module &m) +{ m.def("pop_softmax", &popSoftMax, py::arg("network")); m.def("insert_batchnorm_nodes", &insertBatchNormNodes, py::arg("network")); m.def("sanitize_node_names", &sanitizeNodeNames, py::arg("network")); + m.def("reorder_matmul_inputs", &reorderMatMulInputs, py::arg("network")); } } // namespace Aidge diff --git a/setup.py b/setup.py index 1bfc0ac515fd8cceeec4cba666addc1e7666fd25..4fc2bc419addb150db8b1d4610275bbed5479e4a 100644 --- a/setup.py +++ b/setup.py @@ -63,7 +63,7 @@ class AidgePkgBuild(build_ext): cxx_compiler = os.environ.get("AIDGE_CXX_COMPILER", "g++") build_type = os.environ.get("AIDGE_BUILD_TYPE", "Release") asan = os.environ.get("AIDGE_ASAN", "OFF") - with_cuda = os.environ.get("AIDGE_WITH_CUDA", "OFF") + with_cuda = os.environ.get("AIDGE_WITH_CUDA", "ON") # default could be "OFF" cmake_arch = os.environ.get("AIDGE_CMAKE_ARCH", "") build_gen = os.environ.get("AIDGE_BUILD_GEN", "") @@ -88,6 +88,7 @@ class AidgePkgBuild(build_ext): f"-DENABLE_ASAN={asan}", f"-DCUDA={with_cuda}", "-DPYBIND=ON", + "-DPYBIND11_FINDPYTHON=ON", f"-DPYBIND_INSTALL_PREFIX:PATH={pybind_install_prefix}", "-DCMAKE_EXPORT_COMPILE_COMMANDS=1", "-DCOVERAGE=OFF", diff --git a/src/PTQ/CLE.cpp b/src/PTQ/CLE.cpp index 5265d9c9b1326e73ee4080fe5f69fed5047a0dbb..33cf14667de7121f93d2804d66e6c8037b643f81 100644 --- a/src/PTQ/CLE.cpp +++ b/src/PTQ/CLE.cpp @@ -20,9 +20,17 @@ #include "aidge/quantization/PTQ/PTQ.hpp" // retrieveNodeVector #include "aidge/graph/GraphView.hpp" -#include "aidge/graph/Node.hpp" -#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/scheduler/SequentialScheduler.hpp" +#include "aidge/scheduler/Scheduler.hpp" #include "aidge/utils/Log.hpp" +#include "aidge/operator/OperatorTensor.hpp" + +#include "aidge/operator/Mul.hpp" +#include "aidge/operator/ArgMax.hpp" +#include "aidge/operator/Abs.hpp" +#include "aidge/operator/Reshape.hpp" +#include "aidge/operator/Round.hpp" +#include "aidge/operator/MetaOperator.hpp" namespace Aidge { @@ -37,29 +45,40 @@ static std::shared_ptr<Tensor> getBiasTensor(std::shared_ptr<Node> node) return std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(2); } -static void rescaleTensor(std::shared_ptr<Tensor> tensor, double scaling) +static bool nodeHasBias(std::shared_ptr<Node> node) { - // Get the tensor data pointer - double * castedTensor = static_cast<double *> (tensor->getImpl()->rawPtr()); - - // Rescale the tensor - for(std::size_t i = 0; i < tensor->size(); i++) - castedTensor[i] *= scaling; + if (node->getParents().size() == 3) { + std::shared_ptr<Tensor> biasTensor = getBiasTensor(node); + if (biasTensor) + return true; + } + return false; } -static double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor) +std::shared_ptr<Aidge::Tensor> getScaledWeightTensor(std::shared_ptr<Node> node) { - // Get the tensor data pointer and edit it - double * castedTensor = static_cast<double*> (tensor->getImpl()->rawPtr()); - - // Get the tensor absolute max value - double maxValue = 0.0; - for(std::size_t i = 0; i < tensor->size(); ++i) { - if(std::fabs(castedTensor[i]) > maxValue) { - maxValue = std::fabs(castedTensor[i]); - } + if (node->getParent(1)->attributes()->hasAttr("quantization.ptq.isProducerQuantizer")) + { + auto quantizer = node->getParent(1); + + // perform an inference on the branch + + auto graphView = Sequential({quantizer}); + graphView->add(quantizer->getParent(0)); + SequentialScheduler scheduler(graphView); + scheduler.forward(true, {}); + + // gather and return the result + + auto op = std::static_pointer_cast<MetaOperator_Op>(quantizer->getOperator()); + auto result = op->getOutput(0); + return result; + } + else + { + auto result = getWeightTensor(node); + return result; } - return maxValue; } void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetDelta) @@ -67,14 +86,20 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetD std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView); // Check if the CLE can be applied ... + for (std::shared_ptr<Node> node : nodeVector) - if (node->getChildren().size() > 1) - { - Log::notice("Network have multiple branches, skipping the CLE ... "); + { + if (node->getChildren().size() > 1) { + Log::warn(" Network have multiple branches, skipping the CLE ... "); + return; + } + if (isNotQuantized(node)) { + Log::warn(" Network contains non linear nodes, skipping the CLE ... "); return; } + } - Log::info("Applying the Cross-Layer Equalization ... "); + Log::notice(" Applying the Cross-Layer Equalization ... "); // Get the vector of affine nodes @@ -83,46 +108,44 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetD if (isAffine(node)) affineNodeVector.push_back(node); - if (affineNodeVector.empty()) { - Log::notice("No affine nodes found in the network. CLE cannot be applied."); - return; - } double maxRangeDelta; - int iteration = 0; - do { - ++iteration; maxRangeDelta = 0.0; - //std::cout << " ----- " << std::endl; - //for (std::shared_ptr<Node> node : affineNodeVector) - // std::cout << getTensorAbsoluteMax(getWeightTensor(node)) << std::endl; - - for (std::size_t i = 0; i < (affineNodeVector.size() - 1); i++) + + for (size_t i = 0; i < (affineNodeVector.size() - 1); i++) { + // Log::notice(" node index : {} ", i); + std::shared_ptr<Node> n1 = affineNodeVector[i]; std::shared_ptr<Node> n2 = affineNodeVector[i+1]; - double r1 = getTensorAbsoluteMax(getWeightTensor(n1)); - double r2 = getTensorAbsoluteMax(getWeightTensor(n2)); + std::shared_ptr<Aidge::Tensor> w1 = getScaledWeightTensor(n1); + std::shared_ptr<Aidge::Tensor> w2 = getScaledWeightTensor(n2); + + //Log::notice(" TENSOR : \n {}", *w1); + + double r1 = getTensorAbsoluteMax(w1); + double r2 = getTensorAbsoluteMax(w2); double s1 = std::sqrt(r1 * r2) / r1; double s2 = std::sqrt(r1 * r2) / r2; - rescaleTensor(getWeightTensor(n1), s1); - rescaleTensor(getWeightTensor(n2), s2); + multiplyScalingFactor(n1->getParent(1), s1); - rescaleTensor(getBiasTensor(n1), s1); + if (nodeHasBias(n1)) + multiplyScalingFactor(n1->getParent(2), s1); + + multiplyScalingFactor(n2->getParent(1), s2); double rangeDelta = std::abs(r1 - r2); if (rangeDelta > maxRangeDelta) maxRangeDelta = rangeDelta; } + + // Log::notice(" CLE delta = {} ", maxRangeDelta); } while (maxRangeDelta > targetDelta); - - Log::notice("CLE completed after {} iterations. Final max range delta: {:.6f}", - iteration, maxRangeDelta); } } \ No newline at end of file diff --git a/src/PTQ/Clipping.cpp b/src/PTQ/Clipping.cpp index 57ad7a836bbb6251a8eeb6da87e3647b4f54afe2..4107c555e6b356401ed06057c9a06463084b74bd 100644 --- a/src/PTQ/Clipping.cpp +++ b/src/PTQ/Clipping.cpp @@ -18,34 +18,27 @@ namespace Aidge { - -std::map<std::string, std::vector<int>> computeHistograms(std::map<std::string, double> valueRanges, int nbBins, std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet, bool useCuda) + +std::unordered_map<std::shared_ptr<Node>, std::vector<int>> computeHistograms(std::unordered_map<std::shared_ptr<Node>, double> valueRanges, int nbBins, std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> calibrationSet, bool useCuda) { if (useCuda) graphView->setBackend("cuda"); std::shared_ptr<Node> firstNode = retrieveNodeVector(graphView)[0]; - //std::cout << " COMPUTING HISTOGRAMS ... " << std::endl; + // Log::debug(" COMPUTING HISTOGRAMS ... "); - std::map<std::string, std::vector<int>> histograms; + std::unordered_map<std::shared_ptr<Node>, std::vector<int>> histograms; SequentialScheduler scheduler(graphView); scheduler.resetScheduling(); // Setup the histograms ... - for (std::shared_ptr<Node> node : graphView->getNodes()) + for (std::pair<std::shared_ptr<Node>, double> pair : valueRanges) { - bool isInsideRanges = (valueRanges.find(node->name()) != valueRanges.end()); - if (isInsideRanges) - { - std::vector<int> histogram; - for (int i = 0; i < nbBins; i++) - histogram.push_back(0); - - histograms.insert(std::make_pair(node->name(), histogram)); - } + std::vector<int> histogram(nbBins, 0); + histograms.insert(std::make_pair(pair.first, histogram)); } // Fill the histograms ... @@ -54,7 +47,7 @@ std::map<std::string, std::vector<int>> computeHistograms(std::map<std::string, int it = 0; - for (std::shared_ptr<Tensor> inputTensor : inputDataSet) + for (std::shared_ptr<Tensor> inputTensor : calibrationSet) { Log::debug(" IT (BIS) : {}", it++); @@ -66,35 +59,32 @@ std::map<std::string, std::vector<int>> computeHistograms(std::map<std::string, scheduler.forward(true, {inputTensor}); // Gather values ... - - for (std::shared_ptr<Node> node : graphView->getNodes()) + + for (std::pair<std::shared_ptr<Node>, double> pair : valueRanges) { - bool isInsideRanges = (valueRanges.find(node->name()) != valueRanges.end()); - if (isInsideRanges) - { - double valueRange = valueRanges[node->name()]; + std::shared_ptr<Node> node = pair.first; + double valueRange = pair.second; - std::shared_ptr<Operator> nodeOperator = node->getOperator(); - std::shared_ptr<Tensor> valueTensor = std::static_pointer_cast<Tensor> (nodeOperator->getRawOutput(0)); + std::shared_ptr<Operator> nodeOperator = node->getOperator(); + std::shared_ptr<Tensor> valueTensor = std::static_pointer_cast<Tensor> (nodeOperator->getRawOutput(0)); - if (useCuda) - valueTensor->setBackend("cpu"); + if (useCuda) + valueTensor->setBackend("cpu"); - double * castedTensor = static_cast<double *> (valueTensor->getImpl()->rawPtr()); + double * castedTensor = static_cast<double *> (valueTensor->getImpl()->rawPtr()); - std::vector<int> nodeHistogram = histograms[node->name()]; - for(std::size_t i = 0; i < valueTensor->size(); i++) - { - std::size_t bin = std::round(std::abs(castedTensor[i] / valueRange * nbBins)); - bin = std::min(bin, nodeHistogram.size() - 1); - nodeHistogram[bin]++; - } + std::vector<int> nodeHistogram = histograms[node]; + for(std::size_t i = 0; i < valueTensor->size(); i++) + { + std::size_t bin = std::round(std::abs(castedTensor[i] / valueRange * nbBins)); + bin = std::min(bin, nodeHistogram.size() - 1); + nodeHistogram[bin]++; + } - histograms[node->name()] = nodeHistogram; + histograms[node] = nodeHistogram; - if (useCuda) - valueTensor->setBackend("cuda"); - } + if (useCuda) + valueTensor->setBackend("cuda"); } if (useCuda) @@ -206,8 +196,7 @@ double computeKLClipping(std::vector<int> refHistogram, std::uint8_t nbBits) return bestClipping; } - -std::map<std::string, double> adjustRanges(Clipping clippingMode, std::map<std::string, double> valueRanges, std::uint8_t nbBits, std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet, bool useCuda, bool verbose) +std::unordered_map<std::shared_ptr<Node>, double> adjustRanges(Clipping clippingMode, std::unordered_map<std::shared_ptr<Node>, double> valueRanges, std::uint8_t nbBits, std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> calibrationSet, bool useCuda, bool verbose) { double clipping = 1.0f; @@ -218,13 +207,13 @@ std::map<std::string, double> adjustRanges(Clipping clippingMode, std::map<std:: if (verbose) Log::info(" === CLIPPING VALUES === "); - std::map<std::string, std::vector<int>> histograms = computeHistograms(valueRanges, nbBins, graphView, inputDataSet, useCuda); + std::unordered_map<std::shared_ptr<Node>, std::vector<int>> histograms = computeHistograms(valueRanges, nbBins, graphView, calibrationSet, useCuda); for (std::shared_ptr<Node> node : graphView->getNodes()) { - if (node->type() == "Scaling") + if (node->attributes()->hasAttr("quantization.ptq.isActivationQuantizer")) { - std::vector<int> histogram = histograms[node->name()]; + std::vector<int> histogram = histograms[node]; if (clippingMode == Clipping::MSE) clipping = computeMEClipping(histogram, nbBits, 2.0); @@ -236,12 +225,11 @@ std::map<std::string, double> adjustRanges(Clipping clippingMode, std::map<std:: if (verbose) Log::info(" {:.6f} ({})", clipping, node->name()); - valueRanges[node->name()] *= clipping; + valueRanges[node] *= clipping; } } } - return valueRanges; } diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp index 0e26313475bbbda23a56dcdda52d55a0a5af8204..9e2a62c8b975bbca4f63c90d500324a75e492e64 100644 --- a/src/PTQ/PTQ.cpp +++ b/src/PTQ/PTQ.cpp @@ -9,35 +9,54 @@ * ********************************************************************************/ -#include "aidge/quantization/PTQ/CLE.hpp" -#include "aidge/quantization/PTQ/Clipping.hpp" -#include "aidge/quantization/PTQ/PTQ.hpp" -#include "aidge/quantization/PTQ/PTQMetaOps.hpp" + #include "aidge/quantization/PTQ/CLE.hpp" + #include "aidge/quantization/PTQ/Clipping.hpp" + #include "aidge/quantization/PTQ/PTQ.hpp" + #include "aidge/operator/PTQMetaOps.hpp" + + #include "aidge/data/Tensor.hpp" + #include "aidge/graph/GraphView.hpp" + #include "aidge/graph/Node.hpp" + #include "aidge/scheduler/SequentialScheduler.hpp" + #include "aidge/scheduler/Scheduler.hpp" + #include "aidge/utils/Log.hpp" + + #include "aidge/operator/Producer.hpp" + #include "aidge/operator/Mul.hpp" + #include "aidge/operator/Round.hpp" + #include "aidge/operator/ReLU.hpp" + #include "aidge/operator/BatchNorm.hpp" + #include "aidge/operator/Conv.hpp" + #include "aidge/operator/ArgMax.hpp" + #include "aidge/operator/Reshape.hpp" + #include "aidge/operator/MatMul.hpp" + + #include "aidge/recipes/Recipes.hpp" + #include "aidge/recipes/QuantRecipes.hpp" + #include "aidge/operator/MetaOperator.hpp" +namespace Aidge +{ -#include "aidge/data/Tensor.hpp" -#include "aidge/graph/GraphView.hpp" -#include "aidge/graph/Node.hpp" -#include "aidge/scheduler/SequentialScheduler.hpp" -#include "aidge/scheduler/Scheduler.hpp" -#include "aidge/utils/Log.hpp" - -#include "aidge/operator/Producer.hpp" -#include "aidge/operator/Mul.hpp" -#include "aidge/operator/ReLU.hpp" -#include "aidge/operator/BatchNorm.hpp" -#include "aidge/operator/Conv.hpp" - -#include "aidge/recipes/Recipes.hpp" -#include "aidge/recipes/QuantRecipes.hpp" - +static bool hasAttr(std::shared_ptr<Aidge::Node> node, std::string attr) +{ + return node->attributes()->hasAttr("quantization.ptq." + attr); +} -namespace Aidge +static void addAttr(std::shared_ptr<Aidge::Node> node, std::string attr, double value = 0.0) { + node->attributes()->addAttr("quantization.ptq." + attr, value); +} bool isAffine(std::shared_ptr<Node> node) -{ - return (affineNodeTypes.find(node->type()) != affineNodeTypes.end()); +{ + if (affineNodeTypes.find(node->type()) != affineNodeTypes.end()) + return true; + + if ((node->type() == "MatMul") && hasAttr(node, "isWeighted")) + return true; + + return false; } bool isSeamless(std::shared_ptr<Node> node) @@ -47,79 +66,315 @@ bool isSeamless(std::shared_ptr<Node> node) bool isMerging(std::shared_ptr<Node> node) { - return (mergingNodeTypes.find(node->type()) != mergingNodeTypes.end()); + if (mergingNodeTypes.find(node->type()) != mergingNodeTypes.end()) + return true; + + if ((node->type() == "MatMul") && !hasAttr(node, "isWeighted")) + return true; + + return false; +} + +bool isNotQuantized(std::shared_ptr<Node> node) +{ + return (notQuantizedNodeTypes.find(node->type()) != notQuantizedNodeTypes.end()); +} + +void clearGraphViewInputNodes(std::shared_ptr<GraphView> graphView,DataType targetType) +{ + for (std::shared_ptr<Aidge::Node> node : graphView->inputNodes()) + for (Aidge::IOIndex_t i = node->getFirstFreeDataInput(); i < node->getNbFreeDataInputs(); i++) + node->getOperator()->resetInput(i); } bool checkArchitecture(std::shared_ptr<GraphView> graphView) { - std::set<std::string> otherNodeTypes({"Flatten", "Softmax", "BatchNorm2D", "ReLU", "Producer"}); + std::set<std::string> removedNodeTypes({"Flatten", "Softmax", "BatchNorm2D"}); + + std::set<std::string> specialNodeTypes({"MatMul", "ReLU", "Producer"}); + + std::set<std::string> notQuantizedNodesTypes; for (std::shared_ptr<Node> node : graphView->getNodes()) { - bool isOther = otherNodeTypes.find(node->type()) != otherNodeTypes.end(); - if (!isOther && !isAffine(node) && !isSeamless(node) && !isMerging(node)) { + bool isRemoved = removedNodeTypes.find(node->type()) != removedNodeTypes.end(); + bool isSpecial = specialNodeTypes.find(node->type()) != specialNodeTypes.end(); + if (!isRemoved && !isSpecial && !isAffine(node) && !isSeamless(node) && !isMerging(node) && !isNotQuantized(node)) { Log::warn(" GraphView can't be quantized : node type {} is not supported !", node->type()); return false; } + + if (isNotQuantized(node)) + notQuantizedNodesTypes.insert(node->type()); + } + + if (!notQuantizedNodesTypes.empty()) { + std::string tokens; + for (std::string s : notQuantizedNodesTypes) + tokens += (s + " "); + Log::warn(" Network contains non-linear nodes that won't be quantized : {}", tokens); } return true; } -static void fillTensor(std::shared_ptr<Tensor> tensor, double value) +void prepareNetwork(std::shared_ptr<GraphView> graphView) { - // Get the tensor data pointer - double * castedTensor = static_cast <double *> (tensor->getImpl()->rawPtr()); + // remove the flatten nodes + + removeFlatten(graphView); - // Fill the tensor - for(std::size_t i = 0; i < tensor->size(); i++) - castedTensor[i] = value; + // handle the MatMuls + + reorderMatMulInputs(graphView); + // matMulToFC(graphView); // not working properly atm ! + + // tag the weighted nodes + + std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView); + + for (std::shared_ptr<Node> node : nodeVector) + { + bool isWeighted = isAffine(node); + if (node->type() == "MatMul") + { + std::shared_ptr<Node> parent = node->getParent(1); + if (parent) + if (parent->type() == "Producer") + isWeighted = true; + } + + if (isWeighted) + addAttr(node, "isWeighted"); + } + + // fuse the batchnorms + + bool containsBatchNorm = false; + for (std::shared_ptr<Node> node : nodeVector) { + if (node->type() == "BatchNorm") { + containsBatchNorm = true; + break; + } + } + + if (containsBatchNorm) + fuseBatchNorm(graphView); + + // pop the softmax + + popSoftMax(graphView); +} + +static std::shared_ptr<Aidge::Node> getUniqueChild(std::shared_ptr<Aidge::Node> node) +{ + std::set<std::shared_ptr<Aidge::Node>> childrenSet = node->getChildren(); + AIDGE_ASSERT(childrenSet.size() == 1, " Attempted to access to a unique child while the parent have multiple ones ! "); + return *(childrenSet.begin()); } -static void rescaleTensor(std::shared_ptr<Tensor> tensor, double scaling) +static std::string determineBackend(std::shared_ptr<Aidge::Node> node) { - // Get the tensor data pointer - double * castedTensor = static_cast <double *> (tensor->getImpl()->rawPtr()); + std::string backend = node->getOperator()->backend(); + + if (backend != "") + return backend; + else + { + // gather the parent backends - // Rescale the tensor - for(std::size_t i = 0; i < tensor->size(); i++) - castedTensor[i] *= scaling; + std::set<std::string> parentBackends; + for (auto parent : node->getParents()) + parentBackends.insert(determineBackend(parent)); // it always answers a non empty value ! + + // check if we have two or more different backends gathered + + if (parentBackends.size() > 1) + { + Log::warn(" Unable to determine backend of node {} due to conflicting parent ones !", node->name()); + return (*parentBackends.begin()); + } + + // if all parents have the same backend return it + + if (parentBackends.size() == 1) + return (*parentBackends.begin()); + } + + return "cpu"; // escape path when no parents are found +} + +static int getInputIndex(std::shared_ptr<Node> node, std::shared_ptr<Node> parentNode) +{ + int index = 0; + while (node->getParent(index) != parentNode) + index++; + return index; } -static void roundTensor(std::shared_ptr<Tensor> tensor) +static void insertChildren(std::shared_ptr<Node> parent, std::shared_ptr<Node> newNode, std::shared_ptr<GraphView> graphView) { - // Get the tensor data pointer - double * castedTensor = static_cast <double *> (tensor->getImpl()->rawPtr()); + // Checking the parents always have at least 1 children + + AIDGE_ASSERT(parent->getChildren().size() > 0, " Parent node must have at least one child to insert a new node ! "); + + // Retreive children connection indexes + + std::vector<std::shared_ptr<Node>> nextNodes = parent->getChildren(0); + std::vector<int> inputIndices(nextNodes.size()); + for (std::size_t i = 0; i < nextNodes.size(); i++) { + inputIndices[i] = getInputIndex(nextNodes[i], parent); + } + + // Disconnect childs from parent + + for (std::shared_ptr<Node> nextNode : nextNodes) { + parent->removeChild(nextNode, 0); + } - // Rescale the tensor - for(std::size_t i = 0; i < tensor->size(); i++) - castedTensor[i] = std::nearbyint(castedTensor[i]);//Round + // Insert the new node between the child and the parent + + parent->addChild(newNode, 0, 0); + for (std::size_t i = 0; i < nextNodes.size(); i++) { + newNode->addChild(nextNodes[i], 0, inputIndices[i]); + } + + graphView->add(newNode); } -static double getTensorAbsoluteMax(std::shared_ptr <Tensor> tensor) +void foldProducerQuantizers(std::shared_ptr<GraphView> graphView) { - // Get the tensor data pointer and edit it - double * castedTensor = static_cast<double*>(tensor->getImpl()->rawPtr()); - - // Get the tensor absolute max value - double maxValue = 0.0f; - for(std::size_t i = 0; i < tensor->size(); ++i) { - if(std::fabs(castedTensor[i]) > maxValue) { - maxValue = std::fabs(castedTensor[i]); + std::vector<std::shared_ptr<Node>> producerQuantizers; + for (std::shared_ptr<Node> node : graphView->getNodes()) + if (hasAttr(node, "isProducerQuantizer")) + producerQuantizers.push_back(node); + + for (std::shared_ptr<Node> quantizer : producerQuantizers) + { + // Set the param producer to be constant + + auto paramProducer = quantizer->getParent(0); + auto paramProducerOp = std::static_pointer_cast<Producer_Op>(paramProducer->getOperator()); + paramProducerOp->constant() = true; + + // Set the internal producers of the quantizer to be constant + + auto quantizerOp = std::static_pointer_cast<MetaOperator_Op> (quantizer->getOperator()); + + auto microGraph = quantizerOp->getMicroGraph(); + + for (auto producer : microGraph->getNodes()) + if (producer->type() == "Producer") + { + auto producerOp = std::static_pointer_cast<Producer_Op>(producer->getOperator()); + producerOp->constant() = true; + } + + expandMetaOp(quantizer); // mandatory for now !!! + } + + constantFolding(graphView); +} + +void castQuantizedNetwork(std::shared_ptr<GraphView> graphView, Aidge::DataType targetType, bool singleShift /*, bool bitShiftRounding*/) +{ + std::set<Aidge::DataType> supportedFloatTypes = {DataType::Float16, DataType::Float32, DataType::Float64}; + std::set<Aidge::DataType> supportedIntTypes = {DataType::Int16, DataType::Int32, DataType::Int64}; + + bool castToFloat = (supportedFloatTypes.find(targetType) != supportedFloatTypes.end()); + bool castToInt = (supportedIntTypes.find(targetType) != supportedIntTypes.end()); + + if (castToFloat) + { + graphView->setDataType(targetType); + } + else if (castToInt) + { + if (singleShift) + { + Log::notice(" Replacing scaling nodes with bit-shifts ..."); + + // Replace the scaling nodes with bit-shifts (activations only) + + std::set<std::shared_ptr<Node>> nodes = graphView->getNodes(); // must be called again because of removeRound() ! + for (std::shared_ptr<Node> node : nodes) + { + if (node->type() == "Quantizer") + { + if (hasAttr(node, "isActivationQuantizer")) + { + removeRound(node); + replaceScalingWithBitShift(node); + } + else if (hasAttr(node, "isProducerQuantizer")) + castQuantizerIOs(node, targetType); + } + } + + // Cast the nodes (excepted the producers and quantizers) to integer precision + nodes = graphView->getNodes(); + for (std::shared_ptr<Node> node : nodes) + if (node->type() != "Producer" && !hasAttr(node, "isProducerQuantizer")) // TODO : double check this ! + node->getOperator()->setDataType(targetType); + } + else + { + // Set the nodes (excepted the producers and quantizers) to have integer IOs + + std::set<std::shared_ptr<Node>> nodes = graphView->getNodes(); + for (std::shared_ptr<Node> node : nodes) + if (node->type() != "Quantizer" && node->type() != "Producer") + node->getOperator()->setDataType(targetType); + + // Cast the quantizers input and outputs by inserting Cast nodes + + for (std::shared_ptr<Node> node : nodes) + if (node->type() == "Quantizer") + castQuantizerIOs(node, targetType); } } - return maxValue; + else + { + Log::error(" Cannot cast the quantized network : target type '{}' is not supported ! ", targetType); + } } -// TODO : pass nodeVector by reference ... -static std::vector<std::shared_ptr<Node>> removeMatchingNodes(std::vector<std::shared_ptr<Node>> nodeVector, std::string nodeType) +double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor) { - std::vector<std::shared_ptr<Node>> remainingNodes; - for (std::shared_ptr<Node> node : nodeVector) - if (node->type() != nodeType) - remainingNodes.push_back(node); + std::shared_ptr<Tensor> fallback; + + // get the abs tensor + + std::shared_ptr<Tensor> absTensor = std::make_shared<Tensor>(tensor->abs()); - return remainingNodes; + // flatten the abs tensor + + std::int64_t nbElement = tensor->size(); + + auto reshapeOp = Reshape_Op({nbElement}); + reshapeOp.setDataType(tensor->dataType()); + reshapeOp.setBackend(tensor->backend()); + + reshapeOp.associateInput(0, absTensor); + reshapeOp.forward(); + std::shared_ptr<Tensor> flatTensor = reshapeOp.getOutput(0); + const Tensor& localFlatTensor = flatTensor->refCastFrom(fallback, DataType::Float64, "cpu"); + + // Get the argmax + + auto argmaxOp = ArgMax_Op(0, true, false); + argmaxOp.setDataType(tensor->dataType()); + argmaxOp.setBackend(tensor->backend()); + + argmaxOp.associateInput(0, flatTensor); + argmaxOp.forward(); + const Tensor& argMaxTensor = argmaxOp.getOutput(0)->refCastFrom(fallback, DataType::Float64, "cpu"); + + // Return the max + + int maxIndex = std::round(argMaxTensor.get<double>(0)); + + return localFlatTensor.get<double>(maxIndex); } static void fixScheduling(std::vector<std::shared_ptr<Node>>& nodeVector) { @@ -141,34 +396,46 @@ static void fixScheduling(std::vector<std::shared_ptr<Node>>& nodeVector) { static std::shared_ptr<Tensor> getWeightTensor(std::shared_ptr<Node> node) { - return std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(1); + std::shared_ptr<Node> producer = node->getParent(1); + + if (producer->type() == "Quantizer") + producer = producer->getParent(0); + + return std::static_pointer_cast<OperatorTensor>(producer->getOperator())->getOutput(0); } static std::shared_ptr<Tensor> getBiasTensor(std::shared_ptr<Node> node) { - return std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(2); + std::shared_ptr<Node> producer = node->getParent(2); + + if (producer->type() == "Quantizer") + producer = producer->getParent(0); + + return std::static_pointer_cast<OperatorTensor>(producer->getOperator())->getOutput(0); } std::vector<std::shared_ptr<Node>> retrieveNodeVector(std::shared_ptr<GraphView> graphView, bool newSchedule, bool verbose) { - std::vector<std::shared_ptr<Node>> nodeVector; + std::vector<std::shared_ptr<Node>> nodeVector = graphView->getOrderedNodes(); + + // Remove duplicate nodes. XXX Is it still needed ??? - SequentialScheduler scheduler(graphView); + fixScheduling(nodeVector); - if (newSchedule) - { - scheduler.resetScheduling(); - scheduler.generateScheduling(); // old way : scheduler.forward(); - } + // Remove Producers and their Scalings + + std::vector<std::shared_ptr<Node>> remainingNodes; + for (std::shared_ptr<Node> node : nodeVector) + if ((node->type() != "Producer") && !hasAttr(node, "isProducerQuantizer")) + remainingNodes.push_back(node); - nodeVector = scheduler.getStaticScheduling(); + nodeVector = remainingNodes; - fixScheduling(nodeVector); - nodeVector = removeMatchingNodes(nodeVector, "Producer"); + // Verbose if (verbose) { - Log::info("NB OF NODES = {}", nodeVector.size()); + Log::info(" NB OF NODES = {}", nodeVector.size()); for (std::shared_ptr<Node> node : nodeVector) Log::info("{} {}", node->type(), node->name()); } @@ -181,35 +448,67 @@ static std::shared_ptr<Node> getFirstNode(std::shared_ptr<GraphView> graphView) return retrieveNodeVector(graphView)[0]; } -void prepareNetwork(std::shared_ptr<GraphView> graphView) +// TODO : enhance this by modifying OperatorImpl in "core" ... +static DataType getDataType(std::shared_ptr<Node> node) { - removeFlatten(graphView); + auto op = std::static_pointer_cast<OperatorTensor>(node->getOperator()); + return op->getOutput(0)->dataType(); +} - bool containsBatchNorm = false; - std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView); +// XXX double check this ! +static bool nodeHasBias(std::shared_ptr<Node> node) +{ + if (node->getParents().size() == 3) { + std::shared_ptr<Tensor> biasTensor = getBiasTensor(node); + if (biasTensor) + return true; + } + return false; +} - for (std::shared_ptr<Node> node : nodeVector) - if (node->type() == "BatchNorm") - { - containsBatchNorm = true; +// TODO: rework this ! +static std::shared_ptr<Node> getPreviousScalingNode(std::shared_ptr<Node> node) +{ + std::shared_ptr<Node> currNode = node; + while(!hasAttr(currNode, "isActivationQuantizer")) { + if (currNode->getParents().size() == 0) { + Log::warn(" Warning : No previous Scaling node were found ! "); break; } + currNode = currNode->getParents()[0]; + } + return currNode; +} - if (containsBatchNorm) - fuseBatchNorm(graphView); +void insertScalingBelowProducer(std::shared_ptr<Node> producerNode, std::shared_ptr<GraphView> graphView) +{ + std::string scalingNodeName = makeUniqueName(producerNode->name() + "_ProducerScaling", graphView); + std::shared_ptr<Node> scalingNode = Quantizer(1.0, scalingNodeName);; + addAttr(scalingNode, "isProducerQuantizer"); - popSoftMax(graphView); + scalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) + scalingNode->getOperator()->setBackend(determineBackend(producerNode)); // XXX use the producer parent instead ??? + + insertChildren(producerNode, scalingNode, graphView); } -// TODO : enhance this by modifying OperatorImpl in "core" ... -static DataType getDataType(std::shared_ptr<Node> node) +void insertProducerScalingNodes(std::shared_ptr<GraphView> graphView) { - auto op = std::static_pointer_cast<OperatorTensor>(node->getOperator()); - return op->getOutput(0)->dataType(); + std::set<std::shared_ptr<Node>> nodeSet = graphView->getNodes(); + + for (std::shared_ptr<Node> node : nodeSet) + { + if (isAffine(node)) + { + insertScalingBelowProducer(node->getParent(1), graphView); + if (nodeHasBias(node)) + insertScalingBelowProducer(node->getParent(2), graphView); + } + } } // XXX HERE : Branches containing only Seamless nodes should be considered as residual too !!! -void insertResidualNodes(std::shared_ptr<GraphView> graphView) +void insertResidualScalingNodes(std::shared_ptr<GraphView> graphView) { // TODO: double check this ... @@ -228,14 +527,18 @@ void insertResidualNodes(std::shared_ptr<GraphView> graphView) if (parentIsForking) { // temporary verbose ... + Log::info(" ### found residual branch at index {}", i); Log::info(" ### inserting multiplicative node ..."); std::string residualNodeName = makeUniqueName(parentNode->name() + "_Res", graphView); - std::shared_ptr<Node> residualNode = Scaling(1.0, residualNodeName); + + // XXX + std::shared_ptr<Node> residualNode = Quantizer(1.0, residualNodeName); + addAttr(residualNode, "isActivationQuantizer"); - residualNode->getOperator()->setDataType(DataType::Float64); //getDataType(parentNode) - residualNode->getOperator()->setBackend("cpu"); + residualNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) + residualNode->getOperator()->setBackend(determineBackend(parentNode)); graphView->insertParent(node, residualNode, i, 0, 0); } @@ -244,87 +547,53 @@ void insertResidualNodes(std::shared_ptr<GraphView> graphView) } } -static int getInputIndex(std::shared_ptr<Node> node, std::shared_ptr<Node> parentNode) -{ - int index = 0; - while (node->getParent(index) != parentNode) - index++; - return index; -} - void insertScalingNodes(std::shared_ptr<GraphView> graphView) { - insertResidualNodes(graphView); + insertProducerScalingNodes(graphView); + insertResidualScalingNodes(graphView); std::set<std::shared_ptr<Node>> nodeSet = graphView->getNodes(); for (std::shared_ptr<Node> parentNode : nodeSet) { - if (isAffine(parentNode) || isMerging(parentNode)) + // Insert a Scaling node after each node that have to be quantized + + if (isAffine(parentNode) || isMerging(parentNode) || isNotQuantized(parentNode)) { std::string scalingNodeName = makeUniqueName(parentNode->name() + "_Scaling", graphView); - std::shared_ptr<Node> scalingNode = Scaling(1.0, scalingNodeName); - - scalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) - scalingNode->getOperator()->setBackend("cpu"); - if (parentNode->getChildren().size() > 0) - { - // SCALING NODE INSERTION - - // We always have one output from Affine and Add nodes, but possibly multiple childs - std::vector<std::shared_ptr<Node>> nextNodes = parentNode->getChildren(0); + // XXX XXX XXX + std::shared_ptr<Node> scalingNode = Quantizer(1.0, scalingNodeName); + addAttr(scalingNode, "isActivationQuantizer"); - // For each node in nextNodes store the connexion index - std::vector<int> inputIndices(nextNodes.size()); - for (std::size_t i = 0; i < nextNodes.size(); i++) - inputIndices[i] = getInputIndex(nextNodes[i], parentNode); - - for (std::shared_ptr<Node> nextNode : nextNodes) - parentNode->removeChild(nextNode, 0); + scalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) + scalingNode->getOperator()->setBackend(determineBackend(parentNode)); + if (parentNode->getChildren().size() > 0) { + insertChildren(parentNode, scalingNode, graphView); + } else { parentNode->addChild(scalingNode, 0, 0); - - for (std::size_t i = 0; i < nextNodes.size(); i++) - scalingNode->addChild(nextNodes[i], 0, inputIndices[i]); - graphView->add(scalingNode); } - else + + // In the case the node is a non-linear operator we want to add an extra + // scaling node before it to rescale it's input ... + + if (isNotQuantized(parentNode)) { - // Log::info(" last node reached ! "); - parentNode->addChild(scalingNode, 0, 0); - graphView->add(scalingNode); - } - } - } -} + std::string prevScalingNodeName = makeUniqueName(parentNode->name() + "_PrevScaling", graphView); -static std::shared_ptr<Node> getPreviousScalingNode(std::shared_ptr<Node> mergingNode) -{ - std::shared_ptr<Node> currNode = mergingNode; - while(currNode->type() != "Scaling") - { - if (currNode->getParents().size() == 0) - { - Log::warn(" Warning : No previous Scaling node were found ! "); - break; - } - currNode = currNode->getParents()[0]; - } - return currNode; -} + // XXX XXX XXX + std::shared_ptr<Node> prevScalingNode = Quantizer(1.0, prevScalingNodeName); + addAttr(prevScalingNode, "isActivationQuantizer"); -// XXX double check this ! -static bool nodeHasBias(std::shared_ptr<Node> node) -{ - if (node->getParents().size() == 3) - { - std::shared_ptr<Tensor> biasTensor = getBiasTensor(node); - if (biasTensor) - return true; + prevScalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) + prevScalingNode->getOperator()->setBackend(determineBackend(parentNode)); + + graphView->insertParent(parentNode, prevScalingNode, 0, 0, 0); + } + } } - return false; } void normalizeParameters(std::shared_ptr<GraphView> graphView) @@ -333,11 +602,9 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView) std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView); - std::map<std::string, double> accumulatedRatios; + std::unordered_map<std::shared_ptr<Node>, double> accumulatedRatios; for (std::shared_ptr<Node> node : nodeVector) - { - accumulatedRatios.insert(std::make_pair(node->name(), 1.0)); - } + accumulatedRatios.insert(std::make_pair(node, 1.0)); // ITERATE OVER THE GRAPH ///////////////////////////////////////////////// @@ -346,12 +613,12 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView) for (std::shared_ptr<Node> node : nodeVector) { // Scaling nodes still have a ratio of 1, so they are seamless ... - if (node->type() == "ReLU" || node->type() == "Scaling" || isSeamless(node)) + if (node->type() == "ReLU" || hasAttr(node, "isActivationQuantizer") || isSeamless(node)) { if (node != firstNode) { std::shared_ptr<Node> prevNode = node->getParent(0); - accumulatedRatios[node->name()] = accumulatedRatios[prevNode->name()]; + accumulatedRatios[node] = accumulatedRatios[prevNode]; } } @@ -359,20 +626,21 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView) if (isAffine(node)) { // Rescale the weight tensor + std::shared_ptr<Tensor> weightTensor = getWeightTensor(node); - double scaling = getTensorAbsoluteMax(weightTensor); - double ratio = 1.0 / scaling; - rescaleTensor(weightTensor, ratio); + + double ratio = 1.0 / getTensorAbsoluteMax(weightTensor); + + // rescaleTensor(weightTensor, ratio); + multiplyScalingFactor(node->getParent(1), ratio); // Accumulate the ratio - if (node == firstNode) - { - accumulatedRatios[node->name()] = ratio; - } - else - { + + if (node == firstNode) { + accumulatedRatios[node] = ratio; + } else { std::shared_ptr<Node> prevNode = node->getParent(0); - accumulatedRatios[node->name()] = accumulatedRatios[prevNode->name()] * ratio; + accumulatedRatios[node] = accumulatedRatios[prevNode] * ratio; } // Handle the bias .. @@ -380,83 +648,83 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView) if (nodeHasBias(node)) { std::shared_ptr<Tensor> biasTensor = getBiasTensor(node); - rescaleTensor(biasTensor, accumulatedRatios[node->name()] ); + //rescaleTensor(biasTensor, accumulatedRatios[node] ); + multiplyScalingFactor(node->getParent(2), accumulatedRatios[node]); } } - if (isMerging(node)) + if (isNotQuantized(node)) { - std::vector<std::shared_ptr<Node>> mergingNodes = node->getParents(); + // Gather the previous scaling factor - // Compute the max ratio ... - double maxRatio = 0; - for (std::shared_ptr<Node> mergingNode : mergingNodes) - { - double merginNodeRatio = accumulatedRatios[mergingNode->name()]; - if (merginNodeRatio > maxRatio) - maxRatio = merginNodeRatio; - } + std::shared_ptr<Node> prevScalingNode = getPreviousScalingNode(node); + double prevRatio = accumulatedRatios[prevScalingNode]; - accumulatedRatios[node->name()] = maxRatio; + // Cancel the accumulated ratio - // Rescale the previous scaling Nodes - for (std::shared_ptr<Node> mergingNode : mergingNodes) - { - double mergingNodeRatio = accumulatedRatios[mergingNode->name()]; - double rescaling = mergingNodeRatio / maxRatio; + multiplyScalingFactor(prevScalingNode, 1 / prevRatio); + + // Revert the canceling by using the next scaling node + + accumulatedRatios[node] = prevRatio; + std::shared_ptr<Node> nextScalingNode = getUniqueChild(node); + multiplyScalingFactor(nextScalingNode, prevRatio); + } - std::shared_ptr<Node> scalingNode = getPreviousScalingNode(mergingNode); + if (isMerging(node)) + { + if (node->type() == "MatMul") + { + // Multiply the input scaling factors ! - double currScalingFactor = getScalingFactor(scalingNode); - updateScalingFactor(scalingNode, currScalingFactor / rescaling); + double leftRatio = accumulatedRatios[node->getParent(0)]; + double rightRatio = accumulatedRatios[node->getParent(1)]; - accumulatedRatios[mergingNode->name()] /= rescaling; // optional ... + accumulatedRatios[node] = leftRatio * rightRatio; } - } - } -} + else + { + // Use a maximum arbitration ! -// XXX TODO : take care of the CUDA backend for this too !!! -std::map<std::string, double> computeRanges(std::shared_ptr<GraphView> graphView, std::shared_ptr<Tensor> inputTensor, bool scalingNodesOnly) -{ - std::map<std::string, double> valueRanges; + std::vector<std::shared_ptr<Node>> mergingNodes = node->getParents(); - SequentialScheduler scheduler(graphView); - scheduler.resetScheduling(); + // Compute the max ratio ... - // Inference ... + double maxRatio = 0; + for (std::shared_ptr<Node> mergingNode : mergingNodes) + { + double merginNodeRatio = accumulatedRatios[mergingNode]; + if (merginNodeRatio > maxRatio) + maxRatio = merginNodeRatio; + } - scheduler.forward(true, {inputTensor}); + accumulatedRatios[node] = maxRatio; - // Gather ranges ... + // Rescale the previous scaling Nodes + for (std::shared_ptr<Node> mergingNode : mergingNodes) + { + double mergingNodeRatio = accumulatedRatios[mergingNode]; + double rescaling = mergingNodeRatio / maxRatio; - std::set<std::shared_ptr<Node>> nodeSet = graphView->getNodes(); - for (std::shared_ptr<Node> node : nodeSet) - { - if ((scalingNodesOnly && (node->type() == "Scaling")) || (!scalingNodesOnly && (node->type() != "Producer"))) - { - std::shared_ptr<Operator> nodeOperator = node->getOperator(); - std::shared_ptr<Tensor> valueTensor = std::static_pointer_cast<Tensor> (nodeOperator->getRawOutput(0)); - double range = getTensorAbsoluteMax(valueTensor); + std::shared_ptr<Node> scalingNode = getPreviousScalingNode(mergingNode); + + multiplyScalingFactor(scalingNode, 1 / rescaling); - // Associate the value to the scaling node ... - valueRanges.insert(std::make_pair(node->name(), range)); + accumulatedRatios[mergingNode] /= rescaling; // optional ... + } + } } } - - return valueRanges; } -std::map<std::string, double> computeRanges(std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet, bool scalingNodesOnly, bool useCuda) +std::unordered_map<std::shared_ptr<Node>, double> computeRanges(std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> calibrationSet, bool scalingNodesOnly, bool useCuda) { - std::map<std::string, double> valueRanges; + std::unordered_map<std::shared_ptr<Node>, double> valueRanges; std::set<std::shared_ptr<Node>> nodeSet = graphView->getNodes(); - // std::shared_ptr<Node> inputNode = getFirstNode(graphView); - for (std::shared_ptr<Node> node : nodeSet) - if ((scalingNodesOnly && (node->type() == "Scaling")) || (!scalingNodesOnly && (node->type() != "Producer"))) - valueRanges.insert(std::make_pair(node->name(), 0)); + if ((scalingNodesOnly && hasAttr(node, "isActivationQuantizer")) || (!scalingNodesOnly && (node->type() != "Producer"))) + valueRanges.insert(std::make_pair(node, 0)); if (useCuda) graphView->setBackend("cuda"); @@ -466,7 +734,7 @@ std::map<std::string, double> computeRanges(std::shared_ptr<GraphView> graphView int it = 0; - for (std::shared_ptr<Tensor> sample : inputDataSet) + for (std::shared_ptr<Tensor> sample : calibrationSet) { //Log::info(" IT : {}", it++); @@ -479,10 +747,10 @@ std::map<std::string, double> computeRanges(std::shared_ptr<GraphView> graphView // Gather the sample ranges ... - std::map<std::string, double> sampleRanges; + std::unordered_map<std::shared_ptr<Node>, double> sampleRanges; for (std::shared_ptr<Node> node : nodeSet) { - if ((scalingNodesOnly && (node->type() == "Scaling")) || (!scalingNodesOnly && (node->type() != "Producer"))) + if ((scalingNodesOnly && hasAttr(node, "isActivationQuantizer")) || (!scalingNodesOnly && (node->type() != "Producer"))) { std::shared_ptr<Operator> nodeOperator = node->getOperator(); std::shared_ptr<Tensor> valueTensor = std::static_pointer_cast<Tensor> (nodeOperator->getRawOutput(0)); @@ -493,7 +761,7 @@ std::map<std::string, double> computeRanges(std::shared_ptr<GraphView> graphView double range = getTensorAbsoluteMax(valueTensor); // Associate the value to the scaling node ... - sampleRanges.insert(std::make_pair(node->name(), range)); + sampleRanges.insert(std::make_pair(node, range)); if (useCuda) valueTensor->setBackend("cuda"); @@ -504,12 +772,9 @@ std::map<std::string, double> computeRanges(std::shared_ptr<GraphView> graphView for (std::shared_ptr<Node> node : nodeSet) { - if ((scalingNodesOnly && (node->type() == "Scaling")) || (!scalingNodesOnly && (node->type() != "Producer"))) - { - std::string nodeName = node->name(); - if (sampleRanges[nodeName] > valueRanges[nodeName]) - valueRanges[nodeName] = sampleRanges[nodeName]; - } + if ((scalingNodesOnly && hasAttr(node, "isActivationQuantizer")) || (!scalingNodesOnly && (node->type() != "Producer"))) + if (sampleRanges[node] > valueRanges[node]) + valueRanges[node] = sampleRanges[node]; } if (useCuda) @@ -522,111 +787,116 @@ std::map<std::string, double> computeRanges(std::shared_ptr<GraphView> graphView return valueRanges; } -void normalizeActivations(std::shared_ptr<GraphView> graphView, std::map<std::string, double> valueRanges) +void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_map<std::shared_ptr<Node>, double> valueRanges) { std::shared_ptr<Node> firstNode = getFirstNode(graphView); - // CREATE THE SCALING FACTOR MAP ////////////////////////////////////////// + // CREATE THE ACCUMULATED RATIO MAP /////////////////////////////////////// std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView); - std::map<std::string, double> scalingFactors; + std::unordered_map<std::shared_ptr<Node>, double> accumulatedRatios; for (std::shared_ptr<Node> node : nodeVector) - scalingFactors.insert(std::make_pair(node->name(), 1.0)); + accumulatedRatios.insert(std::make_pair(node, 1.0)); // ITERATE OVER THE GRAPH ///////////////////////////////////////////////// for (std::shared_ptr<Node> node : nodeVector) { // Seamless scaling factor propagation ... - + if (isAffine(node) || isSeamless(node) || node->type() == "ReLU") { - if (node == firstNode) - { - scalingFactors[node->name()] = 1.0; - } - else - { + if (node == firstNode) { + accumulatedRatios[node] = 1.0; + } else { std::shared_ptr<Node> prevNode = node->getParent(0); - scalingFactors[node->name()] = scalingFactors[prevNode->name()]; + accumulatedRatios[node] = accumulatedRatios[prevNode]; } } - // Here prevNode is either a 'Affine' or a 'Merging' - // => do not split the cases, just handle the bias ... + // Use the Scaling nodes to rescale the ranges ... - if (node->type() == "Scaling") + if (hasAttr(node, "isActivationQuantizer")) { - // retrieve the previous scaling factor ... std::shared_ptr<Node> prevNode = node->getParent(0); - double prevScalingFactor = scalingFactors[prevNode->name()]; - // ValueRanges must contains all the scaling nodes !!! - double scalingFactor = valueRanges[node->name()]; + double prevRatio = accumulatedRatios[prevNode]; + double nodeRange = valueRanges[node]; - double currScalingFactor = getScalingFactor(node); - updateScalingFactor(node, currScalingFactor / (scalingFactor / prevScalingFactor)); + multiplyScalingFactor(node, prevRatio / nodeRange); - scalingFactors[node->name()] = scalingFactor; + accumulatedRatios[node] = nodeRange; // If prevNode is Affine, fix the bias ... if (isAffine(prevNode)) - { - bool prevNodeHasBias = nodeHasBias(prevNode); - if (prevNodeHasBias) - { - std::shared_ptr<Tensor> biasTensor = getBiasTensor(prevNode); - rescaleTensor(biasTensor, 1.0 / prevScalingFactor); - } - } + if (nodeHasBias(prevNode)) + multiplyScalingFactor(prevNode->getParent(2), 1.0 / prevRatio); } - // Merging nodes handling : use a maximum arbritation ... + // Merging nodes handling : use a maximum arbritration ... if (isMerging(node)) { - std::vector<std::shared_ptr<Node>> mergingNodes = node->getParents(); - - // Compute the max scaling ... - double maxScaling = 0; - for (std::size_t i = 0; i < mergingNodes.size(); i++) + if (node->type() == "MatMul") { - double merginNodeScaling = scalingFactors[mergingNodes[i]->name()]; - if (merginNodeScaling > maxScaling) { - maxScaling = merginNodeScaling; - } + double leftRatio = accumulatedRatios[node->getParent(0)]; + double rightRatio = accumulatedRatios[node->getParent(1)]; + accumulatedRatios[node] = leftRatio * rightRatio; } + else + { + std::vector<std::shared_ptr<Node>> mergingNodes = node->getParents(); - scalingFactors[node->name()] = maxScaling; + // Compute the max ratio ... - for (std::shared_ptr<Node> mergingNode : mergingNodes) - { - double mergingNodeScaling = scalingFactors[mergingNode->name()]; - double rescaling = mergingNodeScaling / maxScaling; + double maxRatio = 0; + for (std::shared_ptr<Node> mergingNode : mergingNodes) + { + double mergingNodeRatio = accumulatedRatios[mergingNode]; + if (mergingNodeRatio > maxRatio) + maxRatio = mergingNodeRatio; + } - std::shared_ptr<Node> scalingNode = getPreviousScalingNode(mergingNode); - //Log::info(" SCALING NODE : {} {}", scalingNode->type(), scalingNode->name()); + accumulatedRatios[node] = maxRatio; - double currScalingFactor = getScalingFactor(scalingNode); - updateScalingFactor(scalingNode, currScalingFactor * rescaling); + for (std::shared_ptr<Node> mergingNode : mergingNodes) + { + double mergingNodeRatio = accumulatedRatios[mergingNode]; + std::shared_ptr<Node> scalingNode = getPreviousScalingNode(mergingNode); + multiplyScalingFactor(scalingNode, mergingNodeRatio / maxRatio); + // Log::notice(" SCALING NODE : {} {}", scalingNode->type(), scalingNode->name()); + } } } + + if (isNotQuantized(node)) + { + std::shared_ptr<Node> prevScalingNode = node->getParent(0); + double prevRatio = accumulatedRatios[prevScalingNode]; + Log::notice(" prev ratio : {} ", prevRatio); + + // This causes the previous range to not full fill the [-1, 1] interval !!! + // It could be avoided by systematicly add an extra Scaling node before each + // non linearity ... + + multiplyScalingFactor(prevScalingNode, prevRatio); + } } } -std::map<std::string, std::pair<bool, bool>> computeSignMap(std::shared_ptr<GraphView> graphView, bool verbose) +std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> computeSignMap(std::shared_ptr<GraphView> graphView, bool verbose) { std::shared_ptr<Node> firstNode = getFirstNode(graphView); - std::map<std::string, std::pair<bool, bool>> signMap; + std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> signMap; std::pair<bool, bool> unsignedPair(true, true); for (std::shared_ptr<Node> node : graphView->getNodes()) - if (node->type() != "Producer") - signMap.insert(std::make_pair(node->name(), unsignedPair)); + if (node->type() != "Producer") // XXX XXX XXX we should use nodeVector instead ... + signMap.insert(std::make_pair(node, unsignedPair)); // ITERATE OVER THE GRAPH @@ -639,17 +909,17 @@ std::map<std::string, std::pair<bool, bool>> computeSignMap(std::shared_ptr<Grap if (isAffine(node)) { // Affine nodes always have a single parent - if (!isFirstNode) - signMap[node->name()].first = signMap[node->getParent(0)->name()].second; + if (!isFirstNode) + signMap[node].first = signMap[node->getParent(0)].second; else - signMap[node->name()].first = false; + signMap[node].first = false; - signMap[node->name()].second = false; + signMap[node].second = false; } - if (node->type() == "Scaling") + if (hasAttr(node, "isActivationQuantizer")) { - signMap[node->name()].second = false; + signMap[node].second = false; // Scaling nodes always have a single parent std::shared_ptr<Node> parent = node->getParent(0); @@ -662,14 +932,14 @@ std::map<std::string, std::pair<bool, bool>> computeSignMap(std::shared_ptr<Grap // Correct the previous single node (when it is an Affine node) ... if (allChildrenAreReLU) if (isAffine(parent) || isMerging(parent)) - signMap[parent->name()].second = true; + signMap[parent].second = true; // Maintain unsigned output - if (signMap[parent->name()].second) - signMap[node->name()].second = true; + if (signMap[parent].second) + signMap[node].second = true; // Set the link ... - signMap[node->name()].first = signMap[parent->name()].second; + signMap[node].first = signMap[parent].second; } if (isMerging(node)) @@ -680,42 +950,42 @@ std::map<std::string, std::pair<bool, bool>> computeSignMap(std::shared_ptr<Grap bool allParentAreUnsigned = true; for(std::shared_ptr<Node> parent : parentNodes) { - bool parentSign = signMap[parent->name()].second; + bool parentSign = signMap[parent].second; allParentAreSigned &= !parentSign; allParentAreUnsigned &= parentSign; } if (allParentAreSigned) - signMap[node->name()] = std::make_pair(false, false); + signMap[node] = std::make_pair(false, false); else if (allParentAreUnsigned) - signMap[node->name()] = std::make_pair(true, true); + signMap[node] = std::make_pair(true, true); else { // Arbitration : Signed type wins ! for(std::shared_ptr<Node> parent : parentNodes) { - while (parent->type() != "Scaling") + while (!hasAttr(parent, "isActivationQuantizer")) { - signMap[parent->name()] = std::make_pair(false, false); + signMap[parent] = std::make_pair(false, false); // We are on a branch so nodes always have 1 parent ... parent = parent->getParent(0); } - signMap[parent->name()].second = false; + signMap[parent].second = false; } - signMap[node->name()].first = false; + signMap[node].first = false; } } - if (node->type() == "ReLU" || isSeamless(node)) + if (node->type() == "ReLU" || isSeamless(node) || isNotQuantized(node)) { // Thoses nodes always have a single parent std::shared_ptr<Node> parent = node->getParent(0); if (parent) { - signMap[node->name()].first = signMap[parent->name()].second; - signMap[node->name()].second = signMap[node->name()].first; + signMap[node].first = signMap[parent].second; + signMap[node].second = signMap[node].first; } } @@ -727,7 +997,7 @@ std::map<std::string, std::pair<bool, bool>> computeSignMap(std::shared_ptr<Grap { Log::info(" === SIGN MAP === "); for (std::shared_ptr<Node> node : nodeVector) - Log::info(" {}{} | {}", static_cast<int>(signMap[node->name()].first), static_cast<int>(signMap[node->name()].second), node->name()); + Log::info(" {}{} | {}", static_cast<int>(signMap[node].first), static_cast<int>(signMap[node].second), node->name()); } // SANITY CHECK (TEMPORARY) @@ -736,7 +1006,7 @@ std::map<std::string, std::pair<bool, bool>> computeSignMap(std::shared_ptr<Grap { for (std::shared_ptr<Node> child : node->getChildren()) { - if (signMap[node->name()].second != signMap[child->name()].first) + if (signMap[node].second != signMap[child].first) Log::error(" computeSignMap : link is not sane ! ({} -> {})", node->name(), child->name()); } } @@ -744,27 +1014,25 @@ std::map<std::string, std::pair<bool, bool>> computeSignMap(std::shared_ptr<Grap return signMap; } - void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, bool noQuant, bool optimizeSigns, bool verbose) { if (optimizeSigns && noQuant) - { - AIDGE_THROW_OR_ABORT(std::runtime_error,"Signs optimization can not be applied if network is not fully quantized ..."); - } + AIDGE_THROW_OR_ABORT(std::runtime_error, " Sign-optimization can not be applied if network is not fully quantized ..."); double signedMax = (1 << (nbBits - 1)) - 1; double unsignedMax = (1 << nbBits) - 1; - std::map<std::string, std::pair<bool, bool>> signMap; + std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> signMap; if (optimizeSigns) signMap = computeSignMap(graphView, verbose); else { + // XXX XXX XXX we should use the (retreive) node vector std::pair<bool, bool> signedPair(false, false); for (std::shared_ptr<Node> node : graphView->getNodes()) if (node->type() != "Producer") - signMap.insert(std::make_pair(node->name(), signedPair)); + signMap.insert(std::make_pair(node, signedPair)); } // ITERATE OVER THE GRAPH ///////////////////////////////////////////////// @@ -776,98 +1044,111 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_ if (isAffine(node)) { // Rescale the weight tensor + multiplyScalingFactor(node->getParent(1), signedMax); - std::shared_ptr<Tensor> weightTensor = getWeightTensor(node); - rescaleTensor(weightTensor, signedMax); - + // UUU Quantize the Producer !!! if (!noQuant) - roundTensor(weightTensor); + appendRoundClip(node->getParent(1), -(signedMax + 1), signedMax); // Rescale the bias tensor - if (nodeHasBias(node)) { - bool inputIsUnsigned = signMap[node->name()].first; + bool inputIsUnsigned = signMap[node].first; double rescaling = inputIsUnsigned ? unsignedMax * signedMax : signedMax * signedMax; - - - std::shared_ptr<Tensor> biasTensor = getBiasTensor(node); - rescaleTensor(biasTensor, rescaling); + + multiplyScalingFactor(node->getParent(2), rescaling); + // XXX TODO : enhance this ! + int biasMax = (1 << (12 + nbBits)); if (!noQuant) - roundTensor(biasTensor); + appendRoundClip(node->getParent(2), -(biasMax + 1), biasMax); } // Compensate the rescaling using the next Scaling node double rescaling = 1.0 / signedMax; - bool inputIsUnsigned = signMap[node->name()].first; - bool outputIsUnsigned = signMap[node->name()].second; + bool inputIsUnsigned = signMap[node].first; + bool outputIsUnsigned = signMap[node].second; rescaling /= inputIsUnsigned ? unsignedMax : signedMax; rescaling *= outputIsUnsigned ? unsignedMax : signedMax; - std::shared_ptr<Node> scalingNode = *(node->getChildren().begin()); // Assert if scalingNode is a Scaling ... + std::shared_ptr<Node> scalingNode = getUniqueChild(node); // TODO : assert if scalingNode is a Scaling ... - double currScalingFactor = getScalingFactor(scalingNode); - updateScalingFactor(scalingNode, currScalingFactor * rescaling); + multiplyScalingFactor(scalingNode, rescaling); } if (isMerging(node)) { double rescaling = 1.0; - bool inputIsUnsigned = signMap[node->name()].first; - bool outputIsUnsigned = signMap[node->name()].second; + bool inputIsUnsigned = signMap[node].first; + bool outputIsUnsigned = signMap[node].second; rescaling /= inputIsUnsigned ? unsignedMax : signedMax; rescaling *= outputIsUnsigned ? unsignedMax : signedMax; - std::shared_ptr<Node> scalingNode = *(node->getChildren().begin()); // Assert if scalingNode is a Scaling ... + std::shared_ptr<Node> scalingNode = getUniqueChild(node); // TODO : assert if scalingNode is a Scaling ... - double currScalingFactor = getScalingFactor(scalingNode); // XXX bad naming - updateScalingFactor(scalingNode, currScalingFactor * rescaling); + // TODO : double check this ... + if (node->type() == "MatMul") + rescaling /= inputIsUnsigned ? unsignedMax : signedMax; + + multiplyScalingFactor(scalingNode, rescaling); + } + + if (isNotQuantized(node)) + { + double rescaling = 1 / signedMax; // XXX handle the signs !!! + + std::shared_ptr<Node> prevScalingNode = node->getParent(0); + multiplyScalingFactor(prevScalingNode, rescaling); + + std::shared_ptr<Node> nextScalingNode = getUniqueChild(node); + multiplyScalingFactor(nextScalingNode, 1 / rescaling); } // Handle the Scaling Nodes ... - if (node->type() == "Scaling") + if (hasAttr(node, "isActivationQuantizer")) { - if (!noQuant) + // Don't touch the scalings that precede non-linearities ... + + bool precedesNonLinearNode = false; + if (node->getChildren().size() == 1) + if (isNotQuantized(getUniqueChild(node))) + precedesNonLinearNode = true; + + if (!noQuant && !precedesNonLinearNode) { - // Replace the Scaling Node by Quantizer + // we need to gather the sign informations before we modify + // the node pointer with appendRoundClip() ... - std::shared_ptr<Node> quantizerNode = Quantizer(getScalingFactor(node), -(signedMax + 1), signedMax, node->name()); - quantizerNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) - quantizerNode->getOperator()->setBackend("cpu"); + bool inputIsUnsigned = signMap[node].first; + bool outputIsUnsigned = signMap[node].second; - graphView->replace({node}, {quantizerNode}); + appendRoundClip(node, -(signedMax + 1), signedMax); if (optimizeSigns) { double rescaling = 1.0; - bool inputIsUnsigned = signMap[node->name()].first; - bool outputIsUnsigned = signMap[node->name()].second; - rescaling /= inputIsUnsigned ? unsignedMax : signedMax; rescaling *= outputIsUnsigned ? unsignedMax : signedMax; - double currScalingFactor = getScalingFactor(quantizerNode); - updateScalingFactor(quantizerNode, currScalingFactor * rescaling); + // XXX XXX XXX + multiplyScalingFactor(node, rescaling); - if(outputIsUnsigned) - { - setClipRange(quantizerNode,0,unsignedMax); - } + if (outputIsUnsigned) + setClipRange(node, 0, unsignedMax); } } } } } -static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits) +void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, bool noQuant) { // XXX Use the signMap to increase the resolution when possible ... double signedMax = (1 << (nbBits - 1)) - 1; @@ -876,102 +1157,102 @@ static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::u for (std::shared_ptr<Node> node : nodeVector) { - // A merging node is always followed by a scaling node at this point ... - + // The appropriate strategy is to check if the Quantizer is not + // preceded by an Weighted node (that is not forking), and insert + // a mul node (Compensation) before it if so ... + if (node->type() == "Quantizer") - { - bool prevNodeIsForking = ((node->getParent(0))->getChildren().size() > 1); - bool prevNodeIsAffine = isAffine(node->getParent(0)); - bool insertNode = prevNodeIsForking || !prevNodeIsAffine; + { + // Note : this works because a Quantizer has only one Parent ... - if (insertNode) - { - // create and insert the multplicative node + std::shared_ptr<Node> parentNode = node->getParent(0); + bool parentHasWeight = isAffine(parentNode); + bool parentIsForking = (parentNode->getChildren().size() > 1); + + if (parentIsForking || !parentHasWeight) // insert a Compensation Node ... + { + // Create and insert the multiplicative node before the Quantizer std::string mulNodeName = makeUniqueName(node->name() + "_Mul", graphView); std::shared_ptr<Node> mulNode = Mul(mulNodeName); + + // XXX XXX XXX addAttr(mulNode, "isCompensation"); mulNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) - mulNode->getOperator()->setBackend("cpu"); + mulNode->getOperator()->setBackend(determineBackend(node)); graphView->insertParent(node, mulNode, 0, 0, 0); - // create and insert the producer node + // Add the coeff producer to the multiplier node - std::shared_ptr<Tensor> inputTensor = std::static_pointer_cast<Tensor> (mulNode->getOperator()->getRawInput(0)); - std::shared_ptr<Tensor> coeffTensor = std::make_shared<Tensor>(); + std::shared_ptr<Node> coeffProducer = addProducer(mulNode, 1, {1}, ""); + std::shared_ptr<Tensor> coeffTensor = std::make_shared<Tensor>(Array1D<double, 1> {signedMax}); + coeffProducer->getOperator()->setOutput(0, coeffTensor); - coeffTensor->setDataType(DataType::Float64); // getDataType(parentNode) - coeffTensor->setBackend("cpu"); + coeffProducer->getOperator()->setDataType(DataType::Float64); + coeffProducer->getOperator()->setBackend(determineBackend(node)); - coeffTensor->resize(inputTensor->dims()); - fillTensor(coeffTensor, 1); + graphView->add(coeffProducer); // needed ? - std::shared_ptr<Node> producerNode = Producer(coeffTensor, makeUniqueName("coeff", graphView)); - producerNode->addChild(mulNode); - graphView->add(producerNode); + // Adapt the scaling factor value accordingly - // rescale the coeffs and edit scaling factor + multiplyScalingFactor(node, 1.0 / signedMax); // XXX XXX XXX OK - fillTensor(coeffTensor, signedMax); + // Insert a Quantizer for the coeffProducer that will handle + // the single-shift approximation via it's scalingFactor ... - double currScalingFactor = getScalingFactor(node); // XXX bad naming ! - updateScalingFactor(node, currScalingFactor / signedMax); + insertScalingBelowProducer(coeffProducer, graphView); + + if (!noQuant) + { + // XXX XXX XXX double check this ... + std::shared_ptr<Node> coeffQuantizer = mulNode->getParent(1); + appendRoundClip(coeffQuantizer, -(signedMax + 1), signedMax); + } - // TODO : double check this !!! - //std::cout << getTensorAbsoluteMax(coeffTensor) << std::endl; } } } } -void performSingleShiftApproximation(std::shared_ptr<GraphView> graphView, bool noQuant) +void performSingleShiftApproximation(std::shared_ptr<GraphView> graphView) { std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView); for (std::shared_ptr<Node> node : nodeVector) { - // Use A meatoperator of type Scaling of MulCompensation instead - if (isAffine(node) || (node->type() == "Mul")) + if (node->type() == "Quantizer") { - std::shared_ptr<Node> scalingNode = (*node->getChildren().begin()); - - double base = getScalingFactor(scalingNode); + std::shared_ptr<Node> linearNode = node->getParent(0); + double base = getScalingFactor(node); double approx = std::pow(2, std::ceil(std::log2(base))); + double ratio = approx / base; - updateScalingFactor(scalingNode,approx); + // set the scaling factor value to the approximation ... - double ratio = base / approx; + multiplyScalingFactor(node, ratio); - std::shared_ptr<Tensor> weightTensor = getWeightTensor(node); - rescaleTensor(weightTensor, ratio); - if (!noQuant) - roundTensor(weightTensor); + // compensate the ratio using the previous node scaling factors ... - if (nodeHasBias(node)) - { - std::shared_ptr<Tensor> biasTensor = getBiasTensor(node); - rescaleTensor(biasTensor, ratio); - if (!noQuant) - roundTensor(biasTensor); - } + multiplyScalingFactor(linearNode->getParent(1), 1.0 / ratio); + if (nodeHasBias(linearNode)) + multiplyScalingFactor(linearNode->getParent(2), 1.0 / ratio); } } } static void printScalingFactors(std::shared_ptr<GraphView> graphView) { - Log::info(" === SCALING FACTORS === "); for (auto node : retrieveNodeVector(graphView)) - if (node->type() == "Scaling" || node->type() == "Quantizer") + if (hasAttr(node, "isActivationQuantizer") || node->type() == "Quantizer") { double scalingFactor = getScalingFactor(node); - Log::info(" {:.6f} ({})", scalingFactor, node->name()); + Log::notice(" SCALING FACTOR : {} ({})", scalingFactor, node->name()); } } -static void setupDataType(std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet, DataType dataType) +static void setupDataType(std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> calibrationSet, DataType dataType) { graphView->setDataType(dataType); @@ -982,95 +1263,83 @@ static void setupDataType(std::shared_ptr<GraphView> graphView, std::vector<std: inputTensor->setDataType(dataType); } - for (auto tensor : inputDataSet) + for (auto tensor : calibrationSet) tensor->setDataType(dataType); } -static void printRanges(std::shared_ptr<GraphView> graphView, std::map<std::string, double> valueRanges) -{ - SequentialScheduler scheduler(graphView); - scheduler.resetScheduling(); - scheduler.generateScheduling(); - - auto scheduling = scheduler.getStaticScheduling(); - for (auto node : scheduling) - if (node->type() == "Scaling") - fmt::println("{} range = {}", node->name(), valueRanges[node->name()]); -} - -void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet, Clipping clippingMode, bool noQuant, bool optimizeSigns, bool singleShift, bool useCuda, bool verbose) +void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> calibrationSet, DataType targetType, Clipping clippingMode, bool noQuant, bool optimizeSigns, bool singleShift, bool useCuda, bool foldGraph, bool verbose) { - Log::info(" === QUANT PTQ 0.2.21 === "); + Log::notice(" === QUANT PTQ 0.3.0 === "); graphView->setBackend("cpu"); - DataType initialDataType = (inputDataSet[0])->dataType(); - setupDataType(graphView, inputDataSet, DataType::Float64); - if (!checkArchitecture(graphView)) return; - Log::info(" Preparing the network for the PTQ ... "); + DataType initialDataType = (calibrationSet[0])->dataType(); + setupDataType(graphView, calibrationSet, DataType::Float64); + + Log::notice(" Preparing the network for the PTQ ... "); prepareNetwork(graphView); - Log::info(" Inserting the scaling nodes ..."); + Log::notice(" Inserting the scaling nodes ..."); insertScalingNodes(graphView); - crossLayerEqualization(graphView); + // TODO : double check the CLE ... + crossLayerEqualization(graphView); // XXX XXX XXX - Log::info(" Normalizing the parameters ..."); + Log::notice(" Normalizing the parameters ..."); normalizeParameters(graphView); - Log::info(" Computing the value ranges ..."); - std::map<std::string, double> valueRanges = computeRanges(graphView, inputDataSet, true, useCuda); + Log::notice(" Computing the value ranges ..."); + std::unordered_map<std::shared_ptr<Node>, double> valueRanges = computeRanges(graphView, calibrationSet, true, useCuda); - //std::cout << " === RANGES (BEFORE ADJUST) ===" << std::endl; - //printRanges(graphView, valueRanges); + Log::notice(" Optimizing the clipping values ..."); + valueRanges = adjustRanges(clippingMode, valueRanges, nbBits, graphView, calibrationSet, useCuda, verbose); - Log::info(" Optimizing the clipping values ..."); - valueRanges = adjustRanges(clippingMode, valueRanges, nbBits, graphView, inputDataSet, useCuda, verbose); - - //std::cout << " === RANGES (AFTER ADJUST) ===" << std::endl; - //printRanges(graphView, valueRanges); - - Log::info(" Normalizing the activations ..."); + Log::notice(" Normalizing the activations ..."); normalizeActivations(graphView, valueRanges); - Log::info(" Quantizing the normalized network ..."); + Log::notice(" Quantizing the normalized network ..."); quantizeNormalizedNetwork(graphView, nbBits, noQuant, optimizeSigns, verbose); if (singleShift) { - Log::info( " Inserting the compensation nodes ..."); - insertCompensationNodes(graphView, nbBits); + Log::notice( " Inserting the compensation nodes ..."); + insertCompensationNodes(graphView, nbBits, noQuant); - Log::info(" Performing the Single-Shift approximation ..."); - performSingleShiftApproximation(graphView, noQuant); + Log::notice(" Performing the Single-Shift approximation ..."); + performSingleShiftApproximation(graphView); } + Log::notice(" Casting the network to the target type ({}) ...", targetType); + castQuantizedNetwork(graphView, targetType, singleShift); + + if (foldGraph) + { + Log::notice(" Folding the Producer's Quantizers ..."); + foldProducerQuantizers(graphView); + } + + // TODO ... + // Log::notice(" Clearing the input nodes ..."); + if (verbose) printScalingFactors(graphView); - //std::cout << " === SCALINGS (BEFORE CAST) ===" << std::endl; - //printScalingFactors(graphView); - - setupDataType(graphView, inputDataSet, initialDataType); if (useCuda) graphView->setBackend("cuda"); - //std::cout << " === SCALINGS (AFTER CAST) ===" << std::endl; - //printScalingFactors(graphView); - - Log::info(" Reseting the scheduler ..."); + Log::notice(" Reseting the scheduler ..."); SequentialScheduler scheduler(graphView); scheduler.resetScheduling(); - Log::info(" Network is quantized !"); + Log::notice(" Network is quantized !"); } -std::map<std::string, double> getWeightRanges(std::shared_ptr<GraphView> graphView) +std::unordered_map<std::string, double> getWeightRanges(std::shared_ptr<GraphView> graphView) { - std::map<std::string, double> weightRanges; + std::unordered_map<std::string, double> weightRanges; for (std::shared_ptr<Node> node : graphView->getNodes()) { @@ -1090,15 +1359,10 @@ void clearBiases(std::shared_ptr<GraphView> graphView) for (std::shared_ptr<Node> node : graphView->getNodes()) { if (node->type() == "FC" || node->type() == "Conv2D") { std::shared_ptr<Tensor> biasTensor = std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(2); - rescaleTensor(biasTensor, 0); + //rescaleTensor(biasTensor, 0); + //insertScalingBelowProducer(node->getParent(2), 0, graphView); + multiplyScalingFactor(node->getParent(2), 0); } } } - -void devPTQ(std::shared_ptr<GraphView> graphView) -{ - for (std::shared_ptr<Node> node : graphView->getNodes()) - fmt::println(" UUU : {}", node->name()); -} - } diff --git a/src/PTQ/PTQMetaOps.cpp b/src/PTQ/PTQMetaOps.cpp deleted file mode 100644 index 77018c23aee2f1ef6f430389393fd35e97baa0f6..0000000000000000000000000000000000000000 --- a/src/PTQ/PTQMetaOps.cpp +++ /dev/null @@ -1,152 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include "aidge/quantization/PTQ/PTQMetaOps.hpp" - -#include <memory> -#include <string> -#include <utility> - -//Operator -#include "aidge/operator/Clip.hpp" -#include "aidge/operator/Mul.hpp" -#include "aidge/operator/Round.hpp" - -#include "aidge/graph/Node.hpp" -#include "aidge/graph/OpArgs.hpp" -#include "aidge/operator/MetaOperator.hpp" -#include "aidge/operator/Producer.hpp" -#include "aidge/utils/ArrayHelpers.hpp" -#include "aidge/utils/Types.h" -#include "aidge/operator/Identity.hpp" -#include "aidge/data/Tensor.hpp" -#include "aidge/operator/OperatorTensor.hpp" -#include "aidge/utils/Log.hpp" - - -namespace Aidge -{ - -std::shared_ptr<Node> Quantizer(double scalingFactor, double clipMin, double clipMax, const std::string& name) -{ - // create the nodes - - std::shared_ptr<Node> mulNode = Mul((!name.empty()) ? name + "_MulQuant" : ""); - std::shared_ptr<Node> roundNode = Round((!name.empty()) ? name + "_RoundQuant" : ""); - std::shared_ptr<Node> clipNode = Clip((!name.empty()) ? name + "_ClipQuant" : "", clipMin, clipMax); - - // connect the scaling factor producer - - std::shared_ptr<Tensor> scalingFactorTensor = std::make_shared<Tensor>(Array1D<double, 1> {scalingFactor}); - std::shared_ptr<Node> scalingFactorProducer = addProducer<1>(mulNode, 1, {1}, "ScalingFactor"); - scalingFactorProducer->getOperator()->setOutput(0, scalingFactorTensor); - - // create the metaop graph - - std::shared_ptr<GraphView> graphView = Sequential({mulNode, roundNode, clipNode}); - std::shared_ptr<GraphView> connectedGraphView = getConnectedGraphView(mulNode); // XXX why not use the graphView ??? - - // return the metaop - - std::shared_ptr<Node> metaopNode = MetaOperator("Quantizer", connectedGraphView, {}, name); // XXX alternative prototype - - return metaopNode; -} - -std::shared_ptr<Node> Scaling(double scalingFactor, const std::string& name) -{ - std::shared_ptr<Tensor> scalingFactorTensor = std::make_shared<Tensor>(Array1D<double, 1> {scalingFactor}); - - std::shared_ptr<Node> mulNode = Mul((!name.empty()) ? name + "_Scaling" : ""); - - std::shared_ptr<Node> scalingFactorProducer = addProducer<1>(mulNode, 1, {1}, "ScalingFactor"); - scalingFactorProducer->getOperator()->setOutput(0, scalingFactorTensor); - - std::shared_ptr<GraphView> graphView = Sequential({mulNode}); - std::shared_ptr<GraphView> connectedGraphView = getConnectedGraphView(mulNode); - - NodePtr metaopNode = MetaOperator("Scaling", connectedGraphView, {}, name); - - return metaopNode; -} - -static std::shared_ptr<Node> getSubNode(std::shared_ptr<GraphView> graphView, std::string nodeType) -{ - std::shared_ptr<Node> mulNode = nullptr; - for(std::shared_ptr<Node> node : graphView->getNodes()) - if (node->type() == nodeType) - mulNode = node; - - return mulNode; -} - -void updateScalingFactor(std::shared_ptr<Node> metaOpNode, double scalingFactor) -{ - if(metaOpNode->type() != "Scaling" && metaOpNode->type() != "Quantizer") - Log::warn(" Cannot update the scaling factor on Node of type {}", metaOpNode->type()); - - std::shared_ptr<Tensor> scalingFactorTensor = std::make_shared<Tensor>(Array1D<double, 1> {scalingFactor}); - - std::shared_ptr<MetaOperator_Op> metaOp = std::static_pointer_cast<MetaOperator_Op>(metaOpNode->getOperator()); - - std::shared_ptr<Node> mulNode = getSubNode(metaOp->getMicroGraph(), "Mul"); - - if (!mulNode) - Log::warn(" Invalid PTQ MetaOperator, no Mul node found inside ! "); - - mulNode->input(1).first->getOperator()->setOutput(0, scalingFactorTensor); -} - -double getScalingFactor(std::shared_ptr<Node> MetaOpNode) -{ - if (MetaOpNode->type() != "Scaling" && MetaOpNode->type() != "Quantizer") { - Log::warn(" Cannot get the scaling factor on Node of type {}", MetaOpNode->type()); - return 0; - } - - std::shared_ptr<MetaOperator_Op> metaOp = std::static_pointer_cast<MetaOperator_Op>(MetaOpNode->getOperator()); - - std::shared_ptr<Node> mulNode = getSubNode(metaOp->getMicroGraph(), "Mul"); - - if (!mulNode) { - Log::warn(" Invalid PTQ MetaOperator, no Mul found inside node of type {}", MetaOpNode->type()); - return 0; - } - - auto scalingFactorTensor = std::static_pointer_cast<OperatorTensor>(mulNode->getOperator())->getInput(1); - std::shared_ptr<Tensor> fallback; - const Tensor& localTensor = scalingFactorTensor->refCastFrom(fallback, DataType::Float64, "cpu"); - - return localTensor.get<double>(0); -} - - -void setClipRange(std::shared_ptr<Node> quantizerNode, double min, double max) -{ - if (quantizerNode->type() != "Quantizer") { - Log::warn(" Cannot set the clipping range on Node of type {}", quantizerNode->type()); - return; - } - - std::shared_ptr<MetaOperator_Op> metaOp = std::static_pointer_cast<MetaOperator_Op> (quantizerNode->getOperator()); - - std::shared_ptr<Node> clipNode = getSubNode(metaOp->getMicroGraph(), "Clip"); - - if (!clipNode) { - Log::warn(" Invalid PTQ MetaOperator, no Clip found inside node of type {}", quantizerNode->type()); - return; - } - - std::shared_ptr<Clip_Op> clipOp = std::static_pointer_cast<Clip_Op>(clipNode->getOperator()); - clipOp->max() = max; - clipOp->min() = min; -} -} \ No newline at end of file diff --git a/src/QAT/QAT_FixedQ.cpp b/src/QAT/QAT_FixedQ.cpp index 9160b4ae6add5ae0347e008962956dc90c3a36fd..b51123cc0d612714b280423e515fbf25006bbc72 100644 --- a/src/QAT/QAT_FixedQ.cpp +++ b/src/QAT/QAT_FixedQ.cpp @@ -91,7 +91,7 @@ static std::map<std::string, float> collectInputStats(std::shared_ptr<GraphView> const auto op = std::static_pointer_cast<FixedQ_Op>(node->getOperator()); float inputStd = getTensorStd(op->getInput(0)); inputStats.insert(std::make_pair(node->name(), inputStd)); - fmt::println("{} -> {}", node->name(), inputStd); + Log::info(" {} -> {} ", node->name(), inputStd); } } @@ -108,7 +108,7 @@ static std::map<std::string, float> collectParamStats(std::shared_ptr<GraphView> const auto op = std::static_pointer_cast<FixedQ_Op>(node->getOperator()); float paramStd = getTensorStd(op->getInput(1)); paramStats.insert(std::make_pair(node->name(), paramStd)); - fmt::println("{} -> {}", node->name(), paramStd); + Log::info(" {} -> {} ", node->name(), paramStd); } } @@ -152,11 +152,9 @@ void QuantFixedQ::insertAndInitQuantizers(std::shared_ptr<GraphView> graphView, void QuantFixedQ::devQAT(std::shared_ptr<GraphView> graphView) { - SequentialScheduler scheduler(graphView); - scheduler.generateScheduling(); - auto s = scheduler.getStaticScheduling(); - for (std::shared_ptr<Node> node : s) - fmt::println(" name : {}", node->name()); + auto nodeVector = graphView->getOrderedNodes(); + for (std::shared_ptr<Node> node : nodeVector) + Log::info(" name : {} ", node->name()); } } \ No newline at end of file diff --git a/src/QAT/QAT_LSQ.cpp b/src/QAT/QAT_LSQ.cpp index 9b51e846df498a9303b7373ae1c86d4b007a96f0..6eae077b060027eb4029f6b59f55376a1674df70 100644 --- a/src/QAT/QAT_LSQ.cpp +++ b/src/QAT/QAT_LSQ.cpp @@ -21,193 +21,152 @@ #include "aidge/graph/Matching.hpp" #include "aidge/recipes/QuantRecipes.hpp" -namespace Aidge { -void QuantLSQ::insertQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits, float stepSize) +namespace Aidge { - const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|FC#)"); - for (const auto& match : matches) - { - auto linearNode = match.graph->rootNode(); +static float getTensorAbsMean(std::shared_ptr<Tensor> tensor) +{ + auto valueTensor = (*tensor).abs().mean(); + std::shared_ptr<Tensor> fallback; + const Tensor& localTensor = valueTensor.refCastFrom(fallback, DataType::Float32, "cpu"); + return localTensor.get<float>(0); +} - std::pair<int, int> signedRange = {-std::pow(2, nbBits - 1), std::pow(2, nbBits - 1) - 1}; - std::pair<int, int> unsignedRange = {0, std::pow(2, nbBits) - 1}; +static float getTensorStd(std::shared_ptr<Tensor> tensor) +{ + auto valueTensor = (*tensor); + + auto skewedTensor = valueTensor - valueTensor.mean(); + auto squaredTensor = skewedTensor * skewedTensor; + auto varianceTensor = squaredTensor.mean(); - // INPUT QUANTIZERS INSERTION + std::shared_ptr<Tensor> fallback; + auto localTensor = varianceTensor.refCastFrom(fallback, DataType::Float32, "cpu"); + + float variance = localTensor.get<float>(0); + return std::sqrt(variance); +} - // TODO : double check this, and use createUniqueName() - auto inputQuantizerName = makeUniqueName(linearNode->name() + "_lsq_i", graphView); - auto inputQuantizerNode = LSQ(signedRange, inputQuantizerName); - // Set the step size +// INIT THE STEP SIZE OF A QUANTIZER NODE - auto inputStepSizeOp = inputQuantizerNode->getParent(1)->getOperator(); - auto inputStepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}})); - inputStepSizeOp->setOutput(0, inputStepSizeTensor); +static bool initStepSize(std::shared_ptr<Node> quantizer) +{ + const auto quantizerOp = std::static_pointer_cast<LSQ_Op>(quantizer->getOperator()); - // Absorb the ReLU when possible ... + // This formula is the one proposed in the paper ... - // XXX is this safe ??? - bool nodeHasParent = static_cast<bool> (linearNode->getParents()[0]); - // bool nodeHasParent = (linearNode->getParents().size() != 0); + // float inputAbsMean = getTensorAbsMean(quantizerOp->getInput(0)); + // float stepSize = 2.0f * (inputAbsMean / std::sqrt(quantizerOp->range().second)); - if (nodeHasParent) { - auto parentNode = linearNode->getParents()[0]; - if (parentNode->type() == "ReLU") { - auto inputQuantizerOp = std::static_pointer_cast<LSQ_Op> (inputQuantizerNode->getOperator()); - inputQuantizerOp->range() = unsignedRange; - graphView->replace({parentNode}, {}); - } - } + // .. but this formula seems to work better !!! - // We need to handle the case where the linear node is the first one ... + float inputStd = getTensorStd(quantizerOp->getInput(0)); + float stepSize = 8.0f * (inputStd / (quantizerOp->range().second)); - if (nodeHasParent) { - graphView->insertParent(linearNode, inputQuantizerNode, 0, 0, 0); - } else { - inputQuantizerNode->addChild(graphView); - graphView->add(inputQuantizerNode); - } + // TODO : use the scalar constructor + auto stepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}})); - // PARAM QUANTIZERS INSERTION + // XXX Manage backend here ? + stepSizeTensor->setBackend(quantizerOp->getInput(0)->backend()); + stepSizeTensor->setDataType(quantizerOp->getInput(0)->dataType()); - // TODO : double check this, and use createUniqueName() - auto paramQuantizerName = makeUniqueName(linearNode->name() + "_lsq_p", graphView); - auto paramQuantizerNode = LSQ(signedRange, paramQuantizerName); - graphView->insertParent(linearNode, paramQuantizerNode, 1, 0, 0); + auto stepSizeProducer = quantizer->getParent(1); - // Set the step size + stepSizeProducer->getOperator()->setOutput(0, stepSizeTensor); - auto paramStepSizeOp = paramQuantizerNode->getParent(1)->getOperator(); - auto paramStepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}})); - paramStepSizeOp->setOutput(0, paramStepSizeTensor); - } + Log::notice(" [ INIT STEP SIZE = {} ] ", stepSize); + return false; } -static float getTensorAbsMean(std::shared_ptr<Tensor> tensor) +static void setupInputQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits) { - auto backend = tensor->backend(); - if (backend == "cuda") - tensor->setBackend("cpu"); + const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|PaddedConv2D#|FC#)"); - float acc = 0; - float* castedTensor = static_cast<float *> (tensor->getImpl()->rawPtr()); - for(std::size_t i = 0; i < tensor->size(); i++) - acc += std::abs(castedTensor[i]); - acc /= static_cast<float> (tensor->size()); + for (const auto& match : matches) + { + auto linearNode = match.graph->rootNode(); - if (backend == "cuda") - tensor->setBackend("cuda"); + // Log::notice(" SET INPUT QUANTIZER : {} ", linearNode->type()); - return acc; -} + std::pair<int, int> signedRange = {-std::pow(2, nbBits - 1), std::pow(2, nbBits - 1) - 1}; + std::pair<int, int> unsignedRange = {0, std::pow(2, nbBits) - 1}; -static std::map<std::string, float> collectInputStats(std::shared_ptr<GraphView> graphView, std::shared_ptr<Tensor> calibrationData, bool useCuda) -{ - // Propagate the calibration tensor + // Create the input quantizer node - SequentialScheduler scheduler(graphView); - scheduler.resetScheduling(); - scheduler.forward(true, {calibrationData}); + auto quantizerName = makeUniqueName(linearNode->name() + "_lsq_i", graphView); + auto quantizerNode = LSQ(signedRange, quantizerName); - // Store the input tensor statistics + // Init the step-size using the node call stack - if (useCuda) - graphView->setBackend("cpu"); + quantizerNode->addBeforeForward([quantizerNode](){ return initStepSize(quantizerNode); }); - std::map<std::string, float> inputStats; - for (auto node : graphView->getNodes()) - { - if (node->type() == "FC" || node->type() == "Conv2D") // TODO: use graph matching !!! - { - const auto op = std::static_pointer_cast<LSQ_Op>(node->getOperator()); - float inputAbsMean = getTensorAbsMean(op->getInput(0)); - inputStats.insert(std::make_pair(node->name(), inputAbsMean)); - fmt::println("{} -> {}", node->name(), inputAbsMean); - } - } + // Absorb the ReLU when possible ... - if (useCuda) - graphView->setBackend("cuda"); + bool nodeHasParent = static_cast<bool> (linearNode->getParents()[0]); // XXX is this safe ? - return inputStats; -} + if (nodeHasParent) + { + bool allParentsAreReLU = true; + for (auto parentNode : linearNode->getParents()) + if (parentNode->type() != "ReLU") + allParentsAreReLU = false; + + if (allParentsAreReLU) { + auto quantizerOp = std::static_pointer_cast<LSQ_Op> (quantizerNode->getOperator()); + quantizerOp->range() = unsignedRange; + } -static std::map<std::string, float> collectParamStats(std::shared_ptr<GraphView> graphView, bool useCuda) -{ - if (useCuda) - graphView->setBackend("cpu"); + // TODO : remove the ReLUs when possible + } - std::map<std::string, float> paramStats; - for (auto node : graphView->getNodes()) - { - if (node->type() == "FC" || node->type() == "Conv2D") // TODO: use graph matching !!! - { - const auto op = std::static_pointer_cast<LSQ_Op>(node->getOperator()); - float paramAbsMean = getTensorAbsMean(op->getInput(1)); - paramStats.insert(std::make_pair(node->name(), paramAbsMean)); - fmt::println("{} -> {}", node->name(), paramAbsMean); + // Insert the quantizer in the graphView ... + // (We need to handle the case where the linear node is the first one) + + if (nodeHasParent) { + graphView->insertParent(linearNode, quantizerNode, 0, 0, 0); + } else { + quantizerNode->addChild(graphView); + graphView->add(quantizerNode); } } - - if (useCuda) - graphView->setBackend("cuda"); - - return paramStats; } -static void adjustQuantizersStepSizes(std::shared_ptr<GraphView> graphView, std::map<std::string, float> inputStats, std::map<std::string, float> paramStats) -{ - const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|FC#)"); +// PARAM QUANTIZERS INSERTION - for (const auto& match : matches) - { - auto linearNode = match.graph->rootNode(); +static void setupParamQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits) +{ + const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|PaddedConv2D#|FC#)"); - // INPUT QUANTIZERS STEP-SIZES + std::pair<int, int> signedRange = {-std::pow(2, nbBits - 1), std::pow(2, nbBits - 1) - 1}; - auto inputQuantNode = linearNode->getParent(0); - auto inputQuantOp = std::static_pointer_cast<LSQ_Op>(inputQuantNode->getOperator()); + for (const auto& match : matches) + { + auto linearNode = match.graph->rootNode(); - float absMean = inputStats[linearNode->name()]; - float stepSize = 2.0f * (absMean / std::sqrt(inputQuantOp->range().second)); + // Log::notice(" SET PARAM QUANTIZER : {} ", linearNode->type()); - auto inputStepSizeOp = inputQuantNode->getParent(1)->getOperator(); - // XXX inputStepSizeOp->setOutput(0, std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}}))); - auto inputStepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}})); - inputStepSizeOp->setOutput(0, inputStepSizeTensor); + // TODO : double check this, and use createUniqueName() + auto quantizerName = makeUniqueName(linearNode->name() + "_lsq_p", graphView); + auto quantizerNode = LSQ(signedRange, quantizerName); - // PARAM QUANTIZERS STEP-SIZES + // Init the step-size using the node call stack - auto paramQuantNode = linearNode->getParent(1); - auto paramQuantOp = std::static_pointer_cast<LSQ_Op>(paramQuantNode->getOperator()); + quantizerNode->addBeforeForward([quantizerNode](){ return initStepSize(quantizerNode); }); - absMean = paramStats[linearNode->name()]; - stepSize = 2.0f * (absMean / std::sqrt(paramQuantOp->range().second)); + // Insert the quantizer in the graphView - auto paramStepSizeOp = paramQuantNode->getParent(1)->getOperator(); - // XXX paramStepSizeOp->setOutput(0, std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}}))); - auto paramStepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}})); - paramStepSizeOp->setOutput(0, paramStepSizeTensor); + graphView->insertParent(linearNode, quantizerNode, 1, 0, 0); } } -void QuantLSQ::insertAndInitQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits, std::shared_ptr<Tensor> calibrationData) +void QuantLSQ::setupQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits) { - bool useCuda = (calibrationData->backend() == "cuda"); - - // Collect the tensor statisics - auto inputStats = collectInputStats(graphView, calibrationData, useCuda); - - auto paramStats = collectParamStats(graphView, useCuda); - - // Insert the quantizers - insertQuantizers(graphView, nbBits, 1.0); - - // Adjust the quantizers step-sizes - adjustQuantizersStepSizes(graphView, inputStats, paramStats); + sanitizeNodeNames(graphView); + setupInputQuantizers(graphView, nbBits); + setupParamQuantizers(graphView, nbBits); } } \ No newline at end of file diff --git a/src/backend/cuda/operator/LSQImpl.cpp b/src/backend/cuda/operator/LSQImpl.cpp index c66bd8a5aa78513b4bcceec83f9c9d87ffed2b11..bb30cc10b6e87f3d6797918d02874ebca48d47ea 100644 --- a/src/backend/cuda/operator/LSQImpl.cpp +++ b/src/backend/cuda/operator/LSQImpl.cpp @@ -52,25 +52,12 @@ void Aidge::LSQImpl_cuda::backward() { std::shared_ptr<Tensor> gra_int1 = op_.getInput(1)->grad(); std::shared_ptr<Tensor> gra_out0 = op_.getOutput(0)->grad(); - // XXX -/* - size_t tmp; - - cudaDeviceSetLimit(cudaLimitStackSize, 2048); - cudaDeviceGetLimit(&tmp, cudaLimitStackSize ); - printf(" stack limit = %ld \n", tmp); - - cudaDeviceSetLimit(cudaLimitMallocHeapSize, 100000000); - cudaDeviceGetLimit(&tmp, cudaLimitMallocHeapSize); - printf(" heap limit = %ld \n", tmp); -*/ - if (gra_int0->size() > mWorkspaceSize) { - // std::cout << " reallocation " << sizeof(gra_int0) << " " << gra_int0->size() << std::endl; if (mWorkspace != nullptr) { cudaFree(mWorkspace); } - CHECK_CUDA_STATUS(cudaMalloc(&mWorkspace, 8 * gra_int0->size())); // XXX This must be changed !!! + std::size_t sizeOfData = getDataTypeBitWidth(gra_int0->dataType()) / 8; + CHECK_CUDA_STATUS(cudaMalloc(&mWorkspace, sizeOfData * gra_int0->size())); mWorkspaceSize = gra_int0->size(); } @@ -87,12 +74,7 @@ void Aidge::LSQImpl_cuda::backward() { gra_int0->getImpl()->rawPtr(), gra_int1->getImpl()->rawPtr(), mWorkspace); -/* - gra_int1->setBackend("cpu"); - float *castedTensor = static_cast<float *> (gra_int1->getImpl()->rawPtr()); - std::cout << castedTensor[0] << std::endl; - gra_int1->setBackend("cuda"); -*/ + } Aidge::LSQImpl_cuda::~LSQImpl_cuda() { diff --git a/src/operator/FixedQ.cpp b/src/operator/FixedQ.cpp index 9828ce98f4918b3d2336c57fe018c9129804cf01..ce9a65defc71909a61e03b1d603b6037a777697a 100644 --- a/src/operator/FixedQ.cpp +++ b/src/operator/FixedQ.cpp @@ -22,7 +22,7 @@ const std::string Aidge::FixedQ_Op::Type = "FixedQ"; Aidge::FixedQ_Op::FixedQ_Op(const Aidge::FixedQ_Op& op) : OperatorTensor(op), - mAttributes(op.mAttributes) + mAttributes(std::make_shared<Attributes_>(*op.mAttributes)) { if (op.mImpl){ SET_IMPL_MACRO(FixedQ_Op, *this, op.backend()); diff --git a/src/operator/PTQMetaOps.cpp b/src/operator/PTQMetaOps.cpp new file mode 100644 index 0000000000000000000000000000000000000000..82b2e501dc275b361899a0ae8284f8a5409d32dc --- /dev/null +++ b/src/operator/PTQMetaOps.cpp @@ -0,0 +1,377 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/operator/PTQMetaOps.hpp" + +#include <memory> +#include <string> +#include <utility> + +#include "aidge/operator/Clip.hpp" +#include "aidge/operator/Mul.hpp" +#include "aidge/operator/Round.hpp" +#include "aidge/operator/BitShift.hpp" +#include "aidge/operator/Cast.hpp" + +#include "aidge/graph/Node.hpp" +#include "aidge/graph/OpArgs.hpp" +#include "aidge/operator/MetaOperator.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/utils/ArrayHelpers.hpp" +#include "aidge/utils/Types.h" +#include "aidge/operator/Identity.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/utils/Log.hpp" + +namespace Aidge +{ + +static bool hasAttr(std::shared_ptr<Aidge::Node> node, std::string attr) +{ + return node->attributes()->hasAttr("quantization.ptq." + attr); +} + +static void addAttr(std::shared_ptr<Aidge::Node> node, std::string attr, double value = 0.0) +{ + node->attributes()->addAttr("quantization.ptq." + attr, value); +} + +// TODO : rework this +static void copyDynamicAttributes(std::shared_ptr<Aidge::Node> prevNode, std::shared_ptr<Aidge::Node> newNode) +{ + if (hasAttr(prevNode, "isProducerQuantizer")) + addAttr(newNode, "isProducerQuantizer"); + + if (hasAttr(prevNode, "isActivationQuantizer")) + addAttr(newNode, "isActivationQuantizer"); +} + +std::shared_ptr<Node> Quantizer(double scalingFactor, const std::string& name) +{ + std::shared_ptr<Node> mulNode = Mul(name + "_MulQuant"); + + // Scaling Factor Producer + + std::shared_ptr<Tensor> scalingFactorTensor = std::make_shared<Tensor>(Array1D<double, 1> {scalingFactor}); + std::shared_ptr<Node> scalingFactorProducer = addProducer<1>(mulNode, 1, {1}, "ScalingFactor"); + scalingFactorProducer->getOperator()->setOutput(0, scalingFactorTensor); + + // TODO : the above could be replaced by : + // std::shared_ptr<Node> scalingFactorProducer = Producer(scalingFactorTensor); + // scalingFactorProducer->addChild(mulNode, 0, 1); + + // create the graphView ... + + std::shared_ptr<GraphView> graphView = Sequential({mulNode}); + graphView->add(scalingFactorProducer); + + // alternative : capture the Producer ... + // std::shared_ptr<GraphView> connectedGraphView = getConnectedGraphView(mulNode); + + std::shared_ptr<Node> quantizer = MetaOperator("Quantizer", graphView, {}, name); // an simpler prototype exists ... + + return quantizer; +} + +void multiplyScalingFactor(std::shared_ptr<Aidge::Node> quantizer, double coeff) +{ + auto quantizerOp = std::static_pointer_cast<MetaOperator_Op> (quantizer->getOperator()); + + // Get the Mul node from the microGraph + + std::shared_ptr<Node> mulNode = nullptr; + auto microGraph = quantizerOp->getMicroGraph(); + for (auto node : microGraph->getNodes()) + if (node->type() == "Mul") + mulNode = node; + + // Retreive the previous scaling factor + + auto scalingFactorTensor = std::static_pointer_cast<OperatorTensor>(mulNode->getOperator())->getInput(1); + + std::shared_ptr<Tensor> fallback; + const Tensor& localTensor = scalingFactorTensor->refCastFrom(fallback, DataType::Float64, "cpu"); + double prevScalingFactor = localTensor.get<double>(0); + + // Create the new scaling factor tensor + + std::shared_ptr<Tensor> newScalingFactorTensor = std::make_shared<Tensor>(prevScalingFactor * coeff); + newScalingFactorTensor->setBackend(scalingFactorTensor->backend()); + newScalingFactorTensor->setDataType(scalingFactorTensor->dataType()); + + // Set the tensor of the producer + + auto producer = mulNode->getParent(1); + producer->getOperator()->setOutput(0, newScalingFactorTensor); +} + +void appendRoundClip(std::shared_ptr<Node>& quantizer, double clipMin, double clipMax) +{ + // Retreive a clone of the microGraph + + auto quantizerOp = std::static_pointer_cast<MetaOperator_Op> (quantizer->getOperator()); + auto microGraph = quantizerOp->getMicroGraph()->clone(); + + // Save the datatype / backend + + auto outputNode = *(microGraph->outputNodes().begin()); + auto outputOp = std::static_pointer_cast<OperatorTensor> (outputNode->getOperator()); + + auto dataType = outputOp->getOutput(0)->dataType(); + auto backend = outputOp->getOutput(0)->backend(); + + // append round + + auto roundNode = Round(quantizer->name() + "_RoundQuant"); + outputNode->addChild(roundNode, 0, 0); + microGraph->add(roundNode); + + // append clip + + auto clipNode = Clip(quantizer->name() + "_ClipQuant"); + + auto minTensor = std::make_shared<Tensor>(clipMin); + auto minNode = Producer(minTensor); + minNode->addChild(clipNode, 0, 1); + microGraph->add(minNode); + + auto maxTensor = std::make_shared<Tensor>(clipMax); + auto maxNode = Producer(maxTensor); + maxNode->addChild(clipNode, 0, 2); + microGraph->add(maxNode); + + roundNode->addChild(clipNode, 0, 0); + microGraph->add(clipNode); + + // set the datatype / backend + + microGraph->setDataType(dataType); + microGraph->setBackend(backend); + + // create the new quantizer and replace the previous one + + std::shared_ptr<Node> newQuantizer = MetaOperator("Quantizer", microGraph, {}, quantizer->name()); + copyDynamicAttributes(quantizer, newQuantizer); + GraphView::replace({quantizer}, {newQuantizer}); + + // replace the old pointer with the new one (by reference) + + quantizer = newQuantizer; +} + +double getScalingFactor(std::shared_ptr<Node> quantizer) +{ + // Retreive the previous microGraph + + auto quantizerOp = std::static_pointer_cast<MetaOperator_Op> (quantizer->getOperator()); + auto microGraph = quantizerOp->getMicroGraph(); + + // Get the Mul node from the microGraph + + std::shared_ptr<Node> mulNode = nullptr; + for (auto node : microGraph->getNodes()) + if (node->type() == "Mul") + mulNode = node; + + auto mulOp = std::static_pointer_cast<OperatorTensor> (mulNode->getOperator()); + + // Retreive the scaling factor + + auto scalingFactorTensor = mulOp->getInput(1); + + std::shared_ptr<Tensor> fallback; + const Tensor& localTensor = scalingFactorTensor->refCastFrom(fallback, DataType::Float64, "cpu"); + double scalingFactor = localTensor.get<double>(0); + + return scalingFactor; +} + +void setClipRange(std::shared_ptr<Node> quantizer, double min, double max) +{ + auto quantizerOp = std::static_pointer_cast<MetaOperator_Op> (quantizer->getOperator()); + auto microGraph = quantizerOp->getMicroGraph(); + + std::shared_ptr<Node> clipNode = nullptr; + for (auto node : microGraph->getNodes()) + if (node->type() == "Clip") + clipNode = node; + + // TODO : assert that we've got not a nullptr ... + + auto clipOp = std::static_pointer_cast<Clip_Op> (clipNode->getOperator()); + + // set the attributes + + clipOp->max() = max; + clipOp->min() = min; + + // Retreive the previous min/max tensors + + auto minTensor = std::static_pointer_cast<OperatorTensor>(clipNode->getOperator())->getInput(1); + auto maxTensor = std::static_pointer_cast<OperatorTensor>(clipNode->getOperator())->getInput(2); + + // Create the new min/max tensors + + std::shared_ptr<Tensor> newMinTensor = std::make_shared<Tensor>(min); + newMinTensor->setBackend(minTensor->backend()); + newMinTensor->setDataType(minTensor->dataType()); + + std::shared_ptr<Tensor> newMaxTensor = std::make_shared<Tensor>(max); + newMaxTensor->setBackend(maxTensor->backend()); + newMaxTensor->setDataType(maxTensor->dataType()); + + // Set the tensors of the producer + + auto minProducer = clipNode->getParent(1); + minProducer->getOperator()->setOutput(0, newMinTensor); + + auto maxProducer = clipNode->getParent(2); + maxProducer->getOperator()->setOutput(0, newMaxTensor); +} + +void removeRound(std::shared_ptr<Node>& quantizer) +{ + // Retreive a clone of the microGraph + + auto quantizerOp = std::static_pointer_cast<MetaOperator_Op> (quantizer->getOperator()); + auto microGraph = quantizerOp->getMicroGraph()->clone(); + + // retreive the rounding node + + std::shared_ptr<Node> roundNode = nullptr; + for (auto node : microGraph->getNodes()) + if (node->type() == "Round") + roundNode = node; + + if (roundNode == nullptr) + return; + + // remove the Round node + + microGraph->replace({roundNode}, {}); + + // Create the new quantizer and replace the previous one + + std::shared_ptr<Node> newQuantizer = MetaOperator("Quantizer", microGraph, {}, quantizer->name()); + copyDynamicAttributes(quantizer, newQuantizer); + GraphView::replace({quantizer}, {newQuantizer}); + + // replace the old pointer with the new one (by reference) + + quantizer = newQuantizer; +} + +void replaceScalingWithBitShift(std::shared_ptr<Node>& quantizer) +{ + // Retreive a clone of the microGraph + + auto quantizerOp = std::static_pointer_cast<MetaOperator_Op> (quantizer->getOperator()); + auto microGraph = quantizerOp->getMicroGraph()->clone(); + + // retreive the multiplicative (scaling) node + + std::shared_ptr<Node> mulNode = nullptr; + for (auto node : microGraph->getNodes()) + if (node->type() == "Mul") + mulNode = node; + + auto mulOp = std::static_pointer_cast<OperatorTensor> (mulNode->getOperator()); + + // Save the datatype / backend + + auto dataType = mulOp->getOutput(0)->dataType(); + auto backend = mulOp->getOutput(0)->backend(); + + // compute the shift value + + double scalingFactor = getScalingFactor(quantizer); + int bitShiftAmount = -std::round(std::log2(scalingFactor)); + auto bitShiftDirection = BitShift_Op::BitShiftDirection::right; + + Log::notice(" SHIFT AMOUNT = {} ({})", bitShiftAmount, scalingFactor); + + if (bitShiftAmount < 0 ) + { + bitShiftDirection = BitShift_Op::BitShiftDirection::left; + bitShiftAmount = -bitShiftAmount; + } + + bool bitShiftRounding = true; // XXX use an argument !!! + + // create the replacement bit-shift nodes + + auto bitShiftNode = BitShift(bitShiftDirection, bitShiftRounding, quantizer->name() + "_BitShiftQuant"); + auto bitShiftTensor = std::make_shared<Tensor>(Array1D<int, 1> {bitShiftAmount}); + + auto bitShiftProducer = Producer(bitShiftTensor, "bitShiftAmount"); + bitShiftProducer->addChild(bitShiftNode, 0, 1); + + // edit the micrograph + + microGraph->replace({mulNode, mulNode->getParent(1)}, {bitShiftNode, bitShiftNode->getParent(1)}); + + // set the datatype / backend + + microGraph->setDataType(dataType); + microGraph->setBackend(backend); + + // create the new quantizer and replace the previous one + + std::shared_ptr<Node> newQuantizer = MetaOperator("Quantizer", microGraph, {}, quantizer->name()); + copyDynamicAttributes(quantizer, newQuantizer); + GraphView::replace({quantizer}, {newQuantizer}); + + // replace the old pointer with the new one (by reference) + + quantizer = newQuantizer; +} + +void castQuantizerIOs(std::shared_ptr<Node>& quantizer, Aidge::DataType externalType) +{ + // Retreive a clone of the microGraph + + auto quantizerOp = std::static_pointer_cast<MetaOperator_Op> (quantizer->getOperator()); + auto microGraph = quantizerOp->getMicroGraph()->clone(); + + // Edit the micrograph (insert Cast nodes at it's IOs) + + auto mulNode = *(microGraph->inputNodes().begin()); // TODO : assert that mulNode is a Mul ! + auto mulOp = std::static_pointer_cast<OperatorTensor> (mulNode->getOperator()); + + auto internalType = mulOp->getOutput(0)->dataType(); + auto castInputNode = Cast(internalType, quantizer->name() + "_CastIn"); + auto castOutputNode = Cast(externalType, quantizer->name() + "_CastOut"); + + microGraph = Sequential({castInputNode, microGraph, castOutputNode}); + + // Set the micrograph datatype + + microGraph->setDataType(internalType); + castOutputNode->getOperator()->setDataType(externalType); + + // Set the micrograph backend + + auto backend = mulOp->getOutput(0)->backend(); + microGraph->setBackend(backend); + + // Create the new quantizer and replace the old one + + std::shared_ptr<Node> newQuantizer = MetaOperator("Quantizer", microGraph, {}, quantizer->name()); + copyDynamicAttributes(quantizer, newQuantizer); + GraphView::replace({quantizer}, {newQuantizer}); + + // replace the old pointer with the new one (by reference) + + quantizer = newQuantizer; +} + +} \ No newline at end of file diff --git a/src/operator/SAT/DoReFa.cpp b/src/operator/SAT/DoReFa.cpp index 426e330e7f8426d256ca76a843548a91a62b036a..f722631e543c46b1307a372a8d2cb35e65215b2f 100644 --- a/src/operator/SAT/DoReFa.cpp +++ b/src/operator/SAT/DoReFa.cpp @@ -23,7 +23,7 @@ const std::string DoReFa_Op::Type = "DoReFa"; DoReFa_Op::DoReFa_Op(const DoReFa_Op& op) : OperatorTensor(op), - mAttributes(op.mAttributes) + mAttributes(std::make_shared<Attributes_>(*op.mAttributes)) { if (op.mImpl) { SET_IMPL_MACRO(DoReFa_Op, *this, op.backend()); diff --git a/src/recipes/QuantRecipes.cpp b/src/recipes/QuantRecipes.cpp index 6e1dcdb1b64c0a1e94c74ce66cb71f1a458bca35..1806e3d4fd4da402f76c9046b84fcb6acfe69606 100644 --- a/src/recipes/QuantRecipes.cpp +++ b/src/recipes/QuantRecipes.cpp @@ -9,24 +9,18 @@ * ********************************************************************************/ -/* -#include "aidge/data/Tensor.hpp" -#include "aidge/graph/GraphView.hpp" -#include "aidge/graph/Node.hpp" -#include "aidge/scheduler/SequentialScheduler.hpp" -#include "aidge/scheduler/Scheduler.hpp" -#include "aidge/utils/Log.hpp" - +#include "aidge/graph/OpArgs.hpp" #include "aidge/operator/Producer.hpp" -#include "aidge/operator/Mul.hpp" -#include "aidge/operator/ReLU.hpp" -#include "aidge/operator/Scaling.hpp" -*/ - #include "aidge/operator/Conv.hpp" +#include "aidge/operator/Transpose.hpp" +#include "aidge/operator/MatMul.hpp" #include "aidge/operator/BatchNorm.hpp" //#include "aidge/quantization/PTQ/PTQ.hpp" #include "aidge/recipes/QuantRecipes.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/operator/FC.hpp" +#include "aidge/graph/Matching.hpp" + namespace Aidge { @@ -55,14 +49,16 @@ void insertBatchNormNodes(std::shared_ptr<GraphView> graphView) { for (std::shared_ptr<Node> parentNode : graphView->getNodes()) { - if (parentNode->type() == "Conv2D") + // TODO : use graph matching + + if (parentNode->type() == "Conv2D" || parentNode->type() == "PaddedConv2D") { - std::shared_ptr<Conv_Op<2>> convOperator = std::static_pointer_cast<Conv_Op<2>> (parentNode->getOperator()); - int nb_channels = convOperator->getInput(1)->dims()[0]; - fmt::println(" NB CHANNELS = {}", nb_channels); // TODO : remove this ... + std::shared_ptr<OperatorTensor> convOperator = std::static_pointer_cast<OperatorTensor> (parentNode->getOperator()); + int nbChannels = convOperator->getInput(1)->dims()[0]; + Log::notice(" NB CHANNELS = {} ", nbChannels); std::string batchnormNodeName = makeUniqueName(parentNode->name() + "_BN", graphView); - std::shared_ptr<Node> batchnormNode = BatchNorm<2>(nb_channels, 1e-5, 0.1, false, batchnormNodeName); + std::shared_ptr<Node> batchnormNode = BatchNorm<2>(nbChannels, 1e-5, 0.1, false, batchnormNodeName); batchnormNode->getOperator()->setDataType(DataType::Float32); batchnormNode->getOperator()->setBackend("cpu"); @@ -118,6 +114,7 @@ std::string makeUniqueName(std::string baseName, std::shared_ptr<GraphView> grap return newName; } + void sanitizeNodeNames(std::shared_ptr<GraphView> graphView) { for (std::shared_ptr<Node> node : graphView->getNodes()) @@ -129,4 +126,66 @@ void sanitizeNodeNames(std::shared_ptr<GraphView> graphView) } } -} \ No newline at end of file +void reorderMatMulInputs(std::shared_ptr<GraphView> graphView) +{ + const auto matches = SinglePassGraphMatching(graphView).match("(MatMul#)"); + + for (auto match : matches) + { + auto node = match.graph->rootNode(); + + // Check if the MatMul inputs have to be permuted + + bool permuteInputs = false; + + if (node->getParent(0)) + if (node->getParent(0)->type() == "Producer") + permuteInputs = true; + + if (node->getParent(1)) + if (node->getParent(1)->type() == "Producer") + permuteInputs = false; + + // Perform the permutation of the inputs ... + + if (permuteInputs) + { + auto prevMatMul = node; + auto prevTensor = (std::static_pointer_cast<OperatorTensor> (node->getOperator()))->getInput(0); + + // Create the new MatMul op and it's Producer + + auto newMatMul = MatMul(); + + auto newDims = prevTensor->dims(); + std::swap(newDims[0], newDims[1]); + auto newTensor = std::make_shared<Tensor>(newDims); + + newTensor->setDataType(prevTensor->dataType()); + newTensor->setBackend(prevTensor->backend()); + newTensor->copyTranspose(*prevTensor, std::vector<Aidge::DimSize_t>({1, 0})); + + auto newProducer = Producer(newTensor, ""); + newProducer->addChild(newMatMul, 0, 1); + + // Replace the node by a micrograph + + auto prevMicroGraph = Sequential({prevMatMul}); + prevMicroGraph->add(prevMatMul->getParent(0)); + + auto newMicroGraph = Sequential({Transpose({1, 0}), newMatMul, Transpose({1, 0})}); + newMicroGraph->add(newMatMul->getParent(1)); + + newMicroGraph->setDataType(prevTensor->dataType()); + newMicroGraph->setBackend(prevTensor->backend()); + + graphView->replace(prevMicroGraph, newMicroGraph); + } + } + + // TODO : fold the Transpose operators when possible ... + + // USE REGEXPS !!! +} + +} diff --git a/unit_tests/Test_QuantPTQ.cpp b/unit_tests/Test_QuantPTQ.cpp index e7211ce4092f789c8c6263671ad236b97934ffbb..018b111bfb8355af7c1cff96a603eafdd1bd6c64 100644 --- a/unit_tests/Test_QuantPTQ.cpp +++ b/unit_tests/Test_QuantPTQ.cpp @@ -200,7 +200,7 @@ TEST_CASE("[tmp] basic test") { // //no need to do this anymore, forward does it autimatically now ... // //scheduler.generateScheduling(true); -// std::vector<std::shared_ptr<Node>> ordered_graph_view = scheduler.getStaticScheduling(); +// std::vector<std::shared_ptr<Node>> ordered_graph_view = scheduler.getSequentialStaticScheduling(); // printf("Going to quantize network :\n"); @@ -226,7 +226,7 @@ TEST_CASE("[tmp] basic test") { // scheduler_v2.forward(); // scheduler_v2.generateScheduling(false); -// std::vector<std::shared_ptr<Node>> ordered_graph_view_v2 = scheduler_v2.getStaticScheduling(); +// std::vector<std::shared_ptr<Node>> ordered_graph_view_v2 = scheduler_v2.getSequentialStaticScheduling(); // if(verbose) { // printf("Ordered graph after quantization :\n"); diff --git a/version.txt b/version.txt index 9e11b32fcaa96816319e5d0dcff9fb2873f04061..1d0ba9ea182b0f7354f3daf12120744ec5e0c2f8 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.3.1 +0.4.0