Skip to content
Snippets Groups Projects
Commit 2e21e5dd authored by Benjamin Halimi's avatar Benjamin Halimi
Browse files

make use of determineBackend()

parent 302ebd03
No related branches found
No related tags found
2 merge requests!54Update 0.3.1 -> 0.4.0,!48Hotfix the default backend issue
Pipeline #67732 passed
...@@ -138,6 +138,8 @@ static std::string determineBackend(std::shared_ptr<Aidge::Node> node) ...@@ -138,6 +138,8 @@ static std::string determineBackend(std::shared_ptr<Aidge::Node> node)
if (parentBackends.size() == 1) if (parentBackends.size() == 1)
return (*parentBackends.begin()); return (*parentBackends.begin());
} }
return "cpu"; // escape path when no parents are found
} }
static int getInputIndex(std::shared_ptr<Node> node, std::shared_ptr<Node> parentNode) static int getInputIndex(std::shared_ptr<Node> node, std::shared_ptr<Node> parentNode)
...@@ -200,7 +202,7 @@ bool insertRoundBelowProducer(std::shared_ptr<Node> node, std::shared_ptr<GraphV ...@@ -200,7 +202,7 @@ bool insertRoundBelowProducer(std::shared_ptr<Node> node, std::shared_ptr<GraphV
{ {
std::shared_ptr<Aidge::Node> roundNode = Round(node->name() + "_Round"); std::shared_ptr<Aidge::Node> roundNode = Round(node->name() + "_Round");
roundNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) roundNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
roundNode->getOperator()->setBackend("cpu"); roundNode->getOperator()->setBackend(determineBackend(node));
insertChildren(node, roundNode, graphView); insertChildren(node, roundNode, graphView);
roundNode->attributes()->addAttr("quantization.ptq.isProducerRounding", 0.0); roundNode->attributes()->addAttr("quantization.ptq.isProducerRounding", 0.0);
...@@ -367,7 +369,7 @@ bool insertScalingBelowProducer(std::shared_ptr<Node> producerNode, double scali ...@@ -367,7 +369,7 @@ bool insertScalingBelowProducer(std::shared_ptr<Node> producerNode, double scali
std::shared_ptr<Node> scalingNode = createScalingNode(scalingNodeName, {"isProducerScaling"}, scalingFactor); std::shared_ptr<Node> scalingNode = createScalingNode(scalingNodeName, {"isProducerScaling"}, scalingFactor);
scalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) scalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
scalingNode->getOperator()->setBackend("cpu"); scalingNode->getOperator()->setBackend(determineBackend(producerNode));
insertChildren(producerNode, scalingNode, graphView); insertChildren(producerNode, scalingNode, graphView);
graphView->add(scalingNode->getParent(1)); // add the scaling factor producer graphView->add(scalingNode->getParent(1)); // add the scaling factor producer
...@@ -403,7 +405,7 @@ void insertResidualScalingNodes(std::shared_ptr<GraphView> graphView) ...@@ -403,7 +405,7 @@ void insertResidualScalingNodes(std::shared_ptr<GraphView> graphView)
std::shared_ptr<Node> residualNode = createScalingNode(residualNodeName, {"isScaling", "isResidual"}, 1.0); std::shared_ptr<Node> residualNode = createScalingNode(residualNodeName, {"isScaling", "isResidual"}, 1.0);
residualNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) residualNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
residualNode->getOperator()->setBackend("cpu"); residualNode->getOperator()->setBackend(determineBackend(parentNode));
graphView->insertParent(node, residualNode, i, 0, 0); graphView->insertParent(node, residualNode, i, 0, 0);
graphView->add(residualNode->getParent(1)); // add the scaling factor producer graphView->add(residualNode->getParent(1)); // add the scaling factor producer
...@@ -443,7 +445,7 @@ void insertScalingNodes(std::shared_ptr<GraphView> graphView) ...@@ -443,7 +445,7 @@ void insertScalingNodes(std::shared_ptr<GraphView> graphView)
std::shared_ptr<Node> scalingNode = createScalingNode(scalingNodeName, {"isScaling"}, 1.0); std::shared_ptr<Node> scalingNode = createScalingNode(scalingNodeName, {"isScaling"}, 1.0);
scalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) scalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
scalingNode->getOperator()->setBackend("cpu"); scalingNode->getOperator()->setBackend(determineBackend(parentNode));
if (parentNode->getChildren().size() > 0) { if (parentNode->getChildren().size() > 0) {
insertChildren(parentNode, scalingNode, graphView); insertChildren(parentNode, scalingNode, graphView);
...@@ -464,7 +466,7 @@ void insertScalingNodes(std::shared_ptr<GraphView> graphView) ...@@ -464,7 +466,7 @@ void insertScalingNodes(std::shared_ptr<GraphView> graphView)
std::shared_ptr<Node> prevScalingNode = createScalingNode(prevScalingNodeName, {"isScaling"}, 1.0); std::shared_ptr<Node> prevScalingNode = createScalingNode(prevScalingNodeName, {"isScaling"}, 1.0);
prevScalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) prevScalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
prevScalingNode->getOperator()->setBackend("cpu"); prevScalingNode->getOperator()->setBackend(determineBackend(parentNode));
graphView->insertParent(parentNode, prevScalingNode, 0, 0, 0); graphView->insertParent(parentNode, prevScalingNode, 0, 0, 0);
graphView->add(prevScalingNode->getParent(1)); // add the scaling factor producer graphView->add(prevScalingNode->getParent(1)); // add the scaling factor producer
...@@ -1026,7 +1028,7 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_ ...@@ -1026,7 +1028,7 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_
std::shared_ptr<Node> quantizerNode = Quantizer(oldScalingFactor, -(signedMax + 1), signedMax, node->name()); std::shared_ptr<Node> quantizerNode = Quantizer(oldScalingFactor, -(signedMax + 1), signedMax, node->name());
quantizerNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) quantizerNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
quantizerNode->getOperator()->setBackend("cpu"); quantizerNode->getOperator()->setBackend(determineBackend(node));
graphView->replace({node, node->getParent(1)}, {quantizerNode}); graphView->replace({node, node->getParent(1)}, {quantizerNode});
if (optimizeSigns) if (optimizeSigns)
...@@ -1080,7 +1082,7 @@ static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::u ...@@ -1080,7 +1082,7 @@ static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::u
mulNode->attributes()->addAttr("quantization.ptq.isCompensation", 0.0); mulNode->attributes()->addAttr("quantization.ptq.isCompensation", 0.0);
mulNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) mulNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
mulNode->getOperator()->setBackend("cpu"); mulNode->getOperator()->setBackend(determineBackend(node));
graphView->insertParent(node, mulNode, 0, 0, 0); graphView->insertParent(node, mulNode, 0, 0, 0);
...@@ -1091,7 +1093,7 @@ static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::u ...@@ -1091,7 +1093,7 @@ static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::u
coeffProducer->getOperator()->setOutput(0, coeffTensor); coeffProducer->getOperator()->setOutput(0, coeffTensor);
coeffProducer->getOperator()->setDataType(DataType::Float64); coeffProducer->getOperator()->setDataType(DataType::Float64);
coeffProducer->getOperator()->setBackend("cpu"); coeffProducer->getOperator()->setBackend(determineBackend(node));
graphView->add(coeffProducer); // needed ? graphView->add(coeffProducer); // needed ?
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment