Skip to content

Commit 7ef4f05

Browse files
authored
Merge pull request #102 from coreweave/es/fa3-te-update
feat: Update `torch` & `vllm-tensorizer` images, improve build process
2 parents b42c222 + f67f9ec commit 7ef4f05

File tree

8 files changed

+282
-41
lines changed

8 files changed

+282
-41
lines changed

.github/configurations/torch-nccl.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@ include:
55
- torch: 2.7.1
66
vision: 0.22.1
77
audio: 2.7.1
8-
nccl: 2.27.5-1
9-
nccl-tests-hash: '0120901'
8+
nccl: 2.27.6-1
9+
nccl-tests-hash: '7c12c62'
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
vllm-commit:
2-
- 'b6553be1bc75f046b00046a4ad7576364d03c835'
2+
- 'v0.9.2'
33
flashinfer-commit:
44
- 'v0.2.6.post1'
55
builder-base-image:
6-
- 'ghcr.io/coreweave/ml-containers/torch-extras:es-cuda-12.9.1-74755e9-nccl-cuda12.9.1-ubuntu22.04-nccl2.27.5-1-torch2.7.1-vision0.22.1-audio2.7.1-abi1'
6+
- 'ghcr.io/coreweave/ml-containers/torch-extras:es-fa3-te-update-7a94157-nccl-cuda12.9.1-ubuntu22.04-nccl2.27.6-1-torch2.7.1-vision0.22.1-audio2.7.1-abi1'
77
final-base-image:
8-
- 'ghcr.io/coreweave/ml-containers/torch-extras:es-cuda-12.9.1-74755e9-base-cuda12.9.1-ubuntu22.04-torch2.7.1-vision0.22.1-audio2.7.1-abi1'
8+
- 'ghcr.io/coreweave/ml-containers/torch-extras:es-fa3-te-update-7a94157-base-cuda12.9.1-ubuntu22.04-torch2.7.1-vision0.22.1-audio2.7.1-abi1'

.github/workflows/vllm-tensorizer.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ on:
22
push:
33
paths:
44
- "vllm-tensorizer/**"
5+
- ".github/configurations/vllm-tensorizer.yml"
56
- ".github/workflows/vllm-tensorizer.yml"
67
- ".github/workflows/build.yml"
78

torch-extras/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ ARG DEEPSPEED_VERSION="0.14.4"
55
ARG APEX_COMMIT="a1df80457ba67d60cbdb0d3ddfb08a2702c821a8"
66
ARG DEEPSPEED_KERNELS_COMMIT="e77acc40b104696d4e73229b787d1ef29a9685b1"
77
ARG DEEPSPEED_KERNELS_CUDA_ARCH_LIST="80;86;89;90"
8-
ARG XFORMERS_VERSION="0.0.30"
8+
ARG XFORMERS_VERSION="0.0.31.post1"
99
ARG BUILD_MAX_JOBS=""
1010

1111
FROM alpine/git:2.36.3 as apex-downloader

torch/Dockerfile

