diff --git a/include/aidge/backend/cuda/data/TensorImpl.hpp b/include/aidge/backend/cuda/data/TensorImpl.hpp
index 67f6175efd7bcc77ec26da2c657982bf0229e038..dfd347dc53ff019048d36b401443335acb2c1f9d 100644
--- a/include/aidge/backend/cuda/data/TensorImpl.hpp
+++ b/include/aidge/backend/cuda/data/TensorImpl.hpp
@@ -282,16 +282,17 @@ private:
 template <typename T>
 const std::string TensorImpl_cuda<T>::Backend = "cuda";
 
-namespace {
-static Registrar<Tensor> registrarTensorImpl_cuda_Float64(
-        {"cuda", DataType::Float64}, Aidge::TensorImpl_cuda<double>::create);
-static Registrar<Tensor> registrarTensorImpl_cuda_Float32(
-        {"cuda", DataType::Float32}, Aidge::TensorImpl_cuda<float>::create);
-static Registrar<Tensor> registrarTensorImpl_cuda_Float16(
-        {"cuda", DataType::Float16}, Aidge::TensorImpl_cuda<half_float::half>::create);
-static Registrar<Tensor> registrarTensorImpl_cuda_Int32(
-        {"cuda", DataType::Int32}, Aidge::TensorImpl_cuda<int32_t>::create);
-}  // namespace
+REGISTRAR(Tensor, {"cuda", DataType::Float64}, Aidge::TensorImpl_cuda<double>::create);
+REGISTRAR(Tensor, {"cuda", DataType::Float32}, Aidge::TensorImpl_cuda<float>::create);
+REGISTRAR(Tensor, {"cuda", DataType::Float16}, Aidge::TensorImpl_cuda<half_float::half>::create);
+REGISTRAR(Tensor, {"cuda", DataType::Int64}, Aidge::TensorImpl_cuda<int64_t>::create);
+REGISTRAR(Tensor, {"cuda", DataType::Int32}, Aidge::TensorImpl_cuda<int32_t>::create);
+REGISTRAR(Tensor, {"cuda", DataType::Int16}, Aidge::TensorImpl_cuda<int16_t>::create);
+REGISTRAR(Tensor, {"cuda", DataType::Int8}, Aidge::TensorImpl_cuda<int8_t>::create);
+REGISTRAR(Tensor, {"cuda", DataType::UInt64}, Aidge::TensorImpl_cuda<uint64_t>::create);
+REGISTRAR(Tensor, {"cuda", DataType::UInt32}, Aidge::TensorImpl_cuda<uint32_t>::create);
+REGISTRAR(Tensor, {"cuda", DataType::UInt16}, Aidge::TensorImpl_cuda<uint16_t>::create);
+REGISTRAR(Tensor, {"cuda", DataType::UInt8}, Aidge::TensorImpl_cuda<uint8_t>::create);
 }  // namespace Aidge
 
 #endif /* AIDGE_BACKEND_CUDA_DATA_TENSORIMPL_H_ */
diff --git a/include/aidge/backend/cuda/operator/AddImpl.hpp b/include/aidge/backend/cuda/operator/AddImpl.hpp
index 42d420f8410f79100fdfdbe3eabb8b43e616a74a..719c447b47b1dd6e48268f2e3d1906ae0d1753d1 100644
--- a/include/aidge/backend/cuda/operator/AddImpl.hpp
+++ b/include/aidge/backend/cuda/operator/AddImpl.hpp
@@ -46,10 +46,18 @@ public:
 
     void forward() override;
     void backward() override;
+    ~AddImpl_cuda();
 
 private:
-    template <class T> void forward_(const std::vector<Tensor>& inputs, const std::vector<std::vector<int>>& inputsDims, const std::vector<std::vector<int>>& inputsStrides);
-    template <class T> void backward_(const Tensor& outGrad, const std::vector<std::vector<int>>& inputsDims, const std::vector<std::vector<int>>& inputsStrides);
+    std::vector<cudnnTensorDescriptor_t> mTensorDesc;
+    cudnnReduceTensorDescriptor_t mBwdReduceDesc = nullptr;
+    size_t mBwdWorkspaceSize = 0;
+    void* mBwdWorkspace = nullptr;
+    std::vector<std::shared_ptr<Tensor>> mInputFallbacks;
+    std::shared_ptr<Tensor> mOutputGradFallback;
+
+    template <class T> void forward_(const std::vector<std::reference_wrapper<Tensor>>& inputs);
+    template <class T> void backward_(const Tensor& outGrad);
 };
 
 // Implementation entry point registration to Operator
diff --git a/include/aidge/backend/cuda/operator/MulImpl.hpp b/include/aidge/backend/cuda/operator/MulImpl.hpp
index 9a1a4d79d32c7a962d2086319d948e60a9f51049..4c1e2765eb7670dc16d83ea80d8baa9fded13c1a 100644
--- a/include/aidge/backend/cuda/operator/MulImpl.hpp
+++ b/include/aidge/backend/cuda/operator/MulImpl.hpp
@@ -46,10 +46,16 @@ public:
 
     void forward() override;
     void backward() override;
+    ~MulImpl_cuda();
 
 private:
-    template <class T> void forward_(const std::vector<Tensor>& inputs, const std::vector<std::vector<int>>& inputsDims, const std::vector<std::vector<int>>& inputsStrides);
-    template <class T> void backward_(const Tensor& outputGrad, const std::vector<std::vector<int>>& inputsDims, const std::vector<std::vector<int>>& inputsStrides);
+    std::vector<cudnnTensorDescriptor_t> mTensorDesc;
+    cudnnOpTensorDescriptor_t mOpTensorDesc = nullptr;
+    std::vector<std::shared_ptr<Tensor>> mInputFallbacks;
+    std::shared_ptr<Tensor> mOutputGradFallback;
+
+    template <class T> void forward_(const std::vector<std::reference_wrapper<Tensor>>& inputs);
+    template <class T> void backward_(const Tensor& outputGrad);
 };
 
 // Implementation entry point registration to Operator
diff --git a/include/aidge/backend/cuda/operator/ReduceMeanImpl.hpp b/include/aidge/backend/cuda/operator/ReduceMeanImpl.hpp
index 1f6878480d69e19f8c73a12862cc12b2d675440d..d84f39c2cc5b0b89878c0cc9149edc298ab2c326 100644
--- a/include/aidge/backend/cuda/operator/ReduceMeanImpl.hpp
+++ b/include/aidge/backend/cuda/operator/ReduceMeanImpl.hpp
@@ -46,10 +46,17 @@ public:
 
     void forward() override;
     void backward() override;
+    ~ReduceMeanImpl_cuda();
 
+private:
 private:
     // CuDNN specific variables
-    std::shared_ptr<Tensor> mInputFallback, mOutputGradFallback;
+    cudnnReduceTensorDescriptor_t mReduceDesc = nullptr;
+    cudnnTensorDescriptor_t mOutputDesc = nullptr;
+    size_t mWorkspaceSize = 0;
+    void* mWorkspace = nullptr;
+    std::shared_ptr<Tensor> mInputFallback;
+    std::shared_ptr<Tensor> mOutputGradFallback;
 
     template <class T> void forward_(const Tensor& input, const std::vector<int>& axes, bool keepDims);
     template <class T> void backward_(const Tensor& output_grad, const std::vector<int>& axes);
diff --git a/include/aidge/backend/cuda/operator/ReduceSumImpl.hpp b/include/aidge/backend/cuda/operator/ReduceSumImpl.hpp
index 10af90ba3a4ffc1d1464dd73f15313315b0c0032..500670d00043361e353e3e689a23b0a5d1c69530 100644
--- a/include/aidge/backend/cuda/operator/ReduceSumImpl.hpp
+++ b/include/aidge/backend/cuda/operator/ReduceSumImpl.hpp
@@ -46,10 +46,16 @@ public:
 
     void forward() override;
     void backward() override;
+    ~ReduceSumImpl_cuda();
 
 private:
     // CuDNN specific variables
-    std::shared_ptr<Tensor> mInputFallback, mOutputGradFallback;
+    cudnnReduceTensorDescriptor_t mReduceDesc = nullptr;
+    cudnnTensorDescriptor_t mOutputDesc = nullptr;
+    size_t mWorkspaceSize = 0;
+    void* mWorkspace = nullptr;
+    std::shared_ptr<Tensor> mInputFallback;
+    std::shared_ptr<Tensor> mOutputGradFallback;
 
     template <class T> void forward_(const Tensor& input, const std::vector<int>& axes, bool keepDims);
     template <class T> void backward_(const Tensor& output_grad, const std::vector<int>& axes);
diff --git a/include/aidge/backend/cuda/operator/SubImpl.hpp b/include/aidge/backend/cuda/operator/SubImpl.hpp
index 529d0b2b2dd4a0ec8a3dae5bf0219f8a4f2968c6..791e231e166caf1565a206de793726ca42983e89 100644
--- a/include/aidge/backend/cuda/operator/SubImpl.hpp
+++ b/include/aidge/backend/cuda/operator/SubImpl.hpp
@@ -46,10 +46,18 @@ public:
 
     void forward() override;
     void backward() override;
+    ~SubImpl_cuda();
 
 private:
-    template <class T> void forward_(const std::vector<Tensor>& inputs, const std::vector<std::vector<int>>& inputsDims, const std::vector<std::vector<int>>& inputsStrides);
-    template <class T> void backward_(const Tensor& outGrad, const std::vector<std::vector<int>>& inputsDims, const std::vector<std::vector<int>>& inputsStrides);
+    std::vector<cudnnTensorDescriptor_t> mTensorDesc;
+    cudnnReduceTensorDescriptor_t mBwdReduceDesc = nullptr;
+    size_t mBwdWorkspaceSize = 0;
+    void* mBwdWorkspace = nullptr;
+    std::vector<std::shared_ptr<Tensor>> mInputFallbacks;
+    std::shared_ptr<Tensor> mOutputGradFallback;
+
+    template <class T> void forward_(const std::vector<std::reference_wrapper<Tensor>>& inputs);
+    template <class T> void backward_(const Tensor& outGrad);
 };
 
 // Implementation entry point registration to Operator
diff --git a/src/data/TensorImpl.cu b/src/data/TensorImpl.cu
index c70b024fbab1a031ea69d5d9b169dc115b7320db..208414348f1f346c765cf8b97e919e3053513df0 100644
--- a/src/data/TensorImpl.cu
+++ b/src/data/TensorImpl.cu
@@ -36,7 +36,7 @@ cudaCopyToH_kernel(const SRC_T* srcData,
     }
 }
 
-template <typename SRC_T, typename std::enable_if<!std::is_same<half_float::half, SRC_T>::value>::type* = nullptr>
+template <typename SRC_T, typename std::enable_if<!std::is_same<half_float::half, SRC_T>::value>::type*>
 void Aidge::thrust_copy(const SRC_T* srcData, half_float::half* dstData, size_t size)
 {
     cudaCopyToH_kernel<SRC_T><<<(size + 255) / 256, 256>>>
@@ -58,7 +58,7 @@ cudaCopyFromH_kernel(const __half* srcData,
     }
 }
 
-template <typename DST_T, typename std::enable_if<!std::is_same<half_float::half, DST_T>::value>::type* = nullptr>
+template <typename DST_T, typename std::enable_if<!std::is_same<half_float::half, DST_T>::value>::type*>
 void Aidge::thrust_copy(const half_float::half* srcData, DST_T* dstData, size_t size)
 {
     cudaCopyFromH_kernel<DST_T><<<(size + 255) / 256, 256>>>
@@ -99,33 +99,135 @@ bool Aidge::TensorImpl_cuda<T>::operator==(const TensorImpl &otherImpl) const {
     return thrust::equal(thrustData, thrustData + mNbElts, thrustOtherData);
 }
 
-template void Aidge::thrust_copy<double, double>(double const*, double*, unsigned long);
-template void Aidge::thrust_copy<double, float>(double const*, float*, unsigned long);
-template void Aidge::thrust_copy<double, int>(double const*, int*, unsigned long);
-template void Aidge::thrust_copy<float, double>(float const*, double*, unsigned long);
-template void Aidge::thrust_copy<float, float>(float const*, float*, unsigned long);
-template void Aidge::thrust_copy<float, int>(float const*, int*, unsigned long);
-template void Aidge::thrust_copy<int, double>(int const*, double*, unsigned long);
-template void Aidge::thrust_copy<int, float>(int const*, float*, unsigned long);
-template void Aidge::thrust_copy<int, int>(int const*, int*, unsigned long);
-template void Aidge::thrust_copy<long, double>(long const*, double*, unsigned long);
-template void Aidge::thrust_copy<long, float>(long const*, float*, unsigned long);
-template void Aidge::thrust_copy<long, int>(long const*, int*, unsigned long);
-template void Aidge::thrust_copy<short, double>(short const*, double*, unsigned long);
-template void Aidge::thrust_copy<short, float>(short const*, float*, unsigned long);
-template void Aidge::thrust_copy<short, int>(short const*, int*, unsigned long);
-template void Aidge::thrust_copy<signed char, double>(signed char const*, double*, unsigned long);
-template void Aidge::thrust_copy<signed char, float>(signed char const*, float*, unsigned long);
-template void Aidge::thrust_copy<signed char, int>(signed char const*, int*, unsigned long);
-template void Aidge::thrust_copy<unsigned char, double>(unsigned char const*, double*, unsigned long);
-template void Aidge::thrust_copy<unsigned char, float>(unsigned char const*, float*, unsigned long);
-template void Aidge::thrust_copy<unsigned char, int>(unsigned char const*, int*, unsigned long);
-template void Aidge::thrust_copy<unsigned int, double>(unsigned int const*, double*, unsigned long);
-template void Aidge::thrust_copy<unsigned int, float>(unsigned int const*, float*, unsigned long);
-template void Aidge::thrust_copy<unsigned int, int>(unsigned int const*, int*, unsigned long);
-template void Aidge::thrust_copy<unsigned long, double>(unsigned long const*, double*, unsigned long);
-template void Aidge::thrust_copy<unsigned long, float>(unsigned long const*, float*, unsigned long);
-template void Aidge::thrust_copy<unsigned long, int>(unsigned long const*, int*, unsigned long);
-template void Aidge::thrust_copy<unsigned short, double>(unsigned short const*, double*, unsigned long);
-template void Aidge::thrust_copy<unsigned short, float>(unsigned short const*, float*, unsigned long);
-template void Aidge::thrust_copy<unsigned short, int>(unsigned short const*, int*, unsigned long);
+// double
+template void Aidge::thrust_copy<>(double const*, double*, size_t);
+template void Aidge::thrust_copy<>(double const*, float*, size_t);
+template void Aidge::thrust_copy<>(double const*, half_float::half*, size_t);
+template void Aidge::thrust_copy<>(double const*, int64_t*, size_t);
+template void Aidge::thrust_copy<>(double const*, int32_t*, size_t);
+template void Aidge::thrust_copy<>(double const*, int16_t*, size_t);
+template void Aidge::thrust_copy<>(double const*, int8_t*, size_t);
+template void Aidge::thrust_copy<>(double const*, uint64_t*, size_t);
+template void Aidge::thrust_copy<>(double const*, uint32_t*, size_t);
+template void Aidge::thrust_copy<>(double const*, uint16_t*, size_t);
+template void Aidge::thrust_copy<>(double const*, uint8_t*, size_t);
+// float
+template void Aidge::thrust_copy<>(float const*, double*, size_t);
+template void Aidge::thrust_copy<>(float const*, float*, size_t);
+template void Aidge::thrust_copy<>(float const*, half_float::half*, size_t);
+template void Aidge::thrust_copy<>(float const*, int64_t*, size_t);
+template void Aidge::thrust_copy<>(float const*, int32_t*, size_t);
+template void Aidge::thrust_copy<>(float const*, int16_t*, size_t);
+template void Aidge::thrust_copy<>(float const*, int8_t*, size_t);
+template void Aidge::thrust_copy<>(float const*, uint64_t*, size_t);
+template void Aidge::thrust_copy<>(float const*, uint32_t*, size_t);
+template void Aidge::thrust_copy<>(float const*, uint16_t*, size_t);
+template void Aidge::thrust_copy<>(float const*, uint8_t*, size_t);
+// half_float::half
+template void Aidge::thrust_copy<>(half_float::half const*, double*, size_t);
+template void Aidge::thrust_copy<>(half_float::half const*, float*, size_t);
+template void Aidge::thrust_copy<>(half_float::half const*, half_float::half*, size_t);
+template void Aidge::thrust_copy<>(half_float::half const*, int64_t*, size_t);
+template void Aidge::thrust_copy<>(half_float::half const*, int32_t*, size_t);
+template void Aidge::thrust_copy<>(half_float::half const*, int16_t*, size_t);
+template void Aidge::thrust_copy<>(half_float::half const*, int8_t*, size_t);
+template void Aidge::thrust_copy<>(half_float::half const*, uint64_t*, size_t);
+template void Aidge::thrust_copy<>(half_float::half const*, uint32_t*, size_t);
+template void Aidge::thrust_copy<>(half_float::half const*, uint16_t*, size_t);
+template void Aidge::thrust_copy<>(half_float::half const*, uint8_t*, size_t);
+// int64_t
+template void Aidge::thrust_copy<>(int64_t const*, double*, size_t);
+template void Aidge::thrust_copy<>(int64_t const*, float*, size_t);
+template void Aidge::thrust_copy<>(int64_t const*, half_float::half*, size_t);
+template void Aidge::thrust_copy<>(int64_t const*, int64_t*, size_t);
+template void Aidge::thrust_copy<>(int64_t const*, int32_t*, size_t);
+template void Aidge::thrust_copy<>(int64_t const*, int16_t*, size_t);
+template void Aidge::thrust_copy<>(int64_t const*, int8_t*, size_t);
+template void Aidge::thrust_copy<>(int64_t const*, uint64_t*, size_t);
+template void Aidge::thrust_copy<>(int64_t const*, uint32_t*, size_t);
+template void Aidge::thrust_copy<>(int64_t const*, uint16_t*, size_t);
+template void Aidge::thrust_copy<>(int64_t const*, uint8_t*, size_t);
+// int32_t
+template void Aidge::thrust_copy<>(int32_t const*, double*, size_t);
+template void Aidge::thrust_copy<>(int32_t const*, float*, size_t);
+template void Aidge::thrust_copy<>(int32_t const*, half_float::half*, size_t);
+template void Aidge::thrust_copy<>(int32_t const*, int64_t*, size_t);
+template void Aidge::thrust_copy<>(int32_t const*, int32_t*, size_t);
+template void Aidge::thrust_copy<>(int32_t const*, int16_t*, size_t);
+template void Aidge::thrust_copy<>(int32_t const*, int8_t*, size_t);
+template void Aidge::thrust_copy<>(int32_t const*, uint64_t*, size_t);
+template void Aidge::thrust_copy<>(int32_t const*, uint32_t*, size_t);
+template void Aidge::thrust_copy<>(int32_t const*, uint16_t*, size_t);
+template void Aidge::thrust_copy<>(int32_t const*, uint8_t*, size_t);
+// int16_t
+template void Aidge::thrust_copy<>(int16_t const*, double*, size_t);
+template void Aidge::thrust_copy<>(int16_t const*, float*, size_t);
+template void Aidge::thrust_copy<>(int16_t const*, half_float::half*, size_t);
+template void Aidge::thrust_copy<>(int16_t const*, int64_t*, size_t);
+template void Aidge::thrust_copy<>(int16_t const*, int32_t*, size_t);
+template void Aidge::thrust_copy<>(int16_t const*, int16_t*, size_t);
+template void Aidge::thrust_copy<>(int16_t const*, int8_t*, size_t);
+template void Aidge::thrust_copy<>(int16_t const*, uint64_t*, size_t);
+template void Aidge::thrust_copy<>(int16_t const*, uint32_t*, size_t);
+template void Aidge::thrust_copy<>(int16_t const*, uint16_t*, size_t);
+template void Aidge::thrust_copy<>(int16_t const*, uint8_t*, size_t);
+// int8_t
+template void Aidge::thrust_copy<>(int8_t const*, double*, size_t);
+template void Aidge::thrust_copy<>(int8_t const*, float*, size_t);
+template void Aidge::thrust_copy<>(int8_t const*, half_float::half*, size_t);
+template void Aidge::thrust_copy<>(int8_t const*, int64_t*, size_t);
+template void Aidge::thrust_copy<>(int8_t const*, int32_t*, size_t);
+template void Aidge::thrust_copy<>(int8_t const*, int16_t*, size_t);
+template void Aidge::thrust_copy<>(int8_t const*, int8_t*, size_t);
+template void Aidge::thrust_copy<>(int8_t const*, uint64_t*, size_t);
+template void Aidge::thrust_copy<>(int8_t const*, uint32_t*, size_t);
+template void Aidge::thrust_copy<>(int8_t const*, uint16_t*, size_t);
+template void Aidge::thrust_copy<>(int8_t const*, uint8_t*, size_t);
+// uint64_t
+template void Aidge::thrust_copy<>(uint64_t const*, double*, size_t);
+template void Aidge::thrust_copy<>(uint64_t const*, float*, size_t);
+template void Aidge::thrust_copy<>(uint64_t const*, half_float::half*, size_t);
+template void Aidge::thrust_copy<>(uint64_t const*, int64_t*, size_t);
+template void Aidge::thrust_copy<>(uint64_t const*, int32_t*, size_t);
+template void Aidge::thrust_copy<>(uint64_t const*, int16_t*, size_t);
+template void Aidge::thrust_copy<>(uint64_t const*, int8_t*, size_t);
+template void Aidge::thrust_copy<>(uint64_t const*, uint64_t*, size_t);
+template void Aidge::thrust_copy<>(uint64_t const*, uint32_t*, size_t);
+template void Aidge::thrust_copy<>(uint64_t const*, uint16_t*, size_t);
+template void Aidge::thrust_copy<>(uint64_t const*, uint8_t*, size_t);
+// uint32_t
+template void Aidge::thrust_copy<>(uint32_t const*, double*, size_t);
+template void Aidge::thrust_copy<>(uint32_t const*, float*, size_t);
+template void Aidge::thrust_copy<>(uint32_t const*, half_float::half*, size_t);
+template void Aidge::thrust_copy<>(uint32_t const*, int64_t*, size_t);
+template void Aidge::thrust_copy<>(uint32_t const*, int32_t*, size_t);
+template void Aidge::thrust_copy<>(uint32_t const*, int16_t*, size_t);
+template void Aidge::thrust_copy<>(uint32_t const*, int8_t*, size_t);
+template void Aidge::thrust_copy<>(uint32_t const*, uint64_t*, size_t);
+template void Aidge::thrust_copy<>(uint32_t const*, uint32_t*, size_t);
+template void Aidge::thrust_copy<>(uint32_t const*, uint16_t*, size_t);
+template void Aidge::thrust_copy<>(uint32_t const*, uint8_t*, size_t);
+// uint16_t
+template void Aidge::thrust_copy<>(uint16_t const*, double*, size_t);
+template void Aidge::thrust_copy<>(uint16_t const*, float*, size_t);
+template void Aidge::thrust_copy<>(uint16_t const*, half_float::half*, size_t);
+template void Aidge::thrust_copy<>(uint16_t const*, int64_t*, size_t);
+template void Aidge::thrust_copy<>(uint16_t const*, int32_t*, size_t);
+template void Aidge::thrust_copy<>(uint16_t const*, int16_t*, size_t);
+template void Aidge::thrust_copy<>(uint16_t const*, int8_t*, size_t);
+template void Aidge::thrust_copy<>(uint16_t const*, uint64_t*, size_t);
+template void Aidge::thrust_copy<>(uint16_t const*, uint32_t*, size_t);
+template void Aidge::thrust_copy<>(uint16_t const*, uint16_t*, size_t);
+template void Aidge::thrust_copy<>(uint16_t const*, uint8_t*, size_t);
+// uint8_t
+template void Aidge::thrust_copy<>(uint8_t const*, double*, size_t);
+template void Aidge::thrust_copy<>(uint8_t const*, float*, size_t);
+template void Aidge::thrust_copy<>(uint8_t const*, half_float::half*, size_t);
+template void Aidge::thrust_copy<>(uint8_t const*, int64_t*, size_t);
+template void Aidge::thrust_copy<>(uint8_t const*, int32_t*, size_t);
+template void Aidge::thrust_copy<>(uint8_t const*, int16_t*, size_t);
+template void Aidge::thrust_copy<>(uint8_t const*, int8_t*, size_t);
+template void Aidge::thrust_copy<>(uint8_t const*, uint64_t*, size_t);
+template void Aidge::thrust_copy<>(uint8_t const*, uint32_t*, size_t);
+template void Aidge::thrust_copy<>(uint8_t const*, uint16_t*, size_t);
+template void Aidge::thrust_copy<>(uint8_t const*, uint8_t*, size_t);
\ No newline at end of file
diff --git a/src/operator/AddImpl.cpp b/src/operator/AddImpl.cpp
index 8771a79e938dff893d5295bd847567a0dcb18f32..ad1d852f98336accb20504c9a50d2a64bc3a797f 100644
--- a/src/operator/AddImpl.cpp
+++ b/src/operator/AddImpl.cpp
@@ -33,40 +33,46 @@ void Aidge::AddImpl_cuda::forward() {
         AIDGE_ASSERT(op.getInput(i)->dataType() == datatypeFirstInput, "Cannot add inputs with two differents data type.");
     }
 
-    std::vector<std::shared_ptr<Tensor>> inputFallbacks(op.nbInputs());
-    std::vector<Tensor> inputs(op.nbInputs());
-    std::vector<std::vector<int>> dims(op.nbInputs()); // For broadcasted dims
-    std::vector<std::vector<int>> strides(op.nbInputs()); // For the cooresponding strides
-    for (IOIndex_t i = 0; i < op.nbInputs(); ++i) {
-        inputs[i] = op.getInput(i)->refCastFrom(inputFallbacks[i], *op.getOutput(0));
-
-        // Get tensor dims and broadcast them
-        std::copy(inputs[i].dims().begin(), inputs[i].dims().end(), std::back_inserter(dims[i]));
-        dims[i].insert(dims[i].cbegin(), op.getOutput(0)->nbDims() - dims[i].size(), int(1));
-
-        if (dims[i].size() < 4) {
-            dims[i].resize(4, 1);
-        }
+    if (mInputFallbacks.empty()) {
+        mInputFallbacks.resize(op.nbInputs());
+    }
 
-        // Compute the corresponding strides
-        std::vector<int> tensorStrides(dims[i].size());
-        int product = 1;
-        for (size_t j = dims[i].size(); j > 0; --j) {
-            tensorStrides[j - 1] = product;
-            product *= dims[i][j - 1];
+    std::vector<std::reference_wrapper<Tensor>> inputs;
+    for (IOIndex_t i = 0; i < op.nbInputs(); ++i) {
+        inputs.push_back(op.getInput(i)->refCastFrom(mInputFallbacks[i], *op.getOutput(0)));
+
+        if (mTensorDesc.size() <= i) {
+            // Get tensor dims and broadcast them
+            std::vector<int> dims(inputs[i].get().dims().begin(), inputs[i].get().dims().end());
+            dims.insert(dims.cbegin(), op.getOutput(0)->nbDims() - dims.size(), int(1));
+
+            if (dims.size() < 4) {
+                dims.resize(4, 1);
+            }
+
+            // Compute the corresponding strides
+            std::vector<int> strides(dims.size());
+            int product = 1;
+            for (size_t j = dims.size(); j > 0; --j) {
+                strides[j - 1] = product;
+                product *= dims[j - 1];
+            }
+            
+            mTensorDesc.push_back(nullptr);
+            CHECK_CUDNN_STATUS(cudnnCreateTensorDescriptor(&mTensorDesc[i]));
+            CHECK_CUDNN_STATUS(cudnnSetTensorNdDescriptor(mTensorDesc[i], DataTypeToCudnn(op.getOutput(0)->dataType()), dims.size(), dims.data(), strides.data()));
         }
-        strides[i] = tensorStrides;
     }
 
     switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
         case DataType::Float64:
-            forward_<double>(inputs, dims, strides);
+            forward_<double>(inputs);
             break;
         case DataType::Float32:
-            forward_<float>(inputs, dims, strides);
+            forward_<float>(inputs);
             break;
         case DataType::Float16:
-            forward_<half>(inputs, dims, strides);
+            forward_<half>(inputs);
             break;
         default:
             AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
@@ -74,40 +80,23 @@ void Aidge::AddImpl_cuda::forward() {
 }
 
 template <class T>
-void Aidge::AddImpl_cuda::forward_(const std::vector<Tensor>& inputs, const std::vector<std::vector<int>>& inputsDims, const std::vector<std::vector<int>>& inputsStrides) {
+void Aidge::AddImpl_cuda::forward_(const std::vector<std::reference_wrapper<Tensor>>& inputs) {
     const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
     const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f;
     const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f;
 
-    // Create a Tensor descriptor with the broadcasted dims and strides
-    cudnnTensorDescriptor_t tensorDesc;
-    CHECK_CUDNN_STATUS(cudnnCreateTensorDescriptor(&tensorDesc));
-    CHECK_CUDNN_STATUS(cudnnSetTensorNdDescriptor(tensorDesc, CudaContext::data_type<T>::value, inputsDims[0].size(), inputsDims[0].data(), inputsStrides[0].data()));
-    // Add first input
-    CHECK_CUDNN_STATUS(
-        cudnnAddTensor(CudaContext::cudnnHandle(),
-                       &alpha,
-                       tensorDesc,
-                       inputs[0].getImpl()->rawPtr(),
-                       &beta,
-                       std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
-                       std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr())
-    );
-    // Add other inputs if there are any
-    for (size_t i = 1; i < op.nbInputs(); ++i)
+    for (size_t i = 0; i < op.nbInputs(); ++i)
     {
-        CHECK_CUDNN_STATUS(cudnnSetTensorNdDescriptor(tensorDesc, CudaContext::data_type<T>::value, inputsDims[i].size(), inputsDims[i].data(), inputsStrides[i].data()));
         CHECK_CUDNN_STATUS(
             cudnnAddTensor(CudaContext::cudnnHandle(),
                         &alpha,
-                        tensorDesc,
-                        inputs[i].getImpl()->rawPtr(),
-                        &alpha,
+                        mTensorDesc[i],
+                        inputs[i].get().getImpl()->rawPtr(),
+                        (i > 0) ? &alpha : &beta,
                         std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
                         std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr())
         );
     }
-    CHECK_CUDNN_STATUS(cudnnDestroyTensorDescriptor(tensorDesc));
 }
 
 void Aidge::AddImpl_cuda::backward() {
@@ -116,38 +105,42 @@ void Aidge::AddImpl_cuda::backward() {
     AIDGE_ASSERT(op.getOutput(0)->grad(), "missing output gradient in Add operator");
     AIDGE_ASSERT(op.getOutput(0)->grad()->hasImpl(), "cannot run Add backward because the output gradient has no implementation.");
 
-    std::shared_ptr<Tensor> outputGradFallback;
-    const auto& outputGrad = op.getOutput(0)->grad()->refCastFrom(outputGradFallback, *op.getOutput(0)->grad());
+    const auto& outputGrad = op.getOutput(0)->grad()->refCastFrom(mOutputGradFallback, *op.getOutput(0)->grad());
 
-    std::vector<std::vector<int>> dims(op.nbInputs()); // For broadcasted dims
-    std::vector<std::vector<int>> strides(op.nbInputs()); // For the cooresponding strides
-    for (IOIndex_t i = 0; i < op.nbInputs(); ++i) {
-        std::shared_ptr<Tensor> inputFallback;
-        const Tensor input = op.getInput(i)->refCastFrom(inputFallback, *op.getOutput(0));
-
-        // Get tensor dims and broadcast them
-        std::copy(input.dims().begin(), input.dims().end(), std::back_inserter(dims[i]));
-        dims[i].insert(dims[i].cbegin(), op.getOutput(0)->nbDims() - dims[i].size(), int(1));
-
-        // Compute the corresponding strides
-        std::vector<int> tensorStrides(dims[i].size());
-        int product = 1;
-        for (size_t j = dims[i].size(); j > 0; --j) {
-            tensorStrides[j - 1] = product;
-            product *= dims[i][j - 1];
+    if (mBwdReduceDesc == nullptr) {
+        CHECK_CUDNN_STATUS(cudnnCreateReduceTensorDescriptor(&mBwdReduceDesc));
+        CHECK_CUDNN_STATUS(cudnnSetReduceTensorDescriptor(mBwdReduceDesc,
+                                                            CUDNN_REDUCE_TENSOR_ADD,
+                                                            DataTypeToCudnn(op.getOutput(0)->dataType()),
+                                                            CUDNN_PROPAGATE_NAN,
+                                                            CUDNN_REDUCE_TENSOR_NO_INDICES,
+                                                            CUDNN_32BIT_INDICES));
+    }
+
+    if (mBwdWorkspace == nullptr) {
+        size_t workspaceSize = 0;
+        for (std::size_t i = 0; i < mTensorDesc.size(); i++) {
+            CHECK_CUDNN_STATUS(cudnnGetReductionWorkspaceSize(CudaContext::cudnnHandle(),
+                                mBwdReduceDesc,
+                                std::dynamic_pointer_cast<TensorImpl_cuda_>(outputGrad.getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
+                                mTensorDesc[i],
+                                &workspaceSize));
+            
+            mBwdWorkspaceSize = std::max(workspaceSize, mBwdWorkspaceSize);
         }
-        strides[i] = tensorStrides;
+
+        CHECK_CUDA_STATUS(cudaMalloc(&mBwdWorkspace, mBwdWorkspaceSize));
     }
 
     switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
         case DataType::Float64:
-            backward_<double>(outputGrad, dims, strides);
+            backward_<double>(outputGrad);
             break;
         case DataType::Float32:
-            backward_<float>(outputGrad, dims, strides);
+            backward_<float>(outputGrad);
             break;
         case DataType::Float16:
-            backward_<half>(outputGrad, dims, strides);
+            backward_<half>(outputGrad);
             break;
         default:
             AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
@@ -155,65 +148,56 @@ void Aidge::AddImpl_cuda::backward() {
 }
 
 template <class T>
-void Aidge::AddImpl_cuda::backward_(const Tensor& outputGrad, const std::vector<std::vector<int>>& inputsDims, const std::vector<std::vector<int>>& inputsStrides) 
+void Aidge::AddImpl_cuda::backward_(const Tensor& outputGrad) 
 {
     const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
 
     const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f;
     const typename Cuda::cudnn_scaling_type<T>::type beta = 1.0f; // accumulate
 
-    for (std::size_t i = 0; i < inputsDims.size(); i++)
+    for (std::size_t i = 0; i < mTensorDesc.size(); i++)
     {
         if (op.getInput(i)->size() == op.getOutput(0)->size())
         {
-            // TODO: Test if we can avoid copy and simply set rawPtr
-            op.getInput(i)->grad()->getImpl()->copy(outputGrad.getImpl()->rawPtr(), op.getInput(i)->grad()->size());
+            CHECK_CUDNN_STATUS(cudnnAddTensor(CudaContext::cudnnHandle(),
+                        &alpha,
+                        std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
+                        outputGrad.getImpl()->rawPtr(),
+                        &beta,
+                        std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getInput(i)->getImpl())->getCudnnTensorDesc(*op.getInput(i)),
+                        op.getInput(i)->grad()->getImpl()->rawPtr()));
         }
         else // In case of broadcasting
         {
             // Gradient with respect to input_i: sum outputGrad over the broadcasted dimensions using cudnnReduceTensor
-            cudnnReduceTensorDescriptor_t reduceDesc;
-            CHECK_CUDNN_STATUS(cudnnCreateReduceTensorDescriptor(&reduceDesc));
-            CHECK_CUDNN_STATUS(cudnnSetReduceTensorDescriptor(reduceDesc,
-                                                              CUDNN_REDUCE_TENSOR_ADD,
-                                                              CudaContext::data_type<T>::value,
-                                                              CUDNN_PROPAGATE_NAN,
-                                                              CUDNN_REDUCE_TENSOR_NO_INDICES,
-                                                              CUDNN_32BIT_INDICES));
-
-            cudnnTensorDescriptor_t outputDesc = std::dynamic_pointer_cast<TensorImpl_cuda_>(outputGrad.getImpl())->getCudnnTensorDesc(*op.getOutput(0));
-            // Create a Tensor descriptor with the broadcasted dims and strides
-            cudnnTensorDescriptor_t tensorDesc;
-            CHECK_CUDNN_STATUS(cudnnCreateTensorDescriptor(&tensorDesc));
-            CHECK_CUDNN_STATUS(cudnnSetTensorNdDescriptor(tensorDesc,
-                                                          CudaContext::data_type<T>::value,
-                                                          inputsDims[i].size(),
-                                                          inputsDims[i].data(),
-                                                          inputsStrides[i].data()));
-            size_t workspaceSize;
-            CHECK_CUDNN_STATUS(cudnnGetReductionWorkspaceSize(CudaContext::cudnnHandle(),
-                               reduceDesc,
-                               outputDesc,
-                               tensorDesc,
-                               &workspaceSize));
-
-            void *d_workspace;
-            CHECK_CUDA_STATUS(cudaMalloc(&d_workspace, workspaceSize));
-
             CHECK_CUDNN_STATUS(cudnnReduceTensor(CudaContext::cudnnHandle(),
-                               reduceDesc,
+                               mBwdReduceDesc,
                                NULL,
                                0,
-                               d_workspace,
-                               workspaceSize,
+                               mBwdWorkspace,
+                               mBwdWorkspaceSize,
                                &alpha,
-                               outputDesc,
+                               std::dynamic_pointer_cast<TensorImpl_cuda_>(outputGrad.getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
                                outputGrad.getImpl()->rawPtr(),
                                &beta,
-                               tensorDesc,
+                               mTensorDesc[i],
                                op.getInput(i)->grad()->getImpl()->rawPtr()));
+        }
+    }
+}
 
-            CHECK_CUDNN_STATUS(cudnnDestroyTensorDescriptor(tensorDesc));
+Aidge::AddImpl_cuda::~AddImpl_cuda() {
+    for (auto tensorDesc : mTensorDesc) {
+        if (tensorDesc != nullptr) {
+            cudnnDestroyTensorDescriptor(tensorDesc);
         }
     }
+    
+    if (mBwdReduceDesc != nullptr) {
+        cudnnDestroyReduceTensorDescriptor(mBwdReduceDesc);
+    }
+
+    if (mBwdWorkspace != nullptr) {
+        cudaFree(mBwdWorkspace);
+    }
 }
diff --git a/src/operator/MulImpl.cpp b/src/operator/MulImpl.cpp
index aa9b4c74785d3d5785f9d9d62d1a72503f8be104..e86a6b0e7e936e7ea7bcb293b704979a6e00490c 100644
--- a/src/operator/MulImpl.cpp
+++ b/src/operator/MulImpl.cpp
@@ -34,40 +34,51 @@ void Aidge::MulImpl_cuda::forward() {
         AIDGE_ASSERT(op.getInput(i)->dataType() == datatypeFirstInput, "Cannot Mul inputs with two differents data type.");
     }
 
-    std::vector<std::shared_ptr<Tensor>> inputFallbacks(op.nbInputs());
-    std::vector<Tensor> inputs(op.nbInputs());
-    std::vector<std::vector<int>> dims(op.nbInputs()); // For broadcasted dims
-    std::vector<std::vector<int>> strides(op.nbInputs()); // For the cooresponding strides
-    for (IOIndex_t i = 0; i < op.nbInputs(); ++i) {
-        inputs[i] = op.getInput(i)->refCastFrom(inputFallbacks[i], *op.getOutput(0));
-
-        // Get tensor dims and broadcast them
-        std::copy(inputs[i].dims().begin(), inputs[i].dims().end(), std::back_inserter(dims[i]));
-        dims[i].insert(dims[i].cbegin(), op.getOutput(0)->nbDims() - dims[i].size(), int(1));
+    if (mInputFallbacks.empty()) {
+        mInputFallbacks.resize(op.nbInputs());
+    }
 
-        if (dims[i].size() < 4) {
-            dims[i].resize(4, 1);
+    std::vector<std::reference_wrapper<Tensor>> inputs;
+    for (IOIndex_t i = 0; i < op.nbInputs(); ++i) {
+        inputs.push_back(op.getInput(i)->refCastFrom(mInputFallbacks[i], *op.getOutput(0)));
+
+        if (mTensorDesc.size() <= i) {
+            // Get tensor dims and broadcast them
+            std::vector<int> dims(inputs[i].get().dims().begin(), inputs[i].get().dims().end());
+            dims.insert(dims.cbegin(), op.getOutput(0)->nbDims() - dims.size(), int(1));
+
+            if (dims.size() < 4) {
+                dims.resize(4, 1);
+            }
+
+            // Compute the corresponding strides
+            std::vector<int> strides(dims.size());
+            int product = 1;
+            for (size_t j = dims.size(); j > 0; --j) {
+                strides[j - 1] = product;
+                product *= dims[j - 1];
+            }
+            
+            mTensorDesc.push_back(nullptr);
+            CHECK_CUDNN_STATUS(cudnnCreateTensorDescriptor(&mTensorDesc[i]));
+            CHECK_CUDNN_STATUS(cudnnSetTensorNdDescriptor(mTensorDesc[i], DataTypeToCudnn(op.getOutput(0)->dataType()), dims.size(), dims.data(), strides.data()));
         }
+    }
 
-        // Compute the corresponding strides
-        std::vector<int> tensorStrides(dims[i].size());
-        int product = 1;
-        for (size_t j = dims[i].size(); j > 0; --j) {
-            tensorStrides[j - 1] = product;
-            product *= dims[i][j - 1];
-        }
-        strides[i] = tensorStrides;
+    if (mOpTensorDesc == nullptr) {
+        CHECK_CUDNN_STATUS(cudnnCreateOpTensorDescriptor(&mOpTensorDesc));
+        CHECK_CUDNN_STATUS(cudnnSetOpTensorDescriptor(mOpTensorDesc, CUDNN_OP_TENSOR_MUL, DataTypeToCudnn(op.getOutput(0)->dataType()), CUDNN_PROPAGATE_NAN));
     }
 
     switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
         case DataType::Float64:
-            forward_<double>(inputs, dims, strides);
+            forward_<double>(inputs);
             break;
         case DataType::Float32:
-            forward_<float>(inputs, dims, strides);
+            forward_<float>(inputs);
             break;
         case DataType::Float16:
-            forward_<half>(inputs, dims, strides);
+            forward_<half>(inputs);
             break;
         default:
             AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
@@ -75,51 +86,37 @@ void Aidge::MulImpl_cuda::forward() {
 }
 
 template <class T>
-void Aidge::MulImpl_cuda::forward_(const std::vector<Tensor>& inputs, const std::vector<std::vector<int>>& inputsDims, const std::vector<std::vector<int>>& inputsStrides) {
+void Aidge::MulImpl_cuda::forward_(const std::vector<std::reference_wrapper<Tensor>>& inputs) {
     const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
     const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f;
     const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f;
 
-    // Create a Tensor descriptor with the broadcasted dims and strides
-    cudnnTensorDescriptor_t tensorDesc0, tensorDesc1;
-    CHECK_CUDNN_STATUS(cudnnCreateTensorDescriptor(&tensorDesc0));
-    CHECK_CUDNN_STATUS(cudnnSetTensorNdDescriptor(tensorDesc0, CudaContext::data_type<T>::value, inputsDims[0].size(), inputsDims[0].data(), inputsStrides[0].data()));
-    CHECK_CUDNN_STATUS(cudnnCreateTensorDescriptor(&tensorDesc1));
-    CHECK_CUDNN_STATUS(cudnnSetTensorNdDescriptor(tensorDesc1, CudaContext::data_type<T>::value, inputsDims[1].size(), inputsDims[1].data(), inputsStrides[1].data()));
-    // Multiply inputs
-    cudnnOpTensorDescriptor_t opTensorDesc;
-    CHECK_CUDNN_STATUS(cudnnCreateOpTensorDescriptor(&opTensorDesc));
-    CHECK_CUDNN_STATUS(cudnnSetOpTensorDescriptor(opTensorDesc, CUDNN_OP_TENSOR_MUL, CudaContext::data_type<T>::value, CUDNN_PROPAGATE_NAN));
-    if(inputs[0].size()>inputs[1].size()) {
+    if(inputs[0].get().size()>inputs[1].get().size()) {
         CHECK_CUDNN_STATUS(cudnnOpTensor(CudaContext::cudnnHandle(),
-                                        opTensorDesc,
+                                        mOpTensorDesc,
                                         &alpha,
-                                        tensorDesc0,
-                                        inputs[0].getImpl()->rawPtr(),
+                                        mTensorDesc[0],
+                                        inputs[0].get().getImpl()->rawPtr(),
                                         &alpha,
-                                        tensorDesc1,
-                                        inputs[1].getImpl()->rawPtr(),
+                                        mTensorDesc[1],
+                                        inputs[1].get().getImpl()->rawPtr(),
                                         &beta,
                                         std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
                                         std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr()));
     }
     else {
         CHECK_CUDNN_STATUS(cudnnOpTensor(CudaContext::cudnnHandle(),
-                                opTensorDesc,
+                                mOpTensorDesc,
                                 &alpha,
-                                tensorDesc1,
-                                inputs[1].getImpl()->rawPtr(),
+                                mTensorDesc[1],
+                                inputs[1].get().getImpl()->rawPtr(),
                                 &alpha,
-                                tensorDesc0,
-                                inputs[0].getImpl()->rawPtr(),
+                                mTensorDesc[0],
+                                inputs[0].get().getImpl()->rawPtr(),
                                 &beta,
                                 std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
                                 std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr()));
     }
-
-    CHECK_CUDNN_STATUS(cudnnDestroyTensorDescriptor(tensorDesc0));
-    CHECK_CUDNN_STATUS(cudnnDestroyTensorDescriptor(tensorDesc1));
-    CHECK_CUDNN_STATUS(cudnnDestroyOpTensorDescriptor(opTensorDesc));
 }
 
 void Aidge::MulImpl_cuda::backward() {
@@ -128,42 +125,17 @@ void Aidge::MulImpl_cuda::backward() {
     AIDGE_ASSERT(op.getOutput(0)->grad(), "missing output gradient in Mul operator");
     AIDGE_ASSERT(op.getOutput(0)->grad()->hasImpl(), "cannot run Mul backward because the output gradient has no implementation.");
 
-    std::shared_ptr<Tensor> outputGradFallback;
-    const auto& outputGrad = op.getOutput(0)->grad()->refCastFrom(outputGradFallback, *op.getOutput(0)->grad());
-
-    std::vector<std::vector<int>> dims(op.nbInputs()); // For broadcasted dims
-    std::vector<std::vector<int>> strides(op.nbInputs()); // For the cooresponding strides
-    for (IOIndex_t i = 0; i < op.nbInputs(); ++i) {
-        std::shared_ptr<Tensor> inputFallback;
-        const Tensor input = op.getInput(i)->refCastFrom(inputFallback, *op.getOutput(0));
-
-        // Get tensor dims and broadcast them
-        std::copy(input.dims().begin(), input.dims().end(), std::back_inserter(dims[i]));
-        dims[i].insert(dims[i].cbegin(), op.getOutput(0)->nbDims() - dims[i].size(), int(1));
-        
-        if (dims[i].size() < 4) {
-            dims[i].resize(4, 1);
-        }
-
-        // Compute the corresponding strides
-        std::vector<int> tensorStrides(dims[i].size());
-        int product = 1;
-        for (size_t j = dims[i].size(); j > 0; --j) {
-            tensorStrides[j - 1] = product;
-            product *= dims[i][j - 1];
-        }
-        strides[i] = tensorStrides;
-    }
+    const auto& outputGrad = op.getOutput(0)->grad()->refCastFrom(mOutputGradFallback, *op.getOutput(0)->grad());
 
     switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
         case DataType::Float64:
-            backward_<double>(outputGrad, dims, strides);
+            backward_<double>(outputGrad);
             break;
         case DataType::Float32:
-            backward_<float>(outputGrad, dims, strides);
+            backward_<float>(outputGrad);
             break;
         case DataType::Float16:
-            backward_<half>(outputGrad, dims, strides);
+            backward_<half>(outputGrad);
             break;
         default:
             AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
@@ -171,51 +143,47 @@ void Aidge::MulImpl_cuda::backward() {
 }
 
 template <class T>
-void Aidge::MulImpl_cuda::backward_(const Tensor& outputGrad, const std::vector<std::vector<int>>& inputsDims, const std::vector<std::vector<int>>& inputsStrides) {
+void Aidge::MulImpl_cuda::backward_(const Tensor& outputGrad) {
 
     const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
     const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f;
     const typename Cuda::cudnn_scaling_type<T>::type beta = 1.0f; // accumulate
 
-    // Create a Tensor descriptor with the broadcasted dims and strides
-    cudnnTensorDescriptor_t tensorDesc0, tensorDesc1;
-    CHECK_CUDNN_STATUS(cudnnCreateTensorDescriptor(&tensorDesc0));
-    CHECK_CUDNN_STATUS(cudnnSetTensorNdDescriptor(tensorDesc0, CudaContext::data_type<T>::value, inputsDims[0].size(), inputsDims[0].data(), inputsStrides[0].data()));
-    CHECK_CUDNN_STATUS(cudnnCreateTensorDescriptor(&tensorDesc1));
-    CHECK_CUDNN_STATUS(cudnnSetTensorNdDescriptor(tensorDesc1, CudaContext::data_type<T>::value, inputsDims[1].size(), inputsDims[1].data(), inputsStrides[1].data()));
-    
-    // Create the operation descriptor
-    cudnnOpTensorDescriptor_t opTensorDesc;
-    CHECK_CUDNN_STATUS(cudnnCreateOpTensorDescriptor(&opTensorDesc));
-    CHECK_CUDNN_STATUS(cudnnSetOpTensorDescriptor(opTensorDesc, CUDNN_OP_TENSOR_MUL, CudaContext::data_type<T>::value, CUDNN_PROPAGATE_NAN));
-
     // Input0_grad = output_grad * Input1
     CHECK_CUDNN_STATUS(cudnnOpTensor(CudaContext::cudnnHandle(),
-                            opTensorDesc,
+                            mOpTensorDesc,
                             &alpha,
                             std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
                             outputGrad.getImpl()->rawPtr(),
                             &alpha,
-                            tensorDesc1,
+                            mTensorDesc[1],
                             op.getInput(1)->getImpl()->rawPtr(),
                             &beta,
-                            tensorDesc0,
+                            mTensorDesc[0],
                             op.getInput(0)->grad()->getImpl()->rawPtr()));
 
     // Input1_grad = output_grad * Input0
     CHECK_CUDNN_STATUS(cudnnOpTensor(CudaContext::cudnnHandle(),
-                            opTensorDesc,
+                            mOpTensorDesc,
                             &alpha,
                             std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
                             outputGrad.getImpl()->rawPtr(),
                             &alpha,
-                            tensorDesc0,
+                            mTensorDesc[0],
                             op.getInput(0)->getImpl()->rawPtr(),
                             &beta,
-                            tensorDesc1,
+                            mTensorDesc[1],
                             op.getInput(1)->grad()->getImpl()->rawPtr()));
-    
-    CHECK_CUDNN_STATUS(cudnnDestroyTensorDescriptor(tensorDesc0));
-    CHECK_CUDNN_STATUS(cudnnDestroyTensorDescriptor(tensorDesc1));
-    CHECK_CUDNN_STATUS(cudnnDestroyOpTensorDescriptor(opTensorDesc));
-}
\ No newline at end of file
+}
+
+Aidge::MulImpl_cuda::~MulImpl_cuda() {
+    for (auto tensorDesc : mTensorDesc) {
+        if (tensorDesc != nullptr) {
+            cudnnDestroyTensorDescriptor(tensorDesc);
+        }
+    }
+
+    if (mOpTensorDesc != nullptr) {
+        cudnnDestroyOpTensorDescriptor(mOpTensorDesc);
+    }
+}
diff --git a/src/operator/ReduceMeanImpl.cpp b/src/operator/ReduceMeanImpl.cpp
index 2746d4c36fd2d2a5cfe8196d4b091c67ce0f2324..3f8e8b92aba68ed50b13578ff0fae65a13ef94ff 100644
--- a/src/operator/ReduceMeanImpl.cpp
+++ b/src/operator/ReduceMeanImpl.cpp
@@ -36,6 +36,51 @@ void Aidge::ReduceMeanImpl_cuda::forward() {
         std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->copy(input.getImpl()->rawPtr(), input.size());
     }
     else {
+        if (!keepDims && mOutputDesc == nullptr) {
+            std::vector<int> outputDims;
+            std::copy(input.dims().begin(), input.dims().end(), std::back_inserter(outputDims));
+            for (const auto axis:axes) {
+                outputDims[axis] = 1;
+            }
+            if (outputDims.size() < 4) {
+                outputDims.resize(4, 1);
+            }
+            // Compute the corresponding strides
+            std::vector<int> outputStrides(outputDims.size());
+            int product = 1;
+            for (size_t i = outputDims.size(); i > 0; --i) {
+                outputStrides[i - 1] = product;
+                product *= outputDims[i - 1];
+            }
+
+            CHECK_CUDNN_STATUS(cudnnCreateTensorDescriptor(&mOutputDesc));
+            CHECK_CUDNN_STATUS(cudnnSetTensorNdDescriptor(mOutputDesc, DataTypeToCudnn(op.getOutput(0)->dataType()), outputDims.size(), outputDims.data(), outputStrides.data()));
+        }
+
+        if (mReduceDesc == nullptr) {
+            CHECK_CUDNN_STATUS(cudnnCreateReduceTensorDescriptor(&mReduceDesc));
+            CHECK_CUDNN_STATUS(cudnnSetReduceTensorDescriptor(mReduceDesc,
+                                                                CUDNN_REDUCE_TENSOR_AVG,
+                                                                DataTypeToCudnn(op.getOutput(0)->dataType()),
+                                                                CUDNN_PROPAGATE_NAN,
+                                                                CUDNN_REDUCE_TENSOR_NO_INDICES,
+                                                                CUDNN_32BIT_INDICES));
+        }
+
+        if (mWorkspace == nullptr) {
+            const auto outputDesc = (keepDims)
+                ? std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0))
+                : mOutputDesc;
+
+            CHECK_CUDNN_STATUS(cudnnGetReductionWorkspaceSize(CudaContext::cudnnHandle(),
+                                mReduceDesc,
+                                std::dynamic_pointer_cast<TensorImpl_cuda_>(input.getImpl())->getCudnnTensorDesc(input),
+                                outputDesc,
+                                &mWorkspaceSize));
+
+            CHECK_CUDA_STATUS(cudaMalloc(&mWorkspace, mWorkspaceSize));
+        }
+
         switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
             case DataType::Float64:
                 forward_<double>(input, axes, keepDims);
@@ -59,97 +104,32 @@ void Aidge::ReduceMeanImpl_cuda::forward_(const Tensor& input, const std::vector
     const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f;
     const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f;
 
-    cudnnReduceTensorDescriptor_t reduceDesc;
-    cudnnTensorDescriptor_t outputDesc;
     if (keepDims) {
-        outputDesc = std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0));
-        CHECK_CUDNN_STATUS(cudnnCreateReduceTensorDescriptor(&reduceDesc));
-        CHECK_CUDNN_STATUS(cudnnSetReduceTensorDescriptor(reduceDesc,
-                                                            CUDNN_REDUCE_TENSOR_AVG,
-                                                            CudaContext::data_type<T>::value,
-                                                            CUDNN_PROPAGATE_NAN,
-                                                            CUDNN_REDUCE_TENSOR_NO_INDICES,
-                                                            CUDNN_32BIT_INDICES));
-
-
-        size_t workspaceSize;
-        CHECK_CUDNN_STATUS(cudnnGetReductionWorkspaceSize(CudaContext::cudnnHandle(),
-                            reduceDesc,
-                            std::dynamic_pointer_cast<TensorImpl_cuda_>(input.getImpl())->getCudnnTensorDesc(input),
-                            outputDesc,
-                            &workspaceSize));
-
-        void *d_workspace;
-        CHECK_CUDA_STATUS(cudaMalloc(&d_workspace, workspaceSize));
-
         CHECK_CUDNN_STATUS(cudnnReduceTensor(CudaContext::cudnnHandle(),
-                            reduceDesc,
+                            mReduceDesc,
                             NULL,
                             0,
-                            d_workspace,
-                            workspaceSize,
+                            mWorkspace,
+                            mWorkspaceSize,
                             &alpha,
                             std::dynamic_pointer_cast<TensorImpl_cuda_>(input.getImpl())->getCudnnTensorDesc(input),
                             input.getImpl()->rawPtr(),
                             &beta,
-                            outputDesc,
-                            std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr()));
-
-        CHECK_CUDNN_STATUS(cudnnDestroyReduceTensorDescriptor(reduceDesc));
-    }
+                            std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
+                            std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr()));    }
     else {
-        CHECK_CUDNN_STATUS(cudnnCreateTensorDescriptor(&outputDesc));
-        std::vector<int> outputDims;
-        std::copy(input.dims().begin(), input.dims().end(), std::back_inserter(outputDims));
-        for (const auto axis:axes) {
-            outputDims[axis] = 1;
-        }
-        if (outputDims.size() < 4) {
-            outputDims.resize(4, 1);
-        }
-        // Compute the corresponding strides
-        std::vector<int> outputStrides(outputDims.size());
-        int product = 1;
-        for (size_t i = outputDims.size(); i > 0; --i) {
-            outputStrides[i - 1] = product;
-            product *= outputDims[i - 1];
-        }
-        CHECK_CUDNN_STATUS(cudnnSetTensorNdDescriptor(outputDesc, CudaContext::data_type<T>::value, outputDims.size(), outputDims.data(), outputStrides.data()));
-    
-        CHECK_CUDNN_STATUS(cudnnCreateReduceTensorDescriptor(&reduceDesc));
-        CHECK_CUDNN_STATUS(cudnnSetReduceTensorDescriptor(reduceDesc,
-                                                            CUDNN_REDUCE_TENSOR_AVG,
-                                                            CudaContext::data_type<T>::value,
-                                                            CUDNN_PROPAGATE_NAN,
-                                                            CUDNN_REDUCE_TENSOR_NO_INDICES,
-                                                            CUDNN_32BIT_INDICES));
-
-
-        size_t workspaceSize;
-        CHECK_CUDNN_STATUS(cudnnGetReductionWorkspaceSize(CudaContext::cudnnHandle(),
-                            reduceDesc,
-                            std::dynamic_pointer_cast<TensorImpl_cuda_>(input.getImpl())->getCudnnTensorDesc(input),
-                            outputDesc,
-                            &workspaceSize));
-
-        void *d_workspace;
-        CHECK_CUDA_STATUS(cudaMalloc(&d_workspace, workspaceSize));
-
         CHECK_CUDNN_STATUS(cudnnReduceTensor(CudaContext::cudnnHandle(),
-                            reduceDesc,
+                            mReduceDesc,
                             NULL,
                             0,
-                            d_workspace,
-                            workspaceSize,
+                            mWorkspace,
+                            mWorkspaceSize,
                             &alpha,
                             std::dynamic_pointer_cast<TensorImpl_cuda_>(input.getImpl())->getCudnnTensorDesc(input),
                             input.getImpl()->rawPtr(),
                             &beta,
-                            outputDesc,
+                            mOutputDesc,
                             std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr()));
-
-        CHECK_CUDNN_STATUS(cudnnDestroyReduceTensorDescriptor(reduceDesc));
-        CHECK_CUDNN_STATUS(cudnnDestroyTensorDescriptor(outputDesc));
     }
 }
 
@@ -206,3 +186,17 @@ void Aidge::ReduceMeanImpl_cuda::backward_(const Tensor& outGrad, const std::vec
                             alpha,
                             beta);
 }
+
+Aidge::ReduceMeanImpl_cuda::~ReduceMeanImpl_cuda() {
+    if (mReduceDesc != nullptr) {
+        cudnnDestroyReduceTensorDescriptor(mReduceDesc);
+    }
+
+    if (mOutputDesc != nullptr) {
+        cudnnDestroyTensorDescriptor(mOutputDesc);
+    }
+
+    if (mWorkspace != nullptr) {
+        cudaFree(mWorkspace);
+    }
+}
diff --git a/src/operator/ReduceSumImpl.cpp b/src/operator/ReduceSumImpl.cpp
index e8c5b1e98d10d40dc01157465ba21f3a5330ced4..46469c5ec57a80b08527fe1b62b2a6821010f988 100644
--- a/src/operator/ReduceSumImpl.cpp
+++ b/src/operator/ReduceSumImpl.cpp
@@ -36,6 +36,51 @@ void Aidge::ReduceSumImpl_cuda::forward() {
         std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->copy(input.getImpl()->rawPtr(), input.size());
     }
     else {
+        if (!keepDims && mOutputDesc == nullptr) {
+            std::vector<int> outputDims;
+            std::copy(input.dims().begin(), input.dims().end(), std::back_inserter(outputDims));
+            for (const auto axis:axes) {
+                outputDims[axis] = 1;
+            }
+            if (outputDims.size() < 4) {
+                outputDims.resize(4, 1);
+            }
+            // Compute the corresponding strides
+            std::vector<int> outputStrides(outputDims.size());
+            int product = 1;
+            for (size_t i = outputDims.size(); i > 0; --i) {
+                outputStrides[i - 1] = product;
+                product *= outputDims[i - 1];
+            }
+
+            CHECK_CUDNN_STATUS(cudnnCreateTensorDescriptor(&mOutputDesc));
+            CHECK_CUDNN_STATUS(cudnnSetTensorNdDescriptor(mOutputDesc, DataTypeToCudnn(op.getOutput(0)->dataType()), outputDims.size(), outputDims.data(), outputStrides.data()));
+        }
+
+        if (mReduceDesc == nullptr) {
+            CHECK_CUDNN_STATUS(cudnnCreateReduceTensorDescriptor(&mReduceDesc));
+            CHECK_CUDNN_STATUS(cudnnSetReduceTensorDescriptor(mReduceDesc,
+                                                                CUDNN_REDUCE_TENSOR_ADD,
+                                                                DataTypeToCudnn(op.getOutput(0)->dataType()),
+                                                                CUDNN_PROPAGATE_NAN,
+                                                                CUDNN_REDUCE_TENSOR_NO_INDICES,
+                                                                CUDNN_32BIT_INDICES));
+        }
+
+        if (mWorkspace == nullptr) {
+            const auto outputDesc = (keepDims)
+                ? std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0))
+                : mOutputDesc;
+
+            CHECK_CUDNN_STATUS(cudnnGetReductionWorkspaceSize(CudaContext::cudnnHandle(),
+                                mReduceDesc,
+                                std::dynamic_pointer_cast<TensorImpl_cuda_>(input.getImpl())->getCudnnTensorDesc(input),
+                                outputDesc,
+                                &mWorkspaceSize));
+
+            CHECK_CUDA_STATUS(cudaMalloc(&mWorkspace, mWorkspaceSize));
+        }
+
         switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
             case DataType::Float64:
                 forward_<double>(input, axes, keepDims);
@@ -59,97 +104,33 @@ void Aidge::ReduceSumImpl_cuda::forward_(const Tensor& input, const std::vector<
     const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f;
     const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f;
 
-    cudnnReduceTensorDescriptor_t reduceDesc;
-    cudnnTensorDescriptor_t outputDesc;
     if (keepDims) {
-        outputDesc = std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0));
-        CHECK_CUDNN_STATUS(cudnnCreateReduceTensorDescriptor(&reduceDesc));
-        CHECK_CUDNN_STATUS(cudnnSetReduceTensorDescriptor(reduceDesc,
-                                                            CUDNN_REDUCE_TENSOR_ADD,
-                                                            CudaContext::data_type<T>::value,
-                                                            CUDNN_PROPAGATE_NAN,
-                                                            CUDNN_REDUCE_TENSOR_NO_INDICES,
-                                                            CUDNN_32BIT_INDICES));
-
-
-        size_t workspaceSize;
-        CHECK_CUDNN_STATUS(cudnnGetReductionWorkspaceSize(CudaContext::cudnnHandle(),
-                            reduceDesc,
-                            std::dynamic_pointer_cast<TensorImpl_cuda_>(input.getImpl())->getCudnnTensorDesc(input),
-                            outputDesc,
-                            &workspaceSize));
-
-        void *d_workspace;
-        CHECK_CUDA_STATUS(cudaMalloc(&d_workspace, workspaceSize));
-
         CHECK_CUDNN_STATUS(cudnnReduceTensor(CudaContext::cudnnHandle(),
-                            reduceDesc,
+                            mReduceDesc,
                             NULL,
                             0,
-                            d_workspace,
-                            workspaceSize,
+                            mWorkspace,
+                            mWorkspaceSize,
                             &alpha,
                             std::dynamic_pointer_cast<TensorImpl_cuda_>(input.getImpl())->getCudnnTensorDesc(input),
                             input.getImpl()->rawPtr(),
                             &beta,
-                            outputDesc,
+                            std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
                             std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr()));
-
-        CHECK_CUDNN_STATUS(cudnnDestroyReduceTensorDescriptor(reduceDesc));
     }
     else {
-        CHECK_CUDNN_STATUS(cudnnCreateTensorDescriptor(&outputDesc));
-        std::vector<int> outputDims;
-        std::copy(input.dims().begin(), input.dims().end(), std::back_inserter(outputDims));
-        for (const auto axis:axes) {
-            outputDims[axis] = 1;
-        }
-        if (outputDims.size() < 4) {
-            outputDims.resize(4, 1);
-        }
-        // Compute the corresponding strides
-        std::vector<int> outputStrides(outputDims.size());
-        int product = 1;
-        for (size_t i = outputDims.size(); i > 0; --i) {
-            outputStrides[i - 1] = product;
-            product *= outputDims[i - 1];
-        }
-        CHECK_CUDNN_STATUS(cudnnSetTensorNdDescriptor(outputDesc, CudaContext::data_type<T>::value, outputDims.size(), outputDims.data(), outputStrides.data()));
-    
-        CHECK_CUDNN_STATUS(cudnnCreateReduceTensorDescriptor(&reduceDesc));
-        CHECK_CUDNN_STATUS(cudnnSetReduceTensorDescriptor(reduceDesc,
-                                                            CUDNN_REDUCE_TENSOR_ADD,
-                                                            CudaContext::data_type<T>::value,
-                                                            CUDNN_PROPAGATE_NAN,
-                                                            CUDNN_REDUCE_TENSOR_NO_INDICES,
-                                                            CUDNN_32BIT_INDICES));
-
-
-        size_t workspaceSize;
-        CHECK_CUDNN_STATUS(cudnnGetReductionWorkspaceSize(CudaContext::cudnnHandle(),
-                            reduceDesc,
-                            std::dynamic_pointer_cast<TensorImpl_cuda_>(input.getImpl())->getCudnnTensorDesc(input),
-                            outputDesc,
-                            &workspaceSize));
-
-        void *d_workspace;
-        CHECK_CUDA_STATUS(cudaMalloc(&d_workspace, workspaceSize));
-
         CHECK_CUDNN_STATUS(cudnnReduceTensor(CudaContext::cudnnHandle(),
-                            reduceDesc,
+                            mReduceDesc,
                             NULL,
                             0,
-                            d_workspace,
-                            workspaceSize,
+                            mWorkspace,
+                            mWorkspaceSize,
                             &alpha,
                             std::dynamic_pointer_cast<TensorImpl_cuda_>(input.getImpl())->getCudnnTensorDesc(input),
                             input.getImpl()->rawPtr(),
                             &beta,
-                            outputDesc,
+                            mOutputDesc,
                             std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr()));
-
-        CHECK_CUDNN_STATUS(cudnnDestroyReduceTensorDescriptor(reduceDesc));
-        CHECK_CUDNN_STATUS(cudnnDestroyTensorDescriptor(outputDesc));
     }
 }
 
@@ -203,3 +184,17 @@ void Aidge::ReduceSumImpl_cuda::backward_(const Tensor& outGrad, const std::vect
                         alpha,
                         beta);
 }
+
+Aidge::ReduceSumImpl_cuda::~ReduceSumImpl_cuda() {
+    if (mReduceDesc != nullptr) {
+        cudnnDestroyReduceTensorDescriptor(mReduceDesc);
+    }
+
+    if (mOutputDesc != nullptr) {
+        cudnnDestroyTensorDescriptor(mOutputDesc);
+    }
+
+    if (mWorkspace != nullptr) {
+        cudaFree(mWorkspace);
+    }
+}
diff --git a/src/operator/SubImpl.cpp b/src/operator/SubImpl.cpp
index 249d95f5a03c17e96db41c924361be3de1cbc6b0..b0927205901fb0b48bda9da942d1e7b54b005c96 100644
--- a/src/operator/SubImpl.cpp
+++ b/src/operator/SubImpl.cpp
@@ -33,40 +33,46 @@ void Aidge::SubImpl_cuda::forward() {
         AIDGE_ASSERT(op.getInput(i)->dataType() == datatypeFirstInput, "Cannot add inputs with two differents data type.");
     }
 
-    std::vector<std::shared_ptr<Tensor>> inputFallbacks(op.nbInputs());
-    std::vector<Tensor> inputs(op.nbInputs());
-    std::vector<std::vector<int>> dims(op.nbInputs()); // For broadcasted dims
-    std::vector<std::vector<int>> strides(op.nbInputs()); // For the cooresponding strides
-    for (IOIndex_t i = 0; i < op.nbInputs(); ++i) {
-        inputs[i] = op.getInput(i)->refCastFrom(inputFallbacks[i], *op.getOutput(0));
-
-        // Get tensor dims and broadcast them
-        std::copy(inputs[i].dims().begin(), inputs[i].dims().end(), std::back_inserter(dims[i]));
-        dims[i].insert(dims[i].cbegin(), op.getOutput(0)->nbDims() - dims[i].size(), int(1));
-
-        if (dims[i].size() < 4) {
-            dims[i].resize(4, 1);
-        }
+    if (mInputFallbacks.empty()) {
+        mInputFallbacks.resize(op.nbInputs());
+    }
 
-        // Compute the corresponding strides
-        std::vector<int> tensorStrides(dims[i].size());
-        int product = 1;
-        for (size_t j = dims[i].size(); j > 0; --j) {
-            tensorStrides[j - 1] = product;
-            product *= dims[i][j - 1];
+    std::vector<std::reference_wrapper<Tensor>> inputs;
+    for (IOIndex_t i = 0; i < op.nbInputs(); ++i) {
+        inputs.push_back(op.getInput(i)->refCastFrom(mInputFallbacks[i], *op.getOutput(0)));
+
+        if (mTensorDesc.size() <= i) {
+            // Get tensor dims and broadcast them
+            std::vector<int> dims(inputs[i].get().dims().begin(), inputs[i].get().dims().end());
+            dims.insert(dims.cbegin(), op.getOutput(0)->nbDims() - dims.size(), int(1));
+
+            if (dims.size() < 4) {
+                dims.resize(4, 1);
+            }
+
+            // Compute the corresponding strides
+            std::vector<int> strides(dims.size());
+            int product = 1;
+            for (size_t j = dims.size(); j > 0; --j) {
+                strides[j - 1] = product;
+                product *= dims[j - 1];
+            }
+            
+            mTensorDesc.push_back(nullptr);
+            CHECK_CUDNN_STATUS(cudnnCreateTensorDescriptor(&mTensorDesc[i]));
+            CHECK_CUDNN_STATUS(cudnnSetTensorNdDescriptor(mTensorDesc[i], DataTypeToCudnn(op.getOutput(0)->dataType()), dims.size(), dims.data(), strides.data()));
         }
-        strides[i] = tensorStrides;
     }
 
     switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
         case DataType::Float64:
-            forward_<double>(inputs, dims, strides);
+            forward_<double>(inputs);
             break;
         case DataType::Float32:
-            forward_<float>(inputs, dims, strides);
+            forward_<float>(inputs);
             break;
         case DataType::Float16:
-            forward_<half>(inputs, dims, strides);
+            forward_<half>(inputs);
             break;
         default:
             AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
@@ -74,40 +80,24 @@ void Aidge::SubImpl_cuda::forward() {
 }
 
 template <class T>
-void Aidge::SubImpl_cuda::forward_(const std::vector<Tensor>& inputs, const std::vector<std::vector<int>>& inputsDims, const std::vector<std::vector<int>>& inputsStrides) {
+void Aidge::SubImpl_cuda::forward_(const std::vector<std::reference_wrapper<Tensor>>& inputs) {
     const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
     const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f;
     const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f;
     const typename Cuda::cudnn_scaling_type<T>::type gamma = -1.0f;
-    // Create a Tensor descriptor with the broadcasted dims and strides
-    cudnnTensorDescriptor_t tensorDesc;
-    CHECK_CUDNN_STATUS(cudnnCreateTensorDescriptor(&tensorDesc));
-    CHECK_CUDNN_STATUS(cudnnSetTensorNdDescriptor(tensorDesc, CudaContext::data_type<T>::value, inputsDims[0].size(), inputsDims[0].data(), inputsStrides[0].data()));
-    // Add first input to the output
-    CHECK_CUDNN_STATUS(
-        cudnnAddTensor(CudaContext::cudnnHandle(),
-                       &alpha,
-                       tensorDesc,
-                       inputs[0].getImpl()->rawPtr(),
-                       &beta,
-                       std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
-                       std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr())
-    );
-    // Substract other inputs if there are any
-    for (size_t i = 1; i < op.nbInputs(); ++i)
+
+    for (size_t i = 0; i < op.nbInputs(); ++i)
     {
-        CHECK_CUDNN_STATUS(cudnnSetTensorNdDescriptor(tensorDesc, CudaContext::data_type<T>::value, inputsDims[i].size(), inputsDims[i].data(), inputsStrides[i].data()));
         CHECK_CUDNN_STATUS(
             cudnnAddTensor(CudaContext::cudnnHandle(),
-                        &gamma,
-                        tensorDesc,
-                        inputs[i].getImpl()->rawPtr(),
-                        &alpha,
+                        (i > 0) ? &gamma : &alpha,
+                        mTensorDesc[i],
+                        inputs[i].get().getImpl()->rawPtr(),
+                        (i > 0) ? &alpha : &beta,
                         std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
                         std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr())
         );
     }
-    CHECK_CUDNN_STATUS(cudnnDestroyTensorDescriptor(tensorDesc));
 }
 
 void Aidge::SubImpl_cuda::backward() {
@@ -116,38 +106,42 @@ void Aidge::SubImpl_cuda::backward() {
     AIDGE_ASSERT(op.getOutput(0)->grad(), "missing output gradient in Sub operator");
     AIDGE_ASSERT(op.getOutput(0)->grad()->hasImpl(), "cannot run Sub backward because the output gradient has no implementation.");
 
-    std::shared_ptr<Tensor> outputGradFallback;
-    const auto& outputGrad = op.getOutput(0)->grad()->refCastFrom(outputGradFallback, *op.getOutput(0)->grad());
+    const auto& outputGrad = op.getOutput(0)->grad()->refCastFrom(mOutputGradFallback, *op.getOutput(0)->grad());
 
-    std::vector<std::vector<int>> dims(op.nbInputs()); // For broadcasted dims
-    std::vector<std::vector<int>> strides(op.nbInputs()); // For the cooresponding strides
-    for (IOIndex_t i = 0; i < op.nbInputs(); ++i) {
-        std::shared_ptr<Tensor> inputFallback;
-        const Tensor input = op.getInput(i)->refCastFrom(inputFallback, *op.getOutput(0));
-
-        // Get tensor dims and broadcast them
-        std::copy(input.dims().begin(), input.dims().end(), std::back_inserter(dims[i]));
-        dims[i].insert(dims[i].cbegin(), op.getOutput(0)->nbDims() - dims[i].size(), int(1));
-
-        // Compute the corresponding strides
-        std::vector<int> tensorStrides(dims[i].size());
-        int product = 1;
-        for (size_t j = dims[i].size(); j > 0; --j) {
-            tensorStrides[j - 1] = product;
-            product *= dims[i][j - 1];
+    if (mBwdReduceDesc == nullptr) {
+        CHECK_CUDNN_STATUS(cudnnCreateReduceTensorDescriptor(&mBwdReduceDesc));
+        CHECK_CUDNN_STATUS(cudnnSetReduceTensorDescriptor(mBwdReduceDesc,
+                                                            CUDNN_REDUCE_TENSOR_ADD,
+                                                            DataTypeToCudnn(op.getOutput(0)->dataType()),
+                                                            CUDNN_PROPAGATE_NAN,
+                                                            CUDNN_REDUCE_TENSOR_NO_INDICES,
+                                                            CUDNN_32BIT_INDICES));
+    }
+
+    if (mBwdWorkspace == nullptr) {
+        size_t workspaceSize = 0;
+        for (std::size_t i = 0; i < mTensorDesc.size(); i++) {
+            CHECK_CUDNN_STATUS(cudnnGetReductionWorkspaceSize(CudaContext::cudnnHandle(),
+                                mBwdReduceDesc,
+                                std::dynamic_pointer_cast<TensorImpl_cuda_>(outputGrad.getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
+                                mTensorDesc[i],
+                                &workspaceSize));
+            
+            mBwdWorkspaceSize = std::max(workspaceSize, mBwdWorkspaceSize);
         }
-        strides[i] = tensorStrides;
+
+        CHECK_CUDA_STATUS(cudaMalloc(&mBwdWorkspace, mBwdWorkspaceSize));
     }
 
     switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
         case DataType::Float64:
-            backward_<double>(outputGrad, dims, strides);
+            backward_<double>(outputGrad);
             break;
         case DataType::Float32:
-            backward_<float>(outputGrad, dims, strides);
+            backward_<float>(outputGrad);
             break;
         case DataType::Float16:
-            backward_<half>(outputGrad, dims, strides);
+            backward_<half>(outputGrad);
             break;
         default:
             AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
@@ -155,10 +149,7 @@ void Aidge::SubImpl_cuda::backward() {
 }
 
 template <class T>
-void Aidge::SubImpl_cuda::backward_(
-    const Tensor& outputGrad, 
-    const std::vector<std::vector<int>>& inputsDims, 
-    const std::vector<std::vector<int>>& inputsStrides) 
+void Aidge::SubImpl_cuda::backward_(const Tensor& outputGrad) 
 {
     const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
 
@@ -166,12 +157,11 @@ void Aidge::SubImpl_cuda::backward_(
     const typename Cuda::cudnn_scaling_type<T>::type beta  = 1.0f; // accumulate
     const typename Cuda::cudnn_scaling_type<T>::type gamma = -1.0f;
 
-    for (std::size_t i = 0; i < inputsDims.size(); i++)
+    for (std::size_t i = 0; i < mTensorDesc.size(); i++)
     {
         if (op.getInput(i)->size() == op.getOutput(0)->size())
         {
-            CHECK_CUDNN_STATUS(
-            cudnnAddTensor(CudaContext::cudnnHandle(),
+            CHECK_CUDNN_STATUS(cudnnAddTensor(CudaContext::cudnnHandle(),
                         i==0 ? &alpha: &gamma,
                         std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
                         outputGrad.getImpl()->rawPtr(),
@@ -182,48 +172,34 @@ void Aidge::SubImpl_cuda::backward_(
         else // In case of broadcasting
         {
             // Gradient with respect to input_i: sum outputGrad over the broadcasted dimensions using cudnnReduceTensor
-            cudnnReduceTensorDescriptor_t reduceDesc;
-            CHECK_CUDNN_STATUS(cudnnCreateReduceTensorDescriptor(&reduceDesc));
-            CHECK_CUDNN_STATUS(cudnnSetReduceTensorDescriptor(reduceDesc,
-                                                              CUDNN_REDUCE_TENSOR_ADD,
-                                                              CudaContext::data_type<T>::value,
-                                                              CUDNN_PROPAGATE_NAN,
-                                                              CUDNN_REDUCE_TENSOR_NO_INDICES,
-                                                              CUDNN_32BIT_INDICES));
-
-            cudnnTensorDescriptor_t outputDesc = std::dynamic_pointer_cast<TensorImpl_cuda_>(outputGrad.getImpl())->getCudnnTensorDesc(*op.getOutput(0));
-            // Create a Tensor descriptor with the broadcasted dims and strides
-            cudnnTensorDescriptor_t tensorDesc;
-            CHECK_CUDNN_STATUS(cudnnCreateTensorDescriptor(&tensorDesc));
-            CHECK_CUDNN_STATUS(cudnnSetTensorNdDescriptor(tensorDesc,
-                                                          CudaContext::data_type<T>::value,
-                                                          inputsDims[i].size(),
-                                                          inputsDims[i].data(),
-                                                          inputsStrides[i].data()));
-            size_t workspaceSize;
-            CHECK_CUDNN_STATUS(cudnnGetReductionWorkspaceSize(CudaContext::cudnnHandle(),
-                               reduceDesc,
-                               outputDesc,
-                               tensorDesc,
-                               &workspaceSize));
-
-            void *d_workspace;
-            CHECK_CUDA_STATUS(cudaMalloc(&d_workspace, workspaceSize));
-
             CHECK_CUDNN_STATUS(cudnnReduceTensor(CudaContext::cudnnHandle(),
-                               reduceDesc,
+                               mBwdReduceDesc,
                                NULL,
                                0,
-                               d_workspace,
-                               workspaceSize,
+                               mBwdWorkspace,
+                               mBwdWorkspaceSize,
                                i==0 ? &alpha: &gamma,
-                               outputDesc,
+                               std::dynamic_pointer_cast<TensorImpl_cuda_>(outputGrad.getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
                                outputGrad.getImpl()->rawPtr(),
                                &beta,
-                               tensorDesc,
+                               mTensorDesc[i],
                                op.getInput(i)->grad()->getImpl()->rawPtr()));
+        }
+    }
+}
 
-            CHECK_CUDNN_STATUS(cudnnDestroyTensorDescriptor(tensorDesc));
+Aidge::SubImpl_cuda::~SubImpl_cuda() {
+    for (auto tensorDesc : mTensorDesc) {
+        if (tensorDesc != nullptr) {
+            cudnnDestroyTensorDescriptor(tensorDesc);
         }
     }
+    
+    if (mBwdReduceDesc != nullptr) {
+        cudnnDestroyReduceTensorDescriptor(mBwdReduceDesc);
+    }
+
+    if (mBwdWorkspace != nullptr) {
+        cudaFree(mBwdWorkspace);
+    }
 }
diff --git a/version.txt b/version.txt
index 8ea2ddfc77b9f41e373e0591f46fc2fc155eb4ca..4b9fcbec101a6ff8ec68e0f95131ccda4861407f 100644
--- a/version.txt
+++ b/version.txt
@@ -1,2 +1 @@
-0.5.0
-
+0.5.1