Skip to content
Snippets Groups Projects
Commit 2e21e5dd authored by Benjamin Halimi's avatar Benjamin Halimi
Browse files

make use of determineBackend()

parent 302ebd03
Branches fix_backend
No related tags found
2 merge requests!54Update 0.3.1 -> 0.4.0,!48Hotfix the default backend issue
Pipeline #67732 passed
......@@ -138,6 +138,8 @@ static std::string determineBackend(std::shared_ptr<Aidge::Node> node)
if (parentBackends.size() == 1)
return (*parentBackends.begin());
}
return "cpu"; // escape path when no parents are found
}
static int getInputIndex(std::shared_ptr<Node> node, std::shared_ptr<Node> parentNode)
......@@ -200,7 +202,7 @@ bool insertRoundBelowProducer(std::shared_ptr<Node> node, std::shared_ptr<GraphV
{
std::shared_ptr<Aidge::Node> roundNode = Round(node->name() + "_Round");
roundNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
roundNode->getOperator()->setBackend("cpu");
roundNode->getOperator()->setBackend(determineBackend(node));
insertChildren(node, roundNode, graphView);
roundNode->attributes()->addAttr("quantization.ptq.isProducerRounding", 0.0);
......@@ -367,7 +369,7 @@ bool insertScalingBelowProducer(std::shared_ptr<Node> producerNode, double scali
std::shared_ptr<Node> scalingNode = createScalingNode(scalingNodeName, {"isProducerScaling"}, scalingFactor);
scalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
scalingNode->getOperator()->setBackend("cpu");
scalingNode->getOperator()->setBackend(determineBackend(producerNode));
insertChildren(producerNode, scalingNode, graphView);
graphView->add(scalingNode->getParent(1)); // add the scaling factor producer
......@@ -403,7 +405,7 @@ void insertResidualScalingNodes(std::shared_ptr<GraphView> graphView)
std::shared_ptr<Node> residualNode = createScalingNode(residualNodeName, {"isScaling", "isResidual"}, 1.0);
residualNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
residualNode->getOperator()->setBackend("cpu");
residualNode->getOperator()->setBackend(determineBackend(parentNode));
graphView->insertParent(node, residualNode, i, 0, 0);
graphView->add(residualNode->getParent(1)); // add the scaling factor producer
......@@ -443,7 +445,7 @@ void insertScalingNodes(std::shared_ptr<GraphView> graphView)
std::shared_ptr<Node> scalingNode = createScalingNode(scalingNodeName, {"isScaling"}, 1.0);
scalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
scalingNode->getOperator()->setBackend("cpu");
scalingNode->getOperator()->setBackend(determineBackend(parentNode));
if (parentNode->getChildren().size() > 0) {
insertChildren(parentNode, scalingNode, graphView);
......@@ -464,7 +466,7 @@ void insertScalingNodes(std::shared_ptr<GraphView> graphView)
std::shared_ptr<Node> prevScalingNode = createScalingNode(prevScalingNodeName, {"isScaling"}, 1.0);
prevScalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
prevScalingNode->getOperator()->setBackend("cpu");
prevScalingNode->getOperator()->setBackend(determineBackend(parentNode));
graphView->insertParent(parentNode, prevScalingNode, 0, 0, 0);
graphView->add(prevScalingNode->getParent(1)); // add the scaling factor producer
......@@ -1026,7 +1028,7 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_
std::shared_ptr<Node> quantizerNode = Quantizer(oldScalingFactor, -(signedMax + 1), signedMax, node->name());
quantizerNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
quantizerNode->getOperator()->setBackend("cpu");
quantizerNode->getOperator()->setBackend(determineBackend(node));
graphView->replace({node, node->getParent(1)}, {quantizerNode});
if (optimizeSigns)
......@@ -1080,7 +1082,7 @@ static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::u
mulNode->attributes()->addAttr("quantization.ptq.isCompensation", 0.0);
mulNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
mulNode->getOperator()->setBackend("cpu");
mulNode->getOperator()->setBackend(determineBackend(node));
graphView->insertParent(node, mulNode, 0, 0, 0);
......@@ -1091,7 +1093,7 @@ static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::u
coeffProducer->getOperator()->setOutput(0, coeffTensor);
coeffProducer->getOperator()->setDataType(DataType::Float64);
coeffProducer->getOperator()->setBackend("cpu");
coeffProducer->getOperator()->setBackend(determineBackend(node));
graphView->add(coeffProducer); // needed ?
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment