diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp index 28759632207793241cf28ffdeecf9743e740d342..df203f2547e720bcfbef109e05e7ccca5ed42b9e 100644 --- a/src/PTQ/PTQ.cpp +++ b/src/PTQ/PTQ.cpp @@ -82,21 +82,20 @@ bool isNotQuantized(std::shared_ptr<Node> node) { return (notQuantizedNodeTypes.find(node->type()) != notQuantizedNodeTypes.end()); } + +std::shared_ptr<Aidge::Node> getFirstNode(std::shared_ptr<GraphView> graphView) +{ + return graphView->getOrderedInputs()[0].first; +} void clearGraphViewInputNodes(std::shared_ptr<GraphView> graphView,DataType targetType) { for (std::shared_ptr<Aidge::Node> inputNode: graphView->inputNodes()) { - if(expandMetaOp(inputNode)) //Reset Tensor does not work on MetaOps - { - inputNode = std::static_pointer_cast<MetaOperator_Op>(inputNode->getOperator())->getMicroGraph()->getOrderedNodes()[0]; - } for (Aidge::IOIndex_t index = inputNode->getFirstFreeDataInput();index < inputNode->getNbFreeDataInputs(); index++) { - std::shared_ptr<OperatorTensor> firstOperatorTensor = std::static_pointer_cast<OperatorTensor>(inputNode->getOperator()); - firstOperatorTensor->resetInput(inputNode->getFirstFreeDataInput()); + inputNode->getOperator()->resetInput(index); } } - } bool checkArchitecture(std::shared_ptr<GraphView> graphView) { @@ -621,7 +620,7 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView) // ITERATE OVER THE GRAPH ///////////////////////////////////////////////// - std::shared_ptr<Node> firstNode = graphView->getOrderedInputs()[0].first; + std::shared_ptr<Node> firstNode =getFirstNode(graphView); for (std::shared_ptr<Node> node : nodeVector) { @@ -833,7 +832,7 @@ std::unordered_map<std::shared_ptr<Node>, double> computeRanges(std::shared_ptr< void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_map<std::shared_ptr<Node>, double> valueRanges) { - std::shared_ptr<Node> firstNode = graphView->getOrderedInputs()[0].first; + std::shared_ptr<Node> firstNode = getFirstNode(graphView); // CREATE THE ACCUMULATED RATIO MAP /////////////////////////////////////// @@ -936,7 +935,7 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_m std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> computeSignMap(std::shared_ptr<GraphView> graphView, bool verbose) { - std::shared_ptr<Node> firstNode = graphView->getOrderedInputs()[0].first; + std::shared_ptr<Node> firstNode = getFirstNode(graphView); std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> signMap; @@ -1373,6 +1372,7 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView, AIDGE_ASSERT(!noQuant,"Cannot cast operators with the noQuant (Fake Quantization) flag set to true!") Log::notice("Starting to cast operators into the desired type ..."); castQuantizedGraph(graphView,targetType,singleShift,bitshiftRounding); + graphView->updateInputsOutputs(); clearGraphViewInputNodes(graphView,targetType); //Convert all input tensors of the GV into targetType }