diff --git a/include/aidge/backend/cuda.hpp b/include/aidge/backend/cuda.hpp
index cfae53b64115aa7946580d00f45be56f17163d7f..a6bae174471e665f229d08a489d6b9f7911a6e9f 100644
--- a/include/aidge/backend/cuda.hpp
+++ b/include/aidge/backend/cuda.hpp
@@ -14,6 +14,5 @@
 
 #include "aidge/backend/cuda/data/TensorImpl.hpp"
 #include "aidge/backend/cuda/operator/ConvImpl.hpp"
-#include "aidge/backend/cuda/operator/ProducerImpl.hpp"
 
 #endif /* AIDGE_BACKEND_CUDA_IMPORTS_H_ */
\ No newline at end of file
diff --git a/include/aidge/backend/cuda/operator/ConvImpl.hpp b/include/aidge/backend/cuda/operator/ConvImpl.hpp
index 65180fd54aeb9ff349af37192061cf66415e0a77..1f692b09cf44e5d54d1bc9d5b998b90a800f719f 100644
--- a/include/aidge/backend/cuda/operator/ConvImpl.hpp
+++ b/include/aidge/backend/cuda/operator/ConvImpl.hpp
@@ -34,8 +34,11 @@ private:
     cudnnConvolutionDescriptor_t mConvDesc = nullptr;
     cudnnFilterDescriptor_t mFilterDesc = nullptr;
     cudnnConvolutionFwdAlgo_t mFwdAlgo;
+    cudnnConvolutionBwdFilterAlgo_t mBwdFilterAlgo;
+    cudnnConvolutionBwdDataAlgo_t mBwdDataAlgo;
     size_t mWorkspaceSize = 0;
     void* mFwdWorkspace = nullptr;
+    void* mBwdWorkspace = nullptr;
     std::shared_ptr<Tensor> mInput0Fallback;
     std::shared_ptr<Tensor> mInput1Fallback;
     std::shared_ptr<Tensor> mInput2Fallback;
@@ -49,10 +52,12 @@ public:
 
 public:
     void forward();
+    void backward();
     ~ConvImpl_cuda();
 
 private:
     template <class T> void forward_(const Tensor& input0, const Tensor& input1, const Tensor& input2);
+    template <class T> void backward_(const Tensor& input0, const Tensor& input1, const Tensor& input2);
 };
 
 namespace {
diff --git a/include/aidge/backend/cuda/operator/ProducerImpl.hpp b/include/aidge/backend/cuda/operator/ProducerImpl.hpp
deleted file mode 100644
index 9912133072e23181df8f384841660bf89a829b60..0000000000000000000000000000000000000000
--- a/include/aidge/backend/cuda/operator/ProducerImpl.hpp
+++ /dev/null
@@ -1,40 +0,0 @@
-/********************************************************************************
- * Copyright (c) 2023 CEA-List
- *
- * This program and the accompanying materials are made available under the
- * terms of the Eclipse Public License 2.0 which is available at
- * http://www.eclipse.org/legal/epl-2.0.
- *
- * SPDX-License-Identifier: EPL-2.0
- *
- ********************************************************************************/
-
-#ifndef AIDGE_CUDA_OPERATOR_PRODUCERIMPL_H_
-#define AIDGE_CUDA_OPERATOR_PRODUCERIMPL_H_
-
-#include <memory>
-
-#include "aidge/backend/OperatorImpl.hpp"
-#include "aidge/operator/Producer.hpp"
-#include "aidge/utils/Registrar.hpp"
-#include "aidge/utils/Types.h"
-
-namespace Aidge {
-class ProducerImpl_cuda : public OperatorImpl {
-public:
-    ProducerImpl_cuda(const Producer_Op &op) : OperatorImpl(op) {}
-
-    static std::unique_ptr<ProducerImpl_cuda> create(const Producer_Op &op) {
-        return std::make_unique<ProducerImpl_cuda>(op);
-    }
-
-    NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final;
-    void forward() override;
-};
-
-namespace {
-static Registrar<Producer_Op> registrarProducerImpl_cuda("cuda", Aidge::ProducerImpl_cuda::create);
-}  // namespace
-}  // namespace Aidge
-
-#endif /* AIDGE_CUDA_OPERATOR_PRODUCERIMPL_H_ */
diff --git a/include/aidge/backend/cuda/utils/CudaContext.hpp b/include/aidge/backend/cuda/utils/CudaContext.hpp
index 82dd395e6bbb33bae29c5d881290d6996bfb0332..7218cc24aed718f57a1866be74e7ba9124a5a7f1 100644
--- a/include/aidge/backend/cuda/utils/CudaContext.hpp
+++ b/include/aidge/backend/cuda/utils/CudaContext.hpp
@@ -2,8 +2,8 @@
 #define AIDGE_BACKEND_CUDA_CUDA_CONTEXT_H
 
 #include <vector>
-#include <cstdio>
 
+#include "aidge/utils/ErrorHandling.hpp"
 #include "aidge/backend/cuda/utils/CudaUtils.hpp"
 
 namespace Aidge {
@@ -87,7 +87,7 @@ public:
 
         if (cublas_h[dev] == NULL) {
             CHECK_CUBLAS_STATUS(cublasCreate(&cublas_h[dev]));
-            printf("CUBLAS initialized on device #%d\n", dev);
+            fmt::print("CUBLAS initialized on device #{}\n", dev);
         }
 
         return cublas_h[dev];
@@ -113,7 +113,7 @@ public:
 
         if (cudnn_h[dev] == NULL) {
             CHECK_CUDNN_STATUS(cudnnCreate(&cudnn_h[dev]));
-            printf("CUDNN initialized on device #%d\n", dev);
+            fmt::print("CUDNN initialized on device #{}\n", dev);
         }
 
         return cudnn_h[dev];
diff --git a/include/aidge/backend/cuda/utils/CudaUtils.hpp b/include/aidge/backend/cuda/utils/CudaUtils.hpp
index 2f66d0e778778400f0b7def345619d635cc37674..c4929364d2ac455e50d174fbc311930106517d2c 100644
--- a/include/aidge/backend/cuda/utils/CudaUtils.hpp
+++ b/include/aidge/backend/cuda/utils/CudaUtils.hpp
@@ -11,6 +11,8 @@
 #include <cuda.h>
 #include <cudnn.h>
 
+#include "aidge/utils/ErrorHandling.hpp"
+
 #define CHECK_CUDNN_STATUS(status)                                             \
     do {                                                                       \
         const cudnnStatus_t e = (status);                                      \
diff --git a/src/operator/ConvImpl.cpp b/src/operator/ConvImpl.cpp
index 19ce56bcb99f60e08427f8d9b110637c90582adf..e2ee633db5473863edc25e340d9680b7ac613f39 100644
--- a/src/operator/ConvImpl.cpp
+++ b/src/operator/ConvImpl.cpp
@@ -157,6 +157,176 @@ void Aidge::ConvImpl_cuda<DIM>::forward_(const Tensor& input0, const Tensor& inp
     }
 }
 
+template <Aidge::DimIdx_t DIM>
+void Aidge::ConvImpl_cuda<DIM>::backward() {
+    const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
+
+    // FIXME: uncomment the following code once memory handling will work
+    assert(mOp.getRawInput(0) && "missing input #0");
+    assert(mOp.getRawInput(1) && "missing input #1");
+
+    // Convert input data (no overhead if not needed!)
+    const auto& input0 = op.getInput(0)->ref(mInput0Fallback, *op.getOutput(0));
+    const auto& input1 = op.getInput(1)->ref(mInput1Fallback, *op.getOutput(0));
+    const auto& input2 = op.getInput(2)->ref(mInput2Fallback, *op.getOutput(0));
+
+    // Set forward algorithm and allocate the required workspace
+    if (mBwdWorkspace == nullptr) {
+        // Find the best CuDNN backward algorithm (the one with the lowest compute time)
+        int maxAlgoIterations = 0;
+        cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(CudaContext::cudnnHandle(),
+                                                    &maxAlgoIterations);
+        assert(maxAlgoIterations > 0 && "No available CUDNN ConvolutionBackwardFilterAlgorithm");
+
+        int returnAlgoCounts = 0;
+        std::vector<cudnnConvolutionBwdFilterAlgoPerf_t> returnBwdFilterAlgo(maxAlgoIterations);
+
+        CHECK_CUDNN_STATUS(cudnnFindConvolutionBackwardFilterAlgorithm(
+            CudaContext::cudnnHandle(),
+            std::dynamic_pointer_cast<TensorImpl_cuda_>(input0.getImpl())->getCudnnTensorDesc(input0),
+            std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
+            mConvDesc,
+            mFilterDesc,
+            maxAlgoIterations,
+            &returnAlgoCounts,
+            &returnBwdFilterAlgo[0]));
+
+        mBwdFilterAlgo = returnBwdFilterAlgo[0].algo;
+
+        maxAlgoIterations = 0;
+        cudnnGetConvolutionBackwardDataAlgorithmMaxCount(CudaContext::cudnnHandle(),
+                                                    &maxAlgoIterations);
+        assert(maxAlgoIterations > 0 && "No available CUDNN ConvolutionBackwardDataAlgorithm");
+
+        returnAlgoCounts = 0;
+        std::vector<cudnnConvolutionBwdDataAlgoPerf_t> returnBwdDataAlgo(maxAlgoIterations);
+
+        CHECK_CUDNN_STATUS(cudnnFindConvolutionBackwardDataAlgorithm(
+            CudaContext::cudnnHandle(),
+            mFilterDesc,
+            std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
+            mConvDesc,
+            std::dynamic_pointer_cast<TensorImpl_cuda_>(input0.getImpl())->getCudnnTensorDesc(input0),
+            maxAlgoIterations,
+            &returnAlgoCounts,
+            &returnBwdDataAlgo[0]));
+
+        mBwdDataAlgo = returnBwdDataAlgo[0].algo;
+
+        // Allocate the workspace required by the chosen CuDNN backward algorithm
+        size_t workspaceSize = 0;
+        CHECK_CUDNN_STATUS(cudnnGetConvolutionBackwardFilterWorkspaceSize(
+            CudaContext::cudnnHandle(),
+            // same arguments as cudnnGetConvolutionBackwardFilterAlgorithm()
+            // -->
+            std::dynamic_pointer_cast<TensorImpl_cuda_>(input0.getImpl())->getCudnnTensorDesc(input0),
+            std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
+            mConvDesc,
+            mFilterDesc,
+            // <--
+            mBwdFilterAlgo,
+            &workspaceSize));
+
+        size_t workspaceSizeData = 0;
+        CHECK_CUDNN_STATUS(cudnnGetConvolutionBackwardDataWorkspaceSize(
+            CudaContext::cudnnHandle(),
+            // same arguments as cudnnGetConvolutionBackwardDataAlgorithm() -->
+            mFilterDesc,
+            std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
+            mConvDesc,
+            std::dynamic_pointer_cast<TensorImpl_cuda_>(input0.getImpl())->getCudnnTensorDesc(input0),
+            // <--
+            mBwdDataAlgo,
+            &workspaceSizeData));
+
+        if (workspaceSizeData > workspaceSize)
+            workspaceSize = workspaceSizeData;
+
+        if (workspaceSize > mWorkspaceSize) {
+            if (mFwdWorkspace != nullptr) {
+                cudaFree(mFwdWorkspace);
+            }
+            CHECK_CUDA_STATUS(cudaMalloc(&mFwdWorkspace, workspaceSize));
+            mWorkspaceSize = workspaceSize;
+        }
+
+        mBwdWorkspace = mFwdWorkspace;
+    }
+
+    // Do the actual backward computation
+    // Template is only for scaling parameters, which are always in float
+    // excepted when the convolution is performed in double precision.
+    if (op.getOutput(0)->dataType() == DataType::Float64) {
+        backward_<double>(input0, input1, input2);
+    }
+    else {
+        backward_<float>(input0, input1, input2);
+    }
+}
+
+template <Aidge::DimIdx_t DIM>
+template <class T>
+void Aidge::ConvImpl_cuda<DIM>::backward_(const Tensor& input0, const Tensor& input1, const Tensor& input2) {
+    const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
+
+    std::shared_ptr<Tensor> gradOutputFallback;
+    const auto& gradOutput = op.getOutput(0)->grad()->refCastFrom(gradOutputFallback, *(op.getInput(0)->grad()));
+
+    const T alpha = 1.0f;
+    const T beta = 0.0f;
+
+    CHECK_CUDNN_STATUS(cudnnConvolutionBackwardFilter(
+        CudaContext::cudnnHandle(),
+        &alpha,
+        std::dynamic_pointer_cast<TensorImpl_cuda_>(input0.getImpl())->getCudnnTensorDesc(input0),
+        input0.getImpl()->rawPtr(),
+        std::dynamic_pointer_cast<TensorImpl_cuda_>(gradOutput.getImpl())->getCudnnTensorDesc(gradOutput),
+        gradOutput.getImpl()->rawPtr(),
+        mConvDesc,
+        mBwdFilterAlgo,
+        mBwdWorkspace,
+        mWorkspaceSize,
+        &beta,
+        mFilterDesc,
+        op.getInput(1)->grad()->getImpl()->rawPtr()));
+
+    CHECK_CUDNN_STATUS(cudnnConvolutionBackwardData(
+        CudaContext::cudnnHandle(),
+        &alpha,
+        mFilterDesc,
+        input1.getImpl()->rawPtr(),
+        std::dynamic_pointer_cast<TensorImpl_cuda_>(gradOutput.getImpl())->getCudnnTensorDesc(gradOutput),
+        gradOutput.getImpl()->rawPtr(),
+        mConvDesc,
+        mBwdDataAlgo,
+        mBwdWorkspace,
+        mWorkspaceSize,
+        &beta,
+        std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getInput(0)->grad()->getImpl())->getCudnnTensorDesc(*op.getInput(0)),
+        op.getInput(0)->grad()->getImpl()->rawPtr()));
+
+    // Add bias (if there is any)
+    if (mOp.getRawInput(2) && input2.size() > 0) {
+        // Bias tensor needs to have the same number of dims than output tensor for cudnnAddTensor()
+        std::vector<DimSize_t> gradBiasDims(DIM+2, 1);
+        gradBiasDims[1] = op.getInput(2)->grad()->size();
+
+        // Create a dummy tensor with the right dims in order to get a CuDNN tensor descriptor (with getCudnnTensorDesc())
+        Tensor gradBias(op.getInput(2)->grad()->dataType());
+        gradBias.setBackend("cuda");
+        gradBias.resize(gradBiasDims);
+        // TODO: find a more elegant solution(?)
+
+        CHECK_CUDNN_STATUS(cudnnConvolutionBackwardBias(CudaContext::cudnnHandle(),
+            &alpha,
+            std::dynamic_pointer_cast<TensorImpl_cuda_>(gradOutput.getImpl())->getCudnnTensorDesc(gradOutput),
+            gradOutput.getImpl()->rawPtr(),
+            &beta,
+            std::dynamic_pointer_cast<TensorImpl_cuda_>(gradBias.getImpl())->getCudnnTensorDesc(gradBias),
+            op.getInput(2)->grad()->getImpl()->rawPtr()));
+    }
+}
+
 template <Aidge::DimIdx_t DIM>
 Aidge::ConvImpl_cuda<DIM>::~ConvImpl_cuda() {
     if (mConvDesc != nullptr) {
diff --git a/src/operator/ProducerImpl.cpp b/src/operator/ProducerImpl.cpp
deleted file mode 100644
index aca3c4945e357be13017e302cb6e7f12ba61237c..0000000000000000000000000000000000000000
--- a/src/operator/ProducerImpl.cpp
+++ /dev/null
@@ -1,34 +0,0 @@
-/********************************************************************************
- * Copyright (c) 2023 CEA-List
- *
- * This program and the accompanying materials are made available under the
- * terms of the Eclipse Public License 2.0 which is available at
- * http://www.eclipse.org/legal/epl-2.0.
- *
- * SPDX-License-Identifier: EPL-2.0
- *
- ********************************************************************************/
-
-#include <cassert>
-#include <numeric> // std::accumulate
-#include <vector>
-
-#include "aidge/data/Tensor.hpp"
-#include "aidge/operator/Producer.hpp"
-#include "aidge/utils/Types.h"
-
-#include "aidge/backend/cuda/operator/ProducerImpl.hpp"
-
-Aidge::DimSize_t Aidge::ProducerImpl_cuda::getNbProducedData(
-    Aidge::IOIndex_t outputIdx) const
-{
-    // Requires the whole tensors, regardless of available data on inputs
-    assert(outputIdx == 0 && "operator has only one output");
-    (void) outputIdx;
-
-    return std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->size();
-}
-
-void Aidge::ProducerImpl_cuda::forward()
-{
-}
diff --git a/src/utils/CudaUtils.cpp b/src/utils/CudaUtils.cpp
index a6e0514f9a949a805561d966e2a712701c18936c..ca3263a282322e70157b7537c502a63a3edb526f 100644
--- a/src/utils/CudaUtils.cpp
+++ b/src/utils/CudaUtils.cpp
@@ -40,7 +40,7 @@ void Aidge::Cuda::setMultiDevicePeerAccess(unsigned int size, unsigned int* devi
                     CHECK_CUDA_STATUS(cudaSetDevice(devices[j]));
                     const cudaError_t status = cudaDeviceEnablePeerAccess(devices[i], 0);
                     if (status == cudaErrorPeerAccessAlreadyEnabled) {
-                        printf("Peer access already enabled between device %d and device %d\n", devices[j], devices[i]);
+                        fmt::print("Peer access already enabled between device {} and device {}\n", devices[j], devices[i]);
                     } else {
                         CHECK_CUDA_STATUS(status);
                     }
diff --git a/unit_tests/Test_CastMove.cpp b/unit_tests/Test_CastMove.cpp
index 0b68a4f9dcb6c72df91506a6f92be8c31e95f068..c96600f79967c69e43b3c334d3624f6514b6f936 100644
--- a/unit_tests/Test_CastMove.cpp
+++ b/unit_tests/Test_CastMove.cpp
@@ -18,8 +18,8 @@
 #include "aidge/graph/Node.hpp"
 #include "aidge/graph/GraphView.hpp"
 #include "aidge/graph/OpArgs.hpp"
-#include "aidge/scheduler/Scheduler.hpp"
-#include "aidge/recipies/Recipies.hpp"
+#include "aidge/scheduler/SequentialScheduler.hpp"
+#include "aidge/recipes/Recipes.hpp"
 
 #include "aidge/backend/cuda.hpp"
 
diff --git a/version.txt b/version.txt
index 8a9ecc2ea99d607e92feae1656ddbf6fdd82a2c1..341cf11faf9a29504168de4e54beaad182c5adc5 100644
--- a/version.txt
+++ b/version.txt
@@ -1 +1 @@
-0.0.1
\ No newline at end of file
+0.2.0
\ No newline at end of file