Skip to content
Snippets Groups Projects
Commit f5dbfb26 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added missing guards for multi-GPU

parent 83486f0b
No related branches found
No related tags found
1 merge request!75Update 0.5.1 -> 0.6.0
Pipeline #70670 passed
...@@ -2,7 +2,9 @@ ...@@ -2,7 +2,9 @@
#define AIDGE_BACKEND_CUDA_CUDA_CONTEXT_H #define AIDGE_BACKEND_CUDA_CUDA_CONTEXT_H
#include <vector> #include <vector>
#include <mutex>
#include "aidge/data/DataType.hpp"
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
#include "aidge/backend/cuda/utils/CudaUtils.hpp" #include "aidge/backend/cuda/utils/CudaUtils.hpp"
...@@ -46,14 +48,12 @@ public: ...@@ -46,14 +48,12 @@ public:
static std::vector<bool> init; static std::vector<bool> init;
if (deviceProp.empty()) { if (deviceProp.empty()) {
//#pragma omp critical(CudaContext__getDeviceProp) std::lock_guard<std::mutex> guard(initMutex);
if (deviceProp.empty()) { int count = 1;
int count = 1; CHECK_CUDA_STATUS(cudaGetDeviceCount(&count));
CHECK_CUDA_STATUS(cudaGetDeviceCount(&count));
deviceProp.resize(count);
deviceProp.resize(count); init.resize(count, false);
init.resize(count, false);
}
} }
int dev; int dev;
...@@ -73,13 +73,11 @@ public: ...@@ -73,13 +73,11 @@ public:
static std::vector<cublasHandle_t> cublas_h; static std::vector<cublasHandle_t> cublas_h;
if (cublas_h.empty()) { if (cublas_h.empty()) {
//#pragma omp critical(CudaContext__cublasHandle) std::lock_guard<std::mutex> guard(initMutex);
if (cublas_h.empty()) { int count = 1;
int count = 1; CHECK_CUDA_STATUS(cudaGetDeviceCount(&count));
CHECK_CUDA_STATUS(cudaGetDeviceCount(&count));
cublas_h.resize(count, NULL); cublas_h.resize(count, NULL);
}
} }
int dev; int dev;
...@@ -87,7 +85,7 @@ public: ...@@ -87,7 +85,7 @@ public:
if (cublas_h[dev] == NULL) { if (cublas_h[dev] == NULL) {
CHECK_CUBLAS_STATUS(cublasCreate(&cublas_h[dev])); CHECK_CUBLAS_STATUS(cublasCreate(&cublas_h[dev]));
fmt::print("CUBLAS initialized on device #{}\n", dev); Log::debug("CUBLAS initialized on device #{}\n", dev);
} }
return cublas_h[dev]; return cublas_h[dev];
...@@ -99,13 +97,11 @@ public: ...@@ -99,13 +97,11 @@ public:
static std::vector<cudnnHandle_t> cudnn_h; static std::vector<cudnnHandle_t> cudnn_h;
if (cudnn_h.empty()) { if (cudnn_h.empty()) {
//#pragma omp critical(CudaContext__cudnnHandle) std::lock_guard<std::mutex> guard(initMutex);
if (cudnn_h.empty()) { int count = 1;
int count = 1; CHECK_CUDA_STATUS(cudaGetDeviceCount(&count));
CHECK_CUDA_STATUS(cudaGetDeviceCount(&count));
cudnn_h.resize(count, NULL); cudnn_h.resize(count, NULL);
}
} }
int dev; int dev;
...@@ -113,7 +109,7 @@ public: ...@@ -113,7 +109,7 @@ public:
if (cudnn_h[dev] == NULL) { if (cudnn_h[dev] == NULL) {
CHECK_CUDNN_STATUS(cudnnCreate(&cudnn_h[dev])); CHECK_CUDNN_STATUS(cudnnCreate(&cudnn_h[dev]));
fmt::print("CUDNN initialized on device #{}\n", dev); Log::debug("CUDNN initialized on device #{}\n", dev);
} }
return cudnn_h[dev]; return cudnn_h[dev];
...@@ -124,6 +120,9 @@ public: ...@@ -124,6 +120,9 @@ public:
static const cudnnDataType_t value = CUDNN_DATA_FLOAT; static const cudnnDataType_t value = CUDNN_DATA_FLOAT;
// Dummy value by default // Dummy value by default
}; };
private:
static std::mutex initMutex;
}; };
} }
......
#include "aidge/backend/cuda/utils/CudaContext.hpp"
std::mutex Aidge::CudaContext::initMutex;
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment