From 58c5ee5e0bc6031abdaea9c974803aecdf8539ef Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Fri, 19 Apr 2024 12:19:38 +0000 Subject: [PATCH] Add asserts to Conv forwardDims member function --- include/aidge/operator/Conv.hpp | 20 ++++++++++---------- unit_tests/graph/Test_GraphView.cpp | 5 +---- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/include/aidge/operator/Conv.hpp b/include/aidge/operator/Conv.hpp index f0c6c12d7..d6a0df5ab 100644 --- a/include/aidge/operator/Conv.hpp +++ b/include/aidge/operator/Conv.hpp @@ -117,18 +117,18 @@ public: } associated &= !(getInput(i)->empty()); } - AIDGE_ASSERT((getInput(0)->nbDims() == (DIM+2)) && + if (associated) { + AIDGE_ASSERT((getInput(0)->nbDims() == (DIM+2)) && (getInput(0)->template dims<DIM+2>()[1] == this->template getAttr<ConvAttr::InChannels>()), "Wrong input size for Conv operator."); - AIDGE_ASSERT((getInput(1)->nbDims() == (DIM+2)) && - (getInput(1)->template dims<DIM+2>()[1] == this->template getAttr<ConvAttr::InChannels>()) && - (getInput(1)->template dims<DIM+2>()[0] == this->template getAttr<ConvAttr::OutChannels>()), - "Wrong weight size for Conv operator."); - if(!this->template getAttr<ConvAttr::NoBias>()) - AIDGE_ASSERT((getInput(2)->nbDims() == (1)) && - (getInput(2)->template dims<1>()[0] == this->template getAttr<ConvAttr::OutChannels>()), - "Wrong bias size for Conv operator."); - if (associated) { + AIDGE_ASSERT((getInput(1)->nbDims() == (DIM+2)) && + (getInput(1)->template dims<DIM+2>()[1] == this->template getAttr<ConvAttr::InChannels>()) && + (getInput(1)->template dims<DIM+2>()[0] == this->template getAttr<ConvAttr::OutChannels>()), + "Wrong weight size for Conv operator."); + if(!this->template getAttr<ConvAttr::NoBias>()) + AIDGE_ASSERT((getInput(2)->nbDims() == (1)) && + (getInput(2)->template dims<1>()[0] == this->template getAttr<ConvAttr::OutChannels>()), + "Wrong bias size for Conv operator."); std::array<DimSize_t, DIM + 2> outputDims{}; const std::array<DimSize_t, DIM + 2> inputDims(getInput(0)->template dims<DIM+2>()); diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp index 437780b95..8403686d1 100644 --- a/unit_tests/graph/Test_GraphView.cpp +++ b/unit_tests/graph/Test_GraphView.cpp @@ -648,11 +648,8 @@ TEST_CASE("[GraphView] clone") { auto conv1 = Conv(3, 32, {3, 3}, "conv1"); auto conv2 = Conv(32, 64, {3, 3}, "conv2"); auto conv3 = Conv(64, 10, {1, 1}, "conv3"); - auto g1 = std::make_shared<GraphView>("TestGraph"); + auto g1 = Sequential({conv1, conv2, conv3}); dataProvider->addChild(conv1, 0); - g1->add(conv1); - g1->addChild(conv2, conv1, 0); - g1->addChild(conv3, conv2, 0); g1->save("clone_g1"); SECTION("Check input-output connections") { -- GitLab