Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 43 additions & 17 deletions benchmarks/float8/bench_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
get_name_to_shapes_iter,
)

from torchao.ops import mx_fp4_bf16
from torchao.prototype.mx_formats.mx_tensor import to_mx
from torchao.testing.float8.roofline_utils import get_specs


Expand Down Expand Up @@ -62,13 +64,19 @@ def run(
):
device = "cuda"
# TODO(future PR): this is ugly
assert recipe in ("tensorwise", "rowwise", "mxfp8_cublas"), "unsupported"
assert recipe in ("tensorwise", "rowwise", "mxfp8_cublas", "mxfp4_cutlass"), (
"unsupported"
)
use_fp4 = recipe == "mxfp4_cutlass"

specs = get_specs()
bf16_peak_tops = specs["bf16_peak_tops"]
fp8_peak_tops = specs["fp8_peak_tops"]
fp4_peak_tops = specs["fp4_peak_tops"]
print(f"gpu_name: {torch.cuda.get_device_name(0)}")
print(f"peak tops: bf16 {bf16_peak_tops:.2e}, fp8 {fp8_peak_tops:.2e}")
print(
f"peak tops: bf16 {bf16_peak_tops:.2e}, fp8 {fp8_peak_tops:.2e}, fp4 {fp4_peak_tops:.2e}"
)

headers = (
"fast_accum",
Expand All @@ -77,14 +85,14 @@ def run(
"K",
"N",
"ref_time_s",
"fp8_time_s",
"fp8_speedup",
"time_s",
"speedup",
)
results = []

dtype = torch.bfloat16
name_to_shapes = get_name_to_shapes_iter(shape_gen_name, M, K, N)
fast_accum_vals = [True, False]
fast_accum_vals = [False] if use_fp4 else [True, False]

for idx, (fast_accum, (name, (M, K, N))) in enumerate(
itertools.product(fast_accum_vals, name_to_shapes)
Expand All @@ -107,35 +115,53 @@ def run(

del A

# raw float8 matmul (upper bound for what we can achive in eager mode)
# TODO(future): add e5m2
d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, dtype
A = torch.zeros(M, K, device=device, dtype=d1)
B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t()
A_hp = torch.randn(M, K, device=device)
B_hp_t = torch.randn(N, K, device=device)

if use_fp4:
_, A = to_mx(A_hp, torch.float4_e2m1fn_x2, 32)
_, Bt = to_mx(B_hp_t, torch.float4_e2m1fn_x2, 32)
B = Bt.contiguous().T
peak_tops = fp4_peak_tops
else:
# raw float8 matmul (upper bound for what we can achive in eager mode)
# TODO(future): add e5m2
d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, dtype
A = A_hp.to(d1)
B = B_hp_t.to(d2).contiguous().T
peak_tops = fp8_peak_tops

if recipe == "tensorwise":
scale_a = torch.tensor([1.0], device=device)
scale_b = torch.tensor([1.0], device=device)
elif recipe == "rowwise":
scale_a = torch.ones(M, 1, device=device)
scale_b = torch.ones(1, N, device=device)
elif recipe == "mxfp8_cublas":
elif recipe in ("mxfp8_cublas", "mxfp4_cutlass"):
scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu)
scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu)
else:
assert False, f"unknown recipe {recipe}"

def do_matmul(A, B):
def do_matmul_fp8(A, B):
nonlocal scale_a
nonlocal scale_b
return torch._scaled_mm(
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum
)

fp8_time_sec, fp8_tops_sec, fp8_pct_top_peak = do_benchmarks(
tops, fp8_peak_tops, use_gpu_kernel_time, do_matmul, A, B
def do_matmul_mxfp4(A, B):
nonlocal scale_a
nonlocal scale_b
return mx_fp4_bf16(A, B, scale_a, scale_b)

do_matmul = do_matmul_mxfp4 if use_fp4 else do_matmul_fp8

time_sec, tops_sec, pct_top_peak = do_benchmarks(
tops, peak_tops, use_gpu_kernel_time, do_matmul, A, B
)
print(
f"fp8 time_sec {fp8_time_sec:.2E}, tops/sec {fp8_tops_sec:.2E}, pct_peak {fp8_pct_top_peak:.3f}"
f"time_sec {time_sec:.2E}, tops/sec {tops_sec:.2E}, pct_peak {pct_top_peak:.3f}"
)

del A, B, scale_a, scale_b
Expand All @@ -148,8 +174,8 @@ def do_matmul(A, B):
K,
N,
ref_time_sec,
fp8_time_sec,
ref_time_sec / fp8_time_sec,
time_sec,
ref_time_sec / time_sec,
]
)

