Skip to content
Snippets Groups Projects

[Fix] Tensor::setDataFormat handling of DataFormat::Default

Merged Maxence Naud requested to merge fix_314-dataformat-default into dev
7 files
+ 155
94
Compare changes
  • Side-by-side
  • Inline
Files
7
@@ -17,63 +17,84 @@
#include "aidge/utils/logger/EnumString.hpp"
#define PERM(...) { __VA_ARGS__ }
// (EnumName, StringLiteral, NumDims, Permutation (0-indexed), Description)
#define LIST_DATAFORMAT_ATTR(X) \
X(Default, "Default", 0, PERM(), "Unspecified format: interpretation is implementation-dependent"), \
X(Any, "Any", 0, PERM(), "Any format is valid"), \
X(NCHW, "NCHW", 4, PERM(0, 1, 2, 3), "4D format: [batch][channel][height][width]"), \
X(NHWC, "NHWC", 4, PERM(0, 2, 3, 1), "4D format: [batch][height][width][channel]"), \
X(CHWN, "CHWN", 4, PERM(1, 2, 3, 0), "4D format: [channel][height][width][batch]"), \
X(NCDHW, "NCDHW", 5, PERM(0, 1, 2, 3, 4), "5D format: [batch][channel][depth][height][width]"), \
X(NDHWC, "NDHWC", 5, PERM(0, 2, 3, 4, 1), "5D format: [batch][depth][height][width][channel]"), \
X(CDHWN, "CDHWN", 5, PERM(1, 2, 3, 4, 0), "5D format: [channel][depth][height][width][batch]")
#define NB_DFORMAT 8
namespace Aidge {
/**
* @brief Enumeration of supported tensor data layouts
* @brief Enumeration of supported tensor data layouts.
*
* Represents different memory layout formats for multi-dimensional tensors:
* Represents different memory layouts for multi-dimensional tensors.
* The dimensions typically represent:
* - N: Batch size
* - C: Channels
* - H: Height
* - W: Width
* - D: Depth (for 3D tensors)
*
* The enum values are generated via the X-macro.
*/
enum class DataFormat {
Default, ///< Default format, implementation dependent
NCHW, ///< 4D format: [batch][channel][height][width]
NHWC, ///< 4D format: [batch][height][width][channel]
CHWN, ///< 4D format: [channel][height][width][batch]
NCDHW, ///< 5D format: [batch][channel][depth][height][width]
NDHWC, ///< 5D format: [batch][depth][height][width][channel]
CDHWN, ///< 5D format: [channel][depth][height][width][batch]
Any ///< Unspecified format
enum class DataFormat : int {
#define X(enumName, str, nb, arr, desc) enumName
LIST_DATAFORMAT_ATTR(X)
#undef X
};
using DataFormatTranspose = std::array<std::size_t, 5>;
/**
* @brief Dictionary of transpose operations between different formats
* @brief Alias for a fixed-size array representing a permutation.
*
* Contains permutation arrays to convert between different data formats.
* @warning In this array only, dimension index starts at 1
* (0 is reserved as default value).
* A DataFormatTranspose is an array of 5 size_t values that represents a permutation mapping
* (0-indexed) for dimension reordering. Only the first NumDims elements are valid.
*/
constexpr std::array<DataFormatTranspose, 7> DataFormatTransposeDict = {{
{}, // Default
{1, 2, 3, 4}, // NCHW
{1, 3, 4, 2}, // NHWC
{2, 3, 4, 1}, // CHWN
{1, 2, 3, 4, 5}, // NCDHW
{1, 3, 4, 5, 2}, // NDHWC
{2, 3, 4, 5, 1} // CDHWN
}};
using DataFormatTranspose = std::array<std::size_t, 5>;
/**
* @brief Get the permutation array for converting between data formats
* @brief Compute the permutation mapping to convert tensor data from a source format
* to a destination format.
*
* This function performs the following steps:
* 1. If either format is DataFormat::Any, a runtime error is thrown.
* 2. If either format is DataFormat::Default or both dformat are equal, the identity permutation is returned.
* 3. Otherwise, the function retrieves the permutation mappings (from default NCHW)
* for both the source and destination formats.
* 4. It computes the inverse of the source mapping (from source to default ordering).
* 5. Finally, it composes the inverse with the destination mapping to produce the
* permutation mapping from source to destination.
*
* @param src The source data format.
* @param dst The destination data format.
* @return constexpr DataFormatTranspose The computed permutation array (0-indexed).
*
* @param src Source data format
* @param dst Destination data format
* @return DataFormatTranspose Permutation array to achieve the format conversion
* @throws std::runtime_error if either src or dst is DataFormat::Any.
* @pre The source and destination formats must have the same number of dimensions.
*/
DataFormatTranspose getDataFormatTranspose(const DataFormat& src, const DataFormat& dst);
DataFormatTranspose getPermutationMapping(const DataFormat& src, const DataFormat& dst);
} // namespace Aidge
namespace {
template <>
const char* const EnumStrings<Aidge::DataFormat>::data[]
= {"Default", "NCHW", "NHWC", "CHWN", "NCDHW", "NDHWC", "CDHWN", "Any"};
struct EnumStrings<Aidge::DataFormat> {
static const char* const data[];
};
constexpr const char* const EnumStrings<Aidge::DataFormat>::data[] = {
#define X(EnumName, Str, NumDims, Perm, Desc) Str
LIST_DATAFORMAT_ATTR(X)
#undef X
};
}
namespace Aidge {
Loading