Skip to content
Snippets Groups Projects
Commit 9b322f59 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Add possibility to register multiple type to the same ExportNode

parent 7776fc6a
No related branches found
No related tags found
3 merge requests!279v0.4.0,!253v0.4.0,!163Export refactor
......@@ -7,7 +7,7 @@ import aidge_core.export_utils
from . import ROOT_EXPORT
from aidge_core.aidge_export_aidge.registry import ExportSerialize
from aidge_core.export_utils import ExportNode, generate_file
from aidge_core.export_utils import generate_file
def serialize_to_cpp(export_folder: str,
graph_view: aidge_core.GraphView,
......@@ -79,10 +79,23 @@ def serialize_to_cpp(export_folder: str,
# Next nodes to treat are children of current node
open_nodes += list(node.get_children())
if not ExportSerialize.exportable(node):
#raise RuntimeError
print(f"Node {node.name()} (of type [{node.type()}]) is not exportable !")
op = ExportSerialize.get_export_node(node)(node)
op_impl = node.get_operator().get_impl()
if op_impl is None:
raise RuntimeError(f"Operator {node.name()}[{node.type()}] doesn't have an implementation.")
if not isinstance(op_impl, ExportSerialize):
raise RuntimeError(f"Operator {node.name()}[{node.type()}] doesn't have an exportable backend ({op_impl}).")
node.get_operator().set_backend(ExportSerialize._name)
required_specs = op_impl.get_required_spec()
specs = op_impl.get_best_match(required_specs)
export_node = op_impl.get_export_node(specs)
if export_node is None:
raise RuntimeError(f"Could not find export node for {node.name()}[{node.type()}].")
op = export_node(
node, [], False, False) # Note: is_input and is_output is not used for this export
set_operator.add(node.type())
# TODO: list_configs and list_actions don't need to be passed by argument
......
......@@ -68,37 +68,53 @@ class ExportLib(aidge_core.OperatorImpl):
# Decorator to register kernels for this export
@classmethod
def register(cls, type, spec):
def register(cls, op_type, spec):
def decorator(operator):
class Wrapper(operator):
def __init__(self, *args, **kwargs):
return operator(*args, **kwargs)
if (type not in cls._export_node_registry):
cls._export_node_registry[type] = []
cls._export_node_registry[type].append((spec, operator))
register_func: str = f"register_{type}Op"
# If operator is not defined, then it means we try to register a MetaOperator
if register_func not in dir(aidge_core):
raise ValueError(f"Operator of type: {type} is not declared as registrable!\nHint: If you try to register a MetaOperator use register_metaop instead.")
type_list = []
if isinstance(op_type, list):
type_list = op_type
elif isinstance(op_type, str):
type_list = [op_type]
else:
# Equivalent to aidge_core.register_ConvOp("ExportLibX", ExportLibX)
aidge_core.__getattribute__(register_func)(cls._name, cls)
raise TypeError("Argument type of register method should be of type 'List[str]' or 'str', got {type(type)}")
for type_name in type_list:
if (type_name not in cls._export_node_registry):
cls._export_node_registry[type_name] = []
cls._export_node_registry[type_name].append((spec, operator))
register_func: str = f"register_{type_name}Op"
# If operator is not defined, then it means we try to register a MetaOperator
if register_func not in dir(aidge_core):
raise ValueError(f"Operator of type: {type_name} is not declared as registrable!\nHint: If you try to register a MetaOperator use register_metaop instead.")
else:
# Equivalent to aidge_core.register_ConvOp("ExportLibX", ExportLibX)
aidge_core.__getattribute__(register_func)(cls._name, cls)
return Wrapper
return decorator
# Decorator to register kernels for this export
@classmethod
def register_metaop(cls, type, spec):
def register_metaop(cls, op_type, spec):
def decorator(operator):
class Wrapper(operator):
def __init__(self, *args, **kwargs):
return operator(*args, **kwargs)
if (type not in cls._export_node_registry):
cls._export_node_registry[type] = []
cls._export_node_registry[type].append((spec, operator))
aidge_core.register_MetaOperatorOp([cls._name, type], cls)
spec.attrs.add_attr("type", type) # MetaOperator specs need to verify the type
type_list = []
if isinstance(op_type, list):
type_list = op_type
elif isinstance(op_type, str):
type_list = [op_type]
else:
raise TypeError("Argument 'op_type' of register method should be of type 'List[str]' or 'str', got {type(type)}")
for type_name in type_list:
if (type_name not in cls._export_node_registry):
cls._export_node_registry[type_name] = []
cls._export_node_registry[type_name].append((spec, operator))
aidge_core.register_MetaOperatorOp([cls._name, type_name], cls)
spec.attrs.add_attr("type", type_name) # MetaOperator specs need to verify the type
return Wrapper
return decorator
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment