diff --git a/src/operator/Conv.cpp b/src/operator/Conv.cpp index 42c47cb81d95569306beaf25f28bb7bd56bfe734..994388fa8f78d4dfb4362361dabb2f2d9344e34b 100644 --- a/src/operator/Conv.cpp +++ b/src/operator/Conv.cpp @@ -44,15 +44,21 @@ bool Aidge::Conv_Op<DIM>::forwardDims(bool /*allowDataDependency*/) { if (!inputsAssociated()) return false; // first check weight since it defines inChannels and outChannels - if(getInput(0)->dataFormat() == Aidge::DataFormat::NHWC){ - AIDGE_ASSERT((getInput(0)->nbDims() == (DIM+2)) && - (getInput(0)->template dims<DIM+2>()[DIM+1] == inChannels()), - "Wrong input channel size ({}) for Conv operator. Expected dims are [x, {}, {}].", getInput(0)->dims(), fmt::join(std::vector<std::string>(DIM, "x"), ", "), inChannels()); + if (getInput(0)->nbDims() != (DIM+2)) { + Log::error("Wrong number of dimensions for input '{}'.", getInputsName()[0]); + return false; + } + if(getInput(0)->dataFormat() == Aidge::DataFormat::NHWC) { + if (getInput(0)->template dims<DIM+2>()[DIM+1] != inChannels()) { + Log::error("Wrong input channel size ({}) for Conv operator. Expected dims are [x, {}, {}].", getInput(0)->dims(), fmt::join(std::vector<std::string>(DIM, "x"), ", "), inChannels()); + return false; + } } else{ //For dataFormat in NCHW or Default Format - AIDGE_ASSERT((getInput(0)->nbDims() == (DIM+2)) && - (getInput(0)->template dims<DIM+2>()[1] == inChannels()), - "Wrong input size ({}) for Conv operator. Expected dims are [x, {}, {}].", getInput(0)->dims(), inChannels(), fmt::join(std::vector<std::string>(DIM, "x"), ", ")); + if(getInput(0)->template dims<DIM+2>()[1] != inChannels()) { + Log::error("Wrong input size ({}) for Conv operator. Expected dims are [x, {}, {}].", getInput(0)->dims(), inChannels(), fmt::join(std::vector<std::string>(DIM, "x"), ", ")); + return false; + } } AIDGE_ASSERT((getInput(1)->nbDims() == (DIM+2)),