Skip to content

Commit 8552fbc

Browse files
authored
Merge pull request #104 from coreweave/es/vllm-0.10.0
feat(vllm-tensorizer): Update to vLLM v0.10.0 & `flashinfer` v0.2.8
2 parents 11a75a4 + dc236fb commit 8552fbc

File tree

3 files changed

+193
-8
lines changed

3 files changed

+193
-8
lines changed

.github/configurations/vllm-tensorizer.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
vllm-commit:
2-
- 'v0.9.2'
2+
- 'v0.10.0'
33
flashinfer-commit:
4-
- 'v0.2.6.post1'
4+
- 'v0.2.8'
55
builder-base-image:
66
- 'ghcr.io/coreweave/ml-containers/torch-extras:es-fa3-te-update-7a94157-nccl-cuda12.9.1-ubuntu22.04-nccl2.27.6-1-torch2.7.1-vision0.22.1-audio2.7.1-abi1'
77
final-base-image:

vllm-tensorizer/Dockerfile

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,28 @@ WORKDIR /workspace
2525
RUN --mount=type=bind,from=freezer,target=/tmp/frozen \
2626
/tmp/frozen/freeze.sh torch torchaudio torchvision xformers > /opt/constraints.txt
2727

28+
COPY --link --chmod=755 nvcc-wrapper.py /opt/nvcc-wrapper.py
29+
ENV PYTORCH_NVCC='/opt/nvcc-wrapper.py' \
30+
CMAKE_CUDA_COMPILER='/opt/nvcc-wrapper.py'
31+
32+
ARG TARGETPLATFORM
33+
# Switch 9.0, 10.0, and 12.0 to -a variants; preserve originals for PTX
34+
# Flashinfer v0.28.0 in particular can only build for 12.0a but not 12.0
35+
RUN printf 'TORCH_CUDA_ARCH_LIST=' && \
36+
echo "${TORCH_CUDA_ARCH_LIST}" \
37+
| sed -E 's@\b(9|10|12)\.0\b@\1\.0a@g; s@\+PTX\b@@g' \
38+
| tee /opt/arch_list.txt && \
39+
printf 'NVCC_WRAPPER_FILTER_CODES=' && \
40+
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
41+
echo 'sm_80;sm_89;sm_90;sm_100;sm_120;compute_80;compute_89;compute_90;compute_100;compute_120'; \
42+
else \
43+
echo 'sm_90;sm_100;sm_120;compute_90;compute_100;compute_120'; \
44+
fi \
45+
| tee /opt/filter_codes.txt && \
46+
printf '#!/bin/sh\nexport %s %s;\n' \
47+
'TORCH_CUDA_ARCH_LIST="$(cat /opt/arch_list.txt)"' \
48+
'NVCC_WRAPPER_FILTER_CODES="$(cat /opt/filter_codes.txt)"' \
49+
| install -m 500 /dev/stdin /opt/arch_flags.sh
2850

2951
FROM alpine/git:2.36.3 AS vllm-downloader
3052
WORKDIR /git
@@ -66,6 +88,7 @@ RUN git clone --filter=tree:0 --no-single-branch --no-checkout \
6688

6789
FROM builder-base AS vllm-builder
6890
RUN --mount=type=bind,from=vllm-downloader,source=/git/vllm,target=/workspace,rw \
91+
. /opt/arch_flags.sh && \
6992
if [ -z "$MAX_JOBS" ]; then unset MAX_JOBS; fi && \
7093
python3 -m pip install --no-cache-dir py-cpuinfo && \
7194
if [ -f 'use_existing_torch.py' ]; then \
@@ -88,7 +111,9 @@ WORKDIR /wheels
88111

