diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp index da98ed6ba6cb4ff32207e520b173cc401d677741..28759632207793241cf28ffdeecf9743e740d342 100644 --- a/src/PTQ/PTQ.cpp +++ b/src/PTQ/PTQ.cpp @@ -20,6 +20,7 @@ #include "aidge/scheduler/SequentialScheduler.hpp" #include "aidge/scheduler/Scheduler.hpp" #include "aidge/utils/Log.hpp" +#include "aidge/operator/MetaOperator.hpp" #include "aidge/operator/Producer.hpp" #include "aidge/operator/Mul.hpp" @@ -83,12 +84,19 @@ bool isNotQuantized(std::shared_ptr<Node> node) } void clearGraphViewInputNodes(std::shared_ptr<GraphView> graphView,DataType targetType) { - for (std::shared_ptr<Aidge::Node> input_node: graphView->inputNodes()) + for (std::shared_ptr<Aidge::Node> inputNode: graphView->inputNodes()) { - auto cast_node = Cast(targetType); - cast_node->addChild(input_node); - graphView->add(cast_node); + if(expandMetaOp(inputNode)) //Reset Tensor does not work on MetaOps + { + inputNode = std::static_pointer_cast<MetaOperator_Op>(inputNode->getOperator())->getMicroGraph()->getOrderedNodes()[0]; + } + for (Aidge::IOIndex_t index = inputNode->getFirstFreeDataInput();index < inputNode->getNbFreeDataInputs(); index++) + { + std::shared_ptr<OperatorTensor> firstOperatorTensor = std::static_pointer_cast<OperatorTensor>(inputNode->getOperator()); + firstOperatorTensor->resetInput(inputNode->getFirstFreeDataInput()); + } } + } bool checkArchitecture(std::shared_ptr<GraphView> graphView) {