Skip to content
Snippets Groups Projects

Real quantization cast for PTQ

Merged Noam Zerah requested to merge noamzerah/aidge_quantization:quantization_cast into dev
2 files
+ 13
52
Compare changes
  • Side-by-side
  • Inline
Files
2
+ 9
12
@@ -81,13 +81,15 @@ bool isNotQuantized(std::shared_ptr<Node> node)
{
return (notQuantizedNodeTypes.find(node->type()) != notQuantizedNodeTypes.end());
}
void clearGraphViewInputNodes(std::shared_ptr<GraphView> graphView)
void clearGraphViewInputNodes(std::shared_ptr<GraphView> graphView,DataType targetType)
{
for (std::shared_ptr<Aidge::Node> input_node: graphView->inputNodes())
{
for (Aidge::IOIndex_t index = input_node->getFirstFreeDataInput();index < input_node->getNbFreeDataInputs(); index++)
{
std::static_pointer_cast<OperatorTensor>(input_node->getOperator())->resetInput(index);
if(input_node->type() != "IntQuantizer" && input_node->type() != "BitShiftQuantizer") {
std::static_pointer_cast<OperatorTensor>(input_node->getOperator())->getInput(index)->setDataType(targetType);
}
}
}
}
@@ -436,10 +438,6 @@ std::vector<std::shared_ptr<Node>> retrieveNodeVector(std::shared_ptr<GraphView>
return nodeVector;
}
static std::shared_ptr<Node> getFirstNode(std::shared_ptr<GraphView> graphView)
{
return retrieveNodeVector(graphView)[0];
}
// TODO : enhance this by modifying OperatorImpl in "core" ...
static DataType getDataType(std::shared_ptr<Node> node)
@@ -619,7 +617,7 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView)
// ITERATE OVER THE GRAPH /////////////////////////////////////////////////
std::shared_ptr<Node> firstNode = getFirstNode(graphView);
std::shared_ptr<Node> firstNode = graphView->getOrderedInputs()[0].first;
for (std::shared_ptr<Node> node : nodeVector)
{
@@ -764,8 +762,6 @@ std::unordered_map<std::shared_ptr<Node>, double> computeRanges(std::shared_ptr<
std::unordered_map<std::shared_ptr<Node>, double> valueRanges;
std::set<std::shared_ptr<Node>> nodeSet = graphView->getNodes();
// std::shared_ptr<Node> inputNode = getFirstNode(graphView);
for (std::shared_ptr<Node> node : nodeSet)
if ((scalingNodesOnly && hasAttr(node, "isScaling")) || (!scalingNodesOnly && (node->type() != "Producer")))
valueRanges.insert(std::make_pair(node, 0));
@@ -833,7 +829,7 @@ std::unordered_map<std::shared_ptr<Node>, double> computeRanges(std::shared_ptr<
void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_map<std::shared_ptr<Node>, double> valueRanges)
{
std::shared_ptr<Node> firstNode = getFirstNode(graphView);
std::shared_ptr<Node> firstNode = graphView->getOrderedInputs()[0].first;
// CREATE THE ACCUMULATED RATIO MAP ///////////////////////////////////////
@@ -936,7 +932,7 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_m
std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> computeSignMap(std::shared_ptr<GraphView> graphView, bool verbose)
{
std::shared_ptr<Node> firstNode = getFirstNode(graphView);
std::shared_ptr<Node> firstNode = graphView->getOrderedInputs()[0].first;
std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> signMap;
@@ -1373,6 +1369,8 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView,
AIDGE_ASSERT(!noQuant,"Cannot cast operators with the noQuant (Fake Quantization) flag set to true!")
Log::notice("Starting to cast operators into the desired type ...");
castQuantizedGraph(graphView,targetType,singleShift,bitshiftRounding);
graphView->updateInputsOutputs();
clearGraphViewInputNodes(graphView,targetType); //Convert all input tensors of the GV into targetType
}
else
{
@@ -1388,7 +1386,6 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView,
//Clearing input nodes
Log::notice("Clearing all input nodes ...");
clearGraphViewInputNodes(graphView);
if (verbose)
printScalingFactors(graphView);
Loading