89112
FROM builder-base AS flashinfer-builder
90113
RUN --mount=type=bind,from=flashinfer-downloader,source=/git/flashinfer,target=/workspace,rw \
114+
. /opt/arch_flags.sh && \
91115
export TORCH_CUDA_ARCH_LIST="$(echo "${TORCH_CUDA_ARCH_LIST}" | sed 's@[67]\.0 \+@@g')" && \
116+
sed -i 's@torch\.cuda\.get_device_capability()@(12, 0)@' flashinfer/comm/trtllm_ar.py && \
92117
python3 -m flashinfer.aot && \
93118
python3 -m pip wheel -w /wheels \
94119
-v --no-cache-dir --no-build-isolation --no-deps \
@@ -101,6 +126,7 @@ WORKDIR /wheels
101126
FROM builder-base AS lmcache-builder
102127
# LMCache must be built from source as it doesn't have pre-built ARM binaries
103128
RUN --mount=type=bind,from=lmcache-downloader,source=/git/LMCache,target=/workspace,rw \
129+
. /opt/arch_flags.sh && \
104130
python3 -m pip install --no-cache-dir 'xxhash==3.5.0' 'setuptools_scm>=8' && \
105131
sed -Ei \
106132
'/[ "]*(torch(vision|audio)?|xformers) *[<>=~]+/d' \
@@ -155,14 +181,13 @@ RUN --mount=type=bind,from=lmcache-builder,source=/wheels,target=/tmp/wheels \
155181
ARG TARGETPLATFORM
156182

157183
RUN if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
158-
python3 -m pip install --no-cache-dir \
159-
accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.42.0' 'timm==0.9.10' \
160-
boto3 runai-model-streamer runai-model-streamer[s3] -c /tmp/constraints.txt; \
184+
BITSANDBYTES_VER='0.42.0'; \
161185
else \
162-
python3 -m pip install --no-cache-dir \
163-
accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.46.1' 'timm==0.9.10' \
164-
boto3 runai-model-streamer runai-model-streamer[s3] -c /tmp/constraints.txt; \
186+
BITSANDBYTES_VER='0.46.1'; \
165187
fi && \
188+
python3 -m pip install --no-cache-dir \
189+
accelerate hf_transfer 'modelscope!=1.15.0' "bitsandbytes>=${BITSANDBYTES_VER:?}" 'timm==0.9.10' \
190+
boto3 runai-model-streamer runai-model-streamer[s3] -c /tmp/constraints.txt && \
166191
rm /tmp/constraints.txt
167192

168193
EXPOSE 8080

