From 9b322f5978018a1966265a6128d02917b1812f8a Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Wed, 2 Oct 2024 14:54:33 +0000
Subject: [PATCH] Add possibility to register multiple type to the same
 ExportNode

---
 aidge_core/aidge_export_aidge/export.py    | 23 +++++++---
 aidge_core/export_utils/export_registry.py | 52 ++++++++++++++--------
 2 files changed, 52 insertions(+), 23 deletions(-)

diff --git a/aidge_core/aidge_export_aidge/export.py b/aidge_core/aidge_export_aidge/export.py
index cb5aa3ef2..d26ca5dc8 100644
--- a/aidge_core/aidge_export_aidge/export.py
+++ b/aidge_core/aidge_export_aidge/export.py
@@ -7,7 +7,7 @@ import aidge_core.export_utils
 from . import ROOT_EXPORT
 from aidge_core.aidge_export_aidge.registry import ExportSerialize
 
-from aidge_core.export_utils  import ExportNode, generate_file
+from aidge_core.export_utils  import generate_file
 
 def serialize_to_cpp(export_folder: str,
            graph_view: aidge_core.GraphView,
@@ -79,10 +79,23 @@ def serialize_to_cpp(export_folder: str,
         # Next nodes to treat are children of current node
         open_nodes += list(node.get_children())
 
-        if not ExportSerialize.exportable(node):
-            #raise RuntimeError
-            print(f"Node {node.name()} (of type [{node.type()}]) is not exportable !")
-        op = ExportSerialize.get_export_node(node)(node)
+        op_impl = node.get_operator().get_impl()
+        if op_impl is None:
+            raise RuntimeError(f"Operator {node.name()}[{node.type()}] doesn't have an implementation.")
+        if not isinstance(op_impl, ExportSerialize):
+            raise RuntimeError(f"Operator {node.name()}[{node.type()}] doesn't have an exportable backend ({op_impl}).")
+
+        node.get_operator().set_backend(ExportSerialize._name)
+
+        required_specs = op_impl.get_required_spec()
+        specs = op_impl.get_best_match(required_specs)
+        export_node = op_impl.get_export_node(specs)
+        if export_node is None:
+            raise RuntimeError(f"Could not find export node for {node.name()}[{node.type()}].")
+        op = export_node(
+            node, [], False, False) # Note: is_input and is_output is not used for this export
+
+
         set_operator.add(node.type())
 
         # TODO: list_configs and list_actions don't need to be passed by argument
diff --git a/aidge_core/export_utils/export_registry.py b/aidge_core/export_utils/export_registry.py
index 18ad964ea..6cdba2851 100644
--- a/aidge_core/export_utils/export_registry.py
+++ b/aidge_core/export_utils/export_registry.py
@@ -68,37 +68,53 @@ class ExportLib(aidge_core.OperatorImpl):
 
     # Decorator to register kernels for this export
     @classmethod
-    def register(cls, type, spec):
+    def register(cls, op_type, spec):
         def decorator(operator):
             class Wrapper(operator):
                 def __init__(self, *args, **kwargs):
                     return operator(*args, **kwargs)
-            if (type not in cls._export_node_registry):
-                cls._export_node_registry[type] = []
-            cls._export_node_registry[type].append((spec, operator))
-
-            register_func: str = f"register_{type}Op"
-            # If operator is not defined, then it means we try to register a MetaOperator
-            if register_func not in dir(aidge_core):
-                raise ValueError(f"Operator of type: {type} is not declared as registrable!\nHint: If you try to register a MetaOperator use register_metaop instead.")
+            type_list = []
+            if isinstance(op_type, list):
+                type_list = op_type
+            elif isinstance(op_type, str):
+                type_list = [op_type]
             else:
-                # Equivalent to aidge_core.register_ConvOp("ExportLibX", ExportLibX)
-                aidge_core.__getattribute__(register_func)(cls._name, cls)
+                raise TypeError("Argument type of register method should be of type 'List[str]' or 'str', got {type(type)}")
+
+            for type_name in type_list:
+                if (type_name not in cls._export_node_registry):
+                    cls._export_node_registry[type_name] = []
+                cls._export_node_registry[type_name].append((spec, operator))
+
+                register_func: str = f"register_{type_name}Op"
+                # If operator is not defined, then it means we try to register a MetaOperator
+                if register_func not in dir(aidge_core):
+                    raise ValueError(f"Operator of type: {type_name} is not declared as registrable!\nHint: If you try to register a MetaOperator use register_metaop instead.")
+                else:
+                    # Equivalent to aidge_core.register_ConvOp("ExportLibX", ExportLibX)
+                    aidge_core.__getattribute__(register_func)(cls._name, cls)
             return Wrapper
         return decorator
 
     # Decorator to register kernels for this export
     @classmethod
-    def register_metaop(cls, type, spec):
+    def register_metaop(cls, op_type, spec):
         def decorator(operator):
             class Wrapper(operator):
                 def __init__(self, *args, **kwargs):
                     return operator(*args, **kwargs)
-            if (type not in cls._export_node_registry):
-                cls._export_node_registry[type] = []
-
-            cls._export_node_registry[type].append((spec, operator))
-            aidge_core.register_MetaOperatorOp([cls._name, type], cls)
-            spec.attrs.add_attr("type", type) # MetaOperator specs need to verify the type
+            type_list = []
+            if isinstance(op_type, list):
+                type_list = op_type
+            elif isinstance(op_type, str):
+                type_list = [op_type]
+            else:
+                raise TypeError("Argument 'op_type' of register method should be of type 'List[str]' or 'str', got {type(type)}")
+            for type_name in type_list:
+                if (type_name not in cls._export_node_registry):
+                    cls._export_node_registry[type_name] = []
+                cls._export_node_registry[type_name].append((spec, operator))
+                aidge_core.register_MetaOperatorOp([cls._name, type_name], cls)
+                spec.attrs.add_attr("type", type_name) # MetaOperator specs need to verify the type
             return Wrapper
         return decorator
-- 
GitLab