diff --git a/include/aidge/backend/cpu.hpp b/include/aidge/backend/cpu.hpp
index 80574b4a46fef0c843c9511836f162e02de5aab3..5c1f9b111f41a435aa477d0647fa66fb29a058fb 100644
--- a/include/aidge/backend/cpu.hpp
+++ b/include/aidge/backend/cpu.hpp
@@ -27,6 +27,7 @@
 #include "aidge/backend/cpu/operator/ClipImpl.hpp"
 #include "aidge/backend/cpu/operator/ConvDepthWiseImpl.hpp"
 #include "aidge/backend/cpu/operator/ConvImpl.hpp"
+#include "aidge/backend/cpu/operator/ConvTransposeImpl.hpp"
 #include "aidge/backend/cpu/operator/ConstantOfShapeImpl.hpp"
 #include "aidge/backend/cpu/operator/CryptoHashImpl.hpp"
 #include "aidge/backend/cpu/operator/DivImpl.hpp"
diff --git a/include/aidge/backend/cpu/operator/ConvImpl.hpp b/include/aidge/backend/cpu/operator/ConvImpl.hpp
index c06d0912f419909013f930867ce3c3238c1a5555..e480697b6452440f043901140a07cb643f3cbdb6 100644
--- a/include/aidge/backend/cpu/operator/ConvImpl.hpp
+++ b/include/aidge/backend/cpu/operator/ConvImpl.hpp
@@ -13,45 +13,64 @@
 #define AIDGE_CPU_OPERATOR_CONVIMPL_H_
 
 #include <array>
-#include <memory>
-#include <tuple>
-#include <vector>
 
 #include "aidge/backend/cpu/operator/OperatorImpl.hpp"
 #include "aidge/operator/Conv.hpp"
 #include "aidge/utils/Registrar.hpp"
 #include "aidge/utils/Types.h"
-#include "aidge/backend/cpu/data/GetCPUPtr.h"
 
 namespace Aidge {
+
 // Operator implementation entry point for the backend
 using Conv1D_Op = Conv_Op<1>;
 using ConvImpl1D_cpu = OperatorImpl_cpu<Conv_Op<1>,
-    void(const std::array<DimSize_t, 1>&,
-        const std::array<DimSize_t, 1>&,
-        const std::array<DimSize_t, 1>&,
-        const std::array<DimSize_t, 3> &,
-        DimSize_t,
-        const void *,
-        const void *,
-        const void *,
-        void *)>;
+                                        void(const std::array<DimSize_t, 1> &,
+                                             const std::array<DimSize_t, 1> &,
+                                             const std::array<DimSize_t, 1> &,
+                                             const std::array<DimSize_t, 3> &,
+                                             DimSize_t,
+                                             const void *,
+                                             const void *,
+                                             const void *,
+                                             void *),
+                                        void(const std::array<DimSize_t, 1> &,
+                                             const std::array<DimSize_t, 1> &,
+                                             const std::array<DimSize_t, 1> &,
+                                             const std::array<DimSize_t, 3> &,
+                                             const std::array<DimSize_t, 3> &,
+                                             const void *,
+                                             const void *,
+                                             const void *,
+                                             void *,
+                                             void *,
+                                             void *)>;
 
 using Conv2D_Op = Conv_Op<2>;
-using ConvImpl2D_cpu = OperatorImpl_cpu<Conv_Op<2>,
-    void(const std::array<DimSize_t, 2>&,
-        const std::array<DimSize_t, 2>&,
-        const std::array<DimSize_t, 2>&,
-        const std::array<DimSize_t, 4> &,
-        DimSize_t,
-        const void *,
-        const void *,
-        const void *,
-        void *)>;
+using ConvImpl2D_cpu = OperatorImpl_cpu<Conv2D_Op,
+                                        void(const std::array<DimSize_t, 2> &,
+                                             const std::array<DimSize_t, 2> &,
+                                             const std::array<DimSize_t, 2> &,
+                                             const std::array<DimSize_t, 4> &,
+                                             DimSize_t,
+                                             const void *,
+                                             const void *,
+                                             const void *,
+                                             void *),
+                                        void(const std::array<DimSize_t, 2> &,
+                                             const std::array<DimSize_t, 2> &,
+                                             const std::array<DimSize_t, 2> &,
+                                             const std::array<DimSize_t, 4> &,
+                                             const std::array<DimSize_t, 4> &,
+                                             const void *,
+                                             const void *,
+                                             const void *,
+                                             void *,
+                                             void *,
+                                             void *)>;
 
 // Implementation entry point registration to Operator
 REGISTRAR(Conv1D_Op, "cpu", Aidge::ConvImpl1D_cpu::create);
 REGISTRAR(Conv2D_Op, "cpu", Aidge::ConvImpl2D_cpu::create);
-}  // namespace Aidge
+} // namespace Aidge
 
 #endif /* AIDGE_CPU_OPERATOR_CONVIMPL_H_ */
diff --git a/include/aidge/backend/cpu/operator/ConvImpl_kernels.hpp b/include/aidge/backend/cpu/operator/ConvImpl_kernels.hpp
index 1229d5714e6b0cbae4e42ece9130c2c2305f133e..7ae9e45fe4f5d7436e3f08447c69bef3c16b6218 100644
--- a/include/aidge/backend/cpu/operator/ConvImpl_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/ConvImpl_kernels.hpp
@@ -13,18 +13,16 @@
 #define AIDGE_CPU_OPERATOR_CONVIMPL_KERNELS_H_
 
 #include <array>
-#include <memory>
-#include <tuple>
-#include <vector>
+#include <cstdint>
+#include <fmt/base.h>
 
-#include "aidge/backend/cpu/operator/OperatorImpl.hpp"
 #include "aidge/backend/cpu/operator/ConvImpl.hpp"
-#include "aidge/operator/Conv.hpp"
 #include "aidge/utils/Registrar.hpp"
 #include "aidge/utils/Types.h"
-#include "aidge/backend/cpu/data/GetCPUPtr.h"
 
 namespace Aidge {
+using std::array;
+
 /**
  * @brief Forward kernel for 1D Convolution on CPU backend.
  * @tparam I Input data type.
@@ -39,16 +37,15 @@ namespace Aidge {
  * @param output_ Output Tensor.
  */
 template <class I, class W, class B, class O>
-void ConvImpl1D_cpu_forward_kernel(const std::array<DimSize_t, 1>& strideDims,
-                            const std::array<DimSize_t, 1>& dilationDims,
-                            const std::array<DimSize_t, 1>& kernelDims,
-                            const std::array<DimSize_t, 3>& inputDims,
-                            DimSize_t outChannels,
-                            const void *input_,
-                            const void *weights_,
-                            const void *biases_,
-                            void *output_)
-{
+void ConvImpl1D_cpu_forward_kernel(const array<DimSize_t, 1> &strideDim,
+                                   const array<DimSize_t, 1> &dilationDim,
+                                   const array<DimSize_t, 1> &kernelDim,
+                                   const std::array<DimSize_t, 3> &inputDims,
+                                   DimSize_t outChannels,
+                                   const void *input_,
+                                   const void *weights_,
+                                   const void *biases_,
+                                   void *output_) {
     // FIXME: missing convolution attributes as arguments
     const I *input = static_cast<const I *>(input_);
     const W *weights = static_cast<const W *>(weights_);
@@ -56,38 +53,38 @@ void ConvImpl1D_cpu_forward_kernel(const std::array<DimSize_t, 1>& strideDims,
     O *output = static_cast<O *>(output_);
 
     // output H size
-    const std::size_t oxSize =
-            static_cast<std::size_t>(std::floor(static_cast<float>(inputDims[2] - dilationDims[0]*(kernelDims[0] - 1) - 1 + strideDims[0]) /
-                                static_cast<float>(strideDims[0])));
-    const DimSize_t dilated_kernel_x = dilationDims[0]*(kernelDims[0] - 1) + 1;
+    const std::size_t oxSize = static_cast<std::size_t>(std::floor(
+        static_cast<float>(inputDims[2] - dilationDim[0] * (kernelDim[0] - 1) -
+                           1 + strideDim[0]) /
+        static_cast<float>(strideDim[0])));
+    const DimSize_t dilated_kernel_x = dilationDim[0] * (kernelDim[0] - 1) + 1;
 
-    // TODO: kernel computation
-    // output (batch, outCh, Xout, Yout)
-    // input  (batch, inCh, Xin, Yin)
-    // weight (outCh, inCh, kernelX, kernelY)
-    // does not take Dilation attribute into account
     using signedsize = std::make_signed<std::size_t>::type;
     for (std::size_t batch = 0; batch < inputDims[0]; ++batch) {
         for (std::size_t outCh = 0; outCh < outChannels; ++outCh) {
-            const std::size_t oIndex = (outCh + batch*outChannels) * oxSize;
+            const std::size_t oIndex = (outCh + batch * outChannels) * oxSize;
             // If bias = nullptr, set B(0)
             B biasVal = (biases != nullptr) ? biases[outCh] : B(0);
-            std::fill(output + oIndex, output+(oIndex+oxSize), biasVal);
+            std::fill(output + oIndex, output + (oIndex + oxSize), biasVal);
             for (std::size_t inCh = 0; inCh < inputDims[1]; ++inCh) {
-                const std::size_t iIndex = (inCh + batch*inputDims[1]) * inputDims[2];
-                const std::size_t wIndex = (inCh + outCh*inputDims[1]) * kernelDims[0];
+                const std::size_t iIndex =
+                    (inCh + batch * inputDims[1]) * inputDims[2];
+                const std::size_t wIndex =
+                    (inCh + outCh * inputDims[1]) * kernelDim[0];
                 for (std::size_t ox = 0; ox < oxSize; ++ox) {
-                    // const signedsize difx = static_cast<signedsize>(- ox * strideDims[0]);
-                    // const std::size_t sxMin = static_cast<std::size_t>(std::max(difx, signedsize(0)));
-                    // const std::size_t sxMax = (static_cast<signedsize>(inputDims[2]) + difx) < 0 ? 0 : ((inputDims[2] + difx) > kernelDims[0] ? kernelDims[0] : inputDims[2] + difx);
                     const std::size_t sxMin = 0;
                     const std::size_t sxMax = dilated_kernel_x;
                     const std::size_t oIndexFull = oIndex + ox;
-                    const signedsize ix = static_cast<signedsize>(ox * strideDims[0]);
+                    const signedsize ix =
+                        static_cast<signedsize>(ox * strideDim[0]);
 
-                    for (std::size_t sx = sxMin; sx*dilationDims[0] < sxMax; ++sx) {
-                        output[oIndexFull] += weights[wIndex + sx] *
-                                                input[iIndex + static_cast<std::size_t>(ix+static_cast<signedsize>(sx*dilationDims[0]))];
+                    for (std::size_t sx = sxMin; sx * dilationDim[0] < sxMax;
+                         ++sx) {
+                        output[oIndexFull] +=
+                            weights[wIndex + sx] *
+                            input[iIndex + static_cast<std::size_t>(
+                                               ix + static_cast<signedsize>(
+                                                        sx * dilationDim[0]))];
                     }
                 }
             }
@@ -95,20 +92,342 @@ void ConvImpl1D_cpu_forward_kernel(const std::array<DimSize_t, 1>& strideDims,
     }
 }
 
+/**
+ * @brief perform 1D backpropagation for the data input
+ * @note INPUT & OUTPUT convention is the same as in the
+ * forward function
+ * @note formula :
+ * for i in 0..input_size:
+ *  for n in 0..weight_size:
+ *    dL     dYn  dL
+ *   ---- = ---- ----
+ *    dXi    dXi  Yn
+ * with : dYn / dXi = w_k
+ * for each input value
+ * for each weight
+ * for each output
+ * multiply the weight with the associated value
+ * @note kernel & stride are passed as single integers as they are just arrays
+ * of length 1
+ * @note reminder that kernel dimensions are
+ * {outChannels, inChannels, {kernelDims}}
+ * <=> {oDims[1], iDims[1], kernelDim}
+ * @tparam I Input data type.
+ * @tparam W Weight data type.
+ * @tparam O Output data type.
+ * @param[in] stride stride parameter of the convolution operator
+ * @param[in] dilation dilation parameter of the convolution operator
+ * @param[in] kDims dimension of the kernel
+ * @param[in] kStrides nb of elements contained per dimension of the kernel
+ * @param[in] weights kernel weights
+ * @param[in] oDims dimensions of the output
+ * @param[in] oStrides nb of elements contained per dimension of the output
+ * @param[in] oGrad output gradient
+ * @param[in] iDims input dimensions
+ * @param[in] iStrides nb of elements contained per dimension of the input
+ * @param[inout] iGrad gradients of the input to update
+ */
+template <class I, class W, class O>
+void conv1DBackwardInput(const array<DimSize_t, 1> &stride,
+                         const array<DimSize_t, 1> &dilation,
+                         const array<DimSize_t, 1> &kDim,
+                         const array<DimSize_t, 2> &kStrides,
+                         const W *weights,
+                         const array<DimSize_t, 3> &oDims,
+                         const array<DimSize_t, 2> &oStrides,
+                         const O *oGrad,
+                         const array<DimSize_t, 3> &iDims,
+                         const array<DimSize_t, 2> &iStrides,
+                         I *iGrad) {
+
+    array<DimSize_t, 2> iOffsets{0, 0};
+    array<DimSize_t, 2> oOffsets{0, 0};
+    array<DimSize_t, 2> kOffsets{0, 0};
+
+    for (std::size_t batch = 0; batch < iDims[0]; ++batch) {
+        iOffsets[0] = batch * iStrides[0];
+        oOffsets[0] = batch * oStrides[0];
+
+        for (DimSize_t oChannel = 0; oChannel < oDims[1]; oChannel++) {
+            oOffsets[1] = (oChannel * oStrides[1]) + oOffsets[0];
+            kOffsets[0] = oChannel * kStrides[0];
+
+            for (std::size_t iChannel = 0; iChannel < iDims[1]; ++iChannel) {
+                iOffsets[1] = (iChannel * iStrides[1]) + iOffsets[0];
+                kOffsets[1] = iChannel * kStrides[1] + kOffsets[0];
+
+                for (DimSize_t oX = 0; oX < oDims[2]; ++oX) {
+                    auto iX = oX * stride[0];
+                    auto inIdx = iX + iOffsets[1];
+
+                    for (DimSize_t kX = 0; kX < kDim[0]; ++kX) {
+                        auto dilatedKernelIdx = kX * dilation[0];
+
+                        iGrad[inIdx + dilatedKernelIdx] +=
+                            weights[kOffsets[1] + kX] *
+                            oGrad[oOffsets[1] + oX];
+                    }
+                }
+            }
+        }
+    }
+}
+
+/**
+ * @brief computes weight backpropagation for conv1D
+ * @note INPUT & OUTPUT convention is the same as in the
+ * forward function
+ * weight grad
+ * for i in 0..weight_size:
+ *  for n in 0..output_size:
+ *    dL     dYn  dL
+ *   ---- = ---- ----
+ *   dwi     dwi  Yn
+ * with : dYn / dwi = x_k
+ * @tparam I Input data type.
+ * @tparam W Weight data type.
+ * @tparam O Output data type.
+ * @param[in] stride stride parameter of the convolution operator
+ * @param[in] dilation dilation parameter of the convolution operator
+ * @param[in] iDims input dimensions
+ * @param[in] iStrides nb of elements contained per dimension of the input
+ * @param[inout] iGrad gradients of the input to update
+ * @param[in] oDims dimensions of the output
+ * @param[in] oStrides nb of elements contained per dimension of the output
+ * @param[in] oGrad output gradient
+ * @param[in] kDims dimension of the kernel
+ * @param[in] kStrides nb of elements contained per dimension of the kernel
+ * @param[in] weights kernel weights
+ */
+template <class I, class W, class O>
+static void conv1DBackwardWeights(const array<DimSize_t, 1> &stride,
+                                  const array<DimSize_t, 1> &dilation,
+                                  const array<DimSize_t, 3> &iDims,
+                                  const array<DimSize_t, 2> iStrides,
+                                  const I *input,
+                                  const array<DimSize_t, 3> &oDims,
+                                  const array<DimSize_t, 2> oStrides,
+                                  const O *oGrad,
+                                  const array<DimSize_t, 1> &kDim,
+                                  const array<DimSize_t, 2> kStrides,
+                                  W *weightsGrad) {
+
+    array<DimSize_t, 2> iOffsets{0, 0};
+    array<DimSize_t, 2> oOffsets{0, 0};
+    array<DimSize_t, 2> kOffsets{0, 0};
+
+    for (DimSize_t batch = 0; batch < oDims[0]; ++batch) {
+        iOffsets[0] = batch * iStrides[0];
+        oOffsets[0] = batch * oStrides[0];
+
+        for (DimSize_t oChannel = 0; oChannel < oDims[1]; ++oChannel) {
+            oOffsets[1] = oChannel * oStrides[1] + oOffsets[0];
+            kOffsets[0] = oChannel * kStrides[0];
+
+            for (DimSize_t iChannel = 0; iChannel < iDims[1]; ++iChannel) {
+                kOffsets[1] = iChannel * kStrides[1] + kOffsets[0];
+                iOffsets[1] = iChannel * iStrides[1] + iOffsets[0];
+                oOffsets[1] = oChannel * oStrides[1] + oOffsets[0];
+
+                for (DimSize_t kX = 0; kX < kDim[0]; ++kX) {
+
+                    for (DimSize_t oX = 0; oX < oDims[2]; ++oX) {
+                        const DimSize_t iX = oX * stride[0] + kX * dilation[0] ;
+
+                        weightsGrad[kOffsets[1] + kX] +=
+                            input[iOffsets[1] + iX] * oGrad[oOffsets[1] + oX];
+                    }
+                }
+            }
+        }
+    }
+}
+
+/**
+ * @brief computes bias backpropagation for conv1D operation
+ * @note INPUT & OUTPUT convention is the same as in the
+ * forward function
+ * @note formula :
+ * Bias grad:
+ * for i in 0..bias_size:
+ *  for n in 0..output_size:
+ *    dL     dYn  dL
+ *   ---- = ---- ----
+ *   dbi     dbi  Yn
+ * with : dYn / dbi = 1
+ *
+ * Hence the partial derivative of the loss wrt bias is the
+ * output loss. Hence the bias grad is just the sum of the
+ * loss values over the batch
+ * @tparam I Input data type.
+ * @tparam W Weight data type.
+ * @tparam B Bias data type.
+ * @tparam O Output data type.
+ * @param[in] oDims output tensor dimensions
+ * @param[in] oStrides nb of elements contained per dimension of the output
+ * tensor
+ * @param[in] oGrad output tensor gradients
+ * @param[inout] biasesGrad biases gradients
+ */
+template <class B, class O>
+static void conv1DBackwardBias(const array<DimSize_t, 3> &oDims,
+                               const array<DimSize_t, 2> &oStrides,
+                               const O *oGrad,
+                               B *biasesGrad) {
+    array<DimSize_t, 2> oOffsets{0, 0};
+
+    for (DimSize_t batchIdx = 0; batchIdx < oDims[0]; ++batchIdx) {
+        oOffsets[0] = batchIdx * oStrides[0];
+
+        for (DimSize_t oChannel = 0; oChannel < oDims[1]; ++oChannel) {
+            oOffsets[1] = oChannel * oStrides[1] + oOffsets[0];
+
+            for (DimSize_t oIdx = 0; oIdx < oDims[2]; oIdx++) {
+                biasesGrad[oChannel] += oGrad[oOffsets[1] + oIdx];
+            }
+        }
+    }
+}
+
+/**
+ * @brief Backward kernel for 1D Convolution on CPU backend.
+ * @note INPUT & OUTPUT convention is the same as in the
+ * forward function
+ *
+ * @tparam I Input data type.
+ * @tparam W Weight data type.
+ * @tparam B Bias data type.
+ * @tparam O Output data type.
+ * @param[in] const stride
+ * @param[in] const kernelDims
+ * @param[in] const iDims input data dimensions
+ * @param[in] const oDims output data dimmensions
+ * @param[in] const oChannels output channel number
+ * @param[in] const input_ const input Tensor.
+ * @param[in] const weights_ const weight Tensor.
+ * @param[in] const biases_ const Biais Tensor.
+ * @param[in] const output_ Output Tensor.
+ * @param[in] const oGrad_ gradients of output data
+ * @param[inout] iGrad_ gradients of input data
+ * @param[inout] weightsGrad_ gradients of the kernel weights
+ * @param[inout] biasesGrad_ gradients of the kernel biases
+ */
+template <class I, class W, class B, class O>
+void ConvImpl1D_cpu_backward_kernel(const array<DimSize_t,1> &stride,
+                                    const array<DimSize_t,1> &dilation,
+                                    const array<DimSize_t,1> &kernelDim,
+                                    const array<DimSize_t, 3> &inputDims,
+                                    const array<DimSize_t, 3> &outputDims,
+                                    const void *input_,
+                                    const void *weights_,
+                                    const void *oGrad_,
+                                    void *iGrad_,
+                                    void *weightsGrad_,
+                                    void *biasesGrad_) {
+
+    const I *input = static_cast<const I *>(input_);
+    I *iGrad = static_cast<I *>(iGrad_);
+    const I *oGrad = static_cast<const I *>(oGrad_);
+    const W *weights = static_cast<const W *>(weights_);
+    W *weightsGrad = static_cast<W *>(weightsGrad_);
+
+    //////////////////////////////
+    // COMPUTING STRIDES
+    //////////////////////////////
+    // NOTE: The ...Stride var represent the number of values contained in
+    // each dimension they will be used to compute the index offset of
+    // values while iterating on each tensor
+    // NOTE: They are 1 item shorter than their corresponding tensor as the
+    // number of total elements is not used except for gradient initialization
+
+    // {batch_stride, channel_stride, dim0_stride, dim1_stride}
+    const array<DimSize_t, 2> inputStrides{inputDims[1] * inputDims[2],
+                                           inputDims[2]};
+    const DimSize_t nbEltsInput = inputDims[0] * inputStrides[0];
+
+    // {batch_stride, channel_stride, dim0_stride, dim1_stride}
+    const array<DimSize_t, 2> outputStrides{outputDims[1] * outputDims[2],
+                                            outputDims[2]};
+
+    // NOTE: kernel dims = {iChannel, oChannel, kernelDim0, kernelDim1}
+    // kernel_strides = {iChannel, oChannel, kernelDim0}
+    const array<DimSize_t, 2> kernelStrides{
+        inputDims[1] * kernelDim[0],
+        kernelDim[0],
+    };
+    const DimSize_t nbEltsKernel = outputDims[1] * kernelStrides[0];
+
+    std::fill(iGrad, iGrad + nbEltsInput, I(0));
+    std::fill(weightsGrad, weightsGrad + nbEltsKernel, W(0));
+
+    conv1DBackwardInput(stride,
+                        dilation,
+                        kernelDim,
+                        kernelStrides,
+                        weights,
+                        outputDims,
+                        outputStrides,
+                        oGrad,
+                        inputDims,
+                        inputStrides,
+                        iGrad);
+
+    conv1DBackwardWeights(stride,
+                          dilation,
+                          inputDims,
+                          inputStrides,
+                          input,
+                          outputDims,
+                          outputStrides,
+                          oGrad,
+                          kernelDim,
+                          kernelStrides,
+                          weightsGrad);
+
+    if (biasesGrad_ != nullptr) {
+        B *biasesGrad = static_cast<B *>(biasesGrad_);
+        std::fill(biasesGrad, biasesGrad + outputDims[1], B(0));
+        conv1DBackwardBias(outputDims, outputStrides, oGrad, biasesGrad);
+    }
+}
+
 // Kernels registration to implementation entry point
 REGISTRAR(ConvImpl1D_cpu,
-    {{DataType::Any, DataFormat::NCHW}, {DataType::Float32, DataFormat::NCHW}},
-    {ProdConso::inPlaceModel, Aidge::ConvImpl1D_cpu_forward_kernel<float, float, float, float>, nullptr});
+          {{DataType::Any, DataFormat::NCHW},
+           {DataType::Float32, DataFormat::NCHW}},
+          {ProdConso::inPlaceModel,
+           ConvImpl1D_cpu_forward_kernel<float, float, float, float>,
+           ConvImpl1D_cpu_backward_kernel<float, float, float, float>});
 REGISTRAR(ConvImpl1D_cpu,
-    {{DataType::Any, DataFormat::NCHW}, {DataType::Float16, DataFormat::NCHW}},
-    {ProdConso::inPlaceModel, Aidge::ConvImpl1D_cpu_forward_kernel<half_float::half, half_float::half, half_float::half, half_float::half>, nullptr});
+          {{DataType::Any, DataFormat::NCHW},
+           {DataType::Float16, DataFormat::NCHW}},
+          {ProdConso::inPlaceModel,
+           ConvImpl1D_cpu_forward_kernel<half_float::half,
+                                         half_float::half,
+                                         half_float::half,
+                                         half_float::half>,
+           ConvImpl1D_cpu_backward_kernel<half_float::half,
+                                          half_float::half,
+                                          half_float::half,
+                                          half_float::half>});
 REGISTRAR(ConvImpl1D_cpu,
-    {{DataType::Any, DataFormat::NCHW}, {DataType::Int32, DataFormat::NCHW}},
-    {ProdConso::inPlaceModel, Aidge::ConvImpl1D_cpu_forward_kernel<int32_t, int32_t, int32_t, int32_t>, nullptr});
+          {{DataType::Any, DataFormat::NCHW},
+           {DataType::Float64, DataFormat::NCHW}},
+          {ProdConso::inPlaceModel,
+           ConvImpl1D_cpu_forward_kernel<double, double, double, double>,
+           ConvImpl1D_cpu_backward_kernel<double, double, double, double>});
 REGISTRAR(ConvImpl1D_cpu,
-    {{DataType::Any, DataFormat::NCHW}, {DataType::Float64, DataFormat::NCHW}},
-    {ProdConso::inPlaceModel, Aidge::ConvImpl1D_cpu_forward_kernel<double, double, double, double>, nullptr});
-
+          {{DataType::Any, DataFormat::NCHW},
+           {DataType::Int32, DataFormat::NCHW}},
+          {ProdConso::inPlaceModel,
+           ConvImpl1D_cpu_forward_kernel<std::int32_t,
+                                         std::int32_t,
+                                         std::int32_t,
+                                         std::int32_t>,
+           ConvImpl1D_cpu_backward_kernel<std::int32_t,
+                                          std::int32_t,
+                                          std::int32_t,
+                                          std::int32_t>});
 
 /**
  * @brief Forward kernel for 2D Convolution on CPU backend.
@@ -124,16 +443,15 @@ REGISTRAR(ConvImpl1D_cpu,
  * @param output_ Output Tensor.
  */
 template <class I, class W, class B, class O>
-void ConvImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideDims,
-                            const std::array<DimSize_t, 2>& dilationDims,
-                            const std::array<DimSize_t, 2>& kernelDims,
-                            const std::array<DimSize_t, 4> &inputDims,
-                            DimSize_t outChannels,
-                            const void *input_,
-                            const void *weights_,
-                            const void *biases_,
-                            void *output_)
-{
+void ConvImpl2D_cpu_forward_kernel(const array<DimSize_t, 2> &strideDims,
+                                   const array<DimSize_t, 2> &dilationDims,
+                                   const array<DimSize_t, 2> &kernelDims,
+                                   const array<DimSize_t, 4> &inputDims,
+                                   DimSize_t outChannels,
+                                   const void *input_,
+                                   const void *weights_,
+                                   const void *biases_,
+                                   void *output_) {
     // FIXME: missing convolution attributes as arguments
     const I *input = static_cast<const I *>(input_);
     const W *weights = static_cast<const W *>(weights_);
@@ -141,59 +459,102 @@ void ConvImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideDims,
     O *output = static_cast<O *>(output_);
 
     // output H size
-    const DimSize_t dilated_kernel_x = dilationDims[0]*(kernelDims[0] - 1) + 1;
-    const std::size_t oxSize =
-            static_cast<std::size_t>(std::floor(static_cast<float>(inputDims[2] - dilated_kernel_x + strideDims[0]) /
-                                static_cast<float>(strideDims[0])));
+    const DimSize_t dilated_kernel_x =
+        dilationDims[0] * (kernelDims[0] - 1) + 1;
+    const std::size_t oxSize = static_cast<std::size_t>(std::floor(
+        static_cast<float>(inputDims[2] - dilated_kernel_x + strideDims[0]) /
+        static_cast<float>(strideDims[0])));
     // output W size
-    const DimSize_t dilated_kernel_y = dilationDims[1]*(kernelDims[1] - 1) + 1;
-    const std::size_t oySize =
-            static_cast<std::size_t>(std::floor(static_cast<float>(inputDims[3] - dilated_kernel_y + strideDims[1]) /
-                                static_cast<float>(strideDims[1])));
-
+    const DimSize_t dilated_kernel_y =
+        dilationDims[1] * (kernelDims[1] - 1) + 1;
+    const std::size_t oySize = static_cast<std::size_t>(std::floor(
+        static_cast<float>(inputDims[3] - dilated_kernel_y + strideDims[1]) /
+        static_cast<float>(strideDims[1])));
 
     // TODO: kernel computation
     // output (batch, outCh, Xout, Yout)
     // input  (batch, inCh, Xin, Yin)
     // weight (outCh, inCh, kernelX, kernelY)
     // does not take Dilation attribute into account
-    const std::size_t outChannels_s =  oxSize * oySize;
+    const std::size_t outChannels_s = oxSize * oySize;
 
     if (dilated_kernel_x == 3 && dilated_kernel_y == 3) {
         for (std::size_t batch = 0; batch < inputDims[0]; ++batch) {
             for (std::size_t outCh = 0; outCh < outChannels; ++outCh) {
                 // If bias = nullptr, set B(0)
                 B biasVal = (biases != nullptr) ? biases[outCh] : B(0);
-                std::fill(output, output+outChannels_s, biasVal);
+                std::fill(output, output + outChannels_s, biasVal);
                 for (std::size_t inCh = 0; inCh < inputDims[1]; ++inCh) {
-                    std::size_t iIndex = (inCh + batch*inputDims[1]) * inputDims[2] * inputDims[3];
-                    const std::size_t wIndex = (inCh + outCh*inputDims[1]) * 9;
-                    if (strideDims[0] == 1 && strideDims[1]==1) {
-                        for (std::size_t ox = 0, oIndex = 0; ox < oxSize; ++ox, oIndex+=oySize, iIndex-=inputDims[3]) {
+                    std::size_t iIndex = (inCh + batch * inputDims[1]) *
+                                         inputDims[2] * inputDims[3];
+                    const std::size_t wIndex =
+                        (inCh + outCh * inputDims[1]) * 9;
+                    if (strideDims[0] == 1 && strideDims[1] == 1) {
+                        for (std::size_t ox = 0, oIndex = 0; ox < oxSize;
+                             ++ox, oIndex += oySize, iIndex -= inputDims[3]) {
                             for (std::size_t oy = 0; oy < oySize; ++oy) {
-                                output[oIndex + oy] += weights[wIndex+0]*input[iIndex+oy]+weights[wIndex+1]*input[iIndex+oy+1]+weights[wIndex+2]*input[iIndex+oy+2];
+                                output[oIndex + oy] +=
+                                    weights[wIndex + 0] * input[iIndex + oy] +
+                                    weights[wIndex + 1] *
+                                        input[iIndex + oy + 1] +
+                                    weights[wIndex + 2] *
+                                        input[iIndex + oy + 2];
                             }
-                            iIndex+=inputDims[3];
+                            iIndex += inputDims[3];
                             for (std::size_t oy = 0; oy < oySize; ++oy) {
-                                output[oIndex + oy] += weights[wIndex+3]*input[iIndex+oy]+weights[wIndex+4]*input[iIndex+oy+1]+weights[wIndex+5]*input[iIndex+oy+2];
+                                output[oIndex + oy] +=
+                                    weights[wIndex + 3] * input[iIndex + oy] +
+                                    weights[wIndex + 4] *
+                                        input[iIndex + oy + 1] +
+                                    weights[wIndex + 5] *
+                                        input[iIndex + oy + 2];
                             }
-                            iIndex+=inputDims[3];
+                            iIndex += inputDims[3];
                             for (std::size_t oy = 0; oy < oySize; ++oy) {
-                                output[oIndex + oy] += weights[wIndex+6]*input[iIndex+oy]+weights[wIndex+7]*input[iIndex+oy+1]+weights[wIndex+8]*input[iIndex+oy+2];
+                                output[oIndex + oy] +=
+                                    weights[wIndex + 6] * input[iIndex + oy] +
+                                    weights[wIndex + 7] *
+                                        input[iIndex + oy + 1] +
+                                    weights[wIndex + 8] *
+                                        input[iIndex + oy + 2];
                             }
                         }
                     } else {
-                        for (std::size_t ox = 0, oIndex = 0; ox < oxSize; ++ox, oIndex+=oySize, iIndex+=(strideDims[0]-2)*inputDims[3]) {
+                        for (std::size_t ox = 0, oIndex = 0; ox < oxSize; ++ox,
+                                         oIndex += oySize,
+                                         iIndex += (strideDims[0] -
+                                                    2) * inputDims[3]) {
                             for (std::size_t oy = 0; oy < oySize; ++oy) {
-                                output[oIndex + oy] += weights[wIndex+0]*input[iIndex+oy*strideDims[1]]+weights[wIndex+1]*input[iIndex+oy*strideDims[1]+1]+weights[wIndex+2]*input[iIndex+oy*strideDims[1]+2];
+                                output[oIndex + oy] +=
+                                    weights[wIndex + 0] *
+                                        input[iIndex + oy * strideDims[1]] +
+                                    weights[wIndex + 1] *
+                                        input[iIndex + oy * strideDims[1] +
+                                              1] +
+                                    weights[wIndex + 2] *
+                                        input[iIndex + oy * strideDims[1] + 2];
                             }
-                            iIndex+=inputDims[3];
+                            iIndex += inputDims[3];
                             for (std::size_t oy = 0; oy < oySize; ++oy) {
-                                output[oIndex + oy] += weights[wIndex+3]*input[iIndex+oy*strideDims[1]]+weights[wIndex+4]*input[iIndex+oy*strideDims[1]+1]+weights[wIndex+5]*input[iIndex+oy*strideDims[1]+2];
+                                output[oIndex + oy] +=
+                                    weights[wIndex + 3] *
+                                        input[iIndex + oy * strideDims[1]] +
+                                    weights[wIndex + 4] *
+                                        input[iIndex + oy * strideDims[1] +
+                                              1] +
+                                    weights[wIndex + 5] *
+                                        input[iIndex + oy * strideDims[1] + 2];
                             }
-                            iIndex+=inputDims[3];
+                            iIndex += inputDims[3];
                             for (std::size_t oy = 0; oy < oySize; ++oy) {
-                                output[oIndex + oy] += weights[wIndex+6]*input[iIndex+oy*strideDims[1]]+weights[wIndex+7]*input[iIndex+oy*strideDims[1]+1]+weights[wIndex+8]*input[iIndex+oy*strideDims[1]+2];
+                                output[oIndex + oy] +=
+                                    weights[wIndex + 6] *
+                                        input[iIndex + oy * strideDims[1]] +
+                                    weights[wIndex + 7] *
+                                        input[iIndex + oy * strideDims[1] +
+                                              1] +
+                                    weights[wIndex + 8] *
+                                        input[iIndex + oy * strideDims[1] + 2];
                             }
                         }
                     }
@@ -206,18 +567,26 @@ void ConvImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideDims,
             for (std::size_t outCh = 0; outCh < outChannels; ++outCh) {
                 // If bias = nullptr, set B(0)
                 B biasVal = (biases != nullptr) ? biases[outCh] : B(0);
-                std::fill(output, output+outChannels_s, biasVal);
+                std::fill(output, output + outChannels_s, biasVal);
                 for (std::size_t inCh = 0; inCh < inputDims[1]; ++inCh) {
-                    std::size_t iIndex = (inCh + batch*inputDims[1]) * inputDims[2] * inputDims[3];
-                    const std::size_t wIndex = (inCh + outCh*inputDims[1]);
+                    std::size_t iIndex = (inCh + batch * inputDims[1]) *
+                                         inputDims[2] * inputDims[3];
+                    const std::size_t wIndex = (inCh + outCh * inputDims[1]);
                     if (strideDims[0] == 1 && strideDims[1] == 1) {
-                        for (std::size_t oIndex = 0; oIndex < oxSize*oySize; ++oIndex, ++iIndex) {
+                        for (std::size_t oIndex = 0; oIndex < oxSize * oySize;
+                             ++oIndex, ++iIndex) {
                             output[oIndex] += weights[wIndex] * input[iIndex];
                         }
-                    } else  {
-                        for (std::size_t ox = 0, oIndex = 0; ox < oxSize; ++ox, oIndex+=oySize, iIndex+=inputDims[3]*strideDims[0]) {
-                            for (std::size_t oy = 0, iy = 0; oy < oySize; ++oy, iy+=strideDims[1]) {
-                                output[oIndex + oy] += weights[wIndex+0]*input[iIndex+iy];
+                    } else {
+                        for (std::size_t ox = 0, oIndex = 0; ox < oxSize;
+                             ++ox,
+                                         oIndex += oySize,
+                                         iIndex +=
+                                         inputDims[3] * strideDims[0]) {
+                            for (std::size_t oy = 0, iy = 0; oy < oySize;
+                                 ++oy, iy += strideDims[1]) {
+                                output[oIndex + oy] +=
+                                    weights[wIndex + 0] * input[iIndex + iy];
                             }
                         }
                     }
@@ -230,21 +599,36 @@ void ConvImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideDims,
             for (std::size_t outCh = 0; outCh < outChannels; ++outCh) {
                 // If bias = nullptr, set B(0)
                 B biasVal = (biases != nullptr) ? biases[outCh] : B(0);
-                std::fill(output, output+outChannels_s, biasVal);
+                std::fill(output, output + outChannels_s, biasVal);
                 for (std::size_t inCh = 0; inCh < inputDims[1]; ++inCh) {
-                    std::size_t iIndex_channel = (inCh + batch*inputDims[1]) * inputDims[2] * inputDims[3];
-                    const std::size_t wIndex = (inCh + outCh*inputDims[1]) * kernelDims[0] * kernelDims[1];
+                    std::size_t iIndex_channel =
+                        (inCh + batch * inputDims[1]) * inputDims[2] *
+                        inputDims[3];
+                    const std::size_t wIndex = (inCh + outCh * inputDims[1]) *
+                                               kernelDims[0] * kernelDims[1];
 
                     // loop over each ouput line
-                    for (std::size_t ox = 0, oIndex = 0; ox < oxSize; ++ox, oIndex+=oySize, iIndex_channel+=inputDims[3]*strideDims[0]) {
+                    for (std::size_t ox = 0, oIndex = 0; ox < oxSize;
+                         ++ox,
+                                     oIndex += oySize,
+                                     iIndex_channel +=
+                                     inputDims[3] * strideDims[0]) {
                         // loop over associated input line
-                        for (std::size_t ky = 0, ix = 0; ky < kernelDims[0]; ++ky, ix += inputDims[3]*dilationDims[0]) {
+                        for (std::size_t ky = 0, ix = 0; ky < kernelDims[0];
+                             ++ky, ix += inputDims[3] * dilationDims[0]) {
                             // loop over the entire line
-                            for (std::size_t oy = 0, iy = 0; oy < oySize; ++oy, iy+=strideDims[1]) {
-                                const std::size_t iIndex = iIndex_channel + ix + iy;
-                                // loop over elements assosicated with one output
-                                for (std::size_t kx = 0;  kx < kernelDims[0]; ++kx) {
-                                    output[oIndex + oy] += weights[wIndex+kernelDims[0]*ky+kx]*input[iIndex+kx*dilationDims[1]];
+                            for (std::size_t oy = 0, iy = 0; oy < oySize;
+                                 ++oy, iy += strideDims[1]) {
+                                const std::size_t iIndex =
+                                    iIndex_channel + ix + iy;
+                                // loop over elements assosicated with one
+                                // output
+                                for (std::size_t kx = 0; kx < kernelDims[0];
+                                     ++kx) {
+                                    output[oIndex + oy] +=
+                                        weights[wIndex + kernelDims[0] * ky +
+                                                kx] *
+                                        input[iIndex + kx * dilationDims[1]];
                                 }
                             }
                         }
@@ -256,21 +640,380 @@ void ConvImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideDims,
     }
 }
 
+/**
+ * @brief perform backpropagation for the input
+ * @note INPUT & OUTPUT convention is the same as in the
+ * forward function
+ * @note formula :
+ * for i in 0..input_size:
+ *  for n in 0..weight_size:
+ *    dL     dYn  dL
+ *   ---- = ---- ----
+ *    dXi    dXi  Yn
+ * with : dYn / dXi = w_k
+ * for each input value
+ * for each weight
+ * for each output
+ * multiply the weight with the associated value
+ * @note kernel & stride are passed as single integers as they are just arrays
+ * of length 1
+ * @note reminder that kernel dimensions are
+ * {outChannels, inChannels, {kernelDims}}
+ * <=> {oDims[1], iDims[1], kernelDim}
+ * @tparam I Input data type.
+ * @tparam W Weight data type.
+ * @tparam O Output data type.
+ * @param[in] stride stride parameter of the convolution operator
+ * @param[in] dilation dilation parameter of the convolution operator
+ * @param[in] kDims dimension of the kernel
+ * @param[in] kStrides nb of elements contained per dimension of the kernel
+ * @param[in] weights weights values
+ * @param[in] oDims dimensions of the output
+ * @param[in] oStrides nb of elements contained per dimension of the output
+ * @param[in] oGrad output gradient
+ * @param[in] iDims input dimensions
+ * @param[in] iStrides nb of elements contained per dimension of the input
+ * @param[inout] iGrad gradients of the input to update
+ */
+template <class I, class W, class O>
+void conv2DBackwardInput(const array<DimSize_t, 2> &stride,
+                         const array<DimSize_t, 2> &dilation,
+                         const array<DimSize_t, 2> &kDims,
+                         const array<DimSize_t, 3> &kStrides,
+                         const W *weights,
+                         const array<DimSize_t, 4> &oDims,
+                         const array<DimSize_t, 3> &oStrides,
+                         const O *oGrad,
+                         const array<DimSize_t, 4> &iDims,
+                         const array<DimSize_t, 3> &iStrides,
+                         I *iGrad) {
+    // records index offsets for each dimension that have a stride (== all
+    // dimension except the last) for every parsed tensor
+    array<DimSize_t, 3> kOffset{};
+    array<DimSize_t, 3> iOffset{};
+    array<DimSize_t, 3> oOffset{};
+
+    for (std::size_t batch = 0; batch < iDims[0]; ++batch) {
+        iOffset[0] = batch * iStrides[0];
+        oOffset[0] = batch * oStrides[0];
+
+        for (DimSize_t oChannel = 0; oChannel < oDims[1]; oChannel++) {
+            oOffset[1] = (oChannel * oStrides[1]) + oOffset[0];
+            kOffset[0] = (oChannel * kStrides[0]);
+
+            for (std::size_t iChannel = 0; iChannel < iDims[1]; ++iChannel) {
+                iOffset[1] = (iChannel * iStrides[1]) + iOffset[0];
+                kOffset[1] = iChannel * kStrides[1] + kOffset[0];
+
+                for (DimSize_t oX = 0; oX < oDims[2]; ++oX) {
+                    oOffset[2] = (oX * oStrides[2]) + oOffset[1];
+
+                    auto iX = oX * stride[0];
+                    iOffset[2] = (iX * iStrides[2]) + iOffset[1];
 
+                    for (DimSize_t oY = 0; oY < oDims[3]; ++oY) {
+                        auto oIdx = oOffset[2] + oY;
+
+                        auto iY = oY * stride[1];
+                        auto iIdx = iOffset[2] + iY;
+
+                        for (DimSize_t kX = 0; kX < kDims[0]; ++kX) {
+                            auto kDilX = kX * dilation[0];
+                            auto iDilKXOffset = kDilX * iStrides[2];
+
+                            kOffset[2] = (kX * kStrides[2]) + kOffset[1];
+
+                            for (DimSize_t kY = 0; kY < kDims[1]; ++kY) {
+                                auto kDilY = kY * dilation[1];
+
+                                iGrad[iIdx + iDilKXOffset + kDilY] +=
+                                    weights[kOffset[2] + kY] * oGrad[oIdx];
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+/**
+ * @brief computes weight backpropagation for conv2D operation
+ * @note INPUT & OUTPUT convention is the same as in the
+ * forward function
+ * weight grad
+ * for i in 0..weight_size:
+ *  for n in 0..output_size:
+ *    dL     dYn  dL
+ *   ---- = ---- ----
+ *   dwi     dwi  Yn
+ * with : dYn / dwi = x_k
+ * @tparam I input dtype
+ * @tparam W weight dtype
+ * @tparam O output dtype
+ * @param[in] iDims input data dimensions
+ * @param[in] iBatchStride nb element in each input data batch
+ * @param[in] iChannelStride nb element in each input data channel
+ * @param[in] input input data
+ * @param[in] oDims output data dimmensions
+ * @param[in] oBatchStride nb element in each output data batch
+ * @param[in] oChannelStride nb element in each output data channel
+ * @param[in] oGrad gradients of output data
+ * @param[in] stride
+ * @param[in] kernelDims
+ * @param[inout] weightsGrad gradients of the kernel weights
+ */
+template <class I, class W, class O>
+void conv2DBackwardWeights(const array<DimSize_t, 4> &iDims,
+                           const array<DimSize_t, 3> &iStrides,
+                           const I *input,
+                           const array<DimSize_t, 4> &oDims,
+                           const array<DimSize_t, 3> &oStrides,
+                           const O *oGrad,
+                           const array<DimSize_t, 2> &kDim,
+                           const array<DimSize_t, 3> &kStrides,
+                           const array<DimSize_t, 2> &stride,
+                           const array<DimSize_t, 2> &dilation,
+                           W *weightsGrad) {
+    // records index offsets for each dimension that have a stride (== all
+    // dimension except the last) for every parsed tensor
+    array<DimSize_t, 3> iOffsets{0, 0, 0};
+    array<DimSize_t, 3> oOffsets{0, 0, 0};
+    array<DimSize_t, 3> kOffsets{0, 0, 0};
+
+    for (DimSize_t batchIdx = 0; batchIdx < oDims[0]; ++batchIdx) {
+        iOffsets[0] = batchIdx * iStrides[0];
+        oOffsets[0] = batchIdx * oStrides[0];
+
+        for (DimSize_t iChannel = 0; iChannel < iDims[1]; ++iChannel) {
+            iOffsets[1] = iChannel * iStrides[1] + iOffsets[0];
+            kOffsets[0] = iChannel * kStrides[0];
+
+            for (DimSize_t oChannel = 0; oChannel < oDims[1]; ++oChannel) {
+                oOffsets[1] = oChannel * oStrides[1] + oOffsets[0];
+                kOffsets[1] = oChannel * kStrides[1] + kOffsets[0];
+
+                for (DimSize_t kX = 0; kX < kDim[0]; ++kX) {
+                    kOffsets[2] = kX * kStrides[2] + kOffsets[1];
+                    for (DimSize_t kY = 0; kY < kDim[1]; ++kY) {
+
+                        for (DimSize_t oX = 0; oX < oDims[2]; ++oX) {
+                            const DimSize_t iX =
+                                oX * stride[0] + kX * dilation[0];
+
+                            oOffsets[2] = oX * oStrides[2] + oOffsets[1];
+                            iOffsets[2] = iX * iStrides[2] + iOffsets[1];
+
+                            for (DimSize_t oY = 0; oY < oDims[3]; ++oY) {
+                                const DimSize_t iY =
+                                    oY * stride[1] + kY * dilation[1];
+
+                                weightsGrad[kOffsets[2] + kY] +=
+                                    input[iOffsets[2] + iY] *
+                                    oGrad[oOffsets[2] + oY];
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+/**
+ * @brief computes bias backpropagation for conv2D operation
+ * @note INPUT & OUTPUT convention is the same as in the
+ * forward function
+ * @note formula :
+ * Bias grad:
+ * for i in 0..bias_size:
+ *  for n in 0..output_size:
+ *    dL     dYn  dL
+ *   ---- = ---- ----
+ *   dbi     dbi  Yn
+ * with : dYn / dbi = 1
+ *
+ * Hence the partial derivative of the loss wrt bias is the
+ * output loss Hence the bias grad is just the sum of the
+ * loss values over the batch
+ * @tparam I Input data type.
+ * @tparam W Weight data type.
+ * @tparam B Bias data type.
+ * @tparam O Output data type.
+ * @param[in] oDims output tensor dimensions
+ * @param[in] oStrides nb of elements contained per dimension of the
+ * output
+ * @param[in] oGrad output tensor gradients
+ * @param[inout] biasesGrad biases gradients
+ */
+template <class B, class O>
+static void conv2DBackwardBias(const array<DimSize_t, 4> &oDims,
+                               const array<DimSize_t, 3> &oStrides,
+                               const O *oGrad,
+                               B *biasesGrad) {
+    // records all index offsets for output tensor
+    array<DimSize_t, 3> oOffsets{};
+    for (DimSize_t batchIdx = 0; batchIdx < oDims[0]; ++batchIdx) {
+        oOffsets[0] = batchIdx * oStrides[0];
+
+        for (DimSize_t oChannel = 0; oChannel < oDims[1]; ++oChannel) {
+            oOffsets[1] = oChannel * oStrides[1] + oOffsets[0];
+
+            for (DimSize_t oX = 0; oX < oDims[2]; ++oX) {
+                oOffsets[2] = oX * oStrides[2] + oOffsets[1];
+
+                for (DimSize_t oY = 0; oY < oDims[3]; ++oY) {
+                    biasesGrad[oChannel] += oGrad[oOffsets[2] + oY];
+                }
+            }
+        }
+    }
+}
+
+/**
+ * @brief Backward kernel for 2D Convolution on CPU backend.
+ * @note INPUT & OUTPUT convention is the same as in the
+ * forward function
+ *
+ * @tparam I Input data type.
+ * @tparam W Weight data type.
+ * @tparam B Bias data type.
+ * @tparam O Output data type.
+ * @param[in] const stride attribute of conv operator
+ * @param[in] const dilation attribute of conv operator
+ * @param[in] const kernelDims
+ * @param[in] const iDims input data dimensions
+ * @param[in] const oDims output data dimmensions
+ * @param[in] const input_ input tensor.
+ * @param[in] const weights_ kernel tensor.
+ * @param[in] const oGrad_ output tensor gradient.
+ * @param[inout] iGrad_ input tensor gradient.
+ * @param[inout] weightsGrad_  kernel weights tensor gradients
+ * @param[inout] biasesGrad_  kernel biases tensor gradients
+ */
+template <class I, class W, class B, class O>
+void ConvImpl2D_cpu_backward_kernel(const array<DimSize_t, 2> &stride,
+                                    const array<DimSize_t, 2> &dilation,
+                                    const array<DimSize_t, 2> &kernelDims,
+                                    const array<DimSize_t, 4> &inputDims,
+                                    const array<DimSize_t, 4> &outputDims,
+                                    const void *input_,
+                                    const void *weights_,
+                                    const void *oGrad_,
+                                    void *iGrad_,
+                                    void *weightsGrad_,
+                                    void *biasesGrad_) {
+
+    const I *input = static_cast<const I *>(input_);
+    I *iGrad = static_cast<I *>(iGrad_);
+    const I *outputGrad = static_cast<const I *>(oGrad_);
+    const W *weights = static_cast<const W *>(weights_);
+    W *weightsGrad = static_cast<W *>(weightsGrad_);
+
+    //////////////////////////////
+    // COMPUTING STRIDES
+    //////////////////////////////
+    // NOTE: The ...Stride var represent the number of values contained in
+    // each dimension they will be used to compute the index offset of
+    // values while iterating on each tensor
+    // NOTE: They are 1 item shorter than their corresponding tensor as the
+    // number of total elements is not used except for gradient initialization
+
+    // {batch_stride, channel_stride, dim0_stride, dim1_stride}
+    const array<DimSize_t, 3> inputStrides{
+        inputDims[1] * inputDims[2] * inputDims[3],
+        inputDims[2] * inputDims[3],
+        inputDims[3]};
+    const DimSize_t nbEltsInput = inputDims[0] * inputStrides[0];
+
+    // {batch_stride, channel_stride, dim0_stride, dim1_stride}
+    const array<DimSize_t, 3> outputStrides{
+        outputDims[1] * outputDims[2] * outputDims[3],
+        outputDims[2] * outputDims[3],
+        outputDims[3]};
+
+    // NOTE: kernel dims = {iChannel, oChannel, kernelDim0, kernelDim1}
+    // kernel_strides = {iChannel, oChannel, kernelDim0}
+    const array<DimSize_t, 3> kernelStrides{
+        inputDims[1] * kernelDims[0] * kernelDims[1],
+        kernelDims[0] * kernelDims[1],
+        kernelDims[1]};
+
+    const DimSize_t nbEltsKernel = outputDims[1] * kernelStrides[0];
+
+    ////////////////////////////
+    // prepping gradient arrays
+    std::fill(iGrad, iGrad + nbEltsInput, I(0));
+    std::fill(weightsGrad, weightsGrad + nbEltsKernel, W(0));
+
+    conv2DBackwardInput(stride,
+                        dilation,
+                        kernelDims,
+                        kernelStrides,
+                        weights,
+                        outputDims,
+                        outputStrides,
+                        outputGrad,
+                        inputDims,
+                        inputStrides,
+                        iGrad);
+
+    conv2DBackwardWeights(inputDims,
+                          inputStrides,
+                          input,
+                          outputDims,
+                          outputStrides,
+                          outputGrad,
+                          kernelDims,
+                          kernelStrides,
+                          stride,
+                          dilation,
+                          weightsGrad);
+
+    if (biasesGrad_ != nullptr) {
+        B *biasesGrad = static_cast<B *>(biasesGrad_);
+        std::fill(biasesGrad, biasesGrad + outputDims[1], B(0));
+        conv2DBackwardBias(outputDims, outputStrides, outputGrad, biasesGrad);
+    }
+}
 
 // Kernels registration to implementation entry point
 REGISTRAR(ConvImpl2D_cpu,
-    {{DataType::Any, DataFormat::NCHW}, {DataType::Float32, DataFormat::NCHW}},
-    {ProdConso::inPlaceModel, Aidge::ConvImpl2D_cpu_forward_kernel<float, float, float, float>, nullptr});
-REGISTRAR(ConvImpl2D_cpu,
-    {{DataType::Any, DataFormat::NCHW}, {DataType::Float16, DataFormat::NCHW}},
-    {ProdConso::inPlaceModel, Aidge::ConvImpl2D_cpu_forward_kernel<half_float::half, half_float::half, half_float::half, half_float::half>, nullptr});
-REGISTRAR(ConvImpl2D_cpu,
-    {{DataType::Any, DataFormat::NCHW}, {DataType::Int32, DataFormat::NCHW}},
-    {ProdConso::inPlaceModel, Aidge::ConvImpl2D_cpu_forward_kernel<int32_t, int32_t, int32_t, int32_t>, nullptr});
+          {{DataType::Any, DataFormat::NCHW},
+           {DataType::Float32, DataFormat::NCHW}},
+          {ProdConso::inPlaceModel,
+           Aidge::ConvImpl2D_cpu_forward_kernel<float, float, float, float>,
+           Aidge::ConvImpl2D_cpu_backward_kernel<float, float, float, float>});
 REGISTRAR(ConvImpl2D_cpu,
+          {{DataType::Any, DataFormat::NCHW},
+           {DataType::Float16, DataFormat::NCHW}},
+          {ProdConso::inPlaceModel,
+           Aidge::ConvImpl2D_cpu_forward_kernel<half_float::half,
+                                                half_float::half,
+                                                half_float::half,
+                                                half_float::half>,
+           Aidge::ConvImpl2D_cpu_backward_kernel<half_float::half,
+                                                 half_float::half,
+                                                 half_float::half,
+                                                 half_float::half>});
+REGISTRAR(
+    ConvImpl2D_cpu,
     {{DataType::Any, DataFormat::NCHW}, {DataType::Float64, DataFormat::NCHW}},
-    {ProdConso::inPlaceModel, Aidge::ConvImpl2D_cpu_forward_kernel<double, double, double, double>, nullptr});
-}  // namespace Aidge
+    {ProdConso::inPlaceModel,
+     Aidge::ConvImpl2D_cpu_forward_kernel<double, double, double, double>,
+     Aidge::ConvImpl2D_cpu_backward_kernel<double, double, double, double>});
+REGISTRAR(ConvImpl2D_cpu,
+          {{DataType::Any, DataFormat::NCHW},
+           {DataType::Int32, DataFormat::NCHW}},
+          {ProdConso::inPlaceModel,
+           ConvImpl2D_cpu_forward_kernel<std::int32_t,
+                                         std::int32_t,
+                                         std::int32_t,
+                                         std::int32_t>,
+           ConvImpl2D_cpu_backward_kernel<std::int32_t,
+                                          std::int32_t,
+                                          std::int32_t,
+                                          std::int32_t>});
+} // namespace Aidge
 
 #endif /* AIDGE_CPU_OPERATOR_CONVIMPL_KERNELS_H_ */
diff --git a/include/aidge/backend/cpu/operator/ConvTransposeImpl.hpp b/include/aidge/backend/cpu/operator/ConvTransposeImpl.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..7604a96a18e7be44f4c2e8970a0b60b1c4ad918b
--- /dev/null
+++ b/include/aidge/backend/cpu/operator/ConvTransposeImpl.hpp
@@ -0,0 +1,59 @@
+
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#ifndef AIDGE_CPU_OPERATOR_CONVTRANSPOSEIMPL_H_
+#define AIDGE_CPU_OPERATOR_CONVTRANSPOSEIMPL_H_
+
+#include <array>
+
+#include "aidge/backend/cpu/operator/OperatorImpl.hpp"
+#include "aidge/operator/ConvTranspose.hpp"
+#include "aidge/utils/Registrar.hpp"
+#include "aidge/utils/Types.h"
+
+namespace Aidge {
+
+using std::array;
+
+// Operator implementation entry point for the backend
+using ConvTranspose1D_Op = ConvTranspose_Op<1>;
+using ConvTransposeImpl1D_cpu =
+    OperatorImpl_cpu<ConvTranspose1D_Op,
+                     void(const array<DimSize_t,1> &,
+                          const array<DimSize_t,1> &,
+                          const array<DimSize_t,1> &,
+                          const array<DimSize_t, 3> &,
+                          const array<DimSize_t, 3> &,
+                          const void *,
+                          const void *,
+                          const void *,
+                          void *)>;
+
+using ConvTranspose2D_Op = ConvTranspose_Op<2>;
+using ConvTransposeImpl2D_cpu =
+ OperatorImpl_cpu<ConvTranspose2D_Op,
+                                        void(const array<DimSize_t, 2> &,
+                                             const array<DimSize_t, 2> &,
+                                             const array<DimSize_t, 2> &,
+                                             const array<DimSize_t, 4> &,
+                                             const array<DimSize_t, 4> &,
+                                             const void *,
+                                             const void *,
+                                             const void *,
+                                             void *)>;
+
+// Implementation entry point registration to Operator
+REGISTRAR(ConvTranspose1D_Op, "cpu", ConvTransposeImpl1D_cpu::create);
+REGISTRAR(ConvTranspose2D_Op, "cpu", ConvTransposeImpl2D_cpu::create);
+} // namespace Aidge
+
+#endif /* AIDGE_CPU_OPERATOR_CONVTRANSPOSEIMPL_H_ */
diff --git a/include/aidge/backend/cpu/operator/ConvTransposeImpl_kernels.hpp b/include/aidge/backend/cpu/operator/ConvTransposeImpl_kernels.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..e11dd2625ae1645a8e7c5482b1635b85fb475b06
--- /dev/null
+++ b/include/aidge/backend/cpu/operator/ConvTransposeImpl_kernels.hpp
@@ -0,0 +1,305 @@
+/********************************************************************************
+ * Copyright (c) 2025 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#ifndef AIDGE_CPU_OPERATOR_CONVTRANSPOSEIMPL_KERNELS_H_
+#define AIDGE_CPU_OPERATOR_CONVTRANSPOSEIMPL_KERNELS_H_
+
+#include <array>
+
+#include "aidge/backend/cpu/operator/ConvTransposeImpl.hpp"
+#include "aidge/utils/Registrar.hpp"
+#include <aidge/backend/cpu/operator/ConvImpl_kernels.hpp>
+#include <aidge/data/Data.hpp>
+#include <aidge/data/half.hpp>
+#include <aidge/scheduler/ProdConso.hpp>
+#include <aidge/utils/Types.h>
+
+namespace Aidge {
+
+using std::array;
+
+////////////////////////////////////////////////////////
+////////////////////////////////////////////////////////
+// 1D
+////////////////////////////////////////////////////////
+////////////////////////////////////////////////////////
+
+/**
+ * @brief performs forward bias operation for convtranspose operator
+ *
+ * @tparam B Bias data type.
+ * @tparam O Output data type.
+ * @param[in] bias bias values
+ * @param[in] oDims dimensions of the output
+ * @param[in] oStrides nb of elements contained per dimension of the output
+ * @param[out] output
+ */
+template <class B, class O>
+static void convTranspose1DForwardBias(const B *biases,
+                                       const array<DimSize_t, 3> &oDims,
+                                       const array<DimSize_t, 2> &oStrides,
+                                       O *output) {
+    array<DimSize_t, 2> outOffsets{0, 0};
+    for (DimSize_t batch = 0; batch < oDims[0]; ++batch) {
+        outOffsets[0] = batch * oStrides[0];
+        for (DimSize_t outCh = 0; outCh < oDims[1]; ++outCh) {
+            outOffsets[1] = outCh * oStrides[1] + outOffsets[0];
+            // If bias = nullptr, set B(0)
+            B biasVal = (biases != nullptr) ? biases[outCh] : B(0);
+            std::fill(output + outOffsets[1],
+                      output + (outOffsets[1] + oDims[2]),
+                      biasVal);
+        }
+    }
+}
+
+/**
+ * @brief forward kernel for convtranspose
+ * @note ConvTranspose forward is simply convolution backward kernel.
+ * Check convolution functions for more in-depth details on how the
+ subfunctions are built.
+ * @tparam I Input data type.
+ * @tparam W Weight data type.
+ * @tparam B Bias data type.
+ * @tparam O Output data type.
+ * @param[in] stride stride parameter of the convTranspose operator
+ * @param[in] dilation dilation parameter of the convTranspose operator
+ * @param[in] inputDims input dimensions
+ * @param[in] outputDims output tensor dimensions
+ * @param[in] oStrides nb of elements contained per dimension of the output
+ * @param[in] input_ values
+ * @param[in] weight_ values
+ * @param[in] biases_ values
+ * @param[out] output
+ */
+template <class I, class W, class B, class O>
+void ConvTransposeImpl1D_cpu_forward_kernel(
+    const array<DimSize_t, 1> &stride,
+    const array<DimSize_t, 1> &dilation,
+    const array<DimSize_t, 1> &kernelDim,
+    const array<DimSize_t, 3> &inputDims,
+    const array<DimSize_t, 3> &outputDims,
+    const void *input_,
+    const void *weights_,
+    const void *biases_,
+    void *output_) {
+
+    const I *input = static_cast<const I *>(input_);
+    const W *weights = static_cast<const W *>(weights_);
+    O *output = static_cast<O *>(output_);
+
+    // {batch_stride, channel_stride, dim0_stride}
+    const array<DimSize_t, 2> inputStrides{inputDims[1] * inputDims[2],
+                                           inputDims[2]};
+
+    // {batch_stride, channel_stride, dim0_stride}
+    const array<DimSize_t, 2> outputStrides{outputDims[1] * outputDims[2],
+                                            outputDims[2]};
+
+    // NOTE: kernel dims = {inChannels, outChannels, kernelDims[0]}
+    const array<DimSize_t, 2> kernelStrides{
+        outputDims[1] * kernelDim[0],
+        kernelDim[0],
+    };
+
+    if (biases_ != nullptr) {
+        const B *biases = static_cast<const B *>(biases_);
+        convTranspose1DForwardBias(biases, outputDims, outputStrides, output);
+    }
+
+    conv1DBackwardInput(stride,
+                        dilation,
+                        kernelDim,
+                        kernelStrides,
+                        weights,
+                        inputDims,
+                        inputStrides,
+                        input,
+                        outputDims,
+                        outputStrides,
+                        output);
+}
+
+REGISTRAR(ConvTransposeImpl1D_cpu,
+          {{DataType::Any, DataFormat::NCHW},
+           {DataType::Int32, DataFormat::NCHW}},
+          {ProdConso::inPlaceModel,
+           ConvTransposeImpl1D_cpu_forward_kernel<std::int32_t,
+                                                  std::int32_t,
+                                                  std::int32_t,
+                                                  std::int32_t>,
+           nullptr});
+REGISTRAR(ConvTransposeImpl1D_cpu,
+          {{DataType::Any, DataFormat::NCHW},
+           {DataType::Float32, DataFormat::NCHW}},
+          {ProdConso::inPlaceModel,
+           ConvTransposeImpl1D_cpu_forward_kernel<float, float, float, float>,
+           nullptr});
+REGISTRAR(ConvTransposeImpl1D_cpu,
+          {{DataType::Any, DataFormat::NCHW},
+           {DataType::Float16, DataFormat::NCHW}},
+          {ProdConso::inPlaceModel,
+           ConvTransposeImpl1D_cpu_forward_kernel<half_float::half,
+                                                  half_float::half,
+                                                  half_float::half,
+                                                  half_float::half>,
+           nullptr});
+REGISTRAR(
+    ConvTransposeImpl1D_cpu,
+    {{DataType::Any, DataFormat::NCHW}, {DataType::Float64, DataFormat::NCHW}},
+    {ProdConso::inPlaceModel,
+     ConvTransposeImpl1D_cpu_forward_kernel<double, double, double, double>,
+     nullptr});
+
+////////////////////////////////////////////////////////
+////////////////////////////////////////////////////////
+// 2D
+////////////////////////////////////////////////////////
+////////////////////////////////////////////////////////
+
+/**
+ * @brief performs forward bias operation for convtranspose operator
+ *
+ * @tparam B Bias data type.
+ * @tparam O Output data type.
+ * @param[in] bias bias values
+ * @param[in] oDims dimensions of the output
+ * @param[in] oStrides nb of elements contained per dimension of the output
+ * @param[out] output
+ */
+template <class B, class O>
+static void convTranspose2DForwardBias(const B *biases,
+                                       const array<DimSize_t, 4> &oDims,
+                                       const array<DimSize_t, 3> &oStrides,
+                                       O *output) {
+    array<DimSize_t, 2> outOffsets{0, 0};
+
+    for (DimSize_t batch = 0; batch < oDims[0]; ++batch) {
+        outOffsets[0] = batch * oStrides[0];
+
+        for (DimSize_t outCh = 0; outCh < oDims[1]; ++outCh) {
+            outOffsets[1] = outCh * oStrides[1] + outOffsets[0];
+            // If bias = nullptr, set B(0)
+            B biasVal = (biases != nullptr) ? biases[outCh] : B(0);
+            std::fill(output + outOffsets[1],
+                      (output + outOffsets[1]) + oStrides[1],
+                      biasVal);
+        }
+    }
+}
+
+/**
+ * @brief forward kernel for convtranspose
+ * @note ConvTranspose forward is simply convolution backward kernel.
+ * Check convolution functions for more in-depth details on how the
+ subfunctions are built.
+ * @tparam I Input data type.
+ * @tparam W Weight data type.
+ * @tparam B Bias data type.
+ * @tparam O Output data type.
+ * @param[in] stride stride parameter of the convTranspose operator
+ * @param[in] dilation dilation parameter of the convTranspose operator
+ * @param[in] inputDims input dimensions
+ * @param[in] outputDims output tensor dimensions
+ * @param[in] oStrides nb of elements contained per dimension of the output
+ * @param[in] input_ values
+ * @param[in] weight_ values
+ * @param[in] biases_ values
+ * @param[out] output
+ */
+template <class I, class W, class B, class O>
+void ConvTransposeImpl2D_cpu_forward_kernel(
+    const array<DimSize_t, 2> &stride,
+    const array<DimSize_t, 2> &dilation,
+    const array<DimSize_t, 2> &kernelDims,
+    const array<DimSize_t, 4> &inputDims,
+    const array<DimSize_t, 4> &outputDims,
+    const void *input_,
+    const void *weights_,
+    const void *biases_,
+    void *output_) {
+
+    auto input = static_cast<const I *>(input_);
+    auto weights = static_cast<const W *>(weights_);
+    auto output = static_cast<O *>(output_);
+
+    // {channel_stride, dim0_stride, dim1_stride}
+    const array<DimSize_t, 3> inputStrides{
+        inputDims[1] * inputDims[2] * inputDims[3],
+        inputDims[2] * inputDims[3],
+        inputDims[3]};
+
+    // {channel_stride, dim0_stride, dim1_stride}
+    const array<DimSize_t, 3> outputStrides{
+        outputDims[1] * outputDims[2] * outputDims[3],
+        outputDims[2] * outputDims[3],
+        outputDims[3]};
+
+    // NOTE: kernel dims = {inChannels, outChannels, kernelDims[0],
+    // kernelDims[1]}
+    const array<DimSize_t, 3> kernelStrides{
+        outputDims[1] * kernelDims[0] * kernelDims[1],
+        kernelDims[0] * kernelDims[1],
+        kernelDims[1],
+    };
+
+    if (biases_ != nullptr) {
+        auto biases = static_cast<const B *>(biases_);
+        convTranspose2DForwardBias(biases, outputDims, outputStrides, output);
+    }
+
+    conv2DBackwardInput(stride,
+                        dilation,
+                        kernelDims,
+                        kernelStrides,
+                        weights,
+                        inputDims,
+                        inputStrides,
+                        input,
+                        outputDims,
+                        outputStrides,
+                        output);
+}
+
+REGISTRAR(ConvTransposeImpl2D_cpu,
+          {{DataType::Any, DataFormat::NCHW},
+           {DataType::Int32, DataFormat::NCHW}},
+          {ProdConso::inPlaceModel,
+           ConvTransposeImpl2D_cpu_forward_kernel<std::int32_t,
+                                                  std::int32_t,
+                                                  std::int32_t,
+                                                  std::int32_t>,
+           nullptr});
+REGISTRAR(ConvTransposeImpl2D_cpu,
+          {{DataType::Any, DataFormat::NCHW},
+           {DataType::Float16, DataFormat::NCHW}},
+          {ProdConso::inPlaceModel,
+           ConvTransposeImpl2D_cpu_forward_kernel<half_float::half,
+                                                  half_float::half,
+                                                  half_float::half,
+                                                  half_float::half>,
+           nullptr});
+REGISTRAR(ConvTransposeImpl2D_cpu,
+          {{DataType::Any, DataFormat::NCHW},
+           {DataType::Float32, DataFormat::NCHW}},
+          {ProdConso::inPlaceModel,
+           ConvTransposeImpl2D_cpu_forward_kernel<float, float, float, float>,
+           nullptr});
+REGISTRAR(
+    ConvTransposeImpl2D_cpu,
+    {{DataType::Any, DataFormat::NCHW}, {DataType::Float64, DataFormat::NCHW}},
+    {ProdConso::inPlaceModel,
+     ConvTransposeImpl2D_cpu_forward_kernel<double, double, double, double>,
+     nullptr});
+
+} // namespace Aidge
+
+#endif /* AIDGE_CPU_OPERATOR_CONVTRANSPOSEIMPL_KERNELS_H_ */
diff --git a/src/operator/ConvImpl.cpp b/src/operator/ConvImpl.cpp
index fdfe19fbf4bf3e71c86aa28b966cfb21a1b5ba40..d23a9968ffb424b4639e0fcd2629a3a1cc2e11c3 100644
--- a/src/operator/ConvImpl.cpp
+++ b/src/operator/ConvImpl.cpp
@@ -13,14 +13,11 @@
 #include "aidge/backend/cpu/operator/ConvImpl_kernels.hpp"
 
 #include <cassert>
-#include <chrono>  // std::chrono::milliseconds
-#include <numeric> // std::accumulate
-#include <thread>  // std::this_thread::sleep_for
-#include <vector>
 
 #include "aidge/backend/cpu/data/GetCPUPtr.h"
 #include "aidge/operator/Conv.hpp"
-#include "aidge/utils/Types.h"
+
+namespace Aidge {
 
 template <>
 void Aidge::ConvImpl1D_cpu::forward() {
@@ -43,21 +40,60 @@ void Aidge::ConvImpl1D_cpu::forward() {
     const auto& input2 = (op_.getInput(2)) ? op_.getInput(2)->refCastFrom(input2Fallback, *op_.getOutput(0)) : Tensor();
 
     // Call kernel
-    impl.forward(op_.strideDims(),
-            op_.dilationDims(),
-            op_.kernelDims(),
-            op_.getInput(0)->template dims<3>(), // input dimensions
-            dynamic_cast<const Conv_Op<1>&>(mOp).outChannels(), // outChannels
-            input0.getImpl()->rawPtr(), // input
-            input1.getImpl()->rawPtr(), // weight
-            op_.getInput(2) ? input2.getImpl()->rawPtr() : nullptr, // bias
-            getCPUPtr(mOp.getRawOutput(0)) // output
-            );
+    impl.forward(
+        op_.strideDims(),
+        op_.dilationDims(),
+        op_.kernelDims(),
+        op_.getInput(0)->template dims<3>(), // input dimensions
+        dynamic_cast<const Conv_Op<1> &>(mOp).outChannels(),    // outChannels
+        input0.getImpl()->rawPtr(),                             // input
+        input1.getImpl()->rawPtr(),                             // weight
+        op_.getInput(2) ? input2.getImpl()->rawPtr() : nullptr, // bias
+        getCPUPtr(mOp.getRawOutput(0))                          // output
+    );
 }
 
-template <>
-void Aidge::ConvImpl1D_cpu::backward() {
-    AIDGE_THROW_OR_ABORT(std::runtime_error, "Backward not yet implemented for Conv_Op<1> on backend cpu");
+template <> void ConvImpl1D_cpu::backward() {
+    const auto &op = dynamic_cast<const Conv1D_Op &>(mOp);
+    const auto &outputGrad = op.getOutput(0)->grad();
+    AIDGE_ASSERT(outputGrad, "{}: missing ouput #0 gradient", op.type());
+    AIDGE_ASSERT(op.getInput(0)->grad(),
+                 "{}: missing data input(#0) gradient",
+                 op.type());
+    AIDGE_ASSERT(op.getInput(1)->grad(),
+                 "{}: missing weight input(#1) gradient",
+                 op.type());
+
+    std::shared_ptr<Tensor> inputDataGradFallback, inputWeightGradFallback,
+        inputBiasGradFallback;
+    const auto &inputDataGrad =
+        op.getInput(0)->grad()->refCastFrom(inputDataGradFallback,
+                                            *(op.getOutput(0)));
+    const auto &inputWeightGrad =
+        op.getInput(1)->grad()->refCastFrom(inputWeightGradFallback,
+                                            *(op.getOutput(0)));
+    const auto &inputBiasGrad =
+        (op.getInput(2) && op.getInput(2)->grad())
+            ? op.getInput(2)->grad()->refCastFrom(inputBiasGradFallback,
+                                                  *(op.getOutput(0)))
+            : Tensor();
+
+    // Call kernel
+    const auto impl =
+        Registrar<ConvImpl1D_cpu>::create(getBestMatch(getRequiredSpec()));
+    impl.backward(
+        op.strideDims(),
+        op.dilationDims(),
+        op.kernelDims(),
+        op.getInput(0)->template dims<3>(),
+        op.getOutput(0)->template dims<3>(),
+
+        getCPUPtr(op.getInput(0)),
+        getCPUPtr(op.getInput(1)),
+        getCPUPtr(outputGrad),
+        inputDataGrad.getImpl()->rawPtr(),
+        inputWeightGrad.getImpl()->rawPtr(),
+        op.getInput(2) ? inputBiasGrad.getImpl()->rawPtr() : nullptr);
 }
 
 template <>
@@ -93,7 +129,48 @@ void Aidge::ConvImpl2D_cpu::forward() {
             );
 }
 
-template <>
-void Aidge::ConvImpl2D_cpu::backward() {
-    AIDGE_THROW_OR_ABORT(std::runtime_error, "Backward not yet implemented for Conv_Op<2> on backend cpu");
+
+template <> void ConvImpl2D_cpu::backward() {
+    const auto &op = dynamic_cast<const Conv2D_Op &>(mOp);
+    const auto &outputGrad = op.getOutput(0)->grad();
+    AIDGE_ASSERT(outputGrad, "{}: missing ouput #0 gradient", op.type());
+    AIDGE_ASSERT(op.getInput(0)->grad(),
+                 "{}: missing data input(#0) gradient",
+                 op.type());
+    AIDGE_ASSERT(op.getInput(1)->grad(),
+                 "{}: missing weight input(#1) gradient",
+                 op.type());
+
+    std::shared_ptr<Tensor> inputDataGradFallback, inputWeightGradFallback,
+        inputBiasGradFallback;
+    const auto &inputDataGrad =
+        op.getInput(0)->grad()->refCastFrom(inputDataGradFallback,
+                                            *(op.getOutput(0)));
+    const auto &inputWeightGrad =
+        op.getInput(1)->grad()->refCastFrom(inputWeightGradFallback,
+                                            *(op.getOutput(0)));
+    const auto &inputBiasGrad =
+        (op.getInput(2) && op.getInput(2)->grad())
+            ? op.getInput(2)->grad()->refCastFrom(inputBiasGradFallback,
+                                                  *(op.getOutput(0)))
+            : Tensor();
+
+    // Call kernel
+    const auto impl =
+        Registrar<ConvImpl2D_cpu>::create(getBestMatch(getRequiredSpec()));
+    impl.backward(
+        op.strideDims(),
+        op.dilationDims(),
+        op.kernelDims(),
+        op.getInput(0)->template dims<4>(),
+        op.getOutput(0)->template dims<4>(),
+
+        getCPUPtr(op.getInput(0)),
+        getCPUPtr(op.getInput(1)),
+        getCPUPtr(outputGrad),
+        inputDataGrad.getImpl()->rawPtr(),
+        inputWeightGrad.getImpl()->rawPtr(),
+        op.getInput(2) ? inputBiasGrad.getImpl()->rawPtr() : nullptr);
 }
+
+} // namespace Aidge
diff --git a/src/operator/ConvTransposeImpl.cpp b/src/operator/ConvTransposeImpl.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..d1135cc92dd3c68746b9dcf80739f4f65acdad2e
--- /dev/null
+++ b/src/operator/ConvTransposeImpl.cpp
@@ -0,0 +1,91 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include "aidge/backend/cpu/operator/ConvTransposeImpl.hpp"
+#include "aidge/backend/cpu/operator/ConvTransposeImpl_kernels.hpp"
+
+template <> void Aidge::ConvTransposeImpl1D_cpu::forward() {
+    const auto &op = static_cast<const ConvTranspose_Op<1> &>(mOp);
+
+    AIDGE_ASSERT(op.getInput(0), "{}: missing data input (#0).", op.type());
+    AIDGE_ASSERT(op.getInput(1), "{}: missing bias input (#1).", op.type());
+    AIDGE_ASSERT(op.getInput(2), "{}: missing weight input (#1).", op.type());
+
+    std::shared_ptr<Tensor> inputDataFallback, inputWeightFallback,
+        inputBiasFallback;
+    const auto &inputData =
+        op.getInput(0)->refCastFrom(inputDataFallback, *op.getOutput(0));
+    const auto &inputWeight =
+        op.getInput(1)->refCastFrom(inputWeightFallback, *op.getOutput(0));
+    const auto &inputBias =
+        (op.getInput(2))
+            ? op.getInput(2)->refCastFrom(inputBiasFallback, *op.getOutput(0))
+            : Tensor();
+
+    // Call kernel
+    const auto impl = Registrar<ConvTransposeImpl1D_cpu>::create(
+        getBestMatch(getRequiredSpec()));
+    impl.forward(op.strideDims(),
+                 op.dilationDims(),
+                 op.kernelDims(),
+                 op.getInput(0)->template dims<3>(),
+                 op.getOutput(0)->template dims<3>(),
+                 inputData.getImpl()->hostPtr(),
+                 inputWeight.getImpl()->hostPtr(),
+                 op.getInput(2) ? inputBias.getImpl()->hostPtr() : nullptr,
+                 op.getOutput(0)->getImpl()->rawPtr());
+}
+
+template <> void Aidge::ConvTransposeImpl1D_cpu::backward() {
+    AIDGE_THROW_OR_ABORT(
+        std::runtime_error,
+        "Backward not yet implemented for Conv_Op<1> on backend cpu");
+}
+
+template <> void Aidge::ConvTransposeImpl2D_cpu::forward() {
+    const auto &op = static_cast<const ConvTranspose_Op<2> &>(mOp);
+
+    AIDGE_ASSERT(op.getInput(0), "{}: missing data input (#0).", op.type());
+    AIDGE_ASSERT(op.getInput(1), "{}: missing bias input (#1).", op.type());
+    AIDGE_ASSERT(op.getInput(2), "{}: missing weight input (#1).", op.type());
+
+    std::shared_ptr<Tensor> inputDataFallback, inputWeightFallback,
+        inputBiasFallback;
+    const auto &inputData =
+        op.getInput(0)->refCastFrom(inputDataFallback, *op.getOutput(0));
+    const auto &inputWeight =
+        op.getInput(1)->refCastFrom(inputWeightFallback, *op.getOutput(0));
+    const auto &inputBias =
+        (op.getInput(2))
+            ? op.getInput(2)->refCastFrom(inputBiasFallback, *op.getOutput(0))
+            : Tensor();
+
+    // Call kernel
+    const auto impl = Registrar<ConvTransposeImpl2D_cpu>::create(
+        getBestMatch(getRequiredSpec()));
+
+    impl.forward(op.strideDims(),
+                 op.dilationDims(),
+                 op.kernelDims(),
+                 op.getInput(0)->template dims<4>(),
+                 op.getOutput(0)->template dims<4>(),
+                 inputData.getImpl()->hostPtr(),
+                 inputWeight.getImpl()->hostPtr(),
+                 op.getInput(2) ? inputBias.getImpl()->hostPtr() : nullptr,
+                 op.getOutput(0)->getImpl()->rawPtr());
+}
+
+template <> void Aidge::ConvTransposeImpl2D_cpu::backward() {
+    AIDGE_THROW_OR_ABORT(
+        std::runtime_error,
+        "Backward not yet implemented for Conv_Op<2> on backend cpu");
+}
+
diff --git a/unit_tests/operator/Test_ClipImpl.cpp b/unit_tests/operator/Test_ClipImpl.cpp
index 99147ac93bd659dd91897f6b7f1f3f33e5552ef6..3d75ad78807d0e4d23ec231f5df485e8574a03ee 100644
--- a/unit_tests/operator/Test_ClipImpl.cpp
+++ b/unit_tests/operator/Test_ClipImpl.cpp
@@ -315,5 +315,5 @@ TEST_CASE("[cpu/operator] Clip", "[Clip][CPU]")
         Log::info("total time: {}\n", duration.count());
     }
  }
-} // namespace Aidge
-}
\ No newline at end of file
+}
+}  // namespace Aidge
diff --git a/unit_tests/operator/Test_ConvImpl.cpp b/unit_tests/operator/Test_ConvImpl.cpp
index f7be338c0b9c5bb1d5af6bfa09ed7855c17fb6c0..59ec16dd80ee98c09c79d5943c503e945abf5cdb 100644
--- a/unit_tests/operator/Test_ConvImpl.cpp
+++ b/unit_tests/operator/Test_ConvImpl.cpp
@@ -17,6 +17,7 @@
 #include "aidge/backend/cpu/operator/ConvImpl.hpp"
 #include "aidge/data/Data.hpp"  // DataType
 #include "aidge/data/Tensor.hpp"
+#include "aidge/filler/Filler.hpp"
 #include "aidge/graph/Node.hpp"
 #include "aidge/operator/Conv.hpp"
 #include "aidge/utils/TensorUtils.hpp"
@@ -1645,4 +1646,1000 @@ TEST_CASE("[cpu/operator] Conv(forward)", "[Conv][CPU]") {
             REQUIRE(approxEq<float>(*(conv_op.getOutput(0)),*expectedOutput, 1e-5f, 1e-6f));
         }
     }
-}
\ No newline at end of file
+}
+
+template <DimSize_t DIM>
+std::shared_ptr<OperatorTensor>
+setupTestConv(const DimSize_t batchSize,
+                      const DimSize_t inChannels,
+                      const DimSize_t outChannels,
+                      const std::array<DimSize_t, DIM> kernelSize,
+                      const std::array<DimSize_t, DIM> dataSize,
+                      const std::array<DimSize_t, DIM> stride,
+                      const std::array<DimSize_t, DIM> dilation,
+                      const std::array<DimSize_t, 2 * DIM> padding,
+                      const std::shared_ptr<Tensor> input,
+                      const std::shared_ptr<Tensor> weights,
+                      const std::shared_ptr<Tensor> biases) {
+    input->setBackend("cpu");
+    weights->setBackend("cpu");
+    biases->setBackend("cpu");
+    std::shared_ptr<Node> convNode;
+    convNode = Conv(inChannels,
+                    outChannels,
+                    kernelSize,
+                    "myconv",
+                    std::array<DimSize_t, DIM>({stride}),
+                    dilation);
+    auto op =
+        std::static_pointer_cast<OperatorTensor>(convNode->getOperator());
+
+    op->setDataType(DataType::Float32);
+    op->setBackend("cpu");
+
+    op->associateInput(0, input);
+    op->associateInput(1, weights);
+    op->associateInput(2, biases);
+
+    REQUIRE_NOTHROW(op->forwardDims(true));
+
+    return op;
+}
+
+TEST_CASE("[cpu/operator] Conv(backward)", "[Conv][CPU]") {
+    SECTION("1D") {
+        const std::size_t DIM = 1;
+        SECTION("no stride & no dilation, outChannels > inChannels") {
+
+            const DimSize_t batchSize = 1;
+            const DimSize_t inChannels = 2;
+            const DimSize_t outChannels = 3;
+            const DimSize_t kernelSize = 4;
+            const DimSize_t inDataSize = 12;
+
+            const DimSize_t stride = 1;
+            const DimSize_t dilation = 1;
+            const std::array<DimSize_t, 2 * DIM> padding({0, 0});
+
+            auto inputSize =
+                std::vector<DimSize_t>({batchSize, inChannels, inDataSize});
+
+            auto input = std::make_shared<Tensor>(
+                Array3D<float, batchSize, inChannels, inDataSize>(
+                    {{{{1.000000,
+                        1.000000,
+                        1.000000,
+                        1.000000,
+                        1.000000,
+                        1.000000,
+                        1.000000,
+                        1.000000,
+                        1.000000,
+                        1.000000,
+                        1.000000,
+                        1.000000},
+                       {1.000000,
+                        1.000000,
+                        1.000000,
+                        1.000000,
+                        1.000000,
+                        1.000000,
+                        1.000000,
+                        1.000000,
+                        1.000000,
+                        1.000000,
+                        1.000000,
+                        1.000000}}}}));
+
+            auto weights = std::make_shared<Tensor>(
+                Array3D<float, outChannels, inChannels, kernelSize>(
+                    {{{{0.100000, 0.100000, 0.100000, 0.100000},
+                       {0.100000, 0.100000, 0.100000, 0.100000}},
+                      {{0.100000, 0.100000, 0.100000, 0.100000},
+                       {0.100000, 0.100000, 0.100000, 0.100000}},
+                      {{0.100000, 0.100000, 0.100000, 0.100000},
+                       {0.100000, 0.100000, 0.100000, 0.100000}}}
+
+                    }));
+
+            auto biases = std::make_shared<Tensor>(
+                Array1D<float, outChannels>({0.010000, 0.010000, 0.010000}));
+
+            auto op = setupTestConv<DIM>(
+                batchSize,
+                inChannels,
+                outChannels,
+                std::array<DimSize_t, DIM>({kernelSize}),
+                std::array<DimSize_t, DIM>({inDataSize}),
+                std::array<DimSize_t, DIM>({stride}),
+                std::array<DimSize_t, DIM>({dilation}),
+                padding,
+                input,
+                weights,
+                biases);
+
+            ////////////////////////////////////
+            // setup gradients for backward
+            auto outputGrad =
+                std::make_shared<Tensor>(op->getOutput(0)->dims());
+            outputGrad->setDataType(DataType::Float32);
+            outputGrad->setBackend("cpu");
+            constantFiller(outputGrad, 1.f);
+            op->getOutput(0)->setGrad(outputGrad);
+
+            ////////////////////////////////////
+            // setup gradients for backward
+            REQUIRE_NOTHROW(op->backward());
+
+            SECTION("Input Grad") {
+                auto expectedInputGrad = std::make_shared<Tensor>(
+                    Array3D<float, batchSize, inChannels, inDataSize>(
+                        {{{{0.3000,
+                            0.6000,
+                            0.9000,
+                            1.2000,
+                            1.2000,
+                            1.2000,
+                            1.2000,
+                            1.2000,
+                            1.2000,
+                            0.9000,
+                            0.6000,
+                            0.3000},
+                           {0.3000,
+                            0.6000,
+                            0.9000,
+                            1.2000,
+                            1.2000,
+                            1.2000,
+                            1.2000,
+                            1.2000,
+                            1.2000,
+                            0.9000,
+                            0.6000,
+                            0.3000}}}}));
+                CHECK(approxEq<float, float>(*op->getInput(0)->grad(),
+                                             *expectedInputGrad));
+            }
+            SECTION("Weight grad") {
+                std::vector<DimSize_t> weightsSize(
+                    {outChannels, inChannels, kernelSize});
+                auto expectedWeightsGrad =
+                    std::make_shared<Tensor>(weightsSize);
+                expectedWeightsGrad->setBackend("cpu");
+                expectedWeightsGrad->setDataType(DataType::Float32);
+                constantFiller<float>(expectedWeightsGrad, 9.);
+
+                CHECK(approxEq<float, float>(*op->getInput(1)->grad(),
+                                             *expectedWeightsGrad));
+            }
+            SECTION("Bias Grad") {
+                std::vector<DimSize_t> biasesSize({outChannels});
+                auto expectedBiasGrad = std::make_shared<Tensor>(biasesSize);
+                expectedBiasGrad->setBackend("cpu");
+                expectedBiasGrad->setDataType(DataType::Float32);
+                constantFiller<float>(expectedBiasGrad, 9.);
+                CHECK(approxEq<float, float>(*op->getInput(2)->grad(),
+                                             *expectedBiasGrad));
+            }
+        }
+
+        SECTION("stride and no dilation, inChannel > outChannels") {
+            const DimSize_t batchSize = 2;
+            const DimSize_t inChannels = 3;
+            const DimSize_t outChannels = 1;
+            const DimSize_t kernelSize = 2;
+            const DimSize_t inDataSize = 8;
+            const DimSize_t stride = 3;
+            const DimSize_t dilation = 1;
+            const std::array<DimSize_t, 2 * DIM> padding({0, 0});
+
+            auto inputSize =
+                std::vector<DimSize_t>({batchSize, inChannels, inDataSize});
+
+            auto input = std::make_shared<Tensor>(
+                Array3D<float, batchSize, inChannels, inDataSize>(
+                    {{{{1., 1., 1., 1., 1., 1., 1., 1.},
+                       {1., 1., 1., 1., 1., 1., 1., 1.},
+                       {1., 1., 1., 1., 1., 1., 1., 1.}},
+
+                      {{1., 1., 1., 1., 1., 1., 1., 1.},
+                       {1., 1., 1., 1., 1., 1., 1., 1.},
+                       {1., 1., 1., 1., 1., 1., 1., 1.}}}}));
+            auto weights = std::make_shared<Tensor>(
+                Array3D<float, outChannels, inChannels, kernelSize>(
+                    {{{{0.1000, 0.1000},
+                       {0.1000, 0.1000},
+                       {0.1000, 0.1000}}}}));
+
+            auto biases = std::make_shared<Tensor>(
+                Array1D<float, outChannels>({0.060000}));
+
+            auto op = setupTestConv<DIM>(
+                batchSize,
+                inChannels,
+                outChannels,
+                std::array<DimSize_t, DIM>({kernelSize}),
+                std::array<DimSize_t, DIM>({inDataSize}),
+                std::array<DimSize_t, DIM>({stride}),
+                std::array<DimSize_t, DIM>({dilation}),
+                padding,
+                input,
+                weights,
+                biases);
+
+            ////////////////////////////////////
+            // setup gradients for backward
+            auto outputGrad =
+                std::make_shared<Tensor>(op->getOutput(0)->dims());
+            outputGrad->setDataType(DataType::Float32);
+            outputGrad->setBackend("cpu");
+            constantFiller(outputGrad, 1.f);
+            op->getOutput(0)->setGrad(outputGrad);
+
+            ////////////////////////////////////
+            // setup gradients for backward
+            REQUIRE_NOTHROW(op->backward());
+
+            SECTION("Input Grad") {
+                auto expectedInputGrad = std::make_shared<Tensor>(
+                    Array3D<float, batchSize, inChannels, inDataSize>(
+                        {{{{0.1000,
+                            0.1000,
+                            0.0000,
+                            0.1000,
+                            0.1000,
+                            0.0000,
+                            0.1000,
+                            0.1000},
+                           {0.1000,
+                            0.1000,
+                            0.0000,
+                            0.1000,
+                            0.1000,
+                            0.0000,
+                            0.1000,
+                            0.1000},
+                           {0.1000,
+                            0.1000,
+                            0.0000,
+                            0.1000,
+                            0.1000,
+                            0.0000,
+                            0.1000,
+                            0.1000}},
+
+                          {{0.1000,
+                            0.1000,
+                            0.0000,
+                            0.1000,
+                            0.1000,
+                            0.0000,
+                            0.1000,
+                            0.1000},
+                           {0.1000,
+                            0.1000,
+                            0.0000,
+                            0.1000,
+                            0.1000,
+                            0.0000,
+                            0.1000,
+                            0.1000},
+                           {0.1000,
+                            0.1000,
+                            0.0000,
+                            0.1000,
+                            0.1000,
+                            0.0000,
+                            0.1000,
+                            0.1000}}}}));
+                CHECK(approxEq<float, float>(*op->getInput(0)->grad(),
+                                             *expectedInputGrad));
+            }
+            SECTION("Weight grad") {
+                auto expectedWeightsGrad = std::make_shared<Tensor>(
+                    Array3D<float, outChannels, inChannels, kernelSize>(
+                        {{{{6., 6.}, {6., 6.}, {6., 6.}}}}));
+                CHECK(approxEq<float, float>(*op->getInput(1)->grad(),
+                                             *expectedWeightsGrad));
+            }
+            SECTION("Bias Grad") {
+                auto expectedBiasesGrad = std::make_shared<Tensor>(
+                    Array1D<float, outChannels>({6.}));
+                CHECK(approxEq<float, float>(*op->getInput(2)->grad(),
+                                             *expectedBiasesGrad));
+            }
+        }
+
+        SECTION("dilation, no stride") {
+            const DimSize_t batchSize = 2;
+            const DimSize_t inChannels = 3;
+            const DimSize_t outChannels = 1;
+            const DimSize_t kernelSize = 2;
+            const DimSize_t inDataSize = 8;
+
+            const DimSize_t stride = 1;
+            const DimSize_t dilation = 2;
+            const std::array<DimSize_t, 2 * DIM> padding({0, 0});
+
+            auto inputSize =
+                std::vector<DimSize_t>({batchSize, inChannels, inDataSize});
+
+            auto input = std::make_shared<Tensor>(
+                Array3D<float, batchSize, inChannels, inDataSize>(
+                    {{{{1., 1., 1., 1., 1., 1., 1., 1.},
+                       {1., 1., 1., 1., 1., 1., 1., 1.},
+                       {1., 1., 1., 1., 1., 1., 1., 1.}},
+
+                      {{1., 1., 1., 1., 1., 1., 1., 1.},
+                       {1., 1., 1., 1., 1., 1., 1., 1.},
+                       {1., 1., 1., 1., 1., 1., 1., 1.}}}}));
+            auto weights = std::make_shared<Tensor>(
+                Array3D<float, outChannels, inChannels, kernelSize>(
+                    {{{{0.1000, 0.1000},
+                       {0.1000, 0.1000},
+                       {0.1000, 0.1000}}}}));
+
+            auto biases = std::make_shared<Tensor>(
+                Array1D<float, outChannels>({0.060000}));
+
+            auto op = setupTestConv<DIM>(
+                batchSize,
+                inChannels,
+                outChannels,
+                std::array<DimSize_t, DIM>({kernelSize}),
+                std::array<DimSize_t, DIM>({inDataSize}),
+                std::array<DimSize_t, DIM>({stride}),
+                std::array<DimSize_t, DIM>({dilation}),
+                padding,
+                input,
+                weights,
+                biases);
+
+            ////////////////////////////////////
+            // setup gradients for backward
+            auto outputGrad =
+                std::make_shared<Tensor>(op->getOutput(0)->dims());
+            outputGrad->setDataType(DataType::Float32);
+            outputGrad->setBackend("cpu");
+            constantFiller(outputGrad, 1.f);
+            op->getOutput(0)->setGrad(outputGrad);
+
+            ////////////////////////////////////
+            // setup gradients for backward
+            REQUIRE_NOTHROW(op->backward());
+
+            SECTION("Input Grad") {
+                auto expectedInputGrad = std::make_shared<Tensor>(
+                    Array3D<float, batchSize, inChannels, inDataSize>(
+                        {{{{0.1000,
+                            0.1000,
+                            0.2000,
+                            0.2000,
+                            0.2000,
+                            0.2000,
+                            0.1000,
+                            0.1000},
+                           {0.1000,
+                            0.1000,
+                            0.2000,
+                            0.2000,
+                            0.2000,
+                            0.2000,
+                            0.1000,
+                            0.1000},
+                           {0.1000,
+                            0.1000,
+                            0.2000,
+                            0.2000,
+                            0.2000,
+                            0.2000,
+                            0.1000,
+                            0.1000}},
+
+                          {{0.1000,
+                            0.1000,
+                            0.2000,
+                            0.2000,
+                            0.2000,
+                            0.2000,
+                            0.1000,
+                            0.1000},
+                           {0.1000,
+                            0.1000,
+                            0.2000,
+                            0.2000,
+                            0.2000,
+                            0.2000,
+                            0.1000,
+                            0.1000},
+                           {0.1000,
+                            0.1000,
+                            0.2000,
+                            0.2000,
+                            0.2000,
+                            0.2000,
+                            0.1000,
+                            0.1000}}}}));
+                CHECK(approxEq<float, float>(*op->getInput(0)->grad(),
+                                             *expectedInputGrad));
+            }
+            SECTION("Weight grad") {
+                auto expectedWeightsGrad = std::make_shared<Tensor>(
+                    Array3D<float, outChannels, inChannels, kernelSize>(
+                        {{{{12., 12.}, {12., 12.}, {12., 12.}}}}));
+                CHECK(approxEq<float, float>(*op->getInput(1)->grad(),
+                                             *expectedWeightsGrad));
+            }
+            SECTION("Bias Grad") {
+                auto expectedBiasesGrad = std::make_shared<Tensor>(
+                    Array1D<float, outChannels>({12.}));
+                CHECK(approxEq<float, float>(*op->getInput(2)->grad(),
+                                             *expectedBiasesGrad));
+            }
+        }
+        SECTION("stride & dilation") {
+            const DimSize_t batchSize = 1;
+            const DimSize_t inChannels = 4;
+            const DimSize_t outChannels = 4;
+            const DimSize_t kernelSize = 3;
+            const DimSize_t inDataSize = 13;
+
+            const DimSize_t stride = 4;
+            const DimSize_t dilation = 3;
+            const std::array<DimSize_t, 2 * DIM> padding({0, 0});
+
+            auto inputSize =
+                std::vector<DimSize_t>({batchSize, inChannels, inDataSize});
+
+            auto input = std::make_shared<
+                Tensor>(Array3D<float, batchSize, inChannels, inDataSize>(
+                {{{{1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.},
+                   {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.},
+                   {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.},
+                   {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}}}}));
+            auto weights = std::make_shared<Tensor>(
+                Array3D<float, outChannels, inChannels, kernelSize>(
+                    {{{{0.1000, 0.1000, 0.1000},
+                       {0.1000, 0.1000, 0.1000},
+                       {0.1000, 0.1000, 0.1000},
+                       {0.1000, 0.1000, 0.1000}},
+
+                      {{0.1000, 0.1000, 0.1000},
+                       {0.1000, 0.1000, 0.1000},
+                       {0.1000, 0.1000, 0.1000},
+                       {0.1000, 0.1000, 0.1000}},
+
+                      {{0.1000, 0.1000, 0.1000},
+                       {0.1000, 0.1000, 0.1000},
+                       {0.1000, 0.1000, 0.1000},
+                       {0.1000, 0.1000, 0.1000}},
+
+                      {{0.1000, 0.1000, 0.1000},
+                       {0.1000, 0.1000, 0.1000},
+                       {0.1000, 0.1000, 0.1000},
+                       {0.1000, 0.1000, 0.1000}}}}));
+
+            auto biases = std::make_shared<Tensor>(Array1D<float, outChannels>(
+                {{0.0100, 0.0100, 0.0100, 0.0100}}));
+
+            auto op = setupTestConv<DIM>(
+                batchSize,
+                inChannels,
+                outChannels,
+                std::array<DimSize_t, DIM>({kernelSize}),
+                std::array<DimSize_t, DIM>({inDataSize}),
+                std::array<DimSize_t, DIM>({stride}),
+                std::array<DimSize_t, DIM>({dilation}),
+                padding,
+                input,
+                weights,
+                biases);
+
+            ////////////////////////////////////
+            // setup gradients for backward
+            auto outputGrad =
+                std::make_shared<Tensor>(op->getOutput(0)->dims());
+            outputGrad->setDataType(DataType::Float32);
+            outputGrad->setBackend("cpu");
+            constantFiller(outputGrad, 1.f);
+            op->getOutput(0)->setGrad(outputGrad);
+
+            ////////////////////////////////////
+            // setup gradients for backward
+            REQUIRE_NOTHROW(op->backward());
+
+            SECTION("Input Grad") {
+                auto expectedInputGrad = std::make_shared<Tensor>(
+                    Array3D<float, batchSize, inChannels, inDataSize>(
+                        {{{{0.4000,
+                            0.0000,
+                            0.0000,
+                            0.4000,
+                            0.4000,
+                            0.0000,
+                            0.4000,
+                            0.4000,
+                            0.0000,
+                            0.0000,
+                            0.4000,
+                            0.0000,
+                            0.0000},
+                           {0.4000,
+                            0.0000,
+                            0.0000,
+                            0.4000,
+                            0.4000,
+                            0.0000,
+                            0.4000,
+                            0.4000,
+                            0.0000,
+                            0.0000,
+                            0.4000,
+                            0.0000,
+                            0.0000},
+                           {0.4000,
+                            0.0000,
+                            0.0000,
+                            0.4000,
+                            0.4000,
+                            0.0000,
+                            0.4000,
+                            0.4000,
+                            0.0000,
+                            0.0000,
+                            0.4000,
+                            0.0000,
+                            0.0000},
+                           {0.4000,
+                            0.0000,
+                            0.0000,
+                            0.4000,
+                            0.4000,
+                            0.0000,
+                            0.4000,
+                            0.4000,
+                            0.0000,
+                            0.0000,
+                            0.4000,
+                            0.0000,
+                            0.0000}}}}));
+                CHECK(approxEq<float, float>(*op->getInput(0)->grad(),
+                                             *expectedInputGrad));
+            }
+            SECTION("Weight grad") {
+                auto expectedWeightsGrad = std::make_shared<Tensor>(
+                    Array3D<float, outChannels, inChannels, kernelSize>(
+                        {{{{2., 2., 2.},
+                           {2., 2., 2.},
+                           {2., 2., 2.},
+                           {2., 2., 2.}},
+
+                          {{2., 2., 2.},
+                           {2., 2., 2.},
+                           {2., 2., 2.},
+                           {2., 2., 2.}},
+
+                          {{2., 2., 2.},
+                           {2., 2., 2.},
+                           {2., 2., 2.},
+                           {2., 2., 2.}},
+
+                          {{2., 2., 2.},
+                           {2., 2., 2.},
+                           {2., 2., 2.},
+                           {2., 2., 2.}}}}));
+                CHECK(approxEq<float, float>(*op->getInput(1)->grad(),
+                                             *expectedWeightsGrad));
+            }
+            SECTION("Bias Grad") {
+                auto expectedBiasesGrad = std::make_shared<Tensor>(
+                    Array1D<float, outChannels>({{2., 2., 2., 2.}}));
+                CHECK(approxEq<float, float>(*op->getInput(2)->grad(),
+                                             *expectedBiasesGrad));
+            }
+        }
+
+        // Harder to read, look at previous tests in case of issue
+        SECTION("Sequential values") {
+            const DimSize_t batchSize = 1;
+            const DimSize_t inChannels = 2;
+            const DimSize_t outChannels = 2;
+            const DimSize_t kernelSize = 3;
+            const DimSize_t inDataSize = 8;
+
+            const DimSize_t stride = 2;
+            const DimSize_t dilation = 2;
+            const std::array<DimSize_t, 2 * DIM> padding({0, 0});
+
+            const DimSize_t outDataSize = 2;
+
+            auto inputSize =
+                std::vector<DimSize_t>({batchSize, inChannels, inDataSize});
+
+            auto input = std::make_shared<Tensor>(
+                Array3D<float, batchSize, inChannels, inDataSize>(
+                    {{{{1., 2., 3., 4., 5., 6., 7., 8.},
+                       {9., 10., 11., 12., 13., 14., 15., 16.}}}}));
+            auto weights = std::make_shared<Tensor>(
+                Array3D<float, outChannels, inChannels, kernelSize>(
+                    {{{{0.1000, 0.2000, 0.3000}, {0.4000, 0.5000, 0.6000}},
+
+                      {{0.7000, 0.8000, 0.9000}, {1.0000, 1.1000, 1.2000}}}}));
+
+            auto biases = std::make_shared<Tensor>(
+                Array1D<float, outChannels>({{0.0100, 0.0200}}));
+
+            auto outputGrad = std::make_shared<Tensor>(
+                Array3D<float, batchSize, outChannels, outDataSize>(
+                    {{{{1., 2.}, {3., 4.}}}}));
+
+            auto op = setupTestConv<DIM>(
+                batchSize,
+                inChannels,
+                outChannels,
+                std::array<DimSize_t, DIM>({kernelSize}),
+                std::array<DimSize_t, DIM>({inDataSize}),
+                std::array<DimSize_t, DIM>({stride}),
+                std::array<DimSize_t, DIM>({dilation}),
+                padding,
+                input,
+                weights,
+                biases);
+
+            ////////////////////////////////////
+            // setup gradients for backward
+            op->getOutput(0)->setGrad(outputGrad);
+
+            REQUIRE_NOTHROW(op->backward());
+
+            SECTION("Input Grad") {
+                auto expectedInputGrad = std::make_shared<Tensor>(
+                    Array3D<float, batchSize, inChannels, inDataSize>(
+                        {{{{2.2000,
+                            0.0000,
+                            5.6000,
+                            0.0000,
+                            6.6000,
+                            0.0000,
+                            4.2000,
+                            0.0000},
+                           {3.4000,
+                            0.0000,
+                            8.6000,
+                            0.0000,
+                            9.6000,
+                            0.0000,
+                            6.0000,
+                            0.0000}}}}));
+                CHECK(approxEq<float, float>(*op->getInput(0)->grad(),
+                                             *expectedInputGrad));
+            }
+            SECTION("Weight grad") {
+                auto expectedWeightsGrad = std::make_shared<Tensor>(
+                    Array3D<float, outChannels, inChannels, kernelSize>(
+                        {{{{7., 13., 19.}, {31., 37., 43.}},
+
+                          {{15., 29., 43.}, {71., 85., 99.}}}}));
+                CHECK(approxEq<float, float>(*op->getInput(1)->grad(),
+                                             *expectedWeightsGrad));
+            }
+            SECTION("Bias Grad") {
+                auto expectedBiasesGrad = std::make_shared<Tensor>(
+                    Array1D<float, outChannels>({{3., 7.}}));
+                CHECK(approxEq<float, float>(*op->getInput(2)->grad(),
+                                             *expectedBiasesGrad));
+            }
+        }
+        SECTION("random values testing") {
+            const DimSize_t batchSize = 1;
+            const DimSize_t inChannels = 4;
+            const DimSize_t outChannels = 4;
+            const DimSize_t kernelSize = 3;
+            const DimSize_t inDataSize = 13;
+            const DimSize_t outDataSize = 2;
+
+            const DimSize_t stride = 4;
+            const DimSize_t dilation = 3;
+            const std::array<DimSize_t, 2 * DIM> padding({0, 0});
+
+            auto inputSize =
+                std::vector<DimSize_t>({batchSize, inChannels, inDataSize});
+
+            auto input = std::make_shared<Tensor>(
+                Array3D<float, batchSize, inChannels, inDataSize>(
+                    {{{{0.180772,
+                        -0.069988,
+                        -0.359623,
+                        -0.915204,
+                        0.625765,
+                        0.025510,
+                        0.954514,
+                        0.064349,
+                        0.361151,
+                        1.167878,
+                        -1.349893,
+                        -0.510177,
+                        0.235958},
+                       {-0.239778,
+                        -0.921115,
+                        1.543297,
+                        1.348826,
+                        -0.139642,
+                        0.285797,
+                        0.965120,
+                        -2.037150,
+                        0.493136,
+                        1.486999,
+                        0.591033,
+                        0.126030,
+                        -1.562687},
+                       {-1.160103,
+                        -0.334841,
+                        0.447772,
+                        -0.801645,
+                        1.523611,
+                        2.508587,
+                        -0.663096,
+                        -0.251275,
+                        1.010145,
+                        0.121547,
+                        -1.510835,
+                        2.104773,
+                        2.762959},
+                       {-1.746529,
+                        0.410919,
+                        -0.242185,
+                        0.420812,
+                        0.277596,
+                        0.778898,
+                        1.533269,
+                        1.609736,
+                        -0.403228,
+                        -0.274928,
+                        1.473840,
+                        0.068826,
+                        1.332708}}}}));
+            auto weights = std::make_shared<Tensor>(
+                Array3D<float, outChannels, inChannels, kernelSize>(
+                    {{{{0.587285, 0.286069, 0.008287},
+                       {-0.252325, -1.324722, 0.189178},
+                       {0.021100, 0.940420, -0.557690},
+                       {-0.693927, -0.325247, 1.243933}},
+
+                      {{-1.167186, -0.409124, 1.260062},
+                       {-1.563006, 1.134614, -0.082384},
+                       {0.289316, 0.835773, -0.244991},
+                       {0.271223, 0.093636, -0.883432}},
+
+                      {{-0.327417, 0.078394, -0.380766},
+                       {0.377508, 0.111912, 2.314279},
+                       {-0.798906, -0.564303, -1.134660},
+                       {0.170527, 0.994665, 1.262572}},
+
+                      {{1.621816, 1.077471, 0.594781},
+                       {-1.529087, 2.043707, -0.165627},
+                       {0.087070, -0.527656, -0.100288},
+                       {1.053922, -0.623074, -1.590572}}}}));
+
+            auto biases = std::make_shared<Tensor>(Array1D<float, outChannels>(
+                {{1.285940, -0.051787, -0.968103, -0.586324}}));
+
+            auto op = setupTestConv<DIM>(
+                batchSize,
+                inChannels,
+                outChannels,
+                std::array<DimSize_t, DIM>({kernelSize}),
+                std::array<DimSize_t, DIM>({inDataSize}),
+                std::array<DimSize_t, DIM>({stride}),
+                std::array<DimSize_t, DIM>({dilation}),
+                padding,
+                input,
+                weights,
+                biases);
+
+            ////////////////////////////////////
+            // setup gradients for backward
+            auto outputGrad = std::make_shared<Tensor>(
+                Array3D<float, batchSize, outChannels, outDataSize>(
+                    {{{{0.053156, 1.189073},
+                       {0.100228, 1.042344},
+                       {-1.468991, 0.581337},
+                       {1.330418, 0.487802}}}}));
+            op->getOutput(0)->setGrad(outputGrad);
+
+            ////////////////////////////////////
+            // setup gradients for backward
+            REQUIRE_NOTHROW(op->backward());
+
+            SECTION("Input Grad") {
+                auto expectedInputGrad = std::make_shared<Tensor>(
+                    Array3D<float, batchSize, inChannels, inDataSize>(
+                        {{{{2.552898,
+                            0.000000,
+                            0.000000,
+                            1.292528,
+                            0.082501,
+                            0.000000,
+                            1.477383,
+                            0.484875,
+                            0.000000,
+                            0.000000,
+                            1.392054,
+                            0.000000,
+                            0.000000},
+                           {-2.758950,
+                            0.000000,
+                            0.000000,
+                            2.597889,
+                            -2.455656,
+                            0.000000,
+                            -3.618210,
+                            0.669449,
+                            0.000000,
+                            0.000000,
+                            1.403657,
+                            0.000000,
+                            0.000000},
+                           {1.319545,
+                            0.000000,
+                            0.000000,
+                            0.260710,
+                            -0.095303,
+                            0.000000,
+                            1.479181,
+                            1.403949,
+                            0.000000,
+                            0.000000,
+                            -1.627040,
+                            0.000000,
+                            0.000000},
+                           {1.141951,
+                            0.000000,
+                            0.000000,
+                            -2.298007,
+                            0.070817,
+                            0.000000,
+                            -3.993255,
+                            -0.014843,
+                            0.000000,
+                            0.000000,
+                            0.516383,
+                            0.000000,
+                            0.000000}}}}));
+                CHECK(approxEq<float, float>(*op->getInput(0)->grad(),
+                                             *expectedInputGrad,
+                                             1e-5,
+                                             1e-6));
+            }
+            SECTION("Weight grad") {
+                auto expectedWeightsGrad = std::make_shared<Tensor>(
+                    Array3D<float, outChannels, inChannels, kernelSize>(
+                        {{{{0.753690, 0.027866, -1.554383},
+                           {-0.178790, -2.350622, 0.754084},
+                           {1.750019, -0.341397, -1.831741},
+                           {0.237243, 1.936463, 1.834007}},
+
+                          {{0.670381, -0.024656, -1.311384},
+                           {-0.169587, -1.988220, 0.712792},
+                           {1.471852, -0.342263, -1.641270},
+                           {0.114300, 1.720076, 1.689925}},
+
+                          {{0.098228, 1.381835, -2.186914},
+                           {0.271054, -3.165683, -1.074165},
+                           {2.589912, 1.031534, 0.095779},
+                           {2.727013, 0.317630, -1.395561}},
+
+                          {{0.545751, -1.186215, 0.611421},
+                           {-0.387123, 0.800776, 1.572321},
+                           {-0.800201, -1.189095, -1.619183},
+                           {-2.188202, 1.345088, 2.758830}}}
+
+                        }));
+                CHECK(approxEq<float, float>(*op->getInput(1)->grad(),
+                                             *expectedWeightsGrad,
+                                             1e-5,
+                                             1e-6));
+            }
+            SECTION("Bias Grad") {
+                auto expectedBiasesGrad =
+                    std::make_shared<Tensor>(Array1D<float, outChannels>(
+                        {{1.242230, 1.142572, -0.887655, 1.818220}}));
+                CHECK(approxEq<float, float>(*op->getInput(2)->grad(),
+                                             *expectedBiasesGrad));
+            }
+        }
+    }
+    SECTION("2D") {
+        const DimSize_t DIM = 2;
+        SECTION("Sequential values") {
+            constexpr DimSize_t batchSize = 1;
+            constexpr DimSize_t inChannels = 1;
+            constexpr DimSize_t outChannels = 2;
+            constexpr std::array<DimSize_t, DIM> kernelSize = {1, 2};
+            constexpr std::array<DimSize_t, DIM> inDataSize = {3, 4};
+
+            constexpr std::array<DimSize_t, DIM> stride = {1, 2};
+            constexpr std::array<DimSize_t, DIM> dilation = {1, 2};
+            constexpr std::array<DimSize_t, 2 * DIM> padding({0, 0});
+
+            constexpr std::array<DimSize_t, DIM> outDataSize = {3, 1};
+
+            auto inputSize = std::vector<DimSize_t>(
+                {batchSize, inChannels, inDataSize[0], inDataSize[1]});
+
+            auto input = std::make_shared<Tensor>(
+                Array4D<float,
+                        batchSize,
+                        inChannels,
+                        inDataSize[0],
+                        inDataSize[1]>({{{{{1., 2., 3., 4.},
+                                           {5., 6., 7., 8.},
+                                           {9., 10., 11., 12.}}}}}));
+            auto weights = std::make_shared<Tensor>(
+                Array4D<float,
+                        outChannels,
+                        inChannels,
+                        kernelSize[0],
+                        kernelSize[1]>({{{{{1., 2.}}}, {{{3., 4.}}}}}));
+
+            auto biases = std::make_shared<Tensor>(
+                Array1D<float, outChannels>({{1., 2.}}));
+
+            auto outputGrad = std::make_shared<Tensor>(Array4D<float,
+                                                               batchSize,
+                                                               outChannels,
+                                                               outDataSize[0],
+                                                               outDataSize[1]>(
+                {{{{{1.}, {2.}, {3.}}, {{4.}, {5.}, {6.}}}}}));
+
+            auto op = setupTestConv<DIM>(batchSize,
+                                                 inChannels,
+                                                 outChannels,
+                                                 kernelSize,
+                                                 inDataSize,
+                                                 stride,
+                                                 dilation,
+                                                 padding,
+                                                 input,
+                                                 weights,
+                                                 biases);
+
+            ////////////////////////////////////
+            // setup gradients for backward
+            op->getOutput(0)->setGrad(outputGrad);
+
+            REQUIRE_NOTHROW(op->backward());
+
+            SECTION("Input Grad") {
+                auto expectedInputGrad = std::make_shared<Tensor>(
+                    Array4D<float,
+                            batchSize,
+                            inChannels,
+                            inDataSize[0],
+                            inDataSize[1]>({{{{{13., 0., 18., 0.},
+                                               {17., 0., 24., 0.},
+                                               {21., 0., 30., 0.}}}}}));
+                CHECK(approxEq<float, float>(*op->getInput(0)->grad(),
+                                             *expectedInputGrad));
+            }
+            SECTION("Weight grad") {
+                auto expectedWeightsGrad =
+                    std::make_shared<Tensor>(Array4D<float,
+                                                     outChannels,
+                                                     inChannels,
+                                                     kernelSize[0],
+                                                     kernelSize[1]>(
+                        {{{{{38., 50.}}}, {{{83., 113.}}}}}));
+                CHECK(approxEq<float, float>(*op->getInput(1)->grad(),
+                                             *expectedWeightsGrad));
+            }
+            SECTION("Bias Grad") {
+                auto expectedBiasesGrad = std::make_shared<Tensor>(
+                    Array1D<float, outChannels>({{6., 15.}}));
+                CHECK(approxEq<float, float>(*op->getInput(2)->grad(),
+                                             *expectedBiasesGrad));
+            }
+        }
+    }
+}
diff --git a/unit_tests/operator/Test_ConvTranspose.cpp b/unit_tests/operator/Test_ConvTranspose.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..6e889e809e0a05d551829bd15fda9cc651068465
--- /dev/null
+++ b/unit_tests/operator/Test_ConvTranspose.cpp
@@ -0,0 +1,2298 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <aidge/utils/Types.h>
+#include <memory>
+
+#include <catch2/catch_test_macros.hpp>
+#include <fmt/core.h>
+
+#include "aidge/backend/cpu/operator/ConvTransposeImpl.hpp"
+#include "aidge/data/Tensor.hpp"
+#include "aidge/operator/ConvTranspose.hpp"
+#include "aidge/utils/TensorUtils.hpp"
+
+namespace Aidge {
+
+template <DimSize_t DIM>
+static std::shared_ptr<OperatorTensor>
+setupTestConvTranspose(const DimSize_t batchSize,
+                       const DimSize_t inChannels,
+                       const DimSize_t outChannels,
+                       const std::array<DimSize_t, DIM> kernelSize,
+                       const std::array<DimSize_t, DIM> dataSize,
+                       const std::array<DimSize_t, DIM> stride,
+                       const std::array<DimSize_t, DIM> dilation,
+                       const std::shared_ptr<Tensor> input,
+                       const std::shared_ptr<Tensor> weights,
+                       const std::shared_ptr<Tensor> biases) {
+    std::shared_ptr<Node> convTransposeNode;
+    convTransposeNode = ConvTranspose(inChannels,
+                                      outChannels,
+                                      kernelSize,
+                                      stride,
+                                      dilation,
+                                      false,
+                                      "myconv");
+    auto op = std::static_pointer_cast<OperatorTensor>(
+        convTransposeNode->getOperator());
+
+    op->associateInput(0, input);
+    op->setDataType(DataType::Float32);
+
+    input->setBackend("cpu");
+    op->setBackend("cpu");
+
+    weights->setBackend("cpu");
+    op->associateInput(1, weights);
+
+    biases->setBackend("cpu");
+    op->associateInput(2, biases);
+
+    REQUIRE_NOTHROW(op->forwardDims(true));
+
+    return op;
+}
+
+TEST_CASE("[cpu/operator] ConvTranspose(forward)", "[ConvTranspose][CPU]") {
+    constexpr DimSize_t DIM = 1;
+    SECTION("1D") {
+        SECTION("kernel = 2 , in/outChannels = 1") {
+            constexpr DimSize_t batchSize = 1;
+            constexpr DimSize_t inChannels = 1;
+            constexpr DimSize_t outChannels = 1;
+
+            constexpr std::array<DimSize_t, DIM> kernelSize{2};
+
+            constexpr std::array<DimSize_t, DIM> inDataSize{4};
+            constexpr std::array<DimSize_t, DIM> outDataSize{5};
+
+            constexpr std::array<DimSize_t, DIM> stride{1};
+            constexpr std::array<DimSize_t, DIM> dilation{1};
+
+            auto input = std::make_shared<Tensor>(
+                Array3D<float, batchSize, inChannels, inDataSize[0]>(
+                    {{{{1.000000, 2.000000, 3.000000, 4.000000}}}}));
+
+            auto weights = std::make_shared<Tensor>(
+                Array3D<float, inChannels, outChannels, kernelSize[0]>(
+                    {{{{0.100000, 0.200000}}}}));
+
+            auto biases = std::make_shared<Tensor>(
+                Array1D<float, outChannels>({{0.010000}}));
+
+            auto op = setupTestConvTranspose<DIM>(batchSize,
+                                                  inChannels,
+                                                  outChannels,
+                                                  kernelSize,
+                                                  inDataSize,
+                                                  stride,
+                                                  dilation,
+                                                  input,
+                                                  weights,
+                                                  biases);
+
+            REQUIRE_NOTHROW(op->forward());
+
+            auto expectedOutput = std::make_shared<Tensor>(
+                Array3D<float, batchSize, outChannels, outDataSize[0]>(
+                    {{{{0.110000, 0.410000, 0.710000, 1.010000, 0.810000}}}}));
+            CHECK(approxEq<float, float>(*op->getOutput(0), *expectedOutput));
+        }
+        SECTION("kernel = 2, inChannel = 2, outChannels = 1") {
+            constexpr DimSize_t batchSize = 1;
+            constexpr DimSize_t inChannels = 2;
+            constexpr DimSize_t outChannels = 1;
+
+            constexpr std::array<DimSize_t, DIM> kernelSize{2};
+
+            constexpr std::array<DimSize_t, DIM> inDataSize{4};
+            constexpr std::array<DimSize_t, DIM> outDataSize{5};
+
+            constexpr std::array<DimSize_t, DIM> stride{1};
+            constexpr std::array<DimSize_t, DIM> dilation{1};
+
+            auto input = std::make_shared<Tensor>(
+                Array3D<float, batchSize, inChannels, inDataSize[0]>(
+                    {{{{1.000000, 2.000000, 3.000000, 4.000000},
+                       {5.000000, 6.000000, 7.000000, 8.000000}}}}));
+
+            auto weights = std::make_shared<Tensor>(
+                Array3D<float, inChannels, outChannels, kernelSize[0]>(
+                    {{{{0.100000, 0.200000}}, {{0.300000, 0.400000}}}}));
+
+            auto biases = std::make_shared<Tensor>(
+                Array1D<float, outChannels>({{0.010000}}));
+
+            auto op = setupTestConvTranspose<DIM>(batchSize,
+                                                  inChannels,
+                                                  outChannels,
+                                                  kernelSize,
+                                                  inDataSize,
+                                                  stride,
+                                                  dilation,
+                                                  input,
+                                                  weights,
+                                                  biases);
+
+            REQUIRE_NOTHROW(op->forward());
+
+            auto expectedOutput = std::make_shared<Tensor>(
+                Array3D<float, batchSize, outChannels, outDataSize[0]>(
+                    {{{{1.610000, 4.210000, 5.210000, 6.210001, 4.010000}}}}));
+            CHECK(approxEq<float, float>(*op->getOutput(0), *expectedOutput));
+        }
+        SECTION("kernel = 2, inChannel = 1, outChannels = 2") {
+            constexpr DimSize_t batchSize = 1;
+            constexpr DimSize_t inChannels = 1;
+            constexpr DimSize_t outChannels = 2;
+
+            constexpr std::array<DimSize_t, DIM> kernelSize{2};
+
+            constexpr std::array<DimSize_t, DIM> inDataSize{4};
+            constexpr std::array<DimSize_t, DIM> outDataSize{5};
+
+            constexpr std::array<DimSize_t, DIM> stride{1};
+            constexpr std::array<DimSize_t, DIM> dilation{1};
+
+            auto input = std::make_shared<Tensor>(
+                Array3D<float, batchSize, inChannels, inDataSize[0]>(
+                    {{{{1., 2., 3., 4.}}}}));
+
+            auto weights = std::make_shared<Tensor>(
+                Array3D<float, inChannels, outChannels, kernelSize[0]>(
+                    {{{{0.1, 0.2}, {0.3, 0.4}}}}));
+
+            auto biases = std::make_shared<Tensor>(
+                Array1D<float, outChannels>({{0.01, 0.02}}));
+
+            auto op = setupTestConvTranspose<DIM>(batchSize,
+                                                  inChannels,
+                                                  outChannels,
+                                                  kernelSize,
+                                                  inDataSize,
+                                                  stride,
+                                                  dilation,
+                                                  input,
+                                                  weights,
+                                                  biases);
+
+            REQUIRE_NOTHROW(op->forward());
+
+            auto expectedOutput = std::make_shared<Tensor>(
+                Array3D<float, batchSize, outChannels, outDataSize[0]>(
+                    {{{{0.11, 0.41, 0.71, 1.01, 0.81},
+                       {0.32, 1.02, 1.72, 2.42, 1.62}}}}));
+
+            CHECK(approxEq<float, float>(*op->getOutput(0), *expectedOutput));
+        }
+        SECTION("kernel = 1, inChannel = 2, outChannels = 2") {
+            constexpr DimSize_t batchSize = 1;
+            constexpr DimSize_t inChannels = 2;
+            constexpr DimSize_t outChannels = 2;
+
+            constexpr std::array<DimSize_t, DIM> kernelSize{1};
+
+            constexpr std::array<DimSize_t, DIM> inDataSize{4};
+            constexpr std::array<DimSize_t, DIM> outDataSize{4};
+
+            constexpr std::array<DimSize_t, DIM> stride{1};
+            constexpr std::array<DimSize_t, DIM> dilation{1};
+
+            auto input = std::make_shared<Tensor>(
+                Array3D<float, batchSize, inChannels, inDataSize[0]>(
+                    {{{{1.000000, 2.000000, 3.000000, 4.000000},
+                       {5.000000, 6.000000, 7.000000, 8.000000}}}}));
+
+            auto weights = std::make_shared<Tensor>(
+                Array3D<float, inChannels, outChannels, kernelSize[0]>(
+                    {{{{0.100000}, {0.200000}},
+
+                      {{0.300000}, {0.400000}}}}));
+
+            auto biases = std::make_shared<Tensor>(
+                Array1D<float, outChannels>({{0.010000, 0.020000}}));
+
+            auto op = setupTestConvTranspose<DIM>(batchSize,
+                                                  inChannels,
+                                                  outChannels,
+                                                  kernelSize,
+                                                  inDataSize,
+                                                  stride,
+                                                  dilation,
+                                                  input,
+                                                  weights,
+                                                  biases);
+
+            REQUIRE_NOTHROW(op->forward());
+
+            auto expectedOutput = std::make_shared<Tensor>(
+                Array3D<float, batchSize, outChannels, outDataSize[0]>(
+                    {{{{1.610000, 2.010000, 2.410000, 2.810000},
+                       {2.220000, 2.820000, 3.420000, 4.020000}}}}));
+
+            CHECK(approxEq<float, float>(*op->getOutput(0), *expectedOutput));
+        }
+        SECTION("kernel = 2, inChannels = 2, outChannels = 3") {
+            constexpr DimSize_t batchSize = 1;
+            constexpr DimSize_t inChannels = 2;
+            constexpr DimSize_t outChannels = 3;
+
+            constexpr std::array<DimSize_t, DIM> kernelSize{2};
+
+            constexpr std::array<DimSize_t, DIM> inDataSize{4};
+            constexpr std::array<DimSize_t, DIM> outDataSize{5};
+
+            constexpr std::array<DimSize_t, DIM> stride{1};
+            constexpr std::array<DimSize_t, DIM> dilation{1};
+
+            auto input = std::make_shared<Tensor>(
+                Array3D<float, batchSize, inChannels, inDataSize[0]>(
+                    {{{{1., 2., 3., 4.}, {5., 6., 7., 8.}}}}));
+
+            auto weights = std::make_shared<Tensor>(
+                Array3D<float, inChannels, outChannels, kernelSize[0]>(
+                    {{{{0.10, 0.20}, {0.30, 0.40}, {0.50, 0.60}},
+
+                      {{0.70, 0.80}, {0.90, 1.}, {1.10, 1.20}}}}));
+
+            auto biases = std::make_shared<Tensor>(
+                Array1D<float, outChannels>({{0.010000, 0.020000, 0.030000}}));
+
+            auto op = setupTestConvTranspose<DIM>(batchSize,
+                                                  inChannels,
+                                                  outChannels,
+                                                  kernelSize,
+                                                  inDataSize,
+                                                  stride,
+                                                  dilation,
+                                                  input,
+                                                  weights,
+                                                  biases);
+
+            REQUIRE_NOTHROW(op->forward());
+
+            auto expectedOutput = std::make_shared<
+                Tensor>(Array3D<float, batchSize, outChannels, outDataSize[0]>(
+                {{{{3.610000, 8.610001, 10.410000, 12.210001, 7.210001},
+                   {4.820000, 11.420000, 14.020000, 16.620001, 9.620001},
+                   {6.030000, 14.230000, 17.630001, 21.030001, 12.030000}}}}));
+
+            CHECK(approxEq<float, float>(*op->getOutput(0), *expectedOutput));
+        }
+
+        SECTION("Big test to ensure kernel capabilities") {
+            constexpr DimSize_t batchSize = 2;
+            constexpr DimSize_t inChannels = 3;
+            constexpr DimSize_t outChannels = 4;
+
+            constexpr std::array<DimSize_t, DIM> kernelSize{6};
+
+            constexpr std::array<DimSize_t, DIM> inDataSize{6};
+            constexpr std::array<DimSize_t, DIM> outDataSize{11};
+
+            constexpr std::array<DimSize_t, DIM> stride{1};
+            constexpr std::array<DimSize_t, DIM> dilation{1};
+
+            auto input = std::make_shared<Tensor>(
+                Array3D<float, batchSize, inChannels, inDataSize[0]>(
+                    {{{{1., 2., 3., 4., 5., 6.},
+                       {7., 8., 9., 10., 11., 12.},
+                       {13., 14., 15., 16., 17., 18.}},
+
+                      {{19., 20., 21., 22., 23., 24.},
+                       {25., 26., 27., 28., 29., 30.},
+                       {31., 32., 33., 34., 35., 36.}}}}));
+
+            auto weights = std::make_shared<Tensor>(
+                Array3D<float, inChannels, outChannels, kernelSize[0]>(
+                    {{{{0.1, 0.2, 0.3, 0.4, 0.5, 0.6},
+                       {0.7, 0.8, 0.9, 1., 1.1, 1.2},
+                       {1.3, 1.4, 1.5, 1.6, 1.7, 1.8},
+                       {1.9, 2., 2.1, 2.2, 2.3, 2.4}},
+
+                      {{2.5, 2.6, 2.7, 2.8, 2.9, 3.},
+                       {3.1, 3.2, 3.3, 3.4, 3.5, 3.6},
+                       {3.7, 3.8, 3.9, 4., 4.1, 4.2},
+                       {4.3, 4.4, 4.5, 4.6, 4.7, 4.8}},
+
+                      {{4.9, 5., 5.1, 5.2, 5.3, 5.4},
+                       {5.5, 5.6, 5.7, 5.8, 5.9, 6.},
+                       {6.1, 6.2, 6.3, 6.4, 6.5, 6.6},
+                       {6.7, 6.8, 6.9, 7., 7.1, 7.2}}}}));
+
+            auto biases = std::make_shared<Tensor>(
+                Array1D<float, outChannels>({{0.01, 0.02, 0.03, 0.04}}));
+
+            auto op = setupTestConvTranspose<DIM>(batchSize,
+                                                  inChannels,
+                                                  outChannels,
+                                                  kernelSize,
+                                                  inDataSize,
+                                                  stride,
+                                                  dilation,
+                                                  input,
+                                                  weights,
+                                                  biases);
+
+            REQUIRE_NOTHROW(op->forward());
+
+            auto expectedOutput = std::make_shared<Tensor>(
+                Array3D<float, batchSize, outChannels, outDataSize[0]>(
+                    {{{{81.310005,
+                        172.210007,
+                        273.010010,
+                        384.010040,
+                        505.509979,
+                        637.810059,
+                        561.010010,
+                        472.809998,
+                        372.910004,
+                        261.010010,
+                        136.809998},
+                       {93.919998,
+                        199.220001,
+                        316.219971,
+                        445.220001,
+                        586.520081,
+                        740.420044,
+                        651.020020,
+                        548.420044,
+                        432.319977,
+                        302.420013,
+                        158.419998},
+                       {106.529999,
+                        226.230011,
+                        359.429993,
+                        506.430054,
+                        667.530090,
+                        843.030029,
+                        741.030029,
+                        624.030029,
+                        491.730042,
+                        343.829987,
+                        180.029999},
+                       {119.140007,
+                        253.240005,
+                        402.640045,
+                        567.640076,
+                        748.539978,
+                        945.639954,
+                        831.039978,
+                        699.640015,
+                        551.140015,
+                        385.239990,
+                        201.639999}},
+
+                      {{216.309998,
+                        447.610016,
+                        694.210022,
+                        956.410034,
+                        1234.510132,
+                        1528.810059,
+                        1317.010010,
+                        1088.410034,
+                        842.710022,
+                        579.610046,
+                        298.810028},
+                       {261.319977,
+                        539.420044,
+                        834.619995,
+                        1147.220093,
+                        1477.520142,
+                        1825.820068,
+                        1569.019897,
+                        1293.619995,
+                        999.320068,
+                        685.820007,
+                        352.819977},
+                       {306.329987,
+                        631.230042,
+                        975.030029,
+                        1338.030151,
+                        1720.530029,
+                        2122.829834,
+                        1821.029785,
+                        1498.830200,
+                        1155.930054,
+                        792.030029,
+                        406.830017},
+                       {351.340027,
+                        723.039978,
+                        1115.440063,
+                        1528.840210,
+                        1963.539917,
+                        2419.839844,
+                        2073.040283,
+                        1704.040039,
+                        1312.540039,
+                        898.239990,
+                        460.840027}}}}));
+            CHECK(approxEq<float, float>(*op->getOutput(0), *expectedOutput));
+        }
+    }
+
+    SECTION("2D") {
+        constexpr DimSize_t DIM = 2;
+        SECTION("inChannels = 1, outChannels = 2, kernelSize = {1,2}, "
+                "inDataSize = {2,3}") {
+            constexpr DimSize_t batchSize = 1;
+            constexpr DimSize_t inChannels = 1;
+            constexpr DimSize_t outChannels = 2;
+
+            constexpr std::array<DimSize_t, DIM> kernelSize{1, 2};
+
+            constexpr std::array<DimSize_t, DIM> inDataSize{2, 3};
+            constexpr std::array<DimSize_t, DIM> outDataSize{2, 4};
+
+            constexpr std::array<DimSize_t, DIM> stride{1, 1};
+            constexpr std::array<DimSize_t, DIM> dilation{1, 1};
+
+            auto input = std::make_shared<Tensor>(Array4D<float,
+                                                          batchSize,
+                                                          inChannels,
+                                                          inDataSize[0],
+                                                          inDataSize[1]>(
+                {{{{{1.000000, 2.000000, 3.000000},
+                    {4.000000, 5.000000, 6.000000}}}}}));
+
+            auto weights = std::make_shared<Tensor>(
+                Array4D<float,
+                        inChannels,
+                        outChannels,
+                        kernelSize[0],
+                        kernelSize[1]>({{{{{0.100000, 0.200000}},
+
+                                          {{0.300000, 0.400000}}}}}));
+
+            auto biases = std::make_shared<Tensor>(
+                Array1D<float, outChannels>({{0.010000}}));
+
+            auto op = setupTestConvTranspose<DIM>(batchSize,
+                                                  inChannels,
+                                                  outChannels,
+                                                  kernelSize,
+                                                  inDataSize,
+                                                  stride,
+                                                  dilation,
+                                                  input,
+                                                  weights,
+                                                  biases);
+
+            REQUIRE_NOTHROW(op->forward());
+
+            auto expectedOutput =
+                std::make_shared<Tensor>(Array4D<float,
+                                                 batchSize,
+                                                 outChannels,
+                                                 outDataSize[0],
+                                                 outDataSize[1]>(
+                    {{{{{0.110000, 0.410000, 0.710000, 0.610000},
+                        {0.410000, 1.310000, 1.610000, 1.210000}},
+
+                       {{0.320000, 1.020000, 1.720000, 1.220000},
+                        {1.220000, 3.120000, 3.820000, 2.420000}}}}}));
+        }
+        SECTION("inChannels = 1, outChannels = 2, kernelSize = {2,3}, "
+                "inDataSize = {2,3}") {
+            constexpr DimSize_t batchSize = 1;
+            constexpr DimSize_t inChannels = 1;
+            constexpr DimSize_t outChannels = 2;
+
+            constexpr std::array<DimSize_t, DIM> kernelSize{2, 3};
+
+            constexpr std::array<DimSize_t, DIM> inDataSize{2, 3};
+            constexpr std::array<DimSize_t, DIM> outDataSize{3, 5};
+
+            constexpr std::array<DimSize_t, DIM> stride{1, 1};
+            constexpr std::array<DimSize_t, DIM> dilation{1, 1};
+
+            auto input = std::make_shared<Tensor>(Array4D<float,
+                                                          batchSize,
+                                                          inChannels,
+                                                          inDataSize[0],
+                                                          inDataSize[1]>(
+                {{{{{1.000000, 2.000000, 3.000000},
+                    {4.000000, 5.000000, 6.000000}}}}}));
+
+            auto weights = std::make_shared<Tensor>(Array4D<float,
+                                                            inChannels,
+                                                            outChannels,
+                                                            kernelSize[0],
+                                                            kernelSize[1]>(
+                {{{{{0.100000, 0.200000, 0.300000},
+                    {0.400000, 0.500000, 0.600000}},
+
+                   {{0.700000, 0.800000, 0.900000},
+                    {1.000000, 1.100000, 1.200000}}}}}));
+
+            auto biases = std::make_shared<Tensor>(
+                Array1D<float, outChannels>({{0.010000, 0.020000}}));
+
+            auto op = setupTestConvTranspose<DIM>(batchSize,
+                                                  inChannels,
+                                                  outChannels,
+                                                  kernelSize,
+                                                  inDataSize,
+                                                  stride,
+                                                  dilation,
+                                                  input,
+                                                  weights,
+                                                  biases);
+
+            REQUIRE_NOTHROW(op->forward());
+
+            auto expectedOutput = std::make_shared<
+                Tensor>(Array4D<float,
+                                batchSize,
+                                outChannels,
+                                outDataSize[0],
+                                outDataSize[1]>(
+                {{{{{0.110000, 0.410000, 1.010000, 1.210000, 0.910000},
+                    {0.810000, 2.610000, 5.610000, 5.410000, 3.610000},
+                    {1.610000, 4.010000, 7.310000, 6.010000, 3.610000}},
+
+                   {{0.720000, 2.220000, 4.620000, 4.220000, 2.720000},
+                    {3.820000, 9.820001, 18.220001, 15.020000, 9.020000},
+                    {4.020000, 9.420000, 16.320000, 12.620001, 7.220000}}}}}));
+        }
+        SECTION("inChannels = 1, outChannels = 2, kernelSize = {2,3}, "
+                "inDataSize = {6,6}, stride = {2,  2}, dilation = {2,  2}") {
+            constexpr DimSize_t batchSize = 1;
+            constexpr DimSize_t inChannels = 1;
+            constexpr DimSize_t outChannels = 2;
+
+            constexpr std::array<DimSize_t, DIM> kernelSize{2, 3};
+
+            constexpr std::array<DimSize_t, DIM> inDataSize{4, 4};
+            constexpr std::array<DimSize_t, DIM> outDataSize{9, 11};
+
+            constexpr std::array<DimSize_t, DIM> stride{2, 2};
+            constexpr std::array<DimSize_t, DIM> dilation{2, 2};
+
+            auto input = std::make_shared<Tensor>(Array4D<float,
+                                                          batchSize,
+                                                          inChannels,
+                                                          inDataSize[0],
+                                                          inDataSize[1]>(
+                {{{{{1.00, 2.00, 3.00, 4.000000},
+                    {5.00, 6.00, 7.00, 8.000000},
+                    {9.00, 10.00, 11.00, 12.000000},
+                    {13.00, 14.00, 15.00, 16.000000}}}}}));
+
+            auto weights = std::make_shared<Tensor>(Array4D<float,
+                                                            inChannels,
+                                                            outChannels,
+                                                            kernelSize[0],
+                                                            kernelSize[1]>(
+                {{{{{0.10, 0.20, 0.300000}, {0.40, 0.50, 0.600000}},
+
+                   {{0.70, 0.80, 0.900000}, {1.00, 1.10, 1.200000}}}}}));
+
+            auto biases = std::make_shared<Tensor>(
+                Array1D<float, outChannels>({{0.01, 0.020000}}));
+
+            auto op = setupTestConvTranspose<DIM>(batchSize,
+                                                  inChannels,
+                                                  outChannels,
+                                                  kernelSize,
+                                                  inDataSize,
+                                                  stride,
+                                                  dilation,
+                                                  input,
+                                                  weights,
+                                                  biases);
+
+            REQUIRE_NOTHROW(op->forward());
+
+            auto expectedOutput = std::make_shared<Tensor>(
+                Array4D<float,
+                        batchSize,
+                        outChannels,
+                        outDataSize[0],
+                        outDataSize[1]>({{{{{0.11,
+                                             0.01,
+                                             0.41,
+                                             0.01,
+                                             1.01,
+                                             0.01,
+                                             1.61,
+                                             0.01,
+                                             1.71,
+                                             0.01,
+                                             1.210000},
+                                            {0.01,
+                                             0.01,
+                                             0.01,
+                                             0.01,
+                                             0.01,
+                                             0.01,
+                                             0.01,
+                                             0.01,
+                                             0.01,
+                                             0.01,
+                                             0.010000},
+                                            {0.91,
+                                             0.01,
+                                             2.91,
+                                             0.01,
+                                             6.210001,
+                                             0.01,
+                                             8.31,
+                                             0.01,
+                                             7.510001,
+                                             0.01,
+                                             4.810000},
+                                            {0.01,
+                                             0.01,
+                                             0.01,
+                                             0.01,
+                                             0.01,
+                                             0.01,
+                                             0.01,
+                                             0.01,
+                                             0.01,
+                                             0.01,
+                                             0.010000},
+                                            {2.91,
+                                             0.01,
+                                             7.710001,
+                                             0.01,
+                                             14.610001,
+                                             0.01,
+                                             16.710001,
+                                             0.01,
+                                             13.910002,
+                                             0.01,
+                                             8.410001},
+                                            {0.01,
+                                             0.01,
+                                             0.01,
+                                             0.01,
+                                             0.01,
+                                             0.01,
+                                             0.01,
+                                             0.01,
+                                             0.01,
+                                             0.01,
+                                             0.010000},
+                                            {4.91,
+                                             0.01,
+                                             12.51,
+                                             0.01,
+                                             23.01,
+                                             0.01,
+                                             25.110001,
+                                             0.01,
+                                             20.309999,
+                                             0.01,
+                                             12.010000},
+                                            {0.01,
+                                             0.01,
+                                             0.01,
+                                             0.01,
+                                             0.01,
+                                             0.01,
+                                             0.01,
+                                             0.01,
+                                             0.01,
+                                             0.01,
+                                             0.010000},
+                                            {5.210001,
+                                             0.01,
+                                             12.110001,
+                                             0.01,
+                                             20.809999,
+                                             0.01,
+                                             22.309999,
+                                             0.01,
+                                             17.01,
+                                             0.01,
+                                             9.610001}},
+
+                                           {{0.72,
+                                             0.02,
+                                             2.22,
+                                             0.02,
+                                             4.62,
+                                             0.02,
+                                             7.02,
+                                             0.02,
+                                             5.92,
+                                             0.02,
+                                             3.620000},
+                                            {0.02,
+                                             0.02,
+                                             0.02,
+                                             0.02,
+                                             0.02,
+                                             0.02,
+                                             0.02,
+                                             0.02,
+                                             0.02,
+                                             0.02,
+                                             0.020000},
+                                            {4.52,
+                                             0.02,
+                                             11.320001,
+                                             0.02,
+                                             20.620003,
+                                             0.02,
+                                             26.320002,
+                                             0.02,
+                                             20.720001,
+                                             0.02,
+                                             12.020000},
+                                            {0.02,
+                                             0.02,
+                                             0.02,
+                                             0.02,
+                                             0.02,
+                                             0.02,
+                                             0.02,
+                                             0.02,
+                                             0.02,
+                                             0.02,
+                                             0.020000},
+                                            {11.32,
+                                             0.02,
+                                             25.720001,
+                                             0.02,
+                                             43.420002,
+                                             0.02,
+                                             49.120003,
+                                             0.02,
+                                             36.720001,
+                                             0.02,
+                                             20.420002},
+                                            {0.02,
+                                             0.02,
+                                             0.02,
+                                             0.02,
+                                             0.02,
+                                             0.02,
+                                             0.02,
+                                             0.02,
+                                             0.02,
+                                             0.02,
+                                             0.020000},
+                                            {18.119999,
+                                             0.02,
+                                             40.120003,
+                                             0.02,
+                                             66.220001,
+                                             0.02,
+                                             71.919998,
+                                             0.02,
+                                             52.720001,
+                                             0.02,
+                                             28.820002},
+                                            {0.02,
+                                             0.02,
+                                             0.02,
+                                             0.02,
+                                             0.02,
+                                             0.02,
+                                             0.02,
+                                             0.02,
+                                             0.02,
+                                             0.02,
+                                             0.020000},
+                                            {13.02,
+                                             0.02,
+                                             28.32,
+                                             0.02,
+                                             46.02,
+                                             0.02,
+                                             49.320004,
+                                             0.02,
+                                             35.619999,
+                                             0.02,
+                                             19.220001}}}}}));
+        }
+        SECTION("inChannels = 4, outChannels = 3, kernelSize = {2,2}, "
+                "inDataSize = {3,3}, stride = {2,  2}, dilation = {2,  2}") {
+            constexpr DimSize_t batchSize = 1;
+            constexpr DimSize_t inChannels = 4;
+            constexpr DimSize_t outChannels = 3;
+
+            constexpr std::array<DimSize_t, DIM> kernelSize{2, 2};
+
+            constexpr std::array<DimSize_t, DIM> inDataSize{4, 4};
+            constexpr std::array<DimSize_t, DIM> outDataSize{7, 7};
+
+            constexpr std::array<DimSize_t, DIM> stride{2, 2};
+            constexpr std::array<DimSize_t, DIM> dilation{2, 2};
+
+            auto input = std::make_shared<Tensor>(Array4D<float,
+                                                          batchSize,
+                                                          inChannels,
+                                                          inDataSize[0],
+                                                          inDataSize[1]>(
+                {{{{{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}},
+
+                   {{10.0, 11.0, 12.0},
+                    {13.0, 14.0, 15.0},
+                    {16.0, 17.0, 18.0}},
+
+                   {{19.0, 20.0, 21.0},
+                    {22.0, 23.0, 24.0},
+                    {25.0, 26.0, 27.0}},
+
+                   {{28.0, 29.0, 30.0},
+                    {31.0, 32.0, 33.0},
+                    {34.0, 35.0, 36.0}}}}}));
+
+            auto weights = std::make_shared<Tensor>(
+                Array4D<float,
+                        inChannels,
+                        outChannels,
+                        kernelSize[0],
+                        kernelSize[1]>({{{{{0.1, 0.2}, {0.3, 0.4}},
+
+                                          {{0.5, 0.6}, {0.7, 0.8}},
+
+                                          {{0.9, 1.0}, {1.1, 1.2}}},
+
+                                         {{{1.3, 1.4}, {1.5, 1.6}},
+
+                                          {{1.7, 1.8}, {1.9, 2.0}},
+
+                                          {{2.1, 2.2}, {2.3, 2.4}}},
+
+                                         {{{2.5, 2.6}, {2.7, 2.8}},
+
+                                          {{2.9, 3.0}, {3.1, 3.2}},
+
+                                          {{3.3, 3.4}, {3.5, 3.6}}},
+
+                                         {{{3.7, 3.8}, {3.9, 4.0}},
+
+                                          {{4.1, 4.2}, {4.3, 4.4}},
+
+                                          {{4.5, 4.6}, {4.7, 4.8}}}}}));
+
+            auto biases = std::make_shared<Tensor>(
+                Array1D<float, outChannels>({{0.010000, 0.020000, 0.030000}}));
+
+            auto op = setupTestConvTranspose<DIM>(batchSize,
+                                                  inChannels,
+                                                  outChannels,
+                                                  kernelSize,
+                                                  inDataSize,
+                                                  stride,
+                                                  dilation,
+                                                  input,
+                                                  weights,
+                                                  biases);
+
+            REQUIRE_NOTHROW(op->forward());
+
+            auto expectedOutput = std::make_shared<Tensor>(
+                Array4D<float,
+                        batchSize,
+                        outChannels,
+                        outDataSize[0],
+                        outDataSize[1]>({{{{{164.209991,
+                                             0.010000,
+                                             341.809998,
+                                             0.010000,
+                                             357.410034,
+                                             0.010000,
+                                             186.009995},
+                                            {0.010000,
+                                             0.010000,
+                                             0.010000,
+                                             0.010000,
+                                             0.010000,
+                                             0.010000,
+                                             0.010000},
+                                            {362.809998,
+                                             0.010000,
+                                             754.410034,
+                                             0.010000,
+                                             787.210083,
+                                             0.010000,
+                                             409.210022},
+                                            {0.010000,
+                                             0.010000,
+                                             0.010000,
+                                             0.010000,
+                                             0.010000,
+                                             0.010000,
+                                             0.010000},
+                                            {410.809998,
+                                             0.010000,
+                                             852.810059,
+                                             0.010000,
+                                             885.609985,
+                                             0.010000,
+                                             459.610016},
+                                            {0.010000,
+                                             0.010000,
+                                             0.010000,
+                                             0.010000,
+                                             0.010000,
+                                             0.010000,
+                                             0.010000},
+                                            {226.209991,
+                                             0.010000,
+                                             469.010010,
+                                             0.010000,
+                                             486.210022,
+                                             0.010000,
+                                             252.009995}},
+
+                                           {{187.419998,
+                                             0.020000,
+                                             389.820007,
+                                             0.020000,
+                                             408.619995,
+                                             0.020000,
+                                             212.420013},
+                                            {0.020000,
+                                             0.020000,
+                                             0.020000,
+                                             0.020000,
+                                             0.020000,
+                                             0.020000,
+                                             0.020000},
+                                            {414.019989,
+                                             0.020000,
+                                             860.020020,
+                                             0.020000,
+                                             899.220032,
+                                             0.020000,
+                                             466.820007},
+                                            {0.020000,
+                                             0.020000,
+                                             0.020000,
+                                             0.020000,
+                                             0.020000,
+                                             0.020000,
+                                             0.020000},
+                                            {471.620026,
+                                             0.020000,
+                                             977.619995,
+                                             0.020000,
+                                             1016.820068,
+                                             0.020000,
+                                             526.820007},
+                                            {0.020000,
+                                             0.020000,
+                                             0.020000,
+                                             0.020000,
+                                             0.020000,
+                                             0.020000,
+                                             0.020000},
+                                            {259.019989,
+                                             0.020000,
+                                             536.220032,
+                                             0.020000,
+                                             556.619995,
+                                             0.020000,
+                                             288.019989}},
+
+                                           {{210.630005,
+                                             0.030000,
+                                             437.829987,
+                                             0.030000,
+                                             459.829987,
+                                             0.030000,
+                                             238.830002},
+                                            {0.030000,
+                                             0.030000,
+                                             0.030000,
+                                             0.030000,
+                                             0.030000,
+                                             0.030000,
+                                             0.030000},
+                                            {465.230011,
+                                             0.030000,
+                                             965.630005,
+                                             0.030000,
+                                             1011.230103,
+                                             0.030000,
+                                             524.430054},
+                                            {0.030000,
+                                             0.030000,
+                                             0.030000,
+                                             0.030000,
+                                             0.030000,
+                                             0.030000,
+                                             0.030000},
+                                            {532.430054,
+                                             0.030000,
+                                             1102.430054,
+                                             0.030000,
+                                             1148.030029,
+                                             0.030000,
+                                             594.030029},
+                                            {0.030000,
+                                             0.030000,
+                                             0.030000,
+                                             0.030000,
+                                             0.030000,
+                                             0.030000,
+                                             0.030000},
+                                            {291.830017,
+                                             0.030000,
+                                             603.430054,
+                                             0.030000,
+                                             627.030029,
+                                             0.030000,
+                                             324.029999}}}}}));
+        }
+        SECTION("Big test to ensure kernel capabilities 1") {
+            constexpr DimSize_t batchSize = 1;
+            constexpr DimSize_t inChannels = 3;
+            constexpr DimSize_t outChannels = 4;
+
+            constexpr std::array<DimSize_t, DIM> kernelSize{2, 2};
+
+            constexpr std::array<DimSize_t, DIM> inDataSize{6, 5};
+            constexpr std::array<DimSize_t, DIM> outDataSize{8, 17};
+
+            constexpr std::array<DimSize_t, DIM> stride{1, 3};
+            constexpr std::array<DimSize_t, DIM> dilation{2, 4};
+
+            auto input = std::make_shared<Tensor>(
+                Array4D<float,
+                        batchSize,
+                        inChannels,
+                        inDataSize[0],
+                        inDataSize[1]>({{{{{1., 2., 3., 4., 5.},
+                                           {6., 7., 8., 9., 10.},
+                                           {11., 12., 13., 14., 15.},
+                                           {16., 17., 18., 19., 20.},
+                                           {21., 22., 23., 24., 25.},
+                                           {26., 27., 28., 29., 30.}},
+
+                                          {{31., 32., 33., 34., 35.},
+                                           {36., 37., 38., 39., 40.},
+                                           {41., 42., 43., 44., 45.},
+                                           {46., 47., 48., 49., 50.},
+                                           {51., 52., 53., 54., 55.},
+                                           {56., 57., 58., 59., 60.}},
+
+                                          {{61., 62., 63., 64., 65.},
+                                           {66., 67., 68., 69., 70.},
+                                           {71., 72., 73., 74., 75.},
+                                           {76., 77., 78., 79., 80.},
+                                           {81., 82., 83., 84., 85.},
+                                           {86., 87., 88., 89., 90.}}}}}));
+
+            auto weights = std::make_shared<Tensor>(Array4D<float,
+                                                            inChannels,
+                                                            outChannels,
+                                                            kernelSize[0],
+                                                            kernelSize[1]>(
+                {{{{{0.100000, 0.200000}, {0.300000, 0.400000}},
+
+                   {{0.500000, 0.600000}, {0.700000, 0.800000}},
+
+                   {{0.900000, 1.000000}, {1.100000, 1.200000}},
+
+                   {{1.300000, 1.400000}, {1.500000, 1.600000}}},
+
+                  {{{1.700000, 1.800000}, {1.900000, 2.000000}},
+
+                   {{2.100000, 2.200000}, {2.300000, 2.400000}},
+
+                   {{2.500000, 2.600000}, {2.700000, 2.800000}},
+
+                   {{2.900000, 3.000000}, {3.100000, 3.200000}}},
+
+                  {{{3.300000, 3.400000}, {3.500000, 3.600000}},
+
+                   {{3.700000, 3.800000}, {3.900000, 4.000000}},
+
+                   {{4.100000, 4.200000}, {4.300000, 4.400000}},
+
+                   {{4.500000, 4.600000}, {4.700000, 4.800000}}}}}));
+
+            auto biases = std::make_shared<Tensor>(
+                Array1D<float, outChannels>({{0.01, 0.02, 0.03, 0.04}}));
+
+            auto op = setupTestConvTranspose<DIM>(batchSize,
+                                                  inChannels,
+                                                  outChannels,
+                                                  kernelSize,
+                                                  inDataSize,
+                                                  stride,
+                                                  dilation,
+                                                  input,
+                                                  weights,
+                                                  biases);
+
+            REQUIRE_NOTHROW(op->forward());
+
+            auto expectedOutput = std::make_shared<Tensor>(
+                Array4D<float,
+                        batchSize,
+                        outChannels,
+                        outDataSize[0],
+                        outDataSize[1]>({{{{{254.110001,
+                                             0.010000,
+                                             0.010000,
+                                             259.210022,
+                                             263.410034,
+                                             0.010000,
+                                             264.309998,
+                                             268.810028,
+                                             0.010000,
+                                             269.410004,
+                                             274.210022,
+                                             0.010000,
+                                             274.510010,
+                                             279.610016,
+                                             0.010000,
+                                             0.010000,
+                                             285.010010},
+                                            {279.610016,
+                                             0.010000,
+                                             0.010000,
+                                             284.710022,
+                                             290.410004,
+                                             0.010000,
+                                             289.809998,
+                                             295.810028,
+                                             0.010000,
+                                             294.910004,
+                                             301.210022,
+                                             0.010000,
+                                             300.010010,
+                                             306.610016,
+                                             0.010000,
+                                             0.010000,
+                                             312.010010},
+                                            {577.810059,
+                                             0.010000,
+                                             0.010000,
+                                             588.609985,
+                                             599.410034,
+                                             0.010000,
+                                             599.410034,
+                                             610.810059,
+                                             0.010000,
+                                             610.209961,
+                                             622.210022,
+                                             0.010000,
+                                             621.010010,
+                                             633.609985,
+                                             0.010000,
+                                             0.010000,
+                                             645.010010},
+                                            {631.810059,
+                                             0.010000,
+                                             0.010000,
+                                             642.609985,
+                                             656.410034,
+                                             0.010000,
+                                             653.410034,
+                                             667.810059,
+                                             0.010000,
+                                             664.209961,
+                                             679.210022,
+                                             0.010000,
+                                             675.010010,
+                                             690.609985,
+                                             0.010000,
+                                             0.010000,
+                                             702.010010},
+                                            {685.810059,
+                                             0.010000,
+                                             0.010000,
+                                             696.609985,
+                                             713.410034,
+                                             0.010000,
+                                             707.410034,
+                                             724.810059,
+                                             0.010000,
+                                             718.209961,
+                                             736.210022,
+                                             0.010000,
+                                             729.010010,
+                                             747.609985,
+                                             0.010000,
+                                             0.010000,
+                                             759.010010},
+                                            {739.810059,
+                                             0.010000,
+                                             0.010000,
+                                             750.609985,
+                                             770.410034,
+                                             0.010000,
+                                             761.410034,
+                                             781.810059,
+                                             0.010000,
+                                             772.209961,
+                                             793.210022,
+                                             0.010000,
+                                             783.010010,
+                                             804.609985,
+                                             0.010000,
+                                             0.010000,
+                                             816.010010},
+                                            {386.710022,
+                                             0.010000,
+                                             0.010000,
+                                             392.410004,
+                                             402.010010,
+                                             0.010000,
+                                             398.110016,
+                                             408.010010,
+                                             0.010000,
+                                             403.809998,
+                                             414.010010,
+                                             0.010000,
+                                             409.510010,
+                                             420.010010,
+                                             0.010000,
+                                             0.010000,
+                                             426.010010},
+                                            {415.210022,
+                                             0.010000,
+                                             0.010000,
+                                             420.910004,
+                                             432.010010,
+                                             0.010000,
+                                             426.610016,
+                                             438.010040,
+                                             0.010000,
+                                             432.309998,
+                                             444.010010,
+                                             0.010000,
+                                             438.010010,
+                                             450.010040,
+                                             0.010000,
+                                             0.010000,
+                                             456.010010}},
+
+                                           {{291.320007,
+                                             0.020000,
+                                             0.020000,
+                                             297.619995,
+                                             300.619995,
+                                             0.020000,
+                                             303.919983,
+                                             307.219971,
+                                             0.020000,
+                                             310.220001,
+                                             313.819977,
+                                             0.020000,
+                                             316.519989,
+                                             320.419983,
+                                             0.020000,
+                                             0.020000,
+                                             327.019989},
+                                            {322.820007,
+                                             0.020000,
+                                             0.020000,
+                                             329.119995,
+                                             333.619995,
+                                             0.020000,
+                                             335.419983,
+                                             340.219971,
+                                             0.020000,
+                                             341.720001,
+                                             346.819977,
+                                             0.020000,
+                                             348.019989,
+                                             353.419983,
+                                             0.020000,
+                                             0.020000,
+                                             360.019989},
+                                            {664.220032,
+                                             0.020000,
+                                             0.020000,
+                                             677.420044,
+                                             685.820068,
+                                             0.020000,
+                                             690.619995,
+                                             699.619995,
+                                             0.020000,
+                                             703.820068,
+                                             713.420044,
+                                             0.020000,
+                                             717.020020,
+                                             727.219971,
+                                             0.020000,
+                                             0.020000,
+                                             741.020020},
+                                            {730.220032,
+                                             0.020000,
+                                             0.020000,
+                                             743.420044,
+                                             754.820068,
+                                             0.020000,
+                                             756.619995,
+                                             768.619995,
+                                             0.020000,
+                                             769.820068,
+                                             782.420044,
+                                             0.020000,
+                                             783.020020,
+                                             796.219971,
+                                             0.020000,
+                                             0.020000,
+                                             810.020020},
+                                            {796.220032,
+                                             0.020000,
+                                             0.020000,
+                                             809.420044,
+                                             823.820068,
+                                             0.020000,
+                                             822.620056,
+                                             837.619995,
+                                             0.020000,
+                                             835.820068,
+                                             851.420044,
+                                             0.020000,
+                                             849.020020,
+                                             865.219971,
+                                             0.020000,
+                                             0.020000,
+                                             879.020020},
+                                            {862.220032,
+                                             0.020000,
+                                             0.020000,
+                                             875.420044,
+                                             892.820068,
+                                             0.020000,
+                                             888.619995,
+                                             906.619995,
+                                             0.020000,
+                                             901.820068,
+                                             920.420044,
+                                             0.020000,
+                                             915.020020,
+                                             934.219971,
+                                             0.020000,
+                                             0.020000,
+                                             948.020020},
+                                            {447.919983,
+                                             0.020000,
+                                             0.020000,
+                                             454.820007,
+                                             463.220001,
+                                             0.020000,
+                                             461.720001,
+                                             470.420013,
+                                             0.020000,
+                                             468.619995,
+                                             477.619995,
+                                             0.020000,
+                                             475.519989,
+                                             484.819977,
+                                             0.020000,
+                                             0.020000,
+                                             492.019989},
+                                            {482.419983,
+                                             0.020000,
+                                             0.020000,
+                                             489.320007,
+                                             499.220001,
+                                             0.020000,
+                                             496.220001,
+                                             506.420013,
+                                             0.020000,
+                                             503.119995,
+                                             513.619995,
+                                             0.020000,
+                                             510.019989,
+                                             520.820007,
+                                             0.020000,
+                                             0.020000,
+                                             528.020020}},
+
+                                           {{328.529999,
+                                             0.030000,
+                                             0.030000,
+                                             336.029999,
+                                             337.830017,
+                                             0.030000,
+                                             343.529999,
+                                             345.630035,
+                                             0.030000,
+                                             351.029999,
+                                             353.430023,
+                                             0.030000,
+                                             358.529999,
+                                             361.230011,
+                                             0.030000,
+                                             0.030000,
+                                             369.030029},
+                                            {366.029999,
+                                             0.030000,
+                                             0.030000,
+                                             373.529999,
+                                             376.830017,
+                                             0.030000,
+                                             381.029999,
+                                             384.630035,
+                                             0.030000,
+                                             388.529999,
+                                             392.430023,
+                                             0.030000,
+                                             396.029999,
+                                             400.230042,
+                                             0.030000,
+                                             0.030000,
+                                             408.030029},
+                                            {750.630005,
+                                             0.030000,
+                                             0.030000,
+                                             766.230042,
+                                             772.230042,
+                                             0.030000,
+                                             781.830078,
+                                             788.430054,
+                                             0.030000,
+                                             797.430054,
+                                             804.630066,
+                                             0.030000,
+                                             813.030029,
+                                             820.830078,
+                                             0.030000,
+                                             0.030000,
+                                             837.030029},
+                                            {828.630005,
+                                             0.030000,
+                                             0.030000,
+                                             844.230042,
+                                             853.230042,
+                                             0.030000,
+                                             859.830078,
+                                             869.430054,
+                                             0.030000,
+                                             875.430054,
+                                             885.630066,
+                                             0.030000,
+                                             891.030029,
+                                             901.830078,
+                                             0.030000,
+                                             0.030000,
+                                             918.030029},
+                                            {906.630005,
+                                             0.030000,
+                                             0.030000,
+                                             922.230042,
+                                             934.230042,
+                                             0.030000,
+                                             937.830078,
+                                             950.430054,
+                                             0.030000,
+                                             953.430054,
+                                             966.630066,
+                                             0.030000,
+                                             969.030029,
+                                             982.830078,
+                                             0.030000,
+                                             0.030000,
+                                             999.030090},
+                                            {984.630005,
+                                             0.030000,
+                                             0.030000,
+                                             1000.230042,
+                                             1015.230103,
+                                             0.030000,
+                                             1015.830078,
+                                             1031.430054,
+                                             0.030000,
+                                             1031.430054,
+                                             1047.630127,
+                                             0.030000,
+                                             1047.030029,
+                                             1063.830078,
+                                             0.030000,
+                                             0.030000,
+                                             1080.030029},
+                                            {509.130005,
+                                             0.030000,
+                                             0.030000,
+                                             517.230042,
+                                             524.430054,
+                                             0.030000,
+                                             525.330078,
+                                             532.830017,
+                                             0.030000,
+                                             533.430054,
+                                             541.230042,
+                                             0.030000,
+                                             541.530029,
+                                             549.630066,
+                                             0.030000,
+                                             0.030000,
+                                             558.030029},
+                                            {549.630066,
+                                             0.030000,
+                                             0.030000,
+                                             557.730042,
+                                             566.430054,
+                                             0.030000,
+                                             565.830078,
+                                             574.830017,
+                                             0.030000,
+                                             573.930054,
+                                             583.230042,
+                                             0.030000,
+                                             582.030029,
+                                             591.630066,
+                                             0.030000,
+                                             0.030000,
+                                             600.030029}},
+
+                                           {{365.740021,
+                                             0.040000,
+                                             0.040000,
+                                             374.440002,
+                                             375.040009,
+                                             0.040000,
+                                             383.140015,
+                                             384.040009,
+                                             0.040000,
+                                             391.839996,
+                                             393.040009,
+                                             0.040000,
+                                             400.540009,
+                                             402.040009,
+                                             0.040000,
+                                             0.040000,
+                                             411.040009},
+                                            {409.240021,
+                                             0.040000,
+                                             0.040000,
+                                             417.940002,
+                                             420.040009,
+                                             0.040000,
+                                             426.640015,
+                                             429.040009,
+                                             0.040000,
+                                             435.339996,
+                                             438.040009,
+                                             0.040000,
+                                             444.040009,
+                                             447.040009,
+                                             0.040000,
+                                             0.040000,
+                                             456.040009},
+                                            {837.039978,
+                                             0.040000,
+                                             0.040000,
+                                             855.040039,
+                                             858.639954,
+                                             0.040000,
+                                             873.039978,
+                                             877.239990,
+                                             0.040000,
+                                             891.039978,
+                                             895.840027,
+                                             0.040000,
+                                             909.039978,
+                                             914.440002,
+                                             0.040000,
+                                             0.040000,
+                                             933.039978},
+                                            {927.039978,
+                                             0.040000,
+                                             0.040000,
+                                             945.040039,
+                                             951.639954,
+                                             0.040000,
+                                             963.039978,
+                                             970.239990,
+                                             0.040000,
+                                             981.039978,
+                                             988.840027,
+                                             0.040000,
+                                             999.039978,
+                                             1007.440002,
+                                             0.040000,
+                                             0.040000,
+                                             1026.040039},
+                                            {1017.039978,
+                                             0.040000,
+                                             0.040000,
+                                             1035.040039,
+                                             1044.640015,
+                                             0.040000,
+                                             1053.040039,
+                                             1063.239990,
+                                             0.040000,
+                                             1071.040039,
+                                             1081.840088,
+                                             0.040000,
+                                             1089.040039,
+                                             1100.440063,
+                                             0.040000,
+                                             0.040000,
+                                             1119.040039},
+                                            {1107.040039,
+                                             0.040000,
+                                             0.040000,
+                                             1125.040039,
+                                             1137.640137,
+                                             0.040000,
+                                             1143.040039,
+                                             1156.239990,
+                                             0.040000,
+                                             1161.040039,
+                                             1174.840088,
+                                             0.040000,
+                                             1179.040039,
+                                             1193.440063,
+                                             0.040000,
+                                             0.040000,
+                                             1212.040039},
+                                            {570.340027,
+                                             0.040000,
+                                             0.040000,
+                                             579.640015,
+                                             585.640015,
+                                             0.040000,
+                                             588.940002,
+                                             595.239990,
+                                             0.040000,
+                                             598.239990,
+                                             604.840027,
+                                             0.040000,
+                                             607.540039,
+                                             614.440002,
+                                             0.040000,
+                                             0.040000,
+                                             624.039978},
+                                            {616.840027,
+                                             0.040000,
+                                             0.040000,
+                                             626.140015,
+                                             633.640015,
+                                             0.040000,
+                                             635.440002,
+                                             643.239990,
+                                             0.040000,
+                                             644.739990,
+                                             652.840027,
+                                             0.040000,
+                                             654.040039,
+                                             662.440002,
+                                             0.040000,
+                                             0.040000,
+                                             672.039978}}}}}));
+            CHECK(approxEq<float, float>(*op->getOutput(0), *expectedOutput));
+        }
+        SECTION("Big test to ensure kernel capabilities") {
+            constexpr DimSize_t batchSize = 1;
+            constexpr DimSize_t inChannels = 3;
+            constexpr DimSize_t outChannels = 4;
+
+            constexpr std::array<DimSize_t, DIM> kernelSize{6, 4};
+
+            constexpr std::array<DimSize_t, DIM> inDataSize{6, 5};
+            constexpr std::array<DimSize_t, DIM> outDataSize{16, 25};
+
+            constexpr std::array<DimSize_t, DIM> stride{1, 3};
+            constexpr std::array<DimSize_t, DIM> dilation{2, 4};
+
+            auto input = std::make_shared<Tensor>(
+                Array4D<float,
+                        batchSize,
+                        inChannels,
+                        inDataSize[0],
+                        inDataSize[1]>({{{{{1., 2., 3., 4., 5.},
+                                           {6., 7., 8., 9., 10.},
+                                           {11., 12., 13., 14., 15.},
+                                           {16., 17., 18., 19., 20.},
+                                           {21., 22., 23., 24., 25.},
+                                           {26., 27., 28., 29., 30.}},
+
+                                          {{31., 32., 33., 34., 35.},
+                                           {36., 37., 38., 39., 40.},
+                                           {41., 42., 43., 44., 45.},
+                                           {46., 47., 48., 49., 50.},
+                                           {51., 52., 53., 54., 55.},
+                                           {56., 57., 58., 59., 60.}},
+
+                                          {{61., 62., 63., 64., 65.},
+                                           {66., 67., 68., 69., 70.},
+                                           {71., 72., 73., 74., 75.},
+                                           {76., 77., 78., 79., 80.},
+                                           {81., 82., 83., 84., 85.},
+                                           {86., 87., 88., 89., 90.}}}}}));
+
+            auto weights = std::make_shared<Tensor>(Array4D<float,
+                                                            inChannels,
+                                                            outChannels,
+                                                            kernelSize[0],
+                                                            kernelSize[1]>(
+                {{{{{0.100000, 0.200000, 0.300000, 0.400000},
+                    {0.500000, 0.600000, 0.700000, 0.800000},
+                    {0.900000, 1.000000, 1.100000, 1.200000},
+                    {1.300000, 1.400000, 1.500000, 1.600000},
+                    {1.700000, 1.800000, 1.900000, 2.000000},
+                    {2.100000, 2.200000, 2.300000, 2.400000}},
+
+                   {{2.500000, 2.600000, 2.700000, 2.800000},
+                    {2.900000, 3.000000, 3.100000, 3.200000},
+                    {3.300000, 3.400000, 3.500000, 3.600000},
+                    {3.700000, 3.800000, 3.900000, 4.000000},
+                    {4.100000, 4.200000, 4.300000, 4.400000},
+                    {4.500000, 4.600000, 4.700000, 4.800000}},
+
+                   {{4.900000, 5.000000, 5.100000, 5.200000},
+                    {5.300000, 5.400000, 5.500000, 5.600000},
+                    {5.700000, 5.800000, 5.900000, 6.000000},
+                    {6.100000, 6.200000, 6.300000, 6.400000},
+                    {6.500000, 6.600000, 6.700000, 6.800000},
+                    {6.900000, 7.000000, 7.100000, 7.200000}},
+
+                   {{7.300000, 7.400000, 7.500000, 7.600000},
+                    {7.700000, 7.800000, 7.900000, 8.000000},
+                    {8.100000, 8.200000, 8.300000, 8.400001},
+                    {8.500000, 8.600000, 8.700000, 8.800000},
+                    {8.900001, 9.000000, 9.100000, 9.200000},
+                    {9.300000, 9.400001, 9.500000, 9.600000}}},
+
+                  {{{9.700000, 9.800000, 9.900001, 10.000000},
+                    {10.100000, 10.200000, 10.300000, 10.400001},
+                    {10.500000, 10.600000, 10.700000, 10.800000},
+                    {10.900001, 11.000000, 11.100000, 11.200000},
+                    {11.300000, 11.400001, 11.500000, 11.600000},
+                    {11.700000, 11.800000, 11.900001, 12.000000}},
+
+                   {{12.100000, 12.200000, 12.300000, 12.400001},
+                    {12.500000, 12.600000, 12.700000, 12.800000},
+                    {12.900001, 13.000000, 13.100000, 13.200000},
+                    {13.300000, 13.400001, 13.500000, 13.600000},
+                    {13.700000, 13.800000, 13.900001, 14.000000},
+                    {14.100000, 14.200000, 14.300000, 14.400001}},
+
+                   {{14.500000, 14.600000, 14.700000, 14.800000},
+                    {14.900001, 15.000000, 15.100000, 15.200000},
+                    {15.300000, 15.400001, 15.500000, 15.600000},
+                    {15.700000, 15.800000, 15.900001, 16.000000},
+                    {16.100000, 16.200001, 16.300001, 16.400000},
+                    {16.500000, 16.600000, 16.700001, 16.800001}},
+
+                   {{16.900000, 17.000000, 17.100000, 17.200001},
+                    {17.300001, 17.400000, 17.500000, 17.600000},
+                    {17.700001, 17.800001, 17.900000, 18.000000},
+                    {18.100000, 18.200001, 18.300001, 18.400000},
+                    {18.500000, 18.600000, 18.700001, 18.800001},
+                    {18.900000, 19.000000, 19.100000, 19.200001}}},
+
+                  {{{19.300001, 19.400000, 19.500000, 19.600000},
+                    {19.700001, 19.800001, 19.900000, 20.000000},
+                    {20.100000, 20.200001, 20.300001, 20.400000},
+                    {20.500000, 20.600000, 20.700001, 20.800001},
+                    {20.900000, 21.000000, 21.100000, 21.200001},
+                    {21.300001, 21.400000, 21.500000, 21.600000}},
+
+                   {{21.700001, 21.800001, 21.900000, 22.000000},
+                    {22.100000, 22.200001, 22.300001, 22.400000},
+                    {22.500000, 22.600000, 22.700001, 22.800001},
+                    {22.900000, 23.000000, 23.100000, 23.200001},
+                    {23.300001, 23.400000, 23.500000, 23.600000},
+                    {23.700001, 23.800001, 23.900000, 24.000000}},
+
+                   {{24.100000, 24.200001, 24.300001, 24.400000},
+                    {24.500000, 24.600000, 24.700001, 24.800001},
+                    {24.900000, 25.000000, 25.100000, 25.200001},
+                    {25.300001, 25.400000, 25.500000, 25.600000},
+                    {25.700001, 25.800001, 25.900000, 26.000000},
+                    {26.100000, 26.200001, 26.300001, 26.400000}},
+
+                   {{26.500000, 26.600000, 26.700001, 26.800001},
+                    {26.900000, 27.000000, 27.100000, 27.200001},
+                    {27.300001, 27.400000, 27.500000, 27.600000},
+                    {27.700001, 27.800001, 27.900000, 28.000000},
+                    {28.100000, 28.200001, 28.300001, 28.400000},
+                    {28.500000, 28.600000, 28.700001, 28.800001}}}}}));
+
+            auto biases = std::make_shared<Tensor>(
+                Array1D<float, outChannels>({{0.01, 0.02, 0.03, 0.04}}));
+
+            auto op = setupTestConvTranspose<DIM>(batchSize,
+                                                  inChannels,
+                                                  outChannels,
+                                                  kernelSize,
+                                                  inDataSize,
+                                                  stride,
+                                                  dilation,
+                                                  input,
+                                                  weights,
+                                                  biases);
+
+            REQUIRE_NOTHROW(op->forward());
+
+            auto expectedOutput =
+                std::make_shared<Tensor>(Array4D<float,
+                                                 batchSize,
+                                                 outChannels,
+                                                 outDataSize[0],
+                                                 outDataSize[1]>(
+                    {{{{{1478.110107, 0.010000,    0.010000,    1507.210083,
+                         1487.410034, 0.010000,    1536.310059, 1516.809937,
+                         1496.709961, 1565.410034, 1546.209961, 1526.410034,
+                         3100.510010, 1575.609985, 1556.109985, 1536.010010,
+                         1605.010010, 1585.810059, 1566.010010, 0.010000,
+                         1615.510010, 1596.010010, 0.010000,    0.010000,
+                         1626.010010},
+                        {1623.610107, 0.010000,    0.010000,    1652.710083,
+                         1634.410034, 0.010000,    1681.810059, 1663.809937,
+                         1645.209961, 1710.910034, 1693.209961, 1674.910034,
+                         3396.010010, 1722.609985, 1704.610107, 1686.010010,
+                         1752.010010, 1734.310059, 1716.010010, 0.010000,
+                         1764.010010, 1746.010010, 0.010000,    0.010000,
+                         1776.010010},
+                        {3284.410156, 0.010000,    0.010000,    3343.810303,
+                         3306.010010, 0.010000,    3403.210205, 3366.010010,
+                         3327.610107, 3462.610107, 3426.010010, 3388.209961,
+                         6871.209961, 3486.010010, 3448.810059, 3410.409912,
+                         3546.010010, 3509.409912, 3471.610107, 0.010000,
+                         3570.010010, 3532.810059, 0.010000,    0.010000,
+                         3594.010010},
+                        {3581.410156, 0.010000,    0.010000,    3640.810303,
+                         3606.010010, 0.010000,    3700.210205, 3666.010010,
+                         3630.610107, 3759.610107, 3726.010010, 3691.209961,
+                         7474.209961, 3786.010010, 3751.810059, 3716.409912,
+                         3846.010010, 3812.409912, 3777.610107, 0.010000,
+                         3873.010010, 3838.810059, 0.010000,    0.010000,
+                         3900.010010},
+                        {5430.910156,  0.010000,    0.010000,    5521.809570,
+                         5467.809570,  0.010000,    5612.709961, 5559.609863,
+                         5504.709961,  5703.609863, 5651.409668, 5597.409668,
+                         11336.110352, 5743.209961, 5690.109863, 5635.209473,
+                         5835.009766,  5782.809570, 5728.809570, 0.010000,
+                         5875.509766,  5822.409668, 0.010000,    0.010000,
+                         5916.009766},
+                        {5885.410156,  0.010000,    0.010000,    5976.310059,
+                         5926.809570,  0.010000,    6067.209961, 6018.609863,
+                         5968.209961,  6158.110352, 6110.409668, 6060.909668,
+                         12258.610352, 6202.209961, 6153.609375, 6103.209473,
+                         6294.009766,  6246.309570, 6196.809570, 0.010000,
+                         6339.009766,  6290.409668, 0.010000,    0.010000,
+                         6384.009766},
+                        {5578.509766,  0.010000,    0.010000,    5673.009766,
+                         5615.410156,  0.010000,    5767.510254, 5710.809570,
+                         5652.309570,  5862.009766, 5806.209961, 5748.609863,
+                         11645.710938, 5901.609863, 5844.909668, 5786.409668,
+                         5997.009766,  5941.209961, 5883.609375, 0.010000,
+                         6037.509766,  5980.809570, 0.010000,    0.010000,
+                         6078.009766},
+                        {6051.009766,  0.010000,    0.010000,    6145.509766,
+                         6092.410156,  0.010000,    6240.010254, 6187.810059,
+                         6133.809570,  6334.509766, 6283.209961, 6230.109863,
+                         12604.208984, 6378.610352, 6326.409668, 6272.410156,
+                         6474.009766,  6422.709961, 6369.609375, 0.010000,
+                         6519.009766,  6466.809570, 0.010000,    0.010000,
+                         6564.009766},
+                        {5726.109863,  0.010000,    0.010000,    5824.209473,
+                         5763.009766,  0.010000,    5922.309570, 5862.009766,
+                         5799.910156,  6020.409668, 5961.010254, 5899.809570,
+                         11955.309570, 6060.009766, 5999.709961, 5937.609863,
+                         6159.009766,  6099.609863, 6038.409668, 0.010000,
+                         6199.509766,  6139.209961, 0.010000,    0.010000,
+                         6240.009766},
+                        {6216.609863,  0.010000,    0.010000,    6314.709473,
+                         6258.009766,  0.010000,    6412.809570, 6357.009766,
+                         6299.410156,  6510.909668, 6456.010254, 6399.310059,
+                         12949.809570, 6555.009766, 6499.209961, 6441.609863,
+                         6654.009766,  6599.110352, 6542.409668, 0.010000,
+                         6699.009766,  6643.209961, 0.010000,    0.010000,
+                         6744.009766},
+                        {5873.709961,  0.010000,    0.010000,    5975.409668,
+                         5910.609863,  0.010000,    6077.109375, 6013.209473,
+                         5947.509766,  6178.809570, 6115.809570, 6051.009766,
+                         12264.910156, 6218.409668, 6154.510254, 6088.809570,
+                         6321.009766,  6258.009766, 6193.209961, 0.010000,
+                         6361.509766,  6297.610352, 0.010000,    0.010000,
+                         6402.009766},
+                        {6382.209473,  0.010000,    0.010000,    6483.910156,
+                         6423.609863,  0.010000,    6585.609375, 6526.209473,
+                         6465.009766,  6687.309570, 6628.809570, 6568.509766,
+                         13295.410156, 6731.409668, 6672.010254, 6610.810059,
+                         6834.009766,  6775.509766, 6715.209961, 0.010000,
+                         6879.009766,  6819.610352, 0.010000,    0.010000,
+                         6924.009766},
+                        {4320.009766, 0.010000,    0.010000,    4389.009766,
+                         4347.609863, 0.010000,    4458.009766, 4417.209961,
+                         4375.209961, 4527.009766, 4486.809570, 4445.409668,
+                         8998.809570, 4556.409668, 4515.609863, 4473.609863,
+                         4626.009766, 4585.809570, 4544.410156, 0.010000,
+                         4656.009766, 4615.209961, 0.010000,    0.010000,
+                         4686.009766},
+                        {4665.009766, 0.010000,    0.010000,    4734.009766,
+                         4695.609375, 0.010000,    4803.009766, 4765.209961,
+                         4726.209961, 4872.009766, 4834.809570, 4796.409668,
+                         9697.809570, 4904.409668, 4866.609863, 4827.609863,
+                         4974.009766, 4936.809570, 4898.410156, 0.010000,
+                         5007.009766, 4969.209961, 0.010000,    0.010000,
+                         5040.009766},
+                        {2366.110107, 0.010000,    0.010000,    2401.209961,
+                         2381.409912, 0.010000,    2436.310059, 2416.810059,
+                         2396.709961, 2471.410156, 2452.209961, 2432.409912,
+                         4918.509766, 2487.609863, 2468.110107, 2448.010010,
+                         2523.010010, 2503.810059, 2484.010010, 0.010000,
+                         2539.510010, 2520.010010, 0.010000,    0.010000,
+                         2556.010010},
+                        {2541.610107, 0.010000,    0.010000,    2576.710205,
+                         2558.409912, 0.010000,    2611.810059, 2593.810059,
+                         2575.209961, 2646.910156, 2629.209961, 2610.909912,
+                         5274.009766, 2664.609863, 2646.610107, 2628.010010,
+                         2700.010010, 2682.310059, 2664.010010, 0.010000,
+                         2718.010010, 2700.010010, 0.010000,    0.010000,
+                         2736.010010}},
+
+                       {{1701.320068, 0.020000,    0.020000,    1737.620117,
+                         1710.620117, 0.020000,    1773.920044, 1747.220093,
+                         1719.920044, 1810.220093, 1783.820068, 1756.819946,
+                         3575.719971, 1820.420044, 1793.719971, 1766.420044,
+                         1857.020142, 1830.619995, 1803.619995, 0.020000,
+                         1867.520020, 1840.820068, 0.020000,    0.020000,
+                         1878.020020},
+                        {1882.820068, 0.020000,    0.020000,    1919.120117,
+                         1893.620117, 0.020000,    1955.420044, 1930.220093,
+                         1904.420044, 1991.720093, 1966.820068, 1941.319946,
+                         3943.219971, 2003.420044, 1978.219971, 1952.420044,
+                         2040.020142, 2015.119995, 1989.620117, 0.020000,
+                         2052.020020, 2026.820068, 0.020000,    0.020000,
+                         2064.020020},
+                        {3802.820068, 0.020000,    0.020000,    3876.620117,
+                         3824.420166, 0.020000,    3950.420166, 3898.820068,
+                         3846.020020, 4024.220215, 3973.220215, 3921.020020,
+                         7965.620117, 4047.620117, 3996.020020, 3943.219727,
+                         4122.020020, 4071.020020, 4018.820068, 0.020000,
+                         4146.020020, 4094.419922, 0.020000,    0.020000,
+                         4170.020020},
+                        {4171.819824, 0.020000,    0.020000,    4245.620117,
+                         4196.420410, 0.020000,    4319.420410, 4270.819824,
+                         4221.020020, 4393.220215, 4345.220215, 4296.020020,
+                         8712.620117, 4419.620605, 4371.020020, 4321.219727,
+                         4494.020020, 4446.020020, 4396.819824, 0.020000,
+                         4521.020020, 4472.419922, 0.020000,    0.020000,
+                         4548.020020},
+                        {6316.520020,  0.020000,    0.020000,    6429.020020,
+                         6353.420410,  0.020000,    6541.520508, 6466.819824,
+                         6390.319824,  6654.020020, 6580.220215, 6504.620117,
+                         13193.718750, 6693.620605, 6618.919922, 6542.420410,
+                         6807.020020,  6733.220215, 6657.619629, 0.020000,
+                         6847.520020,  6772.819824, 0.020000,    0.020000,
+                         6888.020020},
+                        {6879.020020,  0.020000,    0.020000,    6991.520020,
+                         6920.420410,  0.020000,    7104.020508, 7033.820312,
+                         6961.819824,  7216.520020, 7147.220215, 7076.120117,
+                         14332.218750, 7260.620605, 7190.420410, 7118.420410,
+                         7374.020020,  7304.720215, 7233.619629, 0.020000,
+                         7419.020020,  7348.819824, 0.020000,    0.020000,
+                         7464.020020},
+                        {6464.120117,  0.020000,    0.020000,    6580.219727,
+                         6501.020020,  0.020000,    6696.319824, 6618.020020,
+                         6537.920410,  6812.419922, 6735.020508, 6655.819824,
+                         13503.319336, 6852.020020, 6773.720215, 6693.620117,
+                         6969.020020,  6891.620605, 6812.419922, 0.020000,
+                         7009.520020,  6931.220215, 0.020000,    0.020000,
+                         7050.020020},
+                        {7044.620117,  0.020000,    0.020000,    7160.720215,
+                         7086.020020,  0.020000,    7276.819824, 7203.020020,
+                         7127.420410,  7392.919434, 7320.020508, 7245.320312,
+                         14677.819336, 7437.020020, 7363.220215, 7287.620117,
+                         7554.020020,  7481.120605, 7406.420410, 0.020000,
+                         7599.020020,  7525.220215, 0.020000,    0.020000,
+                         7644.020020},
+                        {6611.719727,  0.020000,    0.020000,    6731.420410,
+                         6648.620117,  0.020000,    6851.119629, 6769.219727,
+                         6685.520020,  6970.819824, 6889.819824, 6807.020020,
+                         13812.919922, 7010.419922, 6928.520508, 6844.819824,
+                         7131.020020,  7050.020020, 6967.220215, 0.020000,
+                         7171.520020,  7089.620605, 0.020000,    0.020000,
+                         7212.020020},
+                        {7210.219727,  0.020000,    0.020000,    7329.920410,
+                         7251.620117,  0.020000,    7449.619629, 7372.220215,
+                         7293.020020,  7569.319824, 7492.819824, 7414.520020,
+                         15023.418945, 7613.419434, 7536.020508, 7456.820312,
+                         7734.020020,  7657.520020, 7579.220215, 0.020000,
+                         7779.020020,  7701.620605, 0.020000,    0.020000,
+                         7824.020020},
+                        {6759.319824,  0.020000,    0.020000,    6882.620117,
+                         6796.219727,  0.020000,    7005.919922, 6920.420410,
+                         6833.120117,  7129.220215, 7044.619629, 6958.219727,
+                         14122.519531, 7168.819824, 7083.319824, 6996.020020,
+                         7293.020020,  7208.419922, 7122.020508, 0.020000,
+                         7333.520020,  7248.020020, 0.020000,    0.020000,
+                         7374.020020},
+                        {7375.819824,  0.020000,    0.020000,    7499.120117,
+                         7417.219727,  0.020000,    7622.420410, 7541.420410,
+                         7458.620117,  7745.720215, 7665.619629, 7583.720215,
+                         15369.019531, 7789.819824, 7708.819824, 7626.020020,
+                         7914.020020,  7833.919434, 7752.020508, 0.020000,
+                         7959.020020,  7878.020020, 0.020000,    0.020000,
+                         8004.020020},
+                        {4982.420410,  0.020000,    0.020000,    5065.819824,
+                         5010.020020,  0.020000,    5149.220215, 5094.020020,
+                         5037.619629,  5232.620605, 5178.020020, 5122.220215,
+                         10381.219727, 5262.020020, 5206.819824, 5150.419922,
+                         5346.020020,  5291.419922, 5235.620117, 0.020000,
+                         5376.020020,  5320.819824, 0.020000,    0.020000,
+                         5406.020020},
+                        {5399.420410,  0.020000,    0.020000,    5482.820312,
+                         5430.020020,  0.020000,    5566.220215, 5514.020020,
+                         5460.619629,  5649.620605, 5598.020020, 5545.220215,
+                         11224.219727, 5682.020020, 5629.819824, 5576.419922,
+                         5766.020020,  5714.419922, 5661.620117, 0.020000,
+                         5799.020020,  5746.819824, 0.020000,    0.020000,
+                         5832.020020},
+                        {2733.320068, 0.020000,    0.020000,    2775.620117,
+                         2748.620117, 0.020000,    2817.920166, 2791.219971,
+                         2763.919922, 2860.220215, 2833.820068, 2806.820068,
+                         5681.720215, 2876.420166, 2849.719971, 2822.419922,
+                         2919.020020, 2892.619873, 2865.620117, 0.020000,
+                         2935.520020, 2908.820068, 0.020000,    0.020000,
+                         2952.020020},
+                        {2944.820068, 0.020000,    0.020000,    2987.120117,
+                         2961.620117, 0.020000,    3029.420166, 3004.220215,
+                         2978.419922, 3071.720215, 3046.820068, 3021.320068,
+                         6109.220215, 3089.420166, 3064.219971, 3038.419922,
+                         3132.020020, 3107.119873, 3081.620117, 0.020000,
+                         3150.020020, 3124.820068, 0.020000,    0.020000,
+                         3168.020020}},
+
+                       {{1924.530029, 0.030000,    0.030000,    1968.030029,
+                         1933.830078, 0.030000,    2011.530029, 1977.630127,
+                         1943.130127, 2055.030029, 2021.430054, 1987.230103,
+                         4050.929932, 2065.230225, 2031.330078, 1996.829956,
+                         2109.030029, 2075.430176, 2041.229980, 0.030000,
+                         2119.530029, 2085.630127, 0.030000,    0.030000,
+                         2130.030029},
+                        {2142.030029, 0.030000,    0.030000,    2185.530029,
+                         2152.830078, 0.030000,    2229.030029, 2196.630127,
+                         2163.630127, 2272.530029, 2240.430176, 2207.729980,
+                         4490.429688, 2284.230225, 2251.830078, 2218.830078,
+                         2328.030029, 2295.930176, 2263.229980, 0.030000,
+                         2340.030029, 2307.629883, 0.030000,    0.030000,
+                         2352.030029},
+                        {4321.229980, 0.030000,    0.030000,    4409.429688,
+                         4342.829590, 0.030000,    4497.629883, 4431.629883,
+                         4364.430176, 4585.829590, 4520.430176, 4453.829590,
+                         9060.030273, 4609.229980, 4543.229980, 4476.029785,
+                         4698.029785, 4632.630371, 4566.029785, 0.030000,
+                         4722.029785, 4656.029785, 0.030000,    0.030000,
+                         4746.029785},
+                        {4762.229980, 0.030000,    0.030000,    4850.429688,
+                         4786.829590, 0.030000,    4938.629883, 4875.629883,
+                         4811.430176, 5026.829590, 4964.430176, 4900.829590,
+                         9951.030273, 5053.229980, 4990.229980, 4926.029785,
+                         5142.029785, 5079.630371, 5016.029785, 0.030000,
+                         5169.029785, 5106.029785, 0.030000,    0.030000,
+                         5196.029785},
+                        {7202.129883,  0.030000,    0.030000,    7336.229492,
+                         7239.029785,  0.030000,    7470.329590, 7374.029785,
+                         7275.930176,  7604.429688, 7509.030273, 7411.829590,
+                         15051.330078, 7644.029785, 7547.729980, 7449.629883,
+                         7779.029785,  7683.630371, 7586.430176, 0.030000,
+                         7819.529785,  7723.229980, 0.030000,    0.030000,
+                         7860.029785},
+                        {7872.629883,  0.030000,    0.030000,    8006.729980,
+                         7914.029785,  0.030000,    8140.829590, 8049.029785,
+                         7955.430176,  8274.929688, 8184.030273, 8091.330078,
+                         16405.830078, 8319.030273, 8227.230469, 8133.629883,
+                         8454.030273,  8363.130859, 8270.430664, 0.030000,
+                         8499.030273,  8407.230469, 0.030000,    0.030000,
+                         8544.030273},
+                        {7349.729492,  0.030000,    0.030000,    7487.430176,
+                         7386.629883,  0.030000,    7625.129395, 7525.229980,
+                         7423.529785,  7762.829590, 7663.829590, 7563.029785,
+                         15360.929688, 7802.429688, 7702.530273, 7600.829590,
+                         7941.029785,  7842.029785, 7741.229980, 0.030000,
+                         7981.529785,  7881.630371, 0.030000,    0.030000,
+                         8022.029785},
+                        {8038.229492,  0.030000,    0.030000,    8175.930176,
+                         8079.629883,  0.030000,    8313.629883, 8218.230469,
+                         8121.029785,  8451.330078, 8356.830078, 8260.530273,
+                         16751.427734, 8495.429688, 8400.030273, 8302.831055,
+                         8634.030273,  8539.530273, 8443.230469, 0.030000,
+                         8679.030273,  8583.630859, 0.030000,    0.030000,
+                         8724.030273},
+                        {7497.329590,  0.030000,    0.030000,    7638.629883,
+                         7534.229492,  0.030000,    7779.930176, 7676.430176,
+                         7571.130371,  7921.229980, 7818.629395, 7714.229980,
+                         15670.530273, 7960.829590, 7857.329590, 7752.029785,
+                         8103.029785,  8000.429688, 7896.030273, 0.030000,
+                         8143.529785,  8040.029785, 0.030000,    0.030000,
+                         8184.029785},
+                        {8203.830078,  0.030000,    0.030000,    8345.129883,
+                         8245.229492,  0.030000,    8486.430664, 8387.430664,
+                         8286.630859,  8627.730469, 8529.629883, 8429.730469,
+                         17097.029297, 8671.830078, 8572.830078, 8472.030273,
+                         8814.030273,  8715.930664, 8616.030273, 0.030000,
+                         8859.030273,  8760.030273, 0.030000,    0.030000,
+                         8904.030273},
+                        {7644.930176,  0.030000,    0.030000,    7789.829590,
+                         7681.829590,  0.030000,    7934.729980, 7827.629883,
+                         7718.729980,  8079.630371, 7973.430176, 7865.430176,
+                         15980.130859, 8119.229980, 8012.129395, 7903.229980,
+                         8265.030273,  8158.830566, 8050.829590, 0.030000,
+                         8305.530273,  8198.430664, 0.030000,    0.030000,
+                         8346.030273},
+                        {8369.430664,  0.030000,    0.030000,    8514.331055,
+                         8410.830078,  0.030000,    8659.230469, 8556.629883,
+                         8452.231445,  8804.130859, 8702.430664, 8598.930664,
+                         17442.628906, 8848.230469, 8745.629883, 8641.230469,
+                         8994.030273,  8892.331055, 8788.830078, 0.030000,
+                         9039.030273,  8936.430664, 0.030000,    0.030000,
+                         9084.030273},
+                        {5644.829590,  0.030000,    0.030000,    5742.629883,
+                         5672.430176,  0.030000,    5840.430176, 5770.830078,
+                         5700.029785,  5938.229980, 5869.229980, 5799.029785,
+                         11763.630859, 5967.630371, 5898.029785, 5827.229980,
+                         6066.029785,  5997.029785, 5926.829590, 0.030000,
+                         6096.029785,  6026.430176, 0.030000,    0.030000,
+                         6126.029785},
+                        {6133.829590,  0.030000,    0.030000,    6231.629883,
+                         6164.430176,  0.030000,    6329.430176, 6262.830078,
+                         6195.029785,  6427.229980, 6361.229980, 6294.029785,
+                         12750.630859, 6459.630371, 6393.029785, 6325.229980,
+                         6558.029785,  6492.029785, 6424.829590, 0.030000,
+                         6591.029785,  6524.430176, 0.030000,    0.030000,
+                         6624.029785},
+                        {3100.530029, 0.030000,    0.030000,    3150.030029,
+                         3115.830078, 0.030000,    3199.530029, 3165.630127,
+                         3131.130127, 3249.030029, 3215.430176, 3181.230225,
+                         6444.930176, 3265.230225, 3231.330078, 3196.830078,
+                         3315.030029, 3281.430176, 3247.230225, 0.030000,
+                         3331.530029, 3297.630127, 0.030000,    0.030000,
+                         3348.030029},
+                        {3348.030029, 0.030000,    0.030000,    3397.530029,
+                         3364.830078, 0.030000,    3447.030029, 3414.630127,
+                         3381.630127, 3496.530029, 3464.430176, 3431.730225,
+                         6944.430176, 3514.230225, 3481.830078, 3448.830078,
+                         3564.030029, 3531.930176, 3499.230225, 0.030000,
+                         3582.030029, 3549.630127, 0.030000,    0.030000,
+                         3600.030029}},
+
+                       {{2147.739990, 0.040000,    0.040000,    2198.439941,
+                         2157.040039, 0.040000,    2249.140137, 2208.040039,
+                         2166.340088, 2299.840088, 2259.040039, 2217.640137,
+                         4526.140137, 2310.040039, 2268.940186, 2227.240234,
+                         2361.040039, 2320.240234, 2278.840088, 0.040000,
+                         2371.540039, 2330.440186, 0.040000,    0.040000,
+                         2382.040039},
+                        {2401.239990, 0.040000,    0.040000,    2451.939941,
+                         2412.040039, 0.040000,    2502.640137, 2463.040039,
+                         2422.840088, 2553.340088, 2514.040039, 2474.140137,
+                         5037.640137, 2565.040039, 2525.440186, 2485.240234,
+                         2616.040039, 2576.740234, 2536.840088, 0.040000,
+                         2628.040039, 2588.440186, 0.040000,    0.040000,
+                         2640.040039},
+                        {4839.640137,  0.040000,    0.040000,    4942.240234,
+                         4861.240234,  0.040000,    5044.839844, 4964.439941,
+                         4882.839844,  5147.440430, 5067.640137, 4986.640137,
+                         10154.440430, 5170.839844, 5090.440430, 5008.840332,
+                         5274.040039,  5194.240234, 5113.240234, 0.040000,
+                         5298.040039,  5217.640625, 0.040000,    0.040000,
+                         5322.040039},
+                        {5352.640137,  0.040000,    0.040000,    5455.240234,
+                         5377.240234,  0.040000,    5557.839844, 5480.439941,
+                         5401.839844,  5660.440430, 5583.640137, 5505.640137,
+                         11189.439453, 5686.839844, 5609.440430, 5530.840332,
+                         5790.040039,  5713.240234, 5635.240234, 0.040000,
+                         5817.040039,  5739.640625, 0.040000,    0.040000,
+                         5844.040039},
+                        {8087.740234,  0.040000,    0.040000,    8243.440430,
+                         8124.640625,  0.040000,    8399.139648, 8281.240234,
+                         8161.540039,  8554.840820, 8437.839844, 8319.040039,
+                         16908.937500, 8594.440430, 8476.540039, 8356.840820,
+                         8751.040039,  8634.040039, 8515.240234, 0.040000,
+                         8791.540039,  8673.640625, 0.040000,    0.040000,
+                         8832.040039},
+                        {8866.240234,  0.040000,    0.040000,    9021.940430,
+                         8907.640625,  0.040000,    9177.639648, 9064.240234,
+                         8949.040039,  9333.340820, 9220.839844, 9106.540039,
+                         18479.437500, 9377.440430, 9264.040039, 9148.840820,
+                         9534.040039,  9421.540039, 9307.240234, 0.040000,
+                         9579.040039,  9465.640625, 0.040000,    0.040000,
+                         9624.040039},
+                        {8235.339844,  0.040000,    0.040000,    8394.639648,
+                         8272.240234,  0.040000,    8553.940430, 8432.440430,
+                         8309.140625,  8713.240234, 8592.639648, 8470.240234,
+                         17218.539062, 8752.840820, 8631.339844, 8508.040039,
+                         8913.040039,  8792.440430, 8670.040039, 0.040000,
+                         8953.540039,  8832.040039, 0.040000,    0.040000,
+                         8994.040039},
+                        {9031.839844,  0.040000,    0.040000,    9191.139648,
+                         9073.240234,  0.040000,    9350.440430, 9233.440430,
+                         9114.640625,  9509.740234, 9393.639648, 9275.740234,
+                         18825.039062, 9553.839844, 9436.839844, 9318.040039,
+                         9714.040039,  9597.940430, 9480.040039, 0.040000,
+                         9759.040039,  9642.040039, 0.040000,    0.040000,
+                         9804.040039},
+                        {8382.940430,  0.040000,    0.040000,    8545.840820,
+                         8419.839844,  0.040000,    8708.740234, 8583.639648,
+                         8456.740234,  8871.640625, 8747.440430, 8621.440430,
+                         17528.138672, 8911.240234, 8786.139648, 8659.240234,
+                         9075.040039,  8950.840820, 8824.839844, 0.040000,
+                         9115.540039,  8990.440430, 0.040000,    0.040000,
+                         9156.040039},
+                        {9197.440430,  0.040000,    0.040000,    9360.340820,
+                         9238.839844,  0.040000,    9523.240234, 9402.639648,
+                         9280.240234,  9686.140625, 9566.440430, 9444.940430,
+                         19170.638672, 9730.240234, 9609.639648, 9487.240234,
+                         9894.040039,  9774.339844, 9652.839844, 0.040000,
+                         9939.040039,  9818.440430, 0.040000,    0.040000,
+                         9984.040039},
+                        {8530.540039,  0.040000,    0.040000,    8697.040039,
+                         8567.440430,  0.040000,    8863.540039, 8734.840820,
+                         8604.339844,  9030.040039, 8902.240234, 8772.639648,
+                         17837.740234, 9069.640625, 8940.940430, 8810.440430,
+                         9237.040039,  9109.240234, 8979.639648, 0.040000,
+                         9277.540039,  9148.840820, 0.040000,    0.040000,
+                         9318.040039},
+                        {9363.040039,  0.040000,    0.040000,    9529.540039,
+                         9404.440430,  0.040000,    9696.040039, 9571.840820,
+                         9445.839844,  9862.540039, 9739.240234, 9614.139648,
+                         19516.240234, 9906.640625, 9782.440430, 9656.440430,
+                         10074.040039, 9950.740234, 9825.639648, 0.040000,
+                         10119.040039, 9994.839844, 0.040000,    0.040000,
+                         10164.040039},
+                        {6307.240234,  0.040000,    0.040000,    6419.439941,
+                         6334.839844,  0.040000,    6531.640137, 6447.640137,
+                         6362.440430,  6643.839844, 6560.440430, 6475.840332,
+                         13146.040039, 6673.240234, 6589.240234, 6504.040039,
+                         6786.040039,  6702.640625, 6618.040039, 0.040000,
+                         6816.040039,  6732.040039, 0.040000,    0.040000,
+                         6846.040039},
+                        {6868.240234,  0.040000,    0.040000,    6980.439941,
+                         6898.839844,  0.040000,    7092.640137, 7011.640137,
+                         6929.440430,  7204.839844, 7124.440430, 7042.840332,
+                         14277.040039, 7237.240234, 7156.240234, 7074.040039,
+                         7350.040039,  7269.640625, 7188.040039, 0.040000,
+                         7383.040039,  7302.040039, 0.040000,    0.040000,
+                         7416.040039},
+                        {3467.739990, 0.040000,    0.040000,    3524.439941,
+                         3483.040039, 0.040000,    3581.140137, 3540.040039,
+                         3498.340088, 3637.840088, 3597.040039, 3555.640137,
+                         7208.140137, 3654.040039, 3612.940186, 3571.240234,
+                         3711.040039, 3670.240234, 3628.840088, 0.040000,
+                         3727.540039, 3686.440186, 0.040000,    0.040000,
+                         3744.040039},
+                        {3751.239990, 0.040000,    0.040000,    3807.939941,
+                         3768.040039, 0.040000,    3864.640137, 3825.040039,
+                         3784.840088, 3921.340088, 3882.040039, 3842.140137,
+                         7779.640137, 3939.040039, 3899.440186, 3859.240234,
+                         3996.040039, 3956.740234, 3916.840088, 0.040000,
+                         4014.040039, 3974.440186, 0.040000,    0.040000,
+                         4032.040039}}}}}));
+            CHECK(approxEq<float, float>(*op->getOutput(0), *expectedOutput));
+        }
+    }
+}
+
+} // namespace Aidge
diff --git a/unit_tests/operator/Test_ExpandImpl.cpp b/unit_tests/operator/Test_ExpandImpl.cpp
index 878c608110eabb824d8a6c0d1ceb0853b3c1449d..ad30457d33307ca595ecddfd3b06d58e118a02d0 100644
--- a/unit_tests/operator/Test_ExpandImpl.cpp
+++ b/unit_tests/operator/Test_ExpandImpl.cpp
@@ -13,20 +13,20 @@
 
 #include <catch2/catch_test_macros.hpp>
 
-#include "aidge/backend/cpu/data/TensorImpl.hpp"
-#include "aidge/backend/cpu/operator/ExpandImpl.hpp"
 #include "aidge/data/DataType.hpp"
 #include "aidge/data/Tensor.hpp"
 #include "aidge/operator/Expand.hpp"
 #include "aidge/utils/ArrayHelpers.hpp"
 
-using std::shared_ptr;
 
-using namespace Aidge;
+namespace Aidge {
+
+using std::shared_ptr;
 
-void setupTestExpand(shared_ptr<Tensor> inputData,
-                     shared_ptr<Tensor> inputShape,
-                     shared_ptr<Expand_Op> &op) {
+static void setupTestExpand(shared_ptr<Tensor> inputData,
+                            shared_ptr<Tensor> inputShape,
+                            shared_ptr<Expand_Op> &op,
+                            Tensor &expectedOutput) {
 
     op->getOutput(0)->setDataType(inputData->dataType());
 
@@ -35,6 +35,9 @@ void setupTestExpand(shared_ptr<Tensor> inputData,
 
     inputShape->setBackend("cpu");
     op->associateInput(1, inputShape);
+
+    expectedOutput.setBackend("cpu");
+    expectedOutput.setDataType(DataType::Int32);
 }
 
 TEST_CASE("[cpu/operator] Expand(forward)", "[Expand][CPU]") {
@@ -49,7 +52,7 @@ TEST_CASE("[cpu/operator] Expand(forward)", "[Expand][CPU]") {
             Array4D<cpptype_t<DataType::Int32>, 1, 3, 4, 2>({{{{{1, 3}, {1, 3}, {1, 3}, {1, 3}},
                                         {{1, 3}, {1, 3}, {1, 3}, {1, 3}},
                                         {{1, 3}, {1, 3}, {1, 3}, {1, 3}}}}});
-        setupTestExpand(inputData, inputShape, op);
+        setupTestExpand(inputData, inputShape, op, expectedOutput);
 
         // forwardDims has already been tested in core
         CHECK(op->forwardDims(true));
@@ -63,7 +66,7 @@ TEST_CASE("[cpu/operator] Expand(forward)", "[Expand][CPU]") {
             std::make_shared<Tensor>(Array1D<std::int64_t, 2>({2, 3}));
         Tensor expectedOutput = Array3D<cpptype_t<DataType::Int32>, 2, 2, 3>(
             {{{{2, 1, 3}, {2, 1, 3}}, {{2, 1, 3}, {2, 1, 3}}}});
-        setupTestExpand(inputData, inputShape, op);
+        setupTestExpand(inputData, inputShape, op,expectedOutput);
 
         // forwardDims has already been tested in core
         CHECK(op->forwardDims(true));
@@ -77,7 +80,7 @@ TEST_CASE("[cpu/operator] Expand(forward)", "[Expand][CPU]") {
             std::make_shared<Tensor>(Array1D<std::int64_t, 1>({1}));
         Tensor expectedOutput =
             Array4D<cpptype_t<DataType::Int32>, 2, 1, 3, 1>({{{2, 1, 3}, {2, 1, 3}}});
-        setupTestExpand(inputData, inputShape, op);
+        setupTestExpand(inputData, inputShape, op, expectedOutput);
 
         // forwardDims has already been tested in core
         CHECK(op->forwardDims(true));
@@ -91,7 +94,7 @@ TEST_CASE("[cpu/operator] Expand(forward)", "[Expand][CPU]") {
             std::make_shared<Tensor>(Array1D<std::int64_t, 3>({2, 1, 1}));
         Tensor expectedOutput =
             Array4D<cpptype_t<DataType::Int32>, 1, 2, 3, 1>({{{{2, 1, 3}, {2, 1, 3}}}});
-        setupTestExpand(inputData, inputShape, op);
+        setupTestExpand(inputData, inputShape, op,expectedOutput);
 
         // forwardDims has already been tested in core
         CHECK(op->forwardDims(true));
@@ -101,3 +104,4 @@ TEST_CASE("[cpu/operator] Expand(forward)", "[Expand][CPU]") {
     SECTION("N-Dim to N-Dim") {}
     auto inputData = std::shared_ptr<Tensor>();
 }
+} // namespace Aidge