From f1d564040ab11e2cc0cb44ac23315d3b36c9625c Mon Sep 17 00:00:00 2001
From: Gallas Gaye <gallasko@gmail.com>
Date: Fri, 21 Feb 2025 10:22:47 +0100
Subject: [PATCH] feat: Add reshape export op

---
 aidge_export_cpp/kernels/reshape.hpp          | 27 +++++++++++++++++++
 aidge_export_cpp/operators.py                 | 13 +++++++++
 .../configuration/reshape_config.jinja        |  8 ++++++
 .../kernel_forward/reshape_forward.jinja      |  6 +++++
 4 files changed, 54 insertions(+)
 create mode 100644 aidge_export_cpp/kernels/reshape.hpp
 create mode 100644 aidge_export_cpp/templates/configuration/reshape_config.jinja
 create mode 100644 aidge_export_cpp/templates/kernel_forward/reshape_forward.jinja

diff --git a/aidge_export_cpp/kernels/reshape.hpp b/aidge_export_cpp/kernels/reshape.hpp
new file mode 100644
index 0000000..a5828da
--- /dev/null
+++ b/aidge_export_cpp/kernels/reshape.hpp
@@ -0,0 +1,27 @@
+#ifndef __AIDGE_EXPORT_CPP_KERNELS_RESHAPE__
+#define __AIDGE_EXPORT_CPP_KERNELS_RESHAPE__
+
+#include "network/typedefs.hpp"
+
+// Generic function for reshape and activation
+
+template<int M,
+         typename Input_T, typename Output_T>
+__attribute__((always_inline)) inline
+void reshape_forward (
+    const Input_T* __restrict, // First input is useless as it only dictate the resulting layout of the reshape
+    const Input_T* __restrict inputs2,
+    Output_T* __restrict outputs)
+{
+    // If inputs and outputs pointers are the same, the memory manager has already optimized this function so it is a no-op !
+    if (inputs2 == outputs)
+        return;
+
+    // A reshape in c++ world should equal to a Noop
+    // We only need to copy the input buffer to the output
+    for (int m = 0; m < M; ++m) {
+        outputs[m] = inputs2[m];
+    }
+}
+
+#endif  // __AIDGE_EXPORT_CPP_KERNELS_RESHAPE__
\ No newline at end of file
diff --git a/aidge_export_cpp/operators.py b/aidge_export_cpp/operators.py
index 54c3805..59ce94a 100644
--- a/aidge_export_cpp/operators.py
+++ b/aidge_export_cpp/operators.py
@@ -94,6 +94,19 @@ class ReLUCPP(ExportNodeCpp):
             str(ROOT / "kernels" / "rescaling.hpp")
         ]
 
+@ExportLibCpp.register("Reshape", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
+class ReshapeCPP(ExportNodeCpp):
+    def __init__(self, node, mem_info):
+        super().__init__(node, mem_info)
+        self.config_template = str(
+            ROOT / "templates" / "configuration" / "reshape_config.jinja")
+        self.forward_template = str(
+            ROOT / "templates" / "kernel_forward" / "reshape_forward.jinja")
+        self.include_list = []
+        self.kernels_to_copy = [
+            str(ROOT / "kernels" / "reshape.hpp"),
+        ]
+
 @ExportLibCpp.register("Conv2D", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
 class ConvCPP(ExportNodeCpp):
     def __init__(self, node, mem_info):
diff --git a/aidge_export_cpp/templates/configuration/reshape_config.jinja b/aidge_export_cpp/templates/configuration/reshape_config.jinja
new file mode 100644
index 0000000..041cf8a
--- /dev/null
+++ b/aidge_export_cpp/templates/configuration/reshape_config.jinja
@@ -0,0 +1,8 @@
+{#- For name header -#}
+#ifndef {{ name|upper }}_LAYER_H
+#define {{ name|upper }}_LAYER_H
+
+{% include "./_def_io.jinja" %}
+{% include "./_meminfo.jinja" %}
+{# For layer configuration -#}
+#define {{ name|upper }}_NB_ELTS {{ in_dims[0]|join('*') }}
diff --git a/aidge_export_cpp/templates/kernel_forward/reshape_forward.jinja b/aidge_export_cpp/templates/kernel_forward/reshape_forward.jinja
new file mode 100644
index 0000000..f9752bc
--- /dev/null
+++ b/aidge_export_cpp/templates/kernel_forward/reshape_forward.jinja
@@ -0,0 +1,6 @@
+{% filter indent(width=4, first=False) %}
+{% include "./_mem_offset.jinja" %}
+reshape_forward<{{name|upper}}_NB_ELTS>
+                 ({{in_name[0]}}, {{in_name[1]}}, {{out_name[0]}});
+{% include "./_save_outputs.jinja" %}
+{% endfilter %}
-- 
GitLab