@@ -644,6 +644,8 @@ def mma(
644644 # CHECK: func.func @mma
645645
646646 ### global buffer is bound to %0, %1 and %2 : MK, NK, MN
647+ # CHECK: %[[C0:.*]] = arith.constant 0 : index
648+
647649 # CHECK: %[[SUBSPAN0:.*]] = stream.binding.subspan
648650 # CHECK: %[[SUBSPAN1:.*]] = stream.binding.subspan
649651 # CHECK: %[[SUBSPAN2:.*]] = stream.binding.subspan
@@ -671,30 +673,17 @@ def mma(
671673 # CHECK: %[[MASK3:.*]] = arith.select %[[COND1]], %{{.*}}, %[[MASK2]] : index
672674 # CHECK: %[[MASK4:.*]] = arith.select %[[COND0]], %{{.*}}, %[[MASK3]] : index
673675
674- ### pack descriptors and invoke tensor load
675-
676- # CHECK: %[[TENSOR_DESC_0:.*]] = arith.select
677- # CHECK-NOT: amdgpu.tensor_load_to_lds
678- # CHECK-NOT: rocdl.s.wait.tensorcnt
679- # CHECK-NOT: amdgpu.lds_barrier
680-
681- ### get shared buffer pointer
682- # CHECK: %[[CAST_4:.*]] = memref.reinterpret_cast %[[VIEW0]]
683- # CHECK: %[[INT_PTR_2:.+]] = memref.extract_aligned_pointer_as_index %[[CAST_4]]
684- # CHECK: %[[INT_PTR_2_CAST:.+]] = arith.index_cast %[[INT_PTR_2]] : index to i32
685- # CHECK: %[[INT_PTR_2_CAST_ADDED:.+]] = arith.addi %[[INT_PTR_2_CAST]], %{{.*}} : i32
676+ # CHECK: %[[TENSOR_DESC_0:.*]] = amdgpu.make_dma_descriptor %[[DMA_BASE0:.+]]
686677
687- ### pack descriptors and invoke tensor load
688- # CHECK: %[[D1:.*]] = vector.from_elements %{{.*}}, %[[INT_PTR_2_CAST_ADDED]], %{{.*}}, %{{.*}} : vector<4xi32>
689- # CHECK: %[[TENSOR_DESC_1:.*]] = vector.from_elements
678+ # CHECK: %[[DMA_BASE1:.+]] = amdgpu.make_dma_base {{.*}} : {{.*}} -> !amdgpu.tdm_base<f16>
679+ # CHECK: %[[TENSOR_DESC_1:.*]] = amdgpu.make_dma_descriptor %[[DMA_BASE1:.+]]
690680
691681 # Fused descriptors
692682 # CHECK: %[[SELECTED:.*]] = arith.cmpi eq, %{{.*}}, %[[C0]] : index
693- # CHECK: %[[D_FUSED:.*]] = arith.select %[[SELECTED]], %[[D0]], %[[D1]] : vector<4xi32>
694- # CHECK: %[[DESC_FUSED:.*]] = arith.select %[[SELECTED]], %[[TENSOR_DESC_0]], %[[TENSOR_DESC_1]] : vector<8xi32>
683+ # CHECK: %[[DESC_FUSED:.*]] = arith.select %[[SELECTED]], %[[TENSOR_DESC_0]], %[[TENSOR_DESC_1]] : !amdgpu.tdm_descriptor
695684
696685 ### resource provider
697- # CHECK: amdgpu.tensor_load_to_lds %[[D_FUSED]], %[[ DESC_FUSED]], {{.*}}, {{.*}} cachepolicy {{.*}} : vector<4xi32>, vector<8xi32>
686+ # CHECK: amdgpu.tensor_load_to_lds %[[DESC_FUSED:.*]]
698687 # CHECK: rocdl.s.wait.tensorcnt 0
699688 # CHECK: rocdl.s.wait.dscnt 0
700689 # CHECK: rocdl.s.barrier.signal id = -1
0 commit comments