[mlir] Initialize CUDA context lazily.

So we can remove the ignore-warning pragma again.

Reviewed By: herhut

Differential Revision: https://reviews.llvm.org/D97864
This commit is contained in:
Christian Sigg 2021-03-03 17:35:02 +01:00
parent b7aeece47c
commit f69d5a7fc7
1 changed files with 15 additions and 21 deletions

View File

@ -37,32 +37,26 @@
fprintf(stderr, "'%s' failed with '%s'\n", #expr, name); \ fprintf(stderr, "'%s' failed with '%s'\n", #expr, name); \
}(expr) }(expr)
#pragma clang diagnostic push // Make the primary context of device 0 current for the duration of the instance
#pragma clang diagnostic ignored "-Wglobal-constructors" // and restore the previous context on destruction.
// Static reference to CUDA primary context for device ordinal 0.
static CUcontext Context = [] {
CUDA_REPORT_IF_ERROR(cuInit(/*flags=*/0));
CUdevice device;
CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/0));
CUcontext context;
CUDA_REPORT_IF_ERROR(cuDevicePrimaryCtxRetain(&context, device));
return context;
}();
#pragma clang diagnostic pop
// Sets the `Context` for the duration of the instance and restores the previous
// context on destruction.
class ScopedContext { class ScopedContext {
public: public:
ScopedContext() { ScopedContext() {
CUDA_REPORT_IF_ERROR(cuCtxGetCurrent(&previous)); // Static reference to CUDA primary context for device ordinal 0.
CUDA_REPORT_IF_ERROR(cuCtxSetCurrent(Context)); static CUcontext context = [] {
CUDA_REPORT_IF_ERROR(cuInit(/*flags=*/0));
CUdevice device;
CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/0));
CUcontext ctx;
// Note: this does not affect the current context.
CUDA_REPORT_IF_ERROR(cuDevicePrimaryCtxRetain(&ctx, device));
return ctx;
}();
CUDA_REPORT_IF_ERROR(cuCtxPushCurrent(context));
} }
~ScopedContext() { CUDA_REPORT_IF_ERROR(cuCtxSetCurrent(previous)); } ~ScopedContext() { CUDA_REPORT_IF_ERROR(cuCtxPopCurrent(nullptr)); }
private:
CUcontext previous;
}; };
extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUmodule mgpuModuleLoad(void *data) { extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUmodule mgpuModuleLoad(void *data) {