Skip to content
Snippets Groups Projects

Fix CPP export for Add Sub and add export for Mul

2 unresolved threads
1 file
+ 48
5
Compare changes
  • Side-by-side
  • Inline
@@ -238,6 +238,8 @@ class AddCPP(ExportNode):
@@ -238,6 +238,8 @@ class AddCPP(ExportNode):
copyfile(str(ROOT / "kernels" / "elemwise.hpp"),
copyfile(str(ROOT / "kernels" / "elemwise.hpp"),
str(export_folder / "include" / "kernels"))
str(export_folder / "include" / "kernels"))
 
copyfile(str(ROOT / "kernels" / "activation.hpp"),
 
str(export_folder / "include" / "kernels"))
generate_file(
generate_file(
str(export_folder / "layers" / f"{self.name}.h"),
str(export_folder / "layers" / f"{self.name}.h"),
@@ -251,13 +253,13 @@ class AddCPP(ExportNode):
@@ -251,13 +253,13 @@ class AddCPP(ExportNode):
return list_configs
return list_configs
def forward(self, list_actions:list):
def forward(self, list_actions:list):
if not self.is_last:
list_actions.append(set_up_output(self.name, "float"))
list_actions.append(set_up_output(self.name, "float"))
list_actions.append(generate_str(
list_actions.append(generate_str(
str(ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja"),
str(ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja"),
name=self.name,
name=self.name,
inputs1_name=self.parents[0].name() if self.parents[0] else self.name + "_input1",
inputs1_name=self.inputs[0].name() if self.inputs[0] else self.name + "_input1",
inputs2_name=self.parents[1].name() if self.parents[1] else self.name + "_input2",
inputs2_name=self.inputs[1].name() if self.inputs[1] else self.name + "_input2",
output_name=self.name
output_name=self.name
))
))
return list_actions
return list_actions
@@ -272,6 +274,9 @@ class SubCPP(ExportNode):
@@ -272,6 +274,9 @@ class SubCPP(ExportNode):
list_configs.append("kernels/elemwise.hpp")
list_configs.append("kernels/elemwise.hpp")
copyfile(str(ROOT / "kernels" / "elemwise.hpp"),
copyfile(str(ROOT / "kernels" / "elemwise.hpp"),
str(export_folder / "include" / "kernels"))
str(export_folder / "include" / "kernels"))
 
copyfile(str(ROOT / "kernels" / "activation.hpp"),
 
str(export_folder / "include" / "kernels"))
 
generate_file(
generate_file(
str(export_folder / "layers" / f"{self.name}.h"),
str(export_folder / "layers" / f"{self.name}.h"),
str(ROOT / "templates" / "configuration" / "elemwise_config.jinja"),
str(ROOT / "templates" / "configuration" / "elemwise_config.jinja"),
@@ -284,8 +289,46 @@ class SubCPP(ExportNode):
@@ -284,8 +289,46 @@ class SubCPP(ExportNode):
return list_configs
return list_configs
def forward(self, list_actions:list):
def forward(self, list_actions:list):
 
if not self.is_last:
 
list_actions.append(set_up_output(self.name, "float"))
 
 
list_actions.append(generate_str(
 
str(ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja"),
 
name=self.name,
 
inputs1_name=self.inputs[0].name() if self.inputs[0] else self.name + "_input1",
 
inputs2_name=self.inputs[1].name() if self.inputs[1] else self.name + "_input2",
 
output_name=self.name
 
))
 
return list_actions
 
 
@operator_register("Mul")
 
class MulCPP(ExportNode):
 
def __init__(self, node):
 
super().__init__(node)
 
 
def export(self, export_folder:str, list_configs:list):
 
list_configs.append(f"layers/{self.name}.h")
 
list_configs.append("kernels/elemwise.hpp")
 
copyfile(str(ROOT / "kernels" / "elemwise.hpp"),
 
str(export_folder / "include" / "kernels"))
 
copyfile(str(ROOT / "kernels" / "activation.hpp"),
 
str(export_folder / "include" / "kernels"))
 
 
generate_file(
 
str(export_folder / "layers" / f"{self.name}.h"),
 
str(ROOT / "templates" / "configuration" / "elemwise_config.jinja"),
 
name=self.name,
 
nb_elts=np.prod(self.inputs_dims[0]),
 
activation="Linear",
 
elemwise_op="Mul",
 
rescaling="NoScaling")
 
 
return list_configs
 
 
def forward(self, list_actions:list):
 
if not self.is_last:
 
list_actions.append(set_up_output(self.name, "float"))
list_actions.append(set_up_output(self.name, "float"))
list_actions.append(generate_str(
list_actions.append(generate_str(
str(ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja"),
str(ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja"),
name=self.name,
name=self.name,
Loading