diff --git a/src/data/DataFormat.cpp b/src/data/DataFormat.cpp index e8dd605bbe82199be29bb9dab2b6a95816b99faf..095ef9368e13b1d31322e73d98d1ec63ab0fda3b 100644 --- a/src/data/DataFormat.cpp +++ b/src/data/DataFormat.cpp @@ -52,7 +52,7 @@ static std::size_t getNbDimensions(DataFormat dformat) { } /** - * @brief Get the DataFormat corresponding to a given permutation vector. + * @brief Get the DataFormat corresponding to a given ``DataFormatTranspose``. * * @param perm The permutation vector. * @return DataFormat The matching DataFormat, or DataFormat::Default if not found. diff --git a/unit_tests/operator/Test_TransposeImpl.cpp b/unit_tests/operator/Test_TransposeImpl.cpp index 75adc9fd571c1a4402c0317bcc23bd07209d1380..f852c3fe0ebf2e6c73dddd358bf778b67779bffc 100644 --- a/unit_tests/operator/Test_TransposeImpl.cpp +++ b/unit_tests/operator/Test_TransposeImpl.cpp @@ -284,7 +284,7 @@ TEST_CASE("[cpu/operator] Transpose DataFormat") { REQUIRE(op->getOutput(0)->dataFormat() == DataFormat::NCHW); } - SECTION("Invalid Format Handling") { + SECTION("Default to Default") { std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16,3,224,450})); input->setDataFormat(Aidge::DataFormat::Default); auto transposeNode = Transpose({0,2,3,1}); @@ -293,7 +293,19 @@ TEST_CASE("[cpu/operator] Transpose DataFormat") { op->forwardDims(); REQUIRE(op->getOutput(0)->dataFormat() == DataFormat::Default); - //Should throw a Warning + //Should print a Log::Warning + } + + SECTION("Invalid Format Handling") { + std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16,3,224,450})); + input->setDataFormat(Aidge::DataFormat::NCHW); + auto transposeNode = Transpose({0,2,1,3}); + auto op = std::static_pointer_cast<OperatorTensor>(transposeNode->getOperator()); + op->associateInput(0, input); + + op->forwardDims(); + REQUIRE(op->getOutput(0)->dataFormat() == DataFormat::Default); + //Should print a Log::Warning } }