Skip to content
Snippets Groups Projects
Commit 4db496dc authored by Wissam Boussella's avatar Wissam Boussella
Browse files

New tests

parent a3ff2919
No related branches found
No related tags found
No related merge requests found
Pipeline #65897 passed
......@@ -42,66 +42,126 @@ TEST_CASE("[core/operator] Conv_Op(ForwardDims) ", "[Operator][ForwardDims][Conv
REQUIRE(conv1.forwardDims());
REQUIRE(conv1.getOutput(0)->dims() == expectedOutputDims);
}
SECTION("I:NCHW O:NHWC W:NCHW") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16,3,224,450}));
std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4,3,3,4}));
const std::vector<std::size_t> expectedOutputDims({16,222,447,4});
SECTION("I:NCHW O:NCHW W:NHWC") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16, 3, 224, 450}));
std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4, 3, 3, 4})); // Out_ch, H, W, In_ch
const std::vector<std::size_t> expectedOutputDims({16, 4, 222, 447});
auto conv1 = Conv_Op<2>(std::array<size_t, 2>{3, 4});
// Set DataFormat
conv1.getOutput(0)->setDataFormat(Aidge::DataFormat::NHWC);
// Set DataFormat
conv1.getOutput(0)->setDataFormat(Aidge::DataFormat::NCHW);
input->setDataFormat(Aidge::DataFormat::NCHW);
weight->setDataFormat(Aidge::DataFormat::NCHW);
weight->setDataFormat(Aidge::DataFormat::NHWC); // NHWC weight format
// Set inputs
conv1.setInput(1, weight);
conv1.setInput(0, input);
REQUIRE(conv1.forwardDims());
REQUIRE(conv1.getOutput(0)->dims() == expectedOutputDims);
}
SECTION("I:NHWC O:NCHW W:NHWC") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16,224,450,3}));
std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4,4,3,3})); // H, W, In_ch, Out_ch
const std::vector<std::size_t> expectedOutputDims({16,4,222,447});
SECTION("I:NHWC O:NHWC W:NCHW") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16, 3, 224, 450}));
std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4, 3, 3, 4})); // Out_ch, In_ch, H, W
const std::vector<std::size_t> expectedOutputDims({16, 222, 447, 4});
auto conv1 = Conv_Op<2>(std::array<size_t, 2>{3, 4});
// Set DataFormat
conv1.getOutput(0)->setDataFormat(Aidge::DataFormat::NCHW);
// Set DataFormat
conv1.getOutput(0)->setDataFormat(Aidge::DataFormat::NHWC);
input->setDataFormat(Aidge::DataFormat::NHWC);
weight->setDataFormat(Aidge::DataFormat::NHWC);
weight->setDataFormat(Aidge::DataFormat::NCHW); // NCHW weight format
// Set inputs
conv1.setInput(1, weight);
conv1.setInput(0, input);
REQUIRE(conv1.forwardDims());
REQUIRE(conv1.getOutput(0)->dims() == expectedOutputDims);
}
SECTION("I:NHWC O:NHWC W:NCHW") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16,224,450,3}));
std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4,3,3,4})); // Out_ch, In_ch, H, W
const std::vector<std::size_t> expectedOutputDims({16,222,447,4});
SECTION("I:NHWC O:NHWC W:NHWC") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16, 3,224, 450}));
std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4, 3, 3, 4})); // (Out_ch, H, W, In_ch)
const std::vector<std::size_t> expectedOutputDims({16, 222, 447, 4});
auto conv1 = Conv_Op<2>(std::array<size_t, 2>{3, 4});
// Set DataFormat
// Set DataFormat
conv1.getOutput(0)->setDataFormat(Aidge::DataFormat::NHWC);
input->setDataFormat(Aidge::DataFormat::NHWC);
weight->setDataFormat(Aidge::DataFormat::NCHW);
weight->setDataFormat(Aidge::DataFormat::NHWC);
// Set inputs
conv1.setInput(1, weight);
conv1.setInput(0, input);
REQUIRE(conv1.forwardDims());
REQUIRE(conv1.getOutput(0)->dims() == expectedOutputDims);
}
// SECTION("I:NCHW O:NHWC W:NCHW") {
// std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16,3,224,450}));
// std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4,3,3,4}));
// const std::vector<std::size_t> expectedOutputDims({16,222,447,4});
// auto conv1 = Conv_Op<2>(std::array<size_t, 2>{3, 4});
// // Set DataFormat
// conv1.getOutput(0)->setDataFormat(Aidge::DataFormat::NHWC);
// input->setDataFormat(Aidge::DataFormat::NCHW);
// weight->setDataFormat(Aidge::DataFormat::NCHW);
// // Set inputs
// conv1.setInput(1, weight);
// conv1.setInput(0, input);
// REQUIRE(conv1.forwardDims());
// REQUIRE(conv1.getOutput(0)->dims() == expectedOutputDims);
// }
// SECTION("I:NHWC O:NCHW W:NHWC") {
// std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16,224,450,3}));
// std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4,4,3,3})); // H, W, In_ch, Out_ch
// const std::vector<std::size_t> expectedOutputDims({16,4,222,447});
// auto conv1 = Conv_Op<2>(std::array<size_t, 2>{3, 4});
// // Set DataFormat
// conv1.getOutput(0)->setDataFormat(Aidge::DataFormat::NCHW);
// input->setDataFormat(Aidge::DataFormat::NHWC);
// weight->setDataFormat(Aidge::DataFormat::NHWC);
// // Set inputs
// conv1.setInput(1, weight);
// conv1.setInput(0, input);
// REQUIRE(conv1.forwardDims());
// REQUIRE(conv1.getOutput(0)->dims() == expectedOutputDims);
// }
// SECTION("I:NHWC O:NHWC W:NCHW") {
// std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16,224,450,3}));
// std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4,3,3,4})); // Out_ch, In_ch, H, W
// const std::vector<std::size_t> expectedOutputDims({16,222,447,4});
// auto conv1 = Conv_Op<2>(std::array<size_t, 2>{3, 4});
// // Set DataFormat
// conv1.getOutput(0)->setDataFormat(Aidge::DataFormat::NHWC);
// input->setDataFormat(Aidge::DataFormat::NHWC);
// weight->setDataFormat(Aidge::DataFormat::NCHW);
// // Set inputs
// conv1.setInput(1, weight);
// conv1.setInput(0, input);
// REQUIRE(conv1.forwardDims());
// REQUIRE(conv1.getOutput(0)->dims() == expectedOutputDims);
// }
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment