diff --git a/include/aidge/backend/cuda/data/TensorImpl.hpp b/include/aidge/backend/cuda/data/TensorImpl.hpp
index 73b211113b3b21b9c8294e51e16cc001afad25e1..7f2ee4067ca65e93c1dad13d85d2a655f61d7770 100644
--- a/include/aidge/backend/cuda/data/TensorImpl.hpp
+++ b/include/aidge/backend/cuda/data/TensorImpl.hpp
@@ -5,11 +5,16 @@
 #include "aidge/data/Tensor.hpp"
 #include "aidge/utils/Registrar.hpp"
 #include "aidge/utils/Types.h"
+#include "aidge/utils/ErrorHandling.hpp"
 
 #include "aidge/backend/cuda/utils/CudaUtils.hpp"
 #include "aidge/backend/cuda/utils/CudaContext.hpp"
 
 namespace Aidge {
+
+template <typename SRC_T, typename DST_T>
+void thrust_copy(SRC_T* /*srcData*/, DST_T* /*dstData*/, size_t /*size*/);
+
 /**
  * @brief Abstract class for the TensorImpl_cuda class template.
  * @details Its purpose is to provide access to base methods that are specific 
@@ -51,14 +56,90 @@ class TensorImpl_cuda : public TensorImpl, public TensorImpl_cuda_  {
 
     std::size_t scalarSize() const override { return sizeof(T); }
 
+    void setDevice(int device) override {
+        mDevice = device;
+    }
+
     void copy(const void *src, NbElts_t length) override {
+        CHECK_CUDA_STATUS(cudaMemcpy(rawPtr(), src, length * sizeof(T), cudaMemcpyDeviceToDevice));
+    }
+
+    void copyCast(const void *src, NbElts_t length, const DataType srcDt) override {
+        if (srcDt == DataType::Float64) {
+            thrust_copy(static_cast<const double*>(src),
+                        static_cast<T*>(rawPtr()),
+                        length);
+        }
+        else if (srcDt == DataType::Float32) {
+            thrust_copy(static_cast<const float*>(src),
+                        static_cast<T*>(rawPtr()),
+                        length);
+        }
+        else if (srcDt == DataType::Int64) {
+            thrust_copy(static_cast<const int64_t*>(src),
+                        static_cast<T*>(rawPtr()),
+                        length);
+        }
+        else if (srcDt == DataType::UInt64) {
+            thrust_copy(static_cast<const uint64_t*>(src),
+                        static_cast<T*>(rawPtr()),
+                        length);
+        }
+        else if (srcDt == DataType::Int32) {
+            thrust_copy(static_cast<const int32_t*>(src),
+                        static_cast<T*>(rawPtr()),
+                        length);
+        }
+        else if (srcDt == DataType::UInt32) {
+            thrust_copy(static_cast<const uint32_t*>(src),
+                        static_cast<T*>(rawPtr()),
+                        length);
+        }
+        else if (srcDt == DataType::Int16) {
+            thrust_copy(static_cast<const int16_t*>(src),
+                        static_cast<T*>(rawPtr()),
+                        length);
+        }
+        else if (srcDt == DataType::UInt16) {
+            thrust_copy(static_cast<const uint16_t*>(src),
+                        static_cast<T*>(rawPtr()),
+                        length);
+        }
+        else if (srcDt == DataType::Int8) {
+            thrust_copy(static_cast<const int8_t*>(src),
+                        static_cast<T*>(rawPtr()),
+                        length);
+        }
+        else if (srcDt == DataType::UInt8) {
+            thrust_copy(static_cast<const uint8_t*>(src),
+                        static_cast<T*>(rawPtr()),
+                        length);
+        }
+        else {
+            AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported data type.");
+        }
+    }
+
+    void copyFromDevice(const void *src, NbElts_t length, const std::pair<std::string, int>& device) override {
+        CHECK_CUDA_STATUS(cudaMemcpy(rawPtr(), src, length * sizeof(T), cudaMemcpyDeviceToDevice));
+    }
+
+    void copyFromHost(const void *src, NbElts_t length) override {
         CHECK_CUDA_STATUS(cudaMemcpy(rawPtr(), src, length * sizeof(T), cudaMemcpyHostToDevice));
     }
 
+    void copyToHost(void *dst, NbElts_t length) override {
+        CHECK_CUDA_STATUS(cudaMemcpy(dst, rawPtr(), length * sizeof(T), cudaMemcpyDeviceToHost));
+    }
+
     void *rawPtr() override {
         lazyInit(reinterpret_cast<void**>(&mData));
         return mData;
-    }
+    };
+
+    void *hostPtr() override {
+        return nullptr;
+    };
 
     void* getRaw(std::size_t idx) {
         return static_cast<void*>(static_cast<T*>(rawPtr()) + idx);
diff --git a/src/data/TensorImpl.cu b/src/data/TensorImpl.cu
index ba2a1348b4c5eae8499884bfe2488d67f016a060..beb76f627c8a63410393d60e09cefaaaeaf6220d 100644
--- a/src/data/TensorImpl.cu
+++ b/src/data/TensorImpl.cu
@@ -14,6 +14,14 @@
 #include <thrust/equal.h>
 #include <thrust/device_ptr.h>
 
+template <typename SRC_T, typename DST_T>
+void Aidge::thrust_copy(SRC_T* srcData, DST_T* dstData, size_t size)
+{
+    thrust::device_ptr<SRC_T> thrustSrcPtr(srcData);
+    thrust::device_ptr<DST_T> thrustDstPtr(dstData);
+    thrust::copy(thrustSrcPtr, thrustSrcPtr + size, thrustDstPtr);
+}
+
 template <class T>
 bool Aidge::TensorImpl_cuda<T>::operator==(const TensorImpl &otherImpl) const {
     const auto& otherImplCuda = static_cast<const TensorImpl_cuda<T>&>(otherImpl);
diff --git a/src/operator/ConvImpl.cpp b/src/operator/ConvImpl.cpp
index 515f5f19d7702ea5bc037b672e182e97800a703b..88b5e7deb26c6b5cff7ea7a0466102f7c3c6a7f3 100644
--- a/src/operator/ConvImpl.cpp
+++ b/src/operator/ConvImpl.cpp
@@ -25,8 +25,8 @@
 template <Aidge::DimIdx_t DIM>
 void Aidge::ConvImpl_cuda<DIM>::forward() {
     // FIXME: uncomment the following code once memory handling will work
-    assert(mOp.getInput(0) && "missing input #0");
-    assert(mOp.getInput(1) && "missing input #1");
+    assert(mOp.getRawInput(0) && "missing input #0");
+    assert(mOp.getRawInput(1) && "missing input #1");
 
     // Lazy-initialize CuDNN convolution descriptor
     if (mConvDesc == nullptr) {
@@ -43,16 +43,16 @@ void Aidge::ConvImpl_cuda<DIM>::forward() {
                                             &strides[0],
                                             &upscales[0],
                                             CUDNN_CROSS_CORRELATION,
-                                            DataTypeToCudnn(mOp.getOutput(0)->dataType())));
+                                            DataTypeToCudnn(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType())));
     }
 
     // Lazy-initialize CuDNN filter descriptor
     if (mFilterDesc == nullptr) {
-        const std::vector<int> kernels(mOp.getInput(1)->dims().begin(), mOp.getInput(1)->dims().end());
+        const std::vector<int> kernels(std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dims().begin(), std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dims().end());
 
         CHECK_CUDNN_STATUS(cudnnCreateFilterDescriptor(&mFilterDesc));
         CHECK_CUDNN_STATUS(cudnnSetFilterNdDescriptor(mFilterDesc,
-                                                    DataTypeToCudnn(mOp.getInput(1)->dataType()),
+                                                    DataTypeToCudnn(std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dataType()),
                                                     CUDNN_TENSOR_NCHW,
                                                     kernels.size(),
                                                     &kernels[0]));
@@ -72,10 +72,10 @@ void Aidge::ConvImpl_cuda<DIM>::forward() {
 
         CHECK_CUDNN_STATUS(cudnnFindConvolutionForwardAlgorithm(
                             CudaContext::cudnnHandle(),
-                            dynamic_cast<TensorImpl_cuda_*>(mOp.getInput(0)->getImpl().get())->getCudnnTensorDesc(),
+                            dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl().get())->getCudnnTensorDesc(),
                             mFilterDesc,
                             mConvDesc,
-                            dynamic_cast<TensorImpl_cuda_*>(mOp.getOutput(0)->getImpl().get())->getCudnnTensorDesc(),
+                            dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl().get())->getCudnnTensorDesc(),
                             maxAlgoIterations,
                             &returnAlgoCounts,
                             &returnFwdAlgo[0]));
@@ -86,10 +86,10 @@ void Aidge::ConvImpl_cuda<DIM>::forward() {
 
         CHECK_CUDNN_STATUS(cudnnGetConvolutionForwardWorkspaceSize(
             CudaContext::cudnnHandle(),
-            dynamic_cast<TensorImpl_cuda_*>(mOp.getInput(0)->getImpl().get())->getCudnnTensorDesc(),
+            dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl().get())->getCudnnTensorDesc(),
             mFilterDesc,
             mConvDesc,
-            dynamic_cast<TensorImpl_cuda_*>(mOp.getOutput(0)->getImpl().get())->getCudnnTensorDesc(),
+            dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl().get())->getCudnnTensorDesc(),
             mFwdAlgo,
             &workspaceSize));
 
@@ -100,7 +100,7 @@ void Aidge::ConvImpl_cuda<DIM>::forward() {
     // Do the actual forward computation
     // Template is only for scaling parameters, which are always in float
     // excepted when the convolution is performed in double precision.
-    if (mOp.getOutput(0)->dataType() == DataType::Float64) {
+    if (std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType() == DataType::Float64) {
         forward_<double>();
     }
     else {
@@ -117,26 +117,26 @@ void Aidge::ConvImpl_cuda<DIM>::forward_() {
     CHECK_CUDNN_STATUS(
         cudnnConvolutionForward(CudaContext::cudnnHandle(),
                                 &alpha,
-                                dynamic_cast<TensorImpl_cuda_*>(mOp.getInput(0)->getImpl().get())->getCudnnTensorDesc(),
-                                mOp.getInput(0)->getImpl()->rawPtr(),
+                                dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl().get())->getCudnnTensorDesc(),
+                                std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(),
                                 mFilterDesc,
-                                mOp.getInput(1)->getImpl()->rawPtr(),
+                                std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->getImpl()->rawPtr(),
                                 mConvDesc,
                                 mFwdAlgo,
                                 mWorkspace,
                                 mWorkspaceSize,
                                 &beta,
-                                dynamic_cast<TensorImpl_cuda_*>(mOp.getOutput(0)->getImpl().get())->getCudnnTensorDesc(),
-                                mOp.getOutput(0)->getImpl()->rawPtr()));
+                                dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl().get())->getCudnnTensorDesc(),
+                                std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()));
 
     // Add bias (if there is any)
-    if (mOp.getInput(2) && mOp.getInput(2)->size() > 0) {
+    if (mOp.getRawInput(2) && std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->size() > 0) {
         // Bias tensor needs to have the same number of dims than output tensor for cudnnAddTensor()
         std::vector<DimSize_t> biasDims(DIM+2, 1);
-        biasDims[1] = mOp.getInput(2)->size();
+        biasDims[1] = std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->size();
 
         // Create a dummy tensor with the right dims in order to get a CuDNN tensor descriptor (with getCudnnTensorDesc())
-        Tensor bias(mOp.getInput(2)->dataType());
+        Tensor bias(std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->dataType());
         bias.setBackend("cuda");
         bias.resize(biasDims);
         // TODO: find a more elegant solution(?)
@@ -144,10 +144,10 @@ void Aidge::ConvImpl_cuda<DIM>::forward_() {
         CHECK_CUDNN_STATUS(cudnnAddTensor(CudaContext::cudnnHandle(),
                                             &alpha,
                                             dynamic_cast<TensorImpl_cuda_*>(bias.getImpl().get())->getCudnnTensorDesc(),
-                                            mOp.getInput(2)->getImpl()->rawPtr(),
+                                            std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->getImpl()->rawPtr(),
                                             &alpha,
-                                            dynamic_cast<TensorImpl_cuda_*>(mOp.getOutput(0)->getImpl().get())->getCudnnTensorDesc(),
-                                            mOp.getOutput(0)->getImpl()->rawPtr()));
+                                            dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl().get())->getCudnnTensorDesc(),
+                                            std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()));
     }
 }
 
diff --git a/unit_tests/Test_ConvImpl.cpp b/unit_tests/Test_ConvImpl.cpp
index 659528dd1b2a45fcdd67ca0bd3440391a0e79654..b7faadd677336b9ff72274ea250251f95785b24f 100644
--- a/unit_tests/Test_ConvImpl.cpp
+++ b/unit_tests/Test_ConvImpl.cpp
@@ -25,8 +25,9 @@ using namespace Aidge;
 TEST_CASE("[gpu/operator] Conv(forward)") {
     SECTION("Simple Conv no bias") {
         std::shared_ptr<Node> myConv = Conv(1,1,{3,3}, "myconv");
-        myConv->getOperator()->setDatatype(DataType::Float32);
-        myConv->getOperator()->setBackend("cuda");
+        auto op = std::static_pointer_cast<OperatorTensor>(myConv->getOperator());
+        op->setDataType(DataType::Float32);
+        op->setBackend("cuda");
         std::shared_ptr<Tensor> myWeights = std::make_shared<Tensor>(Array4D<float,1,1,3,3> {
             {
                 {
@@ -50,12 +51,12 @@ TEST_CASE("[gpu/operator] Conv(forward)") {
         myInput->setBackend("cuda");
         myWeights->setBackend("cuda");
 
-        myConv->getOperator()->associateInput(0,myInput);
-        myConv->getOperator()->associateInput(1,myWeights);
-        myConv->getOperator()->computeOutputDims();
+        op->associateInput(0,myInput);
+        op->associateInput(1,myWeights);
+        op->computeOutputDims();
         myConv->forward();
 
-        REQUIRE(myConv->getOperator()->getOutput(0)->size() == 1);
+        REQUIRE(op->getOutput(0)->size() == 1);
 
         std::array<float, 9> kernel;
         cudaMemcpy(&kernel[0], myWeights->getImpl()->rawPtr(), 9 * sizeof(float), cudaMemcpyDeviceToHost);
@@ -68,15 +69,16 @@ TEST_CASE("[gpu/operator] Conv(forward)") {
         }
 
         float computedOutput;
-        cudaMemcpy(&computedOutput, myConv->getOperator()->getOutput(0)->getImpl()->rawPtr(), sizeof(float), cudaMemcpyDeviceToHost);
+        cudaMemcpy(&computedOutput, op->getOutput(0)->getImpl()->rawPtr(), sizeof(float), cudaMemcpyDeviceToHost);
 
         REQUIRE(fabs(computedOutput - myOutput) < 1e-6);
     }
 
     SECTION("Classic Conv") {
         std::shared_ptr<Node> myConv = Conv(3,4,{3,3}, "myconv");
-        myConv->getOperator()->setDatatype(DataType::Float32);
-        myConv->getOperator()->setBackend("cuda");
+        auto op = std::static_pointer_cast<OperatorTensor>(myConv->getOperator());
+        op->setDataType(DataType::Float32);
+        op->setBackend("cuda");
         std::shared_ptr<Tensor> myWeights = std::make_shared<Tensor>(Array4D<float,4,3,3,3> {
             {
                 {
@@ -205,15 +207,15 @@ TEST_CASE("[gpu/operator] Conv(forward)") {
         myWeights->setBackend("cuda");
         myBias->setBackend("cuda");
 
-        myConv->getOperator()->associateInput(0,myInput);
-        myConv->getOperator()->associateInput(1,myWeights);
-        myConv->getOperator()->associateInput(2,myBias);
-        myConv->getOperator()->computeOutputDims();
+        op->associateInput(0,myInput);
+        op->associateInput(1,myWeights);
+        op->associateInput(2,myBias);
+        op->computeOutputDims();
         myConv->forward();
-        // myConv->getOperator()->getOutput(0)->print();
+        // op->getOutput(0)->print();
 
         float* computedOutput   = new float[myOutput->size()]();
-        cudaMemcpy(computedOutput, myConv->getOperator()->getOutput(0)->getImpl()->rawPtr(), sizeof(float) * myOutput->size(), cudaMemcpyDeviceToHost);
+        cudaMemcpy(computedOutput, op->getOutput(0)->getImpl()->rawPtr(), sizeof(float) * myOutput->size(), cudaMemcpyDeviceToHost);
 
         for(int i = 0; i < myOutput->size(); i++){
             const float targetOutput = *(static_cast<float*>(myOutput->getImpl()->rawPtr()) + i);