Skip to content

Commit 563fc7c

Browse files
committed
build for sm120a
1 parent 1239842 commit 563fc7c

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

setup.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,7 @@ def get_extensions():
515515
"-DCUTE_USE_PACKED_TUPLE=1",
516516
"-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
517517
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1",
518-
"-DCUTLASS_DEBUG_TRACE_LEVEL=0",
518+
"-DCUTLASS_DEBUG_TRACE_LEVEL=1",
519519
"--ftemplate-backtrace-limit=0",
520520
# "--keep",
521521
# "--ptxas-options=--verbose,--register-usage-level=5,--warn-on-local-memory-usage",
@@ -526,6 +526,7 @@ def get_extensions():
526526
)
527527

528528
build_for_sm90a, build_for_sm100a = get_cutlass_build_flags()
529+
build_for_sm100a = True
529530
# Define sm90a sources
530531
cutlass_90a_sources = [
531532
os.path.join(
@@ -623,7 +624,8 @@ def get_extensions():
623624
cutlass_100a_extra_compile_args = copy.deepcopy(extra_compile_args)
624625
# Only use sm100a architecture for these sources, ignoring cuda_arch_flags
625626
cutlass_100a_extra_compile_args["nvcc"].append(
626-
"-gencode=arch=compute_100a,code=sm_100a"
627+
# "-gencode=arch=compute_100a,code=sm_100a"
628+
"-gencode=arch=compute_120a,code=sm_120a",
627629
)
628630
ext_modules.append(
629631
extension(

torchao/csrc/cuda/mx_kernels/mx_fp_cutlass_kernels.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ void run_gemm(at::Tensor& a, at::Tensor& b, at::Tensor& a_scale,
6969
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
7070
// Kernel functional config
7171
using ElementAccumulator = float; // Element type for internal accumulation
72-
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
72+
using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature
7373
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag
7474

7575

@@ -241,7 +241,8 @@ at::Tensor mx_fp4_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale,
241241
using ElementD = cutlass::bfloat16_t;
242242

243243
using MmaTileShape = Shape<_128,_128,_128>;
244-
using ClusterShape = Shape<_2,_1,_1>;
244+
// using ClusterShape = Shape<_2,_1,_1>;
245+
using ClusterShape = Shape<_1,_1,_1>;
245246
using PerSmTileShape_MNK = Shape<_128,_128,_128>;
246247

247248
run_gemm<ElementA, ElementB, ElementD, MmaTileShape, ClusterShape, PerSmTileShape_MNK>(a, b, a_scale, b_scale, out, M, K, N);

0 commit comments

Comments
 (0)