diff --git a/include/aidge/aidge.hpp b/include/aidge/aidge.hpp
index cd36a654772d2d641b9af32bb74b1336f4a9742d..5ff1159e6be4dc3bbd7ea3c893f1ef59eb429ae0 100644
--- a/include/aidge/aidge.hpp
+++ b/include/aidge/aidge.hpp
@@ -80,6 +80,7 @@
 #include "aidge/operator/Split.hpp"
 #include "aidge/operator/Sqrt.hpp"
 #include "aidge/operator/Sub.hpp"
+#include "aidge/operator/Sum.hpp"
 #include "aidge/operator/Transpose.hpp"
 #include "aidge/scheduler/Scheduler.hpp"
 #include "aidge/stimuli/Stimulus.hpp"
diff --git a/include/aidge/operator/FC.hpp b/include/aidge/operator/FC.hpp
index 393e640d60934059a9c216a9335a7018388fe9da..3d056f5f12cb3facb7e11cb3b6c176837abdf107 100644
--- a/include/aidge/operator/FC.hpp
+++ b/include/aidge/operator/FC.hpp
@@ -25,22 +25,20 @@
 
 namespace Aidge {
 
+enum class FCAttr {
+    Alpha,  // The scalar multiplier for the product of input tensors A * B.
+    Beta,   // The scalar multiplier for the bias.
+};
+
 /**
  * @brief Description of a Fully Connected (FC) operation on an input Tensor.
  *
  * The Fully Connected (FC) operation applies a linear transformation to the input Tensor
  * by multiplying it with a weight matrix and optionally adding a bias vector: 
  * - If `bias` is included:
- *   f(x) = x × weights^T + bias
+ *   f(x) = alpha * x * weights^T + beta * bias
  * - If `bias` is omitted:
- *   f(x) = x × weights^T
- *
- * Attributes:
- * - `inChannels`: The number of input features (or channels). Determined from the dimensions
- *   of the weight Tensor. This represents the size of the input vector.
- * - `outChannels`: The number of output features (or channels). Determined from the dimensions
- *   of the weight Tensor. This represents the size of the output vector.
- * - `noBias`: A boolean value indicating whether the bias vector is omitted in the operation.
+ *   f(x) = alpha * x × weights^T
  *
  * @example:
  * - Input Tensor: Shape (64, 128)  // Batch size of 64, 128 input features
@@ -54,6 +52,15 @@ class FC_Op : public OperatorTensor,
               public Registrable<FC_Op,
                                  std::string,
                                  std::function<std::shared_ptr<OperatorImpl>(const FC_Op &)>> {
+private:
+    using Attributes_ = StaticAttributes<FCAttr,
+                                        float,
+                                        float>;
+
+    template <FCAttr e>
+    using attr = typename Attributes_::template attr<e>;
+
+    const std::shared_ptr<Attributes_> mAttributes;
 public:
     /**
      * @brief Static type identifier for the FC operator.
@@ -65,8 +72,11 @@ public:
      *
      * Initializes the operator with a type identifier and input categories.
      */
-    FC_Op()
-    : OperatorTensor(Type, {InputCategory::Data, InputCategory::Param, InputCategory::OptionalParam}, 1)
+    FC_Op(float alpha = 1.0f, float beta = 1.0f)
+    : OperatorTensor(Type, {InputCategory::Data, InputCategory::Param, InputCategory::OptionalParam}, 1),
+    mAttributes(std::make_shared<Attributes_>(
+                attr<FCAttr::Alpha>(alpha),
+                attr<FCAttr::Beta>(beta)))
     {}
 
     /**
@@ -160,6 +170,24 @@ public:
         return getInput(1)->template dims<2>()[0];
     }
 
+    /**
+     * @brief Get the attributes of the operator.
+     * @return A shared pointer to the operator's attributes.
+     */
+    inline std::shared_ptr<Attributes> attributes() const override { return mAttributes; }
+
+    /**
+     * @brief Get the alpha coefficient.
+     * @return The alpha coefficient.
+     */
+    inline float& alpha() const { return mAttributes->template getAttr<FCAttr::Alpha>(); }
+
+    /**
+     * @brief Get the beta coefficient.
+     * @return The beta coefficient.
+     */
+    inline float& beta() const { return mAttributes->template getAttr<FCAttr::Beta>(); }
+
     /**
      * @brief Retrieves the input tensor names for the FC operator.
      * @return A vector of input tensor names: `{"data_input", "weight", "bias"}`.
@@ -180,16 +208,25 @@ public:
 /**
  * @brief Creates a Fully Connected operation node.
  *
- * Constructs an FC operator node with the specified input and output channels.
- *
  * @param[in] inChannels Number of input channels.
  * @param[in] outChannels Number of output channels.
+ * @param[in] alpha Scalar multiplier for the product of input tensors A * B.
+ * @param[in] beta Scalar multiplier for the bias.
  * @param[in] noBias Flag indicating whether to use a bias term (default is `false`).
  * @param[in] name Name of the operator (optional).
  * @return A shared pointer to the Node containing the FC operator.
  */
-std::shared_ptr<Node> FC(const DimSize_t inChannels, const DimSize_t outChannels, bool noBias = false, const std::string& name = "");
+std::shared_ptr<Node> FC(const DimSize_t inChannels,
+                         const DimSize_t outChannels,
+                         float alpha = 1.0f,
+                         float beta = 1.0f,
+                         bool noBias = false,
+                         const std::string& name = "");
 
 } // namespace Aidge
 
