Skip to content
Snippets Groups Projects
Commit cc2f1ab8 authored by Maxence Naud's avatar Maxence Naud
Browse files

Remove the outChannels attribute from FC_Op as it is only used for Weight and Bias creation

parent e89f537c
No related branches found
No related tags found
No related merge requests found
...@@ -24,26 +24,24 @@ ...@@ -24,26 +24,24 @@
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
namespace Aidge { namespace Aidge {
enum class FCAttr { OutChannels, NoBias }; enum class FCAttr { NoBias };
class FC_Op : public OperatorTensor, class FC_Op : public OperatorTensor,
public Registrable<FC_Op, public Registrable<FC_Op,
std::string, std::string,
std::shared_ptr<OperatorImpl>(const FC_Op &)>, std::shared_ptr<OperatorImpl>(const FC_Op &)>,
public StaticAttributes<FCAttr, DimSize_t, bool> { public StaticAttributes<FCAttr, bool> {
public: public:
static const std::string Type; static const std::string Type;
FC_Op() = delete; FC_Op() = delete;
using Attributes_ = StaticAttributes<FCAttr, DimSize_t, bool>; using Attributes_ = StaticAttributes<FCAttr, bool>;
template <FCAttr e> using attr = typename Attributes_::template attr<e>; template <FCAttr e> using attr = typename Attributes_::template attr<e>;
FC_Op(DimSize_t out_channels, bool noBias) FC_Op(bool noBias)
: OperatorTensor(Type, 1, 2, 1), : OperatorTensor(Type, 1, 2, 1),
Attributes_( Attributes_(attr<FCAttr::NoBias>(noBias))
attr<FCAttr::OutChannels>(out_channels),
attr<FCAttr::NoBias>(noBias))
{} {}
/** /**
...@@ -83,9 +81,9 @@ public: ...@@ -83,9 +81,9 @@ public:
} }
}; };
inline std::shared_ptr<Node> FC(DimSize_t inChannels, DimSize_t outChannels, bool noBias = false, const std::string& name = "") { inline std::shared_ptr<Node> FC(const DimSize_t inChannels, const DimSize_t outChannels, bool noBias = false, const std::string& name = "") {
// FIXME: properly handle default w&b initialization in every cases // FIXME: properly handle default w&b initialization in every cases
auto fc = std::make_shared<Node>(std::make_shared<FC_Op>(outChannels, noBias), name); auto fc = std::make_shared<Node>(std::make_shared<FC_Op>(noBias), name);
addProducer(fc, 1, {outChannels, inChannels}, "w"); addProducer(fc, 1, {outChannels, inChannels}, "w");
addProducer(fc, 2, {(noBias ? 0 : outChannels)}, "b"); // already sets bias dims addProducer(fc, 2, {(noBias ? 0 : outChannels)}, "b"); // already sets bias dims
return fc; return fc;
...@@ -94,8 +92,7 @@ inline std::shared_ptr<Node> FC(DimSize_t inChannels, DimSize_t outChannels, boo ...@@ -94,8 +92,7 @@ inline std::shared_ptr<Node> FC(DimSize_t inChannels, DimSize_t outChannels, boo
namespace { namespace {
template <> template <>
const char *const EnumStrings<Aidge::FCAttr>::data[] = {"OutChannels", const char *const EnumStrings<Aidge::FCAttr>::data[] = {"NoBias"};
"NoBias"};
} }
#endif /* AIDGE_CORE_OPERATOR_FC_H_ */ #endif /* AIDGE_CORE_OPERATOR_FC_H_ */
...@@ -45,8 +45,31 @@ bool Aidge::FC_Op::forwardDims(bool /*allowDataDependency*/) { ...@@ -45,8 +45,31 @@ bool Aidge::FC_Op::forwardDims(bool /*allowDataDependency*/) {
associated &= !(getInput(i)->empty()); associated &= !(getInput(i)->empty());
} }
if (associated) { if (associated) {
// first check weight since it defines inChannels and outChannels
AIDGE_ASSERT((getInput(1)->nbDims() == 2),
"Wrong weight Tensor dimension: {} for FC operator (should have 2 dimensions).", getInput(1)->nbDims());
const DimSize_t outChannels = getInput(1)->template dims<2>()[0];
const DimSize_t inChannels = getInput(1)->template dims<2>()[1];
// check data
const std::vector<DimSize_t>& inputDims = getInput(0)->dims();
if (getInput(0)->nbDims() == 1) {
AIDGE_ASSERT(inputDims[0] == inChannels,
"Wrong number of input features for input data ({}), expected {}",
inputDims[0], inChannels);
} else {
AIDGE_ASSERT(getInput(0)->nbDims() > 1, "FC input data must have at least one dimension");
const DimSize_t nbInputFeatures = std::accumulate(inputDims.cbegin() + 1, inputDims.cend(), DimSize_t(1), std::multiplies<DimSize_t>());
AIDGE_ASSERT(nbInputFeatures == inChannels,
"Wrong number of input features for input data ({}), expected {}",
nbInputFeatures, inChannels);
}
// check optional bias
if(!this->template getAttr<FCAttr::NoBias>())
AIDGE_ASSERT((getInput(2)->nbDims() == 1) &&
(getInput(2)->template dims<1>()[0] == outChannels),
"Wrong bias size for FC operator.");
// <batch, OutChannels> // <batch, OutChannels>
mOutputs[0]->resize({getInput(0)->dims()[0], this->template getAttr<FCAttr::OutChannels>()}); mOutputs[0]->resize({getInput(0)->dims()[0], outChannels});
} }
return associated; return associated;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment