Skip to content
Snippets Groups Projects
Commit 4a65163a authored by Benjamin Halimi's avatar Benjamin Halimi
Browse files

Merge branch 'dev' into 'add_matmul'

# Conflicts:
#   src/PTQ/PTQ.cpp
parents c3b64dd1 f2713d00
No related branches found
No related tags found
2 merge requests!54Update 0.3.1 -> 0.4.0,!45Add support for the MatMul operator
Pipeline #68113 passed
...@@ -165,6 +165,37 @@ static std::shared_ptr<Aidge::Node> getUniqueChild(std::shared_ptr<Aidge::Node> ...@@ -165,6 +165,37 @@ static std::shared_ptr<Aidge::Node> getUniqueChild(std::shared_ptr<Aidge::Node>
return *(childrenSet.begin()); return *(childrenSet.begin());
} }
static std::string determineBackend(std::shared_ptr<Aidge::Node> node)
{
std::string backend = node->getOperator()->backend();
if (backend != "")
return backend;
else
{
// gather the parent backends
std::set<std::string> parentBackends;
for (auto parent : node->getParents())
parentBackends.insert(determineBackend(parent)); // it always answers a non empty value !
// check if we have two or more different backends gathered
if (parentBackends.size() > 1)
{
Log::warn(" Unable to determine backend of node {} due to conflicting parent ones !", node->name());
return (*parentBackends.begin());
}
// if all parents have the same backend return it
if (parentBackends.size() == 1)
return (*parentBackends.begin());
}
return "cpu"; // escape path when no parents are found
}
static int getInputIndex(std::shared_ptr<Node> node, std::shared_ptr<Node> parentNode) static int getInputIndex(std::shared_ptr<Node> node, std::shared_ptr<Node> parentNode)
{ {
int index = 0; int index = 0;
...@@ -225,7 +256,7 @@ bool insertRoundBelowProducer(std::shared_ptr<Node> node, std::shared_ptr<GraphV ...@@ -225,7 +256,7 @@ bool insertRoundBelowProducer(std::shared_ptr<Node> node, std::shared_ptr<GraphV
{ {
std::shared_ptr<Aidge::Node> roundNode = Round(node->name() + "_Round"); std::shared_ptr<Aidge::Node> roundNode = Round(node->name() + "_Round");
roundNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) roundNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
roundNode->getOperator()->setBackend(node->getOperator()->backend()); roundNode->getOperator()->setBackend(determineBackend(node));
insertChildren(node, roundNode, graphView); insertChildren(node, roundNode, graphView);
addAttr(roundNode, "isProducerRounding"); addAttr(roundNode, "isProducerRounding");
...@@ -394,8 +425,7 @@ bool insertScalingBelowProducer(std::shared_ptr<Node> producerNode, double scali ...@@ -394,8 +425,7 @@ bool insertScalingBelowProducer(std::shared_ptr<Node> producerNode, double scali
std::shared_ptr<Node> scalingNode = createScalingNode(scalingNodeName, {"isProducerScaling"}, scalingFactor); std::shared_ptr<Node> scalingNode = createScalingNode(scalingNodeName, {"isProducerScaling"}, scalingFactor);
scalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) scalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
auto producerOp = std::static_pointer_cast<OperatorTensor>(producerNode->getOperator()); scalingNode->getOperator()->setBackend(determineBackend(producerNode));
scalingNode->getOperator()->setBackend(producerOp->getOutput(0)->backend());
insertChildren(producerNode, scalingNode, graphView); insertChildren(producerNode, scalingNode, graphView);
graphView->add(scalingNode->getParent(1)); // add the scaling factor producer graphView->add(scalingNode->getParent(1)); // add the scaling factor producer
...@@ -431,7 +461,7 @@ void insertResidualScalingNodes(std::shared_ptr<GraphView> graphView) ...@@ -431,7 +461,7 @@ void insertResidualScalingNodes(std::shared_ptr<GraphView> graphView)
std::shared_ptr<Node> residualNode = createScalingNode(residualNodeName, {"isScaling", "isResidual"}, 1.0); std::shared_ptr<Node> residualNode = createScalingNode(residualNodeName, {"isScaling", "isResidual"}, 1.0);
residualNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) residualNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
residualNode->getOperator()->setBackend(parentNode->getOperator()->backend()); residualNode->getOperator()->setBackend(determineBackend(parentNode));
graphView->insertParent(node, residualNode, i, 0, 0); graphView->insertParent(node, residualNode, i, 0, 0);
graphView->add(residualNode->getParent(1)); // add the scaling factor producer graphView->add(residualNode->getParent(1)); // add the scaling factor producer
...@@ -471,7 +501,7 @@ void insertScalingNodes(std::shared_ptr<GraphView> graphView) ...@@ -471,7 +501,7 @@ void insertScalingNodes(std::shared_ptr<GraphView> graphView)
std::shared_ptr<Node> scalingNode = createScalingNode(scalingNodeName, {"isScaling"}, 1.0); std::shared_ptr<Node> scalingNode = createScalingNode(scalingNodeName, {"isScaling"}, 1.0);
scalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) scalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
scalingNode->getOperator()->setBackend(parentNode->getOperator()->backend()); scalingNode->getOperator()->setBackend(determineBackend(parentNode));
if (parentNode->getChildren().size() > 0) { if (parentNode->getChildren().size() > 0) {
insertChildren(parentNode, scalingNode, graphView); insertChildren(parentNode, scalingNode, graphView);
...@@ -492,7 +522,7 @@ void insertScalingNodes(std::shared_ptr<GraphView> graphView) ...@@ -492,7 +522,7 @@ void insertScalingNodes(std::shared_ptr<GraphView> graphView)
std::shared_ptr<Node> prevScalingNode = createScalingNode(prevScalingNodeName, {"isScaling"}, 1.0); std::shared_ptr<Node> prevScalingNode = createScalingNode(prevScalingNodeName, {"isScaling"}, 1.0);
prevScalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) prevScalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
prevScalingNode->getOperator()->setBackend(parentNode->getOperator()->backend()); prevScalingNode->getOperator()->setBackend(determineBackend(parentNode));
graphView->insertParent(parentNode, prevScalingNode, 0, 0, 0); graphView->insertParent(parentNode, prevScalingNode, 0, 0, 0);
graphView->add(prevScalingNode->getParent(1)); // add the scaling factor producer graphView->add(prevScalingNode->getParent(1)); // add the scaling factor producer
...@@ -1087,7 +1117,7 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_ ...@@ -1087,7 +1117,7 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_
std::shared_ptr<Node> quantizerNode = Quantizer(oldScalingFactor, -(signedMax + 1), signedMax, node->name()); std::shared_ptr<Node> quantizerNode = Quantizer(oldScalingFactor, -(signedMax + 1), signedMax, node->name());
quantizerNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) quantizerNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
quantizerNode->getOperator()->setBackend(node->getOperator()->backend()); quantizerNode->getOperator()->setBackend(determineBackend(node));
graphView->replace({node, node->getParent(1)}, {quantizerNode}); graphView->replace({node, node->getParent(1)}, {quantizerNode});
if (optimizeSigns) if (optimizeSigns)
...@@ -1142,7 +1172,7 @@ static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::u ...@@ -1142,7 +1172,7 @@ static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::u
addAttr(mulNode, "isCompensation"); addAttr(mulNode, "isCompensation");
mulNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) mulNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
mulNode->getOperator()->setBackend(node->getOperator()->backend()); mulNode->getOperator()->setBackend(determineBackend(node));
graphView->insertParent(node, mulNode, 0, 0, 0); graphView->insertParent(node, mulNode, 0, 0, 0);
...@@ -1153,7 +1183,7 @@ static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::u ...@@ -1153,7 +1183,7 @@ static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::u
coeffProducer->getOperator()->setOutput(0, coeffTensor); coeffProducer->getOperator()->setOutput(0, coeffTensor);
coeffProducer->getOperator()->setDataType(DataType::Float64); coeffProducer->getOperator()->setDataType(DataType::Float64);
coeffProducer->getOperator()->setBackend(node->getOperator()->backend()); coeffProducer->getOperator()->setBackend(determineBackend(node));
graphView->add(coeffProducer); // needed ? graphView->add(coeffProducer); // needed ?
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment