@@ -644,6 +644,8 @@ def mma(
644644 # CHECK: func.func @mma
645645
646646 ### global buffer is bound to %0, %1 and %2 : MK, NK, MN
647+ # CHECK: %[[C0:.*]] = arith.constant 0 : index
648+
647649 # CHECK: %[[SUBSPAN0:.*]] = stream.binding.subspan
648650 # CHECK: %[[SUBSPAN1:.*]] = stream.binding.subspan
649651 # CHECK: %[[SUBSPAN2:.*]] = stream.binding.subspan
@@ -658,14 +660,8 @@ def mma(
658660 # CHECK: %[[VIEW0:.*]] = memref.view %[[SMEM]][{{.*}}] : memref<4608xi8, #gpu.address_space<workgroup>> to memref<32x36xf16, #gpu.address_space<workgroup>>
659661 # CHECK: %[[VIEW1:.*]] = memref.view %[[SMEM]][{{.*}}] : memref<4608xi8, #gpu.address_space<workgroup>> to memref<32x36xf16, #gpu.address_space<workgroup>>
660662
661- ### get global buffer pointer
662- # CHECK: %[[INT_PTR_0:.+]] = memref.extract_aligned_pointer_as_index
663-
664- ### get shared buffer pointer
665- # CHECK: %[[CAST_3:.*]] = memref.reinterpret_cast %[[VIEW1]]
666- # CHECK: %[[INT_PTR_1:.+]] = memref.extract_aligned_pointer_as_index %[[CAST_3]]
667-
668- # CHECK: %[[D0:.*]] = vector.from_elements
663+ ### make DMA base
664+ # CHECK: %[[DMA_BASE0:.+]] = amdgpu.make_dma_base {{.*}}, %[[VIEW1]][{{.*}}]
669665
670666 # Cluster mask generation
671667 # CHECK: %[[COND0:.*]] = arith.cmpi eq, %{{.*}}, %{{.*}} : index
@@ -677,30 +673,17 @@ def mma(
677673 # CHECK: %[[MASK3:.*]] = arith.select %[[COND1]], %{{.*}}, %[[MASK2]] : index
678674 # CHECK: %[[MASK4:.*]] = arith.select %[[COND0]], %{{.*}}, %[[MASK3]] : index
679675
680- ### pack descriptors and invoke tensor load
681-
682- # CHECK: %[[TENSOR_DESC_0:.*]] = vector.from_elements
683- # CHECK-NOT: rocdl.tensor.load.to.lds
684- # CHECK-NOT: rocdl.s.wait.tensorcnt
685- # CHECK-NOT: amdgpu.lds_barrier
686-
687- ### get shared buffer pointer
688- # CHECK: %[[CAST_4:.*]] = memref.reinterpret_cast %[[VIEW0]]
689- # CHECK: %[[INT_PTR_2:.+]] = memref.extract_aligned_pointer_as_index %[[CAST_4]]
690- # CHECK: %[[INT_PTR_2_CAST:.+]] = arith.index_cast %[[INT_PTR_2]] : index to i32
691- # CHECK: %[[INT_PTR_2_CAST_ADDED:.+]] = arith.addi %[[INT_PTR_2_CAST]], %{{.*}} : i32
676+ # CHECK: %[[TENSOR_DESC_0:.*]] = amdgpu.make_dma_descriptor %[[DMA_BASE0:.+]] globalSize [%{{.*}}, 32] globalStride [32, 1] sharedSize [16, 32]
692677
693- ### pack descriptors and invoke tensor load
694- # CHECK: %[[D1:.*]] = vector.from_elements %{{.*}}, %[[INT_PTR_2_CAST_ADDED]], %{{.*}}, %{{.*}} : vector<4xi32>
695- # CHECK: %[[TENSOR_DESC_1:.*]] = vector.from_elements
678+ # CHECK: %[[DMA_BASE1:.+]] = amdgpu.make_dma_base {{.*}}, %[[VIEW0]][{{.*}}]
679+ # CHECK: %[[TENSOR_DESC_1:.*]] = amdgpu.make_dma_descriptor %[[DMA_BASE1:.+]] globalSize [%{{.*}}, 32] globalStride [32, 1] sharedSize [16, 32]
696680
697681 # Fused descriptors
698682 # CHECK: %[[SELECTED:.*]] = arith.cmpi eq, %{{.*}}, %[[C0]] : index
699- # CHECK: %[[D_FUSED:.*]] = arith.select %[[SELECTED]], %[[D0]], %[[D1]] : vector<4xi32>
700- # CHECK: %[[DESC_FUSED:.*]] = arith.select %[[SELECTED]], %[[TENSOR_DESC_0]], %[[TENSOR_DESC_1]] : vector<8xi32>
683+ # CHECK: %[[DESC_FUSED:.*]] = arith.select %[[SELECTED]], %[[TENSOR_DESC_0]], %[[TENSOR_DESC_1]] : !amdgpu.tdm_descriptor
701684
702685 ### resource provider
703- # CHECK: rocdl.tensor.load.to.lds %[[D_FUSED]], %[[ DESC_FUSED]], {{.*}}, {{.*}} cachepolicy {{.*}} : vector<4xi32>, vector<8xi32>
686+ # CHECK: amdgpu.tensor_load_to_lds %[[DESC_FUSED:.*]]
704687 # CHECK: rocdl.s.wait.tensorcnt 0
705688 # CHECK: rocdl.s.wait.dscnt 0
706689 # CHECK: rocdl.s.barrier.signal id = -1
0 commit comments