diff --git a/include/aidge/operator/Conv.hpp b/include/aidge/operator/Conv.hpp index f0c6c12d7fbc4e655e2cd5c84c7732ffeb96bbb4..d6a0df5ab472c4a728e5b5042258d6d2bd34f871 100644 --- a/include/aidge/operator/Conv.hpp +++ b/include/aidge/operator/Conv.hpp @@ -117,18 +117,18 @@ public: } associated &= !(getInput(i)->empty()); } - AIDGE_ASSERT((getInput(0)->nbDims() == (DIM+2)) && + if (associated) { + AIDGE_ASSERT((getInput(0)->nbDims() == (DIM+2)) && (getInput(0)->template dims<DIM+2>()[1] == this->template getAttr<ConvAttr::InChannels>()), "Wrong input size for Conv operator."); - AIDGE_ASSERT((getInput(1)->nbDims() == (DIM+2)) && - (getInput(1)->template dims<DIM+2>()[1] == this->template getAttr<ConvAttr::InChannels>()) && - (getInput(1)->template dims<DIM+2>()[0] == this->template getAttr<ConvAttr::OutChannels>()), - "Wrong weight size for Conv operator."); - if(!this->template getAttr<ConvAttr::NoBias>()) - AIDGE_ASSERT((getInput(2)->nbDims() == (1)) && - (getInput(2)->template dims<1>()[0] == this->template getAttr<ConvAttr::OutChannels>()), - "Wrong bias size for Conv operator."); - if (associated) { + AIDGE_ASSERT((getInput(1)->nbDims() == (DIM+2)) && + (getInput(1)->template dims<DIM+2>()[1] == this->template getAttr<ConvAttr::InChannels>()) && + (getInput(1)->template dims<DIM+2>()[0] == this->template getAttr<ConvAttr::OutChannels>()), + "Wrong weight size for Conv operator."); + if(!this->template getAttr<ConvAttr::NoBias>()) + AIDGE_ASSERT((getInput(2)->nbDims() == (1)) && + (getInput(2)->template dims<1>()[0] == this->template getAttr<ConvAttr::OutChannels>()), + "Wrong bias size for Conv operator."); std::array<DimSize_t, DIM + 2> outputDims{}; const std::array<DimSize_t, DIM + 2> inputDims(getInput(0)->template dims<DIM+2>()); diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp index 437780b959b37e0cf6b5b7796e71c9b931f25bc0..8403686d16da15e7e8ad4616029a241d6197d450 100644 --- a/unit_tests/graph/Test_GraphView.cpp +++ b/unit_tests/graph/Test_GraphView.cpp @@ -648,11 +648,8 @@ TEST_CASE("[GraphView] clone") { auto conv1 = Conv(3, 32, {3, 3}, "conv1"); auto conv2 = Conv(32, 64, {3, 3}, "conv2"); auto conv3 = Conv(64, 10, {1, 1}, "conv3"); - auto g1 = std::make_shared<GraphView>("TestGraph"); + auto g1 = Sequential({conv1, conv2, conv3}); dataProvider->addChild(conv1, 0); - g1->add(conv1); - g1->addChild(conv2, conv1, 0); - g1->addChild(conv3, conv2, 0); g1->save("clone_g1"); SECTION("Check input-output connections") {