diff --git a/aidge_export_cpp/operators.py b/aidge_export_cpp/operators.py index b99613458c3b1d7eadbe4435da31da56e52d9fff..24c2d63c0cc33f5eda845ab7c6c8ec0427016701 100644 --- a/aidge_export_cpp/operators.py +++ b/aidge_export_cpp/operators.py @@ -4,7 +4,7 @@ import numpy as np from pathlib import Path from jinja2 import Environment, FileSystemLoader -from aidge_core import ExportNode +from aidge_core import ExportNode, ExportNodeCpp from aidge_core.export_utils.code_generation import * from aidge_export_cpp.utils import ROOT, operator_register from aidge_export_cpp.utils.converter import numpy_dtype2ctype @@ -84,35 +84,17 @@ class ProducerCPP(ExportNode): @operator_register("ReLU") -class ReLUCPP(ExportNode): +class ReLUCPP(ExportNodeCpp): def __init__(self, node): super().__init__(node) - - - def export(self, export_folder:Path, list_configs:list): - - copyfile(str(ROOT / "kernels" / "activation.hpp"), - str(export_folder / "include" / "kernels")) - - list_configs.append("kernels/activation.hpp") - list_configs.append(f"layers/{self.attributes['name']}.h") - - generate_file( - str(export_folder / "layers" / f"{self.attributes['name']}.h"), - str(ROOT / "templates" / "configuration" / "activation_config.jinja"), - activation="Rectifier", - rescaling="NoScaling", - **self.attributes - ) - - return list_configs - - def forward(self, list_actions:list): - list_actions.append(generate_str( - str(ROOT / "templates" / "kernel_forward" / "activation_forward.jinja"), - **self.attributes - )) - return list_actions + self.attributes["activation"] = "Rectifier" + self.attributes["rescaling"] = "NoScaling" + self.config_template = str(ROOT / "templates" / "configuration" / "activation_config.jinja") + self.forward_template = str(ROOT / "templates" / "kernel_forward" / "activation_forward.jinja") + self.include_list = [] + self.kernels_to_copy = [ + str(ROOT / "kernels" / "activation.hpp"), + ] @operator_register("Conv")