diff --git a/include/aidge/backend/cpu/operator/ReLUImpl.hpp b/include/aidge/backend/cpu/operator/ReLUImpl.hpp
index cef82482813757312c638aebac9f2afd738493db..e2ebf44616db876b462157db650ff48362dd7bac 100644
--- a/include/aidge/backend/cpu/operator/ReLUImpl.hpp
+++ b/include/aidge/backend/cpu/operator/ReLUImpl.hpp
@@ -30,7 +30,7 @@ class ReLUImplForward_cpu
     : public Registrable<ReLUImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> {
 };
 class ReLUImplBackward_cpu
-    : public Registrable<ReLUImplBackward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> {
+    : public Registrable<ReLUImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, void*)> {
 };
 
 class ReLUImpl_cpu : public OperatorImpl {
diff --git a/include/aidge/backend/cpu/operator/ReLUImpl_backward_kernels.hpp b/include/aidge/backend/cpu/operator/ReLUImpl_backward_kernels.hpp
index b68ea076cb94eb9550b4a7af89ef58162ee15aea..43a9714ad2d32228fac9bf9c526191f0cec5bfa0 100644
--- a/include/aidge/backend/cpu/operator/ReLUImpl_backward_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/ReLUImpl_backward_kernels.hpp
@@ -14,31 +14,32 @@
 
 #include <cstddef>  // std::size_t
 
-#include "aidge/utils/Registrar.hpp"
-
 #include "aidge/backend/cpu/operator/ReLUImpl.hpp"
+#include "aidge/utils/Registrar.hpp"
 
 namespace Aidge {
-template <class I, class O>
+template <class O, class GI, class GO>
 void ReLUImpl_cpu_backward_kernel(const std::size_t inputLenght,
-                                     const void* input_,
-                                     void* output_) {
-
-    const I* input = static_cast<const I*>(input_);
-    O* output = static_cast<O*>(output_);
-
+                                  const void* output_, const void* grad_output_,
+                                  void* grad_input_) {
+    const O* output = static_cast<const O*>(output_);
+    const GO* grad_output = static_cast<const GO*>(grad_output_);
+    GI* grad_input = static_cast<GI*>(grad_input_);
     for (std::size_t i = 0; i < inputLenght; ++i) {
-        output[i] = (input[i] > I(0)) ? static_cast<O>(input[i]) : O(0);
+        grad_input[i] = (output[i] > GO(0)) ? GI(grad_output[i]) : GI(0);
     }
 }
 
 namespace {
 static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Float32(
-        {DataType::Float32, DataType::Float32}, Aidge::ReLUImpl_cpu_backward_kernel<float, float>);
+    {DataType::Float32, DataType::Float32, DataType::Float32},
+    Aidge::ReLUImpl_cpu_backward_kernel<float, float, float>);
 static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Int32(
-        {DataType::Int32, DataType::Int32}, Aidge::ReLUImpl_cpu_backward_kernel<int, int>);
+    {DataType::Int32, DataType::Int32, DataType::Int32},
+    Aidge::ReLUImpl_cpu_backward_kernel<int, int, int>);
 static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Float64(
-        {DataType::Float64, DataType::Float64}, Aidge::ReLUImpl_cpu_backward_kernel<double, double>);
+    {DataType::Float64, DataType::Float64, DataType::Float64},
+    Aidge::ReLUImpl_cpu_backward_kernel<double, double, double>);
 }  // namespace
 }  // namespace Aidge
 
diff --git a/src/operator/ReLUImpl.cpp b/src/operator/ReLUImpl.cpp
index 4bba09b6fbeea1552bf5b7cc7e491291345fca45..06859f09db169946175a93140e04f2e2a99e3362 100644
--- a/src/operator/ReLUImpl.cpp
+++ b/src/operator/ReLUImpl.cpp
@@ -45,16 +45,18 @@ void Aidge::ReLUImpl_cpu::forward() {
 void Aidge::ReLUImpl_cpu::backward() {
     // reversing in and out Tensors
         const ReLU_Op& op_ = dynamic_cast<const ReLU_Op&>(mOp);
-    std::shared_ptr<Tensor> in0  = op_.getOutput(0)->grad();
-    std::shared_ptr<Tensor> out0 = op_.getInput(0)->grad();
+    std::shared_ptr<Tensor> out0  = op_.getOutput(0);
+    std::shared_ptr<Tensor> gra_out0 = op_.getOutput(0)->grad();
+    std::shared_ptr<Tensor> gra_int0 = op_.getInput(0)->grad();
     AIDGE_ASSERT(out0, "current {} operator output#0 has not gradient Tensor.", op_.type());
 
     // Find the correct kernel type
     auto kernelFunc = Registrar<ReLUImplBackward_cpu>::create({
-        in0->dataType(),
-        out0->dataType()
+        out0->dataType(),
+        gra_out0->dataType(),
+        gra_int0->dataType()
     });
 
     // Call kernel
-    kernelFunc(in0->size(), getCPUPtr(in0), getCPUPtr(out0));
+    kernelFunc(gra_int0->size(), getCPUPtr(out0), getCPUPtr(gra_out0), getCPUPtr(gra_int0));
 }