diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp index e74ac8f0058b84a13bd06a4b886297ebd1caedbc..11bf5a30c167f929072cd426b7ce7f756055e9a1 100644 --- a/src/PTQ/PTQ.cpp +++ b/src/PTQ/PTQ.cpp @@ -58,21 +58,20 @@ bool isNotQuantized(std::shared_ptr<Node> node) { return (notQuantizedNodeTypes.find(node->type()) != notQuantizedNodeTypes.end()); } + +std::shared_ptr<Aidge::Node> getFirstNode(std::shared_ptr<GraphView> graphView) +{ + return graphView->getOrderedInputs()[0].first; +} void clearGraphViewInputNodes(std::shared_ptr<GraphView> graphView,DataType targetType) { for (std::shared_ptr<Aidge::Node> inputNode: graphView->inputNodes()) { - if(expandMetaOp(inputNode)) //Reset Tensor does not work on MetaOps - { - inputNode = std::static_pointer_cast<MetaOperator_Op>(inputNode->getOperator())->getMicroGraph()->getOrderedNodes()[0]; - } for (Aidge::IOIndex_t index = inputNode->getFirstFreeDataInput();index < inputNode->getNbFreeDataInputs(); index++) { - std::shared_ptr<OperatorTensor> firstOperatorTensor = std::static_pointer_cast<OperatorTensor>(inputNode->getOperator()); - firstOperatorTensor->resetInput(inputNode->getFirstFreeDataInput()); + inputNode->getOperator()->resetInput(index); } } - } bool checkArchitecture(std::shared_ptr<GraphView> graphView) { @@ -564,7 +563,7 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView) // ITERATE OVER THE GRAPH ///////////////////////////////////////////////// - std::shared_ptr<Node> firstNode = graphView->getOrderedInputs()[0].first; + std::shared_ptr<Node> firstNode =getFirstNode(graphView); for (std::shared_ptr<Node> node : nodeVector) { @@ -762,7 +761,7 @@ std::unordered_map<std::shared_ptr<Node>, double> computeRanges(std::shared_ptr< void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_map<std::shared_ptr<Node>, double> valueRanges) { - std::shared_ptr<Node> firstNode = graphView->getOrderedInputs()[0].first; + std::shared_ptr<Node> firstNode = getFirstNode(graphView); // CREATE THE ACCUMULATED RATIO MAP /////////////////////////////////////// @@ -856,7 +855,7 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_m std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> computeSignMap(std::shared_ptr<GraphView> graphView, bool verbose) { - std::shared_ptr<Node> firstNode = graphView->getOrderedInputs()[0].first; + std::shared_ptr<Node> firstNode = getFirstNode(graphView); std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> signMap; @@ -1291,6 +1290,7 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView, AIDGE_ASSERT(!noQuant,"Cannot cast operators with the noQuant (Fake Quantization) flag set to true!") Log::notice("Starting to cast operators into the desired type ..."); castQuantizedGraph(graphView,targetType,singleShift,bitshiftRounding); + graphView->updateInputsOutputs(); clearGraphViewInputNodes(graphView,targetType); //Convert all input tensors of the GV into targetType }