Lines changed: 50 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ ARG FINAL_BASE_IMAGE="nvidia/cuda:12.9.1-base-ubuntu22.04"
55
ARG BUILD_TORCH_VERSION="2.7.1"
66
ARG BUILD_TORCH_VISION_VERSION="0.22.1"
77
ARG BUILD_TORCH_AUDIO_VERSION="2.7.1"
8-
ARG BUILD_TRANSFORMERENGINE_VERSION="1.13"
8+
ARG BUILD_TRANSFORMERENGINE_VERSION="2.4"
99
ARG BUILD_FLASH_ATTN_VERSION="2.7.4.post1"
10-
ARG BUILD_FLASH_ATTN_3_VERSION="2.7.2.post1"
10+
ARG BUILD_FLASH_ATTN_3_VERSION="b36ad4ef767d2d5536ff8af2e3f720ae4eba731c"
1111
ARG BUILD_TRITON_VERSION=""
1212
ARG BUILD_TRITON="1"
1313
ARG BUILD_TORCH_CUDA_ARCH_LIST="8.0 8.9 9.0 10.0 12.0+PTX"
@@ -90,7 +90,8 @@ RUN ./clone.sh Dao-AILab/flash-attention flash-attention "${BUILD_FLASH_ATTN_VER
9090
FROM downloader-base AS flash-attn-3-downloader
9191
ARG BUILD_FLASH_ATTN_3_VERSION
9292
RUN if [ -n "$BUILD_FLASH_ATTN_3_VERSION" ]; then \
93-
./clone.sh Dao-AILab/flash-attention flash-attention "${BUILD_FLASH_ATTN_3_VERSION}"; \
93+
./clone.sh Dao-AILab/flash-attention flash-attention "${BUILD_FLASH_ATTN_3_VERSION}" && \
94+
git -C flash-attention cherry-pick -n 3edf7e0daa62662cd2dd2ec8fd999dd7f254415c; \
9495
else \
9596
mkdir flash-attention; \
9697
fi
@@ -329,11 +330,16 @@ ARG BUILD_MAX_JOBS=""
329330
RUN --mount=type=bind,from=triton-downloader,source=/git/triton,target=triton/,rw \
330331
--mount=type=cache,target=/ccache \
331332
if [ "$BUILD_TRITON" = '1' ]; then \
332-
pip3 install --no-cache-dir pybind11 && \
333+
pip3 install --no-cache-dir pybind11 lit && \
333334
export MAX_JOBS="${BUILD_MAX_JOBS:-$(./scale.sh "$(./effective_cpu_count.sh)" 3 32)}" && \
334-
cd triton/python && \
335-
python3 -m pip wheel -w ../../dist/ --no-build-isolation --no-deps -vv . && \
336-
pip3 install ../../dist/*.whl; \
335+
DIST_DIR="$(realpath -e ./dist)" && \
336+
if [ -f 'triton/python/setup.py' ]; then \
337+
cd triton/python; \
338+
else \
339+
cd triton; \
340+
fi && \
341+
python3 -m pip wheel -w "${DIST_DIR}/" --no-build-isolation --no-deps -vv . && \
342+
pip3 install --no-cache-dir "${DIST_DIR}"/*.whl; \
337343
fi
338344

339345
ARG BUILD_TORCH_VERSION
@@ -348,15 +354,22 @@ ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST#*||}"
348354
RUN printf 'Arch: %s\nTORCH_CUDA_ARCH_LIST=%s\n' "$(uname -m)" "${TORCH_CUDA_ARCH_LIST}"
349355

350356
ARG BUILD_NVCC_APPEND_FLAGS="-gencode=arch=compute_90a,code=sm_90a"
351-
# Add sm_100a & sm_120 builds if NV_CUDA_LIB_VERSION matches 12.[89].*
357+
# Add sm_100a & sm_120a builds if NV_CUDA_LIB_VERSION matches 12.[89].*
352358
RUN FLAGS="$BUILD_NVCC_APPEND_FLAGS" && \
353359
case "${NV_CUDA_LIB_VERSION}" in 12.[89].*) \
354360
FLAGS="${FLAGS}$( \
355-
printf -- ' -gencode=arch=compute_%s,code=sm_%s' 120 120 100 100 100a 100a \
361+
printf -- ' -gencode=arch=compute_%s,code=sm_%s' 120a 120a 100a 100a \
356362
)" ;; \
357363
esac && \
358364
echo "-Wno-deprecated-gpu-targets -diag-suppress 191,186,177${FLAGS:+ $FLAGS}" > /build/nvcc.conf
359365

366+
COPY --link --chmod=755 nvcc-wrapper.py /build/nvcc-wrapper.py
367+
ENV PYTORCH_NVCC='/build/nvcc-wrapper.py' \
368+
CMAKE_CUDA_COMPILER='/build/nvcc-wrapper.py'
369+
# Filter these codes because we already build for the architecture-specific
370+
# versions of them instead.
371+
ENV NVCC_WRAPPER_FILTER_CODES='sm_90;sm_100;sm_120;compute_90;compute_100'
372+
360373
# If the directory /opt/nccl-tests exists,
361374
# the base image is assumed to be nccl-tests,
362375
# so it uses the system's special NCCL and UCC installations for the build.
@@ -534,10 +547,6 @@ RUN --mount=type=bind,from=transformerengine-downloader,source=/git/TransformerE
534547
export NVTE_CUDA_ARCHS="${NVTE_CUDA_ARCHS%;100*}" ;; \
535548
esac && \
536549
cd TransformerEngine && \
537-
if python3 -c "import sys; sys.exit(sys.version_info.minor > 8)"; then \
538-
sed -i "s/from functools import cache/from functools import lru_cache as cache/g" \
539-
build_tools/utils.py; \
540-
fi && \
541550
python3 setup.py bdist_wheel --dist-dir /build/dist
542551

543552
FROM builder-base AS flash-attn-builder-base
@@ -550,8 +559,9 @@ COPY <<-"EOT" /build/fa-build.sh
550559
#!/bin/bash
551560
set -eo pipefail;
552561
if [ -n "$1" ]; then cd "$1"; fi;
562+
echo "Flash Attention build: building $(realpath -s .)";
553563
python3 setup.py bdist_wheel --dist-dir /build/dist \
554-
| grep -Ev --line-buffered '^ptxas (/tmp/|(info|warning)\s*:)|bytes spill stores'
564+
| grep -Ev --line-buffered '^ptxas (/tmp/|(info|warning)\s*:)|bytes spill stores';
555565
EOT
556566
RUN chmod 755 /build/fa-build.sh
557567

@@ -581,8 +591,10 @@ FROM flash-attn-builder-base AS flash-attn-3-builder
581591
# Artifically sequence this build stage after the previous one
582592
# to prevent parallelism, because these are both very resource-intensive
583593
RUN --mount=type=bind,from=flash-attn-builder,source=/build,target=/build :
594+
ARG BUILD_FLASH_ATTN_MAX_JOBS="${BUILD_FLASH_ATTN_MAX_JOBS:-3}"
584595

585596
# Build flash-attn v3
597+
SHELL ["/bin/bash", "-o", "pipefail", "-c"]
586598
RUN --mount=type=bind,from=flash-attn-3-downloader,source=/git/flash-attention,target=flash-attention/,rw \
587599
--mount=type=cache,target=/ccache \
588600
if [ ! -d flash-attention/hopper ]; then \
@@ -592,8 +604,16 @@ RUN --mount=type=bind,from=flash-attn-3-downloader,source=/git/flash-attention,t
592604
MAX_JOBS="${BUILD_FLASH_ATTN_MAX_JOBS:-$(./scale.sh "$(./effective_cpu_count.sh)" 10 6)}" && \
593605
echo "MAX_JOBS: ${MAX_JOBS}" && \
594606
export NVCC_APPEND_FLAGS="$(cat /build/nvcc.conf)" && \
607+
if [ "$(uname -m)" = 'aarch64' ]; then \
608+
export FLASH_ATTENTION_DISABLE_SM80=TRUE; \
609+
else \
610+
NVCC_APPEND_FLAGS="${NVCC_APPEND_FLAGS:+$NVCC_APPEND_FLAGS }-Xcompiler -mcmodel=medium"; \
611+
fi && \
595612
echo "NVCC_APPEND_FLAGS: ${NVCC_APPEND_FLAGS}" && \
596-
/build/fa-build.sh flash-attention/hopper
613+
sed -i \
614+
's@if bare_metal_version != Version("12.8"):@if bare_metal_version < Version("12.8"):@' \
615+
flash-attention/hopper/setup.py && \
616+
NVCC_THREADS=4 /build/fa-build.sh flash-attention/hopper
597617

598618
FROM builder-base AS builder
599619
COPY --link --from=torchaudio-builder /build/dist/ /build/dist/
@@ -671,27 +691,27 @@ COPY --link --chmod=755 install_cudnn.sh /tmp/install_cudnn.sh
671691
# - libnvjitlink-X-Y only exists for CUDA versions >= 12-0.
672692
# - Don't mess with libnccl2 when using nccl-tests as a base,
673693
# checked via the existence of the directory "/opt/nccl-tests".
674-
RUN export \
675-
CUDA_MAJOR_VERSION=$(echo $CUDA_VERSION | cut -d. -f1) \
676-
CUDA_MINOR_VERSION=$(echo $CUDA_VERSION | cut -d. -f2) && \
677-
export \
678-
CUDA_PACKAGE_VERSION="${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION}" && \
694+
RUN CUDA_MAJOR_VERSION="$(echo "$CUDA_VERSION" | cut -d. -f1)" && \
695+
CUDA_MINOR_VERSION="$(echo "$CUDA_VERSION" | cut -d. -f2)" && \
696+
CUDA_PACKAGE_VERSION="${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION}" && \
697+
CUDART_VERSION_SPEC="${NV_CUDA_CUDART_VERSION:+=$NV_CUDA_CUDART_VERSION}" && \
679698
apt-get -qq update && \
680699
apt-get -qq install --no-upgrade -y \
681700
libcurand-${CUDA_PACKAGE_VERSION} \
682701
libcufft-${CUDA_PACKAGE_VERSION} \
683702
libcublas-${CUDA_PACKAGE_VERSION} \
684703
cuda-nvrtc-${CUDA_PACKAGE_VERSION} \
704+
cuda-cudart-dev-${CUDA_PACKAGE_VERSION}"${CUDART_VERSION_SPEC}" \
685705
libcusparse-${CUDA_PACKAGE_VERSION} \
686706
libcusolver-${CUDA_PACKAGE_VERSION} \
687707
libcufile-${CUDA_PACKAGE_VERSION} \
688708
cuda-cupti-${CUDA_PACKAGE_VERSION} \
689709
libnvjpeg-${CUDA_PACKAGE_VERSION} \
690710
libnvtoolsext1 && \
691-
{ if [ $CUDA_MAJOR_VERSION -ge 12 ]; then \
711+
{ if [ "$CUDA_MAJOR_VERSION" -ge 12 ]; then \
692712
apt-get -qq install --no-upgrade -y libnvjitlink-${CUDA_PACKAGE_VERSION}; fi; } && \
693713
{ if [ ! -d /opt/nccl-tests ]; then \
694-
export NCCL_PACKAGE_VERSION="2.*+cuda${CUDA_MAJOR_VERSION}.${CUDA_MINOR_VERSION}" && \
714+
NCCL_PACKAGE_VERSION="2.*+cuda${CUDA_MAJOR_VERSION}.${CUDA_MINOR_VERSION}" && \
695715
apt-get -qq install --no-upgrade -y "libnccl2=$NCCL_PACKAGE_VERSION"; fi; } && \
696716
/tmp/install_cudnn.sh "$CUDA_VERSION" runtime && \
697717
rm /tmp/install_cudnn.sh && \
@@ -717,7 +737,12 @@ RUN <<-"EOT" python3
717737
from pathlib import Path
718738
from py_compile import compile
719739

720-
dist = metadata.distribution("flashattn-hopper")
740+
try:
741+
dist = metadata.distribution("flash-attn-3")
742+
record_pattern = "flash?attn?3-*.dist-info/RECORD"
743+
except metadata.PackageNotFoundError:
744+
dist = metadata.distribution("flashattn-hopper")
745+
record_pattern = "flashattn?hopper-*.dist-info/RECORD"
721746
p = dist.locate_file("flash_attn_interface.py")
722747
print("flash_attn_interface:", p)
723748
root = p.parent
@@ -727,7 +752,7 @@ RUN <<-"EOT" python3
727752
if not p.is_file():
728753
raise SystemExit("flash_attn_interface path is not a file")
729754

730-
d = root / "flashattn_hopper"
755+
d = root / "flash_attn_3"
731756
if d.exists():
732757
raise SystemExit(f'"{d}" already exists')
733758

@@ -747,7 +772,7 @@ RUN <<-"EOT" python3
747772

748773

749774
for f in dist.files:
750-
if f.match("flashattn?hopper-*.dist-info/RECORD"):
775+
if f.match(record_pattern):
751776
with f.locate().open("a", encoding="utf-8", newline="") as record:
752777
for added in (new, compiled):
753778
record.write(record_entry(added))

torch/install_cudnn.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ LIBCUDNN_VER="$(
3333

3434
if [ -z "$LIBCUDNN_VER" ]; then
3535
apt-get -qq update && \
36-
apt-get -qq install --no-upgrade -y "${DEV_PREFIX}cudnn9-cuda-${CUDA_MAJOR_VERSION}" && \
36+
apt-get -qq install --no-upgrade -y \
37+
"${DEV_PREFIX}cudnn9-cuda-${CUDA_MAJOR_VERSION}" \
38+
"libcudnn9-dev-cuda-${CUDA_MAJOR_VERSION}" && \
3739
apt-get clean && \
3840
ldconfig;
3941
else

torch/nvcc-wrapper.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
#!/bin/env python3
2+
3+
"""
4+
Wraps invocations of ``nvcc``, watching for evidence of SIGKILL or SIGSEGV,
5+
and then re-running the ``nvcc`` command a configurable number of times.
6+
7+
Checking for SIGKILL or SIGSEGV is implemented by checking for either:
8+
9+
- A subprocess return code indicating either of these signals, or
10+
- The standard ``dash`` error messages for either signal.
11+
12+
``dash`` status messages are checked as NVCC utilizes ``sh``
13+
subprocesses internally, and ``sh`` usually resolves to
14+
the ``dash`` shell within Ubuntu container images.
15+
16+
This wrapper also has the ability to filter out some -gencode flags.
17+
Gencode flags to filter out should be identified by their code parameter
18+
in a semicolon-delimited list stored in the NVCC_WRAPPER_FILTER_CODES
19+
environment variable.
20+
"""
21+
22+
import asyncio
23+
import os
24+
import re
25+
import shutil
26+
import signal
27+
import subprocess
28+
import sys
29+
from typing import BinaryIO, Final, FrozenSet, Iterable, List, Sequence, Set
30+
31+
NVCC_PATH: Final[str] = shutil.which("nvcc")
32+
if NVCC_PATH is None:
33+
raise SystemExit("NVCC wrapper: fatal: nvcc binary not found")
34+
35+
WRAPPER_ATTEMPTS: Final[int] = int(os.getenv("NVCC_WRAPPER_ATTEMPTS") or 10)
36+
if WRAPPER_ATTEMPTS < 1:
37+
raise SystemExit("NVCC wrapper: fatal: invalid value for NVCC_WRAPPER_ATTEMPTS")
38+
39+
FILTER_CODES: Final[FrozenSet[str]] = frozenset(
40+
filter(None, os.getenv("NVCC_WRAPPER_FILTER_CODES", "").split(";"))
41+
)
42+
if FILTER_CODES and not all(
43+
re.fullmatch(r"(?:sm|compute|lto)_\d+[af]?", a) for a in FILTER_CODES
44+
):
45+
raise SystemExit("NVCC wrapper: fatal: invalid value for NVCC_WRAPPER_FILTER_CODES")
46+
47+
RETRY_RET_CODES: Final[FrozenSet[int]] = frozenset({
48+
-signal.SIGSEGV,
49+
-signal.SIGKILL,
50+
128 + signal.SIGSEGV,
51+
128 + signal.SIGKILL,
52+
255,
53+
})
54+
55+
56+
async def main(args) -> int:
57+
args = transform_args(args)
58+
ret: int = 0
59+
for attempt in range(1, WRAPPER_ATTEMPTS + 1):
60+
if attempt > 1:
61+
print(
62+
"NVCC wrapper: info:"
63+
f" Retrying [{attempt:d}/{WRAPPER_ATTEMPTS:d}]"
64+
f" after exit code {ret:d}",
65+
file=sys.stderr,
66+
flush=True,
67+
)
68+
# Wait an exponentially increasing amount of time
69+
# before trying again, up to one minute
70+
await asyncio.sleep(min(60, int(1.5**attempt)))
71+
proc = await asyncio.create_subprocess_exec(
72+
NVCC_PATH, *args, stdout=subprocess.PIPE, stderr=subprocess.PIPE
73+
)
74+
restart_signals: tuple = await asyncio.gather(
75+
monitor_stream(proc.stdout, sys.stdout.buffer),
76+
monitor_stream(proc.stderr, sys.stderr.buffer),
77+
)
78+
ret = await proc.wait()
79+
del proc
80+
if ret == 0 or not any(restart_signals) and ret not in RETRY_RET_CODES:
81+
break
82+
else:
83+
print(
84+
"NVCC wrapper: info:"
85+
f" Maximum attempts reached, exiting with status {ret:d}",
86+
file=sys.stderr,
87+
flush=True,
88+
)
89+
return ret
90+
91+
92+
async def monitor_stream(
93+
stream: asyncio.StreamReader,
94+
output: BinaryIO,
95+
watch_for: Iterable[bytes] = (
96+
b"Segmentation fault",
97+
b"Segmentation fault (core dumped)",
98+
b"Killed",
99+
),
100+
) -> bool:
101+
found: bool = False
102+
while line := await stream.readline():
103+
found = found or line.strip() in watch_for
104+
output.write(line)
105+
output.flush()
106+
return found
107+
108+
109+
def transform_args(args: Sequence[str]) -> Sequence[str]:
110+
# This filters out args of the form -gencode=arch=X,code=Y
111+
# or -gencode arch=X,code=Y for any code in FILTER_CODES.
112+
# This does not filter arguments specified using the
113+
# --gpu-architecture and --gpu-code flags, nor codes specified
114+
# among others in groups, like -gencode=arch=X,code=[Y,Z].
115+
if not FILTER_CODES:
116+
return args
117+
transformed_args = []
118+
partial: bool = False
119+
gencode: Set[str] = {"-gencode", "--generate-code"}
120+
for arg in args:
121+
if not partial and arg in gencode:
122+
partial = True
123+
transformed_args.append(arg)
124+
continue
125+
if partial:
126+
pattern: str = r"(arch=[^,]+,code=)(\S+)"
127+
else:
128+
pattern: str = r"((?:-gencode|--generate-code)=arch=\S+,code=)(\S+)"
129+
m: re.Match = re.fullmatch(pattern, arg)
130+
if m:
131+
code: str = m.group(2)
132+
if code in FILTER_CODES:
133+
if partial:
134+
# There was a hanging `-gencode` arg before this, so delete it
135+
assert transformed_args[-1] in gencode
136+
del transformed_args[-1]
137+
elif re.fullmatch(r"\[\S+]", code):
138+
codes: List[str] = code[1:-1].split(",")
139+
filtered_codes: List[str] = [c for c in codes if c not in FILTER_CODES]
140+
if filtered_codes:
141+
filtered_code: str = ",".join(filtered_codes)
142+
if len(filtered_codes) > 1:
143+
filtered_code = f"[{filtered_code}]"
144+
transformed_args.append(m.group(1) + filtered_code)
145+
elif partial:
146+
assert transformed_args[-1] in gencode
147+
del transformed_args[-1]
148+
else:
149+
transformed_args.append(arg)
150+
else:
151+
transformed_args.append(arg)
152+
partial = False
153+
return transformed_args
154+
155+
156+
if __name__ == "__main__":
157+
try:
158+
sys.exit(asyncio.run(main(sys.argv[1:])))
159+
except KeyboardInterrupt:
160+
sys.exit(130)

0 commit comments

Comments
 (0)