From 341d5dab58755b60dd922f8246f3668bbdc1dcf8 Mon Sep 17 00:00:00 2001
From: Axel Farrugia <axel.farrugia@cea.fr>
Date: Tue, 18 Mar 2025 16:52:54 +0100
Subject: [PATCH] [Feat](Exports) Add label export in generate_main_cpp()
 function and change the default inputs and labels destination folder from
 "ROOT" to "ROOT/data"

---
 aidge_core/export_utils/generate_main.py      | 29 ++++++++++----
 aidge_core/export_utils/templates/main.jinja  | 38 ++++++++++++++++---
 .../export_utils/templates/main_compare.jinja |  4 +-
 3 files changed, 56 insertions(+), 15 deletions(-)

diff --git a/aidge_core/export_utils/generate_main.py b/aidge_core/export_utils/generate_main.py
index 57fc68bca..b8aa14517 100644
--- a/aidge_core/export_utils/generate_main.py
+++ b/aidge_core/export_utils/generate_main.py
@@ -1,8 +1,8 @@
 import aidge_core
 from pathlib import Path
-from aidge_core.export_utils import generate_file, data_conversion
+from aidge_core.export_utils import generate_file, data_conversion, generate_input_file
 
-def generate_main_cpp(export_folder: str, graph_view: aidge_core.GraphView, inputs_tensor=None) -> None:
+def generate_main_cpp(export_folder: str, graph_view: aidge_core.GraphView, inputs_tensor=None, labels=None) -> None:
     """
     Generate a C++ file to manage the forward pass of a model using the given graph structure.
 
@@ -18,7 +18,10 @@ def generate_main_cpp(export_folder: str, graph_view: aidge_core.GraphView, inpu
                        ordered input/output data within the computational graph.
     :type graph_view: aidge_core.graph_view
     :param inputs_tensor: **For future** argument to provide tensor to use in the main function, not implemented yet!
-    :type inputs_tensor: None
+                          By default, the input of the given graph will be exported.
+    :type inputs_tensor: aidge_core.Tensor
+    :param labels: Argument to provide labels tensor to generate and use in the main function. 
+    :type labels: aidge_core.Tensor
     :raises RuntimeError: If there is an inconsistency in the output arguments (names, data types, sizes),
                           indicating an internal bug in the graph representation.
     """