+namespace {
+template <>
+const char *const EnumStrings<Aidge::FCAttr>::data[] = {"alpha", "beta"};
+}
 #endif /* AIDGE_CORE_OPERATOR_FC_H_ */
diff --git a/include/aidge/operator/MetaOperatorDefs.hpp b/include/aidge/operator/MetaOperatorDefs.hpp
index ef087926879f129765d3e446be21e7d49baf8045..57cb56ea07b3104bf4f1b31f493f07e7b6bd61de 100644
--- a/include/aidge/operator/MetaOperatorDefs.hpp
+++ b/include/aidge/operator/MetaOperatorDefs.hpp
@@ -360,6 +360,43 @@ std::shared_ptr<Node> Leaky(const int nbTimeSteps,
                             const LeakyReset resetType = LeakyReset::Subtraction,
                             const std::string &name = "");
 
+
+/**
+ * @brief Creates a FC operation with transposed inputs.
+ *
+ * This function creates a Fully Connected operation with transpose Operation of 1 or both inputs.
+ *
+ * @param[in] inChannels Number of input channels.
+ * @param[in] outChannels Number of output channels.
+ * @param[in] alpha Scalar multiplier for the product of input tensors A * B.
+ * @param[in] beta Scalar multiplier for the bias.
+ * @param[in] name Optional name for the operation.
+ * @param[in] transposeA Flag indicating whether input#0 needs to be transposed (default is `false`).
+ * @param[in] transposeB Flag indicating whether input#1 needs to be transposed (default is `false`).
+ * @return A shared pointer to the Node representing the padded average pooling operation.
+ */
+extern std::shared_ptr<Node> TransposeFC(DimSize_t in_channels,
+                                        DimSize_t out_channels,
+                                        float alpha=1.0f,
+                                        float beta=1.0f,
+                                        const std::string& name = "",
+                                        bool no_bias = false,
+                                        bool transposeA = false,
+                                        bool transposeB = false);
+
+/**
+ * @brief Creates a padded convolution operation as a MetaOperator.
+ *
+ * This function creates a graph-based MetaOperator representing a padded convolution operation (Conv2D/Conv3D).
+ *
+ * @param[in] alpha Scalar multiplier for the product of input tensors A * B.
+ * @param[in] beta Scalar multiplier for the bias.
+ * @param[in] transposeA Flag indicating whether input#0 needs to be transposed (default is `false`).
+ * @param[in] transposeB Flag indicating whether input#1 needs to be transposed (default is `false`).
+ * @return A shared pointer to the MetaOperator_Op representing the padded convolution operation.
+ */
+extern std::shared_ptr<MetaOperator_Op> TransposeFC_Op(float alpha = 1.0f, float beta = 1.0f, bool transposeA = false, bool transposeB = false);
+
 } // namespace Aidge
 
 #endif /* AIDGE_CORE_OPERATOR_METAOPERATORDEFS_H_ */
diff --git a/include/aidge/operator/Sum.hpp b/include/aidge/operator/Sum.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..6718f4179f09c5594c99859bb39e75610de32bba
--- /dev/null
+++ b/include/aidge/operator/Sum.hpp
@@ -0,0 +1,90 @@
+/********************************************************************************
+ * Copyright (c) 2025 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+ #ifndef AIDGE_CORE_OPERATOR_SUM_H_
+ #define AIDGE_CORE_OPERATOR_SUM_H_
+ 
+ #include <memory>
+ #include <string>
+ #include <vector>
+ 
+ #include "aidge/operator/OperatorTensor.hpp"
+ #include "aidge/graph/Node.hpp"
+ #include "aidge/utils/ErrorHandling.hpp"
+ #include "aidge/utils/Types.h"
+ #include "aidge/utils/Registrar.hpp"
+ 
+ namespace Aidge {
+ 
+ /**
+  * @brief Description of an element-wise Sum operation on multiple input Tensors,
+  * supporting NumPy broadcasting.
+  *
+  * For each N of elements x0, x1, ..., xN from the input Tensors, the function 
+  * is defined as:
+  * `f(x0, ..., xN) = x0 + x1 + ... + xN`
+  *
+  * Broadcasting adjusts shapes of the input Tensors to make them compatible:
+  * - Tensors are aligned from the rightmost dimensions.
+  * - Dimensions are compatible if they are equal, one of them is 1, or missing.
+  *
+  * The output Tensor shape is determined by taking the maximum size along 
+  * each dimension of the input Tensors after broadcasting.
+  *
+  * @example Input 1: (3, 4, 2), Input 2: (2), Output: (3, 4, 2)
+  * @example Input 1: (1, 5, 3), Input 2: (2, 1, 3), Input 3 : (2), Output: (2, 5, 3)
+  *
+  * @see OperatorTensor
+  * @see Registrable
+  */
+ class Sum_Op : public OperatorTensor,
+     public Registrable<Sum_Op,
+                        std::string,
+                        std::function<std::shared_ptr<OperatorImpl>(const Sum_Op&)>>
+ {
+ public:
+     static const std::string Type;
+ 
+    Sum_Op() = delete;
+    Sum_Op(const IOIndex_t nbIn);
+ 
+     /**
+      * @brief Copy-constructor.
+      * @param op Sum_Op to copy.
+      * @details Copies the operator attributes and its output tensor(s), but not
+      * its input tensors. The new operator has no associated input.
+      */
+     Sum_Op(const Sum_Op& op);
+ 
+     /**
+      * @brief Clone the operator using its copy-constructor.
+      * @see Operator::Sum_Op
+      */
+     std::shared_ptr<Operator> clone() const override;
+ 
+     bool forwardDims(bool allowDataDependency = false) override final;
+ 
+     void setBackend(const std::string& name, DeviceIdx_t device = 0) override;
+     std::set<std::string> getAvailableBackends() const override;
+ 
+     static const std::vector<std::string> getInputsName() {
+         return {"data_input_0", "data_input_n"};
+     }
+     static const std::vector<std::string> getOutputsName() {
+         return {"data_output"};
+     }
+ };
+ 
+ std::shared_ptr<Node> Sum(const IOIndex_t nbIn, const std::string& name = "");
+ }
+ 
+ #endif /* AIDGE_CORE_OPERATOR_SUM_H_ */
+ 
\ No newline at end of file
diff --git a/python_binding/operator/pybind_FC.cpp b/python_binding/operator/pybind_FC.cpp
index c29b6e1d3723f03f6a9c9b1f03156b42160c6cf3..f915f1fda97347adccfaa07bb653695a519f6453 100644
--- a/python_binding/operator/pybind_FC.cpp
+++ b/python_binding/operator/pybind_FC.cpp
@@ -29,7 +29,10 @@ void declare_FC(py::module &m) {
     :param type : The type of the Fully Connected operation.
     :type type : :py:class:`str`
     )mydelimiter")
-    .def(py::init<>())
+    .def(py::init<float,
+      float>(),
+      py::arg("alpha")=1.0,
+      py::arg("beta")=1.0)
     .def_static("get_inputs_name", &FC_Op::getInputsName)
     .def_static("get_outputs_name", &FC_Op::getOutputsName)
     .def_readonly_static("Type", &FC_Op::Type)
@@ -40,7 +43,13 @@ void declare_FC(py::module &m) {
 
   declare_registrable<FC_Op>(m, "FCOp");
 
-  m.def("FC", &FC, py::arg("in_channels"), py::arg("out_channels"), py::arg("no_bias") = false, py::arg("name") = "",
+  m.def("FC", &FC, 
+        py::arg("in_channels"),
+        py::arg("out_channels"),
+        py::arg("alpha")=1.0f,
+        py::arg("beta")=1.0f,
+        py::arg("no_bias") = false,
+        py::arg("name") = "",
     R"mydelimiter(
     Initialize a node containing a Fully Connected (FC) operator.
 
@@ -52,6 +61,10 @@ void declare_FC(py::module &m) {
     :type no_bias : :py:class:`bool`
     :param name : Name of the node.
     :type name : :py:class:`str`
+    :param alpha : The scalar multiplier for the term A*B.
+    :type alpha : :py:class:`int`
+    :param beta : The scalar multiplier for the bias.
+    :type beta : :py:class:`int`
     )mydelimiter");
 }
 
diff --git a/python_binding/operator/pybind_MetaOperatorDefs.cpp b/python_binding/operator/pybind_MetaOperatorDefs.cpp
index 182a5edaa522f508fe128fa2331289b46e99919c..75e04d6cddf7214f47f5dcf4482e0cd101753d14 100644
--- a/python_binding/operator/pybind_MetaOperatorDefs.cpp
+++ b/python_binding/operator/pybind_MetaOperatorDefs.cpp
@@ -502,6 +502,57 @@ void declare_LeakyOp(py::module &m) {
     )mydelimiter");
 }
 
+void declare_TransposeFCOp(py::module &m) {
+    m.def("TransposeFC", [](DimSize_t in_channels,
+                            DimSize_t out_channels,
+                            float alpha,
+                            float beta,
+                            const std::string& name,
+                            bool no_bias,
+                            bool transA,
+                            bool transB)
+      {
+          return TransposeFC(in_channels, out_channels,alpha, beta, name, no_bias, transA, transB);
+      }, py::arg("in_channels"),
+         py::arg("out_channels"),
+         py::arg("alpha") = 1.0f,
+         py::arg("beta") = 1.0f,
+         py::arg("name") = "",
+         py::arg("no_bias")= false,
+         py::arg("transA")= false,
+         py::arg("transB")= false,
+      R"mydelimiter(
+          Initialize a node containing an FC operator with Transpose on one or both inputs.
+
+          :param in_channels: Number of input channels.
+          :type in_channels: int
+          :param out_channels: Number of output channels.
+          :type out_channels: int
+
+          :param no_bias: Whether to disable bias addition in the convolution.
+          :type no_bias: bool
+          :param name: Name of the node (optional).
+          :type name: str
+          :return: A node containing the FC operator with Transpose node on one or two inputs.
+          :rtype: :py:class:`TransposeFCOp`
+      )mydelimiter");
+
+      m.def("TransposeFCOp", [](float alpha, float beta, bool transA, bool transB)
+      {
+        return TransposeFC_Op(alpha, beta, transA, transB);
+      },
+        py::arg("alpha") = 1.0f,
+        py::arg("beta") = 1.0f,
+        py::arg("transA")= false,
+        py::arg("transB")= false,
+      R"mydelimiter(
+          Initialize an FC operator with Transpose on one or two inputs.
+
+          :return: An FC with Transpose operators.
+          :rtype: :py:class:`TransposeFCOp`
+      )mydelimiter");
+  }
+
 void init_MetaOperatorDefs(py::module &m) {
   declare_PaddedConvOp<1>(m);
   declare_PaddedConvOp<2>(m);
@@ -520,6 +571,7 @@ void init_MetaOperatorDefs(py::module &m) {
   declare_LSTMOp(m);
   declare_LeakyResetEnum(m);
   declare_LeakyOp(m);
+  declare_TransposeFCOp(m);
 
   py::class_<MetaOperator_Op, std::shared_ptr<MetaOperator_Op>, OperatorTensor>(m, "MetaOperatorOp", py::multiple_inheritance())
   .def(py::init<const char *, const std::shared_ptr<GraphView>&, const std::vector<InputCategory>&>(),
diff --git a/python_binding/operator/pybind_Sum.cpp b/python_binding/operator/pybind_Sum.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..2d09d4736c89c931512b15990e86fe7c83e17619
--- /dev/null
+++ b/python_binding/operator/pybind_Sum.cpp
@@ -0,0 +1,67 @@
+/********************************************************************************
+ * Copyright (c) 2025 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+ #include <memory>
+
+ #include <pybind11/pybind11.h>
+ 
+ #include "aidge/operator/Sum.hpp"
+ #include "aidge/operator/OperatorTensor.hpp"
+ #include "aidge/utils/Types.h"
+ 
+ namespace py = pybind11;
+ namespace Aidge {
+ 
+ void declare_Sum(py::module &m) {
+   py::class_<Sum_Op, std::shared_ptr<Sum_Op>, OperatorTensor>(m, "SumOp", py::multiple_inheritance(),
+     R"mydelimiter(
+     Initialize a Sum operator.
+     This operator performs element-wise addition between multiple input tensors.
+     The operation is defined as:
+         Output = Input1 + Input2 + ... + InputN
+     The output tensor shape is determined by taking the maximum size along each dimension of the input tensors after broadcasting.
+     Examples:
+         Input 1: (3, 4, 2), Input 2: (2), Output: (3, 4, 2)
+         Input 1: (1, 5, 3), Input 2: (2, 1, 3), Input 3: (2), Output: (2, 5, 3)
+     :param name : Name of the node (optional).
+     :type name : str
+     )mydelimiter")
+     .def(py::init<const IOIndex_t>(), py::arg("nb_inputs"))
+     .def_static("get_inputs_name", &Sum_Op::getInputsName)
+     .def_static("get_outputs_name", &Sum_Op::getOutputsName)
+     .def_readonly_static("Type", &Sum_Op::Type);
+ 
+   declare_registrable<Sum_Op>(m, "SumOp");
+ 
+   m.def("Sum", &Sum, py::arg("nb_inputs"),  py::arg("name") = "",
+     R"mydelimiter(
+     Initialize a node containing a sum operator that performs element-wise addition between multiple tensors.
+     The operation is defined as:
+         Output = Input1 + Input2 + ... + InputN
+     The output tensor shape is determined by taking the maximum size along each dimension of the input tensors after broadcasting.
+     Examples:
+         Input 1: (3, 4, 2), Input 2: (2), Output: (3, 4, 2)
+         Input 1: (1, 5, 3), Input 2: (2, 1, 3), Input 3: (2), Output: (2, 5, 3)
+     :param nb_inputs : number of inputs to sum.
+     :type nb_inputs : int
+     :param name : Name of the node (optional).
+     :type name : str
+     :return: A node containing the Sum operator.
+     :rtype: :py:class:`SumOp`
+     )mydelimiter");
+ }
+ 
+ void init_Sum(py::module &m) {
+   declare_Sum(m);
+ }
+ 
+ } // namespace Aidge
+ 
\ No newline at end of file
diff --git a/src/operator/FC.cpp b/src/operator/FC.cpp
index dd3ed7aba65cf1875d691d9bc2c8c94bb03856c7..abe94d92e83bc8b9f805808404b472a39b3b12e8 100644
--- a/src/operator/FC.cpp
+++ b/src/operator/FC.cpp
@@ -61,12 +61,11 @@ bool Aidge::FC_Op::forwardDims(bool /*allowDataDependency*/) {
                     nbInputFeatures, inChannels);
         }
         // check optional bias
-        if(getInput(2))
-            AIDGE_ASSERT((getInput(2)->nbDims() == 1) &&
-                    (getInput(2)->template dims<1>()[0] == outChannels),
-                    "Wrong bias size for FC operator.");
+        if(getInput(2)) {
+            AIDGE_ASSERT(getInput(2)->size() == outChannels, "Wrong bias size for FC operator.");
+        }
         // <batch, OutChannels>
-        mOutputs[0]->resize({getInput(0)->dims()[0], outChannels});
+        mOutputs[0]->resize({static_cast<DimSize_t>(getInput(0)->size() / inChannels), outChannels});
         return true;
     }
 
@@ -97,10 +96,12 @@ std::set<std::string> Aidge::FC_Op::getAvailableBackends() const {
 
 std::shared_ptr<Aidge::Node> Aidge::FC(const Aidge::DimSize_t inChannels,
                                        const Aidge::DimSize_t outChannels,
+                                       float alpha,
+                                       float beta,
                                        bool noBias,
                                        const std::string& name) {
     // FIXME: properly handle default w&b initialization in every cases
-    auto fc = std::make_shared<Node>(std::make_shared<FC_Op>(), name);
+    auto fc = std::make_shared<Node>(std::make_shared<FC_Op>(alpha, beta), name);
     addProducer(fc, 1, {outChannels, inChannels}, "w");
     if (!noBias) {
         addProducer(fc, 2, {outChannels}, "b"); // already sets bias dims
diff --git a/src/operator/MetaOperatorDefs/TransposeFC.cpp b/src/operator/MetaOperatorDefs/TransposeFC.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..5dc3a5b3d9598323a32215bdafcfda2843582ab0
--- /dev/null
+++ b/src/operator/MetaOperatorDefs/TransposeFC.cpp
@@ -0,0 +1,86 @@
+/********************************************************************************
+ * Copyright (c) 2025 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+ #include "aidge/operator/MetaOperatorDefs.hpp"
+
+ #include <array>
+ #include <memory>
+ #include <vector>
+ 
+ #include "aidge/graph/Node.hpp"
+ #include "aidge/graph/OpArgs.hpp"
+ #include "aidge/operator/FC.hpp"
+ #include "aidge/operator/MetaOperator.hpp"
+ #include "aidge/operator/Producer.hpp"
+ #include "aidge/operator/Transpose.hpp"
+ #include "aidge/utils/ArrayHelpers.hpp"
+ #include "aidge/utils/Types.h"
+ 
+ std::shared_ptr<Aidge::Node> Aidge::TransposeFC(Aidge::DimSize_t in_channels,
+                                   Aidge::DimSize_t out_channels,
+                                   float alpha,
+                                   float beta,
+                                   const std::string& name,
+                                   bool no_bias,
+                                   bool transposeA,
+                                   bool transposeB)
+ {
+    auto graph = std::make_shared<GraphView>();
+    auto fc = FC(in_channels, out_channels, alpha, beta, no_bias, name);
+    graph->add(fc);
+    if (transposeA) {
+        auto transA = Transpose(std::vector<DimSize_t>{}, name + "_transposeA");
+        transA->addChild(graph->getOrderedInputs()[0].first,0,0);
+        graph->add(transA);
+    }
+    if (transposeB) {
+        auto transB = Transpose(std::vector<DimSize_t>{}, name + "_transposeB");
+        transB->addChild(graph->getOrderedInputs()[1].first,0,1);
+        graph->add(transB);
+    }
+
+    auto metaOpNode = MetaOperator("TransposeFC", graph, {}, name);
+
+    addProducer(metaOpNode, 1, {out_channels, in_channels}, "w");
+    if (!no_bias) {
+        addProducer(metaOpNode, 2, {out_channels}, "b");
+    }
+
+     return metaOpNode;
+ }
+
+ std::shared_ptr<Aidge::MetaOperator_Op> Aidge::TransposeFC_Op(float alpha,
+    float beta,
+    bool transposeA,
+    bool transposeB)
+ {
+    auto graph = std::make_shared<GraphView>();
+    auto fc = std::make_shared<Node>(std::make_shared<FC_Op>(alpha, beta), "");
+    graph->add(fc);
+
+    std::vector<std::pair<NodePtr, IOIndex_t>> orderedInputs = {{fc,0}, {fc,1}, {fc,2}};
+
+    if (transposeA) {
+        auto transA = Transpose(std::vector<DimSize_t>{}, "");
+        transA->addChild(graph->getOrderedInputs()[0].first,0,0);
+        graph->add(transA);
+        orderedInputs[0] = {transA, 0};
+    }
+    if (transposeB) {
+        auto transB = Transpose(std::vector<DimSize_t>{}, "");
+        transB->addChild(graph->getOrderedInputs()[1].first,0,1);
+        graph->add(transB);
+        orderedInputs[1] = {transB, 0};
+    }
+    graph->setOrderedInputs(orderedInputs);
+    graph->setOrderedOutputs({{fc, 0}});
+    return std::make_shared<MetaOperator_Op>("TransposeFC", graph);
+ }
diff --git a/src/operator/Sum.cpp b/src/operator/Sum.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..6c6e5fe2921f9549110909003355c97d420c74cb
--- /dev/null
+++ b/src/operator/Sum.cpp
@@ -0,0 +1,95 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+ #include <cstddef>    // std::size_t
+ #include <stdexcept>  // std::runtime_error
+ #include <string>
+ #include <vector>
+ 
+ #include "aidge/data/Tensor.hpp"
+ #include "aidge/operator/Sum.hpp"
+ #include "aidge/utils/Types.h"
+ #include "aidge/utils/ErrorHandling.hpp"
+ #include "aidge/utils/Registrar.hpp"
+ 
+ const std::string Aidge::Sum_Op::Type = "Sum";
+ 
+ Aidge::Sum_Op::Sum_Op(const IOIndex_t nbIn)
+ : OperatorTensor(Type, std::vector<InputCategory>(nbIn, InputCategory::Data), 1) {
+    if (nbIn == 0) {
+        AIDGE_THROW_OR_ABORT(std::runtime_error, "Sum operator should have at least one input.");
+    }    
+ }
+ 
+ Aidge::Sum_Op::Sum_Op(const Sum_Op& op)
+     : OperatorTensor(op)
+ {
+     if (op.mImpl) {
+         SET_IMPL_MACRO(Sum_Op, *this, op.backend());
+     } else {
+         mImpl = nullptr;
+     }
+ }
+ 
+ std::shared_ptr<Aidge::Operator> Aidge::Sum_Op::clone() const {
+     return std::make_shared<Sum_Op>(*this);
+ }
+ 
+ bool Aidge::Sum_Op::forwardDims(bool /*allowDataDependency*/) {
+     if (inputsAssociated()) {
+         std::vector<std::vector<std::size_t>> inputsDims(nbInputs());
+         for (std::size_t i = 0; i < nbInputs(); i++) {
+             inputsDims[i] = getInput(i)->dims();
+         }
+ 
+         std::size_t outNbDims = 1;
+         for(std::size_t i = 0; i < nbInputs(); ++i) {
+             outNbDims = (inputsDims[i].size() > outNbDims) ? inputsDims[i].size() : outNbDims;
+         }
+ 
+         std::vector<std::size_t> outDims(outNbDims, 1);
+ 
+         for (auto it = outDims.rbegin(); it != outDims.rend(); ++it) {
+             for (std::size_t i = 0; i < nbInputs(); ++i) {
+                 if(!inputsDims[i].empty()) {
+                     const std::size_t dim = inputsDims[i].back();
+                     inputsDims[i].pop_back();
+                     if (*it == 1) {
+                         *it = dim;
+                     }
+                     else if ((dim != *it) && (dim != 1)) {
+                         AIDGE_THROW_OR_ABORT(std::runtime_error, "Incompatible Tensor shape for Add Operation: {} for previous inputs vs {} for input#{}",
+                             outDims, getInput(i)->dims(), i);
+                     }
+                 }
+             }
+         }
+         mOutputs[0]->resize(outDims);
+         return true;
+     }
+ 
+     return false;
+ }
+ 
+ void Aidge::Sum_Op::setBackend(const std::string& name, DeviceIdx_t device) {
+     SET_IMPL_MACRO(Sum_Op, *this, name);
+     mOutputs[0]->setBackend(name, device);
+ }
+ 
+ std::set<std::string> Aidge::Sum_Op::getAvailableBackends() const {
+     return Registrar<Sum_Op>::getKeys();
+ }
+ 
+ ////////////////////////////////////////////////////////////////////////////////
+ 
+ std::shared_ptr<Aidge::Node> Aidge::Sum(const IOIndex_t nbIn, const std::string& name) {
+     return std::make_shared<Node>(std::make_shared<Sum_Op>(nbIn), name);
+ }
\ No newline at end of file
diff --git a/unit_tests/graph/Test_Matching.cpp b/unit_tests/graph/Test_Matching.cpp
index 582c73565a4ef7bfc96e493e1e6029b1683676ab..ae8ae678681cb3c2c8796ff3f8f848433d9a9143 100644
--- a/unit_tests/graph/Test_Matching.cpp
+++ b/unit_tests/graph/Test_Matching.cpp
@@ -357,9 +357,9 @@ TEST_CASE("[core/graph] Matching") {
         ReLU("relu2"),
         Conv(4, 4, {5, 5}, "conv3"),
         BatchNorm<2>(4, 1.0e-5, 0.1, false, "bn3"),
-        FC(4, 4, false, "fc1"),
-        FC(4, 4, false, "fc2"),
-        FC(4, 4, false, "fc3"),
+        FC(4, 4, 1.0, 1.0, false, "fc1"),
+        FC(4, 4, 1.0, 1.0, false, "fc2"),
+        FC(4, 4, 1.0, 1.0, false, "fc3"),
         ReLU("relu3"),
         Conv(1, 4, {5, 5}, "conv4")
     });
diff --git a/unit_tests/recipes/Test_ToGenericOp.cpp b/unit_tests/recipes/Test_ToGenericOp.cpp
index cb75fdb1072dee476c88c1f6d502a792b2e6abd9..4ff2bd72d52c9842742802a0c5ad059bc5139a4e 100644
--- a/unit_tests/recipes/Test_ToGenericOp.cpp
+++ b/unit_tests/recipes/Test_ToGenericOp.cpp
@@ -32,9 +32,9 @@ TEST_CASE("[graph/convert] toGenericOp", "[toGenericOp][recipies]") {
                     ReLU(),
                     Conv(4, 3, {1, 1}, "conv3"),
                     ReLU(),
-                    FC(2028, 256, false, "fc1"),
+                    FC(2028, 256, 1.0, 1.0, false, "fc1"),
                     ReLU(),
-                    FC(256, 10, false, "fc2")});
+                    FC(256, 10, 1.0, 1.0, false, "fc2")});
     
     // NCHW - MNIST DATA like
     g->forwardDims({{5, 1, 28, 28}});
diff --git a/unit_tests/recipes/Test_removeFlatten.cpp b/unit_tests/recipes/Test_removeFlatten.cpp
index 1b5e2783813da890b1e79744582f54bb5c932772..2e12e0532ec0cb87a29d108fc2352ce996ba9f16 100644
--- a/unit_tests/recipes/Test_removeFlatten.cpp
+++ b/unit_tests/recipes/Test_removeFlatten.cpp
@@ -27,8 +27,8 @@ namespace Aidge {
 TEST_CASE("[cpu/recipes] RemoveFlatten", "[RemoveFlatten][recipes]") {
   std::shared_ptr<Node> flatten =
       GenericOperator("Flatten", 1, 0, 1, "myFlatten");
-  std::shared_ptr<Node> fc0 = FC(10, 10, false, "FC_1");
-  std::shared_ptr<Node> fc1 = FC(10, 10, false, "FC_2");
+  std::shared_ptr<Node> fc0 = FC(10, 10, 1.0, 1.0, false, "FC_1");
+  std::shared_ptr<Node> fc1 = FC(10, 10, 1.0, 1.0, false, "FC_2");
   std::shared_ptr<Node> prod = Producer(std::array<DimSize_t, 10>(), "myProd");
 
   SECTION("flatten last layer : nothing removed because pattern searched is "