From 8947fddd80a71c5988ccabdcb86a0201421fd0cc Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Tue, 25 Feb 2025 15:36:46 +0000 Subject: [PATCH] Update DataFormat - upd: 'getDataFormatTranspose' to 'getPermutationMapping' - upd: use X-macro to keep DataFormat enum, string value, permutation with NCHW and description in synch - add: number of dimensions information to X-macro - upd: change 'DataFormatTransposeDict' array to 'permutationFromNCHW' and move it to a static function 'getPermutationFromNCHW' --- include/aidge/data/DataFormat.hpp | 87 +++++++++++++++++++------------ src/data/DataFormat.cpp | 87 ++++++++++++++++++++++--------- 2 files changed, 117 insertions(+), 57 deletions(-) diff --git a/include/aidge/data/DataFormat.hpp b/include/aidge/data/DataFormat.hpp index 77be8680d..4b3b949f0 100644 --- a/include/aidge/data/DataFormat.hpp +++ b/include/aidge/data/DataFormat.hpp @@ -17,63 +17,84 @@ #include "aidge/utils/logger/EnumString.hpp" +#define PERM(...) { __VA_ARGS__ } + +// (EnumName, StringLiteral, NumDims, Permutation (0-indexed), Description) +#define LIST_DATAFORMAT_ATTR(X) \ + X(Default, "Default", 0, PERM(), "Unspecified format: interpretation is implementation-dependent"), \ + X(Any, "Any", 0, PERM(), "Any format is valid"), \ + X(NCHW, "NCHW", 4, PERM(0, 1, 2, 3), "4D format: [batch][channel][height][width]"), \ + X(NHWC, "NHWC", 4, PERM(0, 2, 3, 1), "4D format: [batch][height][width][channel]"), \ + X(CHWN, "CHWN", 4, PERM(1, 2, 3, 0), "4D format: [channel][height][width][batch]"), \ + X(NCDHW, "NCDHW", 5, PERM(0, 1, 2, 3, 4), "5D format: [batch][channel][depth][height][width]"), \ + X(NDHWC, "NDHWC", 5, PERM(0, 2, 3, 4, 1), "5D format: [batch][depth][height][width][channel]"), \ + X(CDHWN, "CDHWN", 5, PERM(1, 2, 3, 4, 0), "5D format: [channel][depth][height][width][batch]") + +#define NB_DFORMAT 8 + namespace Aidge { /** - * @brief Enumeration of supported tensor data layouts + * @brief Enumeration of supported tensor data layouts. * - * Represents different memory layout formats for multi-dimensional tensors: + * Represents different memory layouts for multi-dimensional tensors. + * The dimensions typically represent: * - N: Batch size * - C: Channels * - H: Height * - W: Width * - D: Depth (for 3D tensors) + * + * The enum values are generated via the X-macro. */ -enum class DataFormat { - Default, ///< Default format, implementation dependent - NCHW, ///< 4D format: [batch][channel][height][width] - NHWC, ///< 4D format: [batch][height][width][channel] - CHWN, ///< 4D format: [channel][height][width][batch] - NCDHW, ///< 5D format: [batch][channel][depth][height][width] - NDHWC, ///< 5D format: [batch][depth][height][width][channel] - CDHWN, ///< 5D format: [channel][depth][height][width][batch] - Any ///< Unspecified format +enum class DataFormat : int { +#define X(enumName, str, nb, arr, desc) enumName + LIST_DATAFORMAT_ATTR(X) +#undef X }; -using DataFormatTranspose = std::array<std::size_t, 5>; - /** - * @brief Dictionary of transpose operations between different formats + * @brief Alias for a fixed-size array representing a permutation. * - * Contains permutation arrays to convert between different data formats. - * @warning In this array only, dimension index starts at 1 - * (0 is reserved as default value). + * A DataFormatTranspose is an array of 5 size_t values that represents a permutation mapping + * (0-indexed) for dimension reordering. Only the first NumDims elements are valid. */ -constexpr std::array<DataFormatTranspose, 7> DataFormatTransposeDict = {{ - {}, // Default - {1, 2, 3, 4}, // NCHW - {1, 3, 4, 2}, // NHWC - {2, 3, 4, 1}, // CHWN - {1, 2, 3, 4, 5}, // NCDHW - {1, 3, 4, 5, 2}, // NDHWC - {2, 3, 4, 5, 1} // CDHWN -}}; +using DataFormatTranspose = std::array<std::size_t, 5>; /** - * @brief Get the permutation array for converting between data formats + * @brief Compute the permutation mapping to convert tensor data from a source format + * to a destination format. + * + * This function performs the following steps: + * 1. If either format is DataFormat::Any, a runtime error is thrown. + * 2. If either format is DataFormat::Default or both dformat are equal, the identity permutation is returned. + * 3. Otherwise, the function retrieves the permutation mappings (from default NCHW) + * for both the source and destination formats. + * 4. It computes the inverse of the source mapping (from source to default ordering). + * 5. Finally, it composes the inverse with the destination mapping to produce the + * permutation mapping from source to destination. + * + * @param src The source data format. + * @param dst The destination data format. + * @return constexpr DataFormatTranspose The computed permutation array (0-indexed). * - * @param src Source data format - * @param dst Destination data format - * @return DataFormatTranspose Permutation array to achieve the format conversion + * @throws std::runtime_error if either src or dst is DataFormat::Any. + * @pre The source and destination formats must have the same number of dimensions. */ -DataFormatTranspose getDataFormatTranspose(const DataFormat& src, const DataFormat& dst); +DataFormatTranspose getPermutationMapping(const DataFormat& src, const DataFormat& dst); } // namespace Aidge namespace { template <> -const char* const EnumStrings<Aidge::DataFormat>::data[] - = {"Default", "NCHW", "NHWC", "CHWN", "NCDHW", "NDHWC", "CDHWN", "Any"}; +struct EnumStrings<Aidge::DataFormat> { + static const char* const data[]; +}; +constexpr const char* const EnumStrings<Aidge::DataFormat>::data[] = { +#define X(EnumName, Str, NumDims, Perm, Desc) Str + LIST_DATAFORMAT_ATTR(X) +#undef X +}; } namespace Aidge { diff --git a/src/data/DataFormat.cpp b/src/data/DataFormat.cpp index 466da86c4..8b7460b3d 100644 --- a/src/data/DataFormat.cpp +++ b/src/data/DataFormat.cpp @@ -9,34 +9,73 @@ * ********************************************************************************/ +#include <array> +#include <cstddef> // std::size_t + #include "aidge/data/DataFormat.hpp" +#include "aidge/utils/ErrorHandling.hpp" + + +namespace Aidge { + +/** + * @brief Get the permutation array mapping from the default (NCHW) ordering to a given data format. + * + * @param dformat The target data format. + * @return const DataFormatTranspose& The permutation array (0-indexed). + * + * @note For DataFormat::Default and DataFormat::Any, an empty permutation is returned. + */ +static const DataFormatTranspose& getPermutationFromNCHW(DataFormat dformat) { + constexpr static const std::array<DataFormatTranspose, NB_DFORMAT> permutationFromNCHW = {{ +#define X(EnumName, Str, NumDims, Perm, Desc) Perm + LIST_DATAFORMAT_ATTR(X) +#undef X + }}; + return permutationFromNCHW[static_cast<std::size_t>(dformat)]; +} + + +/** + * @brief Retrieve the number of dimensions for a given data format. + * + * @param dformat The data format. + * @return constexpr std::size_t The number of dimensions. + */ +static std::size_t getNbDimensions(DataFormat dformat) { + constexpr static const std::array<std::size_t, NB_DFORMAT> nbDimensions = { +#define X(name, str, nb, arr, desc) nb + LIST_DATAFORMAT_ATTR(X) +#undef X + }; + return nbDimensions[static_cast<std::size_t>(dformat)]; +} + +DataFormatTranspose getPermutationMapping(const DataFormat& src, const DataFormat& dst) { + AIDGE_ASSERT((src != DataFormat::Any && dst != DataFormat::Any), "Permutation is not defined for DataFormat::Any"); + if (src == DataFormat::Default || dst == DataFormat::Default || src == dst) { + return {0,1,2,3,4}; + } + const std::size_t nbDims = getNbDimensions(src); + AIDGE_ASSERT(nbDims == getNbDimensions(dst), "Incompatible format conversion. Current and new data format must have the same number of dimensions."); + + // Get permutation from default (NCHW) to source and destination. + const DataFormatTranspose& nchw_to_src = getPermutationFromNCHW(src); + const DataFormatTranspose& nchw_to_dst = getPermutationFromNCHW(dst); -Aidge::DataFormatTranspose Aidge::getDataFormatTranspose(const DataFormat& src, const DataFormat& dst) { - // Permutation array from default format to src format - const auto srcDefToFormat = DataFormatTransposeDict[static_cast<int>(src)]; - // Permutation array from default format to dst format - const auto dstDefToFormat = DataFormatTransposeDict[static_cast<int>(dst)]; - // Compute permutation array from src format to default format: - DataFormatTranspose srcFormatToDef{}; - for (size_t i = 0; i < srcDefToFormat.size(); ++i) { - if (srcDefToFormat[i] > 0) { - srcFormatToDef[srcDefToFormat[i] - 1] = i; - } - else { - srcFormatToDef[i] = i; - } + // Compute inverse permutation: mapping from source format to default (NCHW). + DataFormatTranspose src_to_nchw{}; + for (std::size_t i = 0; i < nbDims; ++i) { + // Since the permutations are 0-indexed, simply invert. + src_to_nchw[nchw_to_src[i]] = i; } - // Compute permutation array from src format to dst format: - DataFormatTranspose srcToDst{}; - for (size_t i = 0; i < dstDefToFormat.size(); ++i) { - if (dstDefToFormat[srcFormatToDef[i]] > 0) { - srcToDst[i] = dstDefToFormat[srcFormatToDef[i]] - 1; - } - else { - srcToDst[i] = srcFormatToDef[i]; - } + // Compute mapping from source format to destination format. + DataFormatTranspose src_to_dst{}; + for (std::size_t i = 0; i < nbDims; ++i) { + src_to_dst[i] = nchw_to_dst[src_to_nchw[i]]; } - return srcToDst; + return src_to_dst; } +} // namespace Aidge \ No newline at end of file -- GitLab