Skip to content
Snippets Groups Projects
Commit f2fb39f6 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Merge branch 'dev' into allowNoInputProducer

parents 9723a29d 17efa8fb
No related branches found
No related tags found
3 merge requests!27v0.2.0,!22v0.4.0,!15Export refactor
......@@ -100,13 +100,18 @@ def export(export_folder_name, graphview, scheduler, mem_wrapping=False):
list_outputs_name.append((export_type, f"{node.name()}_output_0"))
# Generate forward file
# TODO: for now the mem type is bound for all intermediate results, should change.
# Note that we may have all inputs constants, hence select output type
assert len(list_outputs_name) >= 1, f"TODO: requires some output to determine mem type"
mem_ctype = list_outputs_name[0][0]
generate_file(
str(dnn_folder / "src" / "forward.cpp"),
str(ROOT / "templates" / "network" / "network_forward.jinja"),
headers=set(list_configs),
actions=list_actions,
inputs= list_inputs_name,
outputs=list_outputs_name
outputs=list_outputs_name,
mem_ctype=mem_ctype,
)
# Generate dnn API
......
......@@ -186,6 +186,8 @@ class AddCPP(ExportNode):
copyfile(str(ROOT / "kernels" / "elemwise.hpp"),
str(export_folder / "include" / "kernels"))
copyfile(str(ROOT / "kernels" / "activation.hpp"),
str(export_folder / "include" / "kernels"))
generate_file(
str(export_folder / "layers" / f"{self.attributes['name']}.h"),
......@@ -214,6 +216,9 @@ class SubCPP(ExportNode):
list_configs.append("kernels/elemwise.hpp")
copyfile(str(ROOT / "kernels" / "elemwise.hpp"),
str(export_folder / "include" / "kernels"))
copyfile(str(ROOT / "kernels" / "activation.hpp"),
str(export_folder / "include" / "kernels"))
generate_file(
str(export_folder / "layers" / f"{self.attributes['name']}.h"),
str(ROOT / "templates" / "configuration" / "elemwise_config.jinja"),
......@@ -231,6 +236,40 @@ class SubCPP(ExportNode):
))
return list_actions
@operator_register("Mul")
class MulCPP(ExportNode):
def __init__(self, node):
super().__init__(node)
def export(self, export_folder:str, list_configs:list):
list_configs.append(f"layers/{self.name}.h")
list_configs.append("kernels/elemwise.hpp")
copyfile(str(ROOT / "kernels" / "elemwise.hpp"),
str(export_folder / "include" / "kernels"))
copyfile(str(ROOT / "kernels" / "activation.hpp"),
str(export_folder / "include" / "kernels"))
generate_file(
str(export_folder / "layers" / f"{self.name}.h"),
str(ROOT / "templates" / "configuration" / "elemwise_config.jinja"),
name=self.name,
nb_elts=np.prod(self.inputs_dims[0]),
activation="Linear",
elemwise_op="Mul",
rescaling="NoScaling")
return list_configs
def forward(self, list_actions:list):
if not self.is_last:
list_actions.append(set_up_output(self.name, "float"))
list_actions.append(generate_str(
str(ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja"),
**self.attributes
))
return list_actions
@operator_register("PaddedMaxPooling")
class PaddedMaxPoolCPP(ExportNode):
......
0.1.1
\ No newline at end of file
0.1.2
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment