Skip to content
Snippets Groups Projects
Commit 052f6d60 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Heavy refactor of node_export to retrieve more attributes.

parent a92ad137
No related branches found
No related tags found
3 merge requests!279v0.4.0,!253v0.4.0,!163Export refactor
Pipeline #50477 failed
import aidge_core
from aidge_core.export_utils import data_conversion
from abc import ABC, abstractmethod
def get_chan(tensor: aidge_core.Tensor) -> int:
"""Given a tensor return the number of channel
"""
dformat = tensor.dformat()
dims = tensor.dims()
if len(dims) == 4: # Suppose NCHW
return dims[1]
elif len(dims) == 2: # Suppose NC
return dims[1]
else:
return None
# if dformat == aidge_core.dformat.Default:
# return None
# elif dformat == aidge_core.dformat.NCHW:
# return dims[1]
# elif dformat == aidge_core.dformat.NHWC:
# return dims[3]
# elif dformat == aidge_core.dformat.CHWN:
# return dims[0]
# elif dformat == aidge_core.dformat.NCDHW:
# return dims[1]
# elif dformat == aidge_core.dformat.NDHWC:
# return dims[4]
# elif dformat == aidge_core.dformat.CDHWN:
# return dims[0]
# else:
# raise RuntimeError(f"Unknown dataformat: {dformat}")
def get_height(tensor: aidge_core.Tensor) -> int:
dformat = tensor.dformat()
dims = tensor.dims()
if len(dims) == 4: # Suppose NCHW
return dims[2]
elif len(dims) == 2: # Suppose NC
return 1
else:
return None
# TODO: use when dformat is fully supported
# if dformat == aidge_core.dformat.Default:
# return None
# elif dformat == aidge_core.dformat.NCHW:
# return dims[2]
# elif dformat == aidge_core.dformat.NHWC:
# return dims[1]
# elif dformat == aidge_core.dformat.CHWN:
# return dims[1]
# elif dformat == aidge_core.dformat.NCDHW:
# return dims[3]
# elif dformat == aidge_core.dformat.NDHWC:
# return dims[2]
# elif dformat == aidge_core.dformat.CDHWN:
# return dims[2]
# else:
# raise RuntimeError(f"Unknown dataformat: {dformat}")
def get_width(tensor: aidge_core.Tensor) -> int:
dformat = tensor.dformat()
dims = tensor.dims()
if len(dims) == 4: # Suppose NCHW
return dims[3]
elif len(dims) == 2: # Suppose NC
return 1
else:
return None
# if dformat == aidge_core.dformat.Default:
# return None
# elif dformat == aidge_core.dformat.NCHW:
# return dims[3]
# elif dformat == aidge_core.dformat.NHWC:
# return dims[2]
# elif dformat == aidge_core.dformat.CHWN:
# return dims[2]
# elif dformat == aidge_core.dformat.NCDHW:
# return dims[4]
# elif dformat == aidge_core.dformat.NDHWC:
# return dims[3]
# elif dformat == aidge_core.dformat.CDHWN:
# return dims[3]
# else:
# raise RuntimeError(f"Unknown dataformat: {dformat}")
class ExportNode(ABC):
"""Abstract class to interface node with export generation.
This class expose a dictionary ``attributes`` which contains all the information required to generate an export:
- All the attributes of the Aidge node are automatically fetch, the key to get an attribute is the attribute name in python format, example ``no_bias``
- **name**: Name of the Node, ``str``
- **nb_in**: Number of inputs, ``int``
- **in_name**: unique name for each input, if no input node the name is ``{node_name}_input_{in_id}``, if there is a parent, the name is ``{parent_name}_output_{out_id}``, ``list[str]``
- **in_dims**: A list of the dimension for each inputs, ``list[list[int]]``
- **in_chan**: A list of channel for each inputs, deduced by the dataformat, ``list[int]``
- **in_height**: A list of height for each inputs, deduced by the dataformat, ``list[int]``
- **in_width**: A list of width for each inputs, deduced by the dataformat, ``list[int]``
- **in_dtype**: A list of type (Aidge format) for each input, ``List[:py:class:`aidge_core.dtype`]``
- **in_cdtype**: A list of type (C/C++ format) for each input, ``List[str]``
- **out_name**: unique name for each output, the name is ``{name}_output_{out_id}``, ``list[str]``
- **nb_out**: Number of outputs, ``int``
- **out_dims**: A list of the dimension for each inputs, ``list[list[int]]``
- **out_chan**: A list of channel for each outputs, deduced by the dataformat, ``list[int]``
- **out_height**: A list of height for each outputs, deduced by the dataformat, ``list[int]``
- **out_width**: A list of width for each outputs, deduced by the dataformat, ``list[int]``
- **out_dtype**: A list of type (Aidge format) for each output, ``List[:py:class:`aidge_core.dtype`]``
- **out_cdtype**: A list of type (C/C++ format) for each output, ``List[str]``
- **is_output**: True if the node is an output node, ``bool``
- **is_input**: True if the node is an input node, ``bool``
"""
@abstractmethod
def __init__(self, aidge_node: aidge_core.Node) -> None:
"""Create ExportNode and retieve attirubtes from ``aidge_node``:
- name: aidge Node name
- attributes: dictionnary of attributes of the aidge Operator linked to the node, attributes name follow aidge naming convention
- parameters: List of parameters node, order in the list is the same as the one defined by the aidge operator
"""Create ExportNode and retieve attriubtes from ``aidge_node``:
"""
super().__init__()
self.node = aidge_node
self.operator = aidge_node.get_operator()
self.name = self.node.name()
self.attributes = self.operator.attr.dict() if self.operator.attr is not None else {} # Attributes are auto fetched from aidge operators
# rename is_leaf ?
self.is_last = len(self.node.get_children()) == 0
# Attributes are auto fetched from aidge operators
self.attributes = self.operator.attr.dict(
) if self.operator.attr is not None else {}
self.attributes["name"] = self.node.name()
self.attributes["nb_in"] = self.node.get_nb_inputs()
self.attributes["nb_out"] = self.node.get_nb_outputs()
# TODO : this check doesn't work if we export a subgraph !
# Maybe we need to add the graph we want to export as parameter !
# Actually may be mandatory for memory manager ...
self.attributes["is_input"] = len(self.node.get_parents()) == 0
self.attributes["is_output"] = len(self.node.get_children()) == 0
self.inputs = []
self.outputs = []
self.inputs_dims = []
self.outputs_dims = []
for idx, parent_node in enumerate(self.node.get_parents()):
self.inputs.append(parent_node)
if parent_node is not None:
self.inputs_dims.append(self.operator.get_input(idx).dims())
else:
if self.operator.get_input(idx) is not None:
self.inputs_dims.append(self.operator.get_input(idx).dims())
else:
self.inputs_dims.append(None)
for idx, child_node in enumerate(self.node.get_children()):
self.outputs.append(child_node)
self.attributes["in_name"] = [None] * self.attributes["nb_in"]
self.attributes["in_dims"] = [None] * self.attributes["nb_in"]
self.attributes["in_dformat"] = [None] * self.attributes["nb_in"]
self.attributes["in_dtype"] = [None] * self.attributes["nb_in"]
self.attributes["in_cdtype"] = [None] * self.attributes["nb_in"]
self.attributes["in_chan"] = [None] * self.attributes["nb_in"]
self.attributes["in_height"] = [None] * self.attributes["nb_in"]
self.attributes["in_width"] = [None] * self.attributes["nb_in"]
# Dirty hot fix, change it quickly
self.outputs_dims.append(self.operator.get_output(0).dims())
self.attributes["out_name"] = [None] * self.attributes["nb_out"]
self.attributes["out_dims"] = [None] * self.attributes["nb_out"]
self.attributes["out_dformat"] = [None] * self.attributes["nb_out"]
self.attributes["out_dtype"] = [None] * self.attributes["nb_out"]
self.attributes["out_cdtype"] = [None] * self.attributes["nb_out"]
self.attributes["out_chan"] = [None] * self.attributes["nb_out"]
self.attributes["out_height"] = [None] * self.attributes["nb_out"]
self.attributes["out_width"] = [None] * self.attributes["nb_out"]
for idx, parent_node_in_id in enumerate(self.node.inputs()):
parent_node, out_id = parent_node_in_id
self.inputs.append(parent_node)
if self.operator.get_input(idx) is not None:
tensor = self.operator.get_input(idx)
self.attributes["in_name"][idx] = f"{self.attributes['name']}_input_{idx}" if parent_node is None else f"{parent_node.name()}_output_{out_id}"
self.attributes["in_dims"][idx] = tensor.dims()
self.attributes["in_dformat"][idx] = tensor.dformat()
self.attributes["in_dtype"][idx] = tensor.dtype()
self.attributes["in_cdtype"][idx] = data_conversion.aidge2c(
tensor.dtype())
self.attributes["in_chan"][idx] = get_chan(tensor)
self.attributes["in_height"][idx] = get_height(tensor)
self.attributes["in_width"][idx] = get_width(tensor)
else:
print(f"No input for {self.node.name()}")
for idx, list_child_node_in_id in enumerate(self.node.outputs()):
self.outputs += [node_in_id[0] for node_in_id in list_child_node_in_id]
if self.operator.get_output(idx) is not None:
tensor = self.operator.get_output(idx)
self.attributes["out_name"][idx] = f"{self.attributes['name']}_output_{idx}"
self.attributes["out_dims"][idx] = tensor.dims()
self.attributes["out_dformat"][idx] = tensor.dformat()
self.attributes["out_dtype"][idx] = tensor.dtype()
self.attributes["out_cdtype"][idx] = data_conversion.aidge2c(
tensor.dtype())
self.attributes["out_chan"][idx] = get_chan(tensor)
self.attributes["out_height"][idx] = get_height(tensor)
self.attributes["out_width"][idx] = get_width(tensor)
else:
print(f"No output for {self.node.name()}")
@abstractmethod
def export(self, export_folder:str, list_configs:list):
def export(self, export_folder: str, list_configs: list):
"""Define how to export the node definition.
"""
pass
@abstractmethod
def forward(self, list_actions:list):
def forward(self, list_actions: list):
"""Define how to generate code to perform a forward pass.
"""
pass
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment