Skip to content
Snippets Groups Projects
Commit e2f3a04b authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Adapt ExportNode to is_input/is_output parameter.

parent fd8690a8
No related branches found
No related tags found
3 merge requests!27v0.2.0,!22v0.4.0,!15Export refactor
......@@ -61,8 +61,8 @@ def export_params(name: str,
@operator_register(ExportLibCpp, "Producer")
class ProducerCPP(ExportNode):
def __init__(self, node, mem_info):
super().__init__(node, mem_info)
def __init__(self, node, mem_info, is_input, is_output):
super().__init__(node, mem_info, is_input, is_output)
self.values = np.array(self.operator.get_output(0))
if len(self.values.shape) == 4: # Note: export in HWC
......@@ -70,11 +70,11 @@ class ProducerCPP(ExportNode):
def export(self, export_folder: Path, list_configs: list):
list_configs.append(f"parameters/{self.attributes['name']}.h")
list_configs.append(f"include/parameters/{self.attributes['name']}.h")
export_params(
self.attributes['out_name'][0],
self.values.reshape(-1),
str(export_folder / "parameters" / f"{self.attributes['name']}.h"))
str(export_folder / "include" / "parameters" / f"{self.attributes['name']}.h"))
return list_configs
......@@ -87,8 +87,8 @@ class ProducerCPP(ExportNode):
@operator_register(ExportLibCpp, "ReLU")
class ReLUCPP(ExportNodeCpp):
def __init__(self, node, mem_info):
super().__init__(node, mem_info)
def __init__(self, node, mem_info, is_input, is_output):
super().__init__(node, mem_info, is_input, is_output)
self.attributes["activation"] = "Rectifier"
self.attributes["rescaling"] = "NoScaling"
self.config_template = str(
......@@ -106,8 +106,8 @@ class ReLUCPP(ExportNodeCpp):
@operator_register(ExportLibCpp, "Conv")
class ConvCPP(ExportNodeCpp):
def __init__(self, node, mem_info):
super().__init__(node, mem_info)
def __init__(self, node, mem_info, is_input, is_output):
super().__init__(node, mem_info, is_input, is_output)
# No padding with Conv
# Use PaddedConv to add padding attribute
self.attributes["padding"] = [0, 0]
......@@ -131,8 +131,8 @@ class ConvCPP(ExportNodeCpp):
@operator_register(ExportLibCpp, "PaddedConv")
class PaddedConvCPP(ExportNodeCpp):
def __init__(self, node, mem_info):
super().__init__(node, mem_info)
def __init__(self, node, mem_info, is_input, is_output):
super().__init__(node, mem_info, is_input, is_output)
# TODO find a way to retrive attr for meta op
for n in self.operator.get_micro_graph().get_nodes():
if n.type() == "Pad":
......@@ -163,8 +163,8 @@ class PaddedConvCPP(ExportNodeCpp):
@operator_register(ExportLibCpp, "Add")
class AddCPP(ExportNodeCpp):
def __init__(self, node, mem_info):
super().__init__(node, mem_info)
def __init__(self, node, mem_info, is_input, is_output):
super().__init__(node, mem_info, is_input, is_output)
self.attributes["elemwise_op"] = "Add"
self.attributes["activation"] = "Linear"
self.attributes["rescaling"] = "NoScaling"
......@@ -183,8 +183,8 @@ class AddCPP(ExportNodeCpp):
@operator_register(ExportLibCpp, "Sub")
class SubCPP(ExportNodeCpp):
def __init__(self, node, mem_info):
super().__init__(node, mem_info)
def __init__(self, node, mem_info, is_input, is_output):
super().__init__(node, mem_info, is_input, is_output)
self.attributes["elemwise_op"] = "Sub"
self.attributes["activation"] = "Linear"
self.attributes["rescaling"] = "NoScaling"
......@@ -204,8 +204,8 @@ class SubCPP(ExportNodeCpp):
@operator_register(ExportLibCpp, "Mul")
class MulCPP(ExportNodeCpp):
def __init__(self, node, mem_info):
super().__init__(node, mem_info)
def __init__(self, node, mem_info, is_input, is_output):
super().__init__(node, mem_info, is_input, is_output)
self.attributes["elemwise_op"] = "Mul"
self.attributes["activation"] = "Linear"
self.attributes["rescaling"] = "NoScaling"
......@@ -224,8 +224,8 @@ class MulCPP(ExportNodeCpp):
@operator_register(ExportLibCpp, "MaxPooling")
class MaxPoolCPP(ExportNodeCpp):
def __init__(self, node, mem_info):
super().__init__(node, mem_info)
def __init__(self, node, mem_info, is_input, is_output):
super().__init__(node, mem_info, is_input, is_output)
# No padding with MaxPooling
# Use PaddedMaxPooling to add padding attribute
......@@ -249,8 +249,8 @@ class MaxPoolCPP(ExportNodeCpp):
@operator_register(ExportLibCpp, "PaddedMaxPooling")
class PaddedMaxPoolCPP(ExportNodeCpp):
def __init__(self, node, mem_info):
super().__init__(node, mem_info)
def __init__(self, node, mem_info, is_input, is_output):
super().__init__(node, mem_info, is_input, is_output)
for n in self.operator.get_micro_graph().get_nodes():
if n.type() == "Pad":
self.attributes["padding"] = n.get_operator(
......@@ -278,8 +278,8 @@ class PaddedMaxPoolCPP(ExportNodeCpp):
@operator_register(ExportLibCpp, "GlobalAveragePooling")
class GlobalAveragePoolCPP(ExportNodeCpp):
def __init__(self, node, mem_info):
super().__init__(node, mem_info)
def __init__(self, node, mem_info, is_input, is_output):
super().__init__(node, mem_info, is_input, is_output)
self.attributes["stride_dims"] = [1, 1]
# No padding with MaxPooling
......@@ -307,8 +307,8 @@ class GlobalAveragePoolCPP(ExportNodeCpp):
@operator_register(ExportLibCpp, "FC")
class FcCPP(ExportNodeCpp):
def __init__(self, node, mem_info):
super().__init__(node, mem_info)
def __init__(self, node, mem_info, is_input, is_output):
super().__init__(node, mem_info, is_input, is_output)
self.attributes["activation"] = "Linear"
self.attributes["rescaling"] = "NoScaling"
self.config_template = str(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment