diff --git a/include/aidge/operator/FC.hpp b/include/aidge/operator/FC.hpp index 393e640d60934059a9c216a9335a7018388fe9da..4633046732228650a2320badd895bcb7c682064e 100644 --- a/include/aidge/operator/FC.hpp +++ b/include/aidge/operator/FC.hpp @@ -25,6 +25,11 @@ namespace Aidge { +enum class FCAttr { + Alpha, // The scalar multiplier for the product of input tensors A * B. + Beta, // The scalar multiplier for the bias. +}; + /** * @brief Description of a Fully Connected (FC) operation on an input Tensor. * @@ -54,6 +59,15 @@ class FC_Op : public OperatorTensor, public Registrable<FC_Op, std::string, std::function<std::shared_ptr<OperatorImpl>(const FC_Op &)>> { +private: + using Attributes_ = StaticAttributes<FCAttr, + float, + float>; + + template <FCAttr e> + using attr = typename Attributes_::template attr<e>; + + const std::shared_ptr<Attributes_> mAttributes; public: /** * @brief Static type identifier for the FC operator. @@ -65,8 +79,11 @@ public: * * Initializes the operator with a type identifier and input categories. */ - FC_Op() - : OperatorTensor(Type, {InputCategory::Data, InputCategory::Param, InputCategory::OptionalParam}, 1) + FC_Op(float alpha = 1.0f, float beta = 1.0f) + : OperatorTensor(Type, {InputCategory::Data, InputCategory::Param, InputCategory::OptionalParam}, 1), + mAttributes(std::make_shared<Attributes_>( + attr<FCAttr::Alpha>(alpha), + attr<FCAttr::Beta>(beta))) {} /** @@ -160,6 +177,24 @@ public: return getInput(1)->template dims<2>()[0]; } + /** + * @brief Get the attributes of the operator. + * @return A shared pointer to the operator's attributes. + */ + inline std::shared_ptr<Attributes> attributes() const override { return mAttributes; } + + /** + * @brief Get the alpha coefficient. + * @return The alpha coefficient. + */ + inline float& alpha() const { return mAttributes->template getAttr<FCAttr::Alpha>(); } + + /** + * @brief Get the beta coefficient. + * @return The beta coefficient. + */ + inline float& beta() const { return mAttributes->template getAttr<FCAttr::Beta>(); } + /** * @brief Retrieves the input tensor names for the FC operator. * @return A vector of input tensor names: `{"data_input", "weight", "bias"}`. @@ -180,16 +215,20 @@ public: /** * @brief Creates a Fully Connected operation node. * - * Constructs an FC operator node with the specified input and output channels. - * * @param[in] inChannels Number of input channels. * @param[in] outChannels Number of output channels. + * @param[in] alpha Scalar multiplier for the product of input tensors A * B. + * @param[in] beta Scalar multiplier for the bias. * @param[in] noBias Flag indicating whether to use a bias term (default is `false`). * @param[in] name Name of the operator (optional). * @return A shared pointer to the Node containing the FC operator. */ -std::shared_ptr<Node> FC(const DimSize_t inChannels, const DimSize_t outChannels, bool noBias = false, const std::string& name = ""); +std::shared_ptr<Node> FC(const DimSize_t inChannels, const DimSize_t outChannels, bool noBias = false, const std::string& name = "", float alpha = 1.0f, float beta = 1.0f); } // namespace Aidge +namespace { +template <> +const char *const EnumStrings<Aidge::FCAttr>::data[] = {"alpha", "beta"}; +} #endif /* AIDGE_CORE_OPERATOR_FC_H_ */ diff --git a/python_binding/operator/pybind_FC.cpp b/python_binding/operator/pybind_FC.cpp index c29b6e1d3723f03f6a9c9b1f03156b42160c6cf3..3dc2c1a6f198874d87b7d5f5d1b0b6725df15c99 100644 --- a/python_binding/operator/pybind_FC.cpp +++ b/python_binding/operator/pybind_FC.cpp @@ -40,7 +40,7 @@ void declare_FC(py::module &m) { declare_registrable<FC_Op>(m, "FCOp"); - m.def("FC", &FC, py::arg("in_channels"), py::arg("out_channels"), py::arg("no_bias") = false, py::arg("name") = "", + m.def("FC", &FC, py::arg("in_channels"), py::arg("out_channels"), py::arg("no_bias") = false, py::arg("name") = "", py::arg("alpha")=1.0f, py::arg("beta")=1.0f, R"mydelimiter( Initialize a node containing a Fully Connected (FC) operator. @@ -52,6 +52,10 @@ void declare_FC(py::module &m) { :type no_bias : :py:class:`bool` :param name : Name of the node. :type name : :py:class:`str` + :param alpha : The scalar multiplier for the term A*B. + :type alpha : :py:class:`int` + :param beta : The scalar multiplier for the bias. + :type beta : :py:class:`int` )mydelimiter"); } diff --git a/src/operator/FC.cpp b/src/operator/FC.cpp index dd3ed7aba65cf1875d691d9bc2c8c94bb03856c7..dd1d0577ead56a4c62b4e3e9ca567685221d0a12 100644 --- a/src/operator/FC.cpp +++ b/src/operator/FC.cpp @@ -98,9 +98,11 @@ std::set<std::string> Aidge::FC_Op::getAvailableBackends() const { std::shared_ptr<Aidge::Node> Aidge::FC(const Aidge::DimSize_t inChannels, const Aidge::DimSize_t outChannels, bool noBias, - const std::string& name) { + const std::string& name, + float alpha, + float beta) { // FIXME: properly handle default w&b initialization in every cases - auto fc = std::make_shared<Node>(std::make_shared<FC_Op>(), name); + auto fc = std::make_shared<Node>(std::make_shared<FC_Op>(alpha, beta), name); addProducer(fc, 1, {outChannels, inChannels}, "w"); if (!noBias) { addProducer(fc, 2, {outChannels}, "b"); // already sets bias dims