diff --git a/include/aidge/operator/FC.hpp b/include/aidge/operator/FC.hpp index b97874f4e0deafd685453b3ce9865e65fafe7561..0b87ff7882c6bf0745cdc52a470b83a8276ad9d5 100644 --- a/include/aidge/operator/FC.hpp +++ b/include/aidge/operator/FC.hpp @@ -24,26 +24,24 @@ #include "aidge/utils/Registrar.hpp" namespace Aidge { -enum class FCAttr { OutChannels, NoBias }; +enum class FCAttr { NoBias }; class FC_Op : public OperatorTensor, public Registrable<FC_Op, std::string, std::shared_ptr<OperatorImpl>(const FC_Op &)>, - public StaticAttributes<FCAttr, DimSize_t, bool> { + public StaticAttributes<FCAttr, bool> { public: static const std::string Type; FC_Op() = delete; - using Attributes_ = StaticAttributes<FCAttr, DimSize_t, bool>; + using Attributes_ = StaticAttributes<FCAttr, bool>; template <FCAttr e> using attr = typename Attributes_::template attr<e>; - FC_Op(DimSize_t out_channels, bool noBias) + FC_Op(bool noBias) : OperatorTensor(Type, 1, 2, 1), - Attributes_( - attr<FCAttr::OutChannels>(out_channels), - attr<FCAttr::NoBias>(noBias)) + Attributes_(attr<FCAttr::NoBias>(noBias)) {} /** @@ -83,9 +81,9 @@ public: } }; -inline std::shared_ptr<Node> FC(DimSize_t inChannels, DimSize_t outChannels, bool noBias = false, const std::string& name = "") { +inline std::shared_ptr<Node> FC(const DimSize_t inChannels, const DimSize_t outChannels, bool noBias = false, const std::string& name = "") { // FIXME: properly handle default w&b initialization in every cases - auto fc = std::make_shared<Node>(std::make_shared<FC_Op>(outChannels, noBias), name); + auto fc = std::make_shared<Node>(std::make_shared<FC_Op>(noBias), name); addProducer(fc, 1, {outChannels, inChannels}, "w"); addProducer(fc, 2, {(noBias ? 0 : outChannels)}, "b"); // already sets bias dims return fc; @@ -94,8 +92,7 @@ inline std::shared_ptr<Node> FC(DimSize_t inChannels, DimSize_t outChannels, boo namespace { template <> -const char *const EnumStrings<Aidge::FCAttr>::data[] = {"OutChannels", - "NoBias"}; +const char *const EnumStrings<Aidge::FCAttr>::data[] = {"NoBias"}; } #endif /* AIDGE_CORE_OPERATOR_FC_H_ */ diff --git a/src/operator/FC.cpp b/src/operator/FC.cpp index ba7e29e7b6543a570ceede6158bd306286037c10..d3bfd4557044c49b452de7690541a1c0a2ac62d9 100644 --- a/src/operator/FC.cpp +++ b/src/operator/FC.cpp @@ -45,8 +45,31 @@ bool Aidge::FC_Op::forwardDims(bool /*allowDataDependency*/) { associated &= !(getInput(i)->empty()); } if (associated) { + // first check weight since it defines inChannels and outChannels + AIDGE_ASSERT((getInput(1)->nbDims() == 2), + "Wrong weight Tensor dimension: {} for FC operator (should have 2 dimensions).", getInput(1)->nbDims()); + const DimSize_t outChannels = getInput(1)->template dims<2>()[0]; + const DimSize_t inChannels = getInput(1)->template dims<2>()[1]; + // check data + const std::vector<DimSize_t>& inputDims = getInput(0)->dims(); + if (getInput(0)->nbDims() == 1) { + AIDGE_ASSERT(inputDims[0] == inChannels, + "Wrong number of input features for input data ({}), expected {}", + inputDims[0], inChannels); + } else { + AIDGE_ASSERT(getInput(0)->nbDims() > 1, "FC input data must have at least one dimension"); + const DimSize_t nbInputFeatures = std::accumulate(inputDims.cbegin() + 1, inputDims.cend(), DimSize_t(1), std::multiplies<DimSize_t>()); + AIDGE_ASSERT(nbInputFeatures == inChannels, + "Wrong number of input features for input data ({}), expected {}", + nbInputFeatures, inChannels); + } + // check optional bias + if(!this->template getAttr<FCAttr::NoBias>()) + AIDGE_ASSERT((getInput(2)->nbDims() == 1) && + (getInput(2)->template dims<1>()[0] == outChannels), + "Wrong bias size for FC operator."); // <batch, OutChannels> - mOutputs[0]->resize({getInput(0)->dims()[0], this->template getAttr<FCAttr::OutChannels>()}); + mOutputs[0]->resize({getInput(0)->dims()[0], outChannels}); } return associated;