From 4e15652f34cbda41f505f7d9aaa6198946088ad1 Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Mon, 7 Oct 2024 12:18:53 +0000
Subject: [PATCH] Adapt Export serialize to refactor.

---
 aidge_core/__init__.py                        |  3 +-
 aidge_core/aidge_export_aidge/export.py       |  7 +--
 .../operator_export/conv.py                   | 34 +++------------
 .../aidge_export_aidge/operator_export/fc.py  | 42 +++---------------
 .../operator_export/maxpooling.py             | 34 +++------------
 .../operator_export/producer.py               | 43 +++----------------
 .../operator_export/relu.py                   | 23 +++-------
 .../aidge_export_aidge/operator_export/sub.py | 25 +++--------
 aidge_core/aidge_export_aidge/registry.py     |  6 ++-
 aidge_core/export_utils/node_export.py        |  5 +--
 10 files changed, 43 insertions(+), 179 deletions(-)

diff --git a/aidge_core/__init__.py b/aidge_core/__init__.py
index 56c19a5fa..8544c5647 100644
--- a/aidge_core/__init__.py
+++ b/aidge_core/__init__.py
@@ -10,6 +10,5 @@ SPDX-License-Identifier: EPL-2.0
 from aidge_core.aidge_core import * # import so generated by PyBind
 import aidge_core.export_utils
 import aidge_core.utils
-# TODO: Commented for dev the new register system
-# from aidge_core.aidge_export_aidge import serialize_to_cpp
+from aidge_core.aidge_export_aidge import serialize_to_cpp
 from ._version import *
diff --git a/aidge_core/aidge_export_aidge/export.py b/aidge_core/aidge_export_aidge/export.py
index d26ca5dc8..747906e3e 100644
--- a/aidge_core/aidge_export_aidge/export.py
+++ b/aidge_core/aidge_export_aidge/export.py
@@ -78,22 +78,19 @@ def serialize_to_cpp(export_folder: str,
             continue
         # Next nodes to treat are children of current node
         open_nodes += list(node.get_children())
-
+        node.get_operator().set_backend(ExportSerialize._name)
         op_impl = node.get_operator().get_impl()
         if op_impl is None:
             raise RuntimeError(f"Operator {node.name()}[{node.type()}] doesn't have an implementation.")
         if not isinstance(op_impl, ExportSerialize):
             raise RuntimeError(f"Operator {node.name()}[{node.type()}] doesn't have an exportable backend ({op_impl}).")
-
-        node.get_operator().set_backend(ExportSerialize._name)
-
         required_specs = op_impl.get_required_spec()
         specs = op_impl.get_best_match(required_specs)
         export_node = op_impl.get_export_node(specs)
         if export_node is None:
             raise RuntimeError(f"Could not find export node for {node.name()}[{node.type()}].")
         op = export_node(
-            node, [], False, False) # Note: is_input and is_output is not used for this export
+            node, None, False, False) # Note: is_input and is_output is not used for this export
 
 
         set_operator.add(node.type())
diff --git a/aidge_core/aidge_export_aidge/operator_export/conv.py b/aidge_core/aidge_export_aidge/operator_export/conv.py
index 1b97d0b41..41783b157 100644
--- a/aidge_core/aidge_export_aidge/operator_export/conv.py
+++ b/aidge_core/aidge_export_aidge/operator_export/conv.py
@@ -1,11 +1,12 @@
 from aidge_core.aidge_export_aidge.registry import ExportSerialize
 from aidge_core.aidge_export_aidge import ROOT_EXPORT
-from aidge_core.export_utils import ExportNodeCpp, operator_register
+from aidge_core.export_utils import ExportNodeCpp
+from aidge_core import ImplSpec, IOSpec, dtype
 
-@operator_register(ExportSerialize, "Conv")
+@ExportSerialize.register(["Conv1D", "Conv2D"], ImplSpec(IOSpec(dtype.any)))
 class Conv(ExportNodeCpp):
