diff --git a/include/aidge/operator/ConvTranspose.hpp b/include/aidge/operator/ConvTranspose.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..e573a1a0285625081c2f9bb89760bf751794e3d4
--- /dev/null
+++ b/include/aidge/operator/ConvTranspose.hpp
@@ -0,0 +1,208 @@
+/********************************************************************************
+ * Copyright (c) 2024 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#ifndef AIDGE_CORE_OPERATOR_CONVTRANSPOSE_H_
+#define AIDGE_CORE_OPERATOR_CONVTRANSPOSE_H_
+
+#include <array>
+#include <cmath>   // std::floor
+#include <string>
+#include <utility> // std::pair
+#include <vector>
+
+#include "aidge/data/Tensor.hpp"
+#include "aidge/graph/Node.hpp"
+#include "aidge/operator/OperatorTensor.hpp"
+#include "aidge/utils/ArrayHelpers.hpp"
+#include "aidge/utils/ErrorHandling.hpp"
+#include "aidge/utils/Registrar.hpp" // SET_IMPL_MACRO
+#include "aidge/utils/StaticAttributes.hpp"
+#include "aidge/utils/Types.h"
+
+namespace Aidge {
+enum class ConvTransposeAttr { StrideDims, DilationDims, KernelDims };
+
+template <DimIdx_t DIM>
+class ConvTranspose_Op
+    : public OperatorTensor,
+      public Registrable<ConvTranspose_Op<DIM>,
+                         std::string,
+                         std::function<std::shared_ptr<OperatorImpl>(
+                             const ConvTranspose_Op<DIM> &)>> {
+
+  public:
+    static const std::string Type;
+
+  private:
+    using Attributes_ = StaticAttributes<ConvTransposeAttr,
+                                         std::array<DimSize_t, DIM>,
+                                         std::array<DimSize_t, DIM>,
+                                         std::array<DimSize_t, DIM>>;
+    template <ConvTransposeAttr e>
+    using attr = typename Attributes_::template attr<e>;
+    const std::shared_ptr<Attributes_> mAttributes;
+
+  public:
+    ConvTranspose_Op() = delete;
+
+    constexpr explicit ConvTranspose_Op(
+        const std::array<DimSize_t, DIM> &kernelDims,
+        const std::array<DimSize_t, DIM> &strideDims =
+            create_array<DimSize_t, DIM>(1),
+        const std::array<DimSize_t, DIM> &dilationDims =
+            create_array<DimSize_t, DIM>(1))
+    : OperatorTensor(Type,
+                     {InputCategory::Data,
+                      InputCategory::Param,
+                      InputCategory::OptionalParam},
+                     1),
+      mAttributes(std::make_shared<Attributes_>(
+          attr<ConvTransposeAttr::StrideDims>(strideDims),
+          attr<ConvTransposeAttr::DilationDims>(dilationDims),
+          attr<ConvTransposeAttr::KernelDims>(kernelDims))) {}
+
+    /**
+     * @brief Copy-constructor. Copy the operator attributes and its output
+     * tensor(s), but not its input tensors (the new operator has no input
+     * associated).
+     * @param op Operator to copy.
+     */
+    ConvTranspose_Op(const ConvTranspose_Op<DIM> &op);
+
+    /**
+     * @brief Clone the operator using its copy-constructor.
+     * @see Operator::Conv_Op
+     */
+    std::shared_ptr<Operator> clone() const override {
+        return std::make_shared<ConvTranspose_Op<DIM>>(*this);
+    }
+
+    bool forwardDims(bool /*allowDataDependency*/ = false) override final;
+
+    std::vector<std::pair<std::vector<DimSize_t>, std::vector<DimSize_t>>>
+    computeReceptiveField(const std::vector<DimSize_t> &firstEltDims,
+                          const std::vector<DimSize_t> &outputDims,
+                          const IOIndex_t outputIdx = 0) const override;
+
+    void setBackend(const std::string &name, DeviceIdx_t device = 0) override;
+    std::set<std::string> getAvailableBackends() const override;
+
+    DimSize_t inChannels() const {
+        if (!getInput(1)) {
+            AIDGE_THROW_OR_ABORT(
+                std::runtime_error,
+                "{}: operator has no weight Tensor associated so no "
+                "specific number of input channel imposed.",
+                Type);
+        }
+        return getInput(1)->template dims<DIM + 2>()[0];
+    }
+
+    DimSize_t outChannels() const {
+        if (!getInput(1)) {
+            AIDGE_THROW_OR_ABORT(
+                std::runtime_error,
+                "{}: operator has no weight Tensor associated so no "
+                "specific number of output channel imposed.",
+                Type);
+        }
+        return getInput(1)->template dims<DIM + 2>()[1];
+    }
+
+    inline std::shared_ptr<Attributes> attributes() const override {
+        return mAttributes;
+    }
+    inline std::array<DimSize_t, DIM> &strideDims() const {
+        return mAttributes->template getAttr<ConvTransposeAttr::StrideDims>();
+    }
+    inline std::array<DimSize_t, DIM> &dilationDims() const {
+        return mAttributes
+            ->template getAttr<ConvTransposeAttr::DilationDims>();
+    }
+    inline std::array<DimSize_t, DIM> &kernelDims() const {
+        return mAttributes->template getAttr<ConvTransposeAttr::KernelDims>();
+    }
+
+    static const std::vector<std::string> getInputsName() {
+        return {"data_input", "weight", "bias"};
+    }
+    static const std::vector<std::string> getOutputsName() {
+        return {"data_output"};
+    }
+};
+
+/**
+ * @brief Perform a convTranspose(/deconvolution) on the input Tensor.
+ *
+ * @tparam DIM Number of dimensions for the feature map.
+ * @param inChannels Number of input channels.
+ * @param outChannels Number of output channels.
+ * @param kernelDims Dimensions of the kernel. Must be the same number of
+ * dimensions as the feature map.
+ * @param name Name of the operator.
+ * @param strideDims Dimensions of the stride attribute. Must be the same
+ * number of dimensions as the feature map.
+ * @param dilationDims Dimensions of the dilation attribute. Must be the same
+ * number of dimensions as the feature map.
+ * @return std::shared_ptr<Node> A Node containing the operator.
+ */
+template <std::array<DimIdx_t, 1>::size_type DIM>
+std::shared_ptr<Node>
+ConvTranspose(const DimSize_t &inChannels,
+              const DimSize_t &outChannels,
+              const std::array<DimSize_t, DIM> &kernelDims,
+              const std::array<DimSize_t, DIM> &strideDims =
+                  create_array<DimSize_t, DIM>(1),
+              const std::array<DimSize_t, DIM> &dilationDims =
+                  create_array<DimSize_t, DIM>(1),
+              const bool noBias = false,
+              const std::string &name = "");
+
+// helper with C-style array instead of std::array for kernel_dims to allow
+// automatic template DIM deduction
+/**
+ * @brief Conv Transpose node constructor
+ * @param[in] inChannels number of input channels of the conv transpose
+ * operator
+ * @param[in] outChannels number of ouptut channels of the convTranspose
+ * operator
+ * @param[in] kernelDims array of size DIM describing the dimensions of the
+ * kernel
+ * @param[in] name name of the node
+ * @param[in] strideDims stride along each dimension of the operator
+ * @param[in] dilationDims dilation along each dimension of the operator
+ * @param[in] noBias describes if the operator has biases or just weights
+ */
+template <DimIdx_t DIM>
+inline std::shared_ptr<Node>
+ConvTranspose(const DimSize_t &inChannels,
+              const DimSize_t &outChannels,
+              DimSize_t const (&kernelDims)[DIM],
+              const std::array<DimSize_t, DIM> &strideDims =
+                  create_array<DimSize_t, DIM>(1),
+              const std::array<DimSize_t, DIM> &dilationDims =
+                  create_array<DimSize_t, DIM>(1),
+              const bool noBias = false,
+              const std::string &name = "");
+} // namespace Aidge
+
+extern template class Aidge::ConvTranspose_Op<1>;
+extern template class Aidge::ConvTranspose_Op<2>;
+
+namespace {
+template <>
+const char *const EnumStrings<Aidge::ConvTransposeAttr>::data[] = {
+    "stride_dims",
+    "dilation_dims",
+    "kernel_dims"};
+}
+
+#endif /* AIDGE_CORE_OPERATOR_CONVTRANSPOSE_H_ */
diff --git a/python_binding/operator/pybind_ConvTranspose.cpp b/python_binding/operator/pybind_ConvTranspose.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..0f759e3dbc59183e26514631f0198c8204c4ac1a
--- /dev/null
+++ b/python_binding/operator/pybind_ConvTranspose.cpp
@@ -0,0 +1,111 @@
+
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <array>
+#include <pybind11/pybind11.h>
+#include <pybind11/stl.h>
+#include <string>
+#include <vector>
+
+#include "aidge/operator/ConvTranspose.hpp"
+#include "aidge/operator/OperatorTensor.hpp"
+#include "aidge/utils/Registrar.hpp" // declare_registrable
+#include "aidge/utils/Types.h"
+
+namespace py = pybind11;
+namespace Aidge {
+
+template <DimIdx_t DIM> void declare_ConvTransposeOp(py::module &m) {
+    const std::string pyClassName("ConvTranspose" + std::to_string(DIM) +
+                                  "DOp");
+    py::class_<ConvTranspose_Op<DIM>,
+               std::shared_ptr<ConvTranspose_Op<DIM>>,
+               OperatorTensor>(m,
+                               pyClassName.c_str(),
+                               py::multiple_inheritance())
+        .def(py::init([](const std::vector<DimSize_t> &kernel_dims,
+                         const std::vector<DimSize_t> &stride_dims,
+                         const std::vector<DimSize_t> &dilation_dims) {
+                 AIDGE_ASSERT(kernel_dims.size() == DIM,
+                              "kernel_dims size [{}] does not match DIM [{}]",
+                              kernel_dims.size(),
+                              DIM);
+                 AIDGE_ASSERT(stride_dims.size() == DIM,
+                              "stride_dims size [{}] does not match DIM [{}]",
+                              stride_dims.size(),
+                              DIM);
+                 AIDGE_ASSERT(
+                     dilation_dims.size() == DIM,
+                     "dilation_dims size [{}] does not match DIM [{}]",
+                     dilation_dims.size(),
+                     DIM);
+
+                 return new ConvTranspose_Op<DIM>(
+                     to_array<DIM>(kernel_dims.begin()),
+                     to_array<DIM>(stride_dims.begin()),
+                     to_array<DIM>(dilation_dims.begin()));
+             }),
+             py::arg("kernel_dims"),
+             py::arg("stride_dims") = std::vector<DimSize_t>(DIM, 1),
+             py::arg("dilation_dims") = std::vector<DimSize_t>(DIM, 1))
+        .def_static("get_inputs_name", &ConvTranspose_Op<DIM>::getInputsName)
+        .def_static("get_outputs_name", &ConvTranspose_Op<DIM>::getOutputsName)
+        .def("in_channels", &ConvTranspose_Op<DIM>::inChannels)
+        .def("out_channels", &ConvTranspose_Op<DIM>::outChannels)
+        .def_readonly_static("Type", &ConvTranspose_Op<DIM>::Type);
+
+    declare_registrable<ConvTranspose_Op<DIM>>(m, pyClassName);
+
+    m.def(("ConvTranspose" + std::to_string(DIM) + "D").c_str(),
+          [](const DimSize_t &in_channels,
+             const DimSize_t &out_channels,
+             const std::vector<DimSize_t> &kernel_dims,
+             const std::vector<DimSize_t> &stride_dims,
+             const std::vector<DimSize_t> &dilation_dims,
+             bool noBias,
+             const std::string &name){
+              AIDGE_ASSERT(kernel_dims.size() == DIM,
+                           "kernel_dims size [{}] does not match DIM [{}]",
+                           kernel_dims.size(),
+                           DIM);
+              AIDGE_ASSERT(stride_dims.size() == DIM,
+                           "stride_dims size [{}] does not match DIM [{}]",
+                           stride_dims.size(),
+                           DIM);
+              AIDGE_ASSERT(dilation_dims.size() == DIM,
+                           "dilation_dims size [{}] does not match DIM [{}]",
+                           dilation_dims.size(),
+                           DIM);
+
+              return ConvTranspose<DIM>(in_channels,
+                               out_channels,
+                               to_array<DIM>(kernel_dims.begin()),
+                               to_array<DIM>(stride_dims.begin()),
+                               to_array<DIM>(dilation_dims.begin()),
+                               noBias,
+                               name);
+          },
+          py::arg("in_channels"),
+          py::arg("out_channels"),
+          py::arg("kernel_dims"),
+          py::arg("name") = "",
+          py::arg("stride_dims") = std::vector<DimSize_t>(DIM, 1),
+          py::arg("dilation_dims") = std::vector<DimSize_t>(DIM, 1),
+          py::arg("no_bias") = false);
+}
+
+void init_ConvTranspose(py::module &m) {
+    declare_ConvTransposeOp<1>(m);
+    declare_ConvTransposeOp<2>(m);
+}
+
+} // namespace Aidge
diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp
index cc3f3abef2dee3888b5f542a3efddd0a5b422ec9..7fef82847f3a9e5252e14c1aff584b21f182e36c 100644
--- a/python_binding/pybind_core.cpp
+++ b/python_binding/pybind_core.cpp
@@ -51,6 +51,7 @@ void init_Clip(py::module&);
 void init_Concat(py::module&);
 void init_ConstantOfShape(py::module&);
 void init_Conv(py::module&);
