diff --git a/include/aidge/data/DataFormat.hpp b/include/aidge/data/DataFormat.hpp index 77be8680deeda422e6cf41e612eab21ecf80c5e1..4b3b949f007ed835fc2af0be15fc53284d188ce0 100644 --- a/include/aidge/data/DataFormat.hpp +++ b/include/aidge/data/DataFormat.hpp @@ -17,63 +17,84 @@ #include "aidge/utils/logger/EnumString.hpp" +#define PERM(...) { __VA_ARGS__ } + +// (EnumName, StringLiteral, NumDims, Permutation (0-indexed), Description) +#define LIST_DATAFORMAT_ATTR(X) \ + X(Default, "Default", 0, PERM(), "Unspecified format: interpretation is implementation-dependent"), \ + X(Any, "Any", 0, PERM(), "Any format is valid"), \ + X(NCHW, "NCHW", 4, PERM(0, 1, 2, 3), "4D format: [batch][channel][height][width]"), \ + X(NHWC, "NHWC", 4, PERM(0, 2, 3, 1), "4D format: [batch][height][width][channel]"), \ + X(CHWN, "CHWN", 4, PERM(1, 2, 3, 0), "4D format: [channel][height][width][batch]"), \ + X(NCDHW, "NCDHW", 5, PERM(0, 1, 2, 3, 4), "5D format: [batch][channel][depth][height][width]"), \ + X(NDHWC, "NDHWC", 5, PERM(0, 2, 3, 4, 1), "5D format: [batch][depth][height][width][channel]"), \ + X(CDHWN, "CDHWN", 5, PERM(1, 2, 3, 4, 0), "5D format: [channel][depth][height][width][batch]") + +#define NB_DFORMAT 8 + namespace Aidge { /** - * @brief Enumeration of supported tensor data layouts + * @brief Enumeration of supported tensor data layouts. * - * Represents different memory layout formats for multi-dimensional tensors: + * Represents different memory layouts for multi-dimensional tensors. + * The dimensions typically represent: * - N: Batch size * - C: Channels * - H: Height * - W: Width * - D: Depth (for 3D tensors) + * + * The enum values are generated via the X-macro. */ -enum class DataFormat { - Default, ///< Default format, implementation dependent - NCHW, ///< 4D format: [batch][channel][height][width] - NHWC, ///< 4D format: [batch][height][width][channel] - CHWN, ///< 4D format: [channel][height][width][batch] - NCDHW, ///< 5D format: [batch][channel][depth][height][width] - NDHWC, ///< 5D format: [batch][depth][height][width][channel] - CDHWN, ///< 5D format: [channel][depth][height][width][batch] - Any ///< Unspecified format +enum class DataFormat : int { +#define X(enumName, str, nb, arr, desc) enumName + LIST_DATAFORMAT_ATTR(X) +#undef X }; -using DataFormatTranspose = std::array<std::size_t, 5>; - /** - * @brief Dictionary of transpose operations between different formats + * @brief Alias for a fixed-size array representing a permutation. * - * Contains permutation arrays to convert between different data formats. - * @warning In this array only, dimension index starts at 1 - * (0 is reserved as default value). + * A DataFormatTranspose is an array of 5 size_t values that represents a permutation mapping + * (0-indexed) for dimension reordering. Only the first NumDims elements are valid. */ -constexpr std::array<DataFormatTranspose, 7> DataFormatTransposeDict = {{ - {}, // Default - {1, 2, 3, 4}, // NCHW - {1, 3, 4, 2}, // NHWC - {2, 3, 4, 1}, // CHWN - {1, 2, 3, 4, 5}, // NCDHW - {1, 3, 4, 5, 2}, // NDHWC - {2, 3, 4, 5, 1} // CDHWN -}}; +using DataFormatTranspose = std::array<std::size_t, 5>; /** - * @brief Get the permutation array for converting between data formats + * @brief Compute the permutation mapping to convert tensor data from a source format + * to a destination format. + * + * This function performs the following steps: + * 1. If either format is DataFormat::Any, a runtime error is thrown. + * 2. If either format is DataFormat::Default or both dformat are equal, the identity permutation is returned. + * 3. Otherwise, the function retrieves the permutation mappings (from default NCHW) + * for both the source and destination formats. + * 4. It computes the inverse of the source mapping (from source to default ordering). + * 5. Finally, it composes the inverse with the destination mapping to produce the + * permutation mapping from source to destination. + * + * @param src The source data format. + * @param dst The destination data format. + * @return constexpr DataFormatTranspose The computed permutation array (0-indexed). * - * @param src Source data format - * @param dst Destination data format - * @return DataFormatTranspose Permutation array to achieve the format conversion + * @throws std::runtime_error if either src or dst is DataFormat::Any. + * @pre The source and destination formats must have the same number of dimensions. */ -DataFormatTranspose getDataFormatTranspose(const DataFormat& src, const DataFormat& dst); +DataFormatTranspose getPermutationMapping(const DataFormat& src, const DataFormat& dst); } // namespace Aidge namespace { template <> -const char* const EnumStrings<Aidge::DataFormat>::data[] - = {"Default", "NCHW", "NHWC", "CHWN", "NCDHW", "NDHWC", "CDHWN", "Any"}; +struct EnumStrings<Aidge::DataFormat> { + static const char* const data[]; +}; +constexpr const char* const EnumStrings<Aidge::DataFormat>::data[] = { +#define X(EnumName, Str, NumDims, Perm, Desc) Str + LIST_DATAFORMAT_ATTR(X) +#undef X +}; } namespace Aidge { diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index 785caaa0e8959ba34d438913a4c0e5bad3df0f86..5df59becdc41f12768935544a42aac24ffb3a333 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -431,7 +431,7 @@ public: * @brief Get the data format enum. * @return constexpr DataFormat */ - constexpr DataFormat dataFormat() const noexcept { return mDataFormat; } + const DataFormat& dataFormat() const noexcept { return mDataFormat; } /** * @brief Set the DataType of the Tensor and converts data @@ -462,13 +462,13 @@ public: * data is copy-transposed. */ void setDataFormat(const DataFormat df, bool copyTrans = true) { - if (!copyTrans || df == dataFormat()) { + if (!copyTrans || df == dataFormat() || df == DataFormat::Default || dataFormat() == DataFormat::Default) { mDataFormat = df; return; } - - const auto transpose = getDataFormatTranspose(dataFormat(), df); - + + const auto transpose = getPermutationMapping(dataFormat(), df); + if (mImpl) { copyTranspose(*this, transpose); } else { @@ -476,7 +476,7 @@ public: for (std::size_t i = 0; i < dims().size(); ++i) { newDims.push_back(dims()[transpose[i]]); } - + std::vector<std::size_t> newStrides(dims().size(), 1); for (size_t i = 0; i < dims().size(); ++i) { for (size_t j = i + 1; j < dims().size(); ++j) { @@ -486,9 +486,10 @@ public: mDims = std::move(newDims); mStrides = std::move(newStrides); } - + mDataFormat = df; } + /** * @brief Get the Impl object * @return constexpr const std::shared_ptr<TensorImpl>& diff --git a/python_binding/data/pybind_DataFormat.cpp b/python_binding/data/pybind_DataFormat.cpp index a63df321c3298284df7de8fd2c3eb0fc0cecae24..5308fb3023e15a74f0dc5f674917d7ae65cbb52f 100644 --- a/python_binding/data/pybind_DataFormat.cpp +++ b/python_binding/data/pybind_DataFormat.cpp @@ -66,7 +66,7 @@ void bindEnum(py::module& m, const std::string& name) { void init_DataFormat(py::module& m) { bindEnum<DataFormat>(m, "dformat"); m.def("format_as", (const char* (*)(DataFormat)) &format_as, py::arg("df")); - m.def("get_data_format_transpose", &getDataFormatTranspose, py::arg("src"), py::arg("dst")); + m.def("get_permutation_mapping", &getPermutationMapping, py::arg("src"), py::arg("dst")); } } // namespace Aidge diff --git a/src/backend/OperatorImpl.cpp b/src/backend/OperatorImpl.cpp index 08f5fe671c7502a6c5fe01dbdfb7ae4c9b95ac81..480e751807d85c4f74039e35c284f13f03013650 100644 --- a/src/backend/OperatorImpl.cpp +++ b/src/backend/OperatorImpl.cpp @@ -196,7 +196,7 @@ bool Aidge::OperatorImpl::checkIOSpec(const ImplSpec::IOSpec& required, const Im && spec.format != DataFormat::Any && required.format != spec.format) { - const auto transpose = getDataFormatTranspose(required.format, spec.format); + const auto transpose = getPermutationMapping(required.format, spec.format); std::vector<size_t> identity(transpose.size()); std::iota(std::begin(identity), std::end(identity), 0); @@ -261,7 +261,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& && IOSpec.format != DataFormat::Any && requiredIOSpec.format != IOSpec.format) { - const auto transpose = getDataFormatTranspose(requiredIOSpec.format, IOSpec.format); + const auto transpose = getPermutationMapping(requiredIOSpec.format, IOSpec.format); auto transposeOp = Transpose(std::vector<DimSize_t>(transpose.begin(), transpose.end())); transposeOp->getOperator()->setDataFormat(IOSpec.format); transposeOp->getOperator()->setDataType(requiredIOSpec.type); @@ -315,7 +315,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& && IOSpec.format != DataFormat::Any && requiredIOSpec.format != IOSpec.format) { - const auto transpose = getDataFormatTranspose(IOSpec.format, requiredIOSpec.format); + const auto transpose = getPermutationMapping(IOSpec.format, requiredIOSpec.format); auto transposeOp = Transpose(std::vector<DimSize_t>(transpose.begin(), transpose.end())); transposeOp->getOperator()->setDataFormat(requiredIOSpec.format); transposeOp->getOperator()->setDataType(requiredIOSpec.type); diff --git a/src/data/DataFormat.cpp b/src/data/DataFormat.cpp index 466da86c469d89e5f1f4fc0895223513783b801c..8b7460b3d42de2e2ddb60a6332405141f87b297c 100644 --- a/src/data/DataFormat.cpp +++ b/src/data/DataFormat.cpp @@ -9,34 +9,73 @@ * ********************************************************************************/ +#include <array> +#include <cstddef> // std::size_t + #include "aidge/data/DataFormat.hpp" +#include "aidge/utils/ErrorHandling.hpp" + + +namespace Aidge { + +/** + * @brief Get the permutation array mapping from the default (NCHW) ordering to a given data format. + * + * @param dformat The target data format. + * @return const DataFormatTranspose& The permutation array (0-indexed). + * + * @note For DataFormat::Default and DataFormat::Any, an empty permutation is returned. + */ +static const DataFormatTranspose& getPermutationFromNCHW(DataFormat dformat) { + constexpr static const std::array<DataFormatTranspose, NB_DFORMAT> permutationFromNCHW = {{ +#define X(EnumName, Str, NumDims, Perm, Desc) Perm + LIST_DATAFORMAT_ATTR(X) +#undef X + }}; + return permutationFromNCHW[static_cast<std::size_t>(dformat)]; +} + + +/** + * @brief Retrieve the number of dimensions for a given data format. + * + * @param dformat The data format. + * @return constexpr std::size_t The number of dimensions. + */ +static std::size_t getNbDimensions(DataFormat dformat) { + constexpr static const std::array<std::size_t, NB_DFORMAT> nbDimensions = { +#define X(name, str, nb, arr, desc) nb + LIST_DATAFORMAT_ATTR(X) +#undef X + }; + return nbDimensions[static_cast<std::size_t>(dformat)]; +} + +DataFormatTranspose getPermutationMapping(const DataFormat& src, const DataFormat& dst) { + AIDGE_ASSERT((src != DataFormat::Any && dst != DataFormat::Any), "Permutation is not defined for DataFormat::Any"); + if (src == DataFormat::Default || dst == DataFormat::Default || src == dst) { + return {0,1,2,3,4}; + } + const std::size_t nbDims = getNbDimensions(src); + AIDGE_ASSERT(nbDims == getNbDimensions(dst), "Incompatible format conversion. Current and new data format must have the same number of dimensions."); + + // Get permutation from default (NCHW) to source and destination. + const DataFormatTranspose& nchw_to_src = getPermutationFromNCHW(src); + const DataFormatTranspose& nchw_to_dst = getPermutationFromNCHW(dst); -Aidge::DataFormatTranspose Aidge::getDataFormatTranspose(const DataFormat& src, const DataFormat& dst) { - // Permutation array from default format to src format - const auto srcDefToFormat = DataFormatTransposeDict[static_cast<int>(src)]; - // Permutation array from default format to dst format - const auto dstDefToFormat = DataFormatTransposeDict[static_cast<int>(dst)]; - // Compute permutation array from src format to default format: - DataFormatTranspose srcFormatToDef{}; - for (size_t i = 0; i < srcDefToFormat.size(); ++i) { - if (srcDefToFormat[i] > 0) { - srcFormatToDef[srcDefToFormat[i] - 1] = i; - } - else { - srcFormatToDef[i] = i; - } + // Compute inverse permutation: mapping from source format to default (NCHW). + DataFormatTranspose src_to_nchw{}; + for (std::size_t i = 0; i < nbDims; ++i) { + // Since the permutations are 0-indexed, simply invert. + src_to_nchw[nchw_to_src[i]] = i; } - // Compute permutation array from src format to dst format: - DataFormatTranspose srcToDst{}; - for (size_t i = 0; i < dstDefToFormat.size(); ++i) { - if (dstDefToFormat[srcFormatToDef[i]] > 0) { - srcToDst[i] = dstDefToFormat[srcFormatToDef[i]] - 1; - } - else { - srcToDst[i] = srcFormatToDef[i]; - } + // Compute mapping from source format to destination format. + DataFormatTranspose src_to_dst{}; + for (std::size_t i = 0; i < nbDims; ++i) { + src_to_dst[i] = nchw_to_dst[src_to_nchw[i]]; } - return srcToDst; + return src_to_dst; } +} // namespace Aidge \ No newline at end of file diff --git a/src/recipes/ExplicitTranspose.cpp b/src/recipes/ExplicitTranspose.cpp index 7ff971b7e436219d5dfbb7cbadbaf780d3f1aeda..c4e2c425c93f6306373a49c29b1d117a03af04ae 100644 --- a/src/recipes/ExplicitTranspose.cpp +++ b/src/recipes/ExplicitTranspose.cpp @@ -94,14 +94,14 @@ void Aidge::explicitTranspose(std::shared_ptr<GraphView> graph) { else { // Case 2: change of format // => compute the new permutation array - const auto transpose = getDataFormatTranspose(parentInput->dataFormat(), output->dataFormat()); + const auto transpose = getPermutationMapping(parentInput->dataFormat(), output->dataFormat()); auto transposeOp = std::static_pointer_cast<Transpose_Op>(parent.first->getOperator()); transposeOp->setDataFormat(output->dataFormat()); transposeOp->outputDimsOrder() = std::vector<DimSize_t>(transpose.begin(), transpose.end()); } } else { - const auto transpose = getDataFormatTranspose(input->dataFormat(), output->dataFormat()); + const auto transpose = getPermutationMapping(input->dataFormat(), output->dataFormat()); auto transposeOp = Transpose(std::vector<DimSize_t>(transpose.begin(), transpose.end())); transposeOp->getOperator()->setDataFormat(output->dataFormat()); transposeOp->getOperator()->setDataType(output->dataType()); diff --git a/unit_tests/operator/Test_Conv_Op.cpp b/unit_tests/operator/Test_Conv_Op.cpp index de33ddd5a7613cde16b96b23722f6d2ab412f373..9b32054c35b14702850820abf9930ce811719dc4 100644 --- a/unit_tests/operator/Test_Conv_Op.cpp +++ b/unit_tests/operator/Test_Conv_Op.cpp @@ -24,13 +24,13 @@ namespace Aidge { TEST_CASE("[core/operator] Conv_Op(ForwardDims) ", "[Operator][ForwardDims][Conv]") { SECTION("I:NCHW O:NCHW W:NCHW"){ - std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16,3,224,450})); + std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16,3,224,450})); std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4,3,3,4})); // Out_ch, In_ch_h,W,H const std::vector<std::size_t> expectedOutputDims({16,4,222,447}); auto conv1 = Conv_Op<2>(std::array<size_t, 2>{3, 4}); - //Set DataFormat + //Set DataFormat conv1.getOutput(0)->setDataFormat(Aidge::DataFormat::NCHW); input->setDataFormat(Aidge::DataFormat::NCHW); weight->setDataFormat(Aidge::DataFormat::NCHW); @@ -43,61 +43,61 @@ TEST_CASE("[core/operator] Conv_Op(ForwardDims) ", "[Operator][ForwardDims][Conv REQUIRE(conv1.getOutput(0)->dims() == expectedOutputDims); } SECTION("I:NCHW O:NCHW W:NHWC") { - std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16, 3, 224, 450})); - std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4, 3, 3, 4})); // Out_ch, H, W, In_ch - + std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16, 3, 224, 450})); + std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4, 3, 3, 3})); // Out_ch, H, W, In_ch + const std::vector<std::size_t> expectedOutputDims({16, 4, 222, 447}); auto conv1 = Conv_Op<2>(std::array<size_t, 2>{3, 4}); - - // Set DataFormat + + // Set DataFormat conv1.getOutput(0)->setDataFormat(Aidge::DataFormat::NCHW); input->setDataFormat(Aidge::DataFormat::NCHW); weight->setDataFormat(Aidge::DataFormat::NHWC); // NHWC weight format - + // Set inputs conv1.setInput(1, weight); conv1.setInput(0, input); - + REQUIRE(conv1.forwardDims()); REQUIRE(conv1.getOutput(0)->dims() == expectedOutputDims); } - + SECTION("I:NHWC O:NHWC W:NCHW") { - std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16, 3, 224, 450})); + std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16, 224, 450, 3})); std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4, 3, 3, 4})); // Out_ch, In_ch, H, W - + const std::vector<std::size_t> expectedOutputDims({16, 222, 447, 4}); auto conv1 = Conv_Op<2>(std::array<size_t, 2>{3, 4}); - - // Set DataFormat + + // Set DataFormat conv1.getOutput(0)->setDataFormat(Aidge::DataFormat::NHWC); input->setDataFormat(Aidge::DataFormat::NHWC); weight->setDataFormat(Aidge::DataFormat::NCHW); // NCHW weight format - + // Set inputs conv1.setInput(1, weight); conv1.setInput(0, input); - + REQUIRE(conv1.forwardDims()); REQUIRE(conv1.getOutput(0)->dims() == expectedOutputDims); } - + SECTION("I:NHWC O:NHWC W:NHWC") { - std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16, 3,224, 450})); - std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4, 3, 3, 4})); // (Out_ch, H, W, In_ch) - + std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16, 224, 450, 3})); + std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4, 3, 4, 3})); // (Out_ch, H, W, In_ch) + const std::vector<std::size_t> expectedOutputDims({16, 222, 447, 4}); auto conv1 = Conv_Op<2>(std::array<size_t, 2>{3, 4}); - - // Set DataFormat + + // Set DataFormat conv1.getOutput(0)->setDataFormat(Aidge::DataFormat::NHWC); input->setDataFormat(Aidge::DataFormat::NHWC); weight->setDataFormat(Aidge::DataFormat::NHWC); - + // Set inputs conv1.setInput(1, weight); conv1.setInput(0, input); - + REQUIRE(conv1.forwardDims()); REQUIRE(conv1.getOutput(0)->dims() == expectedOutputDims); }