diff --git a/aidge_export_cpp/export_utils.py b/aidge_export_cpp/export_utils.py index 441d21b49a5d6df53690bcc49583279699abed75..ac5901748871977bff61efedd326a6181e22231a 100644 --- a/aidge_export_cpp/export_utils.py +++ b/aidge_export_cpp/export_utils.py @@ -85,10 +85,12 @@ def set_nodes_names(scheduler): node_ids = {} # Dict holding the node type along with a counter node_it = 0 # Node Iterator + + ## MetaOps for node in scheduler.get_sequential_static_scheduling(): node_type = node.type() - if node_type != "Producer": # Producers are + if node_type != "Producer": if node.type() not in node_ids: node_ids[node_type] = 0 @@ -99,23 +101,27 @@ def set_nodes_names(scheduler): node_it += 1 # Set producers names - if node_type in ["QConv", "PadConv", "ConvAct", "PadConvAct", "QFC", "FCAct"]: - # nb_parents = len(node.get_parents()) + ## Weights & Biases producers + if get_node_from_metaop(node, "FC") or \ + get_node_from_metaop(node, "Conv2D") or \ + get_node_from_metaop(node, "ConvDepthWise2D"): + node.get_parent(1).set_name(node.name() + "_weights") if node.get_parent(2) is not None: node.get_parent(2).set_name(node.name() + "_biases") - for parent_node in node.get_parents(): - if parent_node is not None: - # [TODO] Does not work yet - # if parent_node.attributes().has_attr("quantization.ptq.CompensationCoeff"): - # parent_node.set_name(node.name() + "_coeff") - if parent_node.attributes().has_attr("quantization.ptq.ShiftAmount"): - parent_node.set_name(node.name() + "_shift") - # [Fix] Add scaling/add coeff nodes manually - elif node.type() in ["CppElemWise", ""] and parent_node.type() == "Producer": - parent_node.set_name(node.name() + "_coeff") - # [End Fix] + ## Scaling Producers + for node in scheduler.get_sequential_static_scheduling(): + """ + TODO: If multiple quantizer nodes are found, the producers will + all have the same name and this will not work properly. + """ + if node.type() == "Producer": + child_node = node.output(0)[0][0] + if node.attributes().has_attr("shift_prod"): + node.set_name(child_node.name() + "_shift") + if node.attributes().has_attr("coef_prod"): + node.set_name(child_node.name() + "_coef") @@ -124,13 +130,17 @@ def set_nodes_datatypes(graph_view: aidge_core.GraphView): The set_datatype function can't be used on Conv2D and FC nodes directly as the biases datatype is different from the other inputs. + TODO: Should be using forward_datatype() :param graph_view: An instance of :py:class:`aidge_core.graph_view`, providing access to nodes and ordered input/output data within the computational graph. """ for node in graph_view.get_nodes(): if node.type() != "Producer": - if node.type() in ["QConv", "PadConv", "ConvAct", "PadConvAct", "QFC", "FCAct"]: + if get_node_from_metaop(node, "FC") or \ + get_node_from_metaop(node, "Conv2D") or \ + get_node_from_metaop(node, "ConvDepthWise2D"): + node.get_operator().get_input(0).set_datatype(aidge_core.dtype.int8) # Input node.get_operator().get_input(1).set_datatype(aidge_core.dtype.int8) # Weights if node.get_parent(2) is not None: @@ -199,13 +209,14 @@ def set_scaling_attributes(export_node: aidge_core.export_utils.ExportNode, node """ QNode = get_node_from_metaop(node, "Quantizer") - - if QNode is not None: - for n in QNode.get_operator().get_micro_graph().get_nodes(): - if n.type() == "BitShift": - export_node.attributes["shift_value"] = n.get_operator().get_input(1)[0] - elif n.type() == "Mul": - export_node.attributes["coef_value"] = n.get_operator().get_input(1)[0] + if QNode: + BNode = get_node_from_metaop(QNode[0], "BitShift") + export_node.attributes["shift_value"] = BNode[0].get_operator().get_input(1)[0] + + QMulNode = get_node_from_metaop(node, "QMul") + if QMulNode: + CNode = get_node_from_metaop(QMulNode[0], "Mul") + export_node.attributes["coef_value"] = CNode[0].get_operator().get_input(1)[0]