Skip to content
Snippets Groups Projects
Commit 91a067d9 authored by Benjamin Halimi's avatar Benjamin Halimi
Browse files

Merge branch 'DevPTQ' into 'main'

Quantization Updates and Optimizations

See merge request !9
parents 3b7cc254 0e591cf8
No related branches found
No related tags found
1 merge request!9Quantization Updates and Optimizations
Showing
with 2007 additions and 151 deletions
......@@ -42,17 +42,19 @@ A particular care is needed for the biases rescaling at each step.
## Doing quantization step by step
It is possible to perform the PTQ step by step, thank's to the exposed functions of the API.
It is possible to perform the PTQ step by step, thanks to the exposed functions of the API.
In that case, here is the standard pipeline:
1) remove the flatten and dropout nodes
2) expand the meta-operators (if there are some)
3) insert the scaling nodes
4) perform the parameter normalization
5) perform the output value normalization, over a calibration dataset
6) quantize the normalized network
- Prepare the network for the PTQ (remove the flatten nodes, fuse the BatchNorms ...)
- Insert the scaling nodes that will allow the model calibration
- Perform the Cross Layer Equalization if possible
- Perform the parameter normalization
- Compute the node output ranges over an input calibration dataset
- Adjust the output ranges using a specified error metric (MSE, KL, ...)
- Perform the activation normalization
- Quantize the normalized network
- Convert the scaling factors to bit-shifting operations if needed
## Further work
* add smart clipping methods for the normalizations.
* add Quantization Aware Training (QAT).
* add Quantization Aware Training (QAT)
\ No newline at end of file
import gzip
import numpy as np
import matplotlib.pyplot as plt
import aidge_core
import aidge_backend_cpu
import aidge_onnx
import aidge_quantization
NB_SAMPLES = 100 # max : 1000
NB_BITS = 4
# --------------------------------------------------------------
# LOAD THE MODEL IN AIDGE
# --------------------------------------------------------------
aidge_model = aidge_onnx.load_onnx("assets/ConvNet.onnx", verbose=False)
aidge_core.remove_flatten(aidge_model)
# --------------------------------------------------------------
# LOAD THE SAMPLES / LABELS (NUMPY)
# --------------------------------------------------------------
samples = np.load(gzip.GzipFile('assets/mnist_samples.npy.gz', "r"))
labels = np.load(gzip.GzipFile('assets/mnist_labels.npy.gz', "r"))
# --------------------------------------------------------------
# SETUP THE AIDGE SCHEDULER
# --------------------------------------------------------------
# Create the Producer node
input_array = np.zeros(784).astype('float32')
input_tensor = aidge_core.Tensor(input_array)
input_node = aidge_core.Producer(input_tensor, "X")
# Configuration for the inputs
input_node.get_operator().set_datatype(aidge_core.dtype.float32)
input_node.get_operator().set_backend("cpu")
# Link Producer to the Graph
input_node.add_child(aidge_model)
# Configuration for the model
aidge_model.set_datatype(aidge_core.dtype.float32)
aidge_model.set_backend("cpu")
# Create the Scheduler
scheduler = aidge_core.SequentialScheduler(aidge_model)
# --------------------------------------------------------------
# RUN SOME EXAMPLE INFERENCES WITH AIDGE
# --------------------------------------------------------------
def propagate(model, scheduler, sample):
# Setup the input
input_tensor = aidge_core.Tensor(sample)
input_node.get_operator().set_output(0, input_tensor)
# Run the inference
scheduler.forward(verbose=False)
# Gather the results
output_node = model.get_output_nodes().pop()
output_tensor = output_node.get_operator().get_output(0)
return np.array(output_tensor)
def bake_sample(sample):
sample = np.reshape(sample, (1, 1, 28, 28))
return sample.astype('float32')
print('\n EXAMPLE INFERENCES :')
for i in range(10):
input_array = bake_sample(samples[i])
output_array = propagate(aidge_model, scheduler, input_array)
print(labels[i] , ' -> ', np.round(output_array, 2))
# --------------------------------------------------------------
# COMPUTE THE MODEL ACCURACY
# --------------------------------------------------------------
def compute_accuracy(model, samples, labels):
acc = 0
for i, sample in enumerate(samples):
x = bake_sample(sample)
y = propagate(model, scheduler, x)
if labels[i] == np.argmax(y):
acc += 1
return acc / len(samples)
accuracy = compute_accuracy(aidge_model, samples[0:NB_SAMPLES], labels)
print(f'\n MODEL ACCURACY : {accuracy * 100:.3f}%')
# --------------------------------------------------------------
# CREATE THE TENSOR SUBSET
# --------------------------------------------------------------
tensors = []
for sample in samples[0:NB_SAMPLES]:
sample = bake_sample(sample)
tensor = aidge_core.Tensor(sample)
tensors.append(tensor)
# --------------------------------------------------------------
# APPLY THE PTQ TO THE MODEL
# --------------------------------------------------------------
aidge_quantization.quantize_network(aidge_model, NB_BITS, tensors)
# --------------------------------------------------------------
# UPDATE THE SCHEDULER
# --------------------------------------------------------------
scheduler = aidge_core.SequentialScheduler(aidge_model)
# --------------------------------------------------------------
# QUANTIZE THE INPUT TENSORS
# --------------------------------------------------------------
scaling = 2**(NB_BITS-1)-1
for i in range(NB_SAMPLES):
samples[i] = np.round(samples[i]*scaling)
# --------------------------------------------------------------
# RUN SOME QUANTIZED INFERENCES WITH AIDGE
# --------------------------------------------------------------
print('\n EXAMPLE QUANTIZED INFERENCES :')
for i in range(10):
input_array = bake_sample(samples[i])
output_array = propagate(aidge_model, scheduler, input_array)
print(labels[i] , ' -> ', np.round(output_array, 2))
# --------------------------------------------------------------
# COMPUTE THE MODEL ACCURACY
# --------------------------------------------------------------
accuracy = compute_accuracy(aidge_model, samples[0:NB_SAMPLES], labels)
print(f'\n QUANTIZED MODEL ACCURACY : {accuracy * 100:.3f}%')
# --------------------------------------------------------------
# WORK IS DONE !
# --------------------------------------------------------------
print('\n that\'s all folks !\n')
MiniResNet.onnx filter=lfs diff=lfs merge=lfs -text
ConvNet.onnx filter=lfs diff=lfs merge=lfs -text
mnist_labels.npy.gz filter=lfs diff=lfs merge=lfs -text
mnist_samples.npy.gz filter=lfs diff=lfs merge=lfs -text
No preview for this file type
File added
No preview for this file type
No preview for this file type
import unittest
import gzip
import numpy as np
import aidge_core
import aidge_backend_cpu
import aidge_onnx
import aidge_quantization
# --------------------------------------------------------------
# CONFIGS
# --------------------------------------------------------------
NB_SAMPLES = 1000 # max : 1000
SAMPLE_SHAPE = (1, 1, 28, 28)
MODEL_NAME = 'MiniResNet.onnx' # 'ConvNet.onnx'
ACCURACIES = (95.4, 94.5) # (97.9, 97.7)
NB_BITS = 4
# --------------------------------------------------------------
# UTILS
# --------------------------------------------------------------
def propagate(model, scheduler, sample):
input_tensor = aidge_core.Tensor(sample)
scheduler.forward(True, [input_tensor])
output_node = model.get_output_nodes().pop()
output_tensor = output_node.get_operator().get_output(0)
return np.array(output_tensor)
def prepare_sample(sample):
sample = np.reshape(sample, SAMPLE_SHAPE)
return sample.astype('float32')
def compute_accuracy(model, samples, labels):
acc = 0
scheduler = aidge_core.SequentialScheduler(model)
for i, sample in enumerate(samples):
x = prepare_sample(sample)
y = propagate(model, scheduler, x)
if labels[i] == np.argmax(y):
acc += 1
return acc / len(samples)
# --------------------------------------------------------------
# TEST CLASS
# --------------------------------------------------------------
class test_ptq(unittest.TestCase):
def setUp(self):
# load the samples / labels (numpy)
self.samples = np.load(gzip.GzipFile('assets/mnist_samples.npy.gz', "r"))
self.labels = np.load(gzip.GzipFile('assets/mnist_labels.npy.gz', "r"))
# load the model in AIDGE
self.model = aidge_onnx.load_onnx("assets/" + MODEL_NAME, verbose=False)
aidge_core.remove_flatten(self.model)
self.model.set_datatype(aidge_core.dtype.float32)
self.model.set_backend("cpu")
def tearDown(self):
pass
def test_model(self):
# compute the base accuracy
accuracy = compute_accuracy(self.model, self.samples[0:NB_SAMPLES], self.labels)
self.assertAlmostEqual(accuracy * 100, ACCURACIES[0], msg='base accuracy does not meet the baseline !', delta=0.1)
def test_quant_model(self):
# create the calibration dataset
tensors = []
for sample in self.samples[0:NB_SAMPLES]:
sample = prepare_sample(sample)
tensor = aidge_core.Tensor(sample)
tensors.append(tensor)
# quantize the model
aidge_quantization.quantize_network(
self.model,
NB_BITS,
tensors,
clipping_mode=aidge_quantization.Clipping.MSE,
apply_rounding=True,
optimize_signs=True,
single_shift=False
)
# rescale the inputs
scaling = 2**(NB_BITS-1)-1
for i in range(NB_SAMPLES):
self.samples[i] = self.samples[i]*scaling # XXX np.round ???
# compute the quantized accuracy
accuracy = compute_accuracy(self.model, self.samples, self.labels)
self.assertAlmostEqual(accuracy * 100, ACCURACIES[1], msg='quantized accuracy does not meet the baseline !', delta=0.1)
if __name__ == '__main__':
unittest.main()
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_QUANTIZATION_PTQ_CLE_H_
#define AIDGE_QUANTIZATION_PTQ_CLE_H_
//#include <cstdint>
//#include <map>
//#include <memory>
//#include <string>
//#include <vector>
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
namespace Aidge
{
/**
* @brief Equalize the ranges of the nodes parameters by proceding iteratively.
* Can only be applied to single branch networks (otherwise does not edit the graphView).
* @param graphView The GraphView to process.
* @param targetDelta the stopping criterion (typical value : 0.01)
*/
void crossLayerEqualization(std::shared_ptr<GraphView> graphView, float targetDelta = 0.01);
}
#endif /* AIDGE_QUANTIZATION_PTQ_CLE_H_ */
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_QUANTIZATION_PTQ_CLIP_H_
#define AIDGE_QUANTIZATION_PTQ_CLIP_H_
//#include <cstdint>
//#include <map>
//#include <memory>
//#include <string>
//#include <vector>
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
namespace Aidge
{
/**
* @brief Kind of clipping policy to apply during the activation quantization
*/
enum Clipping {MAX = 1, MSE, AA, KL};
/**
* @brief Compute the histograms of the activations of each node contained in the map of the ranges (passed as argument).
* @param valueRanges A map associating each considered node name to its corresponding output range.
* @param nbBins Desired number of bins of the returned histograms.
* @param graphView The GraphView containing the considered nodes.
* @param inputDataSet The input dataset, consisting of a vector of input samples.
* @return A map associating each node name to it's corresponding activation histogram.
*/
std::map<std::string, std::vector<int>> computeHistograms(std::map<std::string, float> valueRanges, int nbBins, std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet);
/**
* @brief Given an input activation histogram, compute the optimal clipping value in the sense of the Lp norm.
* @param histogram: The provided activation histogram.
* @param nbBits: The quantization number of bits.
* @param exponent: The exponent of the Lp norm (e.g. 2 for the MSE).
* @return The optimal clipping value.
*/
float computeMEClipping(std::vector<int> histogram, std::uint8_t nbBits, float exponent);
/**
* @brief Given an input activation histogram, compute the optimal clipping value in the sense of the KL divergence.
* @param histogram: The provided activation histogram.
* @param nbBits: The quantization number of bits.
* @return The optimal clipping value.
*/
float computeKLClipping(std::vector<int> histogram, std::uint8_t nbBits);
/**
* @brief Return a corrected map of the provided activation ranges.
* To do so compute the optimal clipping values for every node and multiply the input ranges by those values.
* The method used to compute the clippings can be eihter 'MSE', 'AA', 'KL' or 'MAX'.
* @param clippingMode The method used to compute the optimal clippings.
* @param valueRanges The map associating each affine node to its output range.
* @param nbBits The quantization number of bits.
* @param graphView The GraphView containing the considered nodes.
* @param inputDataSet The input dataset, consisting of a vector of input samples.
* @param verbose Whether to print the clipping values or not.
* @return The corrected map associating each provided node to its clipped range.
*/
std::map<std::string, float> adjustRanges(Clipping clippingMode, std::map<std::string, float> valueRanges, std::uint8_t nbBits, std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet, bool verbose);
}
#endif /* AIDGE_QUANTIZATION_PTQ_CLIP_H_ */
......@@ -9,20 +9,64 @@
*
********************************************************************************/
#ifndef AIDGE_QUANTIZATION_QUANTPTQ_H_
#define AIDGE_QUANTIZATION_QUANTPTQ_H_
#ifndef AIDGE_QUANTIZATION_PTQ_PTQ_H_
#define AIDGE_QUANTIZATION_PTQ_PTQ_H_
#include <cstdint> // std::uint8_t
#include <map>
#include <memory>
#include <string>
#include <vector>
//#include <cstdint>
//#include <map>
//#include <memory>
//#include <string>
//#include <vector>
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
namespace Aidge {
/**
* @brief Set of the types of the nodes which contain affine transforms (that is Y = A.X + B)
*/
static const std::set<std::string> affineNodeTypes({"FC", "Conv", "ConvDepthWise", "PaddedConv", "PaddedConvDepthWise"});
/**
* @brief Set of the types of the nodes which does not affect the PTQ process
*/
static const std::set<std::string> seamlessNodeTypes({"Pad", "MaxPooling", "AvgPooling", "PaddedMaxPooling", "PaddedAvgPooling", "GlobalAveragePooling", "Reshape", "Transpose", "Gather"});
/**
* @brief Set of the types of the nodes that merge multiple branches into one
*/
static const std::set<std::string> mergingNodeTypes({"Add", "Concat", "Sub"});
/**
* @brief Determine if a node contains an affine transform (that is Y = A.X + B)
* @param node The node to be checked
* @return True if the node is affine, else false.
*/
bool isAffine(std::shared_ptr<Node> node);
/**
* @brief Determine if a node contains an operator that does not affect the PTQ process
* @param node The node to be checked
* @return True if the node is seamless, else false.
*/
bool isSeamless(std::shared_ptr<Node> node);
/**
* @brief Determine if a node contains an operator that merges multiple branches into one
* @param node The node to be checked
* @return True if the node is merging, else false.
*/
bool isMerging(std::shared_ptr<Node> node);
/**
* @brief Retrieve the scheduled vector of node of a graphView, without the Producer nodes.
* @param graphView The graphView containing the nodes
* @param verbose Whether to print the node vector or not
* @return The scheduled vector of nodes
*/
std::vector<std::shared_ptr<Node>> retrieveNodeVector(std::shared_ptr<GraphView> graphView, bool newSchedule = true, bool verbose = false);
/**
* @brief Determine whether an input GraphView can be quantized or not.
* @param graphView The GraphView to be checked.
......@@ -32,6 +76,7 @@ namespace Aidge {
/**
* @brief Insert a scaling node after each affine node of the GraphView.
* Also insert a scaling node in every purely residual branches.
* @param graphView The GraphView containing the affine nodes.
*/
void insertScalingNodes(std::shared_ptr<GraphView> graphView);
......@@ -43,57 +88,74 @@ namespace Aidge {
void normalizeParameters(std::shared_ptr<GraphView> graphView);
/**
* @brief Compute the value ranges of every affine node output, given an input dataset.
* @brief Compute the activation ranges of every affine node, given an input dataset.
* @param graphView The GraphView containing the affine nodes, on which the inferences are performed.
* @param inputDataSet The input dataset, consisting of a vector of input samples.
* @param scalingNodesOnly Whether to restrain the retreival of the ranges to scaling nodes only or not.
* @return A map associating each affine node name to it's corresponding output range.
*/
std::map<std::string, float> computeRanges(std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet);
std::map<std::string, float> computeRanges(std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet, bool scalingNodesOnly);
/**
* @brief Normalize the activations of each affine node so that it become equal to one.
* @brief Normalize the activations of each affine node so that they fit in the [-1:1] range.
* This is done by reconfiguring the scaling nodes, as well as rescaling the weights and biases tensors.
* @param graphView The GraphView containing the affine nodes.
* @param valueRanges The node output value ranges computed over the calibration dataset.
*/
void normalizeActivations(std::shared_ptr<GraphView> graphView, std::map<std::string, float> valueRanges);
/**
* @brief For each node, compute the sign of its input and output values.
* The goal of the routine is to maximize the number of unsigned IOs in order to double the value resolution when possible.
* @param graphView The GraphView to analyze.
* @param verbose Whether to print the sign map or not.
* @return A map associating a pair of sign to each node of the GraphView (a sign for the input and one for the output).
*/
std::map<std::string, std::pair<bool, bool>> computeSignMap(std::shared_ptr<GraphView> graphView, bool verbose);
/**
* @brief Quantize an already normalized (in term of parameters and activations) network.
* @param graphView The GraphView to be quantized.
* @param nbBits The desired number of bits of the quantization.
* @param applyRounding Whether to apply the rounding operations or not.
* @param optimizeSigns Whether to take account of the IO signs of the operators or not.
* @param verbose Whether to print the sign map or not.
*/
void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits);
void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, bool applyRounding, bool optimizeSigns, bool verbose);
/**
* @brief Main quantization routine. Performs every step of the quantization pipeline.
* @param graphView The GraphView to be quantized.
* @param nbBits The desired number of bits of the quantization.
* @param inputDataSet The input dataset on which the value ranges are computed.
* @param clippingMode: Type of the clipping optimization. Can be either 'MAX', 'MSE', 'AA' or 'KL'.
* @param applyRounding Whether to apply the rounding operations or not.
* @param optimizeSigns Whether to take account of the IO signs of the operators or not.
* @param singleShift Whether to convert the scaling factors into powers of two. If true the approximations are compensated using the previous nodes weights.
* @param verbose Whether to print internal informations about the quantization process.
*/
void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet, bool OptimizeCliping);
void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet, Clipping clippingMode, bool applyRounding, bool optimizeSigns, bool singleShift, bool verbose);
/**
* @brief Compute the weight ranges of every affine node. Provided for debuging purposes.
* @brief Compute the weight ranges of every affine node. Provided for debugging purposes.
* @param graphView The GraphView containing the affine nodes.
* @return A map associating each affine node name to it's corresponding weight range.
*/
std::map<std::string, float> getWeightRanges(std::shared_ptr<GraphView> graphView);
/**
* @brief Clear the affine nodes biases. Provided form debuging purposes.
* @brief Clear the affine nodes biases. Provided form debugging purposes.
* @param graphView The GraphView containing the affine nodes.
*/
void clearBiases(std::shared_ptr<GraphView> graphView);
void devPTQ(std::shared_ptr<GraphView> graphView);
std::map<std::string, std::vector<int>> computeScalingHistograms(std::map<std::string, float> valueRanges, int nbBins, std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet);
float computeBestClipping(std::vector<int> histogram, std::uint8_t nbBits);
/**
* @brief Developement and test routine.
* @param graphView The GraphView under test.
*/
void devPTQ(std::shared_ptr<GraphView> graphView);
}
#endif /* AIDGE_QUANTIZATION_QUANTPTQ_H_ */
#endif /* AIDGE_QUANTIZATION_PTQ_PTQ_H_ */
......@@ -14,7 +14,10 @@
#include <string>
#include "aidge/QuantPTQ.hpp"
#include "aidge/PTQ/Clip.hpp"
#include "aidge/PTQ/CLE.hpp"
#include "aidge/PTQ/PTQ.hpp"
#include "aidge/hook/Hook.hpp"
#include "aidge/graph/GraphView.hpp"
......@@ -23,18 +26,25 @@ namespace py = pybind11;
namespace Aidge {
void init_QuantPTQ(py::module &m) {
py::enum_<Clipping>(m, "Clipping", "Kind of clipping policy to apply during the activation quantization")
.value("MAX", Clipping::MAX)
.value("MSE", Clipping::MSE)
.value("AA" , Clipping::AA)
.value("KL" , Clipping::KL);
m.def("check_architecture", &checkArchitecture, py::arg("network"),
R"mydelimiter(
Determine whether an input GraphView can be quantized or not.
:param network: The GraphView to be checked.
:type network: :py:class:`aidge_core.GraphView`
:return: True if the GraphView can be quantized, else false.
:return: True if the GraphView can be quantized, else False.
:rtype: bool
)mydelimiter");
m.def("insert_scaling_nodes", &insertScalingNodes, py::arg("network"),
R"mydelimiter(
Insert a scaling node after each affine node of the GraphView.
Insert a scaling node after each affine node of the GraphView.
Also insert a scaling node in every purely residual branches.
:param network: The GraphView containing the affine nodes.
:type network: :py:class:`aidge_core.GraphView`
)mydelimiter");
......@@ -46,20 +56,22 @@ void init_QuantPTQ(py::module &m) {
:type network: :py:class:`aidge_core.GraphView`
)mydelimiter");
m.def("compute_ranges", &computeRanges, py::arg("network"), py::arg("input_dataset"),
m.def("compute_ranges", &computeRanges, py::arg("network"), py::arg("input_dataset"), py::arg("scaling_nodes_only"),
R"mydelimiter(
Compute the value ranges of every affine node output, given an input dataset.
Compute the activation ranges of every affine node, given an input dataset.
:param network: The GraphView containing the affine nodes, on which the inferences are performed.
:type network: :py:class:`aidge_core.GraphView`
:param input_dataset: inputDataSet The input dataset, consisting of a vector of input samples.
:type input_dataset: A list of :py:class:`aidge_core.Tensor`
:return: A map associating each affine node name to it's corresponding output range.
:param input_dataset: The input dataset, consisting of a vector of input samples.
:type input_dataset: list of :py:class:`aidge_core.Tensor`
:param scaling_nodes_only: Whether to restrain the retreival of the ranges to scaling nodes only or not
:type scaling_nodes_only: bool
:return: A map associating each considered node name to it's corresponding output range.
:rtype: dict
)mydelimiter");
m.def("normalize_activations", &normalizeActivations, py::arg("network"), py::arg("value_ranges"),
R"mydelimiter(
Normalize the activations of each affine node so that it become equal to one.
Normalize the activations of each affine node so that they fit in the [-1:1] range.
This is done by reconfiguring the scaling nodes, as well as rescaling the weights and biases tensors.
:param network: The GraphView containing the affine nodes.
:type network: :py:class:`aidge_core.GraphView`
......@@ -67,16 +79,22 @@ void init_QuantPTQ(py::module &m) {
:type value_ranges: list of float.
)mydelimiter");
m.def("quantize_normalized_network", &quantizeNormalizedNetwork, py::arg("network"), py::arg("nb_bits"),
m.def("quantize_normalized_network", &quantizeNormalizedNetwork, py::arg("network"), py::arg("nb_bits"), py::arg("apply_rounding"), py::arg("optimize_signs"), py::arg("verbose"),
R"mydelimiter(
Quantize an already normalized (in term of parameters and activations) network.
:param network: The GraphView to be quantized.
:type network: :py:class:`aidge_core.GraphView`
:param nb_bits: The desired number of bits of the quantization.
:type nb_bits: int
:param apply_rounding: Whether to apply the rounding operations or not.
:type apply_rounding: bool
:param optimize_signs: Whether to take account of the IO signs of the operators or not.
:type optimize_signs: bool
:param verbose: Whether to print the sign map or not.
:type verbose: bool
)mydelimiter");
m.def("quantize_network", &quantizeNetwork ,py::arg("network"), py::arg("nb_bits"), py::arg("input_dataset"), py::arg("optimize_cliping") = false,
m.def("quantize_network", &quantizeNetwork ,py::arg("network"), py::arg("nb_bits"), py::arg("input_dataset"), py::arg("clipping_mode") = "MAX", py::arg("apply_rounding") = true, py::arg("optimize_signs") = false, py::arg("single_shift") = false, py::arg("verbose") = false,
R"mydelimiter(
Main quantization routine. Performs every step of the quantization pipeline.
:param network: The GraphView to be quantized.
......@@ -85,14 +103,105 @@ void init_QuantPTQ(py::module &m) {
:type nb_bits: int
:param input_dataset: The input dataset on which the value ranges are computed.
:type input_dataset: list of :py:class:`aidge_core.Tensor`
:param optimize_cliping: Whether to optimize the cliping values or not.
:type optimize_cliping: bool
:param clipping_mode: Type of the clipping optimization. Can be either 'MAX', 'MSE', 'AA' or 'KL'.
:type clipping_mode: string
:param apply_rounding: Whether to apply the rounding operations or not.
:type apply_rounding: bool
:param optimize_signs: Whether to take account of the IO signs of the operators or not.
:type optimize_signs: bool
:param single_shift: Whether to convert the scaling factors into powers of two. If true the approximations are compensated using the previous nodes weights.
:type single_shift: bool
:param verbose: Whether to print internal informations about the quantization process.
:type verbose: bool
)mydelimiter");
m.def("compute_histograms", &computeHistograms, py::arg("value_ranges"), py::arg("nb_bins"), py::arg("network"), py::arg("input_dataset"),
R"mydelimiter(
Compute the histograms of the activations of each node contained in the map of the ranges (passed as argument).
:param value_ranges: A map associating each considered node name to its corresponding output range.
:type value_ranges: dict
:param nb_bins: Desired number of bins of the returned histograms.
:type nb_bins: int
:param network: The GraphView containing the considered nodes.
:type network: :py:class:`aidge_core.GraphView`
:param input_dataset: The input dataset, consisting of a list of input samples.
:type input_dataset: list of :py:class:`aidge_core.Tensor`
:return: A map associating each node name to it's corresponding activation histogram.
:rtype: dict
)mydelimiter");
m.def("compute_me_clipping", &computeMEClipping, py::arg("histogram"), py::arg("nb_bits"), py::arg("exponent"),
R"mydelimiter(
Given an input activation histogram, compute the optimal clipping value in the sense of the Lp norm.
:param histogram: The provided activation histogram.
:type histogram: list
:param nb_bits: The quantization number of bits.
:type nb_bits: int
:param exponent: The exponent of the Lp norm (e.g. 2 for the MSE).
:type exponent: int
:return: The optimal clipping value.
:rtype: float
)mydelimiter");
m.def("compute_kl_clipping", &computeKLClipping, py::arg("histogram"), py::arg("nb_bits"),
R"mydelimiter(
Given an input activation histogram, compute the optimal clipping value in the sense of the KL divergence.
:param histogram: The provided activation histogram.
:type histogram: list
:param nb_bits: The quantization number of bits.
:type nb_bits: int
:return: The optimal clipping value.
:rtype: float
)mydelimiter");
m.def("adjust_ranges", &adjustRanges, py::arg("clipping_mode"), py::arg("value_ranges"), py::arg("nb_bits"), py::arg("network"), py::arg("input_dataset"), py::arg("verbose"),
R"mydelimiter(
Return a corrected map of the provided activation ranges.
To do so compute the optimal clipping values for every node and multiply the input ranges by those values.
The method used to compute the clippings can be eihter 'MSE', 'AA', 'KL' or 'MAX'.
:param clipping_mode: The method used to compute the optimal clippings.
:type clipping_mode: enum
:param value_ranges: The map associating each affine node to its output range.
:type value_ranges: dict
:param nb_bits: The quantization number of bits.
:type nb_bits: int
:param network: The GraphView containing the considered nodes.
:type network: :py:class:`aidge_core.GraphView`
:param input_dataset: The input dataset, consisting of a list of input samples.
:type input_dataset: list of :py:class:`aidge_core.Tensor`
:param verbose: Whether to print the clipping values or not.
:type verbose: bool
:return: The corrected map associating to each provided node its clipped range.
:rtype: dict
)mydelimiter");
m.def("compute_sign_map", &computeSignMap, py::arg("network"), py::arg("verbose"),
R"mydelimiter(
For each node, compute the sign of its input and output values.
The goal of the routine is to maximize the number of unsigned IOs in order to double the value resolution when possible.
:param network: The GraphView to analyze.
:type network: :py:class:`aidge_core.GraphView`
:param verbose: Whether to print the sign map or not.
:type verbose: bool
:return: A map associating a pair of signs to each node of the GraphView (a sign for the input and one for the output).
:rtype: dict
)mydelimiter");
m.def("cross_layer_equalization", &crossLayerEqualization, py::arg("network"), py::arg("target_delta"),
R"mydelimiter(
Equalize the ranges of the nodes parameters by proceding iteratively.
Can only be applied to single branch networks (otherwise does not edit the graphView).
:param network: The GraphView to process.
:type network: :py:class:`aidge_core.GraphView`
:param target_delta: the stopping criterion (typical value : 0.01)
:type target_delta: float
)mydelimiter");
m.def("get_weight_ranges", &getWeightRanges, py::arg("network"),
R"mydelimiter(
Compute the weight ranges of every affine node. Provided for debuging purposes.
:param network: graphView The GraphView containing the affine nodes.
Compute the weight ranges of every affine nodes. Provided for debugging purposes.
:param network: The GraphView containing the affine nodes.
:type network: :py:class:`aidge_core.GraphView`
:return: A map associating each affine node name to it's corresponding weight range.
:rtype: dict
......@@ -105,10 +214,12 @@ void init_QuantPTQ(py::module &m) {
:type network: :py:class:`aidge_core.GraphView`
)mydelimiter");
m.def("compute_scaling_histograms", &computeScalingHistograms, py::arg("value_ranges"), py::arg("nb_bins"), py::arg("network"), py::arg("input_dataset"), "compute scaling histogram");
m.def("compute_best_clipping", &computeBestClipping, py::arg("histogram"), py::arg("nb_bits"), "compute the best clipping for an histogram");
m.def("dev_ptq", &devPTQ, py::arg("network"), "dev ptq");
m.def("dev_ptq", &devPTQ, py::arg("network"),
R"mydelimiter(
Developement and test routine.
:param network: The GraphView under test.
:type network: :py:class:`aidge_core.GraphView`
)mydelimiter");
}
PYBIND11_MODULE(aidge_quantization, m) {
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/quantization/PTQ/CLE.hpp"
#include "aidge/quantization/PTQ/Clip.hpp"
#include "aidge/quantization/PTQ/PTQ.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp"
#include "aidge/scheduler/Scheduler.hpp"
#include "aidge/utils/Log.hpp"
#include "aidge/operator/OperatorTensor.hpp"
namespace Aidge
{
static std::shared_ptr<Tensor> getWeightTensor(std::shared_ptr<Node> node)
{
return std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(1);
}
static std::shared_ptr<Tensor> getBiasTensor(std::shared_ptr<Node> node)
{
return std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(2);
}
static void rescaleTensor(std::shared_ptr<Tensor> tensor, float scaling)
{
// Get the tensor data pointer
float * castedTensor = static_cast <float *> (tensor->getImpl()->rawPtr());
// Rescale the tensor
for(std::size_t i = 0; i < tensor->size(); i++)
castedTensor[i] *= scaling;
}
static float getTensorAbsoluteMax(std::shared_ptr <Tensor> tensor)
{
// Get the tensor data pointer and edit it
float * castedTensor = static_cast<float*>(tensor->getImpl()->rawPtr());
// Get the tensor absolute max value
float maxValue = 0.0f;
for(std::size_t i = 0; i < tensor->size(); ++i) {
if(std::fabs(castedTensor[i]) > maxValue) {
maxValue = std::fabs(castedTensor[i]);
}
}
return maxValue;
}
void crossLayerEqualization(std::shared_ptr<GraphView> graphView, float targetDelta)
{
std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView);
// Check if the CLE can be applied ...
for (std::shared_ptr<Node> node : nodeVector)
if (node->getChildren().size() > 1)
{
Log::info(" Network have multiple branches, skipping the CLE ... ");
return;
}
Log::info(" Applying the Cross-Layer Equalization ... ");
// Get the vector of affine nodes
std::vector<std::shared_ptr<Node>> affineNodeVector;
for (std::shared_ptr<Node> node : nodeVector)
if (isAffine(node))
affineNodeVector.push_back(node);
float maxRangeDelta;
do
{
maxRangeDelta = 0.0;
/*
std::cout << " ----- " << std::endl;
for (std::shared_ptr<Node> node : affineNodeVector)
std::cout << getTensorAbsoluteMax(getWeightTensor(node)) << std::endl;
*/
for (size_t i = 0; i < (affineNodeVector.size() - 1); i++)
{
std::shared_ptr<Node> n1 = affineNodeVector[i];
std::shared_ptr<Node> n2 = affineNodeVector[i+1];
float r1 = getTensorAbsoluteMax(getWeightTensor(n1));
float r2 = getTensorAbsoluteMax(getWeightTensor(n2));
float s1 = std::sqrt(r1 * r2) / r1;
float s2 = std::sqrt(r1 * r2) / r2;
rescaleTensor(getWeightTensor(n1), s1);
rescaleTensor(getWeightTensor(n2), s2);
rescaleTensor(getBiasTensor(n1), s1);
float rangeDelta = std::abs(r1 - r2);
if (rangeDelta > maxRangeDelta)
maxRangeDelta = rangeDelta;
}
}
while (maxRangeDelta > targetDelta);
}
}
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/quantization/PTQ/CLE.hpp"
#include "aidge/quantization/PTQ/Clip.hpp"
#include "aidge/quantization/PTQ/PTQ.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp"
#include "aidge/scheduler/Scheduler.hpp"
namespace Aidge
{
std::map<std::string, std::vector<int>> computeHistograms(std::map<std::string, float> valueRanges, int nbBins, std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet)
{
std::shared_ptr<Node> firstNode = retrieveNodeVector(graphView)[0];
//std::cout << " COMPUTING HISTOGRAMS ... " << std::endl;
std::map<std::string, std::vector<int>> histograms;
SequentialScheduler scheduler(graphView);
scheduler.resetScheduling();
// Setup the histograms ...
for (std::shared_ptr<Node> node : graphView->getNodes())
{
bool isInsideRanges = (valueRanges.find(node->name()) != valueRanges.end());
if (isInsideRanges)
{
std::vector<int> histogram;
for (int i = 0; i < nbBins; i++)
histogram.push_back(0);
histograms.insert(std::make_pair(node->name(), histogram));
}
}
// Fill the histograms ...
scheduler.resetScheduling();
int it = 0;
for (std::shared_ptr<Tensor> inputTensor : inputDataSet)
{
Log::info(" IT (BIS) : {}", it++);
// Inference ...
scheduler.forward(true, {inputTensor});
// Gather values ...
for (std::shared_ptr<Node> node : graphView->getNodes())
{
bool isInsideRanges = (valueRanges.find(node->name()) != valueRanges.end());
if (isInsideRanges)
{
float valueRange = valueRanges[node->name()];
std::shared_ptr<Operator> nodeOperator = node->getOperator();
std::shared_ptr<Tensor> valueTensor = std::static_pointer_cast<Tensor> (nodeOperator->getRawOutput(0));
float * castedTensor = static_cast<float *> (valueTensor->getImpl()->rawPtr());
for(std::size_t i = 0; i < valueTensor->size(); i++)
{
int bin = std::round(std::abs(castedTensor[i] / valueRange * nbBins));
histograms[node->name()][bin]++;
}
}
}
}
return histograms;
}
float computeMEClipping(std::vector<int> histogram, std::uint8_t nbBits, float exponent)
{
int nbBins = histogram.size();
int nbIter = 100;
int signedMax = (1 << (nbBits - 1)) - 1;
std::vector<float> clippingErrors;
for (int it = 1; it < nbIter; it++)
{
// Compute the rounding cost of this particular clipping ...
float accumulatedError = 0.0;
float clipping = it / static_cast<float> (nbIter);
for (int bin = 0; bin < nbBins; bin++)
{
float value = (bin + 0.5) / nbBins;
float scaling = signedMax / clipping;
float rounded = std::round(value * scaling) / scaling;
float clipped = std::min(clipping, rounded);
float approxError = std::abs(clipped - value);
accumulatedError += std::pow(approxError, exponent) * histogram[bin];
}
clippingErrors.push_back(accumulatedError);
}
std::vector<float>::iterator it = std::min_element(clippingErrors.begin(), clippingErrors.end());
float bestClipping = static_cast<float> (std::distance(clippingErrors.begin(), it)) / static_cast<float> (nbIter);
return bestClipping;
}
float computeKLClipping(std::vector<int> refHistogram, std::uint8_t nbBits)
{
// KL Clipping
int nbIter = 100;
int signedMax = (1 << (nbBits - 1)) - 1;
float refNorm = 0;
for (int n : refHistogram)
refNorm += static_cast<float> (n);
std::vector<float> clippingErrors;
for (int it = 1; it < nbIter; it++)
{
float clipping = it / static_cast<float> (nbIter);
// Create the histogram for this particular clipping ...
std::vector<int> quantHistogram;
for (int i = 0; i < signedMax; i++)
quantHistogram.push_back(0);
for (std::size_t refBin = 0; refBin < refHistogram.size(); refBin++)
{
float value = (static_cast<float> (refBin) + 0.5f) / static_cast<float> (refHistogram.size());
int quantBin = std::floor(value / clipping * signedMax);
quantBin = std::min(quantBin, signedMax-1);
quantHistogram[quantBin] += refHistogram[refBin];
}
// Compute the mass of the histogram
float quantNorm = 0;
for (std::size_t refBin = 0; refBin < refHistogram.size(); refBin++)
{
float value = (static_cast<float> (refBin) + 0.5f) / static_cast<float> (refHistogram.size());
int quantBin = std::floor(value / clipping * signedMax);
if (quantBin < static_cast<int> (quantHistogram.size()))
quantNorm += quantHistogram[quantBin];
}
// Compute the KL divergence
float accumulatedError = 0.0;
for (std::size_t refBin = 0; refBin < refHistogram.size(); refBin++)
{
float value = (static_cast<float> (refBin) + 0.5f) / static_cast<float> (refHistogram.size());
int quantBin = std::floor(value / clipping * signedMax);
float p = static_cast<float> (refHistogram[refBin]) / refNorm;
float q = (quantBin < static_cast<int> (quantHistogram.size())) ?
static_cast<float> (quantHistogram[quantBin]) / quantNorm : 0;
if (p != 0 && q != 0)
accumulatedError += q * std::log(q / p);
}
clippingErrors.push_back(accumulatedError);
}
std::vector<float>::iterator it = std::min_element(clippingErrors.begin() + 1, clippingErrors.end());
float bestClipping = static_cast<float> (std::distance(clippingErrors.begin(), it)) / static_cast<float> (nbIter);
return bestClipping;
}
std::map<std::string, float> adjustRanges(Clipping clippingMode, std::map<std::string, float> valueRanges, std::uint8_t nbBits, std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet, bool verbose)
{
/*
std::set<std::string> supportedModes({"MSE", "KL", "MAX", "AA"});
bool isSupported = supportedModes.find(clippingMode) != supportedModes.end();
if (!isSupported)
{
Log::info(" Clipping mode '{}' is not supported. No clipping will be applied ...", clippingMode);
return valueRanges;
}
*/
float clipping = 1.0f;
int nbBins = (1 << (nbBits + 4)) ; // XXX Enhance this !!!
if (clippingMode != Clipping::MAX)
{
if (verbose)
Log::info(" === CLIPPING VALUES === ");
std::map<std::string, std::vector<int>> histograms = computeHistograms(valueRanges, nbBins, graphView, inputDataSet);
for (std::shared_ptr<Node> node : graphView->getNodes())
if (node->type() == "Scaling")
{
std::vector<int> histogram = histograms[node->name()];
if (clippingMode == Clipping::MSE)
clipping = computeMEClipping(histogram, nbBits, 2.0);
if (clippingMode == Clipping::AA)
clipping = computeMEClipping(histogram, nbBits, 1.0);
if (clippingMode == Clipping::KL)
clipping = computeKLClipping(histogram, nbBits);
if (verbose)
Log::info(" {:.6f} ({})", clipping, node->name());
valueRanges[node->name()] *= clipping;
}
}
return valueRanges;
}
}
\ No newline at end of file
......@@ -9,38 +9,26 @@
*
********************************************************************************/
#include "aidge/QuantPTQ.hpp"
#include <algorithm> // std::find, std::reverse
#include <cmath> // std::round
#include <cstddef> // std::size_t
#include <cstdint> // std::uint8_t
#include <map>
#include <memory>
#include <set>
#include <string>
#include <utility> // std::make_pair
#include <vector>
#include "aidge/quantization/PTQ/CLE.hpp"
#include "aidge/quantization/PTQ/Clip.hpp"
#include "aidge/quantization/PTQ/PTQ.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp"
#include "aidge/scheduler/Scheduler.hpp"
#include "aidge/utils/Log.hpp"
#include "aidge/operator/Add.hpp"
#include "aidge/operator/Concat.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/ConvDepthWise.hpp"
#include "aidge/operator/FC.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/operator/Mul.hpp"
#include "aidge/operator/ReLU.hpp"
#include "aidge/operator/Scaling.hpp"
#include "aidge/recipes/Recipes.hpp"
#include "aidge/operator/MetaOperator.hpp"
namespace Aidge{
namespace Aidge
{
static std::string makeUniqueName(std::string baseName, std::shared_ptr<GraphView> graphView)
{
std::set<std::string> existingNames;
......@@ -64,27 +52,30 @@ static std::string makeUniqueName(std::string baseName, std::shared_ptr<GraphVie
return newName;
}
static bool isAffine(std::shared_ptr<Node> node)
bool isAffine(std::shared_ptr<Node> node)
{
std::set<std::string> affineNodeTypes({"FC", "Conv", "ConvDepthWise", "PaddedConv", "PaddedConvDepthWise"});
return (affineNodeTypes.find(node->type()) != affineNodeTypes.end());
}
static bool isSeamless(std::shared_ptr<Node> node)
bool isSeamless(std::shared_ptr<Node> node)
{
std::set<std::string> seamlessNodeTypes({"Pad", "MaxPooling", "AvgPooling", "PaddedMaxPooling", "PaddedAvgPooling", "GlobalAveragePooling"});
return (seamlessNodeTypes.find(node->type()) != seamlessNodeTypes.end());
}
bool isMerging(std::shared_ptr<Node> node)
{
return (mergingNodeTypes.find(node->type()) != mergingNodeTypes.end());
}
bool checkArchitecture(std::shared_ptr<GraphView> graphView)
{
std::set<std::string> otherNodeTypes({"Flatten", "Add", "Concat", "Softmax", "ReLU", "Producer"});
std::set<std::string> otherNodeTypes({"Flatten", "Softmax", "ReLU", "Producer"});
for (std::shared_ptr<Node> node : graphView->getNodes())
{
bool isOther = otherNodeTypes.find(node->type()) != otherNodeTypes.end();
if (!isOther && !isAffine(node) && !isSeamless(node)) {
Log::info(" GraphView can't be quantized : node type {} is not supported !", node->type());
if (!isOther && !isAffine(node) && !isSeamless(node) && !isMerging(node)) {
Log::warn(" GraphView can't be quantized : node type {} is not supported !", node->type());
return false;
}
}
......@@ -92,38 +83,6 @@ bool checkArchitecture(std::shared_ptr<GraphView> graphView)
return true;
}
static std::shared_ptr<Node> getFirstNode(std::shared_ptr<GraphView> graphView)
{
std::shared_ptr<Node> currNode = graphView->rootNode();
if (currNode->type() == "Producer")
currNode = *(currNode->getChildren()).begin();
std::shared_ptr<Node> parentNode = currNode->getParent(0);
while (parentNode->type() != "Producer") {
currNode = parentNode;
parentNode = currNode->getParent(0);
}
return currNode;
}
static std::shared_ptr<Node> getLastNode(std::shared_ptr<GraphView> graphView)
{
std::shared_ptr<Node> currNode = graphView->rootNode();
while (currNode->getChildren().size() != 0)
currNode = (*currNode->getChildren().begin());
return currNode;
}
static void popSoftMax(std::shared_ptr<GraphView> graphView)
{
std::shared_ptr<Node> lastNode = getLastNode(graphView);
if (lastNode->type() == "Softmax") {
graphView->replace({lastNode}, {}); // remove does not work !!!
}
}
static void fillTensor(std::shared_ptr<Tensor> tensor, float value)
{
// Get the tensor data pointer
......@@ -208,42 +167,21 @@ static std::shared_ptr<Tensor> getBiasTensor(std::shared_ptr<Node> node)
return std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(2);
}
void appendIdentity(std::shared_ptr<GraphView> graphView) {
std::shared_ptr<Node> lastNode = getLastNode(graphView);
int size = std::static_pointer_cast<OperatorTensor> (lastNode->getOperator())->getOutput(0)->size();
std::shared_ptr<Node> identityNode = FC(size, size, true, makeUniqueName("Identity", graphView));
identityNode->getOperator()->setDataType(DataType::Float32);
identityNode->getOperator()->setBackend("cpu");
std::shared_ptr<Tensor> weightTensor = std::static_pointer_cast<Tensor> (identityNode->getOperator()->getRawInput(1));
fillTensor(weightTensor, 0);
float * castedWeightTensor = static_cast<float *> (weightTensor->getImpl()->rawPtr());
for (int n = 0; n < size; n++)
castedWeightTensor[n + size * n] = 1.0;
graphView->addChild(identityNode);
}
std::vector<std::shared_ptr<Node>> extractNodeVector(std::shared_ptr<GraphView> graphView, bool verbose)
std::vector<std::shared_ptr<Node>> retrieveNodeVector(std::shared_ptr<GraphView> graphView, bool newSchedule, bool verbose)
{
std::vector<std::shared_ptr<Node>> nodeVector;
SequentialScheduler scheduler(graphView);
scheduler.forward();
nodeVector = scheduler.getStaticScheduling();
//graphView->forwardDims();
//scheduler.generateScheduling();
//nodeVector = scheduler.getStaticScheduling();
if (newSchedule)
{
scheduler.resetScheduling();
scheduler.generateScheduling(); // old way : scheduler.forward();
}
nodeVector = scheduler.getStaticScheduling();
fixScheduling(nodeVector);
removeMatchingNodes(nodeVector, "Producer");
if (verbose)
......@@ -253,19 +191,61 @@ std::vector<std::shared_ptr<Node>> extractNodeVector(std::shared_ptr<GraphView>
Log::info("{} {}", node->type(), node->name());
}
//for (auto node : nodeVector)
// std::cout << node->type() << std::endl;
return nodeVector;
}
static std::shared_ptr<Node> getFirstNode(std::shared_ptr<GraphView> graphView)
{
return retrieveNodeVector(graphView)[0];
}
static std::shared_ptr<Node> getLastNode(std::shared_ptr<GraphView> graphView)
{
std::shared_ptr<Node> currNode = graphView->rootNode();
while (currNode->getChildren().size() != 0)
currNode = (*currNode->getChildren().begin());
return currNode;
}
static void popSoftMax(std::shared_ptr<GraphView> graphView)
{
std::shared_ptr<Node> lastNode = getLastNode(graphView);
if (lastNode->type() == "Softmax") {
graphView->replace({lastNode}, {}); // remove does not work !!!
}
}
static void prepareNetwork(std::shared_ptr<GraphView> graphView)
{
removeFlatten(graphView);
bool containsBatchNorm = false;
std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView);
for (std::shared_ptr<Node> node : nodeVector)
if (node->type() == "BatchNorm")
{
containsBatchNorm = true;
break;
}
if (containsBatchNorm)
fuseBatchNorm(graphView);
return nodeVector;
popSoftMax(graphView);
}
// XXX HERE : Branches containing only Seamless nodes should be considered as residual too !!!
void insertResidualNodes(std::shared_ptr<GraphView> graphView)
{
std::vector<std::shared_ptr<Node>> nodeVector = extractNodeVector(graphView, false);
// TODO: double check this ...
std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView);
for (std::shared_ptr<Node> node : nodeVector)
{
if (node->type() == "Add" || node->type() == "Concat")
if (isMerging(node))
{
int nbParents = node->getParents().size();
for (int i = 0; i < nbParents; i++)
......@@ -274,6 +254,7 @@ void insertResidualNodes(std::shared_ptr<GraphView> graphView)
bool parentIsForking = (parentNode->getChildren().size() > 1);
if (parentIsForking)
{
// temporary verbose ...
Log::info(" ### found residual branch at index {}", i);
Log::info(" ### inserting multiplicative node ...");
......@@ -287,10 +268,15 @@ void insertResidualNodes(std::shared_ptr<GraphView> graphView)
}
}
}
graphView->forwardDims();
}
static int getInputIndex(std::shared_ptr<Node> node, std::shared_ptr<Node> parentNode)
{
int index = 0;
while (node->getParent(index) != parentNode)
index++;
return index;
}
void insertScalingNodes(std::shared_ptr<GraphView> graphView)
{
......@@ -298,23 +284,36 @@ void insertScalingNodes(std::shared_ptr<GraphView> graphView)
std::set<std::shared_ptr<Node>> nodeSet = graphView->getNodes();
for (std::shared_ptr<Node> node : nodeSet)
for (std::shared_ptr<Node> parentNode : nodeSet)
{
if (isAffine(node))
if (isAffine(parentNode) || isMerging(parentNode))
{
std::string scalingNodeName = makeUniqueName(node->name() + "_Scaling", graphView);
std::string scalingNodeName = makeUniqueName(parentNode->name() + "_Scaling", graphView);
std::shared_ptr<Node> scalingNode = Scaling(1.0, 0, false, scalingNodeName);
scalingNode->getOperator()->setDataType(DataType::Float32);
scalingNode->getOperator()->setBackend("cpu");
if (node->getChildren().size() > 0)
if (parentNode->getChildren().size() > 0)
{
std::shared_ptr<Node> nextNode = node->getChildren(0)[0];
// SCALING NODE INSERTION
// We always have one output from Affine and Add nodes, but possibly multiple childs
std::vector<std::shared_ptr<Node>> nextNodes = parentNode->getChildren(0);
// XXX TODO : be extra careful about this ...
int i = 0;
while (nextNode->getParent(i) != node) i++;
graphView->insertParent(nextNode, scalingNode, i, 0, 0);
// For each node in nextNodes store the connexion index
std::vector<int> inputIndices(nextNodes.size());
for (std::size_t i = 0; i < nextNodes.size(); i++)
inputIndices[i] = getInputIndex(nextNodes[i], parentNode);
for (std::shared_ptr<Node> nextNode : nextNodes)
parentNode->removeChild(nextNode, 0);
parentNode->addChild(scalingNode, 0, 0);
for (std::size_t i = 0; i < nextNodes.size(); i++)
scalingNode->addChild(nextNodes[i], 0, inputIndices[i]);
graphView->add(scalingNode);
}
else
{
......@@ -323,16 +322,8 @@ void insertScalingNodes(std::shared_ptr<GraphView> graphView)
}
}
}
graphView->forwardDims();
// XXX Append identity if needed ...
if (getLastNode(graphView)->type() == "Scaling")
appendIdentity(graphView);
}
static std::shared_ptr<Node> getPreviousScalingNode(std::shared_ptr<Node> mergingNode)
{
std::shared_ptr<Node> currNode = mergingNode;
......@@ -348,26 +339,28 @@ static std::shared_ptr<Node> getPreviousScalingNode(std::shared_ptr<Node> mergin
return currNode;
}
// Be more careful about the '*' and '/' ...
void normalizeParameters(std::shared_ptr<GraphView> graphView)
{
// CREATE THE ACCUMULATED RATIO MAP ///////////////////////////////////////
std::vector<std::shared_ptr<Node>> nodeVector = extractNodeVector(graphView, false);
std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView);
std::map<std::string, float> accumulatedRatios;
for (std::shared_ptr<Node> node : nodeVector)
{
accumulatedRatios.insert(std::make_pair(node->name(), 1.0));
}
// ITERATE OVER THE GRAPH /////////////////////////////////////////////////
std::shared_ptr<Node> firstNode = getFirstNode(graphView);
for (std::shared_ptr<Node> node : nodeVector)
{
// Scaling nodes still have a ratio of 1, so they are seamless ...
if (node->type() == "ReLU" || node->type() == "Scaling" || isSeamless(node))
{
if (node != getFirstNode(graphView))
if (node != firstNode)
{
std::shared_ptr<Node> prevNode = node->getParent(0);
accumulatedRatios[node->name()] = accumulatedRatios[prevNode->name()];
......@@ -384,8 +377,10 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView)
rescaleTensor(weightTensor, ratio);
// Accumulate the ratio
if (node == getFirstNode(graphView))
if (node == firstNode)
{
accumulatedRatios[node->name()] = ratio;
}
else
{
std::shared_ptr<Node> prevNode = node->getParent(0);
......@@ -397,17 +392,15 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView)
if (nodeHasBias)
{
std::shared_ptr<Tensor> biasTensor = getBiasTensor(node);
// Check that a bias is present (as it is optional)
if (biasTensor) {
if (biasTensor)
rescaleTensor(biasTensor, accumulatedRatios[node->name()] );
}
}
}
if (node->type() == "Add" || node->type() == "Concat")
if (isMerging(node))
{
// We should assert if merging nodes are all scalings !
std::vector<std::shared_ptr<Node>> mergingNodes = node->getParents();
// Compute the max ratio ...
......@@ -437,27 +430,23 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView)
}
}
std::map<std::string, float> computeRanges(std::shared_ptr<GraphView> graphView, std::shared_ptr<Tensor> inputTensor)
std::map<std::string, float> computeRanges(std::shared_ptr<GraphView> graphView, std::shared_ptr<Tensor> inputTensor, bool scalingNodesOnly)
{
std::map<std::string, float> valueRanges;
SequentialScheduler scheduler(graphView);
scheduler.resetScheduling();
std::shared_ptr<Node> inputNode = getFirstNode(graphView);
// Setup the input
std::shared_ptr<Node> inputProducer = inputNode->getParent(0);
inputProducer->getOperator()->setOutput(0, inputTensor);
// Inference ...
// Forward ...
scheduler.forward();
scheduler.forward(true, {inputTensor});
// Gather ranges ...
std::set<std::shared_ptr<Node>> nodeSet = graphView->getNodes();
for (std::shared_ptr<Node> node : nodeSet)
{
if (node->type() == "Scaling") // XXX (node->type() != "Producer")
if ((scalingNodesOnly && (node->type() == "Scaling")) || (!scalingNodesOnly && (node->type() != "Producer")))
{
std::shared_ptr<Operator> nodeOperator = node->getOperator();
std::shared_ptr<Tensor> valueTensor = std::static_pointer_cast<Tensor> (nodeOperator->getRawOutput(0));
......@@ -471,38 +460,69 @@ std::map<std::string, float> computeRanges(std::shared_ptr<GraphView> graphView,
return valueRanges;
}
std::map<std::string, float> computeRanges(std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet)
std::map<std::string, float> computeRanges(std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet, bool scalingNodesOnly)
{
std::map<std::string, float> valueRanges;
std::set<std::shared_ptr<Node>> nodeSet = graphView->getNodes();
// std::shared_ptr<Node> inputNode = getFirstNode(graphView);
for (std::shared_ptr<Node> node : nodeSet)
if (node->type() == "Scaling") // XXX (node->type() != "Producer")
if ((scalingNodesOnly && (node->type() == "Scaling")) || (!scalingNodesOnly && (node->type() != "Producer")))
valueRanges.insert(std::make_pair(node->name(), 0));
//int i = 0;
SequentialScheduler scheduler(graphView);
scheduler.resetScheduling();
int it = 0;
for (std::shared_ptr<Tensor> sample : inputDataSet)
{
std::map<std::string, float> sampleRanges = computeRanges(graphView, sample);
Log::info(" IT : {}", it++);
// Inference ...
scheduler.forward(true, {sample});
// Gather the sample ranges ...
std::map<std::string, float> sampleRanges;
for (std::shared_ptr<Node> node : nodeSet)
{
if (node->type() == "Scaling") // XXX (node->type() != "Producer")
if ((scalingNodesOnly && (node->type() == "Scaling")) || (!scalingNodesOnly && (node->type() != "Producer")))
{
std::string nodeName = node->name();
if (sampleRanges[nodeName] > valueRanges[nodeName])
valueRanges[nodeName] = sampleRanges[nodeName];
std::shared_ptr<Operator> nodeOperator = node->getOperator();
std::shared_ptr<Tensor> valueTensor = std::static_pointer_cast<Tensor> (nodeOperator->getRawOutput(0));
float range = getTensorAbsoluteMax(valueTensor);
// Associate the value to the scaling node ...
sampleRanges.insert(std::make_pair(node->name(), range));
}
}
}
// Update the global value ranges ...
for (std::shared_ptr<Node> node : nodeSet)
{
if ((scalingNodesOnly && (node->type() == "Scaling")) || (!scalingNodesOnly && (node->type() != "Producer")))
{
std::string nodeName = node->name();
if (sampleRanges[nodeName] > valueRanges[nodeName])
valueRanges[nodeName] = sampleRanges[nodeName];
}
}
}
return valueRanges;
}
void normalizeActivations(std::shared_ptr<GraphView> graphView, std::map<std::string, float> valueRanges)
{
std::shared_ptr<Node> firstNode = getFirstNode(graphView);
// CREATE THE SCALING FACTOR MAP //////////////////////////////////////////
std::vector<std::shared_ptr<Node>> nodeVector = extractNodeVector(graphView, false);
std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView);
std::map<std::string, float> scalingFactors;
......@@ -514,9 +534,10 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::map<std::st
for (std::shared_ptr<Node> node : nodeVector)
{
// Seamless scaling factor propagation ...
if (isAffine(node) || isSeamless(node) || node->type() == "ReLU")
if (isAffine(node) || isSeamless(node) || node->type() == "ReLU")
{
if (node == getFirstNode(graphView))
if (node == firstNode)
{
scalingFactors[node->name()] = 1.0;
}
......@@ -527,52 +548,51 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::map<std::st
}
}
if (node->type() == "Scaling")
// Here prevNode is either a 'Affine' or a 'Merging'
// => do not split the cases, just handle the bias ...
if (node->type() == "Scaling")
{
// Retreive the previous scaling factor ...
// retrieve the previous scaling factor ...
std::shared_ptr<Node> prevNode = node->getParent(0);
float prevScalingFactor = scalingFactors[prevNode->name()];
// XXX HERE : valueRanges must contains all the scaling nodes !!!
float scalingFactor = valueRanges[node->name()];
// ValueRanges must contains all the scaling nodes !!!
float scalingFactor = valueRanges[node->name()];
std::shared_ptr<Scaling_Op> scalingOperator = std::static_pointer_cast<Scaling_Op> (node->getOperator());
scalingOperator->scalingFactor() /= (scalingFactor / prevScalingFactor);
scalingFactors[node->name()] = scalingFactor;
// Fix the bias ...
bool prevNodeHasBias = (prevNode->getParents().size() == 3);
if (prevNodeHasBias) {
std::shared_ptr<Tensor> biasTensor = getBiasTensor(prevNode);
rescaleTensor(biasTensor, 1.0 / prevScalingFactor);
// If prevNode is Affine, fix the bias ...
if (isAffine(prevNode))
{
bool prevNodeHasBias = (prevNode->getParents().size() == 3);
if (prevNodeHasBias) {
std::shared_ptr<Tensor> biasTensor = getBiasTensor(prevNode);
rescaleTensor(biasTensor, 1.0 / prevScalingFactor);
}
}
}
if (node->type() == "Concat" || node->type() == "Add")
// Merging nodes handling : use a maximum arbritation ...
if (isMerging(node))
{
// We should assert if merging nodes are all scalings !
std::vector<std::shared_ptr<Node>> mergingNodes = node->getParents();
// Compute the max scaling ...
float maxScaling = 0;
int maxNodeIndex = 0;
for (std::size_t i = 0; i < mergingNodes.size(); i++)
{
float merginNodeScaling = scalingFactors[mergingNodes[i]->name()];
if (merginNodeScaling > maxScaling) {
maxScaling = merginNodeScaling;
maxNodeIndex = i;
}
}
// Ensure that the adding node does not overflow ...
if (node->type() == "Add") {
std::shared_ptr<Node> maxNode = mergingNodes[maxNodeIndex];
maxScaling /= valueRanges[getPreviousScalingNode(maxNode)->name()];
maxScaling *= valueRanges[getPreviousScalingNode(node)->name()];
}
scalingFactors[node->name()] = maxScaling;
for (std::shared_ptr<Node> mergingNode : mergingNodes)
......@@ -590,214 +610,384 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::map<std::st
}
}
void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits)
std::map<std::string, std::pair<bool, bool>> computeSignMap(std::shared_ptr<GraphView> graphView, bool verbose)
{
float signedMax = (1 << (nbBits - 1)) - 1;
std::shared_ptr<Node> firstNode = getFirstNode(graphView);
// ITERATE OVER THE GRAPH /////////////////////////////////////////////////
std::map<std::string, std::pair<bool, bool>> signMap;
std::pair<bool, bool> unsignedPair(true, true);
for (std::shared_ptr<Node> node : graphView->getNodes())
if (node->type() != "Producer")
signMap.insert(std::make_pair(node->name(), unsignedPair));
// ITERATE OVER THE GRAPH
std::vector<std::shared_ptr<Node>> nodeVector = extractNodeVector(graphView, false);
std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView);
for (std::shared_ptr<Node> node : nodeVector)
{
// XXX should be removed when the Scaling issue is fixed !!!
bool isLastIdentity = (node->type() == "FC") && (node == getLastNode(graphView));
bool isFirstNode = (node == firstNode);
if (isAffine(node))
{
// Affine nodes always have a single parent
if (!isFirstNode)
signMap[node->name()].first = signMap[node->getParent(0)->name()].second;
else
signMap[node->name()].first = false;
signMap[node->name()].second = false;
}
if (isAffine(node) && !isLastIdentity)
if (node->type() == "Scaling")
{
// Rescale the weight tensor
std::shared_ptr<Tensor> weightTensor = getWeightTensor(node);
rescaleTensor(weightTensor, signedMax);
roundTensor(weightTensor);
signMap[node->name()].second = false;
// Rescale the bias tensor
bool nodeHasBias = (node->getParents().size() == 3);
if (nodeHasBias) {
std::shared_ptr<Tensor> biasTensor = getBiasTensor(node);
rescaleTensor(biasTensor, signedMax * signedMax);
roundTensor(biasTensor);
// Scaling nodes always have a single parent
std::shared_ptr<Node> parent = node->getParent(0);
bool allChildrenAreReLU = true;
allChildrenAreReLU &= !(node->getChildren().empty()); // a bit convoluted ...
for (std::shared_ptr <Node> child : node->getChildren())
allChildrenAreReLU &= (child->type() == "ReLU");
// Correct the previous single node (when it is an Affine node) ...
if (allChildrenAreReLU)
if (isAffine(parent) || isMerging(parent))
signMap[parent->name()].second = true;
// Maintain unsigned output
if (signMap[parent->name()].second)
signMap[node->name()].second = true;
// Set the link ...
signMap[node->name()].first = signMap[parent->name()].second;
}
if (isMerging(node))
{
std::vector<std::shared_ptr<Node>> parentNodes = node->getParents();
bool allParentAreSigned = true;
bool allParentAreUnsigned = true;
for(std::shared_ptr<Node> parent : parentNodes)
{
bool parentSign = signMap[parent->name()].second;
allParentAreSigned &= !parentSign;
allParentAreUnsigned &= parentSign;
}
std::shared_ptr<Node> scalingNode = *(node->getChildren().begin());
std::shared_ptr<Scaling_Op> scalingOperator = std::static_pointer_cast<Scaling_Op> (scalingNode->getOperator());
scalingOperator->scalingFactor() /= signedMax;
scalingOperator->quantizedNbBits() = nbBits;
if (allParentAreSigned)
signMap[node->name()] = std::make_pair(false, false);
else if (allParentAreUnsigned)
signMap[node->name()] = std::make_pair(true, true);
else
{
// Arbitration : Signed type wins !
for(std::shared_ptr<Node> parent : parentNodes)
{
while (parent->type() != "Scaling")
{
signMap[parent->name()] = std::make_pair(false, false);
// We are on a branch so nodes always have 1 parent ...
parent = parent->getParent(0);
}
signMap[parent->name()].second = false;
}
signMap[node->name()].first = false;
}
}
if (node->type() == "ReLU" || isSeamless(node))
{
// Thoses nodes always have a single parent
std::shared_ptr<Node> parent = node->getParent(0);
signMap[node->name()].first = signMap[parent->name()].second;
signMap[node->name()].second = signMap[node->name()].first;
}
}
// Ensure that residual scaling nodes are also quantized ...
for (std::shared_ptr<Node> node : nodeVector)
// VERBOSE
if (verbose)
{
if (node->type() == "Scaling")
Log::info(" === SIGN MAP === ");
for (std::shared_ptr<Node> node : nodeVector)
Log::info(" {}{} | {}", static_cast<int>(signMap[node->name()].first), static_cast<int>(signMap[node->name()].second), node->name());
}
// SANITY CHECK (TEMPORARY)
for (std::shared_ptr<Node> node : nodeVector)
if (node != firstNode)
{
std::shared_ptr<Scaling_Op> scalingOperator = std::static_pointer_cast<Scaling_Op> (node->getOperator());
scalingOperator->quantizedNbBits() = nbBits; // XXX HERE !!!
for (std::shared_ptr<Node> parent : node->getParents())
if (parent->type() != "Producer")
if (signMap[parent->name()].second != signMap[node->name()].first)
Log::error(" computeSignMap : link is not sane ! ({} -> {})", parent->name(), node->name());
}
}
return signMap;
}
std::map<std::string, std::vector<int>> computeScalingHistograms(std::map<std::string, float> valueRanges, int nbBins, std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet)
void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, bool applyRounding, bool optimizeSigns, bool verbose)
{
//std::cout << " COMPUTING HISTOGRAMS ... " << std::endl;
float signedMax = (1 << (nbBits - 1)) - 1;
float unsignedMax = (1 << nbBits) - 1;
std::map<std::string, std::vector<int>> histograms;
std::map<std::string, std::pair<bool, bool>> signMap;
SequentialScheduler scheduler(graphView);
if (optimizeSigns)
signMap = computeSignMap(graphView, verbose);
else
{
std::pair<bool, bool> signedPair(false, false);
for (std::shared_ptr<Node> node : graphView->getNodes())
if (node->type() != "Producer")
signMap.insert(std::make_pair(node->name(), signedPair));
}
std::shared_ptr<Node> inputNode = getFirstNode(graphView);
// ITERATE OVER THE GRAPH /////////////////////////////////////////////////
// Setup the histograms ...
std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView);
for (std::shared_ptr<Node> node : graphView->getNodes())
for (std::shared_ptr<Node> node : nodeVector)
{
if (node->type() == "Scaling")
if (isAffine(node))
{
std::vector<int> histogram;
for (int i = 0; i < nbBins; i++)
histogram.push_back(0);
// Rescale the weight tensor
std::shared_ptr<Tensor> weightTensor = getWeightTensor(node);
rescaleTensor(weightTensor, signedMax);
if (applyRounding)
roundTensor(weightTensor);
// Rescale the bias tensor
bool nodeHasBias = (node->getParents().size() == 3);
if (nodeHasBias)
{
bool inputIsUnsigned = signMap[node->name()].first;
float rescaling = inputIsUnsigned ? unsignedMax * signedMax : signedMax * signedMax;
histograms.insert(std::make_pair(node->name(), histogram));
std::shared_ptr<Tensor> biasTensor = getBiasTensor(node);
rescaleTensor(biasTensor, rescaling);
if (applyRounding)
roundTensor(biasTensor);
}
// Compensate the rescaling using the next Scaling node
float rescaling = 1.0 / signedMax;
bool inputIsUnsigned = signMap[node->name()].first;
bool outputIsUnsigned = signMap[node->name()].second;
rescaling /= inputIsUnsigned ? unsignedMax : signedMax;
rescaling *= outputIsUnsigned ? unsignedMax : signedMax;
std::shared_ptr<Node> scalingNode = *(node->getChildren().begin()); // Assert if scalingNode is a Scaling ...
std::shared_ptr<Scaling_Op> scalingOperator = std::static_pointer_cast<Scaling_Op> (scalingNode->getOperator());
scalingOperator->scalingFactor() *= rescaling;
}
}
if (isMerging(node))
{
float rescaling = 1.0;
// Fill the histograms ...
bool inputIsUnsigned = signMap[node->name()].first;
bool outputIsUnsigned = signMap[node->name()].second;
for (std::shared_ptr<Tensor> inputTensor : inputDataSet)
{
// Setup the input
std::shared_ptr<Node> inputProducer = inputNode->getParent(0);
inputProducer->getOperator()->setOutput(0, inputTensor);
rescaling /= inputIsUnsigned ? unsignedMax : signedMax;
rescaling *= outputIsUnsigned ? unsignedMax : signedMax;
// Forward ...
scheduler.forward();
std::shared_ptr<Node> scalingNode = *(node->getChildren().begin()); // Assert if scalingNode is a Scaling ...
std::shared_ptr<Scaling_Op> scalingOperator = std::static_pointer_cast<Scaling_Op> (scalingNode->getOperator());
scalingOperator->scalingFactor() *= rescaling;
}
// Handle the Scaling Nodes ...
// Gather values ...
for (std::shared_ptr<Node> node : graphView->getNodes())
if (node->type() == "Scaling")
{
if (node->type() == "Scaling")
std::shared_ptr<Scaling_Op> scalingOperator = std::static_pointer_cast<Scaling_Op> (node->getOperator());
if (optimizeSigns)
{
float valueRange = valueRanges[node->name()];
float rescaling = 1.0;
std::shared_ptr<Operator> nodeOperator = node->getOperator();
std::shared_ptr<Tensor> valueTensor = std::static_pointer_cast<Tensor> (nodeOperator->getRawOutput(0));
bool inputIsUnsigned = signMap[node->name()].first;
bool outputIsUnsigned = signMap[node->name()].second;
float * castedTensor = static_cast<float *> (valueTensor->getImpl()->rawPtr());
for(std::size_t i = 0; i < valueTensor->size(); i++)
{
int bin = std::round(std::abs(castedTensor[i] / valueRange * nbBins));
histograms[node->name()][bin]++;
}
rescaling /= inputIsUnsigned ? unsignedMax : signedMax;
rescaling *= outputIsUnsigned ? unsignedMax : signedMax;
scalingOperator->scalingFactor() *= rescaling;
scalingOperator->isOutputUnsigned() = outputIsUnsigned;
}
if (applyRounding)
scalingOperator->quantizedNbBits() = nbBits;
}
}
return histograms;
}
float computeBestClipping(std::vector<int> histogram, std::uint8_t nbBits)
static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits)
{
//std::cout << " TEST " << std::endl;
int nbBins = histogram.size();
int nbIter = 100;
int signedMax = (1 << (nbBits - 1)) - 1;
// XXX Use the signMap to increase the resolution when possible ...
float signedMax = (1 << (nbBits - 1)) - 1;
// Compute the cumulative approximation error :
// At each iteration we test a clipping candidate and loop over
// the histogram to accumulate the squared error
std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView);
std::vector<float> clippingErrors;
for (int it = 0; it < nbIter; it++)
for (std::shared_ptr<Node> node : nodeVector)
{
// Compute the rounding cost of this particular clipping ...
float acc = 0.0;
float clipping = it / static_cast<float> (nbIter);
for (int bin = 0; bin < nbBins; bin++)
{
float value = (bin + 0.5) / nbBins;
float scaling = signedMax / clipping;
float rounded = std::round(value * scaling) / scaling;
float clipped = std::min(clipping, rounded);
// A merging node is always followed by a scaling node at this point ...
if (node->type() == "Scaling")
{
bool prevNodeIsForking = ((node->getParent(0))->getChildren().size() > 1);
bool prevNodeIsAffine = isAffine(node->getParent(0));
bool insertNode = prevNodeIsForking || !prevNodeIsAffine;
if (insertNode)
{
// create and insert the multplicative node
std::string mulNodeName = makeUniqueName(node->name() + "_Mul", graphView);
std::shared_ptr<Node> mulNode = Mul(mulNodeName);
mulNode->getOperator()->setDataType(DataType::Float32);
mulNode->getOperator()->setBackend("cpu");
float approxError = (clipped - value);
acc += (approxError * approxError) * histogram[bin];
graphView->insertParent(node, mulNode, 0, 0, 0);
// create and insert the producer node
std::shared_ptr<Tensor> inputTensor = std::static_pointer_cast<Tensor> (mulNode->getOperator()->getRawInput(0));
std::shared_ptr<Tensor> coeffTensor = std::make_shared<Tensor>();
coeffTensor->setDataType(DataType::Float32);
coeffTensor->setBackend("cpu");
coeffTensor->resize(inputTensor->dims());
fillTensor(coeffTensor, 1);
std::shared_ptr<Node> producerNode = Producer(coeffTensor, makeUniqueName("coeff", graphView));
producerNode->addChild(mulNode);
graphView->add(producerNode);
// rescale the coeffs and edit scaling factor
fillTensor(coeffTensor, signedMax);
std::shared_ptr<Scaling_Op> scalingOperator = std::static_pointer_cast<Scaling_Op> (node->getOperator());
scalingOperator->scalingFactor() /= signedMax;
// TODO : double check this !!!
//std::cout << getTensorAbsoluteMax(coeffTensor) << std::endl;
}
}
clippingErrors.push_back(acc);
}
}
//for (int it = 0; it < nbIter; it++)
// std::cout << " it : " << it << " " << clippingErrors[it] << std::endl;
void performSingleShiftApproximation(std::shared_ptr<GraphView> graphView, bool applyRounding)
{
std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView);
float bestClipping = 0.0;
float minClippingError = clippingErrors[0];
for (int it = 0; it < nbIter; it++)
if (clippingErrors[it] < minClippingError)
for (std::shared_ptr<Node> node : nodeVector)
{
if (isAffine(node) || (node->type() == "Mul"))
{
bestClipping = it / static_cast<float> (nbIter);
minClippingError = clippingErrors[it];
}
std::shared_ptr<Node> scalingNode = (*node->getChildren().begin());
return bestClipping;
}
std::shared_ptr<Scaling_Op> scalingOperator = std::static_pointer_cast<Scaling_Op> (scalingNode->getOperator());
float base = scalingOperator->scalingFactor();
void adjustRanges(std::map<std::string, float>& valueRanges, std::uint8_t nbBits, std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet)
{
//std::cout << " BEFORE CLIPING : " << std::endl;
//std::map<std::string, float>::iterator it;
//for (it = valueRanges.begin(); it != valueRanges.end(); it++)
// std::cout << it->first << " " << it->second << std::endl;
float approx = std::pow(2, std::ceil(std::log2(base)));
int nbBins = (1 << (nbBits + 4)) ; // XXX Enhance this !!!
scalingOperator->scalingFactor() = approx;
std::map<std::string, std::vector<int>> histograms = computeScalingHistograms(valueRanges, nbBins, graphView, inputDataSet);
float ratio = base / approx;
for (std::shared_ptr<Node> node : graphView->getNodes())
std::shared_ptr<Tensor> weightTensor = getWeightTensor(node);
rescaleTensor(weightTensor, ratio);
if (applyRounding)
roundTensor(weightTensor);
bool nodeHasBias = (node->getParents().size() == 3);
if (nodeHasBias)
{
std::shared_ptr<Tensor> biasTensor = getBiasTensor(node);
rescaleTensor(biasTensor, ratio);
if (applyRounding)
roundTensor(biasTensor);
}
}
}
}
static void printScalingFactors(std::shared_ptr<GraphView> graphView)
{
Log::info(" === SCALING FACTORS === ");
for (auto node : retrieveNodeVector(graphView))
if (node->type() == "Scaling")
{
std::vector<int> histogram = histograms[node->name()];
float cliping = computeBestClipping(histogram, nbBits);
//std::cout << " cliping " << node->name() << " " << cliping << std::endl;
valueRanges[node->name()] *= cliping;
std::shared_ptr<Scaling_Op> scalingOperator = std::static_pointer_cast<Scaling_Op> (node->getOperator());
float factor = scalingOperator->scalingFactor();
Log::info(" {:.6f} ({})", factor, node->name());
}
//std::cout << " AFTER CLIPING : " << std::endl;
//for (it = valueRanges.begin(); it != valueRanges.end(); it++)
// std::cout << it->first << " " << it->second << std::endl;
}
void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet, bool OptimizeCliping)
void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet, Clipping clippingMode, bool applyRounding, bool optimizeSigns, bool singleShift, bool verbose)
{
Log::info(" === QUANT PTQ 0.2.8 === ");
Log::info(" === QUANT PTQ 0.2.19 === ");
if (!checkArchitecture(graphView))
return;
Log::info(" Removing the flatten nodes ... ");
removeFlatten(graphView);
Log::info(" Removing the Softmax node ... ");
popSoftMax(graphView);
Log::info(" Preparing the network for the PTQ ... ");
prepareNetwork(graphView);
Log::info(" Inserting the scaling nodes ...");
insertScalingNodes(graphView);
crossLayerEqualization(graphView);
Log::info(" Normalizing the parameters ...");
normalizeParameters(graphView);
Log::info(" Computing the value ranges ...");
std::map<std::string, float> valueRanges = computeRanges(graphView, inputDataSet);
std::map<std::string, float> valueRanges = computeRanges(graphView, inputDataSet, true);
if (OptimizeCliping)
{
Log::info(" Optimizing the cliping values ...");
adjustRanges(valueRanges, nbBits, graphView, inputDataSet);
}
Log::info(" Optimizing the clipping values ...");
valueRanges = adjustRanges(clippingMode, valueRanges, nbBits, graphView, inputDataSet, verbose);
Log::info(" Normalizing the activations ...");
normalizeActivations(graphView, valueRanges);
Log::info(" Quantizing the normalized network ...");
quantizeNormalizedNetwork(graphView, nbBits);
quantizeNormalizedNetwork(graphView, nbBits, applyRounding, optimizeSigns, verbose);
if (singleShift)
{
Log::info( " Inserting the compensation nodes ...");
insertCompensationNodes(graphView, nbBits);
Log::info(" Performing the Single-Shift approximation ...");
performSingleShiftApproximation(graphView, applyRounding);
}
if (verbose)
printScalingFactors(graphView);
Log::info(" Resetting the scheduler ...");
SequentialScheduler scheduler(graphView);
scheduler.resetScheduling();
Log::info(" Network is quantized !");
}
......@@ -819,8 +1009,8 @@ std::map<std::string, float> getWeightRanges(std::shared_ptr<GraphView> graphVie
return weightRanges;
}
void clearBiases(std::shared_ptr<GraphView> graphView) {
void clearBiases(std::shared_ptr<GraphView> graphView)
{
for (std::shared_ptr<Node> node : graphView->getNodes()) {
if (node->type() == "FC" || node->type() == "Conv") {
std::shared_ptr<Tensor> biasTensor = std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(2);
......@@ -829,10 +1019,13 @@ void clearBiases(std::shared_ptr<GraphView> graphView) {
}
}
void devPTQ(std::shared_ptr<GraphView> graphView)
void devPTQ(std::shared_ptr<GraphView> graphView)
{
for (std::shared_ptr<Node> node : graphView->getNodes())
std::cout << " ### node : " << node->type() << std::endl;
}
SequentialScheduler scheduler(graphView);
scheduler.generateScheduling();
auto s = scheduler.getStaticScheduling();
for (std::shared_ptr<Node> node : s)
std::cout << " UUU : " << node->name() << std::endl;
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment