Skip to content
Snippets Groups Projects
Commit 19e303d5 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

support only bias of size outChannels for factory func limitations

parent f2de6fbb
No related branches found
No related tags found
1 merge request!350Draft: Fix import tests
...@@ -45,8 +45,12 @@ bool Aidge::FC_Op::forwardDims(bool /*allowDataDependency*/) { ...@@ -45,8 +45,12 @@ bool Aidge::FC_Op::forwardDims(bool /*allowDataDependency*/) {
// first check weight since it defines inChannels and outChannels // first check weight since it defines inChannels and outChannels
AIDGE_ASSERT((getInput(1)->nbDims() == 2), AIDGE_ASSERT((getInput(1)->nbDims() == 2),
"Wrong weight Tensor dimension: {} for FC operator (should have 2 dimensions).", getInput(1)->nbDims()); "Wrong weight Tensor dimension: {} for FC operator (should have 2 dimensions).", getInput(1)->nbDims());
const DimSize_t outChannels = mAttributes->template getAttr<FCAttr::TransB>() ? getInput(1)->template dims<2>()[1]:getInput(1)->template dims<2>()[0]; const DimSize_t outChannels = mAttributes->template getAttr<FCAttr::TransB>() ?
const DimSize_t inChannels = mAttributes->template getAttr<FCAttr::TransB>() ? getInput(1)->template dims<2>()[0]:getInput(1)->template dims<2>()[1]; getInput(1)->template dims<2>()[1]:
getInput(1)->template dims<2>()[0];
const DimSize_t inChannels = mAttributes->template getAttr<FCAttr::TransB>() ?
getInput(1)->template dims<2>()[0]:
getInput(1)->template dims<2>()[1];
// check data // check data
const std::vector<DimSize_t>& inputDims = getInput(0)->dims(); const std::vector<DimSize_t>& inputDims = getInput(0)->dims();
const DimIdx_t inChannelsIdx = mAttributes->template getAttr<FCAttr::TransA>() ? 1 : 0; const DimIdx_t inChannelsIdx = mAttributes->template getAttr<FCAttr::TransA>() ? 1 : 0;
...@@ -64,18 +68,11 @@ bool Aidge::FC_Op::forwardDims(bool /*allowDataDependency*/) { ...@@ -64,18 +68,11 @@ bool Aidge::FC_Op::forwardDims(bool /*allowDataDependency*/) {
nbInputFeatures, inChannels); nbInputFeatures, inChannels);
} }
// check optional bias // check optional bias
const DimSize_t batchSize = static_cast<DimSize_t>(getInput(0)->size() / inChannels); if(getInput(2)) {
if(getInput(2)) AIDGE_ASSERT(getInput(2)->size() == outChannels, "Wrong bias size for FC operator.");
AIDGE_ASSERT((((getInput(2)->nbDims() == 1) && }
(getInput(2)->template dims<1>()[0] == outChannels)) ||
((getInput(2)->nbDims() == 2)&&
(getInput(0)->nbDims() == 2)&&
(getInput(2)->template dims<2>()[0] == batchSize) &&
(getInput(2)->template dims<2>()[1] == outChannels)
)),
"Wrong bias size for FC operator.");
// <batch, OutChannels> // <batch, OutChannels>
mOutputs[0]->resize({batchSize, outChannels}); mOutputs[0]->resize({static_cast<DimSize_t>(getInput(0)->size() / inChannels), outChannels});
return true; return true;
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment