From 80ca13c8c99f8f183dc5a9d735df1369d1d65c7d Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Fri, 12 Jul 2024 14:19:21 +0000 Subject: [PATCH] Add exportable function to check if a node is exportable instead of having complicated keys in registry. --- aidge_core/export_utils/export_registry.py | 45 ++++++++++++---------- aidge_core/export_utils/node_export.py | 10 +++++ 2 files changed, 35 insertions(+), 20 deletions(-) diff --git a/aidge_core/export_utils/export_registry.py b/aidge_core/export_utils/export_registry.py index cf0ba3b70..dd200ac8b 100644 --- a/aidge_core/export_utils/export_registry.py +++ b/aidge_core/export_utils/export_registry.py @@ -1,13 +1,11 @@ -from typing import Dict, Tuple, Set +from typing import Dict, List, Set import aidge_core from aidge_core.export_utils import ExportNode from enum import Enum -# Language +# Language LANGUAGE = Enum('LANGUAGE', ['Cpp/C']) -# Define new type registry_key -registry_key = Tuple[str, aidge_core.dtype, aidge_core.dformat] # TODO: very naive implementation ! # error handling should be added ! @@ -17,7 +15,7 @@ class ExportLib(): # Should be abstract ? # Lib name usefull ? _name:str = None # Registry of exportNode - _export_node_registry:Dict[registry_key, ExportNode] = {} + _export_node_registry:Dict[str, List[ExportNode]] = {} # The language type usefull ? _language: LANGUAGE = None def __init__(self) -> None: @@ -31,35 +29,42 @@ class ExportLib(): # Should be abstract ? :rtype: bool """ # TODO: should return usable error that can be catch to know if only some keys have not been respected ! - key: registry_key = (node.type(), node.dtype(), node.dformat()) - return key in cls._export_node_registry + if node.type() not in cls._export_node_registry: + return False + else: + for i in cls._export_node_registry[node.type()]: + if i.exportable(node): + return True + return False @classmethod def supported_operators(cls)->Set[str]: """ - :return: list of supported operator by this ExportLib - :rtype: List[str] + :return: Set of operator supported by this ExportLib + :rtype: Set[str] """ - operators = set() - for key in cls._export_node_registry.keys(): - operators.add(key[0]) - return operators + return cls._export_node_registry.keys() @classmethod def get_export_node(cls, node:aidge_core.Node)->ExportNode: """ :param node: Node to transform - :type node: aidge_core.Node + :type node: :py:class:`aidge_core.Node` :return: Corresponding export node. :rtype: ExportNode """ if not cls.exportable(node): - raise ValueError("Node is not exportable ...") - key: registry_key = (node.type(), node.dtype(), node.dformat()) - return key in cls._export_node_registry + raise ValueError(f"Node {node.type()} is not exportable by ExportLib {cls._name} !") + if len(cls._export_node_registry[node.type()]) != 1: + raise RuntimeError("ExportLib registry doesn't support when multiple export node are available yet ...") + else: + return cls._export_node_registry[node.type()][0](node) @classmethod - def add_export_node(cls, key:registry_key, eNode:ExportNode)->None: - cls._export_node_registry[key] = eNode + def add_export_node(cls, key:str, eNode:ExportNode)->None: + if key not in cls._export_node_registry: + cls._export_node_registry[key] = [eNode] + else: + cls._export_node_registry[key].append(eNode) -def operator_register(lib: ExportLib, key:registry_key, *args): +def operator_register(lib: ExportLib, key:str, *args): """Helper decorator to register an :py:class:`ExportNode` to an :py:class:`ExportLib` """ def decorator(operator): diff --git a/aidge_core/export_utils/node_export.py b/aidge_core/export_utils/node_export.py index 310d0ff8a..d311e43aa 100644 --- a/aidge_core/export_utils/node_export.py +++ b/aidge_core/export_utils/node_export.py @@ -181,7 +181,17 @@ class ExportNode(ABC): self.attributes["out_width"][idx] = get_width(tensor) else: print(f"No output for {self.node.name()}") + @classmethod + @abstractmethod + def exportable(cls, node: aidge_core.Node)->bool: + """Given a :py:class:`aidge_core.Node` return if the node can be exported or not. + :param node: Node to test the exportability + :type node: :py:class:`aidge_core.Node` + :return: True if the node is exportable, False oterhwise. + :rtype: bool + """ + pass class ExportNodeCpp(ExportNode): # Path to the template defining how to export the node definition -- GitLab