Skip to content
Snippets Groups Projects
Commit 86c33bf4 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Fix operator.py with new merge.

parent c1fbb783
No related branches found
No related tags found
3 merge requests!27v0.2.0,!22v0.4.0,!15Export refactor
......@@ -126,7 +126,7 @@ class ReLUCPP(ExportNode):
list_actions.append(generate_str(
str(ROOT / "templates" / "kernel_forward" / "activation_forward.jinja"),
name=self.name,
input_name=self.inputs[0].name() if self.inputs[0] else self.name + "_input",
input_name=f"{self.name}_input" if self.inputs[0] is None else self.inputs[0].name(),
output_name=self.name
))
return list_actions
......@@ -191,10 +191,10 @@ class ConvCPP(ExportNode):
list_actions.append(generate_str(
str(ROOT / "templates" / "kernel_forward" / "convolution_forward.jinja"),
name=self.name,
input_name=self.inputs[0].name() if self.inputs[0] else self.name + "_input",
input_name=f"{self.name}_input_0" if self.inputs[0] is None else self.inputs[0].name(),
output_name=self.name,
weights_name=self.inputs[1].name(),
biases_name=self.inputs[2].name()
weights_name=f"{self.name}_input_1" if self.inputs[1] is None else self.inputs[1].name(),
biases_name=f"{self.name}_input_2" if self.inputs[2] is None else self.inputs[2].name()
))
return list_actions
......@@ -251,8 +251,8 @@ class AddCPP(ExportNode):
list_actions.append(generate_str(
str(ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja"),
name=self.name,
inputs1_name=self.parents[0].name() if self.parents[0] else self.name + "_input1",
inputs2_name=self.parents[1].name() if self.parents[1] else self.name + "_input2",
inputs1_name=self.inputs[0].name() if self.inputs[0] else self.name + "_input1",
inputs2_name=self.inputs[1].name() if self.inputs[1] else self.name + "_input2",
output_name=self.name
))
return list_actions
......@@ -284,15 +284,15 @@ class SubCPP(ExportNode):
list_actions.append(generate_str(
str(ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja"),
name=self.name,
inputs1_name=self.inputs[0].name() if self.inputs[0] else self.name + "_input1",
inputs2_name=self.inputs[1].name() if self.inputs[1] else self.name + "_input2",
inputs1_name=f"{self.name}_input_0" if self.inputs[0] is None else self.inputs[0].name(),
inputs2_name=f"{self.name}_input_1" if self.inputs[1] is None else self.inputs[1].name(),
output_name=self.name
))
return list_actions
@operator_register("PaddedMaxPooling")
class MaxPoolCPP(ExportNode):
class PaddedMaxPoolCPP(ExportNode):
def __init__(self, node):
super().__init__(node)
for n in self.operator.get_micro_graph().get_nodes():
......@@ -342,9 +342,7 @@ class MaxPoolCPP(ExportNode):
list_actions.append(generate_str(
str(ROOT / "templates" / "kernel_forward" / "pooling_forward.jinja"),
name=self.name,
input_name=f"{self.name}_0" if self.inputs[0] is None else self.inputs[0].name(),
inputs1_name=self.inputs[0].name() if self.inputs[0] else self.name + "_input1",
inputs2_name=self.inputs[1].name() if self.inputs[1] else self.name + "_input2",
input_name=f"{self.name}_input_0" if self.inputs[0] is None else self.inputs[0].name(),
output_name=self.name
))
return list_actions
......@@ -457,9 +455,9 @@ class FcCPP(ExportNode):
list_actions.append(generate_str(
str(ROOT / "templates" / "kernel_forward" / "fullyconnected_forward.jinja"),
name=self.name,
inputs_name= self.inputs[0].name() if (self.inputs[0] is not None) else self.name + '_input',
inputs_name=f"{self.name}_input" if self.inputs[0] is None else self.inputs[0].name(),
weights_name=self.inputs[1].name(),
biases_name=self.inputs[2].name(),
biases_name=self.inputs[2].name(), # TODO we should check if bias
outputs_name=self.name
))
return list_actions
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment