From 782c910261089741f2b32c258feb30e1ee7e710e Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Fri, 11 Oct 2024 11:58:58 +0000
Subject: [PATCH] Remove is_input and is_output from node_export.

---
 aidge_core/aidge_export_aidge/export.py       |  2 +-
 .../operator_export/conv.py                   |  4 ++--
 .../aidge_export_aidge/operator_export/fc.py  |  4 ++--
 .../operator_export/maxpooling.py             |  4 ++--
 .../operator_export/producer.py               |  4 ++--
 .../operator_export/relu.py                   |  4 ++--
 .../aidge_export_aidge/operator_export/sub.py |  4 ++--
 aidge_core/export_utils/node_export.py        | 21 +++++++++++--------
 aidge_core/export_utils/scheduler_export.py   |  3 +--
 9 files changed, 26 insertions(+), 24 deletions(-)

diff --git a/aidge_core/aidge_export_aidge/export.py b/aidge_core/aidge_export_aidge/export.py
index 747906e3e..aa993c4be 100644
--- a/aidge_core/aidge_export_aidge/export.py
+++ b/aidge_core/aidge_export_aidge/export.py
@@ -90,7 +90,7 @@ def serialize_to_cpp(export_folder: str,
         if export_node is None:
             raise RuntimeError(f"Could not find export node for {node.name()}[{node.type()}].")
         op = export_node(
-            node, None, False, False) # Note: is_input and is_output is not used for this export
+            node, None)
 
 
         set_operator.add(node.type())
diff --git a/aidge_core/aidge_export_aidge/operator_export/conv.py b/aidge_core/aidge_export_aidge/operator_export/conv.py
index c8f1fff3b..8805629b7 100644
--- a/aidge_core/aidge_export_aidge/operator_export/conv.py
+++ b/aidge_core/aidge_export_aidge/operator_export/conv.py
@@ -5,8 +5,8 @@ from aidge_core import ImplSpec, IOSpec, dtype
 
 @ExportSerialize.register(["Conv1D", "Conv2D"], ImplSpec(IOSpec(dtype.any)))
 class Conv(ExportNodeCpp):
