diff --git a/include/aidge/backend/cuda.hpp b/include/aidge/backend/cuda.hpp
index 963c694e01d9f2813eb95622c1c475eb40bf7bb1..4743f6e0e222bd35557ef61d7c1ee7040dc25cba 100644
--- a/include/aidge/backend/cuda.hpp
+++ b/include/aidge/backend/cuda.hpp
@@ -15,6 +15,7 @@
 #include "aidge/backend/cuda/data/TensorImpl.hpp"
 #include "aidge/backend/cuda/operator/AvgPoolingImpl.hpp"
 #include "aidge/backend/cuda/operator/ConvImpl.hpp"
+#include "aidge/backend/cuda/operator/FCImpl.hpp"
 #include "aidge/backend/cuda/operator/MaxPoolingImpl.hpp"
 #include "aidge/backend/cuda/operator/ProducerImpl.hpp"
 #include "aidge/backend/cuda/operator/ReLUImpl.hpp"
diff --git a/include/aidge/backend/cuda/operator/FCImpl.hpp b/include/aidge/backend/cuda/operator/FCImpl.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..b06d42e64ce5272e975f0dcf4039ccd78f24f78a
--- /dev/null
+++ b/include/aidge/backend/cuda/operator/FCImpl.hpp
@@ -0,0 +1,59 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#ifndef AIDGE_BACKEND_CUDA_OPERATOR_FCIMPL_H_
+#define AIDGE_BACKEND_CUDA_OPERATOR_FCIMPL_H_
+
+#include <array>
+#include <memory>
+#include <tuple>
+#include <vector>
+
+#include <cudnn.h>
+
+#include "aidge/backend/OperatorImpl.hpp"
+#include "aidge/operator/FC.hpp"
+#include "aidge/utils/Registrar.hpp"
+#include "aidge/utils/Types.h"
+
+#include "aidge/backend/cuda/utils/CudaUtils.hpp"
+
+namespace Aidge {
+class FCImplForward_cuda : public Registrable<FCImplForward_cuda,
+                                                 std::tuple<DataType>,
+                                                 void(unsigned int , unsigned int , unsigned int, bool, const void* , const void* , const void* , void*)> {};
+class FCImpl_cuda : public OperatorImpl {
+private:
+    // CuDNN specific variables
+
+
+public:
+    FCImpl_cuda(const FC_Op &op) : OperatorImpl(op) {}
+
+    static std::unique_ptr<FCImpl_cuda> create(const FC_Op &op) {
+        return std::make_unique<FCImpl_cuda>(op);
+    }
+
+public:
+    void forward();
+    // ~FCImpl_cuda();
+
+private:
+    template <class T> void forward_(const Tensor& input0, const Tensor& input1, const Tensor& input2, bool noBias, DimSize_t outChannels);
+};
+
+namespace {
+// add cuda backend to FC_Op implementation registry
+static Registrar<FC_Op> registrarFCImpl_cuda("cuda", Aidge::FCImpl_cuda::create);
+}  // namespace
+}  // namespace Aidge
+
+#endif /* AIDGE_BACKEND_CUDA_OPERATOR_FCIMPL_H_ */
diff --git a/include/aidge/backend/cuda/operator/FCImpl_CUDA_kernels.hpp b/include/aidge/backend/cuda/operator/FCImpl_CUDA_kernels.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..fac838f5ee7b11c67736e6ed83df4aa876b2825b
--- /dev/null
+++ b/include/aidge/backend/cuda/operator/FCImpl_CUDA_kernels.hpp
@@ -0,0 +1,35 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#ifndef AIDGE_CUDA_OPERATOR_FCIMPL_FORWARD_KERNEL_H_
+#define AIDGE_CUDA_OPERATOR_FCIMPL_FORWARD_KERNEL_H_
+
+#include <stdexcept>
+#include <cfloat>
+#include <cuda.h>
+#include <cuda_runtime_api.h>
+
+#include "aidge/data/Data.hpp"
+#include "aidge/backend/cuda/operator/FCImpl.hpp"
+#include "aidge/backend/cuda/utils/CudaUtils.hpp"
+
+namespace Aidge {
+
+template<class T>
+void fc_forward_cuda(DimSize_t nbInputs, DimSize_t inChannels, DimSize_t outChannels, bool noBias, const void *input, const void *weights, const void *bias, void *output);
+
+namespace {
+static Registrar<FCImplForward_cuda> registrarFCImpl2DForward_cuda_Float32({DataType::Float32}, Aidge::fc_forward_cuda<float>);
+static Registrar<FCImplForward_cuda> registrarFCImpl2DForward_cuda_Float64({DataType::Float64}, Aidge::fc_forward_cuda<double>);
+}  // namespace
+
+}
+#endif /* AIDGE_CUDA_OPERATOR_FCIMPL_FORWARD_KERNEL_H_ */
\ No newline at end of file
diff --git a/src/operator/FCImpl.cpp b/src/operator/FCImpl.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..5b2183b56208f8f9d8d1d972dc26fe9c03835694
--- /dev/null
+++ b/src/operator/FCImpl.cpp
@@ -0,0 +1,60 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <cassert>
+#include <chrono>  // std::chrono::milliseconds
+#include <numeric> // std::accumulate
+#include <thread>  // std::this_thread::sleep_for
+#include <vector>
+#include <iostream>
+#include "aidge/utils/Types.h"
+#include "aidge/operator/FC.hpp"
+
+#include "aidge/backend/cuda/data/TensorImpl.hpp"
+#include "aidge/backend/cuda/operator/FCImpl_CUDA_kernels.hpp"
+#include "aidge/backend/cuda/operator/FCImpl.hpp"
+#include "aidge/backend/cuda/utils/CudaContext.hpp"
+
+
+void Aidge::FCImpl_cuda::forward() {
+    assert(mOp.getRawInput(0) && "missing input #0");
+    assert(mOp.getRawInput(1) && "missing input #1");
+    assert(mOp.getRawInput(2) && "missing input #2");
+
+    std::shared_ptr<Tensor> inputFallback, input1Fallback, input2Fallback;
+    const auto& input0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->refCastFrom(inputFallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
+    const auto& input1 = std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->refCastFrom(input1Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
+    const auto& input2 = std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->refCastFrom(input2Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
+
+    const FC_Op& fcOp = static_cast<const FC_Op&>(mOp);
+    bool noBias = fcOp.template getAttr<FCAttr::NoBias>();
+    DimSize_t outChannels = fcOp.template getAttr<FCAttr::OutChannels>();
+    if (std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType() == DataType::Float64) {
+        forward_<double>(input0, input1, input2, noBias, outChannels);
+    }
+    else {
+        forward_<float>(input0, input1, input2, noBias, outChannels);
+    }
+}
+
+template<class T>
+void Aidge::FCImpl_cuda::forward_(const Tensor& input0, const Tensor& input1, const Tensor& input2, bool noBias, DimSize_t outChannels)
+{
+    Aidge::fc_forward_cuda<T>(
+                           input0.dims()[0],
+                           input0.size() / input0.dims()[0],
+                           outChannels,
+                           noBias,
+                           input0.getImpl()->rawPtr(), 
+                           input1.getImpl()->rawPtr(),
+                           input2.getImpl()->rawPtr(),
+                           std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());
+}
\ No newline at end of file
diff --git a/src/operator/FCImpl_CUDA_kernels.cu b/src/operator/FCImpl_CUDA_kernels.cu
new file mode 100644
index 0000000000000000000000000000000000000000..14d731f33e1a8edc5a6126d5d1f026b7d2af64c9
--- /dev/null
+++ b/src/operator/FCImpl_CUDA_kernels.cu
@@ -0,0 +1,48 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+#include <stdio.h>
+
+#include "aidge/backend/cuda/operator/FCImpl_CUDA_kernels.hpp"
+
+template<class T>
+__global__
+void fc_forward_cuda_kernel(std::size_t nbInputs, std::size_t inChannels, std::size_t outChannels, bool noBias,const T* input, const T* weights, const T* bias, T *output)
+{
+    const std::size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
+
+    for(std::size_t batch=idx; batch<nbInputs; ++batch)
+    {
+        for (std::size_t out = 0; out < outChannels; ++out) {
+            T sum = 0;
+            for (std::size_t in = 0; in < inChannels; ++in) {
+                sum += input[batch * inChannels + in] * weights[out * inChannels + in];
+            }
+            output[batch * outChannels + out] = sum + (noBias ? 0 : bias[out]);
+        }
+    }
+}
+
+namespace Aidge{
+template<class T>
+void fc_forward_cuda(DimSize_t nbInputs, DimSize_t inChannels, DimSize_t outChannels, bool noBias, const void* input_, const void* weights_, const void* bias_, void* output_)
+{
+    const T* input = static_cast<const T*>(input_);
+    const T* weights = static_cast<const T*>(weights_);
+    const T* bias = static_cast<const T*>(bias_);
+    T * output = static_cast<T*>(output_);
+
+    const dim3 blocksPerGrid = {(static_cast<unsigned int>(inChannels) + 255) / 256, 1, static_cast<unsigned int>(outChannels)};
+    const dim3 threadsPerBlocks = {256, 1, 1};
+
+    fc_forward_cuda_kernel<<<blocksPerGrid, threadsPerBlocks>>>(nbInputs, inChannels, outChannels, noBias, input, weights, bias, output);
+    CHECK_CUDA_STATUS(cudaPeekAtLastError());
+}
+}
diff --git a/unit_tests/Test_FCImpl.cpp b/unit_tests/Test_FCImpl.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..05624d96b17a73ce84e9978599da849ad20c2764
--- /dev/null
+++ b/unit_tests/Test_FCImpl.cpp
@@ -0,0 +1,134 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <array>
+
+#include <catch2/catch_test_macros.hpp>
+
+#include "Test_cuda.hpp"
+
+#include "aidge/data/Tensor.hpp"
+
+#include "aidge/backend/cpu.hpp"
+#include "aidge/backend/cuda.hpp"
+
+using namespace Aidge;
+
+TEST_CASE("[gpu/operator] FC(forward)", "[FC][GPU]") {
+    std::shared_ptr<Tensor> myWeights = std::make_shared<Tensor>(Array2D<float, 5, 75>{
+            {{1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15, 1,  2,  3,  4,
+              5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15, 1,  2,  3,  4,  5,  6,  7,  8,
+              9,  10, 11, 12, 13, 14, 15, 1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12,
+              13, 14, 15, 1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15},
+             {1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15, 1,  2,  3,  4,
+              5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15, 1,  2,  3,  4,  5,  6,  7,  8,
+              9,  10, 11, 12, 13, 14, 15, 1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12,
+              13, 14, 15, 1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15},
+             {1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15, 1,  2,  3,  4,
+              5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15, 1,  2,  3,  4,  5,  6,  7,  8,
+              9,  10, 11, 12, 13, 14, 15, 1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12,
+              13, 14, 15, 1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15},
+             {1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15, 1,  2,  3,  4,
+              5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15, 1,  2,  3,  4,  5,  6,  7,  8,
+              9,  10, 11, 12, 13, 14, 15, 1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12,
+              13, 14, 15, 1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15},
+             {1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15, 1,  2,  3,  4,
+              5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15, 1,  2,  3,  4,  5,  6,  7,  8,
+              9,  10, 11, 12, 13, 14, 15, 1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12,
+              13, 14, 15, 1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15}}});
+    std::shared_ptr<Tensor> myBias = std::make_shared<Tensor>(Array1D<float, 5>{{1, 2, 3, 4, 5}});
+    std::shared_ptr<Tensor> myOutput = std::make_shared<Tensor>(Array2D<float, 2, 5>{
+            {{23601, 23602, 23603, 23604, 23605}, {68601, 68602, 68603, 68604, 68605}}});
+    myWeights->setBackend("cuda");
+    myBias->setBackend("cuda");
+    std::shared_ptr<Node> myFC = FC(75, 5, false, "myfc");
+    auto op = std::static_pointer_cast<OperatorTensor>(myFC -> getOperator());
+    op -> associateInput(1, myWeights);
+    op -> associateInput(2, myBias);
+    SECTION("2D input") {
+        std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array2D<float, 2, 75>{
+                {{0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15, 16, 17, 18,
+                  19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37,
+                  38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56,
+                  57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74},
+                 {75,  76,  77,  78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,
+                  90,  91,  92,  93,  94,  95,  96,  97,  98,  99,  100, 101, 102, 103, 104,
+                  105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119,
+                  120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134,
+                  135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149}}});
+        myInput->setBackend("cuda");
+        op->associateInput(0, myInput);
+        op -> setDataType(DataType::Float32);
+        op -> setBackend("cuda");
+        op->computeOutputDims();
+        myFC->forward();
+
+        float* computedOutput   = new float[myOutput->size()]();
+        cudaMemcpy(computedOutput, op->getOutput(0)->getImpl()->rawPtr(), sizeof(float) * myOutput->size(), cudaMemcpyDeviceToHost);
+
+        for(int i = 0; i < myOutput->size(); i++){
+            const float targetOutput = *(static_cast<float*>(myOutput->getImpl()->rawPtr()) + i);
+            std::cout << "targetOutput " << targetOutput << ",  out " << computedOutput[i]<<std::endl;
+            REQUIRE(fabs(computedOutput[i] - targetOutput) < 1e-6);
+        }
+
+        delete[] computedOutput;
+    }
+    SECTION("4D input") {
+        std::shared_ptr<Tensor> myInput =
+                std::make_shared<Tensor>(Array4D<float, 2, 3, 5, 5>{{{{{0, 1, 2, 3, 4},
+                                                                     {5, 6, 7, 8, 9},
+                                                                     {10, 11, 12, 13, 14},
+                                                                     {15, 16, 17, 18, 19},
+                                                                     {20, 21, 22, 23, 24}},
+                                                                    {{25, 26, 27, 28, 29},
+                                                                     {30, 31, 32, 33, 34},
+                                                                     {35, 36, 37, 38, 39},
+                                                                     {40, 41, 42, 43, 44},
+                                                                     {45, 46, 47, 48, 49}},
+                                                                    {{50, 51, 52, 53, 54},
+                                                                     {55, 56, 57, 58, 59},
+                                                                     {60, 61, 62, 63, 64},
+                                                                     {65, 66, 67, 68, 69},
+                                                                     {70, 71, 72, 73, 74}}},
+                                                                   {{{75, 76, 77, 78, 79},
+                                                                     {80, 81, 82, 83, 84},
+                                                                     {85, 86, 87, 88, 89},
+                                                                     {90, 91, 92, 93, 94},
+                                                                     {95, 96, 97, 98, 99}},
+                                                                    {{100, 101, 102, 103, 104},
+                                                                     {105, 106, 107, 108, 109},
+                                                                     {110, 111, 112, 113, 114},
+                                                                     {115, 116, 117, 118, 119},
+                                                                     {120, 121, 122, 123, 124}},
+                                                                    {{125, 126, 127, 128, 129},
+                                                                     {130, 131, 132, 133, 134},
+                                                                     {135, 136, 137, 138, 139},
+                                                                     {140, 141, 142, 143, 144},
+                                                                     {145, 146, 147, 148, 149}}}}});
+        myInput->setBackend("cuda");
+        op->associateInput(0, myInput);
+        op -> setDataType(DataType::Float32);
+        op -> setBackend("cuda");
+        op->computeOutputDims();
+        myFC->forward();
+
+        float* computedOutput   = new float[myOutput->size()]();
+        cudaMemcpy(computedOutput, op->getOutput(0)->getImpl()->rawPtr(), sizeof(float) * myOutput->size(), cudaMemcpyDeviceToHost);
+
+        for(int i = 0; i < myOutput->size(); i++){
+            const float targetOutput = *(static_cast<float*>(myOutput->getImpl()->rawPtr()) + i);
+            REQUIRE(fabs(computedOutput[i] - targetOutput) < 1e-6);
+        }
+
+        delete[] computedOutput;
+    }
+}
\ No newline at end of file