Expand Down
9 changes: 3 additions & 6 deletions benchmarks/float8/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,9 +352,6 @@ def get_gpu_kernel_gemm_time_s(f, *args, **kwargs):
)
# there is only 1 key, aten::mm or aten::_scaled_mm, with unit nanoseconds
assert len(data) == 1
if "aten::mm" in data:
return data["aten::mm"] / 1e6 / n_iter
elif "aten::_scaled_mm" in data:
return data["aten::_scaled_mm"] / 1e6 / n_iter
else:
raise AssertionError("unexpected format of data")
key, value = next(iter(data.items()))
assert key in ("aten::mm", "aten::_scaled_mm", "torchao::mx_fp4_bf16")
return value / 1e6 / n_iter
81 changes: 56 additions & 25 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,15 +272,18 @@ def get_cutlass_build_flags():
raise ValueError("No CUDA version found")

major, minor = map(int, cuda_version.split(".")[:2])
build_sm90a = major > 12 or (major == 12 and minor >= 6)
build_sm100a = major > 12 or (major == 12 and minor >= 8)
build_sm90a = (major, minor) >= (12, 6)
build_sm100a = (major, minor) >= (12, 8)
build_sm120a = (major, minor) >= (12, 8)

if build_sm90a:
print(f"CUDA {cuda_version}: Enabling SM90a CUTLASS kernels")
if build_sm100a:
print(f"CUDA {cuda_version}: Enabling SM100a CUTLASS kernels")
if build_sm120a:
print(f"CUDA {cuda_version}: Enabling SM120a CUTLASS kernels")

