diff --git a/aidge_export_cpp/export.py b/aidge_export_cpp/export.py index d9dc2969d712deeeda977dbfa7309fde00aefc84..2a535ddd4a355560e9037b6567cdb90a52c66df3 100644 --- a/aidge_export_cpp/export.py +++ b/aidge_export_cpp/export.py @@ -100,13 +100,18 @@ def export(export_folder_name, graphview, scheduler, mem_wrapping=False): list_outputs_name.append((export_type, f"{node.name()}_output_0")) # Generate forward file + # TODO: for now the mem type is bound for all intermediate results, should change. + # Note that we may have all inputs constants, hence select output type + assert len(list_outputs_name) >= 1, f"TODO: requires some output to determine mem type" + mem_ctype = list_outputs_name[0][0] generate_file( str(dnn_folder / "src" / "forward.cpp"), str(ROOT / "templates" / "network" / "network_forward.jinja"), headers=set(list_configs), actions=list_actions, inputs= list_inputs_name, - outputs=list_outputs_name + outputs=list_outputs_name, + mem_ctype=mem_ctype, ) # Generate dnn API diff --git a/aidge_export_cpp/operators.py b/aidge_export_cpp/operators.py index 9edf188fa2fc48376536ef66141dc7dce3ba10de..b99613458c3b1d7eadbe4435da31da56e52d9fff 100644 --- a/aidge_export_cpp/operators.py +++ b/aidge_export_cpp/operators.py @@ -186,6 +186,8 @@ class AddCPP(ExportNode): copyfile(str(ROOT / "kernels" / "elemwise.hpp"), str(export_folder / "include" / "kernels")) + copyfile(str(ROOT / "kernels" / "activation.hpp"), + str(export_folder / "include" / "kernels")) generate_file( str(export_folder / "layers" / f"{self.attributes['name']}.h"), @@ -214,6 +216,9 @@ class SubCPP(ExportNode): list_configs.append("kernels/elemwise.hpp") copyfile(str(ROOT / "kernels" / "elemwise.hpp"), str(export_folder / "include" / "kernels")) + copyfile(str(ROOT / "kernels" / "activation.hpp"), + str(export_folder / "include" / "kernels")) + generate_file( str(export_folder / "layers" / f"{self.attributes['name']}.h"), str(ROOT / "templates" / "configuration" / "elemwise_config.jinja"), @@ -231,6 +236,40 @@ class SubCPP(ExportNode): )) return list_actions +@operator_register("Mul") +class MulCPP(ExportNode): + def __init__(self, node): + super().__init__(node) + + def export(self, export_folder:str, list_configs:list): + list_configs.append(f"layers/{self.name}.h") + list_configs.append("kernels/elemwise.hpp") + copyfile(str(ROOT / "kernels" / "elemwise.hpp"), + str(export_folder / "include" / "kernels")) + copyfile(str(ROOT / "kernels" / "activation.hpp"), + str(export_folder / "include" / "kernels")) + + generate_file( + str(export_folder / "layers" / f"{self.name}.h"), + str(ROOT / "templates" / "configuration" / "elemwise_config.jinja"), + name=self.name, + nb_elts=np.prod(self.inputs_dims[0]), + activation="Linear", + elemwise_op="Mul", + rescaling="NoScaling") + + return list_configs + + def forward(self, list_actions:list): + if not self.is_last: + list_actions.append(set_up_output(self.name, "float")) + + list_actions.append(generate_str( + str(ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja"), + **self.attributes + )) + return list_actions + @operator_register("PaddedMaxPooling") class PaddedMaxPoolCPP(ExportNode): diff --git a/version.txt b/version.txt index 6da28dde76d6550e3d398a70a9a8231256774669..8294c184368c0ec9f84fbcc80c6b36326940c770 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.1.1 \ No newline at end of file +0.1.2 \ No newline at end of file