Skip to content
Snippets Groups Projects

Fix CPP export for Add Sub and add export for Mul

2 unresolved threads
1 file
+ 48
5
Compare changes
  • Side-by-side
  • Inline
@@ -238,6 +238,8 @@ class AddCPP(ExportNode):
copyfile(str(ROOT / "kernels" / "elemwise.hpp"),
str(export_folder / "include" / "kernels"))
copyfile(str(ROOT / "kernels" / "activation.hpp"),
str(export_folder / "include" / "kernels"))
generate_file(
str(export_folder / "layers" / f"{self.name}.h"),
@@ -251,13 +253,13 @@ class AddCPP(ExportNode):
return list_configs
def forward(self, list_actions:list):
list_actions.append(set_up_output(self.name, "float"))
if not self.is_last:
list_actions.append(set_up_output(self.name, "float"))
list_actions.append(generate_str(
str(ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja"),
name=self.name,
inputs1_name=self.parents[0].name() if self.parents[0] else self.name + "_input1",
inputs2_name=self.parents[1].name() if self.parents[1] else self.name + "_input2",
inputs1_name=self.inputs[0].name() if self.inputs[0] else self.name + "_input1",
inputs2_name=self.inputs[1].name() if self.inputs[1] else self.name + "_input2",
output_name=self.name
))
return list_actions
@@ -272,6 +274,9 @@ class SubCPP(ExportNode):
list_configs.append("kernels/elemwise.hpp")
copyfile(str(ROOT / "kernels" / "elemwise.hpp"),
str(export_folder / "include" / "kernels"))
copyfile(str(ROOT / "kernels" / "activation.hpp"),
str(export_folder / "include" / "kernels"))
generate_file(
str(export_folder / "layers" / f"{self.name}.h"),
str(ROOT / "templates" / "configuration" / "elemwise_config.jinja"),
@@ -284,8 +289,46 @@ class SubCPP(ExportNode):
return list_configs
def forward(self, list_actions:list):
if not self.is_last:
list_actions.append(set_up_output(self.name, "float"))
list_actions.append(generate_str(
str(ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja"),
name=self.name,
inputs1_name=self.inputs[0].name() if self.inputs[0] else self.name + "_input1",
inputs2_name=self.inputs[1].name() if self.inputs[1] else self.name + "_input2",
output_name=self.name
))
return list_actions
@operator_register("Mul")
class MulCPP(ExportNode):
def __init__(self, node):
super().__init__(node)
def export(self, export_folder:str, list_configs:list):
list_configs.append(f"layers/{self.name}.h")
list_configs.append("kernels/elemwise.hpp")
copyfile(str(ROOT / "kernels" / "elemwise.hpp"),
str(export_folder / "include" / "kernels"))
copyfile(str(ROOT / "kernels" / "activation.hpp"),
str(export_folder / "include" / "kernels"))
generate_file(
str(export_folder / "layers" / f"{self.name}.h"),
str(ROOT / "templates" / "configuration" / "elemwise_config.jinja"),
name=self.name,
nb_elts=np.prod(self.inputs_dims[0]),
activation="Linear",
elemwise_op="Mul",
rescaling="NoScaling")
return list_configs
def forward(self, list_actions:list):
if not self.is_last:
list_actions.append(set_up_output(self.name, "float"))
list_actions.append(set_up_output(self.name, "float"))
list_actions.append(generate_str(
str(ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja"),
name=self.name,
Loading