Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • eclipse/aidge/aidge_quantization
  • bhalimi/aidge_quantization
  • noamzerah/aidge_quantization
  • hrouis/aidge_quantization
  • maab05/aidge_quantization
  • lucaslopez/aidge_quantization_ll
  • axelfarr/aidge_quantization
  • farges/aidge_quantization
8 results
Show changes
Commits on Source (7)
......@@ -111,6 +111,37 @@ static std::shared_ptr<Aidge::Node> getUniqueChildren(std::shared_ptr<Aidge::Nod
return *(childrenSet.begin());
}
static std::string determineBackend(std::shared_ptr<Aidge::Node> node)
{
std::string backend = node->getOperator()->backend();
if (backend != "")
return backend;
else
{
// gather the parent backends
std::set<std::string> parentBackends;
for (auto parent : node->getParents())
parentBackends.insert(determineBackend(parent)); // it always answers a non empty value !
// check if we have two or more different backends gathered
if (parentBackends.size() > 1)
{
Log::warn(" Unable to determine backend of node {} due to conflicting parent ones !", node->name());
return (*parentBackends.begin());
}
// if all parents have the same backend return it
if (parentBackends.size() == 1)
return (*parentBackends.begin());
}
return "cpu"; // escape path when no parents are found
}
static int getInputIndex(std::shared_ptr<Node> node, std::shared_ptr<Node> parentNode)
{
int index = 0;
......@@ -171,7 +202,7 @@ bool insertRoundBelowProducer(std::shared_ptr<Node> node, std::shared_ptr<GraphV
{
std::shared_ptr<Aidge::Node> roundNode = Round(node->name() + "_Round");
roundNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
roundNode->getOperator()->setBackend(node->getOperator()->backend());
roundNode->getOperator()->setBackend(determineBackend(node));
insertChildren(node, roundNode, graphView);
roundNode->attributes()->addAttr("quantization.ptq.isProducerRounding", 0.0);
......@@ -268,17 +299,7 @@ static std::shared_ptr<Tensor> getBiasTensor(std::shared_ptr<Node> node)
std::vector<std::shared_ptr<Node>> retrieveNodeVector(std::shared_ptr<GraphView> graphView, bool newSchedule, bool verbose)
{
std::vector<std::shared_ptr<Node>> nodeVector;
SequentialScheduler scheduler(graphView);
if (newSchedule)
{
scheduler.resetScheduling();
scheduler.generateScheduling(); // old way : scheduler.forward();
}
nodeVector = scheduler.getStaticScheduling();
std::vector<std::shared_ptr<Node>> nodeVector = graphView->getOrderedNodes();
fixScheduling(nodeVector);
nodeVector = removeMatchingNodes(nodeVector, "Producer");
......@@ -348,8 +369,7 @@ bool insertScalingBelowProducer(std::shared_ptr<Node> producerNode, double scali
std::shared_ptr<Node> scalingNode = createScalingNode(scalingNodeName, {"isProducerScaling"}, scalingFactor);
scalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
auto producerOp = std::static_pointer_cast<OperatorTensor>(producerNode->getOperator());
scalingNode->getOperator()->setBackend(producerOp->getOutput(0)->backend());
scalingNode->getOperator()->setBackend(determineBackend(producerNode));
insertChildren(producerNode, scalingNode, graphView);
graphView->add(scalingNode->getParent(1)); // add the scaling factor producer
......@@ -385,7 +405,7 @@ void insertResidualScalingNodes(std::shared_ptr<GraphView> graphView)
std::shared_ptr<Node> residualNode = createScalingNode(residualNodeName, {"isScaling", "isResidual"}, 1.0);
residualNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
residualNode->getOperator()->setBackend(parentNode->getOperator()->backend());
residualNode->getOperator()->setBackend(determineBackend(parentNode));
graphView->insertParent(node, residualNode, i, 0, 0);
graphView->add(residualNode->getParent(1)); // add the scaling factor producer
......@@ -425,7 +445,7 @@ void insertScalingNodes(std::shared_ptr<GraphView> graphView)
std::shared_ptr<Node> scalingNode = createScalingNode(scalingNodeName, {"isScaling"}, 1.0);
scalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
scalingNode->getOperator()->setBackend(parentNode->getOperator()->backend());
scalingNode->getOperator()->setBackend(determineBackend(parentNode));
if (parentNode->getChildren().size() > 0) {
insertChildren(parentNode, scalingNode, graphView);
......@@ -446,12 +466,11 @@ void insertScalingNodes(std::shared_ptr<GraphView> graphView)
std::shared_ptr<Node> prevScalingNode = createScalingNode(prevScalingNodeName, {"isScaling"}, 1.0);
prevScalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
prevScalingNode->getOperator()->setBackend(parentNode->getOperator()->backend());
prevScalingNode->getOperator()->setBackend(determineBackend(parentNode));
graphView->insertParent(parentNode, prevScalingNode, 0, 0, 0);
graphView->add(prevScalingNode->getParent(1)); // add the scaling factor producer
}
}
}
}
......@@ -1009,7 +1028,7 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_
std::shared_ptr<Node> quantizerNode = Quantizer(oldScalingFactor, -(signedMax + 1), signedMax, node->name());
quantizerNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
quantizerNode->getOperator()->setBackend(node->getOperator()->backend());
quantizerNode->getOperator()->setBackend(determineBackend(node));
graphView->replace({node, node->getParent(1)}, {quantizerNode});
if (optimizeSigns)
......@@ -1063,7 +1082,7 @@ static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::u
mulNode->attributes()->addAttr("quantization.ptq.isCompensation", 0.0);
mulNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
mulNode->getOperator()->setBackend(node->getOperator()->backend());
mulNode->getOperator()->setBackend(determineBackend(node));
graphView->insertParent(node, mulNode, 0, 0, 0);
......@@ -1074,7 +1093,7 @@ static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::u
coeffProducer->getOperator()->setOutput(0, coeffTensor);
coeffProducer->getOperator()->setDataType(DataType::Float64);
coeffProducer->getOperator()->setBackend(node->getOperator()->backend());
coeffProducer->getOperator()->setBackend(determineBackend(node));
graphView->add(coeffProducer); // needed ?
......
......@@ -152,10 +152,8 @@ void QuantFixedQ::insertAndInitQuantizers(std::shared_ptr<GraphView> graphView,
void QuantFixedQ::devQAT(std::shared_ptr<GraphView> graphView)
{
SequentialScheduler scheduler(graphView);
scheduler.generateScheduling();
auto s = scheduler.getStaticScheduling();
for (std::shared_ptr<Node> node : s)
auto nodeVector = graphView->getOrderedNodes();
for (std::shared_ptr<Node> node : nodeVector)
Log::info(" name : {} ", node->name());
}
......
......@@ -200,7 +200,7 @@ TEST_CASE("[tmp] basic test") {
// //no need to do this anymore, forward does it autimatically now ...
// //scheduler.generateScheduling(true);
// std::vector<std::shared_ptr<Node>> ordered_graph_view = scheduler.getStaticScheduling();
// std::vector<std::shared_ptr<Node>> ordered_graph_view = scheduler.getSequentialStaticScheduling();
// printf("Going to quantize network :\n");
......@@ -226,7 +226,7 @@ TEST_CASE("[tmp] basic test") {
// scheduler_v2.forward();
// scheduler_v2.generateScheduling(false);
// std::vector<std::shared_ptr<Node>> ordered_graph_view_v2 = scheduler_v2.getStaticScheduling();
// std::vector<std::shared_ptr<Node>> ordered_graph_view_v2 = scheduler_v2.getSequentialStaticScheduling();
// if(verbose) {
// printf("Ordered graph after quantization :\n");
......