Skip to content
Snippets Groups Projects

Real quantization cast for PTQ

Merged Noam Zerah requested to merge noamzerah/aidge_quantization:quantization_cast into dev
1 file
+ 10
10
Compare changes
  • Side-by-side
  • Inline
+ 10
10
@@ -58,21 +58,20 @@ bool isNotQuantized(std::shared_ptr<Node> node)
{
return (notQuantizedNodeTypes.find(node->type()) != notQuantizedNodeTypes.end());
}
std::shared_ptr<Aidge::Node> getFirstNode(std::shared_ptr<GraphView> graphView)
{
return graphView->getOrderedInputs()[0].first;
}
void clearGraphViewInputNodes(std::shared_ptr<GraphView> graphView,DataType targetType)
{
for (std::shared_ptr<Aidge::Node> inputNode: graphView->inputNodes())
{
if(expandMetaOp(inputNode)) //Reset Tensor does not work on MetaOps
{
inputNode = std::static_pointer_cast<MetaOperator_Op>(inputNode->getOperator())->getMicroGraph()->getOrderedNodes()[0];
}
for (Aidge::IOIndex_t index = inputNode->getFirstFreeDataInput();index < inputNode->getNbFreeDataInputs(); index++)
{
std::shared_ptr<OperatorTensor> firstOperatorTensor = std::static_pointer_cast<OperatorTensor>(inputNode->getOperator());
firstOperatorTensor->resetInput(inputNode->getFirstFreeDataInput());
inputNode->getOperator()->resetInput(index);
}
}
}
bool checkArchitecture(std::shared_ptr<GraphView> graphView)
{
@@ -564,7 +563,7 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView)
// ITERATE OVER THE GRAPH /////////////////////////////////////////////////
std::shared_ptr<Node> firstNode = graphView->getOrderedInputs()[0].first;
std::shared_ptr<Node> firstNode =getFirstNode(graphView);
for (std::shared_ptr<Node> node : nodeVector)
{
@@ -762,7 +761,7 @@ std::unordered_map<std::shared_ptr<Node>, double> computeRanges(std::shared_ptr<
void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_map<std::shared_ptr<Node>, double> valueRanges)
{
std::shared_ptr<Node> firstNode = graphView->getOrderedInputs()[0].first;
std::shared_ptr<Node> firstNode = getFirstNode(graphView);
// CREATE THE ACCUMULATED RATIO MAP ///////////////////////////////////////
@@ -856,7 +855,7 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_m
std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> computeSignMap(std::shared_ptr<GraphView> graphView, bool verbose)
{
std::shared_ptr<Node> firstNode = graphView->getOrderedInputs()[0].first;
std::shared_ptr<Node> firstNode = getFirstNode(graphView);
std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> signMap;
@@ -1291,6 +1290,7 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView,
AIDGE_ASSERT(!noQuant,"Cannot cast operators with the noQuant (Fake Quantization) flag set to true!")
Log::notice("Starting to cast operators into the desired type ...");
castQuantizedGraph(graphView,targetType,singleShift,bitshiftRounding);
graphView->updateInputsOutputs();
clearGraphViewInputNodes(graphView,targetType); //Convert all input tensors of the GV into targetType
}
Loading