+void init_ConvTranspose(py::module&);
 void init_ConvDepthWise(py::module&);
 void init_CryptoHash(py::module&);
 void init_DepthToSpace(py::module&);
@@ -157,6 +158,7 @@ void init_Aidge(py::module& m) {
     init_Clip(m);
     init_Concat(m);
     init_Conv(m);
+    init_ConvTranspose(m);
     init_ConvDepthWise(m);
     init_ConstantOfShape(m);
     init_CryptoHash(m);
diff --git a/src/operator/ConvTranspose.cpp b/src/operator/ConvTranspose.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..8571518d7976a516283a13c651576a095e5e017a
--- /dev/null
+++ b/src/operator/ConvTranspose.cpp
@@ -0,0 +1,322 @@
+/********************************************************************************
+ * Copyright (c) 2024 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include "aidge/operator/ConvTranspose.hpp"
+
+#include <cmath>     // std::floor
+#include <cstddef>   // std::size_t
+#include <cstdint>
+#include <stdexcept> // std::runtime_error
+#include <string>
+#include <utility>   // std::pair
+#include <vector>
+
+#include "aidge/operator/Producer.hpp"
+#include "aidge/utils/ErrorHandling.hpp"
+#include "aidge/utils/Registrar.hpp"
+#include "aidge/utils/Types.h"
+
+namespace Aidge {
+
+template <DimIdx_t DIM>
+const std::string ConvTranspose_Op<DIM>::Type =
+    "ConvTranspose" + std::to_string(DIM) + "D";
+
+template <DimIdx_t DIM>
+ConvTranspose_Op<DIM>::ConvTranspose_Op(const ConvTranspose_Op<DIM> &op)
+    : OperatorTensor(op), mAttributes(op.mAttributes) {
+    if (op.mImpl) {
+        SET_IMPL_MACRO(ConvTranspose_Op<DIM>, *this, op.backend());
+    } else {
+        mImpl = nullptr;
+    }
+}
+
+template <DimIdx_t DIM>
+bool ConvTranspose_Op<DIM>::forwardDims(bool /*allowDataDependency*/) {
+    if (!inputsAssociated()) {
+        return false;
+    }
+    constexpr std::int8_t batchIdx = 0;
+    constexpr std::int8_t channelIdx = 1;
+    // DIM only defines the dimensions of the input, this defines channel &
+    // batch idx offset
+    constexpr std::int8_t NCIdx = 2;
+
+    // first check weight since it defines inChannels and outChannels
+    AIDGE_ASSERT((getInput(1)->nbDims() == (DIM + NCIdx)),
+                 "{}: Wrong weight Tensor dimension: {}. "
+                 "Expected number of dimensions is {}.",
+                 type(),
+                 getInput(1)->nbDims(),
+                 DIM + NCIdx);
+    // check data
+    AIDGE_ASSERT(
+        getInput(0)->template dims<DIM + NCIdx>()[channelIdx] == inChannels(),
+        "{}: Wrong input size ({}). Expected dims are [x, {}, {}] as weights "
+        "dim size "
+        "on 1st axis describes number of input channel.",
+        type(),
+        getInput(0)->dims(),
+        inChannels(),
+        fmt::join(std::vector<std::string>(DIM, "x"), ", "));
+    // check optional bias
+    if (getInput(2)) {
+        AIDGE_ASSERT((getInput(2)->nbDims() == 1) &&
+                         (getInput(2)->template dims<1>()[0] == outChannels()),
+                     "{}: Wrong bias size ({}). Expected dims are [{}].",
+                     type(),
+                     getInput(2)->dims(),
+                     outChannels());
+    }
+    std::array<DimSize_t, DIM + NCIdx> outputDims{};
+    const std::array<DimSize_t, DIM + NCIdx> inputDims(
+        getInput(0)->template dims<DIM + NCIdx>());
+
+    outputDims[channelIdx] = outChannels();
+    outputDims[batchIdx] = inputDims[batchIdx];
+
+    for (std::size_t dim = 0; dim < DIM; ++dim) {
+        const DimSize_t kernelExtent =
+            dilationDims()[dim] * (kernelDims()[dim] - 1) + 1;
+        outputDims[dim + NCIdx] =
+            ((inputDims[dim + NCIdx] - 1) * strideDims()[dim]) + kernelExtent;
+    }
+
+    mOutputs[0]->resize(outputDims);
+    return true;
+}
+
+template <DimIdx_t DIM>
+std::vector<std::pair<std::vector<DimSize_t>, std::vector<DimSize_t>>>
+ConvTranspose_Op<DIM>::computeReceptiveField(
+    const std::vector<DimSize_t> &firstEltDims,
+    const std::vector<DimSize_t> &outputDims,
+    const IOIndex_t outputIdx) const {
+
+    constexpr std::int8_t inBatchIdx = 0;
+    constexpr std::int8_t inChannelIdx = 1;
+    // DIM only defines the dimensions of the input, this defines channel &
+    // batch idx offset
+    constexpr std::int8_t NCChannels = 2;
+
+    if (outputIdx != 0) {
+        AIDGE_THROW_OR_ABORT(std::runtime_error,
+                             "{}: Operator has got only one output Tensor.",
+                             type());
+    }
+    if (firstEltDims.size() != outputDims.size()) {
+        AIDGE_THROW_OR_ABORT(std::runtime_error,
+                             "{}: outputDims and firstEltDims should have the "
+                             "size of the output Tensor dimensions.",
+                             type());
+    }
+    if ((outputDims.size() != (DIM + NCChannels)) || !dimsForwarded()) {
+        AIDGE_THROW_OR_ABORT(
+            std::runtime_error,
+            "Given outputDim out of range or output dim not forwarded yet.");
+    }
+
+    // Offset
+    auto inputIdxDims = firstEltDims; // batch idx is the same
+    // each channel is used so start with the first one
+    inputIdxDims[inChannelIdx] = 0;
+
+    // Error checking : parameters will not create an out of bound error.
+    for (DimIdx_t i = 0; i < (DIM + NCChannels); ++i) {
+        AIDGE_ASSERT(
+            ((outputDims[i] + firstEltDims[i]) <=
+             mOutputs[outputIdx]->template dims<DIM + NCChannels>()[i]) &&
+                outputDims[i] != 0,
+            "{}: Given outputDim out of range for dimension {} ({} + {})",
+            type(),
+            static_cast<std::size_t>(i),
+            firstEltDims[i],
+            outputDims[i]);
+    }
+
+    ////////////////////////
+    // Input
+    std::vector<DimSize_t> inputDims{outputDims[inBatchIdx],
+                                     getInput(0)->dims()[inChannelIdx]};
+    for (DimIdx_t i = 0; i < DIM; ++i) {
+        const DimSize_t kernelExtent =
+            dilationDims()[i] * (kernelDims()[i] - 1) + 1;
+
+        inputDims.push_back(
+            1 +
+            static_cast<DimSize_t>(floor(
+                static_cast<float>(inputDims[i + NCChannels] - kernelExtent) /
+                static_cast<float>(strideDims()[i]))));
+        inputIdxDims[NCChannels + i] *= strideDims()[i];
+    }
+
+    ////////////////////////
+    // Weight
+    // same output value, every input channel is used
+    std::vector<DimSize_t> weightDims{outputDims[inChannelIdx],
+                                      getInput(0)->dims()[inChannelIdx]};
+    for (std::size_t i = 0; i < DIM; ++i) {
+        weightDims.push_back(kernelDims()[i]);
+    }
+    std::vector<DimSize_t> weightIdxDims =
+        std::vector<DimSize_t>(DIM + NCChannels, 0);
+    weightIdxDims[0] = firstEltDims[inChannelIdx];
+
+    ////////////////////////
+    // Result
+    std::vector<std::pair<std::vector<DimSize_t>, std::vector<DimSize_t>>> res;
+    res.push_back(
+        std::pair<std::vector<DimSize_t>, std::vector<DimSize_t>>(inputIdxDims,
+                                                                  inputDims));
+    res.push_back(std::pair<std::vector<DimSize_t>, std::vector<DimSize_t>>(
+        weightIdxDims,
+        weightDims));
+
+    ////////////////////////
+    // Bias
+    if (getInput(2)) {
+        const std::vector<DimSize_t> biasDims{
+            outputDims[1]}; // the number of output channel
+        const std::vector<DimSize_t> biasIdxDims{firstEltDims[1]};
+        res.push_back(
+            std::pair<std::vector<DimSize_t>, std::vector<DimSize_t>>(
+                biasIdxDims,
+                biasDims));
+    }
+    return res;
+}
+
+template <DimIdx_t DIM>
+void ConvTranspose_Op<DIM>::setBackend(const std::string &name,
+                                       DeviceIdx_t device) {
+    SET_IMPL_MACRO(ConvTranspose_Op<DIM>, *this, name);
+    mOutputs[0]->setBackend(name, device);
+
+    // By default, automatically set backend for weight and bias inputs
+    if (getInput(1)) {
+        getInput(1)->setBackend(name, device);
+    } else {
+        Log::notice("ConvTranspose_Op::setBackend(): could not set backend "
+                    "for weight input, because input is not connected");
+    }
+
+    if (getInput(2)) {
+        // Bias is optional
+        getInput(2)->setBackend(name, device);
+    }
+}
+
+template <DimIdx_t DIM>
+std::set<std::string> ConvTranspose_Op<DIM>::getAvailableBackends() const {
+    return Registrar<ConvTranspose_Op<DIM>>::getKeys();
+}
+
+template class ConvTranspose_Op<1>;
+template class ConvTranspose_Op<2>;
+
+/////////////////////////////////////////////////////////////
+
+template <std::array<DimIdx_t, 1>::size_type DIM>
+std::shared_ptr<Node>
+ConvTranspose(const DimSize_t &inChannels,
+              const DimSize_t &outChannels,
+              const std::array<DimSize_t, DIM> &kernelDims,
+              const std::array<DimSize_t, DIM> &strideDims,
+              const std::array<DimSize_t, DIM> &dilationDims,
+              const bool noBias,
+              const std::string &name) {
+    AIDGE_ASSERT(DIM <= MaxDim,
+                 "Too many kernel dimensions required by Conv, not supported");
+    AIDGE_ASSERT(
+        !std::any_of(dilationDims.cbegin(),
+                     dilationDims.cend(),
+                     [](DimSize_t val) { return val <= 0; }),
+        "Conv : at least of of the dilation dimension is <= 0, expecting "
+        "strictly positive values. Got {}",
+        dilationDims);
+    AIDGE_ASSERT(!std::any_of(strideDims.cbegin(),
+                              strideDims.cend(),
+                              [](DimSize_t val) { return val <= 0; }),
+                 "Conv : at least of of the stride dimension is 0<= , expecting "
+                 "strictly positive values. Got {}",
+                 strideDims);
+    auto conv = std::make_shared<Node>(
+        std::make_shared<ConvTranspose_Op<static_cast<DimIdx_t>(DIM)>>(
+            kernelDims,
+            strideDims,
+            dilationDims),
+        name);
+    addProducer(conv,
+                1,
+                append(inChannels, append(outChannels, kernelDims)),
+                "w");
+    if (!noBias) {
+        addProducer(conv, 2, {outChannels}, "b"); // already sets bias dims
+    }
+    return conv;
+}
+
+template std::shared_ptr<Node>
+ConvTranspose<1>(const DimSize_t &,
+                 const DimSize_t &,
+                 const std::array<DimSize_t, 1> &,
+                 const std::array<DimSize_t, 1> &,
+                 const std::array<DimSize_t, 1> &,
+                 const bool,
+                 const std::string &);
+
+template std::shared_ptr<Node>
+ConvTranspose<2>(const DimSize_t &,
+                 const DimSize_t &,
+                 const std::array<DimSize_t, 2> &,
+                 const std::array<DimSize_t, 2> &,
+                 const std::array<DimSize_t, 2> &,
+                 const bool,
+                 const std::string &);
+
+template <DimIdx_t DIM>
+inline std::shared_ptr<Node>
+ConvTranspose(const DimSize_t &inChannels,
+              const DimSize_t &outChannels,
+              DimSize_t const (&kernelDims)[DIM],
+              const std::array<DimSize_t, DIM> &strideDims,
+              const std::array<DimSize_t, DIM> &dilationDims,
+              const bool noBias,
+              const std::string &name) {
+    return ConvTranspose<DIM>(inChannels,
+                              outChannels,
+                              to_array(kernelDims),
+                              strideDims,
+                              dilationDims,
+                              noBias,
+                              name);
+}
+
+template std::shared_ptr<Node>
+ConvTranspose<1>(const DimSize_t &,
+                 const DimSize_t &,
+                 DimSize_t const (&)[1],
+                 const std::array<DimSize_t, 1> &,
+                 const std::array<DimSize_t, 1> &,
+                 const bool,
+                 const std::string &);
+
+template std::shared_ptr<Node>
+ConvTranspose<2>(const DimSize_t &,
+                 const DimSize_t &,
+                 DimSize_t const (&)[2],
+                 const std::array<DimSize_t, 2> &,
+                 const std::array<DimSize_t, 2> &,
+                 const bool,
+                 const std::string &);
+
+} // namespace Aidge
diff --git a/unit_tests/operator/Test_ConvTranspose.cpp b/unit_tests/operator/Test_ConvTranspose.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..f284be237034e8d8937d7ae750a6c636a6f6f545
--- /dev/null
+++ b/unit_tests/operator/Test_ConvTranspose.cpp
@@ -0,0 +1,241 @@
+/********************************************************************************
+ * Copyright (c) 2024 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <catch2/catch_test_macros.hpp>
+#include <memory>
+#include <vector>
+
+#include "aidge/data/Tensor.hpp"
+#include "aidge/graph/GraphView.hpp"
+#include "aidge/operator/Conv.hpp"
+#include "aidge/operator/ConvTranspose.hpp"
+#include "aidge/operator/OperatorTensor.hpp"
+#include "aidge/operator/Producer.hpp"
+#include "aidge/utils/Types.h"
+
+namespace Aidge {
+
+template <DimSize_t DIM>
+static std::shared_ptr<OperatorTensor> setupTestConvTransposeForwardDims(
+    const DimSize_t batchSize,
+    const DimSize_t inChannels,
+    const DimSize_t outChannels,
+    const DimSize_t kernelDims,
+    const DimSize_t inDataSize,
+    const std::array<DimSize_t, DIM> strideDims,
+    const std::array<DimSize_t, DIM> dilationDims) {
+
+    auto convTrans = ConvTranspose(inChannels,
+                                   outChannels,
+                                   {kernelDims},
+                                   strideDims,
+                                   dilationDims,
+                                   false,
+                                   "yeet");
+
+    auto op =
+        std::dynamic_pointer_cast<OperatorTensor>(convTrans->getOperator());
+
+    auto input = std::make_shared<Tensor>(
+        std::vector<DimSize_t>({batchSize, inChannels, inDataSize}));
+
+    op->associateInput(0, input);
+    return op;
+}
+
+/***********************************************************
+ * This test is based on the assumption that conv and
+ * ConvTranspose are the exact oppposite operations
+ * Hence :
+ * Conv::computeReceptiveField() <=> ConvTranspose::forwardDims()
+ * Conv::forwardDims() <=> ConvTranspose::computeReceptiveField()
+ *
+ * This means that this test relies on Conv Operator's tests
+ * properties.
+ ***********************************************************/
+TEST_CASE("[core/operator] ConvTranspose_Op(forwarDims)",
+          "[Operator][forwardDims][ConvTranspose]") {
+
+    SECTION("a (conv()=>convTranspose()) graph has inputs & outputs of "
+            "identical dimensions") {
+        auto prod = Producer({16, 3, 224, 224}, "dataProvider");
+
+        // output dims: {16, 32, 220, 220}
+        auto conv =
+            Conv(3, 32, {5, 5}, "conv"); // output dims: {16, 32, 220, 220}
+
+        auto convTranspose = ConvTranspose(32,
+                                           3,
+                                           {5, 5},
+                                           std::array<DimSize_t, 2>({1, 1}),
+                                           std::array<DimSize_t, 2>({1, 1}),
+                                           false,
+                                           "convtranspose");
+
+        auto g = std::make_shared<GraphView>("TestGraph");
+
+        prod->addChild(conv, 0);
+        g->add(conv);
+        g->addChild(convTranspose, conv, 0);
+        g->forwardDims();
+
+        auto prodOp =
+            std::dynamic_pointer_cast<OperatorTensor>(prod->getOperator());
+        auto op1 =
+            std::dynamic_pointer_cast<OperatorTensor>(conv->getOperator());
+        auto op2 = std::dynamic_pointer_cast<OperatorTensor>(
+            convTranspose->getOperator());
+
+        REQUIRE(g->forwardDims({prodOp->getOutput(0)->dims()}, true));
+        CHECK(prodOp->getOutput(0)->dims() ==
+              std::dynamic_pointer_cast<OperatorTensor>(
+                  g->getOrderedOutputs()[0].first->getOperator())
+                  ->getOutput(0)
+                  ->dims());
+    }
+    SECTION("1D") {
+        constexpr DimSize_t DIM = 1;
+        SECTION("Test with reference output") {
+            SECTION("no stride / no dilation") {
+                constexpr DimSize_t batchSize = 2;
+                constexpr DimSize_t inChannels = 3;
+                constexpr DimSize_t outChannels = 4;
+                constexpr DimSize_t kernelDims = 2;
+
+                constexpr std::array<DimSize_t, DIM> strideDims{1};
+                constexpr std::array<DimSize_t, DIM> dilationDims{1};
+
+                constexpr DimSize_t inDataSize = 6;
+                constexpr DimSize_t outDataSize = 7;
+
+                auto op = setupTestConvTransposeForwardDims(batchSize,
+                                                            inChannels,
+                                                            outChannels,
+                                                            kernelDims,
+                                                            inDataSize,
+                                                            strideDims,
+                                                            dilationDims);
+                REQUIRE(op->forwardDims());
+
+                CHECK(op->getOutput(0)->dims() ==
+                      std::vector<DimSize_t>(
+                          {batchSize, outChannels, outDataSize}));
+            }
+        }
+        SECTION("stride / no dilation") {
+            constexpr DimSize_t batchSize = 2;
+            constexpr DimSize_t inChannels = 3;
+            constexpr DimSize_t outChannels = 4;
+            constexpr DimSize_t kernelDims = 2;
+
+            constexpr std::array<DimSize_t, DIM> strideDims{3};
+            constexpr std::array<DimSize_t, DIM> dilationDims{1};
+
+            constexpr DimSize_t inDataSize = 6;
+            constexpr DimSize_t outDataSize = 17;
+
+            auto op = setupTestConvTransposeForwardDims(batchSize,
+                                                        inChannels,
+                                                        outChannels,
+                                                        kernelDims,
+                                                        inDataSize,
+                                                        strideDims,
+                                                        dilationDims);
+
+            REQUIRE(op->forwardDims());
+
+            CHECK(
+                op->getOutput(0)->dims() ==
+                std::vector<DimSize_t>({batchSize, outChannels, outDataSize}));
+        }
+        SECTION("no stride / dilation") {
+            constexpr DimSize_t batchSize = 2;
+            constexpr DimSize_t inChannels = 3;
+            constexpr DimSize_t outChannels = 4;
+            constexpr DimSize_t kernelDims = 2;
+
+            constexpr std::array<DimSize_t, DIM> strideDims{1};
+            constexpr std::array<DimSize_t, DIM> dilationDims{3};
+
+            constexpr DimSize_t inDataSize = 6;
+            constexpr DimSize_t outDataSize = 9;
+
+            auto op = setupTestConvTransposeForwardDims(batchSize,
+                                                        inChannels,
+                                                        outChannels,
+                                                        kernelDims,
+                                                        inDataSize,
+                                                        strideDims,
+                                                        dilationDims);
+
+            REQUIRE(op->forwardDims());
+
+            CHECK(
+                op->getOutput(0)->dims() ==
+                std::vector<DimSize_t>({batchSize, outChannels, outDataSize}));
+        }
+        SECTION("stride / dilation") {
+            constexpr DimSize_t batchSize = 2;
+            constexpr DimSize_t inChannels = 3;
+            constexpr DimSize_t outChannels = 4;
+            constexpr DimSize_t kernelDims = 4;
+
+            constexpr std::array<DimSize_t, DIM> strideDims{3};
+            constexpr std::array<DimSize_t, DIM> dilationDims{3};
+
+            constexpr DimSize_t inDataSize = 15;
+            constexpr DimSize_t outDataSize = 52;
+
+            auto op = setupTestConvTransposeForwardDims(batchSize,
+                                                        inChannels,
+                                                        outChannels,
+                                                        kernelDims,
+                                                        inDataSize,
+                                                        strideDims,
+                                                        dilationDims);
+
+            REQUIRE(op->forwardDims());
+
+            CHECK(
+                op->getOutput(0)->dims() ==
+                std::vector<DimSize_t>({batchSize, outChannels, outDataSize}));
+        }
+        SECTION("stride / dilation") {
+            constexpr DimSize_t batchSize = 2;
+            constexpr DimSize_t inChannels = 3;
+            constexpr DimSize_t outChannels = 4;
+
+            constexpr DimSize_t kernelDims = 2;
+
+            constexpr DimSize_t inDataSize = 6;
+            constexpr DimSize_t outDataSize = 7;
+
+            constexpr std::array<DimSize_t, DIM> strideDims{1};
+            constexpr std::array<DimSize_t, DIM> dilationDims{1};
+
+            auto op = setupTestConvTransposeForwardDims(batchSize,
+                                                        inChannels,
+                                                        outChannels,
+                                                        kernelDims,
+                                                        inDataSize,
+                                                        strideDims,
+                                                        dilationDims);
+
+            REQUIRE(op->forwardDims());
+
+            CHECK(
+                op->getOutput(0)->dims() ==
+                std::vector<DimSize_t>({batchSize, outChannels, outDataSize}));
+        }
+    }
+}
+
+} // namespace Aidge