diff --git a/aidge_export_cpp/export.py b/aidge_export_cpp/export.py index de20f3b4eee728a48cdef7849cd5e4894ee0e3d1..2e836f92d7a3bc8d3cdaa0aa7fb8a7b370d7a8f4 100644 --- a/aidge_export_cpp/export.py +++ b/aidge_export_cpp/export.py @@ -7,9 +7,11 @@ from typing import List, Union from jinja2 import Environment, FileSystemLoader import aidge_core + +from aidge_core.export_utils.data_conversion import aidge2c from aidge_core.export_utils.code_generation import * from aidge_export_cpp.utils import (ROOT, OPERATORS_REGISTRY, supported_operators) -from aidge_export_cpp.utils.converter import aidge_datatype2ctype, numpy_dtype2ctype +from aidge_export_cpp.utils.converter import numpy_dtype2ctype import aidge_export_cpp.operators from aidge_export_cpp.utils.generation import * from aidge_export_cpp.memory import * @@ -72,21 +74,27 @@ def export(export_folder_name, graphview, scheduler): list_configs.append("memory/mem_info.h") # Get entry nodes - # It supposes the entry nodes are producers with constant=false # Store the datatype & name list_inputs_name = [] - for node in graphview.get_nodes(): - if node.type() == "Producer": - if not node.get_operator().get_attr("Constant"): - export_type = aidge_datatype2ctype(node.get_operator().get_output(0).dtype()) - list_inputs_name.append((export_type, node.name())) + print(graphview.get_input_nodes()) + for node in graphview.get_input_nodes(): + for node_input, outidx in node.inputs(): + + if node_input not in graphview.get_nodes(): + # Case where + export_type = aidge2c(node_input.get_operator().get_output(0).dtype()) + list_inputs_name.append((export_type, node_input.name())) + elif node_input is None: + export_type = aidge2c(node.get_operator().get_output(0).dtype()) + list_inputs_name.append((export_type, f"{node.name()}_{outidx}")) + # Get output nodes # Store the datatype & name, like entry nodes list_outputs_name = [] for node in graphview.get_nodes(): if len(node.get_children()) == 0: - export_type = aidge_datatype2ctype(node.get_operator().get_output(0).dtype()) + export_type = aidge2c(node.get_operator().get_output(0).dtype()) list_outputs_name.append((export_type, node.name())) # Generate forward file diff --git a/aidge_export_cpp/operators.py b/aidge_export_cpp/operators.py index 455ae8174732171cf4257034f5b204cab5e518fb..198abb706e699ff5a936b8bcbf94c8032f8a9f1e 100644 --- a/aidge_export_cpp/operators.py +++ b/aidge_export_cpp/operators.py @@ -71,23 +71,18 @@ class ProducerCPP(ExportNode): def __init__(self, node): super().__init__(node) - self.constant = self.operator.get_attr("Constant") self.values = np.array(self.operator.get_output(0)) - if len(self.values.shape) == 4: + if len(self.values.shape) == 4: # Note: export in HWC self.values = np.transpose(self.values, (0, 2, 3, 1)) def export(self, export_folder:Path, list_configs:list): - # If not constant, it is a dataprovider - # and not a parameter provider - if (self.constant): - list_configs.append(f"parameters/{self.name}.h") - - # Export in HWC - export_params(self.name, - self.values.reshape(-1), - str(export_folder / "parameters" / f"{self.name}.h")) + list_configs.append(f"parameters/{self.name}.h") + export_params( + self.name, + self.values.reshape(-1), + str(export_folder / "parameters" / f"{self.name}.h")) return list_configs @@ -130,7 +125,7 @@ class ReLUCPP(ExportNode): list_actions.append(generate_str( str(ROOT / "templates" / "kernel_forward" / "activation_forward.jinja"), name=self.name, - input_name=self.inputs[0].name(), + input_name=f"{self.name}_0" if self.inputs[0] is None else self.inputs[0].name(), output_name=self.name )) return list_actions @@ -192,11 +187,11 @@ class ConvCPP(ExportNode): if not self.is_last: list_actions.append(set_up_output(self.name, "float")) - + print(self.inputs[0]) list_actions.append(generate_str( str(ROOT / "templates" / "kernel_forward" / "convolution_forward.jinja"), name=self.name, - input_name=self.inputs[0].name(), + input_name=f"{self.name}_0" if self.inputs[0] is None else self.inputs[0].name(), output_name=self.name, weights_name=self.inputs[1].name(), biases_name=self.inputs[2].name() @@ -253,11 +248,11 @@ class AddCPP(ExportNode): def forward(self, list_actions:list): list_actions.append(set_up_output(self.name, "float")) - list_actions.append(generate_action( + list_actions.append(generate_str( str(ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja"), name=self.name, - inputs1_name=self.parents[0].name(), - inputs2_name=self.parents[1].name(), + inputs1_name=f"{self.name}_0" if self.inputs[0] is None else self.inputs[0].name(), + inputs2_name=f"{self.name}_0" if self.inputs[1] is None else self.inputs[1].name(), output_name=self.name )) return list_actions @@ -286,11 +281,11 @@ class SubCPP(ExportNode): def forward(self, list_actions:list): list_actions.append(set_up_output(self.name, "float")) - list_actions.append(generate_action( + list_actions.append(generate_str( str(ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja"), name=self.name, - inputs1_name=self.inputs[0].name(), - inputs2_name=self.inputs[1].name(), + inputs1_name=f"{self.name}_0" if self.inputs[0] is None else self.inputs[0].name(), + inputs2_name=f"{self.name}_1" if self.inputs[1] is None else self.inputs[1].name(), output_name=self.name )) return list_actions @@ -347,7 +342,7 @@ class MaxPoolCPP(ExportNode): list_actions.append(generate_str( str(ROOT / "templates" / "kernel_forward" / "pooling_forward.jinja"), name=self.name, - input_name=self.inputs[0].name(), + input_name=f"{self.name}_0" if self.inputs[0] is None else self.inputs[0].name(), output_name=self.name )) return list_actions @@ -404,7 +399,7 @@ class FcCPP(ExportNode): list_actions.append(generate_str( str(ROOT / "templates" / "kernel_forward" / "fullyconnected_forward.jinja"), name=self.name, - inputs_name=self.inputs[0].name(), + inputs_name=f"{self.name}_0" if self.inputs[0] is None else self.inputs[0].name(), weights_name=self.inputs[1].name(), biases_name=self.inputs[2].name(), outputs_name=self.name diff --git a/aidge_export_cpp/utils/converter.py b/aidge_export_cpp/utils/converter.py index d706d5a26f103316adfca0cd16f1146284e8177c..d4af124280e2c89ec44123c90ee509347003f960 100644 --- a/aidge_export_cpp/utils/converter.py +++ b/aidge_export_cpp/utils/converter.py @@ -1,5 +1,4 @@ import numpy as np -import aidge_core def numpy_dtype2ctype(dtype): if dtype == np.int8: @@ -17,19 +16,3 @@ def numpy_dtype2ctype(dtype): # Add more dtype mappings as needed else: raise ValueError(f"Unsupported {dtype} dtype") - - -def aidge_datatype2ctype(datatype): - if datatype == aidge_core.DataType.Int8: - return "int8_t" - elif datatype == aidge_core.DataType.Int32: - return "int32_t" - elif datatype == aidge_core.DataType.Int64: - return "int64_t" - elif datatype == aidge_core.DataType.Float32: - return "float" - elif datatype == aidge_core.DataType.Float64: - return "double" - # Add more dtype mappings as needed - else: - raise ValueError(f"Unsupported {datatype} aidge datatype") \ No newline at end of file