Skip to content

Commit 64d0e4a

Browse files
committed
update tests
Signed-off-by: Tim Gymnich <[email protected]>
1 parent d1fe3bd commit 64d0e4a

File tree

1 file changed

+7
-18
lines changed

1 file changed

+7
-18
lines changed

lit_tests/kernel/wave/mma.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,8 @@ def mma(
644644
# CHECK: func.func @mma
645645

646646
### global buffer is bound to %0, %1 and %2 : MK, NK, MN
647+
# CHECK: %[[C0:.*]] = arith.constant 0 : index
648+
647649
# CHECK: %[[SUBSPAN0:.*]] = stream.binding.subspan
648650
# CHECK: %[[SUBSPAN1:.*]] = stream.binding.subspan
649651
# CHECK: %[[SUBSPAN2:.*]] = stream.binding.subspan
@@ -671,30 +673,17 @@ def mma(
671673
# CHECK: %[[MASK3:.*]] = arith.select %[[COND1]], %{{.*}}, %[[MASK2]] : index
672674
# CHECK: %[[MASK4:.*]] = arith.select %[[COND0]], %{{.*}}, %[[MASK3]] : index
673675

674-
### pack descriptors and invoke tensor load
675-
676-
# CHECK: %[[TENSOR_DESC_0:.*]] = arith.select
677-
# CHECK-NOT: amdgpu.tensor_load_to_lds
678-
# CHECK-NOT: rocdl.s.wait.tensorcnt
679-
# CHECK-NOT: amdgpu.lds_barrier
680-
681-
### get shared buffer pointer
682-
# CHECK: %[[CAST_4:.*]] = memref.reinterpret_cast %[[VIEW0]]
683-
# CHECK: %[[INT_PTR_2:.+]] = memref.extract_aligned_pointer_as_index %[[CAST_4]]
684-
# CHECK: %[[INT_PTR_2_CAST:.+]] = arith.index_cast %[[INT_PTR_2]] : index to i32
685-
# CHECK: %[[INT_PTR_2_CAST_ADDED:.+]] = arith.addi %[[INT_PTR_2_CAST]], %{{.*}} : i32
676+
# CHECK: %[[TENSOR_DESC_0:.*]] = amdgpu.make_dma_descriptor %[[DMA_BASE0:.+]]
686677

687-
### pack descriptors and invoke tensor load
688-
# CHECK: %[[D1:.*]] = vector.from_elements %{{.*}}, %[[INT_PTR_2_CAST_ADDED]], %{{.*}}, %{{.*}} : vector<4xi32>
689-
# CHECK: %[[TENSOR_DESC_1:.*]] = vector.from_elements
678+
# CHECK: %[[DMA_BASE1:.+]] = amdgpu.make_dma_base {{.*}} : {{.*}} -> !amdgpu.tdm_base<f16>
679+
# CHECK: %[[TENSOR_DESC_1:.*]] = amdgpu.make_dma_descriptor %[[DMA_BASE1:.+]]
690680

691681
# Fused descriptors
692682
# CHECK: %[[SELECTED:.*]] = arith.cmpi eq, %{{.*}}, %[[C0]] : index
693-
# CHECK: %[[D_FUSED:.*]] = arith.select %[[SELECTED]], %[[D0]], %[[D1]] : vector<4xi32>
694-
# CHECK: %[[DESC_FUSED:.*]] = arith.select %[[SELECTED]], %[[TENSOR_DESC_0]], %[[TENSOR_DESC_1]] : vector<8xi32>
683+
# CHECK: %[[DESC_FUSED:.*]] = arith.select %[[SELECTED]], %[[TENSOR_DESC_0]], %[[TENSOR_DESC_1]] : !amdgpu.tdm_descriptor
695684

696685
### resource provider
697-
# CHECK: amdgpu.tensor_load_to_lds %[[D_FUSED]], %[[DESC_FUSED]], {{.*}}, {{.*}} cachepolicy {{.*}} : vector<4xi32>, vector<8xi32>
686+
# CHECK: amdgpu.tensor_load_to_lds %[[DESC_FUSED:.*]]
698687
# CHECK: rocdl.s.wait.tensorcnt 0
699688
# CHECK: rocdl.s.wait.dscnt 0
700689
# CHECK: rocdl.s.barrier.signal id = -1

0 commit comments

Comments
 (0)