|
| 1 | +#!/bin/env python3 |
| 2 | + |
| 3 | +""" |
| 4 | +Wraps invocations of ``nvcc``, watching for evidence of SIGKILL or SIGSEGV, |
| 5 | +and then re-running the ``nvcc`` command a configurable number of times. |
| 6 | +
|
| 7 | +Checking for SIGKILL or SIGSEGV is implemented by checking for either: |
| 8 | +
|
| 9 | +- A subprocess return code indicating either of these signals, or |
| 10 | +- The standard ``dash`` error messages for either signal. |
| 11 | +
|
| 12 | +``dash`` status messages are checked as NVCC utilizes ``sh`` |
| 13 | +subprocesses internally, and ``sh`` usually resolves to |
| 14 | +the ``dash`` shell within Ubuntu container images. |
| 15 | +
|
| 16 | +This wrapper also has the ability to filter out some -gencode flags. |
| 17 | +Gencode flags to filter out should be identified by their code parameter |
| 18 | +in a semicolon-delimited list stored in the NVCC_WRAPPER_FILTER_CODES |
| 19 | +environment variable. |
| 20 | +""" |
| 21 | + |
| 22 | +import asyncio |
| 23 | +import os |
| 24 | +import re |
| 25 | +import shutil |
| 26 | +import signal |
| 27 | +import subprocess |
| 28 | +import sys |
| 29 | +from typing import BinaryIO, Final, FrozenSet, Iterable, List, Sequence, Set |
| 30 | + |
| 31 | +NVCC_PATH: Final[str] = shutil.which("nvcc") |
| 32 | +if NVCC_PATH is None: |
| 33 | + raise SystemExit("NVCC wrapper: fatal: nvcc binary not found") |
| 34 | + |
| 35 | +WRAPPER_ATTEMPTS: Final[int] = int(os.getenv("NVCC_WRAPPER_ATTEMPTS") or 10) |
| 36 | +if WRAPPER_ATTEMPTS < 1: |
| 37 | + raise SystemExit("NVCC wrapper: fatal: invalid value for NVCC_WRAPPER_ATTEMPTS") |
| 38 | + |
| 39 | +FILTER_CODES: Final[FrozenSet[str]] = frozenset( |
| 40 | + filter(None, os.getenv("NVCC_WRAPPER_FILTER_CODES", "").split(";")) |
| 41 | +) |
| 42 | +if FILTER_CODES and not all( |
| 43 | + re.fullmatch(r"(?:sm|compute|lto)_\d+[af]?", a) for a in FILTER_CODES |
| 44 | +): |
| 45 | + raise SystemExit("NVCC wrapper: fatal: invalid value for NVCC_WRAPPER_FILTER_CODES") |
| 46 | + |
| 47 | +RETRY_RET_CODES: Final[FrozenSet[int]] = frozenset({ |
| 48 | + -signal.SIGSEGV, |
| 49 | + -signal.SIGKILL, |
| 50 | + 128 + signal.SIGSEGV, |
| 51 | + 128 + signal.SIGKILL, |
| 52 | + 255, |
| 53 | +}) |
| 54 | + |
| 55 | + |
| 56 | +async def main(args) -> int: |
| 57 | + args = transform_args(args) |
| 58 | + ret: int = 0 |
| 59 | + for attempt in range(1, WRAPPER_ATTEMPTS + 1): |
| 60 | + if attempt > 1: |
| 61 | + print( |
| 62 | + "NVCC wrapper: info:" |
| 63 | + f" Retrying [{attempt:d}/{WRAPPER_ATTEMPTS:d}]" |
| 64 | + f" after exit code {ret:d}", |
| 65 | + file=sys.stderr, |
| 66 | + flush=True, |
| 67 | + ) |
| 68 | + # Wait an exponentially increasing amount of time |
| 69 | + # before trying again, up to one minute |
| 70 | + await asyncio.sleep(min(60, int(1.5**attempt))) |
| 71 | + proc = await asyncio.create_subprocess_exec( |
| 72 | + NVCC_PATH, *args, stdout=subprocess.PIPE, stderr=subprocess.PIPE |
| 73 | + ) |
| 74 | + restart_signals: tuple = await asyncio.gather( |
| 75 | + monitor_stream(proc.stdout, sys.stdout.buffer), |
| 76 | + monitor_stream(proc.stderr, sys.stderr.buffer), |
| 77 | + ) |
| 78 | + ret = await proc.wait() |
| 79 | + del proc |
| 80 | + if ret == 0 or not any(restart_signals) and ret not in RETRY_RET_CODES: |
| 81 | + break |
| 82 | + else: |
| 83 | + print( |
| 84 | + "NVCC wrapper: info:" |
| 85 | + f" Maximum attempts reached, exiting with status {ret:d}", |
| 86 | + file=sys.stderr, |
| 87 | + flush=True, |
| 88 | + ) |
| 89 | + return ret |
| 90 | + |
| 91 | + |
| 92 | +async def monitor_stream( |
| 93 | + stream: asyncio.StreamReader, |
| 94 | + output: BinaryIO, |
| 95 | + watch_for: Iterable[bytes] = ( |
| 96 | + b"Segmentation fault", |
| 97 | + b"Segmentation fault (core dumped)", |
| 98 | + b"Killed", |
| 99 | + ), |
| 100 | +) -> bool: |
| 101 | + found: bool = False |
| 102 | + while line := await stream.readline(): |
| 103 | + found = found or line.strip() in watch_for |
| 104 | + output.write(line) |
| 105 | + output.flush() |
| 106 | + return found |
| 107 | + |
| 108 | + |
| 109 | +def transform_args(args: Sequence[str]) -> Sequence[str]: |
| 110 | + # This filters out args of the form -gencode=arch=X,code=Y |
| 111 | + # or -gencode arch=X,code=Y for any code in FILTER_CODES. |
| 112 | + # This does not filter arguments specified using the |
| 113 | + # --gpu-architecture and --gpu-code flags, nor codes specified |
| 114 | + # among others in groups, like -gencode=arch=X,code=[Y,Z]. |
| 115 | + if not FILTER_CODES: |
| 116 | + return args |
| 117 | + transformed_args = [] |
| 118 | + partial: bool = False |
| 119 | + gencode: Set[str] = {"-gencode", "--generate-code"} |
| 120 | + for arg in args: |
| 121 | + if not partial and arg in gencode: |
| 122 | + partial = True |
| 123 | + transformed_args.append(arg) |
| 124 | + continue |
| 125 | + if partial: |
| 126 | + pattern: str = r"(arch=[^,]+,code=)(\S+)" |
| 127 | + else: |
| 128 | + pattern: str = r"((?:-gencode|--generate-code)=arch=\S+,code=)(\S+)" |
| 129 | + m: re.Match = re.fullmatch(pattern, arg) |
| 130 | + if m: |
| 131 | + code: str = m.group(2) |
| 132 | + if code in FILTER_CODES: |
| 133 | + if partial: |
| 134 | + # There was a hanging `-gencode` arg before this, so delete it |
| 135 | + assert transformed_args[-1] in gencode |
| 136 | + del transformed_args[-1] |
| 137 | + elif re.fullmatch(r"\[\S+]", code): |
| 138 | + codes: List[str] = code[1:-1].split(",") |
| 139 | + filtered_codes: List[str] = [c for c in codes if c not in FILTER_CODES] |
| 140 | + if filtered_codes: |
| 141 | + filtered_code: str = ",".join(filtered_codes) |
| 142 | + if len(filtered_codes) > 1: |
| 143 | + filtered_code = f"[{filtered_code}]" |
| 144 | + transformed_args.append(m.group(1) + filtered_code) |
| 145 | + elif partial: |
| 146 | + assert transformed_args[-1] in gencode |
| 147 | + del transformed_args[-1] |
| 148 | + else: |
| 149 | + transformed_args.append(arg) |
| 150 | + else: |
| 151 | + transformed_args.append(arg) |
| 152 | + partial = False |
| 153 | + return transformed_args |
| 154 | + |
| 155 | + |
| 156 | +if __name__ == "__main__": |
| 157 | + try: |
| 158 | + sys.exit(asyncio.run(main(sys.argv[1:]))) |
| 159 | + except KeyboardInterrupt: |
| 160 | + sys.exit(130) |
0 commit comments