@@ -272,15 +272,18 @@ def get_cutlass_build_flags():
272272 raise ValueError ("No CUDA version found" )
273273
274274 major , minor = map (int , cuda_version .split ("." )[:2 ])
275- build_sm90a = major > 12 or (major == 12 and minor >= 6 )
276- build_sm100a = major > 12 or (major == 12 and minor >= 8 )
275+ build_sm90a = (major , minor ) >= (12 , 6 )
276+ build_sm100a = (major , minor ) >= (12 , 8 )
277+ build_sm120a = (major , minor ) >= (12 , 8 )
277278
278279 if build_sm90a :
279280 print (f"CUDA { cuda_version } : Enabling SM90a CUTLASS kernels" )
280281 if build_sm100a :
281282 print (f"CUDA { cuda_version } : Enabling SM100a CUTLASS kernels" )
283+ if build_sm120a :
284+ print (f"CUDA { cuda_version } : Enabling SM120a CUTLASS kernels" )
282285
283- return build_sm90a , build_sm100a
286+ return build_sm90a , build_sm100a , build_sm120a
284287 except :
285288 # Fallback to architecture flags
286289 cuda_arch_flags = _get_cuda_arch_flags ()
@@ -340,6 +343,11 @@ def __init__(
340343 self .cmake_args = cmake_args
341344
342345
346+ def remove_items (a : list , b : list ) -> list :
347+ """Remove items in list b from list a"""
348+ return [x for x in a if x not in b ]
349+
350+
343351def get_extensions ():
344352 # Skip building C++ extensions if USE_CPP is set to "0"
345353 if use_cpp == "0" :
@@ -454,7 +462,7 @@ def get_extensions():
454462 excluded_sources = list (
455463 glob .glob (os .path .join (extensions_dir , "cpu/*.cpp" ), recursive = True )
456464 )
457- sources = [ s for s in sources if s not in excluded_sources ]
465+ sources = remove_items ( sources , excluded_sources )
458466
459467 # Collect CUDA source files
460468 extensions_cuda_dir = os .path .join (extensions_dir , "cuda" )
@@ -498,22 +506,24 @@ def get_extensions():
498506 rocm_sources = list (
499507 glob .glob (os .path .join (extensions_rocm_dir , "**/*.cpp" ), recursive = True )
500508 )
501- sources = [ s for s in sources if s not in rocm_sources ]
509+ sources = remove_items ( sources , rocm_sources )
502510
503- use_cutlass = False
511+ use_cutlass = use_cuda and not IS_WINDOWS
504512 cutlass_90a_sources = None
505513 cutlass_100a_sources = None
514+ cutlass_120a_sources = None
506515 build_for_sm90a = False
507516 build_for_sm100a = False
508- if use_cuda and not IS_WINDOWS :
509- use_cutlass = True
517+ build_for_sm120a = False
518+
519+ if use_cutlass :
510520 cutlass_dir = os .path .join (third_party_path , "cutlass" )
511521 cutlass_include_dir = os .path .join (cutlass_dir , "include" )
512522 cutlass_tools_include_dir = os .path .join (
513523 cutlass_dir , "tools" , "util" , "include"
514524 )
515525 cutlass_extensions_include_dir = os .path .join (cwd , extensions_cuda_dir )
516- if use_cutlass :
526+
517527 extra_compile_args ["nvcc" ].extend (
518528 [
519529 "-DTORCHAO_USE_CUTLASS" ,
@@ -533,7 +543,7 @@ def get_extensions():
533543 ]
534544 )
535545
536- build_for_sm90a , build_for_sm100a = get_cutlass_build_flags ()
546+ build_for_sm90a , build_for_sm100a , build_for_sm120a = get_cutlass_build_flags ()
537547 # Define sm90a sources
538548 cutlass_90a_sources = [
539549 os .path .join (
@@ -557,40 +567,40 @@ def get_extensions():
557567 "rowwise_scaled_linear_sparse_cutlass_" + dtypes + ".cu" ,
558568 )
559569 )
560- # Always remove sm90a sources from main sources
561- sources = [s for s in sources if s not in cutlass_90a_sources ]
570+ sources = remove_items (sources , cutlass_90a_sources )
562571
563572 # Always compile mx_fp_cutlass_kernels.cu ONLY with sm100a architecture
564573 cutlass_100a_sources = [
565574 os .path .join (
566575 extensions_cuda_dir ,
567576 "mx_kernels" ,
568- "mx_fp_cutlass_kernels .cu" ,
577+ "mx_fp_cutlass_kernels_sm100a .cu" ,
569578 ),
570579 ]
571- # Remove from main sources to prevent compilation with other architectures
572- sources = [
573- s for s in sources if os .path .basename (s ) != "mx_fp_cutlass_kernels.cu"
580+ sources = remove_items (sources , cutlass_100a_sources )
581+
582+ # Always compile mx_fp_cutlass_kernels.cu ONLY with sm120a architecture
583+ cutlass_120a_sources = [
584+ os .path .join (
585+ extensions_cuda_dir ,
586+ "mx_kernels" ,
587+ "mx_fp_cutlass_kernels_sm120a.cu" ,
588+ ),
574589 ]
590+ sources = remove_items (sources , cutlass_120a_sources )
575591
576592 else :
577- # Remove CUTLASS-based kernels from the sources list. An
578- # assumption is that these files will have "cutlass" in its
579- # name.
593+ # Remove CUTLASS-based kernels from the sources list. An assumption is that
594+ # these files will have "cutlass" in its name.
580595 cutlass_sources = list (
581596 glob .glob (
582597 os .path .join (extensions_cuda_dir , "**/*cutlass*.cu" ), recursive = True
583598 )
584599 )
585- sources = [ s for s in sources if s not in cutlass_sources ]
600+ sources = remove_items ( sources , cutlass_sources )
586601
587602 ext_modules = []
588603 if len (sources ) > 0 :
589- # Double-check to ensure mx_fp_cutlass_kernels.cu is not in sources
590- sources = [
591- s for s in sources if os .path .basename (s ) != "mx_fp_cutlass_kernels.cu"
592- ]
593-
594604 ext_modules .append (
595605 extension (
596606 "torchao._C" ,
@@ -643,6 +653,27 @@ def get_extensions():
643653 )
644654 )
645655
656+ # Only build the cutlass_120a extension if sm120a is in the architecture flags
657+ if (
658+ cutlass_120a_sources is not None
659+ and len (cutlass_120a_sources ) > 0
660+ and build_for_sm120a
661+ ):
662+ cutlass_120a_extra_compile_args = copy .deepcopy (extra_compile_args )
663+ # Only use sm120a architecture for these sources, ignoring cuda_arch_flags
664+ cutlass_120a_extra_compile_args ["nvcc" ].append (
665+ "-gencode=arch=compute_120a,code=sm_120a"
666+ )
667+ ext_modules .append (
668+ extension (
669+ "torchao._C_cutlass_120a" ,
670+ cutlass_120a_sources ,
671+ py_limited_api = True ,
672+ extra_compile_args = cutlass_120a_extra_compile_args ,
673+ extra_link_args = extra_link_args ,
674+ )
675+ )
676+
646677 # Build CMakeLists from /torchao/experimental - additional options become available : TORCHAO_BUILD_CPU_AARCH64, TORCHAO_BUILD_KLEIDIAI, TORCHAO_BUILD_MPS_OPS, TORCHAO_PARALLEL_BACKEND
647678 if build_macos_arm_auto or os .getenv ("BUILD_TORCHAO_EXPERIMENTAL" ) == "1" :
648679 build_options = BuildOptions ()
0 commit comments