Skip to content
Snippets Groups Projects
Commit 95b279ec authored by Axel Farrugia's avatar Axel Farrugia
Browse files

[Refactor] Adapt the way to set the scaling attributes, datatypes and node names

parent 0cf74856
No related branches found
No related tags found
No related merge requests found
......@@ -85,10 +85,12 @@ def set_nodes_names(scheduler):
node_ids = {} # Dict holding the node type along with a counter
node_it = 0 # Node Iterator
## MetaOps
for node in scheduler.get_sequential_static_scheduling():
node_type = node.type()
if node_type != "Producer": # Producers are
if node_type != "Producer":
if node.type() not in node_ids:
node_ids[node_type] = 0
......@@ -99,23 +101,27 @@ def set_nodes_names(scheduler):
node_it += 1
# Set producers names
if node_type in ["QConv", "PadConv", "ConvAct", "PadConvAct", "QFC", "FCAct"]:
# nb_parents = len(node.get_parents())
## Weights & Biases producers
if get_node_from_metaop(node, "FC") or \
get_node_from_metaop(node, "Conv2D") or \
get_node_from_metaop(node, "ConvDepthWise2D"):
node.get_parent(1).set_name(node.name() + "_weights")
if node.get_parent(2) is not None:
node.get_parent(2).set_name(node.name() + "_biases")
for parent_node in node.get_parents():
if parent_node is not None:
# [TODO] Does not work yet
# if parent_node.attributes().has_attr("quantization.ptq.CompensationCoeff"):
# parent_node.set_name(node.name() + "_coeff")
if parent_node.attributes().has_attr("quantization.ptq.ShiftAmount"):
parent_node.set_name(node.name() + "_shift")
# [Fix] Add scaling/add coeff nodes manually
elif node.type() in ["CppElemWise", ""] and parent_node.type() == "Producer":
parent_node.set_name(node.name() + "_coeff")
# [End Fix]
## Scaling Producers
for node in scheduler.get_sequential_static_scheduling():
"""
TODO: If multiple quantizer nodes are found, the producers will
all have the same name and this will not work properly.
"""
if node.type() == "Producer":
child_node = node.output(0)[0][0]
if node.attributes().has_attr("shift_prod"):
node.set_name(child_node.name() + "_shift")
if node.attributes().has_attr("coef_prod"):
node.set_name(child_node.name() + "_coef")
......@@ -124,13 +130,17 @@ def set_nodes_datatypes(graph_view: aidge_core.GraphView):
The set_datatype function can't be used on Conv2D and FC nodes directly
as the biases datatype is different from the other inputs.
TODO: Should be using forward_datatype()
:param graph_view: An instance of :py:class:`aidge_core.graph_view`, providing access to nodes and
ordered input/output data within the computational graph.
"""
for node in graph_view.get_nodes():
if node.type() != "Producer":
if node.type() in ["QConv", "PadConv", "ConvAct", "PadConvAct", "QFC", "FCAct"]:
if get_node_from_metaop(node, "FC") or \
get_node_from_metaop(node, "Conv2D") or \
get_node_from_metaop(node, "ConvDepthWise2D"):
node.get_operator().get_input(0).set_datatype(aidge_core.dtype.int8) # Input
node.get_operator().get_input(1).set_datatype(aidge_core.dtype.int8) # Weights
if node.get_parent(2) is not None:
......@@ -199,13 +209,14 @@ def set_scaling_attributes(export_node: aidge_core.export_utils.ExportNode, node
"""
QNode = get_node_from_metaop(node, "Quantizer")
if QNode is not None:
for n in QNode.get_operator().get_micro_graph().get_nodes():
if n.type() == "BitShift":
export_node.attributes["shift_value"] = n.get_operator().get_input(1)[0]
elif n.type() == "Mul":
export_node.attributes["coef_value"] = n.get_operator().get_input(1)[0]
if QNode:
BNode = get_node_from_metaop(QNode[0], "BitShift")
export_node.attributes["shift_value"] = BNode[0].get_operator().get_input(1)[0]
QMulNode = get_node_from_metaop(node, "QMul")
if QMulNode:
CNode = get_node_from_metaop(QMulNode[0], "Mul")
export_node.attributes["coef_value"] = CNode[0].get_operator().get_input(1)[0]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment