From 8947fddd80a71c5988ccabdcb86a0201421fd0cc Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Tue, 25 Feb 2025 15:36:46 +0000
Subject: [PATCH] Update DataFormat - upd: 'getDataFormatTranspose' to
 'getPermutationMapping' - upd: use X-macro to keep DataFormat enum, string
 value, permutation with NCHW and description in synch - add: number of
 dimensions information to X-macro - upd: change 'DataFormatTransposeDict'
 array to 'permutationFromNCHW' and move it to a static function
 'getPermutationFromNCHW'

---
 include/aidge/data/DataFormat.hpp | 87 +++++++++++++++++++------------
 src/data/DataFormat.cpp           | 87 ++++++++++++++++++++++---------
 2 files changed, 117 insertions(+), 57 deletions(-)

diff --git a/include/aidge/data/DataFormat.hpp b/include/aidge/data/DataFormat.hpp
index 77be8680d..4b3b949f0 100644
--- a/include/aidge/data/DataFormat.hpp
+++ b/include/aidge/data/DataFormat.hpp
@@ -17,63 +17,84 @@
 
 #include "aidge/utils/logger/EnumString.hpp"
 
+#define PERM(...) { __VA_ARGS__ }
+
+//   (EnumName, StringLiteral, NumDims, Permutation (0-indexed), Description)
+#define LIST_DATAFORMAT_ATTR(X)                                                      \
+    X(Default, "Default",         0, PERM(), "Unspecified format: interpretation is implementation-dependent"), \
+    X(Any,         "Any",         0, PERM(), "Any format is valid"),                    \
+    X(NCHW,        "NCHW",        4, PERM(0, 1, 2, 3), "4D format: [batch][channel][height][width]"), \
+    X(NHWC,        "NHWC",        4, PERM(0, 2, 3, 1), "4D format: [batch][height][width][channel]"), \
+    X(CHWN,        "CHWN",        4, PERM(1, 2, 3, 0), "4D format: [channel][height][width][batch]"), \
+    X(NCDHW,       "NCDHW",       5, PERM(0, 1, 2, 3, 4), "5D format: [batch][channel][depth][height][width]"), \
+    X(NDHWC,       "NDHWC",       5, PERM(0, 2, 3, 4, 1), "5D format: [batch][depth][height][width][channel]"), \
+    X(CDHWN,       "CDHWN",       5, PERM(1, 2, 3, 4, 0), "5D format: [channel][depth][height][width][batch]")
+
+#define NB_DFORMAT 8
+
 namespace Aidge {
 
 /**
- * @brief Enumeration of supported tensor data layouts
+ * @brief Enumeration of supported tensor data layouts.
  *
- * Represents different memory layout formats for multi-dimensional tensors:
+ * Represents different memory layouts for multi-dimensional tensors.
+ * The dimensions typically represent:
  * - N: Batch size
  * - C: Channels
  * - H: Height
  * - W: Width
  * - D: Depth (for 3D tensors)
+ *
+ * The enum values are generated via the X-macro.
  */
-enum class DataFormat {
-    Default,    ///< Default format, implementation dependent
-    NCHW,      ///< 4D format: [batch][channel][height][width]
-    NHWC,      ///< 4D format: [batch][height][width][channel]
-    CHWN,      ///< 4D format: [channel][height][width][batch]
-    NCDHW,     ///< 5D format: [batch][channel][depth][height][width]
-    NDHWC,     ///< 5D format: [batch][depth][height][width][channel]
-    CDHWN,     ///< 5D format: [channel][depth][height][width][batch]
-    Any        ///< Unspecified format
+enum class DataFormat : int {
+#define X(enumName, str, nb, arr, desc) enumName
+    LIST_DATAFORMAT_ATTR(X)
+#undef X
 };
 
-using DataFormatTranspose = std::array<std::size_t, 5>;
-
 /**
- * @brief Dictionary of transpose operations between different formats
+ * @brief Alias for a fixed-size array representing a permutation.
  *
- * Contains permutation arrays to convert between different data formats.
- * @warning In this array only, dimension index starts at 1
- * (0 is reserved as default value).
+ * A DataFormatTranspose is an array of 5 size_t values that represents a permutation mapping
+ * (0-indexed) for dimension reordering. Only the first NumDims elements are valid.
  */
-constexpr std::array<DataFormatTranspose, 7> DataFormatTransposeDict = {{
-    {},                 // Default
-    {1, 2, 3, 4},      // NCHW
-    {1, 3, 4, 2},      // NHWC
-    {2, 3, 4, 1},      // CHWN
-    {1, 2, 3, 4, 5},   // NCDHW
-    {1, 3, 4, 5, 2},   // NDHWC
-    {2, 3, 4, 5, 1}    // CDHWN
-}};
+using DataFormatTranspose = std::array<std::size_t, 5>;
 
 /**
- * @brief Get the permutation array for converting between data formats
+ * @brief Compute the permutation mapping to convert tensor data from a source format
+ *        to a destination format.
+ *
+ * This function performs the following steps:
+ *   1. If either format is DataFormat::Any, a runtime error is thrown.
+ *   2. If either format is DataFormat::Default or both dformat are equal, the identity permutation is returned.
+ *   3. Otherwise, the function retrieves the permutation mappings (from default NCHW)
+ *      for both the source and destination formats.
+ *   4. It computes the inverse of the source mapping (from source to default ordering).
+ *   5. Finally, it composes the inverse with the destination mapping to produce the
+ *      permutation mapping from source to destination.
+ *
+ * @param src The source data format.
+ * @param dst The destination data format.
+ * @return constexpr DataFormatTranspose The computed permutation array (0-indexed).
  *
- * @param src Source data format
- * @param dst Destination data format
- * @return DataFormatTranspose Permutation array to achieve the format conversion
+ * @throws std::runtime_error if either src or dst is DataFormat::Any.
+ * @pre The source and destination formats must have the same number of dimensions.
  */
-DataFormatTranspose getDataFormatTranspose(const DataFormat& src, const DataFormat& dst);
+DataFormatTranspose getPermutationMapping(const DataFormat& src, const DataFormat& dst);
 
 } // namespace Aidge
 
 namespace {
 template <>
-const char* const EnumStrings<Aidge::DataFormat>::data[]
-    = {"Default", "NCHW", "NHWC", "CHWN", "NCDHW", "NDHWC", "CDHWN", "Any"};
+struct EnumStrings<Aidge::DataFormat> {
+    static const char* const data[];
+};
+constexpr const char* const EnumStrings<Aidge::DataFormat>::data[] = {
+#define X(EnumName, Str, NumDims, Perm, Desc) Str
+    LIST_DATAFORMAT_ATTR(X)
+#undef X
+};
 }
 
 namespace Aidge {
diff --git a/src/data/DataFormat.cpp b/src/data/DataFormat.cpp
index 466da86c4..8b7460b3d 100644
--- a/src/data/DataFormat.cpp
+++ b/src/data/DataFormat.cpp
@@ -9,34 +9,73 @@
  *
  ********************************************************************************/
 
+#include <array>
+#include <cstddef>  // std::size_t
+
 #include "aidge/data/DataFormat.hpp"
+#include "aidge/utils/ErrorHandling.hpp"
+
+
+namespace Aidge {
+
+/**
+ * @brief Get the permutation array mapping from the default (NCHW) ordering to a given data format.
+ *
+ * @param dformat The target data format.
+ * @return const DataFormatTranspose& The permutation array (0-indexed).
+ *
+ * @note For DataFormat::Default and DataFormat::Any, an empty permutation is returned.
+ */
+static const DataFormatTranspose& getPermutationFromNCHW(DataFormat dformat) {
+    constexpr static const std::array<DataFormatTranspose, NB_DFORMAT> permutationFromNCHW = {{
+#define X(EnumName, Str, NumDims, Perm, Desc) Perm
+        LIST_DATAFORMAT_ATTR(X)
+#undef X
+    }};
+    return permutationFromNCHW[static_cast<std::size_t>(dformat)];
+}
+
+
+/**
+ * @brief Retrieve the number of dimensions for a given data format.
+ *
+ * @param dformat The data format.
+ * @return constexpr std::size_t The number of dimensions.
+ */
+static std::size_t getNbDimensions(DataFormat dformat) {
+    constexpr static const std::array<std::size_t, NB_DFORMAT> nbDimensions = {
+#define X(name, str, nb, arr, desc) nb
+        LIST_DATAFORMAT_ATTR(X)
+#undef X
+    };
+    return nbDimensions[static_cast<std::size_t>(dformat)];
+}
+
+DataFormatTranspose getPermutationMapping(const DataFormat& src, const DataFormat& dst) {
+    AIDGE_ASSERT((src != DataFormat::Any && dst != DataFormat::Any), "Permutation is not defined for DataFormat::Any");
+    if (src == DataFormat::Default || dst == DataFormat::Default || src == dst) {
+        return {0,1,2,3,4};
+    }
+    const std::size_t nbDims = getNbDimensions(src);
+    AIDGE_ASSERT(nbDims == getNbDimensions(dst), "Incompatible format conversion. Current and new data format must have the same number of dimensions.");
+
+    // Get permutation from default (NCHW) to source and destination.
+    const DataFormatTranspose& nchw_to_src = getPermutationFromNCHW(src);
+    const DataFormatTranspose& nchw_to_dst = getPermutationFromNCHW(dst);
 
-Aidge::DataFormatTranspose Aidge::getDataFormatTranspose(const DataFormat& src, const DataFormat& dst) {
-    // Permutation array from default format to src format
-    const auto srcDefToFormat = DataFormatTransposeDict[static_cast<int>(src)];
-    // Permutation array from default format to dst format
-    const auto dstDefToFormat = DataFormatTransposeDict[static_cast<int>(dst)];
-    // Compute permutation array from src format to default format:
-    DataFormatTranspose srcFormatToDef{};
-    for (size_t i = 0; i < srcDefToFormat.size(); ++i) {
-        if (srcDefToFormat[i] > 0) {
-            srcFormatToDef[srcDefToFormat[i] - 1] = i;
-        }
-        else {
-            srcFormatToDef[i] = i;
-        }
+    // Compute inverse permutation: mapping from source format to default (NCHW).
+    DataFormatTranspose src_to_nchw{};
+    for (std::size_t i = 0; i < nbDims; ++i) {
+        // Since the permutations are 0-indexed, simply invert.
+        src_to_nchw[nchw_to_src[i]] = i;
     }
 
-    // Compute permutation array from src format to dst format:
-    DataFormatTranspose srcToDst{};
-    for (size_t i = 0; i < dstDefToFormat.size(); ++i) {
-        if (dstDefToFormat[srcFormatToDef[i]] > 0) {
-            srcToDst[i] = dstDefToFormat[srcFormatToDef[i]] - 1;
-        }
-        else {
-            srcToDst[i] = srcFormatToDef[i];
-        }
+    // Compute mapping from source format to destination format.
+    DataFormatTranspose src_to_dst{};
+    for (std::size_t i = 0; i < nbDims; ++i) {
+        src_to_dst[i] = nchw_to_dst[src_to_nchw[i]];
     }
 
-    return srcToDst;
+    return src_to_dst;
 }
+} // namespace Aidge
\ No newline at end of file
-- 
GitLab