Skip to content

Commit ecb83c7

Browse files
committed
Use amdgpu.tensor_load_to_lds instead of rocdl_d.tensor_load_to_lds
Signed-off-by: Tim Gymnich <[email protected]>
1 parent 7171e98 commit ecb83c7

File tree

1 file changed

+43
-182
lines changed

1 file changed

+43
-182
lines changed

wave_lang/kernel/wave/codegen/read_write.py

Lines changed: 43 additions & 182 deletions
Original file line numberDiff line numberDiff line change
@@ -799,26 +799,17 @@ def handle_tensor_load_to_lds(emitter: WaveEmitter, node: fx.Node):
799799
destinations
800800
), "sources and destinations must have the same number of elements."
801801

802-
# construct default descriptors
802+
i1 = IntegerType.get_signless(1)
803+
i16 = IntegerType.get_signless(16)
803804
i32 = IntegerType.get_signless(32)
804-
i48 = IntegerType.get_signless(48)
805-
i57 = IntegerType.get_signless(57)
806-
807-
vec_type_4 = VectorType.get((4,), i32)
808-
vec_type_8 = VectorType.get((8,), i32)
809-
810-
c0 = arith_d.constant(i32, 0)
805+
v1i16 = VectorType.get([1], i16)
806+
v16i1 = VectorType.get([16], i1)
811807

812808
d0_results = []
813-
d1_results = []
814-
d2_results = []
815-
d3_results = []
816809

817810
subs = add_emitter_subs(emitter)
818811

819812
for i, (src, dst) in enumerate(zip(sources, destinations)):
820-
dst_memory = get_custom(dst)
821-
822813
symbolic_shape = _get_symbolic_shape(src)
823814
global_tile_index_current = {k: global_tile_index[k] for k in symbolic_shape}
824815
global_tile_index_current = _subs_index_dict(
@@ -833,210 +824,85 @@ def handle_tensor_load_to_lds(emitter: WaveEmitter, node: fx.Node):
833824
strides = strides_from_symbolic_shape(
834825
IndexingContext.current(), symbolic_shape, allow_mixed_shapes=True
835826
)
836-
# Descriptor assumes rightmost stride 1 and expect last stride as full data size
837-
strides = [strides[0] * symbolic_shape[0]] + strides[:-1]
838-
strides = [gen_sympy_index(subs, s) for s in strides]
839827

840828
distributed_shape_vals = [
841829
gen_sympy_index(subs, distributed_shape[s]) for s in symbolic_shape
842830
]
843831

844-
d0 = vector_d.broadcast(vec_type_4, c0)
845-
d1 = vector_d.broadcast(vec_type_8, c0)
846-
d2 = vector_d.broadcast(vec_type_4, c0)
847-
d3 = vector_d.broadcast(vec_type_4, c0)
848-
849-
# descriptor properties
850-
mode = 2 # vimage
851-
valid = 1
852-
dim_stride_1 = arith_d.index_cast(i48, strides[0])
853-
dim_stride_0 = arith_d.index_cast(i48, strides[1])
854-
tile_size_1 = arith_d.index_cast(i32, distributed_shape_vals[0])
855-
tile_size_0 = arith_d.index_cast(i32, distributed_shape_vals[1])
856-
dim_size_1 = arith_d.index_cast(i32, local_bounds[0])
857-
dim_size_0 = arith_d.index_cast(i32, local_bounds[1])
858-
859-
# 0: 1 byte; 1: 2 byte; 2: 4 byte; 3: 8 byte
860-
descriptor_type = lambda x: int(math.log2(x.bitwidth() >> 3))
861-
data_size = cast_py_value(emitter, descriptor_type(element_type), i32).ir_value
862-
863832
global_mem = cast_py_value(emitter, src)
864833
shared_mem = cast_py_value(emitter, dst)
865834

866835
global_value = global_mem.ir_value
867836
shared_value = shared_mem.ir_value
868837

869838
bytewidth = element_type.bitwidth() // 8
870-
element_byte_index = arith_d.constant(IndexType.get(), bytewidth)
871-
872-
# calculcate global address
873-
# 0. breakdown index sequence to WG & TH offsets : ele
874-
# 1. uniform per wave access : ele
875-
# 2. linearize global memory buffer
876-
# 3. offset = X + Y * tensor dim 0 stride : ele
877-
# 4. offset_byte = offset * element byte : byte
878-
# 5. get global memory pointer
879-
# 6. move global memory pointer by offset_byte to get global address of a tile : byte
880-
index, _, _ = _build_start_indices(emitter, global_tile_index_current)
881-
882-
wave_index_x = assume_index_subgroup_uniform(index[1], i32) # k
883-
wave_index_y = assume_index_subgroup_uniform(index[0], i32) # m
884-
885-
stride0 = arith_d.index_cast(IndexType.get(), dim_stride_0)
886-
y_offset = arith_d.muli(wave_index_y, stride0)
887-
global_base_offset = arith_d.addi(wave_index_x, y_offset)
888-
global_index_offset = arith_d.muli(global_base_offset, element_byte_index)
889-
890-
global_ptr = memref_d.extract_aligned_pointer_as_index(global_value)
891-
global_byte_address = arith_d.addi(global_ptr, global_index_offset)
892839

893-
# calculate shared address
894-
# 0. extract shared tile index from IndexSequence structure
895-
# 1. calculate byte offset from tile indices and distributed shape
896-
# 2. get shared memory pointer
897-
# 3. move shared memory pointer by offset_byte to get shared memory address of a tile.
898-
shared_buffer = _linearize_shared_mem(shared_value)
899-
900-
shared_strides = strides_from_symbolic_shape(
901-
IndexingContext.current(),
902-
dst_memory.distributed_shape,
903-
allow_mixed_shapes=True,
904-
)
840+
index, _, _ = _build_start_indices(emitter, global_tile_index_current)
905841

906842
shared_tile_index_current = {k: shared_tile_index[k] for k in symbolic_shape}
907843
shared_tile_index_current = _subs_index_dict(
908844
shared_tile_index_current, {INPUT_SELECTOR: i}
909845
)
910846

911847
linearized_index = {
912-
"linearized_idx": linearize_index(shared_tile_index_current, shared_strides)
848+
"linearized_idx": linearize_index(
849+
shared_tile_index_current, shared_tile_index_current
850+
)
913851
}
914852

915853
# Calculate shared memory offset from tile indices
916854
shared_index, _, _ = _build_start_indices(emitter, linearized_index)
917855

918-
shared_index_offset = arith_d.muli(shared_index[0], element_byte_index)
919-
shared_byte_offset = arith_d.index_cast(i32, shared_index_offset)
920-
921-
shared_ptr = memref_d.extract_aligned_pointer_as_index(shared_buffer)
922-
shared_ptr = arith_d.index_cast(i32, shared_ptr)
923-
924-
shared_ptr_base_offset = memref_d.extract_strided_metadata(shared_buffer)[1]
925-
shared_ptr_base_offset = arith_d.index_cast(i32, shared_ptr_base_offset)
926-
927-
shared_byte_address = arith_d.addi(shared_ptr_base_offset, shared_byte_offset)
928-
shared_byte_address = arith_d.addi(shared_ptr, shared_byte_address)
929-
930-
# assume no mapping
931-
def lshift(value, bits):
932-
sh = arith_d.constant(value.type, bits)
933-
val = arith_d.shli(value, sh)
934-
return val
856+
ir_type = IrType.parse(element_type.dtype.ir_type_asm())
857+
dmaType = amdgpu_d.TDMBaseType.get(ir_type)
935858

936-
def rshift(value, bits):
937-
sh = arith_d.constant(value.type, bits)
938-
val = arith_d.shrui(value, sh)
939-
return val
940-
941-
# pack global address of a tile
942-
# 1. get lower 32 bit from global value
943-
global_val = arith_d.index_cast(i57, global_byte_address) # i57
944-
global_val_lower = arith_d.trunci(i32, global_val)
945-
d0 = vector_d.insert(
946-
global_val_lower, d0, static_position=[2], dynamic_position=[]
947-
)
948-
# 2. get rest of the upper 25 bit from global value and cast to i32
949-
global_val_rest = rshift(global_val, 32)
950-
global_val_upper = arith_d.trunci(i32, global_val_rest)
951-
# 3. pack with image mode bit
952-
mode = arith_d.constant(i32, mode)
953-
image_mode = lshift(mode, 30)
954-
pack = arith_d.ori(image_mode, global_val_upper)
955-
d0 = vector_d.insert(pack, d0, static_position=[3], dynamic_position=[])
956-
957-
# insert shared addreess to descriptor 0
958-
d0 = vector_d.insert(
959-
shared_byte_address, d0, static_position=[1], dynamic_position=[]
859+
base = amdgpu_d.make_dma_base(
860+
dmaType,
861+
global_value,
862+
index,
863+
shared_value,
864+
shared_index,
960865
)
961866

962-
# valid tensor
963-
valid_tensor = arith_d.constant(i32, valid)
964-
d0 = vector_d.insert(valid_tensor, d0, static_position=[0], dynamic_position=[])
965-
966-
# get data size val packed to i32
967-
data_size_val = lshift(data_size, 16)
968-
867+
pad_interval = None
868+
pad_amount = None
969869
original_dst = propagate_loop_carried_vars(dst)
970870
original_dst = get_custom(original_dst)
971871
if padding := original_dst.padding:
972872
unpadded_dim = int(subs_idxc(original_dst.unpadded_shape[-1])) * bytewidth
973873
assert (
974874
unpadded_dim >= 8
975875
), f"Invalid unpadded_dim for padding: {unpadded_dim} (must be at least 8 bytes)"
976-
pad_enable = 1 << 20
977-
pad_interval = int(math.log2((unpadded_dim // 4) - 1)) << 22
978-
pad_amount = ((padding * bytewidth) // 4 - 1) << 25
979-
pad_packed = pad_enable | pad_interval | pad_amount
980-
data_size_val = arith_d.ori(
981-
data_size_val, arith_d.constant(i32, pad_packed)
876+
pad_interval = arith_d.constant(
877+
i32, int(math.log2((unpadded_dim // 4) - 1))
982878
)
879+
pad_amount = arith_d.constant(i32, ((padding * bytewidth) // 4 - 1))
983880

984-
local_multicast_mask = subs_idxc(safe_subs(multicast_mask, {INPUT_SELECTOR: i}))
985-
986-
if local_multicast_mask:
881+
workgroup_mask = None
882+
if local_multicast_mask := subs_idxc(
883+
safe_subs(multicast_mask, {INPUT_SELECTOR: i})
884+
):
987885
local_multicast_mask = sympy.simplify(local_multicast_mask)
988886
local_multicast_mask_val = gen_sympy_index(subs, local_multicast_mask)
989-
local_multicast_mask_val = arith_d.index_cast(i32, local_multicast_mask_val)
990-
data_size_val = arith_d.ori(data_size_val, local_multicast_mask_val)
991-
992-
d1 = vector_d.insert(
993-
data_size_val, d1, static_position=[0], dynamic_position=[]
994-
)
995-
996-
# get lower 16 bit from tensor dim 0 and pack to i32
997-
tensor_dim_0_lower = lshift(dim_size_0, 16)
998-
d1 = vector_d.insert(
999-
tensor_dim_0_lower, d1, static_position=[1], dynamic_position=[]
1000-
)
1001-
1002-
# get upper 16 bit from tensor dim 0 and lower 16 bit from tensor dim 1, pack to i32
1003-
tensor_dim_0_upper = rshift(dim_size_0, 16)
1004-
tensor_dim_1_lower = lshift(dim_size_1, 16)
1005-
pack = arith_d.ori(tensor_dim_1_lower, tensor_dim_0_upper)
1006-
d1 = vector_d.insert(pack, d1, static_position=[2], dynamic_position=[])
1007-
1008-
# get upper 16 bit from tensor dim 1, packed with tile size 0
1009-
tensor_dim_1_upper = rshift(dim_size_1, 16)
1010-
tile_size_0_shift = lshift(tile_size_0, 16)
1011-
pack = arith_d.ori(tensor_dim_1_upper, tile_size_0_shift)
1012-
d1 = vector_d.insert(pack, d1, static_position=[3], dynamic_position=[])
1013-
1014-
# tile size 1 is in good form
1015-
d1 = vector_d.insert(tile_size_1, d1, static_position=[4], dynamic_position=[])
1016-
1017-
# truncate upper 16 bit from dim stride 0 -> i48 to i32
1018-
dim_stride_0_trunc = arith_d.trunci(i32, dim_stride_0)
1019-
d1 = vector_d.insert(
1020-
dim_stride_0_trunc, d1, static_position=[5], dynamic_position=[]
887+
workgroup_mask = arith_d.index_cast(i16, local_multicast_mask_val)
888+
workgroup_mask = vector_d.from_elements(v1i16, [workgroup_mask])
889+
workgroup_mask = vector_d.bitcast(v16i1, workgroup_mask)
890+
891+
desc = amdgpu_d.make_dma_descriptor(
892+
base,
893+
local_bounds,
894+
[ShapedType.get_dynamic_size(), ShapedType.get_dynamic_size()],
895+
None,
896+
strides,
897+
distributed_shape_vals,
898+
[ShapedType.get_dynamic_size(), ShapedType.get_dynamic_size()],
899+
None,
900+
workgroup_mask=workgroup_mask,
901+
pad_amount=pad_amount,
902+
pad_interval=pad_interval,
1021903
)
1022904

1023-
# get upper 16 bit from dim stride 0, get lower 16 bit from dim stride 1, packed to i32
1024-
dim_stride_0_upper = rshift(dim_stride_0, 32)
1025-
dim_stride_0_trunc = arith_d.trunci(i32, dim_stride_0_upper)
1026-
dim_stride_1_lower = arith_d.trunci(i32, dim_stride_1)
1027-
dim_stride_1_trunc = lshift(dim_stride_1_lower, 16)
1028-
pack = arith_d.ori(dim_stride_0_trunc, dim_stride_1_trunc)
1029-
d1 = vector_d.insert(pack, d1, static_position=[6], dynamic_position=[])
1030-
1031-
# shift dim stride 1 to get upper 32 bit and pack to i32
1032-
dim_stride_1_sh = rshift(dim_stride_1, 16)
1033-
pack = arith_d.trunci(i32, dim_stride_1_sh)
1034-
d1 = vector_d.insert(pack, d1, static_position=[7], dynamic_position=[])
1035-
1036-
d0_results.append(d0)
1037-
d1_results.append(d1)
1038-
d2_results.append(d2)
1039-
d3_results.append(d3)
905+
d0_results.append(desc)
1040906

1041907
# Select the appropriate descriptors based on input_selector
1042908
# Build chained select operations for each descriptor
@@ -1060,13 +926,8 @@ def select_descriptor(results_list, input_selector_val):
1060926

1061927
input_selector_val = gen_sympy_index(subs, input_selector)
1062928
d0_selected = select_descriptor(d0_results, input_selector_val)
1063-
d1_selected = select_descriptor(d1_results, input_selector_val)
1064-
d2_selected = select_descriptor(d2_results, input_selector_val)
1065-
d3_selected = select_descriptor(d3_results, input_selector_val)
1066929

1067-
return rocdl_d.tensor_load_to_lds(
1068-
d0_selected, d1_selected, d2_selected, d3_selected, 0
1069-
)
930+
return amdgpu_d.tensor_load_to_lds(d0_selected)
1070931

1071932

1072933
@handle_op(gather_to_lds)

0 commit comments

Comments
 (0)