@@ -272,18 +272,15 @@ def get_cutlass_build_flags():
272272 raise ValueError ("No CUDA version found" )
273273
274274 major , minor = map (int , cuda_version .split ("." )[:2 ])
275- build_sm90a = (major , minor ) >= (12 , 6 )
276- build_sm100a = (major , minor ) >= (12 , 8 )
277- build_sm120a = (major , minor ) >= (12 , 8 )
275+ build_sm90a = major > 12 or (major == 12 and minor >= 6 )
276+ build_sm100a = major > 12 or (major == 12 and minor >= 8 )
278277
279278 if build_sm90a :
280279 print (f"CUDA { cuda_version } : Enabling SM90a CUTLASS kernels" )
281280 if build_sm100a :
282281 print (f"CUDA { cuda_version } : Enabling SM100a CUTLASS kernels" )
283- if build_sm120a :
284- print (f"CUDA { cuda_version } : Enabling SM120a CUTLASS kernels" )
285282
286- return build_sm90a , build_sm100a , build_sm120a
283+ return build_sm90a , build_sm100a
287284 except :
288285 # Fallback to architecture flags
289286 cuda_arch_flags = _get_cuda_arch_flags ()
@@ -343,11 +340,6 @@ def __init__(
343340 self .cmake_args = cmake_args
344341
345342
346- def remove_items (a : list , b : list ) -> list :
347- """Remove items in list b from list a"""
348- return [x for x in a if x not in b ]
349-
350-
351343def get_extensions ():
352344 # Skip building C++ extensions if USE_CPP is set to "0"
353345 if use_cpp == "0" :
@@ -462,7 +454,7 @@ def get_extensions():
462454 excluded_sources = list (
463455 glob .glob (os .path .join (extensions_dir , "cpu/*.cpp" ), recursive = True )
464456 )
465- sources = remove_items ( sources , excluded_sources )
457+ sources = [ s for s in sources if s not in excluded_sources ]
466458
467459 # Collect CUDA source files
468460 extensions_cuda_dir = os .path .join (extensions_dir , "cuda" )
@@ -506,24 +498,22 @@ def get_extensions():
506498 rocm_sources = list (
507499 glob .glob (os .path .join (extensions_rocm_dir , "**/*.cpp" ), recursive = True )
508500 )
509- sources = remove_items ( sources , rocm_sources )
501+ sources = [ s for s in sources if s not in rocm_sources ]
510502
511- use_cutlass = use_cuda and not IS_WINDOWS
503+ use_cutlass = False
512504 cutlass_90a_sources = None
513505 cutlass_100a_sources = None
514- cutlass_120a_sources = None
515506 build_for_sm90a = False
516507 build_for_sm100a = False
517- build_for_sm120a = False
518-
519- if use_cutlass :
508+ if use_cuda and not IS_WINDOWS :
509+ use_cutlass = True
520510 cutlass_dir = os .path .join (third_party_path , "cutlass" )
521511 cutlass_include_dir = os .path .join (cutlass_dir , "include" )
522512 cutlass_tools_include_dir = os .path .join (
523513 cutlass_dir , "tools" , "util" , "include"
524514 )
525515 cutlass_extensions_include_dir = os .path .join (cwd , extensions_cuda_dir )
526-
516+ if use_cutlass :
527517 extra_compile_args ["nvcc" ].extend (
528518 [
529519 "-DTORCHAO_USE_CUTLASS" ,
@@ -543,7 +533,7 @@ def get_extensions():
543533 ]
544534 )
545535
546- build_for_sm90a , build_for_sm100a , build_for_sm120a = get_cutlass_build_flags ()
536+ build_for_sm90a , build_for_sm100a = get_cutlass_build_flags ()
547537 # Define sm90a sources
548538 cutlass_90a_sources = [
549539 os .path .join (
@@ -567,40 +557,40 @@ def get_extensions():
567557 "rowwise_scaled_linear_sparse_cutlass_" + dtypes + ".cu" ,
568558 )
569559 )
570- sources = remove_items (sources , cutlass_90a_sources )
560+ # Always remove sm90a sources from main sources
561+ sources = [s for s in sources if s not in cutlass_90a_sources ]
571562
572563 # Always compile mx_fp_cutlass_kernels.cu ONLY with sm100a architecture
573564 cutlass_100a_sources = [
574565 os .path .join (
575566 extensions_cuda_dir ,
576567 "mx_kernels" ,
577- "mx_fp_cutlass_kernels_sm100a .cu" ,
568+ "mx_fp_cutlass_kernels .cu" ,
578569 ),
579570 ]
580- sources = remove_items (sources , cutlass_100a_sources )
581-
582- # Always compile mx_fp_cutlass_kernels.cu ONLY with sm120a architecture
583- cutlass_120a_sources = [
584- os .path .join (
585- extensions_cuda_dir ,
586- "mx_kernels" ,
587- "mx_fp_cutlass_kernels_sm120a.cu" ,
588- ),
571+ # Remove from main sources to prevent compilation with other architectures
572+ sources = [
573+ s for s in sources if os .path .basename (s ) != "mx_fp_cutlass_kernels.cu"
589574 ]
590- sources = remove_items (sources , cutlass_120a_sources )
591575
592576 else :
593- # Remove CUTLASS-based kernels from the sources list. An assumption is that
594- # these files will have "cutlass" in its name.
577+ # Remove CUTLASS-based kernels from the sources list. An
578+ # assumption is that these files will have "cutlass" in its
579+ # name.
595580 cutlass_sources = list (
596581 glob .glob (
597582 os .path .join (extensions_cuda_dir , "**/*cutlass*.cu" ), recursive = True
598583 )
599584 )
600- sources = remove_items ( sources , cutlass_sources )
585+ sources = [ s for s in sources if s not in cutlass_sources ]
601586
602587 ext_modules = []
603588 if len (sources ) > 0 :
589+ # Double-check to ensure mx_fp_cutlass_kernels.cu is not in sources
590+ sources = [
591+ s for s in sources if os .path .basename (s ) != "mx_fp_cutlass_kernels.cu"
592+ ]
593+
604594 ext_modules .append (
605595 extension (
606596 "torchao._C" ,
@@ -653,27 +643,6 @@ def get_extensions():
653643 )
654644 )
655645
656- # Only build the cutlass_120a extension if sm120a is in the architecture flags
657- if (
658- cutlass_120a_sources is not None
659- and len (cutlass_120a_sources ) > 0
660- and build_for_sm120a
661- ):
662- cutlass_120a_extra_compile_args = copy .deepcopy (extra_compile_args )
663- # Only use sm120a architecture for these sources, ignoring cuda_arch_flags
664- cutlass_120a_extra_compile_args ["nvcc" ].append (
665- "-gencode=arch=compute_120a,code=sm_120a"
666- )
667- ext_modules .append (
668- extension (
669- "torchao._C_cutlass_120a" ,
670- cutlass_120a_sources ,
671- py_limited_api = True ,
672- extra_compile_args = cutlass_120a_extra_compile_args ,
673- extra_link_args = extra_link_args ,
674- )
675- )
676-
677646 # Build CMakeLists from /torchao/experimental - additional options become available : TORCHAO_BUILD_CPU_AARCH64, TORCHAO_BUILD_KLEIDIAI, TORCHAO_BUILD_MPS_OPS, TORCHAO_PARALLEL_BACKEND
678647 if build_macos_arm_auto or os .getenv ("BUILD_TORCHAO_EXPERIMENTAL" ) == "1" :
679648 build_options = BuildOptions ()
0 commit comments