vllm-tensorizer/nvcc-wrapper.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
#!/bin/env python3
2+
3+
"""
4+
Wraps invocations of ``nvcc``, watching for evidence of SIGKILL or SIGSEGV,
5+
and then re-running the ``nvcc`` command a configurable number of times.
6+
7+
Checking for SIGKILL or SIGSEGV is implemented by checking for either:
8+
9+
- A subprocess return code indicating either of these signals, or
10+
- The standard ``dash`` error messages for either signal.
11+
12+
``dash`` status messages are checked as NVCC utilizes ``sh``
13+
subprocesses internally, and ``sh`` usually resolves to
14+
the ``dash`` shell within Ubuntu container images.
15+
16+
This wrapper also has the ability to filter out some -gencode flags.
17+
Gencode flags to filter out should be identified by their code parameter
18+
in a semicolon-delimited list stored in the NVCC_WRAPPER_FILTER_CODES
19+
environment variable.
20+
"""
21+
22+
import asyncio
23+
import os
24+
import re
25+
import shutil
26+
import signal
27+
import subprocess
28+
import sys
29+
from typing import BinaryIO, Final, FrozenSet, Iterable, List, Sequence, Set
30+
31+
NVCC_PATH: Final[str] = shutil.which("nvcc")
32+
if NVCC_PATH is None:
33+
raise SystemExit("NVCC wrapper: fatal: nvcc binary not found")
34+
35+
WRAPPER_ATTEMPTS: Final[int] = int(os.getenv("NVCC_WRAPPER_ATTEMPTS") or 10)
36+
if WRAPPER_ATTEMPTS < 1:
37+
raise SystemExit("NVCC wrapper: fatal: invalid value for NVCC_WRAPPER_ATTEMPTS")
38+
39+
FILTER_CODES: Final[FrozenSet[str]] = frozenset(
40+
filter(None, os.getenv("NVCC_WRAPPER_FILTER_CODES", "").split(";"))
41+
)
42+
if FILTER_CODES and not all(
43+
re.fullmatch(r"(?:sm|compute|lto)_\d+[af]?", a) for a in FILTER_CODES
44+
):
45+
raise SystemExit("NVCC wrapper: fatal: invalid value for NVCC_WRAPPER_FILTER_CODES")
46+
47+
RETRY_RET_CODES: Final[FrozenSet[int]] = frozenset({
48+
-signal.SIGSEGV,
49+
-signal.SIGKILL,
50+
128 + signal.SIGSEGV,
51+
128 + signal.SIGKILL,
52+
255,
53+
})
54+
55+
56+
async def main(args) -> int:
57+
args = transform_args(args)
58+
ret: int = 0
59+
for attempt in range(1, WRAPPER_ATTEMPTS + 1):
60+
if attempt > 1:
61+
print(
62+
"NVCC wrapper: info:"
63+
f" Retrying [{attempt:d}/{WRAPPER_ATTEMPTS:d}]"
64+
f" after exit code {ret:d}",
65+
file=sys.stderr,
66+
flush=True,
67+
)
68+
# Wait an exponentially increasing amount of time
69+
# before trying again, up to one minute
70+
await asyncio.sleep(min(60, int(1.5**attempt)))
71+
proc = await asyncio.create_subprocess_exec(
72+
NVCC_PATH, *args, stdout=subprocess.PIPE, stderr=subprocess.PIPE
73+
)
74+
restart_signals: tuple = await asyncio.gather(
75+
monitor_stream(proc.stdout, sys.stdout.buffer),
76+
monitor_stream(proc.stderr, sys.stderr.buffer),
77+
)
78+
ret = await proc.wait()
79+
del proc
80+
if ret == 0 or not any(restart_signals) and ret not in RETRY_RET_CODES:
81+
break
82+
else:
83+
print(
84+
"NVCC wrapper: info:"
85+
f" Maximum attempts reached, exiting with status {ret:d}",
86+
file=sys.stderr,
87+
flush=True,
88+
)
89+
return ret
90+
91+
92+
async def monitor_stream(
93+
stream: asyncio.StreamReader,
94+
output: BinaryIO,
95+
watch_for: Iterable[bytes] = (
96+
b"Segmentation fault",
97+
b"Segmentation fault (core dumped)",
98+
b"Killed",
99+
),
100+
) -> bool:
101+
found: bool = False
102+
while line := await stream.readline():
103+
found = found or line.strip() in watch_for
104+
output.write(line)
105+
output.flush()
106+
return found
107+
108+
109+
def transform_args(args: Sequence[str]) -> Sequence[str]:
110+
# This filters out args of the form -gencode=arch=X,code=Y
111+
# or -gencode arch=X,code=Y for any code in FILTER_CODES.
112+
# This does not filter arguments specified using the
113+
# --gpu-architecture and --gpu-code flags, nor codes specified
114+
# among others in groups, like -gencode=arch=X,code=[Y,Z].
115+
if not FILTER_CODES:
116+
return args
117+
transformed_args = []
118+
partial: bool = False
119+
gencode: Set[str] = {"-gencode", "--generate-code"}
120+
for arg in args:
121+
if not partial and arg in gencode:
122+
partial = True
123+
transformed_args.append(arg)
124+
continue
125+
if partial:
126+
pattern: str = r"(arch=[^,]+,code=)(\S+)"
127+
else:
128+
pattern: str = r"((?:-gencode|--generate-code)=arch=\S+,code=)(\S+)"
129+
m: re.Match = re.fullmatch(pattern, arg)
130+
if m:
131+
code: str = m.group(2)
132+
if code in FILTER_CODES:
133+
if partial:
134+
# There was a hanging `-gencode` arg before this, so delete it
135+
assert transformed_args[-1] in gencode
136+
del transformed_args[-1]
137+
elif re.fullmatch(r"\[\S+]", code):
138+
codes: List[str] = code[1:-1].split(",")
139+
filtered_codes: List[str] = [c for c in codes if c not in FILTER_CODES]
140+
if filtered_codes:
141+
filtered_code: str = ",".join(filtered_codes)
142+
if len(filtered_codes) > 1:
143+
filtered_code = f"[{filtered_code}]"
144+
transformed_args.append(m.group(1) + filtered_code)
145+
elif partial:
146+
assert transformed_args[-1] in gencode
147+
del transformed_args[-1]
148+
else:
149+
transformed_args.append(arg)
150+
else:
151+
transformed_args.append(arg)
152+
partial = False
153+
return transformed_args
154+
155+
156+
if __name__ == "__main__":
157+
try:
158+
sys.exit(asyncio.run(main(sys.argv[1:])))
159+
except KeyboardInterrupt:
160+
sys.exit(130)

0 commit comments

Comments
 (0)