Skip to content
Snippets Groups Projects
Commit ad111c7f authored by Wissam Boussella's avatar Wissam Boussella Committed by Olivier BICHLER
Browse files

New unit_test for forward_conv and some fix about nhwc format

parent 7ff77b54
No related branches found
No related tags found
No related merge requests found
...@@ -163,6 +163,7 @@ public: ...@@ -163,6 +163,7 @@ public:
if (!getInput(1)) { if (!getInput(1)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Convolution operator has no weight Tensor associated so no specific number of input channel imposed."); AIDGE_THROW_OR_ABORT(std::runtime_error, "Convolution operator has no weight Tensor associated so no specific number of input channel imposed.");
} }
if(getInput(1)->dataFormat()==Aidge::DataFormat::NHWC) if(getInput(1)->dataFormat()==Aidge::DataFormat::NHWC)
return getInput(1)->template dims<DIM+2>()[DIM+1]; return getInput(1)->template dims<DIM+2>()[DIM+1];
return getInput(1)->template dims<DIM+2>()[1]; return getInput(1)->template dims<DIM+2>()[1];
...@@ -177,8 +178,6 @@ public: ...@@ -177,8 +178,6 @@ public:
if (!getInput(1)) { if (!getInput(1)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Convolution operator has no weight Tensor associated so no specific number of output channel imposed."); AIDGE_THROW_OR_ABORT(std::runtime_error, "Convolution operator has no weight Tensor associated so no specific number of output channel imposed.");
} }
if(getInput(1)->dataFormat()==Aidge::DataFormat::NHWC)
return getInput(1)->template dims<DIM+2>()[DIM+1];
return getInput(1)->template dims<DIM+2>()[0]; return getInput(1)->template dims<DIM+2>()[0];
} }
......
...@@ -63,7 +63,7 @@ bool Aidge::Conv_Op<DIM>::forwardDims(bool /*allowDataDependency*/) { ...@@ -63,7 +63,7 @@ bool Aidge::Conv_Op<DIM>::forwardDims(bool /*allowDataDependency*/) {
"Wrong bias size ({}) for Conv operator. Expected dims are [{}].", getInput(2)->dims(), outChannels()); "Wrong bias size ({}) for Conv operator. Expected dims are [{}].", getInput(2)->dims(), outChannels());
const std::array<DimSize_t, DIM + 2> inputDims(getInput(0)->template dims<DIM+2>()); const std::array<DimSize_t, DIM + 2> inputDims(getInput(0)->template dims<DIM+2>());
std::array<DimSize_t, DIM + 2> outputDims; std::array<DimSize_t, DIM + 2> outputDims{};
unsigned int in_dims_index = (getInput(0)->dataFormat() == Aidge::DataFormat::NHWC) ? 1 : 2; unsigned int in_dims_index = (getInput(0)->dataFormat() == Aidge::DataFormat::NHWC) ? 1 : 2;
......
...@@ -22,6 +22,90 @@ ...@@ -22,6 +22,90 @@
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
namespace Aidge { namespace Aidge {
TEST_CASE("[core/operator] Conv_Op(ForwardDims) ", "[Operator][ForwardDims][Conv]") {
SECTION("I:NCHW O:NCHW W:NCHW"){
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16,3,224,450}));
std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4,3,3,4})); // Out_ch, In_ch_h,W,H
const std::vector<std::size_t> expectedOutputDims({16,4,222,447});
auto conv1 = Conv_Op<2>(std::array<size_t, 2>{3, 4});
//Set DataFormat
conv1.getOutput(0)->setDataFormat(Aidge::DataFormat::NCHW);
input->setDataFormat(Aidge::DataFormat::NCHW);
weight->setDataFormat(Aidge::DataFormat::NCHW);
//Set inputs
conv1.setInput(1,weight);
conv1.setInput(0,input);
REQUIRE(conv1.forwardDims());
REQUIRE(conv1.getOutput(0)->dims() == expectedOutputDims);
}
SECTION("I:NCHW O:NHWC W:NCHW") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16,3,224,450}));
std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4,3,3,4}));
const std::vector<std::size_t> expectedOutputDims({16,222,447,4});
auto conv1 = Conv_Op<2>(std::array<size_t, 2>{3, 4});
// Set DataFormat
conv1.getOutput(0)->setDataFormat(Aidge::DataFormat::NHWC);
input->setDataFormat(Aidge::DataFormat::NCHW);
weight->setDataFormat(Aidge::DataFormat::NCHW);
// Set inputs
conv1.setInput(1, weight);
conv1.setInput(0, input);
REQUIRE(conv1.forwardDims());
REQUIRE(conv1.getOutput(0)->dims() == expectedOutputDims);
}
SECTION("I:NHWC O:NCHW W:NHWC") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16,224,450,3}));
std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4,4,3,3})); // H, W, In_ch, Out_ch
const std::vector<std::size_t> expectedOutputDims({16,4,222,447});
auto conv1 = Conv_Op<2>(std::array<size_t, 2>{3, 4});
// Set DataFormat
conv1.getOutput(0)->setDataFormat(Aidge::DataFormat::NCHW);
input->setDataFormat(Aidge::DataFormat::NHWC);
weight->setDataFormat(Aidge::DataFormat::NHWC);
// Set inputs
conv1.setInput(1, weight);
conv1.setInput(0, input);
REQUIRE(conv1.forwardDims());
REQUIRE(conv1.getOutput(0)->dims() == expectedOutputDims);
}
SECTION("I:NHWC O:NHWC W:NCHW") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({16,224,450,3}));
std::shared_ptr<Tensor> weight = std::make_shared<Tensor>(std::vector<std::size_t>({4,3,3,4})); // Out_ch, In_ch, H, W
const std::vector<std::size_t> expectedOutputDims({16,222,447,4});
auto conv1 = Conv_Op<2>(std::array<size_t, 2>{3, 4});
// Set DataFormat
conv1.getOutput(0)->setDataFormat(Aidge::DataFormat::NHWC);
input->setDataFormat(Aidge::DataFormat::NHWC);
weight->setDataFormat(Aidge::DataFormat::NCHW);
// Set inputs
conv1.setInput(1, weight);
conv1.setInput(0, input);
REQUIRE(conv1.forwardDims());
REQUIRE(conv1.getOutput(0)->dims() == expectedOutputDims);
}
}
TEST_CASE("[core/operator] Conv_Op(computeReceptiveField)", "[Operator][computeReceptiveField][Conv]") { TEST_CASE("[core/operator] Conv_Op(computeReceptiveField)", "[Operator][computeReceptiveField][Conv]") {
auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider");
auto conv1 = Conv(3, 32, {5, 5}, "conv1"); // output dims: {16, 32, 220, 220} auto conv1 = Conv(3, 32, {5, 5}, "conv1"); // output dims: {16, 32, 220, 220}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment