4646/*
4747 * CUDA ERROR HANDLING
4848 *
49- * which is only defined when CUDA-compiling, since it is invoked only a macro (defined
50- * in gpu_config.hpp) which wraps CUDA API calls
49+ * which are only defined when CUDA-compiling, since only ever invoked
50+ * when encountering issues through use of the CUDA API
5151 */
5252
5353#if COMPILE_CUDA
5454
5555void assertCudaCallSucceeded (int result, const char * call, const char * caller, const char * file, int line) {
5656
57+ // this function is only invoked by the CUDA_CHECK macro defined in gpu_config.hpp header
58+
5759 // result (int) is actually type cudaError_t but we cannot use this CUDA-defined type
5860 // in gpu_config.hpp (since it's included by non-CUDA-compiled files), and we wish to keep
5961 // the signature consistent.
@@ -63,6 +65,32 @@ void assertCudaCallSucceeded(int result, const char* call, const char* caller, c
6365 error_cudaCallFailed (cudaGetErrorString (code), call, caller, file, line);
6466}
6567
68+ void clearPossibleCudaError () {
69+
70+ // beware that in addition to clearing anticipated CUDA errors (like
71+ // cudaMalloc failing), this function will check that the CUDA API is
72+ // generally working (i.e. has not encountered an irrecoverable error),
73+ // including whether e.g. the CUDA drivers match the runtime version. It
74+ // should ergo never be called in settings where GPU is compiled but not
75+ // runtime activated, since such settings see CUDA be in an acceptably
76+ // broken state - calling this function would throw an internal error
77+
78+ // clear "non-sticky" errors so that future CUDA API use is not corrupted
79+ cudaError_t initialCode = cudaGetLastError ();
80+
81+ // nothing to do if no error had occurred
82+ if (initialCode == cudaSuccess)
83+ return ;
84+
85+ // sync and re-check if error code is erroneously unchanged, which
86+ // indicates that CUDA encountered an irrecoverable "sticky" error
87+ CUDA_CHECK ( cudaDeviceSynchronize () );
88+
89+ cudaError_t finalCode = cudaGetLastError ();
90+ if (initialCode == finalCode)
91+ error_cudaEncounteredIrrecoverableError ();
92+ }
93+
6694#endif
6795
6896
@@ -153,6 +181,12 @@ int gpu_getNumberOfLocalGpus() {
153181 // is called but no devices exist, which we handle
154182 int num;
155183 auto status = cudaGetDeviceCount (&num);
184+
185+ // treat query failure as indication of no local GPUs
186+ // so do not call clearPossibleCudaError(). This is
187+ // necessary because cudaGetDeviceCount() can report
188+ // driver version errors when QuEST is GPU-compiled
189+ // on a platform without a GPU, which we tolerate
156190 return (status == cudaSuccess)? num : 0 ;
157191
158192#else
@@ -176,8 +210,12 @@ bool gpu_isGpuAvailable() {
176210 struct cudaDeviceProp props;
177211 auto status = cudaGetDeviceProperties (&props, deviceInd);
178212
179- // if the query failed, device is anyway unusable
180- if (status != cudaSuccess)
213+ // if the query failed, device is anyway unusable; we do not
214+ // clear the error with clearPossibleCudaError() since this
215+ // can trigger an internal error when QuEST is GPU-compiled
216+ // but no valid GPU exists (hence no valid driver), like
217+ // occurs on cluster submission nodes
218+ if (status != cudaSuccess)
181219 continue ;
182220
183221 // if the device is a real GPU, it's 'major' compute capability is != 9999 (meaning emulation)
@@ -405,9 +443,16 @@ qcomp* gpu_allocArray(qindex length) {
405443 qcomp* ptr;
406444 cudaError_t errCode = cudaMalloc (&ptr, numBytes);
407445
408- // intercept memory-alloc error and merely return nullptr pointer (to be handled by validation)
409- if (errCode == cudaErrorMemoryAllocation)
446+ // intercept memory-alloc error (handled by caller's validation)
447+ if (errCode == cudaErrorMemoryAllocation) {
448+
449+ // malloc failure can break CUDA API state, so recover it in
450+ // case execution is continuing (e.g. by unit tests)
451+ clearPossibleCudaError ();
452+
453+ // indicate alloc failure
410454 return nullptr ;
455+ }
411456
412457 // pass all other unexpected errors to internal error handling
413458 CUDA_CHECK (errCode);
0 commit comments