@@ -41,7 +44,18 @@ def generate_main_cpp(export_folder: str, graph_view: aidge_core.GraphView, inpu
             else:
                 aidge_core.Log.notice(f"No input tensor set for {in_name}, main generated will not be functionnal after code generation.")
         else:
-            aidge_core.export_utils.generate_input_file(export_folder=export_folder, array_name=in_name, tensor=input_tensor)
+            # Generate input file
+            generate_input_file(
+                 export_folder=str(Path(export_folder) / "data"), 
+                 array_name=in_name, 
+                 tensor=input_tensor)
+        if labels is not None:
+             # Generate labels
+             generate_input_file(
+                  export_folder=str(Path(export_folder) / "data"),
+                  array_name="labels",
+                  tensor=labels
+             )
 
     for out_node, out_id in gv_outputs:
         outputs_name.append(f"{out_node.name()}_output_{out_id}")
@@ -60,7 +74,8 @@ def generate_main_cpp(export_folder: str, graph_view: aidge_core.GraphView, inpu
         inputs_name=inputs_name,
         outputs_name=outputs_name,
         outputs_dtype=outputs_dtype,
-        outputs_size=outputs_size
+        outputs_size=outputs_size,
+        labels=(labels is not None)
     )
 
 
@@ -103,7 +118,7 @@ def generate_main_compare_cpp(export_folder: str, graph_view: aidge_core.GraphVi
             else:
                 aidge_core.Log.notice(f"No input tensor set for {in_name}, main generated will not be functionnal after code generation.")
         else:
-            aidge_core.export_utils.generate_input_file(export_folder=export_folder, array_name=in_name, tensor=input_tensor)
+            generate_input_file(export_folder=export_folder, array_name=in_name, tensor=input_tensor)
 
     for out_node, out_id in gv_outputs:
         out_name = f"{out_node.name()}_output_{out_id}"
@@ -114,7 +129,7 @@ def generate_main_compare_cpp(export_folder: str, graph_view: aidge_core.GraphVi
         if out_tensor is None or out_tensor.undefined() or not out_tensor.has_impl():
                 aidge_core.Log.notice(f"No input tensor set for {out_name}, main generated will not be functionnal after code generation.")
         else:
-            aidge_core.export_utils.generate_input_file(export_folder=export_folder, array_name=out_name+"_expected", tensor=out_tensor)
+            generate_input_file(export_folder=export_folder, array_name=out_name+"_expected", tensor=out_tensor)
 
     if len(outputs_name) != len(outputs_dtype) or len(outputs_name) != len(outputs_size):
             raise RuntimeError("FATAL: Output args list does not have the same length this is an internal bug.")
diff --git a/aidge_core/export_utils/templates/main.jinja b/aidge_core/export_utils/templates/main.jinja
index 697a97b53..b44f40a90 100644
--- a/aidge_core/export_utils/templates/main.jinja
+++ b/aidge_core/export_utils/templates/main.jinja
@@ -1,11 +1,14 @@
 
 #include <iostream>
 #include "forward.hpp"
-{% for name in inputs_name %}
-#include "{{ name }}.h"
-{% endfor %}
+{%- for name in inputs_name %}
+#include "data/{{ name }}.h"
+{%- endfor %}
+{%- if labels %}
+#include "data/labels.h"
+{%- endif %}
 
-{% set printf_formats = {
+{%- set printf_formats = {
     "double": "%lf",
     "float": "%f",
     "int8_t": "%hhd",
@@ -28,13 +31,36 @@ int main()
     // Call the forward function
     {{ func_name }}({{ inputs_name|join(", ") }}{% if inputs_name %}, {% endif %}&{{ outputs_name|join(", &") }});
 
-    // Print the results of each output
+    // Print the results
+    {%- if labels %}
+    int prediction;
+    int confidence;
+
     {%- for o in range(outputs_name | length) %}
+    prediction = 0;
+    confidence = {{ outputs_name[o] }}[0];
+
+    for (int o = 0; o < {{ outputs_size[0] }}; ++o) {
+        if ({{ outputs_name[0] }}[o] > confidence) {
+            prediction = o;
+            confidence = {{ outputs_name[0] }}[o];
+        }
+    }
+
+    printf("Prediction : %d (%d)\n", prediction, confidence);
+    printf("Label : %d\n", labels[{{ o }}]);
+
+    {%- endfor %}
+    {%- else %}
+    {%- for o in range(outputs_name | length) %}
+
     printf("{{ outputs_name[o] }}:\n");
     for (int o = 0; o < {{ outputs_size[o] }}; ++o) {
         printf("{{ printf_formats[outputs_dtype[o]] }} ", {{ outputs_name[o] }}[o]);
     }
     printf("\n");
-    {% endfor %}
+
+    {%- endfor %}
+    {%- endif %}
     return 0;
 }
diff --git a/aidge_core/export_utils/templates/main_compare.jinja b/aidge_core/export_utils/templates/main_compare.jinja
index 3cc4c986d..7113fb0f2 100644
--- a/aidge_core/export_utils/templates/main_compare.jinja
+++ b/aidge_core/export_utils/templates/main_compare.jinja
@@ -16,12 +16,12 @@
 
 // Inputs
 {% for name in inputs_name %}
-#include "{{ name }}.h"
+#include "data/{{ name }}.h"
 {% endfor %}
 
 // Outputs
 {% for name in outputs_name %}
-#include "{{ name }}_expected.h"
+#include "data/{{ name }}_expected.h"
 {% endfor %}
 
 int main()
-- 
GitLab