Skip to content
Snippets Groups Projects
Commit b609b95f authored by Noam Zerah's avatar Noam Zerah
Browse files

Real quantization cast for PTQ

parent 3f669a98
No related branches found
No related tags found
2 merge requests!54Update 0.3.1 -> 0.4.0,!46Real quantization cast for PTQ
File added
File added
File added
import unittest
import gzip
import numpy as np
from pathlib import Path
import gzip
import aidge_core
import aidge_backend_cpu
import aidge_onnx
import aidge_backend_cpu
import aidge_quantization
import sys
from pathlib import Path
from aidge_core import Log, Level
"""
Unit test for the PTQ pipeline:
This script is designed to test and validate the accuracy of five small model topologies on the MNIST dataset:
["MiniResNet", "ConvNet", "BranchNetV4", "TestNet", "MLP"]
It compares the results for three configurations: the baseline, quantization, and quantization with single shift.
The value of sigma represents the tolerance for the tests.
"""
aidge_core.Log.set_console_level(aidge_core.Level.Error) # Reduce useless logs
# --------------------------------------------------------------
# CONFIGS
# CONFIGURATION
# --------------------------------------------------------------
NB_SAMPLES = 1000 # max : 1000
SAMPLE_SHAPE = (1, 1, 28, 28)
MODEL_NAME = 'MiniResNet.onnx' # 'ConvNet.onnx'
ACCURACIES = (95.4, 94.4) # (97.9, 97.7)
NB_BITS = 4
NB_SAMPLES = 1000
SAMPLE_SHAPE = (1, 1, 28, 28)
NB_BITS = 4
CLIPPING = aidge_quantization.Clipping.MSE
EXPECTED_RESULTS = {
"MiniResNet.onnx": (95.4, 94.5, 94.7),
"ConvNet.onnx": (97.9, 97.7, 97.4),
"BranchNetV4.onnx": (93.8, 93.2, 93.7),
"TestNet.onnx": (95.5, 94.2, 94.2),
"MLP.onnx": (94.7, 94.2, 93.3)
}
SIGMA = 0.05
# --------------------------------------------------------------
# UTILS
# --------------------------------------------------------------
def propagate(model, scheduler, sample):
sample = np.reshape(sample, SAMPLE_SHAPE)
input_tensor = aidge_core.Tensor(sample)
scheduler.forward(True, [input_tensor])
output_node = model.get_output_nodes().pop()
output_tensor = output_node.get_operator().get_output(0)
return np.array(output_tensor)
def prepare_sample(sample):
sample = np.reshape(sample, SAMPLE_SHAPE)
return sample.astype('float32')
def compute_accuracy(model, samples, labels):
acc = 0
scheduler = aidge_core.SequentialScheduler(model)
for i, sample in enumerate(samples):
x = prepare_sample(sample)
y = propagate(model, scheduler, x)
if labels[i] == np.argmax(y):
acc += 1
schedueler = aidge_core.SequentialScheduler(model)
acc = sum(labels[i] == np.argmax(propagate(model, schedueler, x)) for i, x in enumerate(samples))
return acc / len(samples)
# --------------------------------------------------------------
# TEST CLASS
# --------------------------------------------------------------
class test_ptq(unittest.TestCase):
class TestQuantization(unittest.TestCase):
def setUp(self):
# load the samples / labels (numpy)
curr_file_dir = Path(__file__).parent.resolve()
self.samples = np.load(gzip.GzipFile(curr_file_dir / 'assets/mnist_samples.npy.gz', "r"))
self.labels = np.load(gzip.GzipFile(curr_file_dir / 'assets/mnist_labels.npy.gz', "r"))
# load the model in AIDGE
self.model = aidge_onnx.load_onnx(curr_file_dir / "assets/" / MODEL_NAME, verbose=False)
aidge_core.remove_flatten(self.model)
self.model.set_datatype(aidge_core.dtype.float32)
self.model.set_backend("cpu")
def tearDown(self):
pass
def test_model(self):
Log.set_console_level(Level.Info)
# compute the base accuracy
accuracy = compute_accuracy(self.model, self.samples[0:NB_SAMPLES], self.labels)
self.assertAlmostEqual(accuracy * 100, ACCURACIES[0], msg='base accuracy does not meet the baseline !', delta=0.1)
def test_quant_model(self):
Log.set_console_level(Level.Debug)
# create the calibration dataset
tensors = []
for sample in self.samples[0:NB_SAMPLES]:
sample = prepare_sample(sample)
tensor = aidge_core.Tensor(sample)
tensors.append(tensor)
# quantize the model
aidge_quantization.quantize_network(
self.model,
NB_BITS,
tensors,
clipping_mode=aidge_quantization.Clipping.MSE,
no_quantization=False,
optimize_signs=True,
single_shift=False
)
# rescale the inputs
scaling = 2**(NB_BITS-1)-1
for i in range(NB_SAMPLES):
self.samples[i] = self.samples[i]*scaling # XXX np.round ???
# compute the quantized accuracy
accuracy = compute_accuracy(self.model, self.samples, self.labels)
self.assertAlmostEqual(accuracy * 100, ACCURACIES[1], msg='quantized accuracy does not meet the baseline !', delta=0.1)
self.labels = np.load(gzip.GzipFile(curr_file_dir / 'assets/mnist_labels.npy.gz', "r"))
self.quantized_sample = np.load(gzip.GzipFile(curr_file_dir / 'assets/mnist_samples.npy.gz', "r")) * ((1 << (NB_BITS - 1)) - 1)
def run_model_test(self, model_name):
model_path = Path(__file__).parent / "assets" / model_name
model = aidge_onnx.load_onnx(model_path, verbose=False)
aidge_core.remove_flatten(model)
model.set_datatype(aidge_core.dtype.float64)
model.set_backend("cpu")
expected_base, expected_quant, expected_quant_ss = EXPECTED_RESULTS[model_name]
# Baseline Accuracy
base_accuracy = compute_accuracy(model, self.samples[:NB_SAMPLES], self.labels)
self.assertAlmostEqual(base_accuracy * 100, expected_base, delta=SIGMA, msg=f"[X] Baseline accuracy mismatch for {model_name}. Expected accuracy was: {expected_base}, but got: {base_accuracy * 100}")
# Quantize
tensors = [aidge_core.Tensor(np.reshape(sample, SAMPLE_SHAPE)) for sample in self.samples[:NB_SAMPLES]]
aidge_quantization.quantize_network(network = model,
nb_bits = NB_BITS,
input_dataset = tensors,
clipping_mode = CLIPPING,
target_type = aidge_core.dtype.float64,
no_quantization = False,
optimize_signs = True,
single_shift = False,
use_cuda = False,
fold_graph = True,
bitshift_rounding = False,
verbose = False)
quant_accuracy = compute_accuracy(model, self.quantized_sample[:NB_SAMPLES], self.labels)
self.assertAlmostEqual(quant_accuracy * 100, expected_quant, delta=SIGMA, msg=f"[X] Quantized accuracy mismatch for {model_name},Expected accuracy was: {expected_quant}, but got: {quant_accuracy * 100}")
# Quantize with Single Shift
model_ss = aidge_onnx.load_onnx(model_path, verbose=False)
aidge_core.remove_flatten(model_ss)
model_ss.set_datatype(aidge_core.dtype.float64)
model_ss.set_backend("cpu")
aidge_quantization.quantize_network(network = model_ss,
nb_bits = NB_BITS,
input_dataset = tensors,
clipping_mode = CLIPPING,
target_type = aidge_core.dtype.float64,
no_quantization = False,
optimize_signs = True,
single_shift = True,
use_cuda = False,
fold_graph = True,
bitshift_rounding = False,
verbose = False)
quant_accuracy_ss = compute_accuracy(model_ss, self.quantized_sample[:NB_SAMPLES], self.labels)
self.assertAlmostEqual(quant_accuracy_ss * 100, expected_quant_ss, delta=SIGMA, msg=f"[X] Quantized Single Shift accuracy mismatch for {model_name}.Expected accuracy was: {expected_quant_ss}, but got: {quant_accuracy_ss * 100}")
def test_models(self):
for model in EXPECTED_RESULTS.keys():
with self.subTest(model=model):
self.run_model_test(model)
if __name__ == '__main__':
unittest.main()
......@@ -29,6 +29,34 @@ namespace Aidge {
/// @return A shared pointer to an instance of the meta-operator node.
std::shared_ptr<Aidge::Node> Quantizer(double scalingFactor, double clipMin, double clipMax, const std::string& name);
/// @brief IntQuantizer acts as an extension of the Quantizer meta-operator, enabling seamless integration
/// into computation graphs with a data type other than Float while preserving floating-point precision.
///
/// This operator modifies the provided Quantizer by inserting explicit casting operations before and after
/// the quantization process. It first casts the input to Float64, applies the quantization steps (Mul, Clip, Round),
/// and then casts the result back to the target data type. This ensures compatibility with integer-based computation graphs
/// while maintaining the precision of floating-point operations.
///
/// @param oldQuantizer A shared pointer to the existing Quantizer node that will be adapted.
/// @param targetType The target data type to which the final output should be cast after the quantization process.
/// @param name The name of the meta-operator node created.
/// @return A shared pointer to a new instance of the modified meta-operator node.
std::shared_ptr<Node> IntQuantizer(std::shared_ptr<Node> oldQuantizer, DataType targetType, const std::string& name);
/// @brief BitShiftQuantizer acts as an extension of the Quantizer meta-operator, enabling seamless integration
/// into computation graphs with a data type other than Float while preserving floating-point precision.
///
/// This operator modifies the provided Quantizer by inserting explicit casting operations before and after
/// the quantization process. It first casts the input to Float64, applies the quantization steps (Mul, Clip, Round),
/// and then casts the result back to the target data type. This ensures compatibility with integer-based computation graphs
/// while maintaining the precision of floating-point operations.
///
/// @param oldQuantizer A shared pointer to the existing Quantizer node that will be adapted.
/// @param targetType The target data type to which the final output should be cast after the quantization process.
/// @param name The name of the meta-operator node created.
/// @return A shared pointer to a new instance of the modified meta-operator node.
std::shared_ptr<Node> BitShiftQuantizer(std::shared_ptr<Node> oldQuantizer, DataType targetType,bool bitshiftRounding, const std::string& name);
/// @brief Updates the scaling factor of a PTQ meta-operator node, allowing for dynamic adjustment of the scaling parameter.
/// This function sets a new scaling factor for a specified meta-operator node, modifying the scalar applied in the [Mul] operation.
/// The meta-operator node must be a PTQ-specific operator, such as a Quantizer or Scaling node.
......
......@@ -181,14 +181,29 @@ namespace Aidge {
* @param graphView The GraphView to be quantized.
* @param nbBits The desired number of bits of the quantization.
* @param inputDataSet The input dataset on which the value ranges are computed.
* @param clippingMode: Type of the clipping optimization. Can be either 'MAX', 'MSE', 'AA' or 'KL'.
* @param clippingMode Type of the clipping optimization. Can be either 'MAX', 'MSE', 'AA' or 'KL'.
* @param targetType Desired target type to cast the graph into (default is Float64 which will NOT apply casting on the network)
* @param noQuant Whether to apply the rounding operations or not.
* @param optimizeSigns Whether to take account of the IO signs of the operators or not.
* @param singleShift Whether to convert the scaling factors into powers of two. If true the approximations are compensated using the previous nodes weights.
* @param useCuda Wheter to speed up the PTQ by computing the values ranges using CUDA kernels.
* This flag does not set the backend of the graphview to "cuda" at the end of the PTQ pipeline
* @param foldGraph Whether to apply the constant folding recipe which makes the end graphview much easier to read
* @param bitshiftRounding Whether rounding should be applied after bit-shifting operations. If enabled, the result of bit-shifting is rounded to the nearest integer.
* @param verbose Whether to print internal informations about the quantization process.
*/
void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet, Clipping clippingMode, bool noQuant, bool optimizeSigns, bool singleShift, bool useCuda, bool verbose);
void quantizeNetwork(std::shared_ptr<GraphView> graphView,
std::uint8_t nbBits,
std::vector<std::shared_ptr<Tensor>> inputDataSet,
Clipping clippingMode,
DataType targetType,
bool noQuant,
bool optimizeSigns,
bool singleShift,
bool useCuda,
bool foldGraph,
bool bitshiftRounding,
bool verbose);
/**
* @brief Compute the weight ranges of every affine node. Provided for debugging purposes.
* @param graphView The GraphView containing the affine nodes.
......
......@@ -93,7 +93,20 @@ void init_PTQ(py::module &m) {
:type verbose: bool
)mydelimiter");
m.def("quantize_network", &quantizeNetwork ,py::arg("network"), py::arg("nb_bits"), py::arg("input_dataset"), py::arg("clipping_mode") = Clipping::MAX , py::arg("no_quantization") = false, py::arg("optimize_signs") = false, py::arg("single_shift") = false, py::arg("use_cuda") = false, py::arg("verbose") = false,
m.def("quantize_network",
&quantizeNetwork,
py::arg("network"),
py::arg("nb_bits"),
py::arg("input_dataset"),
py::arg("clipping_mode") = Clipping::MAX,
py::arg("target_type") = DataType::Float64,
py::arg("no_quantization") = false,
py::arg("optimize_signs") = false,
py::arg("single_shift") = false,
py::arg("use_cuda") = false,
py::arg("fold_graph") = true,
py::arg("bitshift_rounding") = false,
py::arg("verbose") = false,
R"mydelimiter(
Main quantization routine. Performs every step of the quantization pipeline.
:param network: The GraphView to be quantized.
......
......@@ -20,6 +20,7 @@
#include "aidge/scheduler/SequentialScheduler.hpp"
#include "aidge/scheduler/Scheduler.hpp"
#include "aidge/utils/Log.hpp"
#include "aidge/operator/MetaOperator.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/operator/Mul.hpp"
......@@ -30,6 +31,9 @@
#include "aidge/operator/ArgMax.hpp"
#include "aidge/operator/Reshape.hpp"
#include "aidge/operator/MatMul.hpp"
#include "aidge/operator/Cast.hpp"
#include "aidge/recipes/Recipes.hpp"
#include "aidge/recipes/QuantRecipes.hpp"
......@@ -79,6 +83,20 @@ bool isNotQuantized(std::shared_ptr<Node> node)
return (notQuantizedNodeTypes.find(node->type()) != notQuantizedNodeTypes.end());
}
std::shared_ptr<Aidge::Node> getFirstNode(std::shared_ptr<GraphView> graphView)
{
return graphView->getOrderedInputs()[0].first;
}
void clearGraphViewInputNodes(std::shared_ptr<GraphView> graphView,DataType targetType)
{
for (std::shared_ptr<Aidge::Node> inputNode: graphView->inputNodes())
{
for (Aidge::IOIndex_t index = inputNode->getFirstFreeDataInput();index < inputNode->getNbFreeDataInputs(); index++)
{
inputNode->getOperator()->resetInput(index);
}
}
}
bool checkArchitecture(std::shared_ptr<GraphView> graphView)
{
std::set<std::string> removedNodeTypes({"Flatten", "Softmax", "BatchNorm2D"});
......@@ -249,6 +267,59 @@ static void insertChildren(std::shared_ptr<Node> parent, std::shared_ptr<Node> n
graphView->add(newNode);
}
void applyConstFold(std::shared_ptr<GraphView> &graphView)
{
for (const std::shared_ptr<Node> node : graphView->getNodes())
{
if (node->type() == "Producer" )
{
const auto& producer = std::static_pointer_cast<Producer_Op>(node->getOperator());
producer->constant() = true;
}
}
constantFolding(graphView);
}
bool castQuantizedGraph(std::shared_ptr<GraphView> &graphView, Aidge::DataType targetType, bool singleShift,bool bitshiftRounding)
{
//We need a deepcopy of the graphs nodes since we will replace some nodes
std::vector<std::shared_ptr<Node>> nodeVector(graphView->getNodes().begin(), graphView->getNodes().end());
for (std::shared_ptr<Node> node : nodeVector)
{
if (node->type() == "Round" && node->attributes()->hasAttr("quantization.ptq.isProducerRounding"))
{
std::shared_ptr<Aidge::Node> castNode = Cast(targetType,node->name() + "_Cast");
castNode->getOperator()->setDataType(targetType);
castNode->getOperator()->setBackend(node->getOperator()->backend());
insertChildren(node,castNode,graphView);
castNode->attributes()->addAttr("quantization.ptq.isProducerCasting",0.0);
node->getOperator()->setDataType(DataType::Float64);
}
else if(node->type() == "Quantizer")
{
if(singleShift)
{
std::shared_ptr<Node> newBitShiftQuantizer = BitShiftQuantizer(node,targetType,bitshiftRounding,node->name()+"_BitShift_Quantizer");
newBitShiftQuantizer->getOperator()->setBackend(node->getOperator()->backend());
graphView->replace({node},{newBitShiftQuantizer});
}
else //If single shift is not enabled we keep using the alternative Int Quantizer (which cast the data before and after the regular Quantizer Operations)
{
std::shared_ptr<Node> newIntQuantizer = IntQuantizer(node,targetType,node->name());
newIntQuantizer->getOperator()->setBackend(node->getOperator()->backend());
graphView->replace({node},{newIntQuantizer});
}
}
else if (node->type() != "Producer" &&
!node->attributes()->hasAttr("quantization.ptq.isProducerScaling"))
{
node->getOperator()->setDataType(targetType);
}
}
return true;
}
bool insertRoundBelowProducer(std::shared_ptr<Node> node, std::shared_ptr<GraphView> graphView)
{
......@@ -270,7 +341,6 @@ double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor)
{
// get the abs tensor
std::shared_ptr<Tensor> fallback; //Fallback tensor for refCastFR
std::shared_ptr<Tensor> absTensor = std::make_shared<Tensor>(tensor->abs());
// flatten the abs tensor
......@@ -371,10 +441,6 @@ std::vector<std::shared_ptr<Node>> retrieveNodeVector(std::shared_ptr<GraphView>
return nodeVector;
}
static std::shared_ptr<Node> getFirstNode(std::shared_ptr<GraphView> graphView)
{
return retrieveNodeVector(graphView)[0];
}
// TODO : enhance this by modifying OperatorImpl in "core" ...
static DataType getDataType(std::shared_ptr<Node> node)
......@@ -554,7 +620,7 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView)
// ITERATE OVER THE GRAPH /////////////////////////////////////////////////
std::shared_ptr<Node> firstNode = getFirstNode(graphView);
std::shared_ptr<Node> firstNode =getFirstNode(graphView);
for (std::shared_ptr<Node> node : nodeVector)
{
......@@ -699,8 +765,6 @@ std::unordered_map<std::shared_ptr<Node>, double> computeRanges(std::shared_ptr<
std::unordered_map<std::shared_ptr<Node>, double> valueRanges;
std::set<std::shared_ptr<Node>> nodeSet = graphView->getNodes();
// std::shared_ptr<Node> inputNode = getFirstNode(graphView);
for (std::shared_ptr<Node> node : nodeSet)
if ((scalingNodesOnly && hasAttr(node, "isScaling")) || (!scalingNodesOnly && (node->type() != "Producer")))
valueRanges.insert(std::make_pair(node, 0));
......@@ -768,7 +832,7 @@ std::unordered_map<std::shared_ptr<Node>, double> computeRanges(std::shared_ptr<
void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_map<std::shared_ptr<Node>, double> valueRanges)
{
std::shared_ptr<Node> firstNode = getFirstNode(graphView);
std::shared_ptr<Node> firstNode = getFirstNode(graphView);
// CREATE THE ACCUMULATED RATIO MAP ///////////////////////////////////////
......@@ -871,7 +935,7 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_m
std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> computeSignMap(std::shared_ptr<GraphView> graphView, bool verbose)
{
std::shared_ptr<Node> firstNode = getFirstNode(graphView);
std::shared_ptr<Node> firstNode = getFirstNode(graphView);
std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> signMap;
......@@ -1248,7 +1312,18 @@ static void setupDataType(std::shared_ptr<GraphView> graphView, std::vector<std:
tensor->setDataType(dataType);
}
void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet, Clipping clippingMode, bool noQuant, bool optimizeSigns, bool singleShift, bool useCuda, bool verbose)
void quantizeNetwork(std::shared_ptr<GraphView> graphView,
std::uint8_t nbBits,
std::vector<std::shared_ptr<Tensor>> inputDataSet,
Clipping clippingMode,
DataType targetType,
bool noQuant,
bool optimizeSigns,
bool singleShift,
bool useCuda,
bool foldGraph,
bool bitshiftRounding,
bool verbose)
{
Log::notice(" === QUANT PTQ 0.2.21 === ");
......@@ -1292,13 +1367,33 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits,
Log::notice(" Performing the Single-Shift approximation ...");
performSingleShiftApproximation(graphView, noQuant);
}
if( targetType != DataType::Float64 && targetType != DataType::Float32 && targetType != DataType::Float16)
{
AIDGE_ASSERT(!noQuant,"Cannot cast operators with the noQuant (Fake Quantization) flag set to true!")
Log::notice("Starting to cast operators into the desired type ...");
castQuantizedGraph(graphView,targetType,singleShift,bitshiftRounding);
graphView->updateInputsOutputs();
clearGraphViewInputNodes(graphView,targetType); //Convert all input tensors of the GV into targetType
}
else
{
setupDataType(graphView, inputDataSet, targetType);
}
if(foldGraph)
{
Log::notice("Applying constant folding recipe to the graph ...");
applyConstFold(graphView);
}
//Mandatory to handle all of the newly added connections!
graphView->updateInputsOutputs();
//Clearing input nodes
Log::notice("Clearing all input nodes ...");
if (verbose)
printScalingFactors(graphView);
if (useCuda)
graphView->setBackend("cuda");
Log::notice(" Reseting the scheduler ...");
SequentialScheduler scheduler(graphView);
scheduler.resetScheduling();
......@@ -1333,4 +1428,4 @@ void clearBiases(std::shared_ptr<GraphView> graphView)
}
}
}
}
}
\ No newline at end of file
......@@ -18,6 +18,8 @@
#include "aidge/operator/Clip.hpp"
#include "aidge/operator/Mul.hpp"
#include "aidge/operator/Round.hpp"
#include "aidge/operator/BitShift.hpp"
#include "aidge/operator/Cast.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/graph/OpArgs.hpp"
......@@ -33,7 +35,15 @@
namespace Aidge
{
static std::shared_ptr<Node> getSubNode(std::shared_ptr<GraphView> graphView, std::string nodeType)
{
std::shared_ptr<Node> mulNode = nullptr;
for(std::shared_ptr<Node> node : graphView->getNodes())
if (node->type() == nodeType)
mulNode = node;
return mulNode;
}
std::shared_ptr<Node> Quantizer(double scalingFactor, double clipMin, double clipMax, const std::string& name)
{
// create the nodes
......@@ -55,19 +65,65 @@ std::shared_ptr<Node> Quantizer(double scalingFactor, double clipMin, double cli
// return the metaop
std::shared_ptr<Node> metaopNode = MetaOperator("Quantizer", connectedGraphView, {}, name); // XXX alternative prototype
return MetaOperator("Quantizer", connectedGraphView, {}, name); // XXX alternative prototype
return metaopNode;
}
std::shared_ptr<Node> BitShiftQuantizer(std::shared_ptr<Node> oldQuantizer, DataType targetType,bool bitshiftRounding, const std::string& name)
{
double scalingFactor = getScalingFactor(oldQuantizer);
static std::shared_ptr<Node> getSubNode(std::shared_ptr<GraphView> graphView, std::string nodeType)
std::shared_ptr<MetaOperator_Op> metaOp = std::static_pointer_cast<MetaOperator_Op> (oldQuantizer->getOperator());
std::shared_ptr<Node> oldclipNode = getSubNode(metaOp->getMicroGraph(), "Clip");
if (!oldclipNode) {
Log::warn("Invalid PTQ MetaOperator, no Clip found inside node of type {}", oldQuantizer->type());
return nullptr;
}
std::shared_ptr<Clip_Op> clipOp = std::static_pointer_cast<Clip_Op>(oldclipNode->getOperator());
int shift = std::log2(scalingFactor);
BitShift_Op::BitShiftDirection direction = BitShift_Op::BitShiftDirection::left;
if(shift < 0 )
{
direction = BitShift_Op::BitShiftDirection::right;
shift = -shift;
}
std::shared_ptr<Node> bitShiftNode = BitShift(direction,bitshiftRounding,(!name.empty()) ? name + "_MulIQuant" : "");
std::shared_ptr<Node> clipNode = Clip((!name.empty()) ? name + "_IClipQuant" : "", clipOp->min(), clipOp->max());
std::shared_ptr<Tensor> bitshiftTensor = std::make_shared<Tensor>(Array1D<int, 1> {shift});
std::shared_ptr<Node> bitshiftProducer = addProducer(bitShiftNode, 1, {1}, "ScalingFactor");
bitshiftProducer->getOperator()->setOutput(0, bitshiftTensor);
bitshiftProducer->attributes()->addAttr("quantization.ptq.ShiftAmount",shift);
bitshiftProducer->getOperator()->setDataType(targetType);
// connect the scaling factor producer
bitShiftNode->getOperator()->setDataType(targetType);
clipNode->getOperator()->setDataType(targetType);
// create the metaop graph
std::shared_ptr<GraphView> graphView = Sequential({bitShiftNode,clipNode});
std::shared_ptr<GraphView> connectedGraphView = getConnectedGraphView(bitShiftNode); // XXX why not use the graphView ???
// return the metaop
return MetaOperator("BitShiftQuantizer", connectedGraphView, {}, name); // XXX alternative prototype
}
std::shared_ptr<Node> IntQuantizer(std::shared_ptr<Node> oldQuantizer, DataType targetType, const std::string& name)
{
std::shared_ptr<Node> mulNode = nullptr;
for(std::shared_ptr<Node> node : graphView->getNodes())
if (node->type() == nodeType)
mulNode = node;
std::shared_ptr<Node> castPreNode = Cast(DataType::Float64,((!name.empty()) ? name + "_PreCast" : ""));
std::shared_ptr<Node> castPostNode = Cast(targetType,((!name.empty()) ? name + "_PostCast" : ""));
return mulNode;
castPreNode->getOperator()->setDataType(DataType::Float64);
castPostNode->getOperator()->setDataType(targetType);
std::shared_ptr<GraphView> graphView = Sequential({castPreNode, oldQuantizer->clone(), castPostNode});
return MetaOperator("IntQuantizer", graphView, {}, name); // XXX alternative prototype
}
void updateScalingFactor(std::shared_ptr<Node> metaOpNode, double scalingFactor)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment