diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp index ee179912068f126409e98574221cf4dcf7916934..91a003d55c0234d9edb9b173d655e1872361d5b6 100644 --- a/src/PTQ/PTQ.cpp +++ b/src/PTQ/PTQ.cpp @@ -165,6 +165,37 @@ static std::shared_ptr<Aidge::Node> getUniqueChild(std::shared_ptr<Aidge::Node> return *(childrenSet.begin()); } +static std::string determineBackend(std::shared_ptr<Aidge::Node> node) +{ + std::string backend = node->getOperator()->backend(); + + if (backend != "") + return backend; + else + { + // gather the parent backends + + std::set<std::string> parentBackends; + for (auto parent : node->getParents()) + parentBackends.insert(determineBackend(parent)); // it always answers a non empty value ! + + // check if we have two or more different backends gathered + + if (parentBackends.size() > 1) + { + Log::warn(" Unable to determine backend of node {} due to conflicting parent ones !", node->name()); + return (*parentBackends.begin()); + } + + // if all parents have the same backend return it + + if (parentBackends.size() == 1) + return (*parentBackends.begin()); + } + + return "cpu"; // escape path when no parents are found +} + static int getInputIndex(std::shared_ptr<Node> node, std::shared_ptr<Node> parentNode) { int index = 0; @@ -225,7 +256,7 @@ bool insertRoundBelowProducer(std::shared_ptr<Node> node, std::shared_ptr<GraphV { std::shared_ptr<Aidge::Node> roundNode = Round(node->name() + "_Round"); roundNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) - roundNode->getOperator()->setBackend(node->getOperator()->backend()); + roundNode->getOperator()->setBackend(determineBackend(node)); insertChildren(node, roundNode, graphView); addAttr(roundNode, "isProducerRounding"); @@ -394,8 +425,7 @@ bool insertScalingBelowProducer(std::shared_ptr<Node> producerNode, double scali std::shared_ptr<Node> scalingNode = createScalingNode(scalingNodeName, {"isProducerScaling"}, scalingFactor); scalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) - auto producerOp = std::static_pointer_cast<OperatorTensor>(producerNode->getOperator()); - scalingNode->getOperator()->setBackend(producerOp->getOutput(0)->backend()); + scalingNode->getOperator()->setBackend(determineBackend(producerNode)); insertChildren(producerNode, scalingNode, graphView); graphView->add(scalingNode->getParent(1)); // add the scaling factor producer @@ -431,7 +461,7 @@ void insertResidualScalingNodes(std::shared_ptr<GraphView> graphView) std::shared_ptr<Node> residualNode = createScalingNode(residualNodeName, {"isScaling", "isResidual"}, 1.0); residualNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) - residualNode->getOperator()->setBackend(parentNode->getOperator()->backend()); + residualNode->getOperator()->setBackend(determineBackend(parentNode)); graphView->insertParent(node, residualNode, i, 0, 0); graphView->add(residualNode->getParent(1)); // add the scaling factor producer @@ -471,7 +501,7 @@ void insertScalingNodes(std::shared_ptr<GraphView> graphView) std::shared_ptr<Node> scalingNode = createScalingNode(scalingNodeName, {"isScaling"}, 1.0); scalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) - scalingNode->getOperator()->setBackend(parentNode->getOperator()->backend()); + scalingNode->getOperator()->setBackend(determineBackend(parentNode)); if (parentNode->getChildren().size() > 0) { insertChildren(parentNode, scalingNode, graphView); @@ -492,7 +522,7 @@ void insertScalingNodes(std::shared_ptr<GraphView> graphView) std::shared_ptr<Node> prevScalingNode = createScalingNode(prevScalingNodeName, {"isScaling"}, 1.0); prevScalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) - prevScalingNode->getOperator()->setBackend(parentNode->getOperator()->backend()); + prevScalingNode->getOperator()->setBackend(determineBackend(parentNode)); graphView->insertParent(parentNode, prevScalingNode, 0, 0, 0); graphView->add(prevScalingNode->getParent(1)); // add the scaling factor producer @@ -1087,7 +1117,7 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_ std::shared_ptr<Node> quantizerNode = Quantizer(oldScalingFactor, -(signedMax + 1), signedMax, node->name()); quantizerNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) - quantizerNode->getOperator()->setBackend(node->getOperator()->backend()); + quantizerNode->getOperator()->setBackend(determineBackend(node)); graphView->replace({node, node->getParent(1)}, {quantizerNode}); if (optimizeSigns) @@ -1142,7 +1172,7 @@ static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::u addAttr(mulNode, "isCompensation"); mulNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) - mulNode->getOperator()->setBackend(node->getOperator()->backend()); + mulNode->getOperator()->setBackend(determineBackend(node)); graphView->insertParent(node, mulNode, 0, 0, 0); @@ -1153,7 +1183,7 @@ static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::u coeffProducer->getOperator()->setOutput(0, coeffTensor); coeffProducer->getOperator()->setDataType(DataType::Float64); - coeffProducer->getOperator()->setBackend(node->getOperator()->backend()); + coeffProducer->getOperator()->setBackend(determineBackend(node)); graphView->add(coeffProducer); // needed ?