From bb81ff6e1fde51c03a4007c6f4708d059d99e91d Mon Sep 17 00:00:00 2001 From: vbaudelet <baudelet.vin@gmail.com> Date: Wed, 5 Mar 2025 15:49:13 +0100 Subject: [PATCH] [issue 24] Refactoring kernels --- .../_Aidge_Arm/kernels/Add/aidge_add.hpp | 8 ++ .../_Aidge_Arm/kernels/Atan/aidge_atan.hpp | 4 +- .../kernels/BatchNorm/aidge_batchnorm.hpp | 36 +++++++ ...32.c => aidge_batchnorm2d_chw_float32.hpp} | 0 .../_Aidge_Arm/kernels/Div/aidge_div.hpp | 9 ++ .../kernels/MatMul/aidge_matmul.hpp | 95 +++++++++++++++++++ .../_Aidge_Arm/kernels/Mul/aidge_mul.hpp | 8 ++ .../_Aidge_Arm/kernels/Relu/aidge_relu.hpp | 8 ++ .../kernels/Reshape/aidge_reshape.hpp | 7 ++ ..._float32.c => aidge_reshape_chw_float32.h} | 0 .../kernels/Sigmoid/aidge_sigmoid.hpp | 10 ++ .../kernels/Softmax/aidge_softmax.hpp | 48 ++++++++++ .../_Aidge_Arm/kernels/Sub/aidge_sub.hpp | 8 ++ 13 files changed, 239 insertions(+), 2 deletions(-) create mode 100644 aidge_export_arm_cortexm/_Aidge_Arm/kernels/Add/aidge_add.hpp create mode 100644 aidge_export_arm_cortexm/_Aidge_Arm/kernels/BatchNorm/aidge_batchnorm.hpp rename aidge_export_arm_cortexm/_Aidge_Arm/kernels/BatchNorm/{aidge_batchnorm2d_chw_float32.c => aidge_batchnorm2d_chw_float32.hpp} (100%) create mode 100644 aidge_export_arm_cortexm/_Aidge_Arm/kernels/Div/aidge_div.hpp create mode 100644 aidge_export_arm_cortexm/_Aidge_Arm/kernels/MatMul/aidge_matmul.hpp create mode 100644 aidge_export_arm_cortexm/_Aidge_Arm/kernels/Mul/aidge_mul.hpp create mode 100644 aidge_export_arm_cortexm/_Aidge_Arm/kernels/Relu/aidge_relu.hpp create mode 100644 aidge_export_arm_cortexm/_Aidge_Arm/kernels/Reshape/aidge_reshape.hpp rename aidge_export_arm_cortexm/_Aidge_Arm/kernels/Reshape/{aidge_reshape_chw_float32.c => aidge_reshape_chw_float32.h} (100%) create mode 100644 aidge_export_arm_cortexm/_Aidge_Arm/kernels/Sigmoid/aidge_sigmoid.hpp create mode 100644 aidge_export_arm_cortexm/_Aidge_Arm/kernels/Softmax/aidge_softmax.hpp create mode 100644 aidge_export_arm_cortexm/_Aidge_Arm/kernels/Sub/aidge_sub.hpp diff --git a/aidge_export_arm_cortexm/_Aidge_Arm/kernels/Add/aidge_add.hpp b/aidge_export_arm_cortexm/_Aidge_Arm/kernels/Add/aidge_add.hpp new file mode 100644 index 0000000..57cbc6f --- /dev/null +++ b/aidge_export_arm_cortexm/_Aidge_Arm/kernels/Add/aidge_add.hpp @@ -0,0 +1,8 @@ +template <unsigned int SIZE, typename T> +__attribute__((always_inline)) inline static +void aidge_add(T* __restrict input_a, T* __restrict input_b, T* __restrict output) { + for (unsigned int i = 0; i < SIZE; ++i) { + // Note : no cast to get compiler warning if we lose precision during auto cast! + output[i] = input_a[i] + input_b[i]; + } +} \ No newline at end of file diff --git a/aidge_export_arm_cortexm/_Aidge_Arm/kernels/Atan/aidge_atan.hpp b/aidge_export_arm_cortexm/_Aidge_Arm/kernels/Atan/aidge_atan.hpp index d4da329..cbbe4e3 100644 --- a/aidge_export_arm_cortexm/_Aidge_Arm/kernels/Atan/aidge_atan.hpp +++ b/aidge_export_arm_cortexm/_Aidge_Arm/kernels/Atan/aidge_atan.hpp @@ -1,8 +1,8 @@ #include <cmath> -template <unsigned int SIZE, typename Input_T, typename Output_T> +template <unsigned int SIZE, typename T> __attribute__((always_inline)) inline static -void aidge_atan(Input_T* __restrict input, Output_T* __restrict output) { +void aidge_atan(T* __restrict input, T* __restrict output) { for (unsigned int i = 0; i < SIZE; ++i) { // Note : no cast to get compiler warning if we lose precision during auto cast! output[i] = std::atan(input[i]); diff --git a/aidge_export_arm_cortexm/_Aidge_Arm/kernels/BatchNorm/aidge_batchnorm.hpp b/aidge_export_arm_cortexm/_Aidge_Arm/kernels/BatchNorm/aidge_batchnorm.hpp new file mode 100644 index 0000000..9accb60 --- /dev/null +++ b/aidge_export_arm_cortexm/_Aidge_Arm/kernels/BatchNorm/aidge_batchnorm.hpp @@ -0,0 +1,36 @@ +#include <cmath> + +template < + typename T, + typename MeanVar_T, + typename ScaleBias_T, + typename SpatialDims_T, + unsigned int NB_Channels, + unsigned int NB_SpatialDims + > +__attribute__((always_inline)) inline static +void aidge_batchnorm(T* __restrict inputs, + T* __restrict outputs, + MeanVar_T* __restrict input_mean, + MeanVar_T* __restrict input_var, + ScaleBias_T* __restrict scale, + ScaleBias_T* __restrict bias, + SpatialDims_T* __restrict spatial_dims, + float epsilon) +{ + int featureMapSize = 1; + for (int index = 0; index < NB_SpatialDims; ++index){ + featureMapSize *= spatial_dims[index]; + } + for (int current_channel = 0; current_channel < NB_Channels; ++current_channel){ + int ioIndex = current_channel * featureMapSize; + + for (int index = ioIndex; index < (ioIndex + featureMapSize); index++ ){ + outputs[index] = bias[current_channel]; + } + float var = sqrt(input_var[current_channel] + epsilon); + for (int current_feature = 0; current_feature < featureMapSize; ++current_feature){ + outputs[ioIndex + current_feature] += scale[current_channel] * (inputs[ioIndex + current_feature] - input_mean[current_channel]) / var; + } + } +} \ No newline at end of file diff --git a/aidge_export_arm_cortexm/_Aidge_Arm/kernels/BatchNorm/aidge_batchnorm2d_chw_float32.c b/aidge_export_arm_cortexm/_Aidge_Arm/kernels/BatchNorm/aidge_batchnorm2d_chw_float32.hpp similarity index 100% rename from aidge_export_arm_cortexm/_Aidge_Arm/kernels/BatchNorm/aidge_batchnorm2d_chw_float32.c rename to aidge_export_arm_cortexm/_Aidge_Arm/kernels/BatchNorm/aidge_batchnorm2d_chw_float32.hpp diff --git a/aidge_export_arm_cortexm/_Aidge_Arm/kernels/Div/aidge_div.hpp b/aidge_export_arm_cortexm/_Aidge_Arm/kernels/Div/aidge_div.hpp new file mode 100644 index 0000000..3fa4772 --- /dev/null +++ b/aidge_export_arm_cortexm/_Aidge_Arm/kernels/Div/aidge_div.hpp @@ -0,0 +1,9 @@ +template <unsigned int SIZE, typename T> +__attribute__((always_inline)) inline static +void aidge_div(T* __restrict input_a, T* __restrict input_b, T* __restrict output) { + for (unsigned int i = 0; i < SIZE; ++i) { + // Note : no cast to get compiler warning if we lose precision during auto cast! + // [TODO] : input_b[i] = 0 + output[i] = input_a[i] / input_b[i]; + } +} \ No newline at end of file diff --git a/aidge_export_arm_cortexm/_Aidge_Arm/kernels/MatMul/aidge_matmul.hpp b/aidge_export_arm_cortexm/_Aidge_Arm/kernels/MatMul/aidge_matmul.hpp new file mode 100644 index 0000000..2a19356 --- /dev/null +++ b/aidge_export_arm_cortexm/_Aidge_Arm/kernels/MatMul/aidge_matmul.hpp @@ -0,0 +1,95 @@ +template < + typename T, + typename Dim_T, + typename Size_T + > +__attribute__((always_inline)) inline static +void aidge_matmul(T* __restrict input_a, + T* __restrict input_b, + T* __restrict output, + Dim_T* __restrict dim_a, + Dim_T* __restrict dim_b, + Dim_T* __restrict dim_output, + Size_T __restrict size_aDim, + Size_T __restrict size_bDim, + Size_T __restrict size_outputDim) +{ + int ndim_a[size_outputDim]; + int ndim_b[size_outputDim]; + if (size_aDim == 1) { + ndim_a[0] = 1; + ndim_a[1] = dim_a[0]; + } + if (size_bDim == 1) { + ndim_b[0] = dim_b[0]; + ndim_b[1] = 1; + } + for (int i = 0; i < size_outputDim; ++i) { + int idx = size_outputDim - size_aDim; + ndim_a[i] = (i < idx) ? 1 : dim_a[i - idx]; + } + for (int i = 0; i < size_outputDim; ++i) { + int idx = size_outputDim - size_bDim; + ndim_b[i] = (i < idx) ? 1 : dim_b[i - idx]; + } + + int stride_post0[size_outputDim - 2]; + int stride_post1[size_outputDim - 2]; + int stride_step0[size_outputDim - 2]; + int stride_step1[size_outputDim - 2]; + + if (size_outputDim > 2) { + stride_post0[size_outputDim - 3] = 1; + stride_post1[size_outputDim - 3] = 1; + for (int i = size_outputDim - 4; i != -1; --i) { + stride_post0[i] = stride_post0[i + 1] * ndim_a[i + 1]; + stride_post1[i] = stride_post1[i + 1] * ndim_b[i + 1]; + } + for (int i = 0; i < size_outputDim - 2; ++i) { + stride_step0[i] = (ndim_a[i] == 1) ? 1 - stride_post0[i] : 1; + stride_step1[i] = (ndim_b[i] == 1) ? 1 - stride_post1[i] : 1; + } + } + + int nbMatrices = 1; + for (int i = size_outputDim - 3; i >= 0; --i) { + nbMatrices *= dim_output[i]; + } + int dim = size_outputDim - 3; + + int offsetIn0 = 0; + int offsetIn1 = 0; + int offsetOut = 0; + const int n = ndim_a[size_outputDim - 2]; + const int k = ndim_a[size_outputDim - 1]; + const int m = ndim_b[size_outputDim - 1]; + const int matrix0Size = n * k; + const int matrix1Size = k * m; + const int matrixOutSize = n * m; + + for(int stack = 0; stack < nbMatrices;){ + for (int i = 0; i < n; ++i) { + for (int j = 0; j < m; ++j) { + float sum = 0; + for (int l = 0; l < k; ++l) { + sum += (input_a[ offsetIn0*matrix0Size + i*k + l] * input_b[offsetIn1*matrix1Size + l*m + j]); + } + output[ offsetOut*matrixOutSize + i*m + j] = sum; + } + } + + if (++stack < nbMatrices) { + int tmp_stack = stack; + while(tmp_stack % dim_output[dim] == 0) { + tmp_stack /= dim_output[dim]; + dim--; + } + offsetIn0 += stride_step0[dim]; + offsetIn1 += stride_step1[dim]; + ++offsetOut; + dim = size_outputDim -3; + } + + } + +} \ No newline at end of file diff --git a/aidge_export_arm_cortexm/_Aidge_Arm/kernels/Mul/aidge_mul.hpp b/aidge_export_arm_cortexm/_Aidge_Arm/kernels/Mul/aidge_mul.hpp new file mode 100644 index 0000000..7f522b9 --- /dev/null +++ b/aidge_export_arm_cortexm/_Aidge_Arm/kernels/Mul/aidge_mul.hpp @@ -0,0 +1,8 @@ +template <unsigned int SIZE, typename T> +__attribute__((always_inline)) inline static +void aidge_mul(T* __restrict input_a, T* __restrict input_b, T* __restrict output) { + for (unsigned int i = 0; i < SIZE; ++i) { + // Note : no cast to get compiler warning if we lose precision during auto cast! + output[i] = input_a[i] * input_b[i]; + } +} \ No newline at end of file diff --git a/aidge_export_arm_cortexm/_Aidge_Arm/kernels/Relu/aidge_relu.hpp b/aidge_export_arm_cortexm/_Aidge_Arm/kernels/Relu/aidge_relu.hpp new file mode 100644 index 0000000..9b84a33 --- /dev/null +++ b/aidge_export_arm_cortexm/_Aidge_Arm/kernels/Relu/aidge_relu.hpp @@ -0,0 +1,8 @@ + +template <unsigned int SIZE, typename T> +__attribute__((always_inline)) inline static +void aidge_relu(T* __restrict input, T* __restrict output) { + for (unsigned int i = 0; i < SIZE; ++i) { + output[i] = (input[i] < 0.0f) ? 0.0f : input[i]; + } +} \ No newline at end of file diff --git a/aidge_export_arm_cortexm/_Aidge_Arm/kernels/Reshape/aidge_reshape.hpp b/aidge_export_arm_cortexm/_Aidge_Arm/kernels/Reshape/aidge_reshape.hpp new file mode 100644 index 0000000..8dfc52d --- /dev/null +++ b/aidge_export_arm_cortexm/_Aidge_Arm/kernels/Reshape/aidge_reshape.hpp @@ -0,0 +1,7 @@ +template <unsigned int SIZE, typename T> +__attribute__((always_inline)) inline static +void aidge_reshape(T* __restrict input, T* __restrict output) { + for (unsigned int i = 0; i < SIZE; ++i) { + output[i] = input[i]; + } +} \ No newline at end of file diff --git a/aidge_export_arm_cortexm/_Aidge_Arm/kernels/Reshape/aidge_reshape_chw_float32.c b/aidge_export_arm_cortexm/_Aidge_Arm/kernels/Reshape/aidge_reshape_chw_float32.h similarity index 100% rename from aidge_export_arm_cortexm/_Aidge_Arm/kernels/Reshape/aidge_reshape_chw_float32.c rename to aidge_export_arm_cortexm/_Aidge_Arm/kernels/Reshape/aidge_reshape_chw_float32.h diff --git a/aidge_export_arm_cortexm/_Aidge_Arm/kernels/Sigmoid/aidge_sigmoid.hpp b/aidge_export_arm_cortexm/_Aidge_Arm/kernels/Sigmoid/aidge_sigmoid.hpp new file mode 100644 index 0000000..eb287dd --- /dev/null +++ b/aidge_export_arm_cortexm/_Aidge_Arm/kernels/Sigmoid/aidge_sigmoid.hpp @@ -0,0 +1,10 @@ +#include <math.h> + +template <unsigned int SIZE, typename T> +__attribute__((always_inline)) inline static +void aidge_sigmoid(T* __restrict inputs, T* __restrict outputs) { + for (unsigned int i = 0; i < SIZE; ++i) { + // Note : no cast to get compiler warning if we lose precision during auto cast! + outputs[i] = 1 / (1 + exp(-inputs[i]) ); + } +} \ No newline at end of file diff --git a/aidge_export_arm_cortexm/_Aidge_Arm/kernels/Softmax/aidge_softmax.hpp b/aidge_export_arm_cortexm/_Aidge_Arm/kernels/Softmax/aidge_softmax.hpp new file mode 100644 index 0000000..00db1f2 --- /dev/null +++ b/aidge_export_arm_cortexm/_Aidge_Arm/kernels/Softmax/aidge_softmax.hpp @@ -0,0 +1,48 @@ +#include <math.h> + +#define MAX_DIMS_AXIS_SIZE 128 /** TODO : is 128 enough or to big ? | Other possibility is to use a shared buffer as param, but this could have a side effect on Aidge's overall mechanics **/ +float exps[MAX_DIMS_AXIS_SIZE]; + +template < + typename T, + typename Dim_T, + typename Size_T + > +__attribute__((always_inline)) inline static +void aidge_softmax(T* __restrict input, + T* __restrict output, + Dim_T* __restrict dims, + Size_T __restrict dim_size, + int axis) +{ + axis += (axis >= 0 ) ? 0 : dim_size; + + int postAxisElems = 1; + for (unsigned int index = axis+1; index < dim_size; ++index) { + postAxisElems *= dims[index]; + } + int preAxisElems = 1; + for (int index = 0; index < axis; ++index) { + preAxisElems *= dims[index]; + } + + + for (int i = 0; i < preAxisElems; ++i) { + for (int j = 0; j < postAxisElems; ++j) { + float sumExp = 0.0; + + int baseIdx = i * dims[axis] * postAxisElems + j; + + for (int k = 0; k < dims[axis]; ++k) { + int inIdx = baseIdx + k * postAxisElems; + exps[k] = exp(input[inIdx]); + sumExp += exps[k]; + } + + for (int k = 0; k < dims[axis]; ++k) { + int inIdx = baseIdx + k * postAxisElems; + output[inIdx] = exps[k] / sumExp; + } + } + } +} \ No newline at end of file diff --git a/aidge_export_arm_cortexm/_Aidge_Arm/kernels/Sub/aidge_sub.hpp b/aidge_export_arm_cortexm/_Aidge_Arm/kernels/Sub/aidge_sub.hpp new file mode 100644 index 0000000..8bbc194 --- /dev/null +++ b/aidge_export_arm_cortexm/_Aidge_Arm/kernels/Sub/aidge_sub.hpp @@ -0,0 +1,8 @@ +template <unsigned int SIZE, typename T> +__attribute__((always_inline)) inline static +void aidge_sub(T* __restrict input_a, T* __restrict input_b, T* __restrict output) { + for (unsigned int i = 0; i < SIZE; ++i) { + // Note : no cast to get compiler warning if we lose precision during auto cast! + output[i] = input_a[i] - input_b[i]; + } +} \ No newline at end of file -- GitLab