diff --git a/include/aidge/backend/cuda/utils/CudaContext.hpp b/include/aidge/backend/cuda/utils/CudaContext.hpp index f21886e502b9017aa55e250e7257d16bc5d04501..35f02d6905eb4a1a669344848a70c397d2572d79 100644 --- a/include/aidge/backend/cuda/utils/CudaContext.hpp +++ b/include/aidge/backend/cuda/utils/CudaContext.hpp @@ -2,7 +2,9 @@ #define AIDGE_BACKEND_CUDA_CUDA_CONTEXT_H #include <vector> +#include <mutex> +#include "aidge/data/DataType.hpp" #include "aidge/utils/ErrorHandling.hpp" #include "aidge/backend/cuda/utils/CudaUtils.hpp" @@ -46,14 +48,12 @@ public: static std::vector<bool> init; if (deviceProp.empty()) { -//#pragma omp critical(CudaContext__getDeviceProp) - if (deviceProp.empty()) { - int count = 1; - CHECK_CUDA_STATUS(cudaGetDeviceCount(&count)); - - deviceProp.resize(count); - init.resize(count, false); - } + std::lock_guard<std::mutex> guard(initMutex); + int count = 1; + CHECK_CUDA_STATUS(cudaGetDeviceCount(&count)); + + deviceProp.resize(count); + init.resize(count, false); } int dev; @@ -73,13 +73,11 @@ public: static std::vector<cublasHandle_t> cublas_h; if (cublas_h.empty()) { -//#pragma omp critical(CudaContext__cublasHandle) - if (cublas_h.empty()) { - int count = 1; - CHECK_CUDA_STATUS(cudaGetDeviceCount(&count)); + std::lock_guard<std::mutex> guard(initMutex); + int count = 1; + CHECK_CUDA_STATUS(cudaGetDeviceCount(&count)); - cublas_h.resize(count, NULL); - } + cublas_h.resize(count, NULL); } int dev; @@ -87,7 +85,7 @@ public: if (cublas_h[dev] == NULL) { CHECK_CUBLAS_STATUS(cublasCreate(&cublas_h[dev])); - fmt::print("CUBLAS initialized on device #{}\n", dev); + Log::debug("CUBLAS initialized on device #{}\n", dev); } return cublas_h[dev]; @@ -99,13 +97,11 @@ public: static std::vector<cudnnHandle_t> cudnn_h; if (cudnn_h.empty()) { -//#pragma omp critical(CudaContext__cudnnHandle) - if (cudnn_h.empty()) { - int count = 1; - CHECK_CUDA_STATUS(cudaGetDeviceCount(&count)); + std::lock_guard<std::mutex> guard(initMutex); + int count = 1; + CHECK_CUDA_STATUS(cudaGetDeviceCount(&count)); - cudnn_h.resize(count, NULL); - } + cudnn_h.resize(count, NULL); } int dev; @@ -113,7 +109,7 @@ public: if (cudnn_h[dev] == NULL) { CHECK_CUDNN_STATUS(cudnnCreate(&cudnn_h[dev])); - fmt::print("CUDNN initialized on device #{}\n", dev); + Log::debug("CUDNN initialized on device #{}\n", dev); } return cudnn_h[dev]; @@ -124,6 +120,9 @@ public: static const cudnnDataType_t value = CUDNN_DATA_FLOAT; // Dummy value by default }; + +private: + static std::mutex initMutex; }; } diff --git a/src/utils/CudaContext.cpp b/src/utils/CudaContext.cpp new file mode 100644 index 0000000000000000000000000000000000000000..22fc346167c07286565b5c199cc22d20f64104ff --- /dev/null +++ b/src/utils/CudaContext.cpp @@ -0,0 +1,3 @@ +#include "aidge/backend/cuda/utils/CudaContext.hpp" + +std::mutex Aidge::CudaContext::initMutex;