return build_sm90a, build_sm100a
return build_sm90a, build_sm100a, build_sm120a
except:
# Fallback to architecture flags
cuda_arch_flags = _get_cuda_arch_flags()
Expand Down Expand Up @@ -340,6 +343,11 @@ def __init__(
self.cmake_args = cmake_args


def remove_items(a: list, b: list) -> list:
"""Remove items in list b from list a"""
return [x for x in a if x not in b]


def get_extensions():
# Skip building C++ extensions if USE_CPP is set to "0"
if use_cpp == "0":
Expand Down Expand Up @@ -454,7 +462,7 @@ def get_extensions():
excluded_sources = list(
glob.glob(os.path.join(extensions_dir, "cpu/*.cpp"), recursive=True)
)
sources = [s for s in sources if s not in excluded_sources]
sources = remove_items(sources, excluded_sources)

# Collect CUDA source files
extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
Expand Down Expand Up @@ -498,22 +506,24 @@ def get_extensions():
rocm_sources = list(
glob.glob(os.path.join(extensions_rocm_dir, "**/*.cpp"), recursive=True)
)
sources = [s for s in sources if s not in rocm_sources]
sources = remove_items(sources, rocm_sources)

use_cutlass = False
use_cutlass = use_cuda and not IS_WINDOWS
cutlass_90a_sources = None
cutlass_100a_sources = None
cutlass_120a_sources = None
build_for_sm90a = False
build_for_sm100a = False
if use_cuda and not IS_WINDOWS:
use_cutlass = True
build_for_sm120a = False

if use_cutlass:
cutlass_dir = os.path.join(third_party_path, "cutlass")
cutlass_include_dir = os.path.join(cutlass_dir, "include")
cutlass_tools_include_dir = os.path.join(
cutlass_dir, "tools", "util", "include"
)
cutlass_extensions_include_dir = os.path.join(cwd, extensions_cuda_dir)
if use_cutlass:

extra_compile_args["nvcc"].extend(
[
"-DTORCHAO_USE_CUTLASS",
Expand All @@ -533,7 +543,7 @@ def get_extensions():
]
)

build_for_sm90a, build_for_sm100a = get_cutlass_build_flags()
build_for_sm90a, build_for_sm100a, build_for_sm120a = get_cutlass_build_flags()
# Define sm90a sources
cutlass_90a_sources = [
os.path.join(
Expand All @@ -557,40 +567,40 @@ def get_extensions():
"rowwise_scaled_linear_sparse_cutlass_" + dtypes + ".cu",
)
)
# Always remove sm90a sources from main sources
sources = [s for s in sources if s not in cutlass_90a_sources]
sources = remove_items(sources, cutlass_90a_sources)

# Always compile mx_fp_cutlass_kernels.cu ONLY with sm100a architecture
cutlass_100a_sources = [
os.path.join(
extensions_cuda_dir,
"mx_kernels",
"mx_fp_cutlass_kernels.cu",
"mx_fp_cutlass_kernels_sm100a.cu",
),
]
# Remove from main sources to prevent compilation with other architectures
sources = [
s for s in sources if os.path.basename(s) != "mx_fp_cutlass_kernels.cu"
sources = remove_items(sources, cutlass_100a_sources)

# Always compile mx_fp_cutlass_kernels.cu ONLY with sm120a architecture
cutlass_120a_sources = [
os.path.join(
extensions_cuda_dir,
"mx_kernels",
"mx_fp_cutlass_kernels_sm120a.cu",
),
]
sources = remove_items(sources, cutlass_120a_sources)

else:
# Remove CUTLASS-based kernels from the sources list. An
# assumption is that these files will have "cutlass" in its
# name.
# Remove CUTLASS-based kernels from the sources list. An assumption is that
# these files will have "cutlass" in its name.
cutlass_sources = list(
glob.glob(
os.path.join(extensions_cuda_dir, "**/*cutlass*.cu"), recursive=True
)
)
sources = [s for s in sources if s not in cutlass_sources]
sources = remove_items(sources, cutlass_sources)

ext_modules = []
if len(sources) > 0:
# Double-check to ensure mx_fp_cutlass_kernels.cu is not in sources
sources = [
s for s in sources if os.path.basename(s) != "mx_fp_cutlass_kernels.cu"
]

ext_modules.append(
extension(
"torchao._C",
Expand Down Expand Up @@ -643,6 +653,27 @@ def get_extensions():
)
)

# Only build the cutlass_120a extension if sm120a is in the architecture flags
if (
cutlass_120a_sources is not None
and len(cutlass_120a_sources) > 0
and build_for_sm120a
):
cutlass_120a_extra_compile_args = copy.deepcopy(extra_compile_args)
# Only use sm120a architecture for these sources, ignoring cuda_arch_flags
cutlass_120a_extra_compile_args["nvcc"].append(
"-gencode=arch=compute_120a,code=sm_120a"
)
ext_modules.append(
extension(
"torchao._C_cutlass_120a",
cutlass_120a_sources,
py_limited_api=True,
extra_compile_args=cutlass_120a_extra_compile_args,
extra_link_args=extra_link_args,
)
)

# Build CMakeLists from /torchao/experimental - additional options become available : TORCHAO_BUILD_CPU_AARCH64, TORCHAO_BUILD_KLEIDIAI, TORCHAO_BUILD_MPS_OPS, TORCHAO_PARALLEL_BACKEND
if build_macos_arm_auto or os.getenv("BUILD_TORCHAO_EXPERIMENTAL") == "1":
build_options = BuildOptions()
Expand Down
5 changes: 3 additions & 2 deletions test/prototype/mx_formats/test_mx_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torchao.prototype.mx_formats.utils import to_blocked
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_8,
is_sm_at_least_100,
is_sm_version,
)

if not TORCH_VERSION_AT_LEAST_2_8:
Expand Down Expand Up @@ -59,7 +59,8 @@ def run_matrix_test(M: int, K: int, N: int, format) -> float:

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for mxfloat8"
not (is_sm_version(10, 0) or is_sm_version(12, 0)),
reason="CUDA capability 10.0 or 12.0 is required for mxfloat8",
)
@pytest.mark.parametrize(
"size",
Expand Down
13 changes: 13 additions & 0 deletions torchao/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,21 @@

so_files = list(Path(__file__).parent.glob("_C*.so"))
if len(so_files) > 0:
compute_capability = (
torch.cuda.get_device_capability() if torch.cuda.is_available() else None
)

for file in so_files:
# only load architecture-specific target if the current GPU matches that target
if (
("cutlass_90a" in file.name and compute_capability != (9, 0))
or ("cutlass_100a" in file.name and compute_capability != (10, 0))
or ("cutlass_120a" in file.name and compute_capability != (12, 0))
):
continue

torch.ops.load_library(str(file))

from . import ops

# The following library contains CPU kernels from torchao/experimental
Expand Down
Loading
Loading