Skip to content
Snippets Groups Projects
Commit a07a87bb authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Export can be done if there is no Producer. Producer doesn't need ot be static any more.

parent 81d562fa
No related branches found
No related tags found
No related merge requests found
......@@ -7,9 +7,11 @@ from typing import List, Union
from jinja2 import Environment, FileSystemLoader
import aidge_core
from aidge_core.export_utils.data_conversion import aidge2c
from aidge_core.export_utils.code_generation import *
from aidge_export_cpp.utils import (ROOT, OPERATORS_REGISTRY, supported_operators)
from aidge_export_cpp.utils.converter import aidge_datatype2ctype, numpy_dtype2ctype
from aidge_export_cpp.utils.converter import numpy_dtype2ctype
import aidge_export_cpp.operators
from aidge_export_cpp.utils.generation import *
from aidge_export_cpp.memory import *
......@@ -72,21 +74,27 @@ def export(export_folder_name, graphview, scheduler):
list_configs.append("memory/mem_info.h")
# Get entry nodes
# It supposes the entry nodes are producers with constant=false
# Store the datatype & name
list_inputs_name = []
for node in graphview.get_nodes():
if node.type() == "Producer":
if not node.get_operator().get_attr("Constant"):
export_type = aidge_datatype2ctype(node.get_operator().get_output(0).dtype())
list_inputs_name.append((export_type, node.name()))
print(graphview.get_input_nodes())
for node in graphview.get_input_nodes():
for node_input, outidx in node.inputs():
if node_input not in graphview.get_nodes():
# Case where
export_type = aidge2c(node_input.get_operator().get_output(0).dtype())
list_inputs_name.append((export_type, node_input.name()))
elif node_input is None:
export_type = aidge2c(node.get_operator().get_output(0).dtype())
list_inputs_name.append((export_type, f"{node.name()}_{outidx}"))
# Get output nodes
# Store the datatype & name, like entry nodes
list_outputs_name = []
for node in graphview.get_nodes():
if len(node.get_children()) == 0:
export_type = aidge_datatype2ctype(node.get_operator().get_output(0).dtype())
export_type = aidge2c(node.get_operator().get_output(0).dtype())
list_outputs_name.append((export_type, node.name()))
# Generate forward file
......
......@@ -71,23 +71,18 @@ class ProducerCPP(ExportNode):
def __init__(self, node):
super().__init__(node)
self.constant = self.operator.get_attr("Constant")
self.values = np.array(self.operator.get_output(0))
if len(self.values.shape) == 4:
if len(self.values.shape) == 4: # Note: export in HWC
self.values = np.transpose(self.values, (0, 2, 3, 1))
def export(self, export_folder:Path, list_configs:list):
# If not constant, it is a dataprovider
# and not a parameter provider
if (self.constant):
list_configs.append(f"parameters/{self.name}.h")
# Export in HWC
export_params(self.name,
self.values.reshape(-1),
str(export_folder / "parameters" / f"{self.name}.h"))
list_configs.append(f"parameters/{self.name}.h")
export_params(
self.name,
self.values.reshape(-1),
str(export_folder / "parameters" / f"{self.name}.h"))
return list_configs
......@@ -130,7 +125,7 @@ class ReLUCPP(ExportNode):
list_actions.append(generate_str(
str(ROOT / "templates" / "kernel_forward" / "activation_forward.jinja"),
name=self.name,
input_name=self.inputs[0].name(),
input_name=f"{self.name}_0" if self.inputs[0] is None else self.inputs[0].name(),
output_name=self.name
))
return list_actions
......@@ -192,11 +187,11 @@ class ConvCPP(ExportNode):
if not self.is_last:
list_actions.append(set_up_output(self.name, "float"))
print(self.inputs[0])
list_actions.append(generate_str(
str(ROOT / "templates" / "kernel_forward" / "convolution_forward.jinja"),
name=self.name,
input_name=self.inputs[0].name(),
input_name=f"{self.name}_0" if self.inputs[0] is None else self.inputs[0].name(),
output_name=self.name,
weights_name=self.inputs[1].name(),
biases_name=self.inputs[2].name()
......@@ -253,11 +248,11 @@ class AddCPP(ExportNode):
def forward(self, list_actions:list):
list_actions.append(set_up_output(self.name, "float"))
list_actions.append(generate_action(
list_actions.append(generate_str(
str(ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja"),
name=self.name,
inputs1_name=self.parents[0].name(),
inputs2_name=self.parents[1].name(),
inputs1_name=f"{self.name}_0" if self.inputs[0] is None else self.inputs[0].name(),
inputs2_name=f"{self.name}_0" if self.inputs[1] is None else self.inputs[1].name(),
output_name=self.name
))
return list_actions
......@@ -286,11 +281,11 @@ class SubCPP(ExportNode):
def forward(self, list_actions:list):
list_actions.append(set_up_output(self.name, "float"))
list_actions.append(generate_action(
list_actions.append(generate_str(
str(ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja"),
name=self.name,
inputs1_name=self.inputs[0].name(),
inputs2_name=self.inputs[1].name(),
inputs1_name=f"{self.name}_0" if self.inputs[0] is None else self.inputs[0].name(),
inputs2_name=f"{self.name}_1" if self.inputs[1] is None else self.inputs[1].name(),
output_name=self.name
))
return list_actions
......@@ -347,7 +342,7 @@ class MaxPoolCPP(ExportNode):
list_actions.append(generate_str(
str(ROOT / "templates" / "kernel_forward" / "pooling_forward.jinja"),
name=self.name,
input_name=self.inputs[0].name(),
input_name=f"{self.name}_0" if self.inputs[0] is None else self.inputs[0].name(),
output_name=self.name
))
return list_actions
......@@ -404,7 +399,7 @@ class FcCPP(ExportNode):
list_actions.append(generate_str(
str(ROOT / "templates" / "kernel_forward" / "fullyconnected_forward.jinja"),
name=self.name,
inputs_name=self.inputs[0].name(),
inputs_name=f"{self.name}_0" if self.inputs[0] is None else self.inputs[0].name(),
weights_name=self.inputs[1].name(),
biases_name=self.inputs[2].name(),
outputs_name=self.name
......
import numpy as np
import aidge_core
def numpy_dtype2ctype(dtype):
if dtype == np.int8:
......@@ -17,19 +16,3 @@ def numpy_dtype2ctype(dtype):
# Add more dtype mappings as needed
else:
raise ValueError(f"Unsupported {dtype} dtype")
def aidge_datatype2ctype(datatype):
if datatype == aidge_core.DataType.Int8:
return "int8_t"
elif datatype == aidge_core.DataType.Int32:
return "int32_t"
elif datatype == aidge_core.DataType.Int64:
return "int64_t"
elif datatype == aidge_core.DataType.Float32:
return "float"
elif datatype == aidge_core.DataType.Float64:
return "double"
# Add more dtype mappings as needed
else:
raise ValueError(f"Unsupported {datatype} aidge datatype")
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment