Skip to content
Snippets Groups Projects
Commit f2713d00 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Merge branch 'fix_backend' into 'dev'

Hotfix the default backend issue

See merge request !48
parents e9463363 2e21e5dd
No related branches found
No related tags found
2 merge requests!54Update 0.3.1 -> 0.4.0,!48Hotfix the default backend issue
Pipeline #67902 failed
......@@ -111,6 +111,37 @@ static std::shared_ptr<Aidge::Node> getUniqueChildren(std::shared_ptr<Aidge::Nod
return *(childrenSet.begin());
}
static std::string determineBackend(std::shared_ptr<Aidge::Node> node)
{
std::string backend = node->getOperator()->backend();
if (backend != "")
return backend;
else
{
// gather the parent backends
std::set<std::string> parentBackends;
for (auto parent : node->getParents())
parentBackends.insert(determineBackend(parent)); // it always answers a non empty value !
// check if we have two or more different backends gathered
if (parentBackends.size() > 1)
{
Log::warn(" Unable to determine backend of node {} due to conflicting parent ones !", node->name());
return (*parentBackends.begin());
}
// if all parents have the same backend return it
if (parentBackends.size() == 1)
return (*parentBackends.begin());
}
return "cpu"; // escape path when no parents are found
}
static int getInputIndex(std::shared_ptr<Node> node, std::shared_ptr<Node> parentNode)
{
int index = 0;
......@@ -171,7 +202,7 @@ bool insertRoundBelowProducer(std::shared_ptr<Node> node, std::shared_ptr<GraphV
{
std::shared_ptr<Aidge::Node> roundNode = Round(node->name() + "_Round");
roundNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
roundNode->getOperator()->setBackend(node->getOperator()->backend());
roundNode->getOperator()->setBackend(determineBackend(node));
insertChildren(node, roundNode, graphView);
roundNode->attributes()->addAttr("quantization.ptq.isProducerRounding", 0.0);
......@@ -268,17 +299,7 @@ static std::shared_ptr<Tensor> getBiasTensor(std::shared_ptr<Node> node)
std::vector<std::shared_ptr<Node>> retrieveNodeVector(std::shared_ptr<GraphView> graphView, bool newSchedule, bool verbose)
{
std::vector<std::shared_ptr<Node>> nodeVector;
SequentialScheduler scheduler(graphView);
if (newSchedule)
{
scheduler.resetScheduling();
scheduler.generateScheduling(); // old way : scheduler.forward();
}
nodeVector = scheduler.getSequentialStaticScheduling();
std::vector<std::shared_ptr<Node>> nodeVector = graphView->getOrderedNodes();
fixScheduling(nodeVector);
nodeVector = removeMatchingNodes(nodeVector, "Producer");
......@@ -348,8 +369,7 @@ bool insertScalingBelowProducer(std::shared_ptr<Node> producerNode, double scali
std::shared_ptr<Node> scalingNode = createScalingNode(scalingNodeName, {"isProducerScaling"}, scalingFactor);
scalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
auto producerOp = std::static_pointer_cast<OperatorTensor>(producerNode->getOperator());
scalingNode->getOperator()->setBackend(producerOp->getOutput(0)->backend());
scalingNode->getOperator()->setBackend(determineBackend(producerNode));
insertChildren(producerNode, scalingNode, graphView);
graphView->add(scalingNode->getParent(1)); // add the scaling factor producer
......@@ -385,7 +405,7 @@ void insertResidualScalingNodes(std::shared_ptr<GraphView> graphView)
std::shared_ptr<Node> residualNode = createScalingNode(residualNodeName, {"isScaling", "isResidual"}, 1.0);
residualNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
residualNode->getOperator()->setBackend(parentNode->getOperator()->backend());
residualNode->getOperator()->setBackend(determineBackend(parentNode));
graphView->insertParent(node, residualNode, i, 0, 0);
graphView->add(residualNode->getParent(1)); // add the scaling factor producer
......@@ -425,7 +445,7 @@ void insertScalingNodes(std::shared_ptr<GraphView> graphView)
std::shared_ptr<Node> scalingNode = createScalingNode(scalingNodeName, {"isScaling"}, 1.0);
scalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
scalingNode->getOperator()->setBackend(parentNode->getOperator()->backend());
scalingNode->getOperator()->setBackend(determineBackend(parentNode));
if (parentNode->getChildren().size() > 0) {
insertChildren(parentNode, scalingNode, graphView);
......@@ -446,12 +466,11 @@ void insertScalingNodes(std::shared_ptr<GraphView> graphView)
std::shared_ptr<Node> prevScalingNode = createScalingNode(prevScalingNodeName, {"isScaling"}, 1.0);
prevScalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
prevScalingNode->getOperator()->setBackend(parentNode->getOperator()->backend());
prevScalingNode->getOperator()->setBackend(determineBackend(parentNode));
graphView->insertParent(parentNode, prevScalingNode, 0, 0, 0);
graphView->add(prevScalingNode->getParent(1)); // add the scaling factor producer
}
}
}
}
......@@ -1009,7 +1028,7 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_
std::shared_ptr<Node> quantizerNode = Quantizer(oldScalingFactor, -(signedMax + 1), signedMax, node->name());
quantizerNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
quantizerNode->getOperator()->setBackend(node->getOperator()->backend());
quantizerNode->getOperator()->setBackend(determineBackend(node));
graphView->replace({node, node->getParent(1)}, {quantizerNode});
if (optimizeSigns)
......@@ -1063,7 +1082,7 @@ static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::u
mulNode->attributes()->addAttr("quantization.ptq.isCompensation", 0.0);
mulNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
mulNode->getOperator()->setBackend(node->getOperator()->backend());
mulNode->getOperator()->setBackend(determineBackend(node));
graphView->insertParent(node, mulNode, 0, 0, 0);
......@@ -1074,7 +1093,7 @@ static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::u
coeffProducer->getOperator()->setOutput(0, coeffTensor);
coeffProducer->getOperator()->setDataType(DataType::Float64);
coeffProducer->getOperator()->setBackend(node->getOperator()->backend());
coeffProducer->getOperator()->setBackend(determineBackend(node));
graphView->add(coeffProducer); // needed ?
......
......@@ -152,10 +152,8 @@ void QuantFixedQ::insertAndInitQuantizers(std::shared_ptr<GraphView> graphView,
void QuantFixedQ::devQAT(std::shared_ptr<GraphView> graphView)
{
SequentialScheduler scheduler(graphView);
scheduler.generateScheduling();
auto s = scheduler.getSequentialStaticScheduling();
for (std::shared_ptr<Node> node : s)
auto nodeVector = graphView->getOrderedNodes();
for (std::shared_ptr<Node> node : nodeVector)
Log::info(" name : {} ", node->name());
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment