diff --git a/include/aidge/operator/FC.hpp b/include/aidge/operator/FC.hpp index 373c7da29596b163eae21c1235d6db578b755a5f..3d056f5f12cb3facb7e11cb3b6c176837abdf107 100644 --- a/include/aidge/operator/FC.hpp +++ b/include/aidge/operator/FC.hpp @@ -28,8 +28,6 @@ namespace Aidge { enum class FCAttr { Alpha, // The scalar multiplier for the product of input tensors A * B. Beta, // The scalar multiplier for the bias. - TransA, // Boolean to store whether we need to tranpose input#0 - TransB // Boolean to store whether we need to tranpose input#1 }; /** @@ -38,16 +36,9 @@ enum class FCAttr { * The Fully Connected (FC) operation applies a linear transformation to the input Tensor * by multiplying it with a weight matrix and optionally adding a bias vector: * - If `bias` is included: - * f(x) = x × weights^T + bias + * f(x) = alpha * x * weights^T + beta * bias * - If `bias` is omitted: - * f(x) = x × weights^T - * - * Attributes: - * - `inChannels`: The number of input features (or channels). Determined from the dimensions - * of the weight Tensor. This represents the size of the input vector. - * - `outChannels`: The number of output features (or channels). Determined from the dimensions - * of the weight Tensor. This represents the size of the output vector. - * - `noBias`: A boolean value indicating whether the bias vector is omitted in the operation. + * f(x) = alpha * x × weights^T * * @example: * - Input Tensor: Shape (64, 128) // Batch size of 64, 128 input features @@ -64,9 +55,7 @@ class FC_Op : public OperatorTensor, private: using Attributes_ = StaticAttributes<FCAttr, float, - float, - bool, - bool>; + float>; template <FCAttr e> using attr = typename Attributes_::template attr<e>; @@ -83,13 +72,11 @@ public: * * Initializes the operator with a type identifier and input categories. */ - FC_Op(float alpha = 1.0f, float beta = 1.0f, bool transA = false, bool transB = false) + FC_Op(float alpha = 1.0f, float beta = 1.0f) : OperatorTensor(Type, {InputCategory::Data, InputCategory::Param, InputCategory::OptionalParam}, 1), mAttributes(std::make_shared<Attributes_>( attr<FCAttr::Alpha>(alpha), - attr<FCAttr::Beta>(beta), - attr<FCAttr::TransA>(transA), - attr<FCAttr::TransB>(transB))) + attr<FCAttr::Beta>(beta))) {} /** @@ -201,18 +188,6 @@ public: */ inline float& beta() const { return mAttributes->template getAttr<FCAttr::Beta>(); } - /** - * @brief Get the transA boolean. - * @return Whether input#0 needs to be transposed. - */ - inline bool& transA() const { return mAttributes->template getAttr<FCAttr::TransA>(); } - - /** - * @brief Get the transB boolean. - * @return Whether input#1 needs to be transposed. - */ - inline bool& transB() const { return mAttributes->template getAttr<FCAttr::TransB>(); } - /** * @brief Retrieves the input tensor names for the FC operator. * @return A vector of input tensor names: `{"data_input", "weight", "bias"}`. @@ -238,8 +213,6 @@ public: * @param[in] alpha Scalar multiplier for the product of input tensors A * B. * @param[in] beta Scalar multiplier for the bias. * @param[in] noBias Flag indicating whether to use a bias term (default is `false`). - * @param[in] transA Flag indicating whether input#0 needs to be transposed (default is `false`). - * @param[in] transB Flag indicating whether input#1 needs to be transposed (default is `false`). * @param[in] name Name of the operator (optional). * @return A shared pointer to the Node containing the FC operator. */ @@ -248,14 +221,12 @@ std::shared_ptr<Node> FC(const DimSize_t inChannels, float alpha = 1.0f, float beta = 1.0f, bool noBias = false, - bool transA = false, - bool transB = false, const std::string& name = ""); } // namespace Aidge namespace { template <> -const char *const EnumStrings<Aidge::FCAttr>::data[] = {"alpha", "beta", "transA", "transB"}; +const char *const EnumStrings<Aidge::FCAttr>::data[] = {"alpha", "beta"}; } #endif /* AIDGE_CORE_OPERATOR_FC_H_ */ diff --git a/include/aidge/operator/MetaOperatorDefs.hpp b/include/aidge/operator/MetaOperatorDefs.hpp index ef087926879f129765d3e446be21e7d49baf8045..57cb56ea07b3104bf4f1b31f493f07e7b6bd61de 100644 --- a/include/aidge/operator/MetaOperatorDefs.hpp +++ b/include/aidge/operator/MetaOperatorDefs.hpp @@ -360,6 +360,43 @@ std::shared_ptr<Node> Leaky(const int nbTimeSteps, const LeakyReset resetType = LeakyReset::Subtraction, const std::string &name = ""); + +/** + * @brief Creates a FC operation with transposed inputs. + * + * This function creates a Fully Connected operation with transpose Operation of 1 or both inputs. + * + * @param[in] inChannels Number of input channels. + * @param[in] outChannels Number of output channels. + * @param[in] alpha Scalar multiplier for the product of input tensors A * B. + * @param[in] beta Scalar multiplier for the bias. + * @param[in] name Optional name for the operation. + * @param[in] transposeA Flag indicating whether input#0 needs to be transposed (default is `false`). + * @param[in] transposeB Flag indicating whether input#1 needs to be transposed (default is `false`). + * @return A shared pointer to the Node representing the padded average pooling operation. + */ +extern std::shared_ptr<Node> TransposeFC(DimSize_t in_channels, + DimSize_t out_channels, + float alpha=1.0f, + float beta=1.0f, + const std::string& name = "", + bool no_bias = false, + bool transposeA = false, + bool transposeB = false); + +/** + * @brief Creates a padded convolution operation as a MetaOperator. + * + * This function creates a graph-based MetaOperator representing a padded convolution operation (Conv2D/Conv3D). + * + * @param[in] alpha Scalar multiplier for the product of input tensors A * B. + * @param[in] beta Scalar multiplier for the bias. + * @param[in] transposeA Flag indicating whether input#0 needs to be transposed (default is `false`). + * @param[in] transposeB Flag indicating whether input#1 needs to be transposed (default is `false`). + * @return A shared pointer to the MetaOperator_Op representing the padded convolution operation. + */ +extern std::shared_ptr<MetaOperator_Op> TransposeFC_Op(float alpha = 1.0f, float beta = 1.0f, bool transposeA = false, bool transposeB = false); + } // namespace Aidge #endif /* AIDGE_CORE_OPERATOR_METAOPERATORDEFS_H_ */ diff --git a/python_binding/operator/pybind_FC.cpp b/python_binding/operator/pybind_FC.cpp index dc3f738fbf5613edf4212c15e89df4624084cb39..f915f1fda97347adccfaa07bb653695a519f6453 100644 --- a/python_binding/operator/pybind_FC.cpp +++ b/python_binding/operator/pybind_FC.cpp @@ -30,13 +30,9 @@ void declare_FC(py::module &m) { :type type : :py:class:`str` )mydelimiter") .def(py::init<float, - float, - bool, - bool>(), + float>(), py::arg("alpha")=1.0, - py::arg("beta")=1.0, - py::arg("transA")=false, - py::arg("transB")=false) + py::arg("beta")=1.0) .def_static("get_inputs_name", &FC_Op::getInputsName) .def_static("get_outputs_name", &FC_Op::getOutputsName) .def_readonly_static("Type", &FC_Op::Type) @@ -53,8 +49,6 @@ void declare_FC(py::module &m) { py::arg("alpha")=1.0f, py::arg("beta")=1.0f, py::arg("no_bias") = false, - py::arg("transA") = false, - py::arg("transB") = false, py::arg("name") = "", R"mydelimiter( Initialize a node containing a Fully Connected (FC) operator. @@ -71,10 +65,6 @@ void declare_FC(py::module &m) { :type alpha : :py:class:`int` :param beta : The scalar multiplier for the bias. :type beta : :py:class:`int` - :param transA : Indicates whether first input needs to be transposed. - :type transA : :py:class:`bool` - :param transB : Indicates whether second input needs to be transposed. - :type transB : :py:class:`bool` )mydelimiter"); } diff --git a/python_binding/operator/pybind_MetaOperatorDefs.cpp b/python_binding/operator/pybind_MetaOperatorDefs.cpp index 182a5edaa522f508fe128fa2331289b46e99919c..75e04d6cddf7214f47f5dcf4482e0cd101753d14 100644 --- a/python_binding/operator/pybind_MetaOperatorDefs.cpp +++ b/python_binding/operator/pybind_MetaOperatorDefs.cpp @@ -502,6 +502,57 @@ void declare_LeakyOp(py::module &m) { )mydelimiter"); } +void declare_TransposeFCOp(py::module &m) { + m.def("TransposeFC", [](DimSize_t in_channels, + DimSize_t out_channels, + float alpha, + float beta, + const std::string& name, + bool no_bias, + bool transA, + bool transB) + { + return TransposeFC(in_channels, out_channels,alpha, beta, name, no_bias, transA, transB); + }, py::arg("in_channels"), + py::arg("out_channels"), + py::arg("alpha") = 1.0f, + py::arg("beta") = 1.0f, + py::arg("name") = "", + py::arg("no_bias")= false, + py::arg("transA")= false, + py::arg("transB")= false, + R"mydelimiter( + Initialize a node containing an FC operator with Transpose on one or both inputs. + + :param in_channels: Number of input channels. + :type in_channels: int + :param out_channels: Number of output channels. + :type out_channels: int + + :param no_bias: Whether to disable bias addition in the convolution. + :type no_bias: bool + :param name: Name of the node (optional). + :type name: str + :return: A node containing the FC operator with Transpose node on one or two inputs. + :rtype: :py:class:`TransposeFCOp` + )mydelimiter"); + + m.def("TransposeFCOp", [](float alpha, float beta, bool transA, bool transB) + { + return TransposeFC_Op(alpha, beta, transA, transB); + }, + py::arg("alpha") = 1.0f, + py::arg("beta") = 1.0f, + py::arg("transA")= false, + py::arg("transB")= false, + R"mydelimiter( + Initialize an FC operator with Transpose on one or two inputs. + + :return: An FC with Transpose operators. + :rtype: :py:class:`TransposeFCOp` + )mydelimiter"); + } + void init_MetaOperatorDefs(py::module &m) { declare_PaddedConvOp<1>(m); declare_PaddedConvOp<2>(m); @@ -520,6 +571,7 @@ void init_MetaOperatorDefs(py::module &m) { declare_LSTMOp(m); declare_LeakyResetEnum(m); declare_LeakyOp(m); + declare_TransposeFCOp(m); py::class_<MetaOperator_Op, std::shared_ptr<MetaOperator_Op>, OperatorTensor>(m, "MetaOperatorOp", py::multiple_inheritance()) .def(py::init<const char *, const std::shared_ptr<GraphView>&, const std::vector<InputCategory>&>(), diff --git a/src/operator/FC.cpp b/src/operator/FC.cpp index 13da22423ec1c5be748461f4518db87dc11f4fa6..abe94d92e83bc8b9f805808404b472a39b3b12e8 100644 --- a/src/operator/FC.cpp +++ b/src/operator/FC.cpp @@ -45,24 +45,17 @@ bool Aidge::FC_Op::forwardDims(bool /*allowDataDependency*/) { // first check weight since it defines inChannels and outChannels AIDGE_ASSERT((getInput(1)->nbDims() == 2), "Wrong weight Tensor dimension: {} for FC operator (should have 2 dimensions).", getInput(1)->nbDims()); - const DimSize_t outChannels = mAttributes->template getAttr<FCAttr::TransB>() ? - getInput(1)->template dims<2>()[1]: - getInput(1)->template dims<2>()[0]; - const DimSize_t inChannels = mAttributes->template getAttr<FCAttr::TransB>() ? - getInput(1)->template dims<2>()[0]: - getInput(1)->template dims<2>()[1]; + const DimSize_t outChannels = getInput(1)->template dims<2>()[0]; + const DimSize_t inChannels = getInput(1)->template dims<2>()[1]; // check data const std::vector<DimSize_t>& inputDims = getInput(0)->dims(); - const DimIdx_t inChannelsIdx = mAttributes->template getAttr<FCAttr::TransA>() ? 1 : 0; if (getInput(0)->nbDims() == 1) { - AIDGE_ASSERT(inputDims[inChannelsIdx] == inChannels, + AIDGE_ASSERT(inputDims[0] == inChannels, "Wrong number of input features for input data ({}), expected {}", - inputDims[inChannelsIdx], inChannels); + inputDims[0], inChannels); } else { AIDGE_ASSERT(getInput(0)->nbDims() > 1, "FC input data must have at least one dimension"); - const DimSize_t nbInputFeatures = mAttributes->template getAttr<FCAttr::TransA>() ? - inputDims[0]: - std::accumulate(inputDims.cbegin() + 1, inputDims.cend(), DimSize_t(1), std::multiplies<DimSize_t>()); + const DimSize_t nbInputFeatures = std::accumulate(inputDims.cbegin() + 1, inputDims.cend(), DimSize_t(1), std::multiplies<DimSize_t>()); AIDGE_ASSERT(nbInputFeatures == inChannels, "Wrong number of input features for input data ({}), expected {}", nbInputFeatures, inChannels); @@ -106,11 +99,9 @@ std::shared_ptr<Aidge::Node> Aidge::FC(const Aidge::DimSize_t inChannels, float alpha, float beta, bool noBias, - bool transA, - bool transB, const std::string& name) { // FIXME: properly handle default w&b initialization in every cases - auto fc = std::make_shared<Node>(std::make_shared<FC_Op>(alpha, beta, transA, transB), name); + auto fc = std::make_shared<Node>(std::make_shared<FC_Op>(alpha, beta), name); addProducer(fc, 1, {outChannels, inChannels}, "w"); if (!noBias) { addProducer(fc, 2, {outChannels}, "b"); // already sets bias dims diff --git a/src/operator/MetaOperatorDefs/TransposeFC.cpp b/src/operator/MetaOperatorDefs/TransposeFC.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5dc3a5b3d9598323a32215bdafcfda2843582ab0 --- /dev/null +++ b/src/operator/MetaOperatorDefs/TransposeFC.cpp @@ -0,0 +1,86 @@ +/******************************************************************************** + * Copyright (c) 2025 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + + #include "aidge/operator/MetaOperatorDefs.hpp" + + #include <array> + #include <memory> + #include <vector> + + #include "aidge/graph/Node.hpp" + #include "aidge/graph/OpArgs.hpp" + #include "aidge/operator/FC.hpp" + #include "aidge/operator/MetaOperator.hpp" + #include "aidge/operator/Producer.hpp" + #include "aidge/operator/Transpose.hpp" + #include "aidge/utils/ArrayHelpers.hpp" + #include "aidge/utils/Types.h" + + std::shared_ptr<Aidge::Node> Aidge::TransposeFC(Aidge::DimSize_t in_channels, + Aidge::DimSize_t out_channels, + float alpha, + float beta, + const std::string& name, + bool no_bias, + bool transposeA, + bool transposeB) + { + auto graph = std::make_shared<GraphView>(); + auto fc = FC(in_channels, out_channels, alpha, beta, no_bias, name); + graph->add(fc); + if (transposeA) { + auto transA = Transpose(std::vector<DimSize_t>{}, name + "_transposeA"); + transA->addChild(graph->getOrderedInputs()[0].first,0,0); + graph->add(transA); + } + if (transposeB) { + auto transB = Transpose(std::vector<DimSize_t>{}, name + "_transposeB"); + transB->addChild(graph->getOrderedInputs()[1].first,0,1); + graph->add(transB); + } + + auto metaOpNode = MetaOperator("TransposeFC", graph, {}, name); + + addProducer(metaOpNode, 1, {out_channels, in_channels}, "w"); + if (!no_bias) { + addProducer(metaOpNode, 2, {out_channels}, "b"); + } + + return metaOpNode; + } + + std::shared_ptr<Aidge::MetaOperator_Op> Aidge::TransposeFC_Op(float alpha, + float beta, + bool transposeA, + bool transposeB) + { + auto graph = std::make_shared<GraphView>(); + auto fc = std::make_shared<Node>(std::make_shared<FC_Op>(alpha, beta), ""); + graph->add(fc); + + std::vector<std::pair<NodePtr, IOIndex_t>> orderedInputs = {{fc,0}, {fc,1}, {fc,2}}; + + if (transposeA) { + auto transA = Transpose(std::vector<DimSize_t>{}, ""); + transA->addChild(graph->getOrderedInputs()[0].first,0,0); + graph->add(transA); + orderedInputs[0] = {transA, 0}; + } + if (transposeB) { + auto transB = Transpose(std::vector<DimSize_t>{}, ""); + transB->addChild(graph->getOrderedInputs()[1].first,0,1); + graph->add(transB); + orderedInputs[1] = {transB, 0}; + } + graph->setOrderedInputs(orderedInputs); + graph->setOrderedOutputs({{fc, 0}}); + return std::make_shared<MetaOperator_Op>("TransposeFC", graph); + } diff --git a/unit_tests/graph/Test_Matching.cpp b/unit_tests/graph/Test_Matching.cpp index bd684f9ea1c951396cb186810e6adc388622e0a9..ae8ae678681cb3c2c8796ff3f8f848433d9a9143 100644 --- a/unit_tests/graph/Test_Matching.cpp +++ b/unit_tests/graph/Test_Matching.cpp @@ -357,9 +357,9 @@ TEST_CASE("[core/graph] Matching") { ReLU("relu2"), Conv(4, 4, {5, 5}, "conv3"), BatchNorm<2>(4, 1.0e-5, 0.1, false, "bn3"), - FC(4, 4, 1.0, 1.0, false, false, false, "fc1"), - FC(4, 4, 1.0, 1.0, false, false, false, "fc2"), - FC(4, 4, 1.0, 1.0, false, false, false, "fc3"), + FC(4, 4, 1.0, 1.0, false, "fc1"), + FC(4, 4, 1.0, 1.0, false, "fc2"), + FC(4, 4, 1.0, 1.0, false, "fc3"), ReLU("relu3"), Conv(1, 4, {5, 5}, "conv4") }); diff --git a/unit_tests/recipes/Test_ToGenericOp.cpp b/unit_tests/recipes/Test_ToGenericOp.cpp index 02d784385ee18ceb495fd1e8a2f25ed161b4fee0..4ff2bd72d52c9842742802a0c5ad059bc5139a4e 100644 --- a/unit_tests/recipes/Test_ToGenericOp.cpp +++ b/unit_tests/recipes/Test_ToGenericOp.cpp @@ -32,9 +32,9 @@ TEST_CASE("[graph/convert] toGenericOp", "[toGenericOp][recipies]") { ReLU(), Conv(4, 3, {1, 1}, "conv3"), ReLU(), - FC(2028, 256, 1.0, 1.0, false, false, false, "fc1"), + FC(2028, 256, 1.0, 1.0, false, "fc1"), ReLU(), - FC(256, 10, 1.0, 1.0, false, false, false, "fc2")}); + FC(256, 10, 1.0, 1.0, false, "fc2")}); // NCHW - MNIST DATA like g->forwardDims({{5, 1, 28, 28}}); diff --git a/unit_tests/recipes/Test_removeFlatten.cpp b/unit_tests/recipes/Test_removeFlatten.cpp index 655f7c7f5992902f7d73dd310f4a323d0e1eadce..2e12e0532ec0cb87a29d108fc2352ce996ba9f16 100644 --- a/unit_tests/recipes/Test_removeFlatten.cpp +++ b/unit_tests/recipes/Test_removeFlatten.cpp @@ -27,8 +27,8 @@ namespace Aidge { TEST_CASE("[cpu/recipes] RemoveFlatten", "[RemoveFlatten][recipes]") { std::shared_ptr<Node> flatten = GenericOperator("Flatten", 1, 0, 1, "myFlatten"); - std::shared_ptr<Node> fc0 = FC(10, 10, 1.0, 1.0, false, false, false, "FC_1"); - std::shared_ptr<Node> fc1 = FC(10, 10, 1.0, 1.0, false, false, false, "FC_2"); + std::shared_ptr<Node> fc0 = FC(10, 10, 1.0, 1.0, false, "FC_1"); + std::shared_ptr<Node> fc1 = FC(10, 10, 1.0, 1.0, false, "FC_2"); std::shared_ptr<Node> prod = Producer(std::array<DimSize_t, 10>(), "myProd"); SECTION("flatten last layer : nothing removed because pattern searched is "