@@ -5,9 +5,9 @@ ARG FINAL_BASE_IMAGE="nvidia/cuda:12.9.1-base-ubuntu22.04"
5
5
ARG BUILD_TORCH_VERSION="2.7.1"
6
6
ARG BUILD_TORCH_VISION_VERSION="0.22.1"
7
7
ARG BUILD_TORCH_AUDIO_VERSION="2.7.1"
8
- ARG BUILD_TRANSFORMERENGINE_VERSION="1.13 "
8
+ ARG BUILD_TRANSFORMERENGINE_VERSION="2.4 "
9
9
ARG BUILD_FLASH_ATTN_VERSION="2.7.4.post1"
10
- ARG BUILD_FLASH_ATTN_3_VERSION="2.7.2.post1 "
10
+ ARG BUILD_FLASH_ATTN_3_VERSION="b36ad4ef767d2d5536ff8af2e3f720ae4eba731c "
11
11
ARG BUILD_TRITON_VERSION=""
12
12
ARG BUILD_TRITON="1"
13
13
ARG BUILD_TORCH_CUDA_ARCH_LIST="8.0 8.9 9.0 10.0 12.0+PTX"
@@ -90,7 +90,8 @@ RUN ./clone.sh Dao-AILab/flash-attention flash-attention "${BUILD_FLASH_ATTN_VER
90
90
FROM downloader-base AS flash-attn-3-downloader
91
91
ARG BUILD_FLASH_ATTN_3_VERSION
92
92
RUN if [ -n "$BUILD_FLASH_ATTN_3_VERSION" ]; then \
93
- ./clone.sh Dao-AILab/flash-attention flash-attention "${BUILD_FLASH_ATTN_3_VERSION}" ; \
93
+ ./clone.sh Dao-AILab/flash-attention flash-attention "${BUILD_FLASH_ATTN_3_VERSION}" && \
94
+ git -C flash-attention cherry-pick -n 3edf7e0daa62662cd2dd2ec8fd999dd7f254415c; \
94
95
else \
95
96
mkdir flash-attention; \
96
97
fi
@@ -329,11 +330,16 @@ ARG BUILD_MAX_JOBS=""
329
330
RUN --mount=type=bind,from=triton-downloader,source=/git/triton,target=triton/,rw \
330
331
--mount=type=cache,target=/ccache \
331
332
if [ "$BUILD_TRITON" = '1' ]; then \
332
- pip3 install --no-cache-dir pybind11 && \
333
+ pip3 install --no-cache-dir pybind11 lit && \
333
334
export MAX_JOBS="${BUILD_MAX_JOBS:-$(./scale.sh " $(./effective_cpu_count.sh)" 3 32)}" && \
334
- cd triton/python && \
335
- python3 -m pip wheel -w ../../dist/ --no-build-isolation --no-deps -vv . && \
336
- pip3 install ../../dist/*.whl; \
335
+ DIST_DIR="$(realpath -e ./dist)" && \
336
+ if [ -f 'triton/python/setup.py' ]; then \
337
+ cd triton/python; \
338
+ else \
339
+ cd triton; \
340
+ fi && \
341
+ python3 -m pip wheel -w "${DIST_DIR}/" --no-build-isolation --no-deps -vv . && \
342
+ pip3 install --no-cache-dir "${DIST_DIR}" /*.whl; \
337
343
fi
338
344
339
345
ARG BUILD_TORCH_VERSION
@@ -348,15 +354,22 @@ ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST#*||}"
348
354
RUN printf 'Arch: %s\n TORCH_CUDA_ARCH_LIST=%s\n ' "$(uname -m)" "${TORCH_CUDA_ARCH_LIST}"
349
355
350
356
ARG BUILD_NVCC_APPEND_FLAGS="-gencode=arch=compute_90a,code=sm_90a"
351
- # Add sm_100a & sm_120 builds if NV_CUDA_LIB_VERSION matches 12.[89].*
357
+ # Add sm_100a & sm_120a builds if NV_CUDA_LIB_VERSION matches 12.[89].*
352
358
RUN FLAGS="$BUILD_NVCC_APPEND_FLAGS" && \
353
359
case "${NV_CUDA_LIB_VERSION}" in 12.[89].*) \
354
360
FLAGS="${FLAGS}$( \
355
- printf -- ' -gencode=arch=compute_%s,code=sm_%s' 120 120 100 100 100a 100a \
361
+ printf -- ' -gencode=arch=compute_%s,code=sm_%s' 120a 120a 100a 100a \
356
362
)" ;; \
357
363
esac && \
358
364
echo "-Wno-deprecated-gpu-targets -diag-suppress 191,186,177${FLAGS:+ $FLAGS}" > /build/nvcc.conf
359
365
366
+ COPY --link --chmod=755 nvcc-wrapper.py /build/nvcc-wrapper.py
367
+ ENV PYTORCH_NVCC='/build/nvcc-wrapper.py' \
368
+ CMAKE_CUDA_COMPILER='/build/nvcc-wrapper.py'
369
+ # Filter these codes because we already build for the architecture-specific
370
+ # versions of them instead.
371
+ ENV NVCC_WRAPPER_FILTER_CODES='sm_90;sm_100;sm_120;compute_90;compute_100'
372
+
360
373
# If the directory /opt/nccl-tests exists,
361
374
# the base image is assumed to be nccl-tests,
362
375
# so it uses the system's special NCCL and UCC installations for the build.
@@ -534,10 +547,6 @@ RUN --mount=type=bind,from=transformerengine-downloader,source=/git/TransformerE
534
547
export NVTE_CUDA_ARCHS="${NVTE_CUDA_ARCHS%;100*}" ;; \
535
548
esac && \
536
549
cd TransformerEngine && \
537
- if python3 -c "import sys; sys.exit(sys.version_info.minor > 8)" ; then \
538
- sed -i "s/from functools import cache/from functools import lru_cache as cache/g" \
539
- build_tools/utils.py; \
540
- fi && \
541
550
python3 setup.py bdist_wheel --dist-dir /build/dist
542
551
543
552
FROM builder-base AS flash-attn-builder-base
@@ -550,8 +559,9 @@ COPY <<-"EOT" /build/fa-build.sh
550
559
# !/bin/bash
551
560
set -eo pipefail;
552
561
if [ -n "$1" ]; then cd "$1" ; fi;
562
+ echo "Flash Attention build: building $(realpath -s .)" ;
553
563
python3 setup.py bdist_wheel --dist-dir /build/dist \
554
- | grep -Ev --line-buffered '^ptxas (/tmp/|(info|warning)\s *:)|bytes spill stores'
564
+ | grep -Ev --line-buffered '^ptxas (/tmp/|(info|warning)\s *:)|bytes spill stores' ;
555
565
EOT
556
566
RUN chmod 755 /build/fa-build.sh
557
567
@@ -581,8 +591,10 @@ FROM flash-attn-builder-base AS flash-attn-3-builder
581
591
# Artifically sequence this build stage after the previous one
582
592
# to prevent parallelism, because these are both very resource-intensive
583
593
RUN --mount=type=bind,from=flash-attn-builder,source=/build,target=/build :
594
+ ARG BUILD_FLASH_ATTN_MAX_JOBS="${BUILD_FLASH_ATTN_MAX_JOBS:-3}"
584
595
585
596
# Build flash-attn v3
597
+ SHELL ["/bin/bash" , "-o" , "pipefail" , "-c" ]
586
598
RUN --mount=type=bind,from=flash-attn-3-downloader,source=/git/flash-attention,target=flash-attention/,rw \
587
599
--mount=type=cache,target=/ccache \
588
600
if [ ! -d flash-attention/hopper ]; then \
@@ -592,8 +604,16 @@ RUN --mount=type=bind,from=flash-attn-3-downloader,source=/git/flash-attention,t
592
604
MAX_JOBS="${BUILD_FLASH_ATTN_MAX_JOBS:-$(./scale.sh " $(./effective_cpu_count.sh)" 10 6)}" && \
593
605
echo "MAX_JOBS: ${MAX_JOBS}" && \
594
606
export NVCC_APPEND_FLAGS="$(cat /build/nvcc.conf)" && \
607
+ if [ "$(uname -m)" = 'aarch64' ]; then \
608
+ export FLASH_ATTENTION_DISABLE_SM80=TRUE; \
609
+ else \
610
+ NVCC_APPEND_FLAGS="${NVCC_APPEND_FLAGS:+$NVCC_APPEND_FLAGS }-Xcompiler -mcmodel=medium" ; \
611
+ fi && \
595
612
echo "NVCC_APPEND_FLAGS: ${NVCC_APPEND_FLAGS}" && \
596
- /build/fa-build.sh flash-attention/hopper
613
+ sed -i \
614
+ 's@if bare_metal_version != Version("12.8"):@if bare_metal_version < Version("12.8"):@' \
615
+ flash-attention/hopper/setup.py && \
616
+ NVCC_THREADS=4 /build/fa-build.sh flash-attention/hopper
597
617
598
618
FROM builder-base AS builder
599
619
COPY --link --from=torchaudio-builder /build/dist/ /build/dist/
@@ -671,27 +691,27 @@ COPY --link --chmod=755 install_cudnn.sh /tmp/install_cudnn.sh
671
691
# - libnvjitlink-X-Y only exists for CUDA versions >= 12-0.
672
692
# - Don't mess with libnccl2 when using nccl-tests as a base,
673
693
# checked via the existence of the directory "/opt/nccl-tests".
674
- RUN export \
675
- CUDA_MAJOR_VERSION=$(echo $CUDA_VERSION | cut -d. -f1) \
676
- CUDA_MINOR_VERSION=$(echo $CUDA_VERSION | cut -d. -f2) && \
677
- export \
678
- CUDA_PACKAGE_VERSION="${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION}" && \
694
+ RUN CUDA_MAJOR_VERSION="$(echo " $CUDA_VERSION" | cut -d. -f1)" && \
695
+ CUDA_MINOR_VERSION="$(echo " $CUDA_VERSION" | cut -d. -f2)" && \
696
+ CUDA_PACKAGE_VERSION="${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION}" && \
697
+ CUDART_VERSION_SPEC="${NV_CUDA_CUDART_VERSION:+=$NV_CUDA_CUDART_VERSION}" && \
679
698
apt-get -qq update && \
680
699
apt-get -qq install --no-upgrade -y \
681
700
libcurand-${CUDA_PACKAGE_VERSION} \
682
701
libcufft-${CUDA_PACKAGE_VERSION} \
683
702
libcublas-${CUDA_PACKAGE_VERSION} \
684
703
cuda-nvrtc-${CUDA_PACKAGE_VERSION} \
704
+ cuda-cudart-dev-${CUDA_PACKAGE_VERSION}"${CUDART_VERSION_SPEC}" \
685
705
libcusparse-${CUDA_PACKAGE_VERSION} \
686
706
libcusolver-${CUDA_PACKAGE_VERSION} \
687
707
libcufile-${CUDA_PACKAGE_VERSION} \
688
708
cuda-cupti-${CUDA_PACKAGE_VERSION} \
689
709
libnvjpeg-${CUDA_PACKAGE_VERSION} \
690
710
libnvtoolsext1 && \
691
- { if [ $CUDA_MAJOR_VERSION -ge 12 ]; then \
711
+ { if [ " $CUDA_MAJOR_VERSION" -ge 12 ]; then \
692
712
apt-get -qq install --no-upgrade -y libnvjitlink-${CUDA_PACKAGE_VERSION}; fi; } && \
693
713
{ if [ ! -d /opt/nccl-tests ]; then \
694
- export NCCL_PACKAGE_VERSION="2.*+cuda${CUDA_MAJOR_VERSION}.${CUDA_MINOR_VERSION}" && \
714
+ NCCL_PACKAGE_VERSION="2.*+cuda${CUDA_MAJOR_VERSION}.${CUDA_MINOR_VERSION}" && \
695
715
apt-get -qq install --no-upgrade -y "libnccl2=$NCCL_PACKAGE_VERSION" ; fi; } && \
696
716
/tmp/install_cudnn.sh "$CUDA_VERSION" runtime && \
697
717
rm /tmp/install_cudnn.sh && \
@@ -717,7 +737,12 @@ RUN <<-"EOT" python3
717
737
from pathlib import Path
718
738
from py_compile import compile
719
739
720
- dist = metadata.distribution("flashattn-hopper" )
740
+ try:
741
+ dist = metadata.distribution("flash-attn-3" )
742
+ record_pattern = "flash?attn?3-*.dist-info/RECORD"
743
+ except metadata.PackageNotFoundError:
744
+ dist = metadata.distribution("flashattn-hopper" )
745
+ record_pattern = "flashattn?hopper-*.dist-info/RECORD"
721
746
p = dist.locate_file("flash_attn_interface.py" )
722
747
print("flash_attn_interface:" , p)
723
748
root = p.parent
@@ -727,7 +752,7 @@ RUN <<-"EOT" python3
727
752
if not p.is_file():
728
753
raise SystemExit("flash_attn_interface path is not a file" )
729
754
730
- d = root / "flashattn_hopper "
755
+ d = root / "flash_attn_3 "
731
756
if d.exists():
732
757
raise SystemExit(f'"{d}" already exists' )
733
758
@@ -747,7 +772,7 @@ RUN <<-"EOT" python3
747
772
748
773
749
774
for f in dist.files:
750
- if f.match("flashattn?hopper-*.dist-info/RECORD" ):
775
+ if f.match(record_pattern ):
751
776
with f.locate().open("a" , encoding="utf-8" , newline="" ) as record:
752
777
for added in (new, compiled):
753
778
record.write(record_entry(added))
0 commit comments