diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp index 42bac4051eadef07aa46e153bb3de654b6f4ea31..dbf4b6234d2a64c7fd9572c77882a1dcf64f3f49 100644 --- a/src/PTQ/PTQ.cpp +++ b/src/PTQ/PTQ.cpp @@ -138,6 +138,8 @@ static std::string determineBackend(std::shared_ptr<Aidge::Node> node) if (parentBackends.size() == 1) return (*parentBackends.begin()); } + + return "cpu"; // escape path when no parents are found } static int getInputIndex(std::shared_ptr<Node> node, std::shared_ptr<Node> parentNode) @@ -200,7 +202,7 @@ bool insertRoundBelowProducer(std::shared_ptr<Node> node, std::shared_ptr<GraphV { std::shared_ptr<Aidge::Node> roundNode = Round(node->name() + "_Round"); roundNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) - roundNode->getOperator()->setBackend("cpu"); + roundNode->getOperator()->setBackend(determineBackend(node)); insertChildren(node, roundNode, graphView); roundNode->attributes()->addAttr("quantization.ptq.isProducerRounding", 0.0); @@ -367,7 +369,7 @@ bool insertScalingBelowProducer(std::shared_ptr<Node> producerNode, double scali std::shared_ptr<Node> scalingNode = createScalingNode(scalingNodeName, {"isProducerScaling"}, scalingFactor); scalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) - scalingNode->getOperator()->setBackend("cpu"); + scalingNode->getOperator()->setBackend(determineBackend(producerNode)); insertChildren(producerNode, scalingNode, graphView); graphView->add(scalingNode->getParent(1)); // add the scaling factor producer @@ -403,7 +405,7 @@ void insertResidualScalingNodes(std::shared_ptr<GraphView> graphView) std::shared_ptr<Node> residualNode = createScalingNode(residualNodeName, {"isScaling", "isResidual"}, 1.0); residualNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) - residualNode->getOperator()->setBackend("cpu"); + residualNode->getOperator()->setBackend(determineBackend(parentNode)); graphView->insertParent(node, residualNode, i, 0, 0); graphView->add(residualNode->getParent(1)); // add the scaling factor producer @@ -443,7 +445,7 @@ void insertScalingNodes(std::shared_ptr<GraphView> graphView) std::shared_ptr<Node> scalingNode = createScalingNode(scalingNodeName, {"isScaling"}, 1.0); scalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) - scalingNode->getOperator()->setBackend("cpu"); + scalingNode->getOperator()->setBackend(determineBackend(parentNode)); if (parentNode->getChildren().size() > 0) { insertChildren(parentNode, scalingNode, graphView); @@ -464,7 +466,7 @@ void insertScalingNodes(std::shared_ptr<GraphView> graphView) std::shared_ptr<Node> prevScalingNode = createScalingNode(prevScalingNodeName, {"isScaling"}, 1.0); prevScalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) - prevScalingNode->getOperator()->setBackend("cpu"); + prevScalingNode->getOperator()->setBackend(determineBackend(parentNode)); graphView->insertParent(parentNode, prevScalingNode, 0, 0, 0); graphView->add(prevScalingNode->getParent(1)); // add the scaling factor producer @@ -1026,7 +1028,7 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_ std::shared_ptr<Node> quantizerNode = Quantizer(oldScalingFactor, -(signedMax + 1), signedMax, node->name()); quantizerNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) - quantizerNode->getOperator()->setBackend("cpu"); + quantizerNode->getOperator()->setBackend(determineBackend(node)); graphView->replace({node, node->getParent(1)}, {quantizerNode}); if (optimizeSigns) @@ -1080,7 +1082,7 @@ static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::u mulNode->attributes()->addAttr("quantization.ptq.isCompensation", 0.0); mulNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) - mulNode->getOperator()->setBackend("cpu"); + mulNode->getOperator()->setBackend(determineBackend(node)); graphView->insertParent(node, mulNode, 0, 0, 0); @@ -1091,7 +1093,7 @@ static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::u coeffProducer->getOperator()->setOutput(0, coeffTensor); coeffProducer->getOperator()->setDataType(DataType::Float64); - coeffProducer->getOperator()->setBackend("cpu"); + coeffProducer->getOperator()->setBackend(determineBackend(node)); graphView->add(coeffProducer); // needed ?