Skip to content
Snippets Groups Projects
Commit db126426 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Add slice operator to arm export.

parent 3b66800a
No related branches found
No related tags found
3 merge requests!17v0.1.0,!12v0.4.0,!11Export refactor
void aidge_slice_float32 (float* inputs,
float* outputs,
int* axes,
int* starts,
void aidge_slice_float32 (float* inputs,
float* outputs,
int* axes,
int* starts,
int* ends,
unsigned int input_dims,
unsigned int nb_axes)
......@@ -13,4 +11,4 @@ void aidge_slice_float32 (float* inputs,
for (int i = starts[axes[0] - 1]; i < ends[axes[0] - 1]; ++i) {
outputs[out_index++] = inputs[i];
}
}
\ No newline at end of file
}
......@@ -5,12 +5,12 @@
/* Slice layer */
{# For layer configuration -#}
#define {{ name|upper }}_NB_CHANNELS {{ nb_inputs }}
#define {{ name|upper }}_NB_OUTPUTS {{ nb_outputs }}
#define {{ name|upper }}_NB_CHANNELS {{ nb_in }}
#define {{ name|upper }}_NB_OUTPUTS {{ nb_out }}
#define {{ name|upper }}_NB_AXES {{ axes|length }}
static const int {{ name|upper }}_AXES[] = { {%- for axe in axes %}{{ axe }}, {% endfor -%} };
static const int {{ name|upper }}_STARTS[] = { {%- for start in starts %}{{ start }}, {% endfor -%} };
static const int {{ name|upper }}_ENDS[] = { {%- for end in ends %}{{ end }}, {% endfor -%} };
static const int {{ name|upper }}_AXES[] = { {{ axes | join(', ') }} };
static const int {{ name|upper }}_STARTS[] = { {{ starts | join(', ') }} };
static const int {{ name|upper }}_ENDS[] = { {{ ends | join(', ') }} };
#endif /* {{ name|upper }}_LAYER_H */
aidge_slice_{{dataformat}} ({{input_name}}, {{output_name}}, {{name|upper}}_AXES, {{name|upper}}_STARTS, {{name|upper}}_ENDS, {{name|upper}}_NB_AXES, {{name|upper}}_NB_CHANNELS);
\ No newline at end of file
aidge_slice_float32 ({{in_name[0]}}, {{out_name[0]}}, {{name|upper}}_AXES, {{name|upper}}_STARTS, {{name|upper}}_ENDS, {{name|upper}}_NB_AXES, {{name|upper}}_NB_CHANNELS);
......@@ -40,10 +40,6 @@ def export_params(name:str,
################### Actions ##################
##############################################
def set_up_output(name, dtype):
return f"{dtype}* {name} = ({dtype}*) mem + {name.upper()}_MEM_CONT_OFFSET;"
@ExportLibAidgeARM.register("Producer", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.any)))
class Producer_ARMCortexM(ExportNode):
......@@ -51,7 +47,6 @@ class Producer_ARMCortexM(ExportNode):
def __init__(self, node, mem_info, is_input, is_output):
super().__init__(node, mem_info, is_input, is_output)
self.values = np.array(self.operator.get_output(0))
print(f"{node.name()}: {self.values.shape}")
if len(self.values.shape) == 4: # Note: export in HWC
self.values = np.transpose(self.values, (0, 2, 3, 1))
# The following block of code is a dirty fix for FC
......@@ -399,6 +394,19 @@ class Atan_ARMCortexM(ExportNodeCpp):
]
@ExportLibAidgeARM.register("Slice", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
class Atan_ARMCortexM(ExportNodeCpp):
def __init__(self, node, mem_info, is_input, is_output):
super().__init__(node, mem_info, is_input, is_output)
self.config_template = str(ROOT / "_Aidge_Arm" / "templates" / "configuration" / "slice.jinja")
self.forward_template = str(ROOT / "_Aidge_Arm" / "templates" / "forward_call" / "slice.jinja")
self.include_list = []
self.kernels_to_copy = [
str(ROOT / "_Aidge_Arm" / "kernels" / "Slice" / "aidge_slice_float32.hpp"),
]
@ExportLibAidgeARM.register("Sigmoid", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
class Sigmoid_ARMCortexM(ExportNodeCpp):
def __init__(self, node, mem_info, is_input, is_output):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment