From 84af78981722d87d422d1468071d60217a1c490b Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Thu, 10 Oct 2024 06:42:54 +0000 Subject: [PATCH] Add saveOutputs function to cpp export. --- .../static/include/network/utils.hpp | 107 +++++++++++++++++- 1 file changed, 106 insertions(+), 1 deletion(-) diff --git a/aidge_export_cpp/static/include/network/utils.hpp b/aidge_export_cpp/static/include/network/utils.hpp index 11c8e06..e2bfbe2 100644 --- a/aidge_export_cpp/static/include/network/utils.hpp +++ b/aidge_export_cpp/static/include/network/utils.hpp @@ -1,6 +1,13 @@ #ifndef __AIDGE_EXPORT_CPP_NETWORK_UTILS__ #define __AIDGE_EXPORT_CPP_NETWORK_UTILS__ +#ifdef SAVE_OUTPUTS +#include <sys/types.h> +#include <sys/stat.h> +#include <cstdio> // fprintf +#include <type_traits> // std::is_floating_point +#endif + /** * @brief Integer clamping * @param[in] v Value to be clamped @@ -41,4 +48,102 @@ int min (int lhs, int rhs) return (lhs <= rhs) ? lhs : rhs; } -#endif // __AIDGE_EXPORT_CPP_NETWORK_UTILS__ + +#ifdef SAVE_OUTPUTS +enum class Format { + Default, + NCHW, + NHWC, + CHWN, + NCDHW, + NDHWC, + CDHWN +}; + + +template<typename Output_T> +inline void saveOutputs( + int NB_OUTPUTS, + int OUTPUTS_HEIGHT, int OUTPUTS_WIDTH, + int OUTPUT_MEM_CONT_OFFSET, + int OUTPUT_MEM_CONT_SIZE, + int OUTPUT_MEM_WRAP_OFFSET, + int OUTPUT_MEM_WRAP_SIZE, + int OUTPUT_MEM_STRIDE, + const Output_T* __restrict outputs, + FILE* pFile, + Format format) +{ + // default is NHCW ! + if (format == Format::NHWC) { + fprintf(pFile, "("); + for(int oy = 0; oy < OUTPUTS_HEIGHT; oy++) { + fprintf(pFile, "("); + + for(int ox = 0; ox < OUTPUTS_WIDTH; ox++) { + fprintf(pFile, "("); + + const int oPos = (ox + OUTPUTS_WIDTH * oy); + int oOffset = OUTPUT_MEM_STRIDE * oPos; + + if (OUTPUT_MEM_WRAP_SIZE > 0 + && oOffset >= OUTPUT_MEM_CONT_SIZE) + { + oOffset += OUTPUT_MEM_WRAP_OFFSET - OUTPUT_MEM_CONT_OFFSET + - OUTPUT_MEM_CONT_SIZE; + } + + for (int output = 0; output < NB_OUTPUTS; output++) { + if (std::is_floating_point<Output_T>::value) + fprintf(pFile, "%f", static_cast<float>(outputs[oOffset + output])); + else + fprintf(pFile, "%d", static_cast<int>(outputs[oOffset + output])); + + fprintf(pFile, ", "); + } + + fprintf(pFile, "), \n"); + } + + fprintf(pFile, "), \n"); + } + + fprintf(pFile, ")\n"); + } + else if (format == Format::NCHW || format == Format::Default) { + for(int output = 0; output < NB_OUTPUTS; output++) { + fprintf(pFile, "%d:\n", output); + for(int oy = 0; oy < OUTPUTS_HEIGHT; oy++) { + for(int ox = 0; ox < OUTPUTS_WIDTH; ox++) { + const int oPos = (ox + OUTPUTS_WIDTH * oy); + int oOffset = OUTPUT_MEM_STRIDE * oPos; + if (OUTPUT_MEM_WRAP_SIZE > 0 + && oOffset >= OUTPUT_MEM_CONT_SIZE) + { + oOffset += OUTPUT_MEM_WRAP_OFFSET + - OUTPUT_MEM_CONT_OFFSET - OUTPUT_MEM_CONT_SIZE; + } + + if (std::is_floating_point<Output_T>::value) + fprintf(pFile, "%f", static_cast<float>(outputs[oOffset + output])); + else + fprintf(pFile, "%d", static_cast<int>(outputs[oOffset + output])); + + fprintf(pFile, " "); + } + + fprintf(pFile, "\n"); + } + + fprintf(pFile, "\n"); + } + + fprintf(pFile, "\n"); + } + else { + printf("Warning unsupported dataformat.\n"); + } +} +#endif // SAVE_OUTPUTS + +#endif // __AIDGE_EXPORT_CPP_NETWORK_UTILS__ -- GitLab