Skip to content
Snippets Groups Projects
Commit 4db496dc authored by Wissam Boussella's avatar Wissam Boussella
Browse files

New tests

parent a3ff2919
Branches master
No related tags found
No related merge requests found
Pipeline #65897 passed
This commit is part of merge request !314. Comments created here will be created in the context of that merge request.
......@@ -42,66 +42,126 @@ TEST_CASE("[core/operator] Conv_Op(ForwardDims) ", "[Operator][ForwardDims][Conv
REQUIRE(conv1.forwardDims());
REQUIRE(conv1.getOutput(0)->dims() == expectedOutputDims);
}
SECTION("I:NCHW O:NHWC W:NCHW") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16,3,224,450}));
std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4,3,3,4}));
const std::vector<std::size_t> expectedOutputDims({16,222,447,4});
SECTION("I:NCHW O:NCHW W:NHWC") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16, 3, 224, 450}));
std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4, 3, 3, 4})); // Out_ch, H, W, In_ch
const std::vector<std::size_t> expectedOutputDims({16, 4, 222, 447});
auto conv1 = Conv_Op<2>(std::array<size_t, 2>{3, 4});
// Set DataFormat
conv1.getOutput(0)->setDataFormat(Aidge::DataFormat::NHWC);
// Set DataFormat
conv1.getOutput(0)->setDataFormat(Aidge::DataFormat::NCHW);
input->setDataFormat(Aidge::DataFormat::NCHW);
weight->setDataFormat(Aidge::DataFormat::NCHW);
weight->setDataFormat(Aidge::DataFormat::NHWC); // NHWC weight format
// Set inputs
conv1.setInput(1, weight);
conv1.setInput(0, input);
REQUIRE(conv1.forwardDims());
REQUIRE(conv1.getOutput(0)->dims() == expectedOutputDims);
}
SECTION("I:NHWC O:NCHW W:NHWC") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16,224,450,3}));
std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4,4,3,3})); // H, W, In_ch, Out_ch
const std::vector<std::size_t> expectedOutputDims({16,4,222,447});
SECTION("I:NHWC O:NHWC W:NCHW") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16, 3, 224, 450}));
std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4, 3, 3, 4})); // Out_ch, In_ch, H, W
const std::vector<std::size_t> expectedOutputDims({16, 222, 447, 4});
auto conv1 = Conv_Op<2>(std::array<size_t, 2>{3, 4});
// Set DataFormat
conv1.getOutput(0)->setDataFormat(Aidge::DataFormat::NCHW);
// Set DataFormat
conv1.getOutput(0)->setDataFormat(Aidge::DataFormat::NHWC);
input->setDataFormat(Aidge::DataFormat::NHWC);
weight->setDataFormat(Aidge::DataFormat::NHWC);
weight->setDataFormat(Aidge::DataFormat::NCHW); // NCHW weight format
// Set inputs
conv1.setInput(1, weight);
conv1.setInput(0, input);
REQUIRE(conv1.forwardDims());
REQUIRE(conv1.getOutput(0)->dims() == expectedOutputDims);
}
SECTION("I:NHWC O:NHWC W:NCHW") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16,224,450,3}));
std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4,3,3,4})); // Out_ch, In_ch, H, W
const std::vector<std::size_t> expectedOutputDims({16,222,447,4});
SECTION("I:NHWC O:NHWC W:NHWC") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16, 3,224, 450}));
std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4, 3, 3, 4})); // (Out_ch, H, W, In_ch)
const std::vector<std::size_t> expectedOutputDims({16, 222, 447, 4});
auto conv1 = Conv_Op<2>(std::array<size_t, 2>{3, 4});
// Set DataFormat
// Set DataFormat
conv1.getOutput(0)->setDataFormat(Aidge::DataFormat::NHWC);
input->setDataFormat(Aidge::DataFormat::NHWC);
weight->setDataFormat(Aidge::DataFormat::NCHW);
weight->setDataFormat(Aidge::DataFormat::NHWC);
// Set inputs
conv1.setInput(1, weight);
conv1.setInput(0, input);
REQUIRE(conv1.forwardDims());
REQUIRE(conv1.getOutput(0)->dims() == expectedOutputDims);
}
// SECTION("I:NCHW O:NHWC W:NCHW") {
// std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16,3,224,450}));
// std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4,3,3,4}));
// const std::vector<std::size_t> expectedOutputDims({16,222,447,4});
// auto conv1 = Conv_Op<2>(std::array<size_t, 2>{3, 4});
// // Set DataFormat
// conv1.getOutput(0)->setDataFormat(Aidge::DataFormat::NHWC);
// input->setDataFormat(Aidge::DataFormat::NCHW);
// weight->setDataFormat(Aidge::DataFormat::NCHW);
// // Set inputs
// conv1.setInput(1, weight);
// conv1.setInput(0, input);
// REQUIRE(conv1.forwardDims());
// REQUIRE(conv1.getOutput(0)->dims() == expectedOutputDims);
// }
// SECTION("I:NHWC O:NCHW W:NHWC") {
// std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16,224,450,3}));
// std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4,4,3,3})); // H, W, In_ch, Out_ch
// const std::vector<std::size_t> expectedOutputDims({16,4,222,447});
// auto conv1 = Conv_Op<2>(std::array<size_t, 2>{3, 4});
// // Set DataFormat
// conv1.getOutput(0)->setDataFormat(Aidge::DataFormat::NCHW);
// input->setDataFormat(Aidge::DataFormat::NHWC);
// weight->setDataFormat(Aidge::DataFormat::NHWC);
// // Set inputs
// conv1.setInput(1, weight);
// conv1.setInput(0, input);
// REQUIRE(conv1.forwardDims());
// REQUIRE(conv1.getOutput(0)->dims() == expectedOutputDims);
// }
// SECTION("I:NHWC O:NHWC W:NCHW") {
// std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16,224,450,3}));
// std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4,3,3,4})); // Out_ch, In_ch, H, W
// const std::vector<std::size_t> expectedOutputDims({16,222,447,4});
// auto conv1 = Conv_Op<2>(std::array<size_t, 2>{3, 4});
// // Set DataFormat
// conv1.getOutput(0)->setDataFormat(Aidge::DataFormat::NHWC);
// input->setDataFormat(Aidge::DataFormat::NHWC);
// weight->setDataFormat(Aidge::DataFormat::NCHW);
// // Set inputs
// conv1.setInput(1, weight);
// conv1.setInput(0, input);
// REQUIRE(conv1.forwardDims());
// REQUIRE(conv1.getOutput(0)->dims() == expectedOutputDims);
// }
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment