diff --git a/unit_tests/operator/Test_Conv_Op.cpp b/unit_tests/operator/Test_Conv_Op.cpp index 103b8c624455a453d98898b6cb9c227102577382..002e52522da145135ed5d2497ecf07e3fd27986c 100644 --- a/unit_tests/operator/Test_Conv_Op.cpp +++ b/unit_tests/operator/Test_Conv_Op.cpp @@ -42,66 +42,126 @@ TEST_CASE("[core/operator] Conv_Op(ForwardDims) ", "[Operator][ForwardDims][Conv REQUIRE(conv1.forwardDims()); REQUIRE(conv1.getOutput(0)->dims() == expectedOutputDims); } - - SECTION("I:NCHW O:NHWC W:NCHW") { - std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16,3,224,450})); - std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4,3,3,4})); - - const std::vector<std::size_t> expectedOutputDims({16,222,447,4}); + SECTION("I:NCHW O:NCHW W:NHWC") { + std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16, 3, 224, 450})); + std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4, 3, 3, 4})); // Out_ch, H, W, In_ch + + const std::vector<std::size_t> expectedOutputDims({16, 4, 222, 447}); auto conv1 = Conv_Op<2>(std::array<size_t, 2>{3, 4}); - - // Set DataFormat - conv1.getOutput(0)->setDataFormat(Aidge::DataFormat::NHWC); + + // Set DataFormat + conv1.getOutput(0)->setDataFormat(Aidge::DataFormat::NCHW); input->setDataFormat(Aidge::DataFormat::NCHW); - weight->setDataFormat(Aidge::DataFormat::NCHW); - + weight->setDataFormat(Aidge::DataFormat::NHWC); // NHWC weight format + // Set inputs conv1.setInput(1, weight); conv1.setInput(0, input); - + REQUIRE(conv1.forwardDims()); REQUIRE(conv1.getOutput(0)->dims() == expectedOutputDims); } - - SECTION("I:NHWC O:NCHW W:NHWC") { - std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16,224,450,3})); - std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4,4,3,3})); // H, W, In_ch, Out_ch - - const std::vector<std::size_t> expectedOutputDims({16,4,222,447}); + + SECTION("I:NHWC O:NHWC W:NCHW") { + std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16, 3, 224, 450})); + std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4, 3, 3, 4})); // Out_ch, In_ch, H, W + + const std::vector<std::size_t> expectedOutputDims({16, 222, 447, 4}); auto conv1 = Conv_Op<2>(std::array<size_t, 2>{3, 4}); - - // Set DataFormat - conv1.getOutput(0)->setDataFormat(Aidge::DataFormat::NCHW); + + // Set DataFormat + conv1.getOutput(0)->setDataFormat(Aidge::DataFormat::NHWC); input->setDataFormat(Aidge::DataFormat::NHWC); - weight->setDataFormat(Aidge::DataFormat::NHWC); - + weight->setDataFormat(Aidge::DataFormat::NCHW); // NCHW weight format + // Set inputs conv1.setInput(1, weight); conv1.setInput(0, input); - + REQUIRE(conv1.forwardDims()); REQUIRE(conv1.getOutput(0)->dims() == expectedOutputDims); } - - SECTION("I:NHWC O:NHWC W:NCHW") { - std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16,224,450,3})); - std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4,3,3,4})); // Out_ch, In_ch, H, W - - const std::vector<std::size_t> expectedOutputDims({16,222,447,4}); + + SECTION("I:NHWC O:NHWC W:NHWC") { + std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16, 3,224, 450})); + std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4, 3, 3, 4})); // (Out_ch, H, W, In_ch) + + const std::vector<std::size_t> expectedOutputDims({16, 222, 447, 4}); auto conv1 = Conv_Op<2>(std::array<size_t, 2>{3, 4}); - - // Set DataFormat + + // Set DataFormat conv1.getOutput(0)->setDataFormat(Aidge::DataFormat::NHWC); input->setDataFormat(Aidge::DataFormat::NHWC); - weight->setDataFormat(Aidge::DataFormat::NCHW); - + weight->setDataFormat(Aidge::DataFormat::NHWC); + // Set inputs conv1.setInput(1, weight); conv1.setInput(0, input); - + REQUIRE(conv1.forwardDims()); REQUIRE(conv1.getOutput(0)->dims() == expectedOutputDims); } + + + // SECTION("I:NCHW O:NHWC W:NCHW") { + // std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16,3,224,450})); + // std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4,3,3,4})); + + // const std::vector<std::size_t> expectedOutputDims({16,222,447,4}); + // auto conv1 = Conv_Op<2>(std::array<size_t, 2>{3, 4}); + + // // Set DataFormat + // conv1.getOutput(0)->setDataFormat(Aidge::DataFormat::NHWC); + // input->setDataFormat(Aidge::DataFormat::NCHW); + // weight->setDataFormat(Aidge::DataFormat::NCHW); + + // // Set inputs + // conv1.setInput(1, weight); + // conv1.setInput(0, input); + + // REQUIRE(conv1.forwardDims()); + // REQUIRE(conv1.getOutput(0)->dims() == expectedOutputDims); + // } + + // SECTION("I:NHWC O:NCHW W:NHWC") { + // std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16,224,450,3})); + // std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4,4,3,3})); // H, W, In_ch, Out_ch + + // const std::vector<std::size_t> expectedOutputDims({16,4,222,447}); + // auto conv1 = Conv_Op<2>(std::array<size_t, 2>{3, 4}); + + // // Set DataFormat + // conv1.getOutput(0)->setDataFormat(Aidge::DataFormat::NCHW); + // input->setDataFormat(Aidge::DataFormat::NHWC); + // weight->setDataFormat(Aidge::DataFormat::NHWC); + + // // Set inputs + // conv1.setInput(1, weight); + // conv1.setInput(0, input); + + // REQUIRE(conv1.forwardDims()); + // REQUIRE(conv1.getOutput(0)->dims() == expectedOutputDims); + // } + + // SECTION("I:NHWC O:NHWC W:NCHW") { + // std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16,224,450,3})); + // std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4,3,3,4})); // Out_ch, In_ch, H, W + + // const std::vector<std::size_t> expectedOutputDims({16,222,447,4}); + // auto conv1 = Conv_Op<2>(std::array<size_t, 2>{3, 4}); + + // // Set DataFormat + // conv1.getOutput(0)->setDataFormat(Aidge::DataFormat::NHWC); + // input->setDataFormat(Aidge::DataFormat::NHWC); + // weight->setDataFormat(Aidge::DataFormat::NCHW); + + // // Set inputs + // conv1.setInput(1, weight); + // conv1.setInput(0, input); + + // REQUIRE(conv1.forwardDims()); + // REQUIRE(conv1.getOutput(0)->dims() == expectedOutputDims); + // } }