diff --git a/include/aidge/backend/cuda.hpp b/include/aidge/backend/cuda.hpp
index 1359cd7fe1f3081ef510e7e90581ce968d7810f0..b5fbc21e2f683d209ec5e367f19c583a738fc385 100644
--- a/include/aidge/backend/cuda.hpp
+++ b/include/aidge/backend/cuda.hpp
@@ -14,6 +14,7 @@
 
 #include "aidge/backend/cuda/data/TensorImpl.hpp"
 #include "aidge/backend/cuda/operator/AddImpl.hpp"
+#include "aidge/backend/cuda/operator/SubImpl.hpp"
 #include "aidge/backend/cuda/operator/AvgPoolingImpl.hpp"
 #include "aidge/backend/cuda/operator/BatchNormImpl.hpp"
 #include "aidge/backend/cuda/operator/ConvImpl.hpp"
@@ -26,4 +27,4 @@
 #include "aidge/backend/cuda/operator/SigmoidImpl.hpp"
 #include "aidge/backend/cuda/operator/TanhImpl.hpp"
 
-#endif /* AIDGE_BACKEND_CUDA_IMPORTS_H_ */
\ No newline at end of file
+#endif /* AIDGE_BACKEND_CUDA_IMPORTS_H_ */
diff --git a/include/aidge/backend/cuda/operator/SubImpl.hpp b/include/aidge/backend/cuda/operator/SubImpl.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..0c244621b1940b2268dbf2f369acbac81c20ad8f
--- /dev/null
+++ b/include/aidge/backend/cuda/operator/SubImpl.hpp
@@ -0,0 +1,54 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#ifndef AIDGE_BACKEND_CUDA_OPERATOR_SUBIMPL_H_
+#define AIDGE_BACKEND_CUDA_OPERATOR_SUBIMPL_H_
+
+#include <array>
+#include <memory>
+#include <tuple>
+#include <vector>
+
+#include <cudnn.h>
+
+#include "aidge/backend/OperatorImpl.hpp"
+#include "aidge/operator/Sub.hpp"
+#include "aidge/utils/Registrar.hpp"
+#include "aidge/utils/Types.h"
+
+#include "aidge/backend/cuda/utils/CudaUtils.hpp"
+
+namespace Aidge {
+class SubImpl_cuda : public OperatorImpl {
+private:
+
+
+public:
+    SubImpl_cuda(const Sub_Op &op) : OperatorImpl(op, "cuda") {}
+
+    static std::unique_ptr<SubImpl_cuda> create(const Sub_Op &op) {
+        return std::make_unique<SubImpl_cuda>(op);
+    }
+
+public:
+    void forward();
+    // ~SubImpl_cuda();
+private:
+    template <class T> void forward_(const std::vector<Tensor>& inputs, const std::vector<std::vector<int>>& inputsDims, const std::vector<std::vector<int>>& inputsStrides);
+};
+
+namespace {
+// add cuda backend to Sub_Op implementation registry
+static Registrar<Sub_Op> registrarSubImpl_cuda("cuda", Aidge::SubImpl_cuda::create);
+}  // namespace
+}  // namespace Aidge
+
+#endif /* AIDGE_BACKEND_CUDA_OPERATOR_SUBIMPL_H_ */
diff --git a/src/operator/SubImpl.cpp b/src/operator/SubImpl.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..a5856c11350770e926df104e68a6795460561cfc
--- /dev/null
+++ b/src/operator/SubImpl.cpp
@@ -0,0 +1,106 @@
+/********************************************************************************
+ * Copyright (c) 2024 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <algorithm>
+#include <cassert>
+#include <numeric>
+#include <vector>
+
+#include "aidge/backend/cuda/data/TensorImpl.hpp"
+#include "aidge/backend/cuda/operator/SubImpl.hpp"
+#include "aidge/backend/cuda/utils/CudaContext.hpp"
+#include "aidge/backend/cuda/utils/CudaUtils.hpp"
+#include "aidge/operator/Sub.hpp"
+#include "aidge/utils/Types.h"
+
+void Aidge::SubImpl_cuda::forward() {
+    const Sub_Op& op = static_cast<const Sub_Op&>(mOp);
+    // Check inputs
+    AIDGE_ASSERT(op.getInput(0), "missing input in Sub operator");
+    AIDGE_ASSERT(op.getInput(0)->hasImpl(), "cannot run Sub forward because the 0-th input has no implementation.");
+    DataType datatypeFirstInput = op.getInput(0)->dataType();
+    for (IOIndex_t i = 1; i < op.nbInputs(); ++i) {
+        AIDGE_ASSERT(op.getInput(i), "missing input in Sub operator");
+        AIDGE_ASSERT(op.getInput(i)->hasImpl(), "cannot run Sub forward because the {}-th input has no implementation.", i);
+        AIDGE_ASSERT(op.getInput(i)->dataType() == datatypeFirstInput, "Cannot add inputs with two differents data type.");
+    }
+
+    std::vector<std::shared_ptr<Tensor>> inputFallbacks(op.nbInputs());
+    std::vector<Tensor> inputs(op.nbInputs());
+    std::vector<std::vector<int>> dims(op.nbInputs()); // For broadcasted dims
+    std::vector<std::vector<int>> strides(op.nbInputs()); // For the cooresponding strides
+    for (IOIndex_t i = 0; i < op.nbInputs(); ++i) {
+        inputs[i] = op.getInput(i)->refCastFrom(inputFallbacks[i], *op.getOutput(0));
+
+        // Get tensor dims and broadcast them
+        std::copy(inputs[i].dims().begin(), inputs[i].dims().end(), std::back_inserter(dims[i]));
+        dims[i].insert(dims[i].cbegin(), op.getOutput(0)->nbDims() - dims[i].size(), int(1));
+
+        // Compute the corresponding strides
+        std::vector<int> tensorStrides(dims[i].size());
+        int product = 1;
+        for (size_t j = dims[i].size(); j > 0; --j) {
+            tensorStrides[j - 1] = product;
+            product *= dims[i][j - 1];
+        }
+        strides[i] = tensorStrides;
+    }
+
+    switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
+        case DataType::Float64:
+            forward_<double>(inputs, dims, strides);
+            break;
+        case DataType::Float32:
+            forward_<float>(inputs, dims, strides);
+            break;
+        case DataType::Float16:
+            forward_<half>(inputs, dims, strides);
+            break;
+        default:
+            AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
+    }
+}
+
+template <class T>
+void Aidge::SubImpl_cuda::forward_(const std::vector<Tensor>& inputs, const std::vector<std::vector<int>>& inputsDims, const std::vector<std::vector<int>>& inputsStrides) {
+    const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
+    const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f;
+    const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f;
+    const typename Cuda::cudnn_scaling_type<T>::type gamma = -1.0f;
+    // Create a Tensor descriptor with the broadcasted dims and strides
+    cudnnTensorDescriptor_t tensorDesc;
+    CHECK_CUDNN_STATUS(cudnnSetTensorNdDescriptor(tensorDesc, CudaContext::data_type<T>::value, inputsDims[0].size(), inputsDims[0].data(), inputsStrides[0].data()));
+    // Add first input
+    CHECK_CUDNN_STATUS(
+        cudnnAddTensor(CudaContext::cudnnHandle(),
+                       &alpha,
+                       tensorDesc,
+                       inputs[0].getImpl()->rawPtr(),
+                       &beta,
+                       std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
+                       std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr())
+    );
+    // Add other inputs if there are any
+    for (size_t i = 1; i < op.nbInputs(); ++i)
+    {
+        CHECK_CUDNN_STATUS(cudnnSetTensorNdDescriptor(tensorDesc, CudaContext::data_type<T>::value, inputsDims[i].size(), inputsDims[i].data(), inputsStrides[i].data()));
+        CHECK_CUDNN_STATUS(
+            cudnnAddTensor(CudaContext::cudnnHandle(),
+                        &gamma,
+                        tensorDesc,
+                        inputs[i].getImpl()->rawPtr(),
+                        &alpha,
+                        std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
+                        std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr())
+        );
+    }
+    CHECK_CUDNN_STATUS(cudnnDestroyTensorDescriptor(tensorDesc));
+}