From 9b322f5978018a1966265a6128d02917b1812f8a Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Wed, 2 Oct 2024 14:54:33 +0000 Subject: [PATCH] Add possibility to register multiple type to the same ExportNode --- aidge_core/aidge_export_aidge/export.py | 23 +++++++--- aidge_core/export_utils/export_registry.py | 52 ++++++++++++++-------- 2 files changed, 52 insertions(+), 23 deletions(-) diff --git a/aidge_core/aidge_export_aidge/export.py b/aidge_core/aidge_export_aidge/export.py index cb5aa3ef2..d26ca5dc8 100644 --- a/aidge_core/aidge_export_aidge/export.py +++ b/aidge_core/aidge_export_aidge/export.py @@ -7,7 +7,7 @@ import aidge_core.export_utils from . import ROOT_EXPORT from aidge_core.aidge_export_aidge.registry import ExportSerialize -from aidge_core.export_utils import ExportNode, generate_file +from aidge_core.export_utils import generate_file def serialize_to_cpp(export_folder: str, graph_view: aidge_core.GraphView, @@ -79,10 +79,23 @@ def serialize_to_cpp(export_folder: str, # Next nodes to treat are children of current node open_nodes += list(node.get_children()) - if not ExportSerialize.exportable(node): - #raise RuntimeError - print(f"Node {node.name()} (of type [{node.type()}]) is not exportable !") - op = ExportSerialize.get_export_node(node)(node) + op_impl = node.get_operator().get_impl() + if op_impl is None: + raise RuntimeError(f"Operator {node.name()}[{node.type()}] doesn't have an implementation.") + if not isinstance(op_impl, ExportSerialize): + raise RuntimeError(f"Operator {node.name()}[{node.type()}] doesn't have an exportable backend ({op_impl}).") + + node.get_operator().set_backend(ExportSerialize._name) + + required_specs = op_impl.get_required_spec() + specs = op_impl.get_best_match(required_specs) + export_node = op_impl.get_export_node(specs) + if export_node is None: + raise RuntimeError(f"Could not find export node for {node.name()}[{node.type()}].") + op = export_node( + node, [], False, False) # Note: is_input and is_output is not used for this export + + set_operator.add(node.type()) # TODO: list_configs and list_actions don't need to be passed by argument diff --git a/aidge_core/export_utils/export_registry.py b/aidge_core/export_utils/export_registry.py index 18ad964ea..6cdba2851 100644 --- a/aidge_core/export_utils/export_registry.py +++ b/aidge_core/export_utils/export_registry.py @@ -68,37 +68,53 @@ class ExportLib(aidge_core.OperatorImpl): # Decorator to register kernels for this export @classmethod - def register(cls, type, spec): + def register(cls, op_type, spec): def decorator(operator): class Wrapper(operator): def __init__(self, *args, **kwargs): return operator(*args, **kwargs) - if (type not in cls._export_node_registry): - cls._export_node_registry[type] = [] - cls._export_node_registry[type].append((spec, operator)) - - register_func: str = f"register_{type}Op" - # If operator is not defined, then it means we try to register a MetaOperator - if register_func not in dir(aidge_core): - raise ValueError(f"Operator of type: {type} is not declared as registrable!\nHint: If you try to register a MetaOperator use register_metaop instead.") + type_list = [] + if isinstance(op_type, list): + type_list = op_type + elif isinstance(op_type, str): + type_list = [op_type] else: - # Equivalent to aidge_core.register_ConvOp("ExportLibX", ExportLibX) - aidge_core.__getattribute__(register_func)(cls._name, cls) + raise TypeError("Argument type of register method should be of type 'List[str]' or 'str', got {type(type)}") + + for type_name in type_list: + if (type_name not in cls._export_node_registry): + cls._export_node_registry[type_name] = [] + cls._export_node_registry[type_name].append((spec, operator)) + + register_func: str = f"register_{type_name}Op" + # If operator is not defined, then it means we try to register a MetaOperator + if register_func not in dir(aidge_core): + raise ValueError(f"Operator of type: {type_name} is not declared as registrable!\nHint: If you try to register a MetaOperator use register_metaop instead.") + else: + # Equivalent to aidge_core.register_ConvOp("ExportLibX", ExportLibX) + aidge_core.__getattribute__(register_func)(cls._name, cls) return Wrapper return decorator # Decorator to register kernels for this export @classmethod - def register_metaop(cls, type, spec): + def register_metaop(cls, op_type, spec): def decorator(operator): class Wrapper(operator): def __init__(self, *args, **kwargs): return operator(*args, **kwargs) - if (type not in cls._export_node_registry): - cls._export_node_registry[type] = [] - - cls._export_node_registry[type].append((spec, operator)) - aidge_core.register_MetaOperatorOp([cls._name, type], cls) - spec.attrs.add_attr("type", type) # MetaOperator specs need to verify the type + type_list = [] + if isinstance(op_type, list): + type_list = op_type + elif isinstance(op_type, str): + type_list = [op_type] + else: + raise TypeError("Argument 'op_type' of register method should be of type 'List[str]' or 'str', got {type(type)}") + for type_name in type_list: + if (type_name not in cls._export_node_registry): + cls._export_node_registry[type_name] = [] + cls._export_node_registry[type_name].append((spec, operator)) + aidge_core.register_MetaOperatorOp([cls._name, type_name], cls) + spec.attrs.add_attr("type", type_name) # MetaOperator specs need to verify the type return Wrapper return decorator -- GitLab