Skip to content
Snippets Groups Projects
Commit a5382e31 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

support only bias of size outChannels for factory func limitations

parent aab35aab
No related branches found
No related tags found
No related merge requests found
......@@ -45,8 +45,12 @@ bool Aidge::FC_Op::forwardDims(bool /*allowDataDependency*/) {
// first check weight since it defines inChannels and outChannels
AIDGE_ASSERT((getInput(1)->nbDims() == 2),
"Wrong weight Tensor dimension: {} for FC operator (should have 2 dimensions).", getInput(1)->nbDims());
const DimSize_t outChannels = mAttributes->template getAttr<FCAttr::TransB>() ? getInput(1)->template dims<2>()[1]:getInput(1)->template dims<2>()[0];
const DimSize_t inChannels = mAttributes->template getAttr<FCAttr::TransB>() ? getInput(1)->template dims<2>()[0]:getInput(1)->template dims<2>()[1];
const DimSize_t outChannels = mAttributes->template getAttr<FCAttr::TransB>() ?
getInput(1)->template dims<2>()[1]:
getInput(1)->template dims<2>()[0];
const DimSize_t inChannels = mAttributes->template getAttr<FCAttr::TransB>() ?
getInput(1)->template dims<2>()[0]:
getInput(1)->template dims<2>()[1];
// check data
const std::vector<DimSize_t>& inputDims = getInput(0)->dims();
const DimIdx_t inChannelsIdx = mAttributes->template getAttr<FCAttr::TransA>() ? 1 : 0;
......@@ -64,18 +68,11 @@ bool Aidge::FC_Op::forwardDims(bool /*allowDataDependency*/) {
nbInputFeatures, inChannels);
}
// check optional bias
const DimSize_t batchSize = static_cast<DimSize_t>(getInput(0)->size() / inChannels);
if(getInput(2))
AIDGE_ASSERT((((getInput(2)->nbDims() == 1) &&
(getInput(2)->template dims<1>()[0] == outChannels)) ||
((getInput(2)->nbDims() == 2)&&
(getInput(0)->nbDims() == 2)&&
(getInput(2)->template dims<2>()[0] == batchSize) &&
(getInput(2)->template dims<2>()[1] == outChannels)
)),
"Wrong bias size for FC operator.");
if(getInput(2)) {
AIDGE_ASSERT(getInput(2)->size() == outChannels, "Wrong bias size for FC operator.");
}
// <batch, OutChannels>
mOutputs[0]->resize({batchSize, outChannels});
mOutputs[0]->resize({static_cast<DimSize_t>(getInput(0)->size() / inChannels), outChannels});
return true;
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment