From 8b9a6658e18b83b89700200a7e2295fc48879ef7 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Sat, 9 Dec 2023 23:20:50 +0100
Subject: [PATCH] Added safeguards

---
 include/aidge/backend/cuda/data/TensorImpl.hpp | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/include/aidge/backend/cuda/data/TensorImpl.hpp b/include/aidge/backend/cuda/data/TensorImpl.hpp
index 6b309c5..45f8a6c 100644
--- a/include/aidge/backend/cuda/data/TensorImpl.hpp
+++ b/include/aidge/backend/cuda/data/TensorImpl.hpp
@@ -85,6 +85,7 @@ public:
             return;
         }
 
+        AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity");
         if (srcDt == DataType::Float64) {
             thrust_copy(static_cast<const double*>(src),
                         static_cast<T*>(rawPtr()),
@@ -141,14 +142,17 @@ public:
     }
 
     void copyFromDevice(const void *src, NbElts_t length, const std::pair<std::string, int>& device) override {
+        AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity");
         CHECK_CUDA_STATUS(cudaMemcpy(rawPtr(), src, length * sizeof(T), cudaMemcpyDeviceToDevice));
     }
 
     void copyFromHost(const void *src, NbElts_t length) override {
+        AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity");
         CHECK_CUDA_STATUS(cudaMemcpy(rawPtr(), src, length * sizeof(T), cudaMemcpyHostToDevice));
     }
 
     void copyToHost(void *dst, NbElts_t length) const override {
+        AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity");
         CHECK_CUDA_STATUS(cudaMemcpy(dst, rawPtr(), length * sizeof(T), cudaMemcpyDeviceToHost));
     }
 
-- 
GitLab