Skip to content
Snippets Groups Projects
Commit 86c33bf4 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Fix operator.py with new merge.

parent c1fbb783
No related branches found
No related tags found
3 merge requests!27v0.2.0,!22v0.4.0,!15Export refactor
...@@ -126,7 +126,7 @@ class ReLUCPP(ExportNode): ...@@ -126,7 +126,7 @@ class ReLUCPP(ExportNode):
list_actions.append(generate_str( list_actions.append(generate_str(
str(ROOT / "templates" / "kernel_forward" / "activation_forward.jinja"), str(ROOT / "templates" / "kernel_forward" / "activation_forward.jinja"),
name=self.name, name=self.name,
input_name=self.inputs[0].name() if self.inputs[0] else self.name + "_input", input_name=f"{self.name}_input" if self.inputs[0] is None else self.inputs[0].name(),
output_name=self.name output_name=self.name
)) ))
return list_actions return list_actions
...@@ -191,10 +191,10 @@ class ConvCPP(ExportNode): ...@@ -191,10 +191,10 @@ class ConvCPP(ExportNode):
list_actions.append(generate_str( list_actions.append(generate_str(
str(ROOT / "templates" / "kernel_forward" / "convolution_forward.jinja"), str(ROOT / "templates" / "kernel_forward" / "convolution_forward.jinja"),
name=self.name, name=self.name,
input_name=self.inputs[0].name() if self.inputs[0] else self.name + "_input", input_name=f"{self.name}_input_0" if self.inputs[0] is None else self.inputs[0].name(),
output_name=self.name, output_name=self.name,
weights_name=self.inputs[1].name(), weights_name=f"{self.name}_input_1" if self.inputs[1] is None else self.inputs[1].name(),
biases_name=self.inputs[2].name() biases_name=f"{self.name}_input_2" if self.inputs[2] is None else self.inputs[2].name()
)) ))
return list_actions return list_actions
...@@ -251,8 +251,8 @@ class AddCPP(ExportNode): ...@@ -251,8 +251,8 @@ class AddCPP(ExportNode):
list_actions.append(generate_str( list_actions.append(generate_str(
str(ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja"), str(ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja"),
name=self.name, name=self.name,
inputs1_name=self.parents[0].name() if self.parents[0] else self.name + "_input1", inputs1_name=self.inputs[0].name() if self.inputs[0] else self.name + "_input1",
inputs2_name=self.parents[1].name() if self.parents[1] else self.name + "_input2", inputs2_name=self.inputs[1].name() if self.inputs[1] else self.name + "_input2",
output_name=self.name output_name=self.name
)) ))
return list_actions return list_actions
...@@ -284,15 +284,15 @@ class SubCPP(ExportNode): ...@@ -284,15 +284,15 @@ class SubCPP(ExportNode):
list_actions.append(generate_str( list_actions.append(generate_str(
str(ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja"), str(ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja"),
name=self.name, name=self.name,
inputs1_name=self.inputs[0].name() if self.inputs[0] else self.name + "_input1", inputs1_name=f"{self.name}_input_0" if self.inputs[0] is None else self.inputs[0].name(),
inputs2_name=self.inputs[1].name() if self.inputs[1] else self.name + "_input2", inputs2_name=f"{self.name}_input_1" if self.inputs[1] is None else self.inputs[1].name(),
output_name=self.name output_name=self.name
)) ))
return list_actions return list_actions
@operator_register("PaddedMaxPooling") @operator_register("PaddedMaxPooling")
class MaxPoolCPP(ExportNode): class PaddedMaxPoolCPP(ExportNode):
def __init__(self, node): def __init__(self, node):
super().__init__(node) super().__init__(node)
for n in self.operator.get_micro_graph().get_nodes(): for n in self.operator.get_micro_graph().get_nodes():
...@@ -342,9 +342,7 @@ class MaxPoolCPP(ExportNode): ...@@ -342,9 +342,7 @@ class MaxPoolCPP(ExportNode):
list_actions.append(generate_str( list_actions.append(generate_str(
str(ROOT / "templates" / "kernel_forward" / "pooling_forward.jinja"), str(ROOT / "templates" / "kernel_forward" / "pooling_forward.jinja"),
name=self.name, name=self.name,
input_name=f"{self.name}_0" if self.inputs[0] is None else self.inputs[0].name(), input_name=f"{self.name}_input_0" if self.inputs[0] is None else self.inputs[0].name(),
inputs1_name=self.inputs[0].name() if self.inputs[0] else self.name + "_input1",
inputs2_name=self.inputs[1].name() if self.inputs[1] else self.name + "_input2",
output_name=self.name output_name=self.name
)) ))
return list_actions return list_actions
...@@ -457,9 +455,9 @@ class FcCPP(ExportNode): ...@@ -457,9 +455,9 @@ class FcCPP(ExportNode):
list_actions.append(generate_str( list_actions.append(generate_str(
str(ROOT / "templates" / "kernel_forward" / "fullyconnected_forward.jinja"), str(ROOT / "templates" / "kernel_forward" / "fullyconnected_forward.jinja"),
name=self.name, name=self.name,
inputs_name= self.inputs[0].name() if (self.inputs[0] is not None) else self.name + '_input', inputs_name=f"{self.name}_input" if self.inputs[0] is None else self.inputs[0].name(),
weights_name=self.inputs[1].name(), weights_name=self.inputs[1].name(),
biases_name=self.inputs[2].name(), biases_name=self.inputs[2].name(), # TODO we should check if bias
outputs_name=self.name outputs_name=self.name
)) ))
return list_actions return list_actions
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment