diff --git a/src/operator/FC.cpp b/src/operator/FC.cpp index d73c8fd8cb14c26ccf080cb5ae3e107e25d2a955..54f28507b906eb435dbdb9d2cec92b71c813b760 100644 --- a/src/operator/FC.cpp +++ b/src/operator/FC.cpp @@ -45,8 +45,12 @@ bool Aidge::FC_Op::forwardDims(bool /*allowDataDependency*/) { // first check weight since it defines inChannels and outChannels AIDGE_ASSERT((getInput(1)->nbDims() == 2), "Wrong weight Tensor dimension: {} for FC operator (should have 2 dimensions).", getInput(1)->nbDims()); - const DimSize_t outChannels = mAttributes->template getAttr<FCAttr::TransB>() ? getInput(1)->template dims<2>()[1]:getInput(1)->template dims<2>()[0]; - const DimSize_t inChannels = mAttributes->template getAttr<FCAttr::TransB>() ? getInput(1)->template dims<2>()[0]:getInput(1)->template dims<2>()[1]; + const DimSize_t outChannels = mAttributes->template getAttr<FCAttr::TransB>() ? + getInput(1)->template dims<2>()[1]: + getInput(1)->template dims<2>()[0]; + const DimSize_t inChannels = mAttributes->template getAttr<FCAttr::TransB>() ? + getInput(1)->template dims<2>()[0]: + getInput(1)->template dims<2>()[1]; // check data const std::vector<DimSize_t>& inputDims = getInput(0)->dims(); const DimIdx_t inChannelsIdx = mAttributes->template getAttr<FCAttr::TransA>() ? 1 : 0; @@ -64,18 +68,11 @@ bool Aidge::FC_Op::forwardDims(bool /*allowDataDependency*/) { nbInputFeatures, inChannels); } // check optional bias - const DimSize_t batchSize = static_cast<DimSize_t>(getInput(0)->size() / inChannels); - if(getInput(2)) - AIDGE_ASSERT((((getInput(2)->nbDims() == 1) && - (getInput(2)->template dims<1>()[0] == outChannels)) || - ((getInput(2)->nbDims() == 2)&& - (getInput(0)->nbDims() == 2)&& - (getInput(2)->template dims<2>()[0] == batchSize) && - (getInput(2)->template dims<2>()[1] == outChannels) - )), - "Wrong bias size for FC operator."); + if(getInput(2)) { + AIDGE_ASSERT(getInput(2)->size() == outChannels, "Wrong bias size for FC operator."); + } // <batch, OutChannels> - mOutputs[0]->resize({batchSize, outChannels}); + mOutputs[0]->resize({static_cast<DimSize_t>(getInput(0)->size() / inChannels), outChannels}); return true; }