From f5dbfb261a6eee532103c1236f7d751d35c14d0d Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Thu, 10 Apr 2025 22:29:02 +0200
Subject: [PATCH] Added missing guards for multi-GPU

---
 .../aidge/backend/cuda/utils/CudaContext.hpp  | 43 +++++++++----------
 src/utils/CudaContext.cpp                     |  3 ++
 2 files changed, 24 insertions(+), 22 deletions(-)
 create mode 100644 src/utils/CudaContext.cpp

diff --git a/include/aidge/backend/cuda/utils/CudaContext.hpp b/include/aidge/backend/cuda/utils/CudaContext.hpp
index f21886e..35f02d6 100644
--- a/include/aidge/backend/cuda/utils/CudaContext.hpp
+++ b/include/aidge/backend/cuda/utils/CudaContext.hpp
@@ -2,7 +2,9 @@
 #define AIDGE_BACKEND_CUDA_CUDA_CONTEXT_H
 
 #include <vector>
+#include <mutex>
 
+#include "aidge/data/DataType.hpp"
 #include "aidge/utils/ErrorHandling.hpp"
 #include "aidge/backend/cuda/utils/CudaUtils.hpp"
 
@@ -46,14 +48,12 @@ public:
         static std::vector<bool> init;
 
         if (deviceProp.empty()) {
-//#pragma omp critical(CudaContext__getDeviceProp)
-            if (deviceProp.empty()) {
-                int count = 1;
-                CHECK_CUDA_STATUS(cudaGetDeviceCount(&count));
-
-                deviceProp.resize(count);
-                init.resize(count, false);
-            }
+            std::lock_guard<std::mutex> guard(initMutex);
+            int count = 1;
+            CHECK_CUDA_STATUS(cudaGetDeviceCount(&count));
+
+            deviceProp.resize(count);
+            init.resize(count, false);
         }
 
         int dev;
@@ -73,13 +73,11 @@ public:
         static std::vector<cublasHandle_t> cublas_h;
 
         if (cublas_h.empty()) {
-//#pragma omp critical(CudaContext__cublasHandle)
-            if (cublas_h.empty()) {
-                int count = 1;
-                CHECK_CUDA_STATUS(cudaGetDeviceCount(&count));
+            std::lock_guard<std::mutex> guard(initMutex);
+            int count = 1;
+            CHECK_CUDA_STATUS(cudaGetDeviceCount(&count));
 
-                cublas_h.resize(count, NULL);
-            }
+            cublas_h.resize(count, NULL);
         }
 
         int dev;
@@ -87,7 +85,7 @@ public:
 
         if (cublas_h[dev] == NULL) {
             CHECK_CUBLAS_STATUS(cublasCreate(&cublas_h[dev]));
-            fmt::print("CUBLAS initialized on device #{}\n", dev);
+            Log::debug("CUBLAS initialized on device #{}\n", dev);
         }
 
         return cublas_h[dev];
@@ -99,13 +97,11 @@ public:
         static std::vector<cudnnHandle_t> cudnn_h;
 
         if (cudnn_h.empty()) {
-//#pragma omp critical(CudaContext__cudnnHandle)
-            if (cudnn_h.empty()) {
-                int count = 1;
-                CHECK_CUDA_STATUS(cudaGetDeviceCount(&count));
+            std::lock_guard<std::mutex> guard(initMutex);
+            int count = 1;
+            CHECK_CUDA_STATUS(cudaGetDeviceCount(&count));
 
-                cudnn_h.resize(count, NULL);
-            }
+            cudnn_h.resize(count, NULL);
         }
 
         int dev;
@@ -113,7 +109,7 @@ public:
 
         if (cudnn_h[dev] == NULL) {
             CHECK_CUDNN_STATUS(cudnnCreate(&cudnn_h[dev]));
-            fmt::print("CUDNN initialized on device #{}\n", dev);
+            Log::debug("CUDNN initialized on device #{}\n", dev);
         }
 
         return cudnn_h[dev];
@@ -124,6 +120,9 @@ public:
         static const cudnnDataType_t value = CUDNN_DATA_FLOAT;
                                             // Dummy value by default
     };
+
+private:
+    static std::mutex initMutex;
 };
 }
 
diff --git a/src/utils/CudaContext.cpp b/src/utils/CudaContext.cpp
new file mode 100644
index 0000000..22fc346
--- /dev/null
+++ b/src/utils/CudaContext.cpp
@@ -0,0 +1,3 @@
+#include "aidge/backend/cuda/utils/CudaContext.hpp"
+
+std::mutex Aidge::CudaContext::initMutex;
-- 
GitLab