diff --git a/aidge_core/export_utils/export_registry.py b/aidge_core/export_utils/export_registry.py index b4462895a5e9d9f8874066024f1bd74ee860a58f..2490bfa2e11cc37ef9f40d9762795c713133bd13 100644 --- a/aidge_core/export_utils/export_registry.py +++ b/aidge_core/export_utils/export_registry.py @@ -84,7 +84,6 @@ class ExportLib(aidge_core.OperatorImpl): return export_node return None - # Decorator to register kernels for this export @classmethod def register(cls, op_type, spec): """Decorator to register an operator implementation for a specified operator type. @@ -127,7 +126,6 @@ class ExportLib(aidge_core.OperatorImpl): return Wrapper return decorator - # Decorator to register kernels for this export @classmethod def register_metaop(cls, op_type, spec): """Decorator to register a MetaOperator with the export library. @@ -161,3 +159,38 @@ class ExportLib(aidge_core.OperatorImpl): spec.attrs.add_attr("type", type_name) # MetaOperator specs need to verify the type return Wrapper return decorator + + + @classmethod + def register_generic(cls, op_type, spec): + """Decorator to register a GenericOperator with the export library. + + Registers a GenericOperator under a given operator type and specification. This decorator + is intended for operator types that are grouped as meta operators. + + :param op_type: Operator type(s) to register as a ``GenericOperator``. + :type op_type: Union[str, List[str]] + :param spec: Implementation specification for the GenericOperator. + :type spec: aidge_core.ImplSpec + :return: A wrapper class that initializes the registered GenericOperator. + :rtype: Callable + """ + def decorator(operator): + class Wrapper(operator): + def __init__(self, *args, **kwargs): + return operator(*args, **kwargs) + type_list = [] + if isinstance(op_type, list): + type_list = op_type + elif isinstance(op_type, str): + type_list = [op_type] + else: + raise TypeError("Argument 'op_type' of register method should be of type 'List[str]' or 'str', got {type(type)}") + for type_name in type_list: + if (type_name not in cls._export_node_registry): + cls._export_node_registry[type_name] = [] + cls._export_node_registry[type_name].append((spec, operator)) + aidge_core.register_GenericOperatorOp([cls._name, type_name], cls) + spec.attrs.add_attr("type", type_name) # GenericOperator specs need to verify the type + return Wrapper + return decorator