-    def __init__(self, node):
-        super().__init__(node)
+    def __init__(self, node, mem_info, is_input, is_output):
+        super().__init__(node, mem_info, is_input, is_output)
         self.config_template = str(
             ROOT_EXPORT / "templates/attributes/conv.jinja")
         self.forward_template = str(
@@ -14,28 +15,3 @@ class Conv(ExportNodeCpp):
         self.kernels_to_copy = []
         self.config_path = "include/parameters"
         self.config_extension = "hpp"
-    @classmethod
-    def exportable(cls, node):
-        return True
-
-    # def export(self, export_folder:Path, list_configs:list):
-    #     include_path = f"attributes/{self.name}.hpp"
-    #     filepath = export_folder / f"include/{include_path}"
-
-    #     generate_file(
-    #         filepath,
-    #         ROOT_EXPORT / "templates/attributes/conv.jinja",
-    #         name=self.name,
-    #         **self.attributes
-    #     )
-    #     list_configs.append(include_path)
-    #     return list_configs
-
-    # def forward(self, list_actions:list):
-    #     list_actions.append(generate_str(
-    #         ROOT_EXPORT /"templates/graph_ctor/conv.jinja",
-    #         name=self.name,
-    #         inputs=parse_node_input(self.node.inputs()),
-    #         **self.attributes
-    #     ))
-    #     return list_actions
diff --git a/aidge_core/aidge_export_aidge/operator_export/fc.py b/aidge_core/aidge_export_aidge/operator_export/fc.py
index 8161eada9..4e1bc97c1 100644
--- a/aidge_core/aidge_export_aidge/operator_export/fc.py
+++ b/aidge_core/aidge_export_aidge/operator_export/fc.py
@@ -1,13 +1,13 @@
-from aidge_core.aidge_export_aidge import ROOT_EXPORT
 from aidge_core.aidge_export_aidge.registry import ExportSerialize
+from aidge_core.aidge_export_aidge import ROOT_EXPORT
+from aidge_core.export_utils import ExportNodeCpp
+from aidge_core import ImplSpec, IOSpec, dtype
 
-from aidge_core.export_utils import ExportNodeCpp, operator_register
-from pathlib import Path
 
-@operator_register(ExportSerialize, "FC")
+@ExportSerialize.register("FC", ImplSpec(IOSpec(dtype.any)))
 class FC(ExportNodeCpp):
-    def __init__(self, node):
-        super().__init__(node)
+    def __init__(self, node, mem_info, is_input, is_output):
+        super().__init__(node, mem_info, is_input, is_output)
         self.config_template = str(
             ROOT_EXPORT / "templates/attributes/fc.jinja")
         self.forward_template = str(
@@ -16,33 +16,3 @@ class FC(ExportNodeCpp):
         self.kernels_to_copy = []
         self.config_path = "include/parameters"
         self.config_extension = "hpp"
-    @classmethod
-    def exportable(cls, node):
-        return True
-
-    # def export(self, export_folder:Path, list_configs:list):
-
-
-    #     include_path = f"attributes/{self.name}.hpp"
-    #     filepath = export_folder / f"include/{include_path}"
-
-
-    #     generate_file(
-    #         filepath,
-    #         ROOT_EXPORT / "templates/attributes/fc.jinja",
-    #         name=self.name,
-    #         InChannels=self.inputs_dims[1][1],
-    #         OutChannels=self.operator.out_channels(),
-    #         **self.attributes
-    #     )
-    #     list_configs.append(include_path)
-    #     return list_configs
-
-    # def forward(self, list_actions:list):
-    #     list_actions.append(generate_str(
-    #         ROOT_EXPORT / "templates/graph_ctor/fc.jinja",
-    #         name=self.name,
-    #         inputs=parse_node_input(self.node.inputs()),
-    #         **self.attributes
-    #     ))
-    #     return list_actions
diff --git a/aidge_core/aidge_export_aidge/operator_export/maxpooling.py b/aidge_core/aidge_export_aidge/operator_export/maxpooling.py
index 571c72a01..6cbccadd9 100644
--- a/aidge_core/aidge_export_aidge/operator_export/maxpooling.py
+++ b/aidge_core/aidge_export_aidge/operator_export/maxpooling.py
@@ -1,11 +1,12 @@
 from aidge_core.aidge_export_aidge.registry import ExportSerialize
 from aidge_core.aidge_export_aidge import ROOT_EXPORT
-from aidge_core.export_utils import ExportNodeCpp, operator_register
+from aidge_core.export_utils import ExportNodeCpp
+from aidge_core import ImplSpec, IOSpec, dtype
 
-@operator_register(ExportSerialize,"MaxPooling")
+@ExportSerialize.register(["MaxPooling1D", "MaxPooling2D", "MaxPooling3D"], ImplSpec(IOSpec(dtype.any)))
 class MaxPooling(ExportNodeCpp):
-    def __init__(self, node):
-        super().__init__(node)
+    def __init__(self, node, mem_info, is_input, is_output):
+        super().__init__(node, mem_info, is_input, is_output)
         self.config_template = str(
             ROOT_EXPORT / "templates/attributes/maxpooling.jinja")
         self.forward_template = str(
@@ -14,28 +15,3 @@ class MaxPooling(ExportNodeCpp):
         self.kernels_to_copy = []
         self.config_path = "include/parameters"
         self.config_extension = "hpp"
-    @classmethod
-    def exportable(cls, node):
-        return True
-
-    # def export(self, export_folder:Path, list_configs:list):
-    #     include_path = f"attributes/{self.name}.hpp"
-    #     filepath = export_folder / f"include/{include_path}"
-
-    #     generate_file(
-    #         filepath,
-    #         ROOT_EXPORT / "templates/attributes/maxpooling.jinja",
-    #         name=self.name,
-    #         **self.attributes
-    #     )
-    #     list_configs.append(include_path)
-    #     return list_configs
-
-    # def forward(self, list_actions:list):
-    #     list_actions.append(generate_str(
-    #         ROOT_EXPORT / "templates/graph_ctor/maxpooling.jinja",
-    #         name=self.name,
-    #         inputs=parse_node_input(self.node.inputs()),
-    #         **self.attributes
-    #     ))
-    #     return list_actions
diff --git a/aidge_core/aidge_export_aidge/operator_export/producer.py b/aidge_core/aidge_export_aidge/operator_export/producer.py
index c6ace522f..7d0233917 100644
--- a/aidge_core/aidge_export_aidge/operator_export/producer.py
+++ b/aidge_core/aidge_export_aidge/operator_export/producer.py
@@ -1,17 +1,18 @@
-from aidge_core.aidge_export_aidge import ROOT_EXPORT
-from aidge_core.aidge_export_aidge.registry import ExportSerialize
-from aidge_core.export_utils import ExportNodeCpp, operator_register
 import numpy as np
 
+from aidge_core.aidge_export_aidge.registry import ExportSerialize
+from aidge_core.aidge_export_aidge import ROOT_EXPORT
+from aidge_core.export_utils import ExportNodeCpp
+from aidge_core import ImplSpec, IOSpec, dtype
 
-@operator_register(ExportSerialize, "Producer")
+@ExportSerialize.register("Producer", ImplSpec(IOSpec(dtype.any)))
 class Producer(ExportNodeCpp):
     """
     If there is a standardization of the export operators
     then this class should be just a inheritance of ProducerCPP
     """
-    def __init__(self, node):
-        super().__init__(node)
+    def __init__(self, node, mem_info, is_input, is_output):
+        super().__init__(node, mem_info, is_input, is_output)
         child, in_idx = self.node.output(0)[0]
 
         self.values = np.array(self.operator.get_output(0))
@@ -26,34 +27,4 @@ class Producer(ExportNodeCpp):
         self.kernels_to_copy = []
         self.config_path = "include/parameters"
         self.config_extension = "hpp"
-    @classmethod
-    def exportable(cls, node):
-        return True
-
-    # def export(self, export_folder:Path, list_configs:list):
-    #     assert(len(self.node.output(0)) == 1)
-
-    #     include_path = f"parameters/{self.tensor_name}.hpp"
-    #     filepath = export_folder / f"include/{include_path}"
-
-    #     aidge_tensor = self.operator.get_output(0)
-    #     datatype = aidge2c(aidge_tensor.dtype())
-    #     generate_file(
-    #         filepath,
-    #         ROOT_EXPORT / "templates/parameter.jinja",
-    #         dims = aidge_tensor.dims(),
-    #         data_t = datatype,
-    #         name = self.tensor_name,
-    #         values = str(aidge_tensor)
-    #     )
-    #     list_configs.append(include_path)
-    #     return list_configs
 
-    # def forward(self, list_actions:list):
-    #     list_actions.append(generate_str(
-    #         ROOT_EXPORT / "templates/graph_ctor/producer.jinja",
-    #         name=self.name,
-    #         tensor_name=self.tensor_name,
-    #         **self.attributes
-    #     ))
-    #     return list_actions
diff --git a/aidge_core/aidge_export_aidge/operator_export/relu.py b/aidge_core/aidge_export_aidge/operator_export/relu.py
index 8012b2d04..b58f3dfc7 100644
--- a/aidge_core/aidge_export_aidge/operator_export/relu.py
+++ b/aidge_core/aidge_export_aidge/operator_export/relu.py
@@ -1,27 +1,14 @@
 from aidge_core.aidge_export_aidge.registry import ExportSerialize
 from aidge_core.aidge_export_aidge import ROOT_EXPORT
-from aidge_core.export_utils import ExportNodeCpp, operator_register
+from aidge_core.export_utils import ExportNodeCpp
+from aidge_core import ImplSpec, IOSpec, dtype
 
-@operator_register(ExportSerialize, "ReLU")
+@ExportSerialize.register("ReLU", ImplSpec(IOSpec(dtype.any)))
 class ReLU(ExportNodeCpp):
-    def __init__(self, node):
-        super().__init__(node)
+    def __init__(self, node, mem_info, is_input, is_output):
+        super().__init__(node, mem_info, is_input, is_output)
         self.config_template = ""
         self.forward_template = str(
             ROOT_EXPORT / "templates/graph_ctor/relu.jinja")
         self.include_list = []
         self.kernels_to_copy = []
-    @classmethod
-    def exportable(cls, node):
-        return True
-    # def export(self, export_folder:Path, list_configs:list):
-    #     return list_configs
-
-    # def forward(self, list_actions:list):
-    #     list_actions.append(generate_str(
-    #         ROOT_EXPORT / "templates/graph_ctor/relu.jinja",
-    #         name=self.name,
-    #         inputs=parse_node_input(self.node.inputs()),
-    #         **self.attributes
-    #     ))
-    #     return list_actions
diff --git a/aidge_core/aidge_export_aidge/operator_export/sub.py b/aidge_core/aidge_export_aidge/operator_export/sub.py
index 80a96c7ee..f4468d750 100644
--- a/aidge_core/aidge_export_aidge/operator_export/sub.py
+++ b/aidge_core/aidge_export_aidge/operator_export/sub.py
@@ -1,29 +1,14 @@
 from aidge_core.aidge_export_aidge.registry import ExportSerialize
 from aidge_core.aidge_export_aidge import ROOT_EXPORT
-from aidge_core.export_utils import ExportNodeCpp, operator_register
+from aidge_core.export_utils import ExportNodeCpp
+from aidge_core import ImplSpec, IOSpec, dtype
 
-@operator_register(ExportSerialize, "Sub")
+@ExportSerialize.register("Sub", ImplSpec(IOSpec(dtype.any)))
 class Sub(ExportNodeCpp):
-    def __init__(self, node):
-        super().__init__(node)
+    def __init__(self, node, mem_info, is_input, is_output):
+        super().__init__(node, mem_info, is_input, is_output)
         self.config_template = ""
         self.forward_template = str(
             ROOT_EXPORT / "templates/graph_ctor/sub.jinja")
         self.include_list = []
         self.kernels_to_copy = []
-
-    @classmethod
-    def exportable(cls, node):
-        return True
-
-    # def export(self, export_folder:Path, list_configs:list):
-    #     return list_configs
-
-    # def forward(self, list_actions:list):
-    #     list_actions.append(generate_str(
-    #         ROOT_EXPORT / "templates/graph_ctor/sub.jinja",
-    #         name=self.name,
-    #         inputs=parse_node_input(self.node.inputs()),
-    #         **self.attributes
-    #     ))
-    #     return list_actions
diff --git a/aidge_core/aidge_export_aidge/registry.py b/aidge_core/aidge_export_aidge/registry.py
index 477dc583d..fe94a2239 100644
--- a/aidge_core/aidge_export_aidge/registry.py
+++ b/aidge_core/aidge_export_aidge/registry.py
@@ -1,6 +1,10 @@
 from aidge_core.export_utils import ExportLib
 from . import ROOT_EXPORT
+import aidge_core
+
 
 class ExportSerialize(ExportLib):
-    name="export_serialize"
+    _name="export_serialize"
 
+aidge_core.register_Tensor(["export_serialize", aidge_core.dtype.float32],
+                           aidge_core.get_key_value_Tensor(["cpu", aidge_core.dtype.float32]))
diff --git a/aidge_core/export_utils/node_export.py b/aidge_core/export_utils/node_export.py
index 15911ee58..70194a141 100644
--- a/aidge_core/export_utils/node_export.py
+++ b/aidge_core/export_utils/node_export.py
@@ -203,8 +203,7 @@ class ExportNode(ABC):
                 self.operator.input_category(idx) == aidge_core.InputCategory.OptionalData:
                 pass
             else:
-                print(self.operator.input_category(idx))
-                raise RuntimeError(f"No input for {self.node.name()} at input {idx}")
+                aidge_core.Log.notice(f"No input for {self.node.name()} at input {idx}")
         for idx, list_child_node_in_id in enumerate(self.node.outputs()):
             self.outputs += [node_in_id[0]
                              for node_in_id in list_child_node_in_id]
@@ -257,7 +256,7 @@ class ExportNode(ABC):
                     else:
                         self.attributes["mem_info_wrap_size"][idx] = 0
             else:
-                print(f"No output for {self.node.name()}")
+                aidge_core.Log.notice(f"No output for {self.node.name()}")
 
 class ExportNodeCpp(ExportNode):
     # Path to the template defining how to export the node definition
-- 
GitLab