-    def __init__(self, node, mem_info, is_input, is_output):
-        super().__init__(node, mem_info, is_input, is_output)
+    def __init__(self, node, mem_info):
+        super().__init__(node, mem_info)
         self.config_template = str(
             ROOT_EXPORT / "templates/attributes/conv.jinja")
         self.forward_template = str(
diff --git a/aidge_core/aidge_export_aidge/operator_export/fc.py b/aidge_core/aidge_export_aidge/operator_export/fc.py
index a1207ee51..6fae97d66 100644
--- a/aidge_core/aidge_export_aidge/operator_export/fc.py
+++ b/aidge_core/aidge_export_aidge/operator_export/fc.py
@@ -6,8 +6,8 @@ from aidge_core import ImplSpec, IOSpec, dtype
 
 @ExportSerialize.register("FC", ImplSpec(IOSpec(dtype.any)))
 class FC(ExportNodeCpp):
-    def __init__(self, node, mem_info, is_input, is_output):
-        super().__init__(node, mem_info, is_input, is_output)
+    def __init__(self, node, mem_info):
+        super().__init__(node, mem_info)
         self.config_template = str(
             ROOT_EXPORT / "templates/attributes/fc.jinja")
         self.forward_template = str(
diff --git a/aidge_core/aidge_export_aidge/operator_export/maxpooling.py b/aidge_core/aidge_export_aidge/operator_export/maxpooling.py
index ad35970e2..df53de9eb 100644
--- a/aidge_core/aidge_export_aidge/operator_export/maxpooling.py
+++ b/aidge_core/aidge_export_aidge/operator_export/maxpooling.py
@@ -5,8 +5,8 @@ from aidge_core import ImplSpec, IOSpec, dtype
 
 @ExportSerialize.register(["MaxPooling1D", "MaxPooling2D", "MaxPooling3D"], ImplSpec(IOSpec(dtype.any)))
 class MaxPooling(ExportNodeCpp):
-    def __init__(self, node, mem_info, is_input, is_output):
-        super().__init__(node, mem_info, is_input, is_output)
+    def __init__(self, node, mem_info):
+        super().__init__(node, mem_info)
         self.config_template = str(
             ROOT_EXPORT / "templates/attributes/maxpooling.jinja")
         self.forward_template = str(
diff --git a/aidge_core/aidge_export_aidge/operator_export/producer.py b/aidge_core/aidge_export_aidge/operator_export/producer.py
index d378c0531..475d36255 100644
--- a/aidge_core/aidge_export_aidge/operator_export/producer.py
+++ b/aidge_core/aidge_export_aidge/operator_export/producer.py
@@ -11,8 +11,8 @@ class Producer(ExportNodeCpp):
     If there is a standardization of the export operators
     then this class should be just a inheritance of ProducerCPP
     """
-    def __init__(self, node, mem_info, is_input, is_output):
-        super().__init__(node, mem_info, is_input, is_output)
+    def __init__(self, node, mem_info):
+        super().__init__(node, mem_info)
         child, in_idx = self.node.output(0)[0]
 
         self.values = np.array(self.operator.get_output(0))
diff --git a/aidge_core/aidge_export_aidge/operator_export/relu.py b/aidge_core/aidge_export_aidge/operator_export/relu.py
index b58f3dfc7..300135734 100644
--- a/aidge_core/aidge_export_aidge/operator_export/relu.py
+++ b/aidge_core/aidge_export_aidge/operator_export/relu.py
@@ -5,8 +5,8 @@ from aidge_core import ImplSpec, IOSpec, dtype
 
 @ExportSerialize.register("ReLU", ImplSpec(IOSpec(dtype.any)))
 class ReLU(ExportNodeCpp):
-    def __init__(self, node, mem_info, is_input, is_output):
-        super().__init__(node, mem_info, is_input, is_output)
+    def __init__(self, node, mem_info):
+        super().__init__(node, mem_info)
         self.config_template = ""
         self.forward_template = str(
             ROOT_EXPORT / "templates/graph_ctor/relu.jinja")
diff --git a/aidge_core/aidge_export_aidge/operator_export/sub.py b/aidge_core/aidge_export_aidge/operator_export/sub.py
index f4468d750..b728e088d 100644
--- a/aidge_core/aidge_export_aidge/operator_export/sub.py
+++ b/aidge_core/aidge_export_aidge/operator_export/sub.py
@@ -5,8 +5,8 @@ from aidge_core import ImplSpec, IOSpec, dtype
 
 @ExportSerialize.register("Sub", ImplSpec(IOSpec(dtype.any)))
 class Sub(ExportNodeCpp):
-    def __init__(self, node, mem_info, is_input, is_output):
-        super().__init__(node, mem_info, is_input, is_output)
+    def __init__(self, node, mem_info):
+        super().__init__(node, mem_info)
         self.config_template = ""
         self.forward_template = str(
             ROOT_EXPORT / "templates/graph_ctor/sub.jinja")
diff --git a/aidge_core/export_utils/node_export.py b/aidge_core/export_utils/node_export.py
index 8059eeea1..98c94ba2e 100644
--- a/aidge_core/export_utils/node_export.py
+++ b/aidge_core/export_utils/node_export.py
@@ -96,6 +96,7 @@ class ExportNode(ABC):
     - **nb_in**: Number of inputs, ``int``
     - **in_name**: unique name for each input, if no input node the name is ``{node_name}_input_{in_id}``, if there is a parent, the name is ``{parent_name}_output_{out_id}``, ``list[str]``
     - **in_dims**: A list of the dimension for each inputs, ``list[list[int]]``
+    - **in_node**: A list of Node associated for each inputs, ``list[aidge_core.Node]``
     - **in_size**: A list of the size for each inputs, ``list[int]``
     - **in_chan**: A list of channel for each inputs, deduced by the dataformat, ``list[int]``
     - **in_height**: A list of height for each inputs, deduced by the dataformat, ``list[int]``
@@ -103,6 +104,7 @@ class ExportNode(ABC):
     - **in_dtype**: A list of type (Aidge format) for each input, ``List[:py:class:`aidge_core.dtype`]``
     - **in_cdtype**: A list of type (C/C++ format) for each input, ``List[str]``
     - **out_name**: unique name for each output, the name is ``{name}_output_{out_id}``, ``list[str]``
+    - **out_node**: A list of list of Node associated for each outputs, ``list[list[aidge_core.Node]]``
     - **nb_out**: Number of outputs, ``int``
     - **out_dims**: A list of the dimension for each inputs, ``list[list[int]]``
     - **out_size**: A list of the size for each outputs, ``list[int]``
@@ -111,8 +113,6 @@ class ExportNode(ABC):
     - **out_width**: A list of width for each outputs, deduced by the dataformat, ``list[int]``
     - **out_dtype**: A list of type (Aidge format) for each output, ``List[:py:class:`aidge_core.dtype`]``
     - **out_cdtype**: A list of type (C/C++ format) for each output, ``List[str]``
-    - **is_output**: True if the node is an output node, ``bool``
-    - **is_input**: True if the node is an input node, ``bool``
     - **mem_info**: True if mem_info is available for this node, ``bool``
     - **mem_info_size**: A list of memory size for each output, ``List[int]``
     - **mem_info_offset**: A list of offset to access each output, ``List[int]``
@@ -125,7 +125,7 @@ class ExportNode(ABC):
     """
 
     @abstractmethod
-    def __init__(self, aidge_node: aidge_core.Node, mem_info: List[dict]=None, is_input: bool=False, is_output: bool=False) -> None:
+    def __init__(self, aidge_node: aidge_core.Node, mem_info: List[dict]=None) -> None:
         """Create ExportNode and retrieve attributes from ``aidge_node``:
         """
 
@@ -139,15 +139,14 @@ class ExportNode(ABC):
         self.attributes["name"] = self.node.name()
         self.attributes["nb_in"] = self.node.get_nb_inputs()
         self.attributes["nb_out"] = self.node.get_nb_outputs()
-        # TODO : this check doesn't work if we export a subgraph !
-        # Maybe we need to add the graph we want to export as parameter !
-        # Actually may be mandatory for memory manager ...
-        self.attributes["is_input"] = is_input
-        self.attributes["is_output"] = is_output
+
+        # List of input nodes
         self.inputs = []
+        # List of output nodes
         self.outputs = []
 
         self.attributes["in_name"] = [None] * self.attributes["nb_in"]
+        self.attributes["in_node"] = [None] * self.attributes["nb_in"]
         self.attributes["in_dims"] = [None] * self.attributes["nb_in"]
         self.attributes["in_size"] = [None] * self.attributes["nb_in"]
         self.attributes["in_dformat"] = [None] * self.attributes["nb_in"]
@@ -159,6 +158,7 @@ class ExportNode(ABC):
         self.attributes["in_width"] = [None] * self.attributes["nb_in"]
 
         self.attributes["out_name"] = [None] * self.attributes["nb_out"]
+        self.attributes["out_nodes"] = [None] * self.attributes["nb_out"]
         self.attributes["out_dims"] = [None] * self.attributes["nb_out"]
         self.attributes["out_size"] = [None] * self.attributes["nb_out"]
         self.attributes["out_dformat"] = [None] * self.attributes["nb_out"]
@@ -189,6 +189,7 @@ class ExportNode(ABC):
             if self.operator.get_input(idx) is not None:
                 tensor = self.operator.get_input(idx)
                 self.attributes["in_name"][idx] = f"{self.attributes['name']}_input_{idx}" if parent_node is None else f"{parent_node.name()}_output_{out_id}"
+                self.attributes["in_node"][idx] = parent_node
                 self.attributes["in_dims"][idx] = tensor.dims()
                 self.attributes["in_size"][idx] = tensor.size()
                 self.attributes["in_dformat"][idx] = tensor.dformat()
@@ -205,11 +206,13 @@ class ExportNode(ABC):
             else:
                 raise RuntimeError(f"No input for {self.node.name()} at input {idx}, did you forget to forward dims?")
         for idx, list_child_node_in_id in enumerate(self.node.outputs()):
-            self.outputs += [node_in_id[0]
+            out_nodes = [node_in_id[0]
                              for node_in_id in list_child_node_in_id]
+            self.outputs += out_nodes
             if self.operator.get_output(idx) is not None:
                 tensor = self.operator.get_output(idx)
                 self.attributes["out_name"][idx] = f"{self.attributes['name']}_output_{idx}"
+                self.attributes["out_nodes"][idx] = out_nodes
                 self.attributes["out_dims"][idx] = tensor.dims()
                 self.attributes["out_size"][idx] = tensor.size()
                 self.attributes["out_dformat"][idx] = tensor.dformat()
diff --git a/aidge_core/export_utils/scheduler_export.py b/aidge_core/export_utils/scheduler_export.py
index 467504b63..6829832fe 100644
--- a/aidge_core/export_utils/scheduler_export.py
+++ b/aidge_core/export_utils/scheduler_export.py
@@ -76,8 +76,7 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib =
             if export_node is None:
                 raise RuntimeError(f"Could not find export node for {node.name()}[{node.type()}].")
             # Instanciate ExportNode
-            op = export_node(
-                node, mem_info[node], is_input, is_output)
+            op = export_node(node, mem_info[node])
 
             # For configuration files
             list_configs += op.export(dnn_folder)
-- 
GitLab