diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp index 05ab48594709f79099ad643b2960b03c96ef7e75..dbf4b6234d2a64c7fd9572c77882a1dcf64f3f49 100644 --- a/src/PTQ/PTQ.cpp +++ b/src/PTQ/PTQ.cpp @@ -111,6 +111,37 @@ static std::shared_ptr<Aidge::Node> getUniqueChildren(std::shared_ptr<Aidge::Nod return *(childrenSet.begin()); } +static std::string determineBackend(std::shared_ptr<Aidge::Node> node) +{ + std::string backend = node->getOperator()->backend(); + + if (backend != "") + return backend; + else + { + // gather the parent backends + + std::set<std::string> parentBackends; + for (auto parent : node->getParents()) + parentBackends.insert(determineBackend(parent)); // it always answers a non empty value ! + + // check if we have two or more different backends gathered + + if (parentBackends.size() > 1) + { + Log::warn(" Unable to determine backend of node {} due to conflicting parent ones !", node->name()); + return (*parentBackends.begin()); + } + + // if all parents have the same backend return it + + if (parentBackends.size() == 1) + return (*parentBackends.begin()); + } + + return "cpu"; // escape path when no parents are found +} + static int getInputIndex(std::shared_ptr<Node> node, std::shared_ptr<Node> parentNode) { int index = 0; @@ -171,7 +202,7 @@ bool insertRoundBelowProducer(std::shared_ptr<Node> node, std::shared_ptr<GraphV { std::shared_ptr<Aidge::Node> roundNode = Round(node->name() + "_Round"); roundNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) - roundNode->getOperator()->setBackend(node->getOperator()->backend()); + roundNode->getOperator()->setBackend(determineBackend(node)); insertChildren(node, roundNode, graphView); roundNode->attributes()->addAttr("quantization.ptq.isProducerRounding", 0.0); @@ -268,17 +299,7 @@ static std::shared_ptr<Tensor> getBiasTensor(std::shared_ptr<Node> node) std::vector<std::shared_ptr<Node>> retrieveNodeVector(std::shared_ptr<GraphView> graphView, bool newSchedule, bool verbose) { - std::vector<std::shared_ptr<Node>> nodeVector; - - SequentialScheduler scheduler(graphView); - - if (newSchedule) - { - scheduler.resetScheduling(); - scheduler.generateScheduling(); // old way : scheduler.forward(); - } - - nodeVector = scheduler.getSequentialStaticScheduling(); + std::vector<std::shared_ptr<Node>> nodeVector = graphView->getOrderedNodes(); fixScheduling(nodeVector); nodeVector = removeMatchingNodes(nodeVector, "Producer"); @@ -348,8 +369,7 @@ bool insertScalingBelowProducer(std::shared_ptr<Node> producerNode, double scali std::shared_ptr<Node> scalingNode = createScalingNode(scalingNodeName, {"isProducerScaling"}, scalingFactor); scalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) - auto producerOp = std::static_pointer_cast<OperatorTensor>(producerNode->getOperator()); - scalingNode->getOperator()->setBackend(producerOp->getOutput(0)->backend()); + scalingNode->getOperator()->setBackend(determineBackend(producerNode)); insertChildren(producerNode, scalingNode, graphView); graphView->add(scalingNode->getParent(1)); // add the scaling factor producer @@ -385,7 +405,7 @@ void insertResidualScalingNodes(std::shared_ptr<GraphView> graphView) std::shared_ptr<Node> residualNode = createScalingNode(residualNodeName, {"isScaling", "isResidual"}, 1.0); residualNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) - residualNode->getOperator()->setBackend(parentNode->getOperator()->backend()); + residualNode->getOperator()->setBackend(determineBackend(parentNode)); graphView->insertParent(node, residualNode, i, 0, 0); graphView->add(residualNode->getParent(1)); // add the scaling factor producer @@ -425,7 +445,7 @@ void insertScalingNodes(std::shared_ptr<GraphView> graphView) std::shared_ptr<Node> scalingNode = createScalingNode(scalingNodeName, {"isScaling"}, 1.0); scalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) - scalingNode->getOperator()->setBackend(parentNode->getOperator()->backend()); + scalingNode->getOperator()->setBackend(determineBackend(parentNode)); if (parentNode->getChildren().size() > 0) { insertChildren(parentNode, scalingNode, graphView); @@ -446,12 +466,11 @@ void insertScalingNodes(std::shared_ptr<GraphView> graphView) std::shared_ptr<Node> prevScalingNode = createScalingNode(prevScalingNodeName, {"isScaling"}, 1.0); prevScalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) - prevScalingNode->getOperator()->setBackend(parentNode->getOperator()->backend()); + prevScalingNode->getOperator()->setBackend(determineBackend(parentNode)); graphView->insertParent(parentNode, prevScalingNode, 0, 0, 0); graphView->add(prevScalingNode->getParent(1)); // add the scaling factor producer } - } } } @@ -1009,7 +1028,7 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_ std::shared_ptr<Node> quantizerNode = Quantizer(oldScalingFactor, -(signedMax + 1), signedMax, node->name()); quantizerNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) - quantizerNode->getOperator()->setBackend(node->getOperator()->backend()); + quantizerNode->getOperator()->setBackend(determineBackend(node)); graphView->replace({node, node->getParent(1)}, {quantizerNode}); if (optimizeSigns) @@ -1063,7 +1082,7 @@ static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::u mulNode->attributes()->addAttr("quantization.ptq.isCompensation", 0.0); mulNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) - mulNode->getOperator()->setBackend(node->getOperator()->backend()); + mulNode->getOperator()->setBackend(determineBackend(node)); graphView->insertParent(node, mulNode, 0, 0, 0); @@ -1074,7 +1093,7 @@ static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::u coeffProducer->getOperator()->setOutput(0, coeffTensor); coeffProducer->getOperator()->setDataType(DataType::Float64); - coeffProducer->getOperator()->setBackend(node->getOperator()->backend()); + coeffProducer->getOperator()->setBackend(determineBackend(node)); graphView->add(coeffProducer); // needed ? diff --git a/src/QAT/QAT_FixedQ.cpp b/src/QAT/QAT_FixedQ.cpp index c182d6cc5b8402dabbf33c706ba8f406d4e6a162..b51123cc0d612714b280423e515fbf25006bbc72 100644 --- a/src/QAT/QAT_FixedQ.cpp +++ b/src/QAT/QAT_FixedQ.cpp @@ -152,10 +152,8 @@ void QuantFixedQ::insertAndInitQuantizers(std::shared_ptr<GraphView> graphView, void QuantFixedQ::devQAT(std::shared_ptr<GraphView> graphView) { - SequentialScheduler scheduler(graphView); - scheduler.generateScheduling(); - auto s = scheduler.getSequentialStaticScheduling(); - for (std::shared_ptr<Node> node : s) + auto nodeVector = graphView->getOrderedNodes(); + for (std::shared_ptr<Node> node : nodeVector) Log::info(" name : {} ", node->name()); }