diff --git a/aidge_export_cpp/kernels/pad.hpp b/aidge_export_cpp/kernels/pad.hpp index 4e83257c1152b1963dd4b0eefc912216a729de7d..2c3030d9a36560d52789ffba9bafd424e33c5e19 100644 --- a/aidge_export_cpp/kernels/pad.hpp +++ b/aidge_export_cpp/kernels/pad.hpp @@ -3,44 +3,83 @@ #include "network/typedefs.hpp" #include "network/utils.hpp" +#include <cstdio> -// Todo add border value and border type (Reflect, Constant, Wrap...) and add the two missing pad value (bottom and right) - -template<int NB_BATCHES, int NB_CHANNELS, - int CHANNELS_HEIGHT, int CHANNELS_WIDTH, - int NB_OUTPUTS, - int OUTPUTS_HEIGHT, int OUTPUTS_WIDTH, - int PADDING_TOP, - int PADDING_LEFT, - int PADDING_BOTTOM, - int PADDING_RIGHT, - typename Input_T, typename Output_T> -__attribute__((always_inline)) inline -void pad_forward( - double borderValue, - const Input_T* __restrict inputs, - Output_T* __restrict outputs - ) -{ - const unsigned int oySize = CHANNELS_HEIGHT + PADDING_TOP + PADDING_BOTTOM; - const unsigned int oxSize = CHANNELS_WIDTH + PADDING_LEFT + PADDING_RIGHT; - - for (unsigned int batch = 0; batch < NB_BATCHES; ++batch) { - for (unsigned int ch = 0; ch < NB_CHANNELS; ++ch) { - const unsigned int preIndex = batch * NB_CHANNELS * CHANNELS_HEIGHT * CHANNELS_WIDTH + ch * CHANNELS_HEIGHT * CHANNELS_WIDTH; - - for (unsigned int oy = 0; oy < oySize; ++oy) { - for (unsigned int ox = 0; ox < oxSize; ++ox) { - const unsigned int outIndex = batch * NB_CHANNELS * oySize * oxSize + ch * oySize * oxSize + oy * oxSize + ox; - - outputs[outIndex] = borderValue; - - const unsigned int inputX = ox - PADDING_LEFT; - const unsigned int inputY = oy - PADDING_TOP; - - if (inputY >= 0 and inputY < CHANNELS_HEIGHT and inputX >= 0 and inputX < CHANNELS_WIDTH) - { - outputs[outIndex] = inputs[preIndex + inputY * CHANNELS_WIDTH + inputX]; +// TODO : add border value and border type (Reflect, Constant, Wrap...) and add +// the two missing pad value (bottom and right) + +template <int NB_BATCHES, + int NB_CHANNELS, + int CHANNELS_HEIGHT, + int CHANNELS_WIDTH, + int NB_OUTPUTS, + int OUTPUTS_HEIGHT, + int OUTPUTS_WIDTH, + int PADDING_TOP, + int PADDING_LEFT, + int PADDING_BOTTOM, + int PADDING_RIGHT, + typename Input_T, + typename Output_T> +__attribute__((always_inline)) inline void +pad_forward(double borderValue, + const Input_T *__restrict inputs, + Output_T *__restrict outputs) { + constexpr unsigned int oySize = + CHANNELS_HEIGHT + PADDING_TOP + PADDING_BOTTOM; + constexpr unsigned int oxSize = + CHANNELS_WIDTH + PADDING_LEFT + PADDING_RIGHT; + + constexpr unsigned int inputStrides[3] = { + NB_CHANNELS * CHANNELS_HEIGHT * CHANNELS_WIDTH, + CHANNELS_WIDTH * CHANNELS_HEIGHT, + CHANNELS_WIDTH}; + constexpr unsigned int outputStrides[3] = { + NB_CHANNELS * oySize * oxSize, + oySize * oxSize, + oxSize, + }; + + for (unsigned int batch = 0, inBatchOffset = 0, outBatchOffset = 0; + batch < NB_BATCHES; + ++batch, + inBatchOffset += inputStrides[0], + outBatchOffset += outputStrides[0]) { + + for (unsigned int ch = 0, + inChannelOffset = inBatchOffset, + outChannelOffset = outBatchOffset; + ch < NB_CHANNELS; + ++ch, + inChannelOffset += inputStrides[1], + outChannelOffset += outputStrides[1]) { + + for (int oY = 0, + oDimYOffset = outChannelOffset, + iY = oY - PADDING_TOP, + // iDimOffset won't be used unless iY >= 0 hence no risk + // of negative idx + iDimYOffset = inChannelOffset + iY * inputStrides[2]; + static_cast<unsigned int>(oY) < oySize; + ++oY, + ++iY, + iDimYOffset += inputStrides[2], + oDimYOffset += outputStrides[2]) { + + if (iY < 0 or iY >= CHANNELS_HEIGHT) { + for (Output_T *o = outputs + oDimYOffset; + o != outputs + oDimYOffset + outputStrides[2]; + ++o) { + *o = borderValue; + } + continue; + } + for (unsigned int oX = 0; oX < oxSize; ++oX) { + const int iX = static_cast<int>(oX - PADDING_LEFT); + if (iX < 0 or iX >= CHANNELS_WIDTH) { + outputs[oDimYOffset + oX] = borderValue; + } else { + outputs[oDimYOffset + oX] = inputs[iDimYOffset + iX]; } } } @@ -48,4 +87,4 @@ void pad_forward( } } -#endif // __AIDGE_EXPORT_CPP_KERNELS_PAD2D__ +#endif // __AIDGE_EXPORT_CPP_KERNELS_PAD2D__