@@ -799,26 +799,17 @@ def handle_tensor_load_to_lds(emitter: WaveEmitter, node: fx.Node):
799799 destinations
800800 ), "sources and destinations must have the same number of elements."
801801
802- # construct default descriptors
802+ i1 = IntegerType .get_signless (1 )
803+ i16 = IntegerType .get_signless (16 )
803804 i32 = IntegerType .get_signless (32 )
804- i48 = IntegerType .get_signless (48 )
805- i57 = IntegerType .get_signless (57 )
806-
807- vec_type_4 = VectorType .get ((4 ,), i32 )
808- vec_type_8 = VectorType .get ((8 ,), i32 )
809-
810- c0 = arith_d .constant (i32 , 0 )
805+ v1i16 = VectorType .get ([1 ], i16 )
806+ v16i1 = VectorType .get ([16 ], i1 )
811807
812808 d0_results = []
813- d1_results = []
814- d2_results = []
815- d3_results = []
816809
817810 subs = add_emitter_subs (emitter )
818811
819812 for i , (src , dst ) in enumerate (zip (sources , destinations )):
820- dst_memory = get_custom (dst )
821-
822813 symbolic_shape = _get_symbolic_shape (src )
823814 global_tile_index_current = {k : global_tile_index [k ] for k in symbolic_shape }
824815 global_tile_index_current = _subs_index_dict (
@@ -833,210 +824,85 @@ def handle_tensor_load_to_lds(emitter: WaveEmitter, node: fx.Node):
833824 strides = strides_from_symbolic_shape (
834825 IndexingContext .current (), symbolic_shape , allow_mixed_shapes = True
835826 )
836- # Descriptor assumes rightmost stride 1 and expect last stride as full data size
837- strides = [strides [0 ] * symbolic_shape [0 ]] + strides [:- 1 ]
838- strides = [gen_sympy_index (subs , s ) for s in strides ]
839827
840828 distributed_shape_vals = [
841829 gen_sympy_index (subs , distributed_shape [s ]) for s in symbolic_shape
842830 ]
843831
844- d0 = vector_d .broadcast (vec_type_4 , c0 )
845- d1 = vector_d .broadcast (vec_type_8 , c0 )
846- d2 = vector_d .broadcast (vec_type_4 , c0 )
847- d3 = vector_d .broadcast (vec_type_4 , c0 )
848-
849- # descriptor properties
850- mode = 2 # vimage
851- valid = 1
852- dim_stride_1 = arith_d .index_cast (i48 , strides [0 ])
853- dim_stride_0 = arith_d .index_cast (i48 , strides [1 ])
854- tile_size_1 = arith_d .index_cast (i32 , distributed_shape_vals [0 ])
855- tile_size_0 = arith_d .index_cast (i32 , distributed_shape_vals [1 ])
856- dim_size_1 = arith_d .index_cast (i32 , local_bounds [0 ])
857- dim_size_0 = arith_d .index_cast (i32 , local_bounds [1 ])
858-
859- # 0: 1 byte; 1: 2 byte; 2: 4 byte; 3: 8 byte
860- descriptor_type = lambda x : int (math .log2 (x .bitwidth () >> 3 ))
861- data_size = cast_py_value (emitter , descriptor_type (element_type ), i32 ).ir_value
862-
863832 global_mem = cast_py_value (emitter , src )
864833 shared_mem = cast_py_value (emitter , dst )
865834
866835 global_value = global_mem .ir_value
867836 shared_value = shared_mem .ir_value
868837
869838 bytewidth = element_type .bitwidth () // 8
870- element_byte_index = arith_d .constant (IndexType .get (), bytewidth )
871-
872- # calculcate global address
873- # 0. breakdown index sequence to WG & TH offsets : ele
874- # 1. uniform per wave access : ele
875- # 2. linearize global memory buffer
876- # 3. offset = X + Y * tensor dim 0 stride : ele
877- # 4. offset_byte = offset * element byte : byte
878- # 5. get global memory pointer
879- # 6. move global memory pointer by offset_byte to get global address of a tile : byte
880- index , _ , _ = _build_start_indices (emitter , global_tile_index_current )
881-
882- wave_index_x = assume_index_subgroup_uniform (index [1 ], i32 ) # k
883- wave_index_y = assume_index_subgroup_uniform (index [0 ], i32 ) # m
884-
885- stride0 = arith_d .index_cast (IndexType .get (), dim_stride_0 )
886- y_offset = arith_d .muli (wave_index_y , stride0 )
887- global_base_offset = arith_d .addi (wave_index_x , y_offset )
888- global_index_offset = arith_d .muli (global_base_offset , element_byte_index )
889-
890- global_ptr = memref_d .extract_aligned_pointer_as_index (global_value )
891- global_byte_address = arith_d .addi (global_ptr , global_index_offset )
892839
893- # calculate shared address
894- # 0. extract shared tile index from IndexSequence structure
895- # 1. calculate byte offset from tile indices and distributed shape
896- # 2. get shared memory pointer
897- # 3. move shared memory pointer by offset_byte to get shared memory address of a tile.
898- shared_buffer = _linearize_shared_mem (shared_value )
899-
900- shared_strides = strides_from_symbolic_shape (
901- IndexingContext .current (),
902- dst_memory .distributed_shape ,
903- allow_mixed_shapes = True ,
904- )
840+ index , _ , _ = _build_start_indices (emitter , global_tile_index_current )
905841
906842 shared_tile_index_current = {k : shared_tile_index [k ] for k in symbolic_shape }
907843 shared_tile_index_current = _subs_index_dict (
908844 shared_tile_index_current , {INPUT_SELECTOR : i }
909845 )
910846
911847 linearized_index = {
912- "linearized_idx" : linearize_index (shared_tile_index_current , shared_strides )
848+ "linearized_idx" : linearize_index (
849+ shared_tile_index_current , shared_tile_index_current
850+ )
913851 }
914852
915853 # Calculate shared memory offset from tile indices
916854 shared_index , _ , _ = _build_start_indices (emitter , linearized_index )
917855
918- shared_index_offset = arith_d .muli (shared_index [0 ], element_byte_index )
919- shared_byte_offset = arith_d .index_cast (i32 , shared_index_offset )
920-
921- shared_ptr = memref_d .extract_aligned_pointer_as_index (shared_buffer )
922- shared_ptr = arith_d .index_cast (i32 , shared_ptr )
923-
924- shared_ptr_base_offset = memref_d .extract_strided_metadata (shared_buffer )[1 ]
925- shared_ptr_base_offset = arith_d .index_cast (i32 , shared_ptr_base_offset )
926-
927- shared_byte_address = arith_d .addi (shared_ptr_base_offset , shared_byte_offset )
928- shared_byte_address = arith_d .addi (shared_ptr , shared_byte_address )
929-
930- # assume no mapping
931- def lshift (value , bits ):
932- sh = arith_d .constant (value .type , bits )
933- val = arith_d .shli (value , sh )
934- return val
856+ ir_type = IrType .parse (element_type .dtype .ir_type_asm ())
857+ dmaType = amdgpu_d .TDMBaseType .get (ir_type )
935858
936- def rshift (value , bits ):
937- sh = arith_d .constant (value .type , bits )
938- val = arith_d .shrui (value , sh )
939- return val
940-
941- # pack global address of a tile
942- # 1. get lower 32 bit from global value
943- global_val = arith_d .index_cast (i57 , global_byte_address ) # i57
944- global_val_lower = arith_d .trunci (i32 , global_val )
945- d0 = vector_d .insert (
946- global_val_lower , d0 , static_position = [2 ], dynamic_position = []
947- )
948- # 2. get rest of the upper 25 bit from global value and cast to i32
949- global_val_rest = rshift (global_val , 32 )
950- global_val_upper = arith_d .trunci (i32 , global_val_rest )
951- # 3. pack with image mode bit
952- mode = arith_d .constant (i32 , mode )
953- image_mode = lshift (mode , 30 )
954- pack = arith_d .ori (image_mode , global_val_upper )
955- d0 = vector_d .insert (pack , d0 , static_position = [3 ], dynamic_position = [])
956-
957- # insert shared addreess to descriptor 0
958- d0 = vector_d .insert (
959- shared_byte_address , d0 , static_position = [1 ], dynamic_position = []
859+ base = amdgpu_d .make_dma_base (
860+ dmaType ,
861+ global_value ,
862+ index ,
863+ shared_value ,
864+ shared_index ,
960865 )
961866
962- # valid tensor
963- valid_tensor = arith_d .constant (i32 , valid )
964- d0 = vector_d .insert (valid_tensor , d0 , static_position = [0 ], dynamic_position = [])
965-
966- # get data size val packed to i32
967- data_size_val = lshift (data_size , 16 )
968-
867+ pad_interval = None
868+ pad_amount = None
969869 original_dst = propagate_loop_carried_vars (dst )
970870 original_dst = get_custom (original_dst )
971871 if padding := original_dst .padding :
972872 unpadded_dim = int (subs_idxc (original_dst .unpadded_shape [- 1 ])) * bytewidth
973873 assert (
974874 unpadded_dim >= 8
975875 ), f"Invalid unpadded_dim for padding: { unpadded_dim } (must be at least 8 bytes)"
976- pad_enable = 1 << 20
977- pad_interval = int (math .log2 ((unpadded_dim // 4 ) - 1 )) << 22
978- pad_amount = ((padding * bytewidth ) // 4 - 1 ) << 25
979- pad_packed = pad_enable | pad_interval | pad_amount
980- data_size_val = arith_d .ori (
981- data_size_val , arith_d .constant (i32 , pad_packed )
876+ pad_interval = arith_d .constant (
877+ i32 , int (math .log2 ((unpadded_dim // 4 ) - 1 ))
982878 )
879+ pad_amount = arith_d .constant (i32 , ((padding * bytewidth ) // 4 - 1 ))
983880
984- local_multicast_mask = subs_idxc (safe_subs (multicast_mask , {INPUT_SELECTOR : i }))
985-
986- if local_multicast_mask :
881+ workgroup_mask = None
882+ if local_multicast_mask := subs_idxc (
883+ safe_subs (multicast_mask , {INPUT_SELECTOR : i })
884+ ):
987885 local_multicast_mask = sympy .simplify (local_multicast_mask )
988886 local_multicast_mask_val = gen_sympy_index (subs , local_multicast_mask )
989- local_multicast_mask_val = arith_d .index_cast (i32 , local_multicast_mask_val )
990- data_size_val = arith_d .ori (data_size_val , local_multicast_mask_val )
991-
992- d1 = vector_d .insert (
993- data_size_val , d1 , static_position = [0 ], dynamic_position = []
994- )
995-
996- # get lower 16 bit from tensor dim 0 and pack to i32
997- tensor_dim_0_lower = lshift (dim_size_0 , 16 )
998- d1 = vector_d .insert (
999- tensor_dim_0_lower , d1 , static_position = [1 ], dynamic_position = []
1000- )
1001-
1002- # get upper 16 bit from tensor dim 0 and lower 16 bit from tensor dim 1, pack to i32
1003- tensor_dim_0_upper = rshift (dim_size_0 , 16 )
1004- tensor_dim_1_lower = lshift (dim_size_1 , 16 )
1005- pack = arith_d .ori (tensor_dim_1_lower , tensor_dim_0_upper )
1006- d1 = vector_d .insert (pack , d1 , static_position = [2 ], dynamic_position = [])
1007-
1008- # get upper 16 bit from tensor dim 1, packed with tile size 0
1009- tensor_dim_1_upper = rshift (dim_size_1 , 16 )
1010- tile_size_0_shift = lshift (tile_size_0 , 16 )
1011- pack = arith_d .ori (tensor_dim_1_upper , tile_size_0_shift )
1012- d1 = vector_d .insert (pack , d1 , static_position = [3 ], dynamic_position = [])
1013-
1014- # tile size 1 is in good form
1015- d1 = vector_d .insert (tile_size_1 , d1 , static_position = [4 ], dynamic_position = [])
1016-
1017- # truncate upper 16 bit from dim stride 0 -> i48 to i32
1018- dim_stride_0_trunc = arith_d .trunci (i32 , dim_stride_0 )
1019- d1 = vector_d .insert (
1020- dim_stride_0_trunc , d1 , static_position = [5 ], dynamic_position = []
887+ workgroup_mask = arith_d .index_cast (i16 , local_multicast_mask_val )
888+ workgroup_mask = vector_d .from_elements (v1i16 , [workgroup_mask ])
889+ workgroup_mask = vector_d .bitcast (v16i1 , workgroup_mask )
890+
891+ desc = amdgpu_d .make_dma_descriptor (
892+ base ,
893+ local_bounds ,
894+ [ShapedType .get_dynamic_size (), ShapedType .get_dynamic_size ()],
895+ None ,
896+ strides ,
897+ distributed_shape_vals ,
898+ [ShapedType .get_dynamic_size (), ShapedType .get_dynamic_size ()],
899+ None ,
900+ workgroup_mask = workgroup_mask ,
901+ pad_amount = pad_amount ,
902+ pad_interval = pad_interval ,
1021903 )
1022904
1023- # get upper 16 bit from dim stride 0, get lower 16 bit from dim stride 1, packed to i32
1024- dim_stride_0_upper = rshift (dim_stride_0 , 32 )
1025- dim_stride_0_trunc = arith_d .trunci (i32 , dim_stride_0_upper )
1026- dim_stride_1_lower = arith_d .trunci (i32 , dim_stride_1 )
1027- dim_stride_1_trunc = lshift (dim_stride_1_lower , 16 )
1028- pack = arith_d .ori (dim_stride_0_trunc , dim_stride_1_trunc )
1029- d1 = vector_d .insert (pack , d1 , static_position = [6 ], dynamic_position = [])
1030-
1031- # shift dim stride 1 to get upper 32 bit and pack to i32
1032- dim_stride_1_sh = rshift (dim_stride_1 , 16 )
1033- pack = arith_d .trunci (i32 , dim_stride_1_sh )
1034- d1 = vector_d .insert (pack , d1 , static_position = [7 ], dynamic_position = [])
1035-
1036- d0_results .append (d0 )
1037- d1_results .append (d1 )
1038- d2_results .append (d2 )
1039- d3_results .append (d3 )
905+ d0_results .append (desc )
1040906
1041907 # Select the appropriate descriptors based on input_selector
1042908 # Build chained select operations for each descriptor
@@ -1060,13 +926,8 @@ def select_descriptor(results_list, input_selector_val):
1060926
1061927 input_selector_val = gen_sympy_index (subs , input_selector )
1062928 d0_selected = select_descriptor (d0_results , input_selector_val )
1063- d1_selected = select_descriptor (d1_results , input_selector_val )
1064- d2_selected = select_descriptor (d2_results , input_selector_val )
1065- d3_selected = select_descriptor (d3_results , input_selector_val )
1066929
1067- return rocdl_d .tensor_load_to_lds (
1068- d0_selected , d1_selected , d2_selected , d3_selected , 0
1069- )
930+ return amdgpu_d .tensor_load_to_lds (d0_selected )
1070931
1071932
1072933@handle_op (gather_to_lds )
0 commit comments