Skip to content

Build mxfp4 kernel for sm120a #2285

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 200 additions & 0 deletions benchmarks/mx_formats/mm_bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import itertools
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from functools import partial
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import torch
from jsonargparse import CLI
from tabulate import tabulate
from torch._inductor.utils import do_bench_using_profiling
from tqdm import tqdm

from torchao.ops import mx_fp4_bf16
from torchao.prototype.mx_formats.mx_tensor import to_mx
from torchao.prototype.mx_formats.utils import to_blocked


class Format(Enum):
MX_FP8 = "MX-FP8"
MX_FP4 = "MX-FP4"


def get_mx_matmul(A: torch.Tensor, B: torch.Tensor, format: Format):
if format == Format.MX_FP8:
dtype = torch.float8_e4m3fn
fn = partial(torch._scaled_mm, out_dtype=torch.bfloat16)
elif format == Format.MX_FP4:
dtype = torch.float4_e2m1fn_x2
fn = mx_fp4_bf16
else:
raise ValueError(f"Invalid format: {format}")

a_scale, A_lp = to_mx(A, dtype, 32)
b_scale, B_lp_t = to_mx(B.T, dtype, 32)
assert B_lp_t.is_contiguous()
B_lp = B_lp_t.T

a_scale = to_blocked(a_scale.view(A.shape[0], A.shape[1] // 32))
b_scale = to_blocked(b_scale.view(B.shape[1], B.shape[0] // 32))

return lambda: fn(A_lp, B_lp, a_scale, b_scale)


@dataclass(frozen=True)
class ExperimentConfig:
M: int
K: int
N: int
format: Format


@dataclass(frozen=True)
class ExperimentResult:
time: float
tflops: float


@dataclass(frozen=True)
class Experiment:
config: ExperimentConfig
result: ExperimentResult


def calculate_tflops(M: int, N: int, K: int, time_us: float) -> float:
"""Calculate TFLOPS (Tera Floating Point Operations Per Second)"""
# Number of floating point operations for matrix multiplication
flops = 2 * M * N * K
tflops = (flops / time_us) / 1e6 # Convert to TFLOPS
return tflops


def run_experiment(config: ExperimentConfig) -> ExperimentResult:
A = torch.zeros(config.M, config.K, device="cuda", dtype=torch.bfloat16)
B = torch.zeros(config.N, config.K, device="cuda", dtype=torch.bfloat16).T

matmul = get_mx_matmul(A, B, config.format)

# Warmup phase
warmup_iterations = 5
for _ in range(warmup_iterations):
_ = matmul()
torch.cuda.synchronize()

# Actual benchmarking
time_us = do_bench_using_profiling(matmul) * 1e3
tflops = calculate_tflops(config.M, config.N, config.K, time_us)

return ExperimentResult(time=time_us, tflops=tflops)


def print_results(experiments: list[Experiment], save_path: Path | None = None):
headers = ["M", "K", "N", "Format", "Time (ms)", "TFLOPS"]
rows = []
for experiment in experiments:
config = experiment.config
result = experiment.result

rows.append(
[
config.M,
config.K,
config.N,
config.format.value,
f"{result.time:.4f}",
f"{result.tflops:.2f}",
]
)

print(tabulate(rows, headers=headers))

if save_path is not None:
pd.DataFrame(rows, columns=headers).to_csv(save_path, index=False)
print(f"💾 Results saved to: {save_path}")


def plot_tflops_comparison(df, save_path: Path):
plt.figure(figsize=(12, 6))
grouped = df.groupby(["K", "Format"])
k_values = sorted(df["K"].unique())
formats = df["Format"].unique()
m_value = df["M"].iloc[0]
n_value = df["N"].iloc[0]

# Plot MX kernel performance
for format in formats:
try:
tflops_values = [
grouped.get_group((k, format))["TFLOPS"].values[0] for k in k_values
]
plt.plot(k_values, tflops_values, label=format)
except KeyError:
# Skip if this combination doesn't exist in the data
continue

plt.xlabel("K (Matrix Dimension)")
plt.ylabel("TFLOPS")

# Set y-axis to start at 0
plt.ylim(bottom=0)

title = f"MX Matrix Multiplication Performance \nM={m_value}, N={n_value}"
plt.title(title)

plt.legend()
plt.grid(True, which="both", ls="-", alpha=0.2)
plt.xticks(k_values, rotation=45, ha="right")
plt.tight_layout()

# Generate the file name and save in the same directory as the CSV file
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
file_name = f"mx_{m_value}_{n_value}_{timestamp}.png"
graph_path = save_path.parent / file_name
plt.savefig(graph_path, dpi=300)
print(f"TFLOPS comparison plot saved as {graph_path}")


def get_configs_varying_k(M: int = 8192, N: int = 8192) -> list[ExperimentConfig]:
shapes = [(M, K, N) for K in range(1024, 16385, 1024)]
formats = [Format.MX_FP8, Format.MX_FP4]

configs = [
ExperimentConfig(M=M, K=K, N=N, format=format)
for (M, K, N), format in itertools.product(shapes, formats)
]
return configs


def main(
save_path: str | None = None, M: int = 8192, N: int = 8192, graph: bool = False
):
"""Benchmark MX MatMul with different configurations and optionally graph results.

Args:
save_path (Optional[str], optional): Path to save the results. Defaults to None.
M (int, optional): Number of rows in the first matrix. Defaults to 8192.
N (int, optional): Number of columns in the second matrix. Defaults to 8192.
graph (bool, optional): Whether to create a graph of the results. Defaults to False.
"""
torch.random.manual_seed(123)
configs = get_configs_varying_k(M, N)
results = []
if save_path is not None:
save_path = Path(save_path)
save_path = save_path.with_suffix(".csv")
save_path.parent.mkdir(parents=True, exist_ok=True)
for config in tqdm(configs):
result = run_experiment(config)
results.append(Experiment(config=config, result=result))
print_results(results, save_path)

if graph and save_path is not None:
df = pd.read_csv(save_path)
plot_tflops_comparison(df, save_path)


if __name__ == "__main__":
CLI(main)
70 changes: 45 additions & 25 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,15 +274,18 @@ def get_cutlass_build_flags():
raise ValueError("No CUDA version found")

major, minor = map(int, cuda_version.split(".")[:2])
build_sm90a = major > 12 or (major == 12 and minor >= 6)
build_sm100a = major > 12 or (major == 12 and minor >= 8)
build_sm90a = (major, minor) >= (12, 6)
build_sm100a = (major, minor) >= (12, 8)
build_sm120a = (major, minor) >= (12, 8)

if build_sm90a:
print(f"CUDA {cuda_version}: Enabling SM90a CUTLASS kernels")
if build_sm100a:
print(f"CUDA {cuda_version}: Enabling SM100a CUTLASS kernels")
if build_sm120a:
print(f"CUDA {cuda_version}: Enabling SM120a CUTLASS kernels")

return build_sm90a, build_sm100a
return build_sm90a, build_sm100a, build_sm120a
except:
# Fallback to architecture flags
cuda_arch_flags = _get_cuda_arch_flags()
Expand Down Expand Up @@ -495,8 +498,10 @@ def get_extensions():
use_cutlass = False
cutlass_90a_sources = None
cutlass_100a_sources = None
cutlass_120a_sources = None
build_for_sm90a = False
build_for_sm100a = False
build_for_sm120a = False
if use_cuda and not IS_WINDOWS:
use_cutlass = True
cutlass_dir = os.path.join(third_party_path, "cutlass")
Expand Down Expand Up @@ -525,7 +530,7 @@ def get_extensions():
]
)

build_for_sm90a, build_for_sm100a = get_cutlass_build_flags()
build_for_sm90a, build_for_sm100a, build_for_sm120a = get_cutlass_build_flags()
# Define sm90a sources
cutlass_90a_sources = [
os.path.join(
Expand All @@ -549,40 +554,34 @@ def get_extensions():
"rowwise_scaled_linear_sparse_cutlass_" + dtypes + ".cu",
)
)
# Always remove sm90a sources from main sources
sources = [s for s in sources if s not in cutlass_90a_sources]

# Always compile mx_fp_cutlass_kernels.cu ONLY with sm100a architecture
cutlass_100a_sources = [
os.path.join(
extensions_cuda_dir,
"mx_kernels",
"mx_fp_cutlass_kernels.cu",
"mx_fp_cutlass_kernels_sm100a.cu",
),
]
# Remove from main sources to prevent compilation with other architectures
sources = [
s for s in sources if os.path.basename(s) != "mx_fp_cutlass_kernels.cu"

# Always compile mx_fp_cutlass_kernels.cu ONLY with sm120a architecture
cutlass_120a_sources = [
os.path.join(
extensions_cuda_dir,
"mx_kernels",
"mx_fp_cutlass_kernels_sm120a.cu",
),
]

else:
# Remove CUTLASS-based kernels from the sources list. An
# assumption is that these files will have "cutlass" in its
# name.
cutlass_sources = list(
glob.glob(
os.path.join(extensions_cuda_dir, "**/*cutlass*.cu"), recursive=True
)
)
sources = [s for s in sources if s not in cutlass_sources]
# Remove CUTLASS-based kernels from the sources list. An assumption is that
# these files will have "cutlass" in its name.
cutlass_sources = list(
glob.glob(os.path.join(extensions_cuda_dir, "**/*cutlass*.cu"), recursive=True)
)
sources = [s for s in sources if s not in cutlass_sources]

ext_modules = []
if len(sources) > 0:
# Double-check to ensure mx_fp_cutlass_kernels.cu is not in sources
sources = [
s for s in sources if os.path.basename(s) != "mx_fp_cutlass_kernels.cu"
]

ext_modules.append(
extension(
"torchao._C",
Expand Down Expand Up @@ -635,6 +634,27 @@ def get_extensions():
)
)

# Only build the cutlass_120a extension if sm120a is in the architecture flags
if (
cutlass_120a_sources is not None
and len(cutlass_120a_sources) > 0
and build_for_sm120a
):
cutlass_120a_extra_compile_args = copy.deepcopy(extra_compile_args)
# Only use sm120a architecture for these sources, ignoring cuda_arch_flags
cutlass_120a_extra_compile_args["nvcc"].append(
"-gencode=arch=compute_120a,code=sm_120a"
)
ext_modules.append(
extension(
"torchao._C_cutlass_120a",
cutlass_120a_sources,
py_limited_api=True,
extra_compile_args=cutlass_120a_extra_compile_args,
extra_link_args=extra_link_args,
)
)

# Build CMakeLists from /torchao/experimental - additional options become available : TORCHAO_BUILD_CPU_AARCH64, TORCHAO_BUILD_KLEIDIAI, TORCHAO_BUILD_MPS_OPS, TORCHAO_PARALLEL_BACKEND
if build_macos_arm_auto or os.getenv("BUILD_TORCHAO_EXPERIMENTAL") == "1":
build_options = BuildOptions()
Expand Down
5 changes: 3 additions & 2 deletions test/prototype/mx_formats/test_mx_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torchao.prototype.mx_formats.utils import to_blocked
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_8,
is_sm_at_least_100,
is_sm_version,
)

if not TORCH_VERSION_AT_LEAST_2_8:
Expand Down Expand Up @@ -59,7 +59,8 @@ def run_matrix_test(M: int, K: int, N: int, format) -> float:

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for mxfloat8"
not (is_sm_version(10, 0) or is_sm_version(12, 0)),
reason="CUDA capability 10.0 or 12.0 is required for mxfloat8",
)
@pytest.mark.parametrize(
"size",
Expand Down
13 changes: 13 additions & 0 deletions torchao/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,21 @@

so_files = list(Path(__file__).parent.glob("_C*.so"))
if len(so_files) > 0:
compute_capability = (
torch.cuda.get_device_capability() if torch.cuda.is_available() else None
)

for file in so_files:
# only load architecture-specific target if the current GPU matches that target
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

if (
("cutlass_90a" in file.name and compute_capability != (9, 0))
or ("cutlass_100a" in file.name and compute_capability != (10, 0))
or ("cutlass_120a" in file.name and compute_capability != (12, 0))
):
continue

torch.ops.load_library(str(file))

from . import ops

# The following library contains CPU kernels from torchao/experimental
Expand Down
Loading
Loading