From f5dbfb261a6eee532103c1236f7d751d35c14d0d Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Thu, 10 Apr 2025 22:29:02 +0200 Subject: [PATCH] Added missing guards for multi-GPU --- .../aidge/backend/cuda/utils/CudaContext.hpp | 43 +++++++++---------- src/utils/CudaContext.cpp | 3 ++ 2 files changed, 24 insertions(+), 22 deletions(-) create mode 100644 src/utils/CudaContext.cpp diff --git a/include/aidge/backend/cuda/utils/CudaContext.hpp b/include/aidge/backend/cuda/utils/CudaContext.hpp index f21886e..35f02d6 100644 --- a/include/aidge/backend/cuda/utils/CudaContext.hpp +++ b/include/aidge/backend/cuda/utils/CudaContext.hpp @@ -2,7 +2,9 @@ #define AIDGE_BACKEND_CUDA_CUDA_CONTEXT_H #include <vector> +#include <mutex> +#include "aidge/data/DataType.hpp" #include "aidge/utils/ErrorHandling.hpp" #include "aidge/backend/cuda/utils/CudaUtils.hpp" @@ -46,14 +48,12 @@ public: static std::vector<bool> init; if (deviceProp.empty()) { -//#pragma omp critical(CudaContext__getDeviceProp) - if (deviceProp.empty()) { - int count = 1; - CHECK_CUDA_STATUS(cudaGetDeviceCount(&count)); - - deviceProp.resize(count); - init.resize(count, false); - } + std::lock_guard<std::mutex> guard(initMutex); + int count = 1; + CHECK_CUDA_STATUS(cudaGetDeviceCount(&count)); + + deviceProp.resize(count); + init.resize(count, false); } int dev; @@ -73,13 +73,11 @@ public: static std::vector<cublasHandle_t> cublas_h; if (cublas_h.empty()) { -//#pragma omp critical(CudaContext__cublasHandle) - if (cublas_h.empty()) { - int count = 1; - CHECK_CUDA_STATUS(cudaGetDeviceCount(&count)); + std::lock_guard<std::mutex> guard(initMutex); + int count = 1; + CHECK_CUDA_STATUS(cudaGetDeviceCount(&count)); - cublas_h.resize(count, NULL); - } + cublas_h.resize(count, NULL); } int dev; @@ -87,7 +85,7 @@ public: if (cublas_h[dev] == NULL) { CHECK_CUBLAS_STATUS(cublasCreate(&cublas_h[dev])); - fmt::print("CUBLAS initialized on device #{}\n", dev); + Log::debug("CUBLAS initialized on device #{}\n", dev); } return cublas_h[dev]; @@ -99,13 +97,11 @@ public: static std::vector<cudnnHandle_t> cudnn_h; if (cudnn_h.empty()) { -//#pragma omp critical(CudaContext__cudnnHandle) - if (cudnn_h.empty()) { - int count = 1; - CHECK_CUDA_STATUS(cudaGetDeviceCount(&count)); + std::lock_guard<std::mutex> guard(initMutex); + int count = 1; + CHECK_CUDA_STATUS(cudaGetDeviceCount(&count)); - cudnn_h.resize(count, NULL); - } + cudnn_h.resize(count, NULL); } int dev; @@ -113,7 +109,7 @@ public: if (cudnn_h[dev] == NULL) { CHECK_CUDNN_STATUS(cudnnCreate(&cudnn_h[dev])); - fmt::print("CUDNN initialized on device #{}\n", dev); + Log::debug("CUDNN initialized on device #{}\n", dev); } return cudnn_h[dev]; @@ -124,6 +120,9 @@ public: static const cudnnDataType_t value = CUDNN_DATA_FLOAT; // Dummy value by default }; + +private: + static std::mutex initMutex; }; } diff --git a/src/utils/CudaContext.cpp b/src/utils/CudaContext.cpp new file mode 100644 index 0000000..22fc346 --- /dev/null +++ b/src/utils/CudaContext.cpp @@ -0,0 +1,3 @@ +#include "aidge/backend/cuda/utils/CudaContext.hpp" + +std::mutex Aidge::CudaContext::initMutex; -- GitLab