diff --git a/include/aidge/backend/cuda/data/TensorImpl.hpp b/include/aidge/backend/cuda/data/TensorImpl.hpp
index c61e926c88a9baf1fcdf64794c2a975a1b891356..f7de7aaf434ddf8a19f225f1c5db49780fde88d2 100644
--- a/include/aidge/backend/cuda/data/TensorImpl.hpp
+++ b/include/aidge/backend/cuda/data/TensorImpl.hpp
@@ -29,6 +29,9 @@ void thrust_copy(const half_float::half* srcData, half_float::half* dstData, siz
  * class), but whose data type does not need to be known.
  */
 class TensorImpl_cuda_ {
+protected:
+    mutable cudnnTensorDescriptor_t mCudnnTensor = nullptr;
+
 public:
     /**
      * @brief Return the CuDNN tensor descriptor of the tensor.
@@ -36,7 +39,12 @@ public:
      * (which is therefore mutable in the derived class).
      * @return cudnnTensorDescriptor_t CuDNN tensor descriptor.
      */
-    virtual const cudnnTensorDescriptor_t& getCudnnTensorDesc() const = 0;
+    virtual const cudnnTensorDescriptor_t& getCudnnTensorDesc(const Tensor& tensor) const = 0;
+
+    virtual ~TensorImpl_cuda_() {
+        if (mCudnnTensor != nullptr)
+            cudnnDestroyTensorDescriptor(mCudnnTensor);
+    }
 };
 
 template <class T>
@@ -54,119 +62,116 @@ private:
     }
 
 private:
-    const Tensor &mTensor;  // Impl needs to access Tensor information, but is not
-                            // supposed to change it!
-    /// Pointer to the data and its capacity
     future_std::span<T> mData;
     /// If this instance own the data, std::unique_ptr manages it
     std::unique_ptr<T, decltype(&cudaDelete)> mDataOwner;
-    mutable cudnnTensorDescriptor_t mCudnnTensor = nullptr;
 
 public:
     static constexpr const char *Backend = "cuda";
 
-    TensorImpl_cuda(const Tensor &tensor) : TensorImpl(Backend), mTensor(tensor), mDataOwner(nullptr, cudaDelete) {}
+    TensorImpl_cuda(DeviceIdx_t device, NbElts_t length) : TensorImpl(Backend, device, length), mDataOwner(nullptr, cudaDelete) {}
 
     bool operator==(const TensorImpl &otherImpl) const override final;
 
-    static std::unique_ptr<TensorImpl_cuda> create(const Tensor &tensor) {
-        return std::make_unique<TensorImpl_cuda<T>>(tensor);
+    static std::shared_ptr<TensorImpl_cuda> create(DeviceIdx_t device, NbElts_t length) {
+        return std::make_shared<TensorImpl_cuda<T>>(device, length);
     }
 
     // native interface
     const future_std::span<T>& data() const { return mData; }
 
-    std::size_t size() const override { return mData.size(); }
     std::size_t scalarSize() const override { return sizeof(T); }
 
-    void setDevice(DeviceIdx_t device) override {
-        mDevice = device;
-    }
-
     void copy(const void *src, NbElts_t length, NbElts_t offset = 0) override {
-        void* dst = static_cast<void*>(static_cast<T*>(rawPtr()) + offset);
-        CHECK_CUDA_STATUS(cudaMemcpy(dst, src, length * sizeof(T), cudaMemcpyDeviceToDevice));
+        AIDGE_ASSERT(length <= mData.size() || length <= mNbElts, "copy length is above capacity");
+        const T* srcT = static_cast<const T *>(src);
+        T* dstT = static_cast<T *>(rawPtr(offset));
+
+        AIDGE_ASSERT(dstT < srcT || dstT >= srcT + length, "overlapping copy is not supported");
+        CHECK_CUDA_STATUS(cudaMemcpy(dstT, srcT, length * sizeof(T), cudaMemcpyDeviceToDevice));
     }
 
-    void copyCast(const void *src, NbElts_t length, const DataType srcDt) override {
+    void copyCast(const void *src, const DataType srcDt, NbElts_t length, NbElts_t offset = 0) override {
         if (length == 0) {
             return;
         }
 
-        AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity");
-        if (srcDt == DataType::Float64) {
+        AIDGE_ASSERT(length <= mData.size() || length <= mNbElts, "copy length is above capacity");
+        switch (srcDt) {
+        case DataType::Float64:
             thrust_copy(static_cast<const double*>(src),
-                        static_cast<T*>(rawPtr()),
+                        static_cast<T*>(rawPtr(offset)),
                         length);
-        }
-        else if (srcDt == DataType::Float32) {
+            break;
+        case DataType::Float32:
             thrust_copy(static_cast<const float*>(src),
-                        static_cast<T*>(rawPtr()),
+                        static_cast<T*>(rawPtr(offset)),
                         length);
-        }
-        else if (srcDt == DataType::Float16) {
+            break;
+        case DataType::Float16:
             thrust_copy(static_cast<const half_float::half*>(src),
-                        static_cast<T*>(rawPtr()),
+                        static_cast<T*>(rawPtr(offset)),
                         length);
-        }
-        else if (srcDt == DataType::Int64) {
+            break;
+        case DataType::Int64:
             thrust_copy(static_cast<const int64_t*>(src),
-                        static_cast<T*>(rawPtr()),
+                        static_cast<T*>(rawPtr(offset)),
                         length);
-        }
-        else if (srcDt == DataType::UInt64) {
+            break;
+        case DataType::UInt64:
             thrust_copy(static_cast<const uint64_t*>(src),
-                        static_cast<T*>(rawPtr()),
+                        static_cast<T*>(rawPtr(offset)),
                         length);
-        }
-        else if (srcDt == DataType::Int32) {
+            break;
+        case DataType::Int32:
             thrust_copy(static_cast<const int32_t*>(src),
-                        static_cast<T*>(rawPtr()),
+                        static_cast<T*>(rawPtr(offset)),
                         length);
-        }
-        else if (srcDt == DataType::UInt32) {
+            break;
+        case DataType::UInt32:
             thrust_copy(static_cast<const uint32_t*>(src),
-                        static_cast<T*>(rawPtr()),
+                        static_cast<T*>(rawPtr(offset)),
                         length);
-        }
-        else if (srcDt == DataType::Int16) {
+            break;
+        case DataType::Int16:
             thrust_copy(static_cast<const int16_t*>(src),
-                        static_cast<T*>(rawPtr()),
+                        static_cast<T*>(rawPtr(offset)),
                         length);
-        }
-        else if (srcDt == DataType::UInt16) {
+            break;
+        case DataType::UInt16:
             thrust_copy(static_cast<const uint16_t*>(src),
-                        static_cast<T*>(rawPtr()),
+                        static_cast<T*>(rawPtr(offset)),
                         length);
-        }
-        else if (srcDt == DataType::Int8) {
+            break;
+        case DataType::Int8:
             thrust_copy(static_cast<const int8_t*>(src),
-                        static_cast<T*>(rawPtr()),
+                        static_cast<T*>(rawPtr(offset)),
                         length);
-        }
-        else if (srcDt == DataType::UInt8) {
+            break;
+        case DataType::UInt8:
             thrust_copy(static_cast<const uint8_t*>(src),
-                        static_cast<T*>(rawPtr()),
+                        static_cast<T*>(rawPtr(offset)),
                         length);
-        }
-        else {
+            break;
+        default:
             AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported data type.");
+            break;
         }
     }
 
-    void copyFromDevice(const void *src, NbElts_t length, const std::pair<std::string, DeviceIdx_t>& device) override {
-        AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity");
-        CHECK_CUDA_STATUS(cudaMemcpy(rawPtr(), src, length * sizeof(T), cudaMemcpyDeviceToDevice));
+    void copyFromDevice(const void *src, const std::pair<std::string, DeviceIdx_t>& device, NbElts_t length, NbElts_t offset = 0) override {
+        AIDGE_ASSERT(length <= mData.size() || length <= mNbElts, "copy length is above capacity");
+        CHECK_CUDA_STATUS(cudaMemcpy(rawPtr(offset), src, length * sizeof(T), cudaMemcpyDeviceToDevice));
     }
 
-    void copyFromHost(const void *src, NbElts_t length) override {
-        AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity");
-        CHECK_CUDA_STATUS(cudaMemcpy(rawPtr(), src, length * sizeof(T), cudaMemcpyHostToDevice));
+    void copyFromHost(const void *src, NbElts_t length, NbElts_t offset = 0) override {
+        AIDGE_ASSERT(length <= mData.size() || length <= mNbElts, "copy length is above capacity");
+        CHECK_CUDA_STATUS(cudaMemcpy(rawPtr(offset), src, length * sizeof(T), cudaMemcpyHostToDevice));
     }
 
-    void copyToHost(void *dst, NbElts_t length) const override {
-        AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity");
-        CHECK_CUDA_STATUS(cudaMemcpy(dst, rawPtr(), length * sizeof(T), cudaMemcpyDeviceToHost));
+    void copyToHost(void *dst, NbElts_t length, NbElts_t offset = 0) const override {
+        AIDGE_ASSERT(length <= mData.size() || length <= mNbElts, "copy length is above capacity");
+        CHECK_CUDA_STATUS(cudaMemcpy(dst, rawPtr(offset), length * sizeof(T), cudaMemcpyDeviceToHost));
     }
 
     void *rawPtr(NbElts_t offset = 0) override {
@@ -175,30 +180,27 @@ public:
     };
 
     const void *rawPtr(NbElts_t offset = 0) const override {
-        AIDGE_ASSERT(mData.size() >= mTensor.size(), "accessing uninitialized const rawPtr");
+        AIDGE_ASSERT(mData.size() >= mNbElts, "accessing uninitialized const rawPtr");
         return (mData.data() + offset);
     };
 
-    const cudnnTensorDescriptor_t& getCudnnTensorDesc() const override {
+    const cudnnTensorDescriptor_t& getCudnnTensorDesc(const Tensor& tensor) const override {
         if (mCudnnTensor == nullptr) {
             CHECK_CUDNN_STATUS(cudnnCreateTensorDescriptor(&mCudnnTensor));
 
-            if (mTensor.size() > 0) {
+            if (tensor.size() > 0) {
                 /**
                 **      cudNN Tensors are restricted to having at least 4 dimensions :
                 **      When working with lower dimensionsal data, unused dimensions are set to 1.
                 **      Referes to the cudnnSetTensorNdDescriptor documentation from :
                 **      https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html
                 **/
-                std::vector<int> dims(mTensor.dims().begin(), mTensor.dims().end());
+                std::vector<int> dims(tensor.dims().cbegin(), tensor.dims().cend());
+                std::vector<int> strides(tensor.strides().cbegin(), tensor.strides().cend());
 
-                if (dims.size() < 4)
+                if (dims.size() < 4) {
                     dims.resize(4, 1);
-
-                std::vector<int> strides(dims.size(), 1);
-
-                for (size_t dim = 1; dim < dims.size(); ++dim) {
-                    strides[dims.size() - dim - 1] = strides[dims.size() - dim] * dims[dims.size() - dim];
+                    strides.resize(4, 1);
                 }
 
                 CHECK_CUDNN_STATUS(cudnnSetTensorNdDescriptor(mCudnnTensor,
@@ -213,23 +215,20 @@ public:
     }
 
     void setRawPtr(void *ptr, NbElts_t length) override final {
-        AIDGE_ASSERT(length >= mTensor.size(), "trying to set raw pointer of insufficient capacity");
+        AIDGE_ASSERT(length >= mNbElts, "trying to set raw pointer of insufficient capacity");
         mData = future_std::span<T>(static_cast<T *>(ptr), length);
         mDataOwner.reset();
     };
 
-    virtual ~TensorImpl_cuda() {
-        if (mCudnnTensor != nullptr)
-            cudnnDestroyTensorDescriptor(mCudnnTensor);
-    }
+    virtual ~TensorImpl_cuda() = default;
 
 private:
     void lazyInit() {
-        if (mData.size() < mTensor.size()) {
+        if (mData.size() < mNbElts) {
             // Need more data, a re-allocation will occur
             AIDGE_ASSERT(mData.empty() || mDataOwner != nullptr, "trying to enlarge non-owned data");
-            mDataOwner.reset(cudaAlloc(mTensor.size()));
-            mData = future_std::span<T>(mDataOwner.get(), mTensor.size());
+            mDataOwner.reset(cudaAlloc(mNbElts));
+            mData = future_std::span<T>(mDataOwner.get(), mNbElts);
         }
     }
 };
diff --git a/include/aidge/backend/cuda/operator/ConvImpl.hpp b/include/aidge/backend/cuda/operator/ConvImpl.hpp
index 0ad995b24082782c611b93fbcc040d1319a7362f..65180fd54aeb9ff349af37192061cf66415e0a77 100644
--- a/include/aidge/backend/cuda/operator/ConvImpl.hpp
+++ b/include/aidge/backend/cuda/operator/ConvImpl.hpp
@@ -35,12 +35,15 @@ private:
     cudnnFilterDescriptor_t mFilterDesc = nullptr;
     cudnnConvolutionFwdAlgo_t mFwdAlgo;
     size_t mWorkspaceSize = 0;
-    void* mWorkspace = nullptr;
+    void* mFwdWorkspace = nullptr;
+    std::shared_ptr<Tensor> mInput0Fallback;
+    std::shared_ptr<Tensor> mInput1Fallback;
+    std::shared_ptr<Tensor> mInput2Fallback;
 
 public:
     ConvImpl_cuda(const Conv_Op<DIM> &op) : OperatorImpl(op) {}
 
-    static std::unique_ptr<ConvImpl_cuda> create(const Conv_Op<2> &op) {
+    static std::unique_ptr<ConvImpl_cuda> create(const Conv_Op<DIM> &op) {
         return std::make_unique<ConvImpl_cuda>(op);
     }
 
diff --git a/src/data/TensorImpl.cu b/src/data/TensorImpl.cu
index ecacd4d678dd7d79462332fb28e238b063d8bdd1..898475b5db325afcaedff44756cc2157cf9e2eec 100644
--- a/src/data/TensorImpl.cu
+++ b/src/data/TensorImpl.cu
@@ -91,10 +91,10 @@ template <class T>
 bool Aidge::TensorImpl_cuda<T>::operator==(const TensorImpl &otherImpl) const {
     const auto& otherImplCuda = static_cast<const TensorImpl_cuda<T>&>(otherImpl);
 
-    if (mTensor.size() != otherImplCuda.mTensor.size())
+    if (mNbElts != otherImplCuda.size())
         return false;
 
     thrust::device_ptr<T> thrustData(mData.data());
     thrust::device_ptr<T> thrustOtherData(otherImplCuda.mData.data());
-    return thrust::equal(thrustData, thrustData + mTensor.size(), thrustOtherData);
+    return thrust::equal(thrustData, thrustData + mNbElts, thrustOtherData);
 }
diff --git a/src/operator/ConvImpl.cpp b/src/operator/ConvImpl.cpp
index 9c3684e89f6b27133ca99be16b332c4e9f9a27b1..19ce56bcb99f60e08427f8d9b110637c90582adf 100644
--- a/src/operator/ConvImpl.cpp
+++ b/src/operator/ConvImpl.cpp
@@ -24,18 +24,16 @@
 
 template <Aidge::DimIdx_t DIM>
 void Aidge::ConvImpl_cuda<DIM>::forward() {
+    const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
+
     // FIXME: uncomment the following code once memory handling will work
     assert(mOp.getRawInput(0) && "missing input #0");
     assert(mOp.getRawInput(1) && "missing input #1");
 
     // Convert input data (no overhead if not needed!)
-    // TODO: right now, if needed, memory will be allocated/deallocated at each
-    // call to forward(). We might put the following shared_ptr as members of
-    // this class to avoid that.
-    std::shared_ptr<Tensor> input0Fallback, input1Fallback, input2Fallback;
-    const auto& input0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->refCastFrom(input0Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
-    const auto& input1 = std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->refCastFrom(input1Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
-    const auto& input2 = std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->refCastFrom(input2Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
+    const auto& input0 = op.getInput(0)->refCastFrom(mInput0Fallback, *op.getOutput(0));
+    const auto& input1 = op.getInput(1)->refCastFrom(mInput1Fallback, *op.getOutput(0));
+    const auto& input2 = op.getInput(2)->refCastFrom(mInput2Fallback, *op.getOutput(0));
 
     // Lazy-initialize CuDNN convolution descriptor
     if (mConvDesc == nullptr) {
@@ -45,14 +43,13 @@ void Aidge::ConvImpl_cuda<DIM>::forward() {
         const std::vector<int> upscales(convOp.template getAttr<ConvAttr::DilationDims>().begin(), convOp.template getAttr<ConvAttr::DilationDims>().end());
 
         CHECK_CUDNN_STATUS(cudnnCreateConvolutionDescriptor(&mConvDesc));
-        CHECK_CUDNN_STATUS(
-            cudnnSetConvolutionNdDescriptor(mConvDesc,
-                                            DIM,
-                                            &paddings[0],
-                                            &strides[0],
-                                            &upscales[0],
-                                            CUDNN_CROSS_CORRELATION,
-                                            DataTypeToCudnn(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType())));
+        CHECK_CUDNN_STATUS(cudnnSetConvolutionNdDescriptor(mConvDesc,
+            DIM,
+            &paddings[0],
+            &strides[0],
+            &upscales[0],
+            CUDNN_CROSS_CORRELATION,
+            DataTypeToCudnn(op.getOutput(0)->dataType())));
     }
 
     // Lazy-initialize CuDNN filter descriptor
@@ -61,14 +58,14 @@ void Aidge::ConvImpl_cuda<DIM>::forward() {
 
         CHECK_CUDNN_STATUS(cudnnCreateFilterDescriptor(&mFilterDesc));
         CHECK_CUDNN_STATUS(cudnnSetFilterNdDescriptor(mFilterDesc,
-                                                    DataTypeToCudnn(input1.dataType()),
-                                                    CUDNN_TENSOR_NCHW,
-                                                    kernels.size(),
-                                                    &kernels[0]));
+            DataTypeToCudnn(input1.dataType()),
+            CUDNN_TENSOR_NCHW,
+            kernels.size(),
+            &kernels[0]));
     }
 
     // Set forward algorithm and allocate the required workspace
-    if (mWorkspace == nullptr) {
+    if (mFwdWorkspace == nullptr) {
         // Find the best CuDNN forward algorithm (the one with the lowest compute time)
         int maxAlgoIterations = 0;
         cudnnGetConvolutionForwardAlgorithmMaxCount(CudaContext::cudnnHandle(),
@@ -80,14 +77,14 @@ void Aidge::ConvImpl_cuda<DIM>::forward() {
         std::vector<cudnnConvolutionFwdAlgoPerf_t> returnFwdAlgo(maxAlgoIterations);
 
         CHECK_CUDNN_STATUS(cudnnFindConvolutionForwardAlgorithm(
-                            CudaContext::cudnnHandle(),
-                            dynamic_cast<TensorImpl_cuda_*>(input0.getImpl().get())->getCudnnTensorDesc(),
-                            mFilterDesc,
-                            mConvDesc,
-                            dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl().get())->getCudnnTensorDesc(),
-                            maxAlgoIterations,
-                            &returnAlgoCounts,
-                            &returnFwdAlgo[0]));
+            CudaContext::cudnnHandle(),
+            std::dynamic_pointer_cast<TensorImpl_cuda_>(input0.getImpl())->getCudnnTensorDesc(input0),
+            mFilterDesc,
+            mConvDesc,
+            std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
+            maxAlgoIterations,
+            &returnAlgoCounts,
+            &returnFwdAlgo[0]));
         mFwdAlgo = returnFwdAlgo[0].algo;
 
         // Allocate the workspace required by the chosen CuDNN forward algorithm
@@ -95,21 +92,21 @@ void Aidge::ConvImpl_cuda<DIM>::forward() {
 
         CHECK_CUDNN_STATUS(cudnnGetConvolutionForwardWorkspaceSize(
             CudaContext::cudnnHandle(),
-            dynamic_cast<TensorImpl_cuda_*>(input0.getImpl().get())->getCudnnTensorDesc(),
+            std::dynamic_pointer_cast<TensorImpl_cuda_>(input0.getImpl())->getCudnnTensorDesc(input0),
             mFilterDesc,
             mConvDesc,
-            dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl().get())->getCudnnTensorDesc(),
+            std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
             mFwdAlgo,
             &workspaceSize));
 
-        CHECK_CUDA_STATUS(cudaMalloc(&mWorkspace, workspaceSize));
+        CHECK_CUDA_STATUS(cudaMalloc(&mFwdWorkspace, workspaceSize));
         mWorkspaceSize = workspaceSize;
     }
 
     // Do the actual forward computation
     // Template is only for scaling parameters, which are always in float
     // excepted when the convolution is performed in double precision.
-    if (std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType() == DataType::Float64) {
+    if (op.getOutput(0)->dataType() == DataType::Float64) {
         forward_<double>(input0, input1, input2);
     }
     else {
@@ -120,23 +117,23 @@ void Aidge::ConvImpl_cuda<DIM>::forward() {
 template <Aidge::DimIdx_t DIM>
 template <class T>
 void Aidge::ConvImpl_cuda<DIM>::forward_(const Tensor& input0, const Tensor& input1, const Tensor& input2) {
+    const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
     const T alpha = 1.0f;
     const T beta = 0.0f;
 
-    CHECK_CUDNN_STATUS(
-        cudnnConvolutionForward(CudaContext::cudnnHandle(),
-                                &alpha,
-                                dynamic_cast<TensorImpl_cuda_*>(input0.getImpl().get())->getCudnnTensorDesc(),
-                                input0.getImpl()->rawPtr(),
-                                mFilterDesc,
-                                input1.getImpl()->rawPtr(),
-                                mConvDesc,
-                                mFwdAlgo,
-                                mWorkspace,
-                                mWorkspaceSize,
-                                &beta,
-                                dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl().get())->getCudnnTensorDesc(),
-                                std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()));
+    CHECK_CUDNN_STATUS(cudnnConvolutionForward(CudaContext::cudnnHandle(),
+        &alpha,
+        std::dynamic_pointer_cast<TensorImpl_cuda_>(input0.getImpl())->getCudnnTensorDesc(input0),
+        input0.getImpl()->rawPtr(),
+        mFilterDesc,
+        input1.getImpl()->rawPtr(),
+        mConvDesc,
+        mFwdAlgo,
+        mFwdWorkspace,
+        mWorkspaceSize,
+        &beta,
+        std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
+        op.getOutput(0)->getImpl()->rawPtr()));
 
     // Add bias (if there is any)
     if (mOp.getRawInput(2) && input2.size() > 0) {
@@ -151,12 +148,12 @@ void Aidge::ConvImpl_cuda<DIM>::forward_(const Tensor& input0, const Tensor& inp
         // TODO: find a more elegant solution(?)
 
         CHECK_CUDNN_STATUS(cudnnAddTensor(CudaContext::cudnnHandle(),
-                                            &alpha,
-                                            dynamic_cast<TensorImpl_cuda_*>(bias.getImpl().get())->getCudnnTensorDesc(),
-                                            input2.getImpl()->rawPtr(),
-                                            &alpha,
-                                            dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl().get())->getCudnnTensorDesc(),
-                                            std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()));
+            &alpha,
+            std::dynamic_pointer_cast<TensorImpl_cuda_>(bias.getImpl())->getCudnnTensorDesc(bias),
+            input2.getImpl()->rawPtr(),
+            &alpha,
+            std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
+            op.getOutput(0)->getImpl()->rawPtr()));
     }
 }
 
@@ -170,8 +167,8 @@ Aidge::ConvImpl_cuda<DIM>::~ConvImpl_cuda() {
         cudnnDestroyFilterDescriptor(mFilterDesc);
     }
 
-    if (mWorkspace != nullptr) {
-        cudaFree(mWorkspace);
+    if (mFwdWorkspace != nullptr) {
+        cudaFree(mFwdWorkspace);
     }
 }