Skip to content
Snippets Groups Projects
Commit f1d56404 authored by Gallas Gaye's avatar Gallas Gaye Committed by Gallas Gaye
Browse files

feat: Add reshape export op

parent f3a546b9
No related branches found
No related tags found
2 merge requests!39Update 0.2.1 -> 0.3.0,!31Add missing operators for basic onnx model exporting
#ifndef __AIDGE_EXPORT_CPP_KERNELS_RESHAPE__
#define __AIDGE_EXPORT_CPP_KERNELS_RESHAPE__
#include "network/typedefs.hpp"
// Generic function for reshape and activation
template<int M,
typename Input_T, typename Output_T>
__attribute__((always_inline)) inline
void reshape_forward (
const Input_T* __restrict, // First input is useless as it only dictate the resulting layout of the reshape
const Input_T* __restrict inputs2,
Output_T* __restrict outputs)
{
// If inputs and outputs pointers are the same, the memory manager has already optimized this function so it is a no-op !
if (inputs2 == outputs)
return;
// A reshape in c++ world should equal to a Noop
// We only need to copy the input buffer to the output
for (int m = 0; m < M; ++m) {
outputs[m] = inputs2[m];
}
}
#endif // __AIDGE_EXPORT_CPP_KERNELS_RESHAPE__
\ No newline at end of file
......@@ -94,6 +94,19 @@ class ReLUCPP(ExportNodeCpp):
str(ROOT / "kernels" / "rescaling.hpp")
]
@ExportLibCpp.register("Reshape", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
class ReshapeCPP(ExportNodeCpp):
def __init__(self, node, mem_info):
super().__init__(node, mem_info)
self.config_template = str(
ROOT / "templates" / "configuration" / "reshape_config.jinja")
self.forward_template = str(
ROOT / "templates" / "kernel_forward" / "reshape_forward.jinja")
self.include_list = []
self.kernels_to_copy = [
str(ROOT / "kernels" / "reshape.hpp"),
]
@ExportLibCpp.register("Conv2D", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
class ConvCPP(ExportNodeCpp):
def __init__(self, node, mem_info):
......
{#- For name header -#}
#ifndef {{ name|upper }}_LAYER_H
#define {{ name|upper }}_LAYER_H
{% include "./_def_io.jinja" %}
{% include "./_meminfo.jinja" %}
{# For layer configuration -#}
#define {{ name|upper }}_NB_ELTS {{ in_dims[0]|join('*') }}
{% filter indent(width=4, first=False) %}
{% include "./_mem_offset.jinja" %}
reshape_forward<{{name|upper}}_NB_ELTS>
({{in_name[0]}}, {{in_name[1]}}, {{out_name[0]}});
{% include "./_save_outputs.jinja" %}
{% endfilter %}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment