Skip to content
Snippets Groups Projects
Commit 5bc72539 authored by Maxence Naud's avatar Maxence Naud
Browse files

Add asserts in Conv::forwardDims() member function

parent 62de2285
No related branches found
No related tags found
No related merge requests found
...@@ -117,6 +117,17 @@ public: ...@@ -117,6 +117,17 @@ public:
} }
associated &= !(getInput(i)->empty()); associated &= !(getInput(i)->empty());
} }
AIDGE_ASSERT((getInput(0)->nbDims() == (DIM+2)) &&
(getInput(0)->template dims<DIM+2>()[1] == this->template getAttr<ConvAttr::InChannels>()),
"Wrong input size for Conv operator.");
AIDGE_ASSERT((getInput(1)->nbDims() == (DIM+2)) &&
(getInput(1)->template dims<DIM+2>()[1] == this->template getAttr<ConvAttr::InChannels>()) &&
(getInput(1)->template dims<DIM+2>()[0] == this->template getAttr<ConvAttr::OutChannels>()),
"Wrong weight size for Conv operator.");
if(!this->template getAttr<ConvAttr::NoBias>())
AIDGE_ASSERT((getInput(2)->nbDims() == (1)) &&
(getInput(2)->template dims<1>()[0] == this->template getAttr<ConvAttr::OutChannels>()),
"Wrong bias size for Conv operator.");
if (associated) { if (associated) {
std::array<DimSize_t, DIM + 2> outputDims{}; std::array<DimSize_t, DIM + 2> outputDims{};
const std::array<DimSize_t, DIM + 2> inputDims(getInput(0)->template dims<DIM+2>()); const std::array<DimSize_t, DIM + 2> inputDims(getInput(0)->template dims<DIM+2>());
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment