diff --git a/include/aidge/operator/FC.hpp b/include/aidge/operator/FC.hpp
index 373c7da29596b163eae21c1235d6db578b755a5f..3d056f5f12cb3facb7e11cb3b6c176837abdf107 100644
--- a/include/aidge/operator/FC.hpp
+++ b/include/aidge/operator/FC.hpp
@@ -28,8 +28,6 @@ namespace Aidge {
 enum class FCAttr {
     Alpha,  // The scalar multiplier for the product of input tensors A * B.
     Beta,   // The scalar multiplier for the bias.
-    TransA, // Boolean to store whether we need to tranpose input#0
-    TransB  // Boolean to store whether we need to tranpose input#1
 };
 
 /**
@@ -38,16 +36,9 @@ enum class FCAttr {
  * The Fully Connected (FC) operation applies a linear transformation to the input Tensor
  * by multiplying it with a weight matrix and optionally adding a bias vector: 
  * - If `bias` is included:
- *   f(x) = x × weights^T + bias
+ *   f(x) = alpha * x * weights^T + beta * bias
  * - If `bias` is omitted:
- *   f(x) = x × weights^T
- *
- * Attributes:
- * - `inChannels`: The number of input features (or channels). Determined from the dimensions
- *   of the weight Tensor. This represents the size of the input vector.
- * - `outChannels`: The number of output features (or channels). Determined from the dimensions
- *   of the weight Tensor. This represents the size of the output vector.
- * - `noBias`: A boolean value indicating whether the bias vector is omitted in the operation.
+ *   f(x) = alpha * x × weights^T
  *
  * @example:
  * - Input Tensor: Shape (64, 128)  // Batch size of 64, 128 input features
@@ -64,9 +55,7 @@ class FC_Op : public OperatorTensor,
 private:
     using Attributes_ = StaticAttributes<FCAttr,
                                         float,
-                                        float,
-                                        bool,
-                                        bool>;
+                                        float>;
 
     template <FCAttr e>
     using attr = typename Attributes_::template attr<e>;
@@ -83,13 +72,11 @@ public:
      *
      * Initializes the operator with a type identifier and input categories.
      */
-    FC_Op(float alpha = 1.0f, float beta = 1.0f, bool transA = false, bool transB = false)
+    FC_Op(float alpha = 1.0f, float beta = 1.0f)
     : OperatorTensor(Type, {InputCategory::Data, InputCategory::Param, InputCategory::OptionalParam}, 1),
     mAttributes(std::make_shared<Attributes_>(
                 attr<FCAttr::Alpha>(alpha),
-                attr<FCAttr::Beta>(beta),
-                attr<FCAttr::TransA>(transA),
-                attr<FCAttr::TransB>(transB)))
+                attr<FCAttr::Beta>(beta)))
     {}
 
     /**
@@ -201,18 +188,6 @@ public:
      */
     inline float& beta() const { return mAttributes->template getAttr<FCAttr::Beta>(); }
 
-    /**
-     * @brief Get the transA boolean.
-     * @return Whether input#0 needs to be transposed.
-     */
-    inline bool& transA() const { return mAttributes->template getAttr<FCAttr::TransA>(); }
-
-    /**
-     * @brief Get the transB boolean.
-     * @return Whether input#1 needs to be transposed.
-     */
-    inline bool& transB() const { return mAttributes->template getAttr<FCAttr::TransB>(); }
-
     /**
      * @brief Retrieves the input tensor names for the FC operator.
      * @return A vector of input tensor names: `{"data_input", "weight", "bias"}`.
@@ -238,8 +213,6 @@ public:
  * @param[in] alpha Scalar multiplier for the product of input tensors A * B.
  * @param[in] beta Scalar multiplier for the bias.
  * @param[in] noBias Flag indicating whether to use a bias term (default is `false`).
- * @param[in] transA Flag indicating whether input#0 needs to be transposed (default is `false`).
- * @param[in] transB Flag indicating whether input#1 needs to be transposed (default is `false`).
  * @param[in] name Name of the operator (optional).
  * @return A shared pointer to the Node containing the FC operator.
  */
@@ -248,14 +221,12 @@ std::shared_ptr<Node> FC(const DimSize_t inChannels,
                          float alpha = 1.0f,
                          float beta = 1.0f,
                          bool noBias = false,
-                         bool transA = false,
-                         bool transB = false,
                          const std::string& name = "");
 
 } // namespace Aidge
 
 namespace {
 template <>
-const char *const EnumStrings<Aidge::FCAttr>::data[] = {"alpha", "beta", "transA", "transB"};
+const char *const EnumStrings<Aidge::FCAttr>::data[] = {"alpha", "beta"};
 }
 #endif /* AIDGE_CORE_OPERATOR_FC_H_ */
diff --git a/include/aidge/operator/MetaOperatorDefs.hpp b/include/aidge/operator/MetaOperatorDefs.hpp
index ef087926879f129765d3e446be21e7d49baf8045..57cb56ea07b3104bf4f1b31f493f07e7b6bd61de 100644
--- a/include/aidge/operator/MetaOperatorDefs.hpp
+++ b/include/aidge/operator/MetaOperatorDefs.hpp
@@ -360,6 +360,43 @@ std::shared_ptr<Node> Leaky(const int nbTimeSteps,
                             const LeakyReset resetType = LeakyReset::Subtraction,
                             const std::string &name = "");
 
+
+/**
+ * @brief Creates a FC operation with transposed inputs.
+ *
+ * This function creates a Fully Connected operation with transpose Operation of 1 or both inputs.
+ *
+ * @param[in] inChannels Number of input channels.
+ * @param[in] outChannels Number of output channels.
+ * @param[in] alpha Scalar multiplier for the product of input tensors A * B.
+ * @param[in] beta Scalar multiplier for the bias.
+ * @param[in] name Optional name for the operation.
+ * @param[in] transposeA Flag indicating whether input#0 needs to be transposed (default is `false`).
+ * @param[in] transposeB Flag indicating whether input#1 needs to be transposed (default is `false`).
+ * @return A shared pointer to the Node representing the padded average pooling operation.
+ */
+extern std::shared_ptr<Node> TransposeFC(DimSize_t in_channels,
+                                        DimSize_t out_channels,
+                                        float alpha=1.0f,
+                                        float beta=1.0f,
+                                        const std::string& name = "",
+                                        bool no_bias = false,
+                                        bool transposeA = false,
+                                        bool transposeB = false);
+
+/**
+ * @brief Creates a padded convolution operation as a MetaOperator.
+ *
+ * This function creates a graph-based MetaOperator representing a padded convolution operation (Conv2D/Conv3D).
+ *
+ * @param[in] alpha Scalar multiplier for the product of input tensors A * B.
+ * @param[in] beta Scalar multiplier for the bias.
+ * @param[in] transposeA Flag indicating whether input#0 needs to be transposed (default is `false`).
+ * @param[in] transposeB Flag indicating whether input#1 needs to be transposed (default is `false`).
+ * @return A shared pointer to the MetaOperator_Op representing the padded convolution operation.
+ */
+extern std::shared_ptr<MetaOperator_Op> TransposeFC_Op(float alpha = 1.0f, float beta = 1.0f, bool transposeA = false, bool transposeB = false);
+
 } // namespace Aidge
 
 #endif /* AIDGE_CORE_OPERATOR_METAOPERATORDEFS_H_ */
diff --git a/python_binding/operator/pybind_FC.cpp b/python_binding/operator/pybind_FC.cpp
index dc3f738fbf5613edf4212c15e89df4624084cb39..f915f1fda97347adccfaa07bb653695a519f6453 100644
--- a/python_binding/operator/pybind_FC.cpp
+++ b/python_binding/operator/pybind_FC.cpp
@@ -30,13 +30,9 @@ void declare_FC(py::module &m) {
     :type type : :py:class:`str`
     )mydelimiter")
     .def(py::init<float,
-      float,
-      bool,
-      bool>(),
+      float>(),
       py::arg("alpha")=1.0,
-      py::arg("beta")=1.0,
-      py::arg("transA")=false,
-      py::arg("transB")=false)
+      py::arg("beta")=1.0)
     .def_static("get_inputs_name", &FC_Op::getInputsName)
     .def_static("get_outputs_name", &FC_Op::getOutputsName)
     .def_readonly_static("Type", &FC_Op::Type)
@@ -53,8 +49,6 @@ void declare_FC(py::module &m) {
         py::arg("alpha")=1.0f,
         py::arg("beta")=1.0f,
         py::arg("no_bias") = false,
-        py::arg("transA") = false,
-        py::arg("transB") = false,
         py::arg("name") = "",
     R"mydelimiter(
     Initialize a node containing a Fully Connected (FC) operator.
@@ -71,10 +65,6 @@ void declare_FC(py::module &m) {
     :type alpha : :py:class:`int`
     :param beta : The scalar multiplier for the bias.
     :type beta : :py:class:`int`
-    :param transA : Indicates whether first input needs to be transposed.
-    :type transA : :py:class:`bool`
-    :param transB : Indicates whether second input needs to be transposed.
-    :type transB : :py:class:`bool`
     )mydelimiter");
 }
 
diff --git a/python_binding/operator/pybind_MetaOperatorDefs.cpp b/python_binding/operator/pybind_MetaOperatorDefs.cpp
index 182a5edaa522f508fe128fa2331289b46e99919c..75e04d6cddf7214f47f5dcf4482e0cd101753d14 100644
--- a/python_binding/operator/pybind_MetaOperatorDefs.cpp
+++ b/python_binding/operator/pybind_MetaOperatorDefs.cpp
@@ -502,6 +502,57 @@ void declare_LeakyOp(py::module &m) {
     )mydelimiter");
 }
 
+void declare_TransposeFCOp(py::module &m) {
+    m.def("TransposeFC", [](DimSize_t in_channels,
+                            DimSize_t out_channels,
+                            float alpha,
+                            float beta,
+                            const std::string& name,
+                            bool no_bias,
+                            bool transA,
+                            bool transB)
+      {
+          return TransposeFC(in_channels, out_channels,alpha, beta, name, no_bias, transA, transB);
+      }, py::arg("in_channels"),
+         py::arg("out_channels"),
+         py::arg("alpha") = 1.0f,
+         py::arg("beta") = 1.0f,
+         py::arg("name") = "",
+         py::arg("no_bias")= false,
+         py::arg("transA")= false,
+         py::arg("transB")= false,
+      R"mydelimiter(
+          Initialize a node containing an FC operator with Transpose on one or both inputs.
+
+          :param in_channels: Number of input channels.
+          :type in_channels: int
+          :param out_channels: Number of output channels.
+          :type out_channels: int
+
+          :param no_bias: Whether to disable bias addition in the convolution.
+          :type no_bias: bool
+          :param name: Name of the node (optional).
+          :type name: str
+          :return: A node containing the FC operator with Transpose node on one or two inputs.
+          :rtype: :py:class:`TransposeFCOp`
+      )mydelimiter");
+
+      m.def("TransposeFCOp", [](float alpha, float beta, bool transA, bool transB)
+      {
+        return TransposeFC_Op(alpha, beta, transA, transB);
+      },
+        py::arg("alpha") = 1.0f,
+        py::arg("beta") = 1.0f,
+        py::arg("transA")= false,
+        py::arg("transB")= false,
+      R"mydelimiter(
+          Initialize an FC operator with Transpose on one or two inputs.
+
+          :return: An FC with Transpose operators.
+          :rtype: :py:class:`TransposeFCOp`
+      )mydelimiter");
+  }
+
 void init_MetaOperatorDefs(py::module &m) {
   declare_PaddedConvOp<1>(m);
   declare_PaddedConvOp<2>(m);
@@ -520,6 +571,7 @@ void init_MetaOperatorDefs(py::module &m) {
   declare_LSTMOp(m);
   declare_LeakyResetEnum(m);
   declare_LeakyOp(m);
+  declare_TransposeFCOp(m);
 
   py::class_<MetaOperator_Op, std::shared_ptr<MetaOperator_Op>, OperatorTensor>(m, "MetaOperatorOp", py::multiple_inheritance())
   .def(py::init<const char *, const std::shared_ptr<GraphView>&, const std::vector<InputCategory>&>(),
diff --git a/src/operator/FC.cpp b/src/operator/FC.cpp
index 13da22423ec1c5be748461f4518db87dc11f4fa6..abe94d92e83bc8b9f805808404b472a39b3b12e8 100644
--- a/src/operator/FC.cpp
+++ b/src/operator/FC.cpp
@@ -45,24 +45,17 @@ bool Aidge::FC_Op::forwardDims(bool /*allowDataDependency*/) {
         // first check weight since it defines inChannels and outChannels
         AIDGE_ASSERT((getInput(1)->nbDims() == 2),
                     "Wrong weight Tensor dimension: {} for FC operator (should have 2 dimensions).", getInput(1)->nbDims());
-        const DimSize_t outChannels = mAttributes->template getAttr<FCAttr::TransB>() ?
-                                      getInput(1)->template dims<2>()[1]:
-                                      getInput(1)->template dims<2>()[0];
-        const DimSize_t inChannels = mAttributes->template getAttr<FCAttr::TransB>() ?
-                                     getInput(1)->template dims<2>()[0]:
-                                     getInput(1)->template dims<2>()[1];
+        const DimSize_t outChannels = getInput(1)->template dims<2>()[0];
+        const DimSize_t inChannels = getInput(1)->template dims<2>()[1];
         // check data
         const std::vector<DimSize_t>& inputDims = getInput(0)->dims();
-        const DimIdx_t inChannelsIdx = mAttributes->template getAttr<FCAttr::TransA>() ? 1 : 0;
         if (getInput(0)->nbDims() == 1) {
-            AIDGE_ASSERT(inputDims[inChannelsIdx] == inChannels,
+            AIDGE_ASSERT(inputDims[0] == inChannels,
                 "Wrong number of input features for input data ({}), expected {}",
-                inputDims[inChannelsIdx], inChannels);
+                inputDims[0], inChannels);
         } else {
             AIDGE_ASSERT(getInput(0)->nbDims() > 1, "FC input data must have at least one dimension");
-            const DimSize_t nbInputFeatures = mAttributes->template getAttr<FCAttr::TransA>() ?
-                                                            inputDims[0]:
-                                                            std::accumulate(inputDims.cbegin() + 1, inputDims.cend(), DimSize_t(1), std::multiplies<DimSize_t>());
+            const DimSize_t nbInputFeatures = std::accumulate(inputDims.cbegin() + 1, inputDims.cend(), DimSize_t(1), std::multiplies<DimSize_t>());
             AIDGE_ASSERT(nbInputFeatures == inChannels,
                     "Wrong number of input features for input data ({}), expected {}",
                     nbInputFeatures, inChannels);
@@ -106,11 +99,9 @@ std::shared_ptr<Aidge::Node> Aidge::FC(const Aidge::DimSize_t inChannels,
                                        float alpha,
                                        float beta,
                                        bool noBias,
-                                       bool transA,
-                                       bool transB,
                                        const std::string& name) {
     // FIXME: properly handle default w&b initialization in every cases
-    auto fc = std::make_shared<Node>(std::make_shared<FC_Op>(alpha, beta, transA, transB), name);
+    auto fc = std::make_shared<Node>(std::make_shared<FC_Op>(alpha, beta), name);
     addProducer(fc, 1, {outChannels, inChannels}, "w");
     if (!noBias) {
         addProducer(fc, 2, {outChannels}, "b"); // already sets bias dims
diff --git a/src/operator/MetaOperatorDefs/TransposeFC.cpp b/src/operator/MetaOperatorDefs/TransposeFC.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..5dc3a5b3d9598323a32215bdafcfda2843582ab0
--- /dev/null
+++ b/src/operator/MetaOperatorDefs/TransposeFC.cpp
@@ -0,0 +1,86 @@
+/********************************************************************************
+ * Copyright (c) 2025 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+ #include "aidge/operator/MetaOperatorDefs.hpp"
+
+ #include <array>
+ #include <memory>
+ #include <vector>
+ 
+ #include "aidge/graph/Node.hpp"
+ #include "aidge/graph/OpArgs.hpp"
+ #include "aidge/operator/FC.hpp"
+ #include "aidge/operator/MetaOperator.hpp"
+ #include "aidge/operator/Producer.hpp"
+ #include "aidge/operator/Transpose.hpp"
+ #include "aidge/utils/ArrayHelpers.hpp"
+ #include "aidge/utils/Types.h"
+ 
+ std::shared_ptr<Aidge::Node> Aidge::TransposeFC(Aidge::DimSize_t in_channels,
+                                   Aidge::DimSize_t out_channels,
+                                   float alpha,
+                                   float beta,
+                                   const std::string& name,
+                                   bool no_bias,
+                                   bool transposeA,
+                                   bool transposeB)
+ {
+    auto graph = std::make_shared<GraphView>();
+    auto fc = FC(in_channels, out_channels, alpha, beta, no_bias, name);
+    graph->add(fc);
+    if (transposeA) {
+        auto transA = Transpose(std::vector<DimSize_t>{}, name + "_transposeA");
+        transA->addChild(graph->getOrderedInputs()[0].first,0,0);
+        graph->add(transA);
+    }
+    if (transposeB) {
+        auto transB = Transpose(std::vector<DimSize_t>{}, name + "_transposeB");
+        transB->addChild(graph->getOrderedInputs()[1].first,0,1);
+        graph->add(transB);
+    }
+
+    auto metaOpNode = MetaOperator("TransposeFC", graph, {}, name);
+
+    addProducer(metaOpNode, 1, {out_channels, in_channels}, "w");
+    if (!no_bias) {
+        addProducer(metaOpNode, 2, {out_channels}, "b");
+    }
+
+     return metaOpNode;
+ }
+
+ std::shared_ptr<Aidge::MetaOperator_Op> Aidge::TransposeFC_Op(float alpha,
+    float beta,
+    bool transposeA,
+    bool transposeB)
+ {
+    auto graph = std::make_shared<GraphView>();
+    auto fc = std::make_shared<Node>(std::make_shared<FC_Op>(alpha, beta), "");
+    graph->add(fc);
+
+    std::vector<std::pair<NodePtr, IOIndex_t>> orderedInputs = {{fc,0}, {fc,1}, {fc,2}};
+
+    if (transposeA) {
+        auto transA = Transpose(std::vector<DimSize_t>{}, "");
+        transA->addChild(graph->getOrderedInputs()[0].first,0,0);
+        graph->add(transA);
+        orderedInputs[0] = {transA, 0};
+    }
+    if (transposeB) {
+        auto transB = Transpose(std::vector<DimSize_t>{}, "");
+        transB->addChild(graph->getOrderedInputs()[1].first,0,1);
+        graph->add(transB);
+        orderedInputs[1] = {transB, 0};
+    }
+    graph->setOrderedInputs(orderedInputs);
+    graph->setOrderedOutputs({{fc, 0}});
+    return std::make_shared<MetaOperator_Op>("TransposeFC", graph);
+ }
diff --git a/unit_tests/graph/Test_Matching.cpp b/unit_tests/graph/Test_Matching.cpp
index bd684f9ea1c951396cb186810e6adc388622e0a9..ae8ae678681cb3c2c8796ff3f8f848433d9a9143 100644
--- a/unit_tests/graph/Test_Matching.cpp
+++ b/unit_tests/graph/Test_Matching.cpp
@@ -357,9 +357,9 @@ TEST_CASE("[core/graph] Matching") {
         ReLU("relu2"),
         Conv(4, 4, {5, 5}, "conv3"),
         BatchNorm<2>(4, 1.0e-5, 0.1, false, "bn3"),
-        FC(4, 4, 1.0, 1.0, false, false, false, "fc1"),
-        FC(4, 4, 1.0, 1.0, false, false, false, "fc2"),
-        FC(4, 4, 1.0, 1.0, false, false, false, "fc3"),
+        FC(4, 4, 1.0, 1.0, false, "fc1"),
+        FC(4, 4, 1.0, 1.0, false, "fc2"),
+        FC(4, 4, 1.0, 1.0, false, "fc3"),
         ReLU("relu3"),
         Conv(1, 4, {5, 5}, "conv4")
     });
diff --git a/unit_tests/recipes/Test_ToGenericOp.cpp b/unit_tests/recipes/Test_ToGenericOp.cpp
index 02d784385ee18ceb495fd1e8a2f25ed161b4fee0..4ff2bd72d52c9842742802a0c5ad059bc5139a4e 100644
--- a/unit_tests/recipes/Test_ToGenericOp.cpp
+++ b/unit_tests/recipes/Test_ToGenericOp.cpp
@@ -32,9 +32,9 @@ TEST_CASE("[graph/convert] toGenericOp", "[toGenericOp][recipies]") {
                     ReLU(),
                     Conv(4, 3, {1, 1}, "conv3"),
                     ReLU(),
-                    FC(2028, 256, 1.0, 1.0, false, false, false, "fc1"),
+                    FC(2028, 256, 1.0, 1.0, false, "fc1"),
                     ReLU(),
-                    FC(256, 10, 1.0, 1.0, false, false, false, "fc2")});
+                    FC(256, 10, 1.0, 1.0, false, "fc2")});
     
     // NCHW - MNIST DATA like
     g->forwardDims({{5, 1, 28, 28}});
diff --git a/unit_tests/recipes/Test_removeFlatten.cpp b/unit_tests/recipes/Test_removeFlatten.cpp
index 655f7c7f5992902f7d73dd310f4a323d0e1eadce..2e12e0532ec0cb87a29d108fc2352ce996ba9f16 100644
--- a/unit_tests/recipes/Test_removeFlatten.cpp
+++ b/unit_tests/recipes/Test_removeFlatten.cpp
@@ -27,8 +27,8 @@ namespace Aidge {
 TEST_CASE("[cpu/recipes] RemoveFlatten", "[RemoveFlatten][recipes]") {
   std::shared_ptr<Node> flatten =
       GenericOperator("Flatten", 1, 0, 1, "myFlatten");
-  std::shared_ptr<Node> fc0 = FC(10, 10, 1.0, 1.0, false, false, false, "FC_1");
-  std::shared_ptr<Node> fc1 = FC(10, 10, 1.0, 1.0, false, false, false, "FC_2");
+  std::shared_ptr<Node> fc0 = FC(10, 10, 1.0, 1.0, false, "FC_1");
+  std::shared_ptr<Node> fc1 = FC(10, 10, 1.0, 1.0, false, "FC_2");
   std::shared_ptr<Node> prod = Producer(std::array<DimSize_t, 10>(), "myProd");
 
   SECTION("flatten last layer : nothing removed because pattern searched is "