diff --git a/aidge_export_cpp/static/include/network/utils.hpp b/aidge_export_cpp/static/include/network/utils.hpp index 11c8e060eb5b7607e1d90be62114780828176d3d..e2bfbe2f35b3522d0600f10e8481e0879338f43a 100644 --- a/aidge_export_cpp/static/include/network/utils.hpp +++ b/aidge_export_cpp/static/include/network/utils.hpp @@ -1,6 +1,13 @@ #ifndef __AIDGE_EXPORT_CPP_NETWORK_UTILS__ #define __AIDGE_EXPORT_CPP_NETWORK_UTILS__ +#ifdef SAVE_OUTPUTS +#include <sys/types.h> +#include <sys/stat.h> +#include <cstdio> // fprintf +#include <type_traits> // std::is_floating_point +#endif + /** * @brief Integer clamping * @param[in] v Value to be clamped @@ -41,4 +48,102 @@ int min (int lhs, int rhs) return (lhs <= rhs) ? lhs : rhs; } -#endif // __AIDGE_EXPORT_CPP_NETWORK_UTILS__ + +#ifdef SAVE_OUTPUTS +enum class Format { + Default, + NCHW, + NHWC, + CHWN, + NCDHW, + NDHWC, + CDHWN +}; + + +template<typename Output_T> +inline void saveOutputs( + int NB_OUTPUTS, + int OUTPUTS_HEIGHT, int OUTPUTS_WIDTH, + int OUTPUT_MEM_CONT_OFFSET, + int OUTPUT_MEM_CONT_SIZE, + int OUTPUT_MEM_WRAP_OFFSET, + int OUTPUT_MEM_WRAP_SIZE, + int OUTPUT_MEM_STRIDE, + const Output_T* __restrict outputs, + FILE* pFile, + Format format) +{ + // default is NHCW ! + if (format == Format::NHWC) { + fprintf(pFile, "("); + for(int oy = 0; oy < OUTPUTS_HEIGHT; oy++) { + fprintf(pFile, "("); + + for(int ox = 0; ox < OUTPUTS_WIDTH; ox++) { + fprintf(pFile, "("); + + const int oPos = (ox + OUTPUTS_WIDTH * oy); + int oOffset = OUTPUT_MEM_STRIDE * oPos; + + if (OUTPUT_MEM_WRAP_SIZE > 0 + && oOffset >= OUTPUT_MEM_CONT_SIZE) + { + oOffset += OUTPUT_MEM_WRAP_OFFSET - OUTPUT_MEM_CONT_OFFSET + - OUTPUT_MEM_CONT_SIZE; + } + + for (int output = 0; output < NB_OUTPUTS; output++) { + if (std::is_floating_point<Output_T>::value) + fprintf(pFile, "%f", static_cast<float>(outputs[oOffset + output])); + else + fprintf(pFile, "%d", static_cast<int>(outputs[oOffset + output])); + + fprintf(pFile, ", "); + } + + fprintf(pFile, "), \n"); + } + + fprintf(pFile, "), \n"); + } + + fprintf(pFile, ")\n"); + } + else if (format == Format::NCHW || format == Format::Default) { + for(int output = 0; output < NB_OUTPUTS; output++) { + fprintf(pFile, "%d:\n", output); + for(int oy = 0; oy < OUTPUTS_HEIGHT; oy++) { + for(int ox = 0; ox < OUTPUTS_WIDTH; ox++) { + const int oPos = (ox + OUTPUTS_WIDTH * oy); + int oOffset = OUTPUT_MEM_STRIDE * oPos; + if (OUTPUT_MEM_WRAP_SIZE > 0 + && oOffset >= OUTPUT_MEM_CONT_SIZE) + { + oOffset += OUTPUT_MEM_WRAP_OFFSET + - OUTPUT_MEM_CONT_OFFSET - OUTPUT_MEM_CONT_SIZE; + } + + if (std::is_floating_point<Output_T>::value) + fprintf(pFile, "%f", static_cast<float>(outputs[oOffset + output])); + else + fprintf(pFile, "%d", static_cast<int>(outputs[oOffset + output])); + + fprintf(pFile, " "); + } + + fprintf(pFile, "\n"); + } + + fprintf(pFile, "\n"); + } + + fprintf(pFile, "\n"); + } + else { + printf("Warning unsupported dataformat.\n"); + } +} +#endif // SAVE_OUTPUTS + +#endif // __AIDGE_EXPORT_CPP_NETWORK_UTILS__