Skip to content
Snippets Groups Projects
Commit f5dbfb26 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added missing guards for multi-GPU

parent 83486f0b
No related branches found
No related tags found
1 merge request!75Update 0.5.1 -> 0.6.0
Pipeline #70670 passed
......@@ -2,7 +2,9 @@
#define AIDGE_BACKEND_CUDA_CUDA_CONTEXT_H
#include <vector>
#include <mutex>
#include "aidge/data/DataType.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/backend/cuda/utils/CudaUtils.hpp"
......@@ -46,14 +48,12 @@ public:
static std::vector<bool> init;
if (deviceProp.empty()) {
//#pragma omp critical(CudaContext__getDeviceProp)
if (deviceProp.empty()) {
int count = 1;
CHECK_CUDA_STATUS(cudaGetDeviceCount(&count));
deviceProp.resize(count);
init.resize(count, false);
}
std::lock_guard<std::mutex> guard(initMutex);
int count = 1;
CHECK_CUDA_STATUS(cudaGetDeviceCount(&count));
deviceProp.resize(count);
init.resize(count, false);
}
int dev;
......@@ -73,13 +73,11 @@ public:
static std::vector<cublasHandle_t> cublas_h;
if (cublas_h.empty()) {
//#pragma omp critical(CudaContext__cublasHandle)
if (cublas_h.empty()) {
int count = 1;
CHECK_CUDA_STATUS(cudaGetDeviceCount(&count));
std::lock_guard<std::mutex> guard(initMutex);
int count = 1;
CHECK_CUDA_STATUS(cudaGetDeviceCount(&count));
cublas_h.resize(count, NULL);
}
cublas_h.resize(count, NULL);
}
int dev;
......@@ -87,7 +85,7 @@ public:
if (cublas_h[dev] == NULL) {
CHECK_CUBLAS_STATUS(cublasCreate(&cublas_h[dev]));
fmt::print("CUBLAS initialized on device #{}\n", dev);
Log::debug("CUBLAS initialized on device #{}\n", dev);
}
return cublas_h[dev];
......@@ -99,13 +97,11 @@ public:
static std::vector<cudnnHandle_t> cudnn_h;
if (cudnn_h.empty()) {
//#pragma omp critical(CudaContext__cudnnHandle)
if (cudnn_h.empty()) {
int count = 1;
CHECK_CUDA_STATUS(cudaGetDeviceCount(&count));
std::lock_guard<std::mutex> guard(initMutex);
int count = 1;
CHECK_CUDA_STATUS(cudaGetDeviceCount(&count));
cudnn_h.resize(count, NULL);
}
cudnn_h.resize(count, NULL);
}
int dev;
......@@ -113,7 +109,7 @@ public:
if (cudnn_h[dev] == NULL) {
CHECK_CUDNN_STATUS(cudnnCreate(&cudnn_h[dev]));
fmt::print("CUDNN initialized on device #{}\n", dev);
Log::debug("CUDNN initialized on device #{}\n", dev);
}
return cudnn_h[dev];
......@@ -124,6 +120,9 @@ public:
static const cudnnDataType_t value = CUDNN_DATA_FLOAT;
// Dummy value by default
};
private:
static std::mutex initMutex;
};
}
......
#include "aidge/backend/cuda/utils/CudaContext.hpp"
std::mutex Aidge::CudaContext::initMutex;
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment