diff --git a/unit_tests/operator/Test_Conv_Op.cpp b/unit_tests/operator/Test_Conv_Op.cpp index de33ddd5a7613cde16b96b23722f6d2ab412f373..9b32054c35b14702850820abf9930ce811719dc4 100644 --- a/unit_tests/operator/Test_Conv_Op.cpp +++ b/unit_tests/operator/Test_Conv_Op.cpp @@ -24,13 +24,13 @@ namespace Aidge { TEST_CASE("[core/operator] Conv_Op(ForwardDims) ", "[Operator][ForwardDims][Conv]") { SECTION("I:NCHW O:NCHW W:NCHW"){ - std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16,3,224,450})); + std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16,3,224,450})); std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4,3,3,4})); // Out_ch, In_ch_h,W,H const std::vector<std::size_t> expectedOutputDims({16,4,222,447}); auto conv1 = Conv_Op<2>(std::array<size_t, 2>{3, 4}); - //Set DataFormat + //Set DataFormat conv1.getOutput(0)->setDataFormat(Aidge::DataFormat::NCHW); input->setDataFormat(Aidge::DataFormat::NCHW); weight->setDataFormat(Aidge::DataFormat::NCHW); @@ -43,61 +43,61 @@ TEST_CASE("[core/operator] Conv_Op(ForwardDims) ", "[Operator][ForwardDims][Conv REQUIRE(conv1.getOutput(0)->dims() == expectedOutputDims); } SECTION("I:NCHW O:NCHW W:NHWC") { - std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16, 3, 224, 450})); - std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4, 3, 3, 4})); // Out_ch, H, W, In_ch - + std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16, 3, 224, 450})); + std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4, 3, 3, 3})); // Out_ch, H, W, In_ch + const std::vector<std::size_t> expectedOutputDims({16, 4, 222, 447}); auto conv1 = Conv_Op<2>(std::array<size_t, 2>{3, 4}); - - // Set DataFormat + + // Set DataFormat conv1.getOutput(0)->setDataFormat(Aidge::DataFormat::NCHW); input->setDataFormat(Aidge::DataFormat::NCHW); weight->setDataFormat(Aidge::DataFormat::NHWC); // NHWC weight format - + // Set inputs conv1.setInput(1, weight); conv1.setInput(0, input); - + REQUIRE(conv1.forwardDims()); REQUIRE(conv1.getOutput(0)->dims() == expectedOutputDims); } - + SECTION("I:NHWC O:NHWC W:NCHW") { - std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16, 3, 224, 450})); + std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16, 224, 450, 3})); std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4, 3, 3, 4})); // Out_ch, In_ch, H, W - + const std::vector<std::size_t> expectedOutputDims({16, 222, 447, 4}); auto conv1 = Conv_Op<2>(std::array<size_t, 2>{3, 4}); - - // Set DataFormat + + // Set DataFormat conv1.getOutput(0)->setDataFormat(Aidge::DataFormat::NHWC); input->setDataFormat(Aidge::DataFormat::NHWC); weight->setDataFormat(Aidge::DataFormat::NCHW); // NCHW weight format - + // Set inputs conv1.setInput(1, weight); conv1.setInput(0, input); - + REQUIRE(conv1.forwardDims()); REQUIRE(conv1.getOutput(0)->dims() == expectedOutputDims); } - + SECTION("I:NHWC O:NHWC W:NHWC") { - std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16, 3,224, 450})); - std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4, 3, 3, 4})); // (Out_ch, H, W, In_ch) - + std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16, 224, 450, 3})); + std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4, 3, 4, 3})); // (Out_ch, H, W, In_ch) + const std::vector<std::size_t> expectedOutputDims({16, 222, 447, 4}); auto conv1 = Conv_Op<2>(std::array<size_t, 2>{3, 4}); - - // Set DataFormat + + // Set DataFormat conv1.getOutput(0)->setDataFormat(Aidge::DataFormat::NHWC); input->setDataFormat(Aidge::DataFormat::NHWC); weight->setDataFormat(Aidge::DataFormat::NHWC); - + // Set inputs conv1.setInput(1, weight); conv1.setInput(0, input); - + REQUIRE(conv1.forwardDims()); REQUIRE(conv1.getOutput(0)->dims() == expectedOutputDims); }