Skip to content
Snippets Groups Projects
Commit 95b279ec authored by Axel Farrugia's avatar Axel Farrugia
Browse files

[Refactor] Adapt the way to set the scaling attributes, datatypes and node names

parent 0cf74856
No related branches found
No related tags found
No related merge requests found
...@@ -85,10 +85,12 @@ def set_nodes_names(scheduler): ...@@ -85,10 +85,12 @@ def set_nodes_names(scheduler):
node_ids = {} # Dict holding the node type along with a counter node_ids = {} # Dict holding the node type along with a counter
node_it = 0 # Node Iterator node_it = 0 # Node Iterator
## MetaOps
for node in scheduler.get_sequential_static_scheduling(): for node in scheduler.get_sequential_static_scheduling():
node_type = node.type() node_type = node.type()
if node_type != "Producer": # Producers are if node_type != "Producer":
if node.type() not in node_ids: if node.type() not in node_ids:
node_ids[node_type] = 0 node_ids[node_type] = 0
...@@ -99,23 +101,27 @@ def set_nodes_names(scheduler): ...@@ -99,23 +101,27 @@ def set_nodes_names(scheduler):
node_it += 1 node_it += 1
# Set producers names # Set producers names
if node_type in ["QConv", "PadConv", "ConvAct", "PadConvAct", "QFC", "FCAct"]: ## Weights & Biases producers
# nb_parents = len(node.get_parents()) if get_node_from_metaop(node, "FC") or \
get_node_from_metaop(node, "Conv2D") or \
get_node_from_metaop(node, "ConvDepthWise2D"):
node.get_parent(1).set_name(node.name() + "_weights") node.get_parent(1).set_name(node.name() + "_weights")
if node.get_parent(2) is not None: if node.get_parent(2) is not None:
node.get_parent(2).set_name(node.name() + "_biases") node.get_parent(2).set_name(node.name() + "_biases")
for parent_node in node.get_parents(): ## Scaling Producers
if parent_node is not None: for node in scheduler.get_sequential_static_scheduling():
# [TODO] Does not work yet """
# if parent_node.attributes().has_attr("quantization.ptq.CompensationCoeff"): TODO: If multiple quantizer nodes are found, the producers will
# parent_node.set_name(node.name() + "_coeff") all have the same name and this will not work properly.
if parent_node.attributes().has_attr("quantization.ptq.ShiftAmount"): """
parent_node.set_name(node.name() + "_shift") if node.type() == "Producer":
# [Fix] Add scaling/add coeff nodes manually child_node = node.output(0)[0][0]
elif node.type() in ["CppElemWise", ""] and parent_node.type() == "Producer": if node.attributes().has_attr("shift_prod"):
parent_node.set_name(node.name() + "_coeff") node.set_name(child_node.name() + "_shift")
# [End Fix] if node.attributes().has_attr("coef_prod"):
node.set_name(child_node.name() + "_coef")
...@@ -124,13 +130,17 @@ def set_nodes_datatypes(graph_view: aidge_core.GraphView): ...@@ -124,13 +130,17 @@ def set_nodes_datatypes(graph_view: aidge_core.GraphView):
The set_datatype function can't be used on Conv2D and FC nodes directly The set_datatype function can't be used on Conv2D and FC nodes directly
as the biases datatype is different from the other inputs. as the biases datatype is different from the other inputs.
TODO: Should be using forward_datatype()
:param graph_view: An instance of :py:class:`aidge_core.graph_view`, providing access to nodes and :param graph_view: An instance of :py:class:`aidge_core.graph_view`, providing access to nodes and
ordered input/output data within the computational graph. ordered input/output data within the computational graph.
""" """
for node in graph_view.get_nodes(): for node in graph_view.get_nodes():
if node.type() != "Producer": if node.type() != "Producer":
if node.type() in ["QConv", "PadConv", "ConvAct", "PadConvAct", "QFC", "FCAct"]: if get_node_from_metaop(node, "FC") or \
get_node_from_metaop(node, "Conv2D") or \
get_node_from_metaop(node, "ConvDepthWise2D"):
node.get_operator().get_input(0).set_datatype(aidge_core.dtype.int8) # Input node.get_operator().get_input(0).set_datatype(aidge_core.dtype.int8) # Input
node.get_operator().get_input(1).set_datatype(aidge_core.dtype.int8) # Weights node.get_operator().get_input(1).set_datatype(aidge_core.dtype.int8) # Weights
if node.get_parent(2) is not None: if node.get_parent(2) is not None:
...@@ -199,13 +209,14 @@ def set_scaling_attributes(export_node: aidge_core.export_utils.ExportNode, node ...@@ -199,13 +209,14 @@ def set_scaling_attributes(export_node: aidge_core.export_utils.ExportNode, node
""" """
QNode = get_node_from_metaop(node, "Quantizer") QNode = get_node_from_metaop(node, "Quantizer")
if QNode:
if QNode is not None: BNode = get_node_from_metaop(QNode[0], "BitShift")
for n in QNode.get_operator().get_micro_graph().get_nodes(): export_node.attributes["shift_value"] = BNode[0].get_operator().get_input(1)[0]
if n.type() == "BitShift":
export_node.attributes["shift_value"] = n.get_operator().get_input(1)[0] QMulNode = get_node_from_metaop(node, "QMul")
elif n.type() == "Mul": if QMulNode:
export_node.attributes["coef_value"] = n.get_operator().get_input(1)[0] CNode = get_node_from_metaop(QMulNode[0], "Mul")
export_node.attributes["coef_value"] = CNode[0].get_operator().get_input(1)[0]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment