17
17
# pylint: disable=invalid-name, unused-argument, redefined-argument-from-local
18
18
"""Dispatch sort and scan operators to platform dependent implementation."""
19
19
20
- from tvm import topi
20
+ from tvm import topi , dlight , relax
21
21
from tvm .ir import Op
22
22
from tvm .ir .module import IRModule
23
23
from tvm .ir .transform import PassContext , module_pass
24
24
from tvm .target import Target
25
25
from tvm .contrib .thrust import can_use_thrust
26
- from tvm .relax import Expr , Function , Call , PyExprMutator , expr_functor , TensorStructInfo
26
+ from tvm .relax import PyExprMutator , expr_functor
27
27
28
28
29
29
@expr_functor .mutator
@@ -36,13 +36,17 @@ class SortScanDispatcher(PyExprMutator):
36
36
def __init__ (self , mod ):
37
37
super ().__init__ (mod )
38
38
39
- def _get_target (self , expr : Expr ) -> Target :
40
- sinfo = expr .struct_info
39
+ def _get_target (self , sinfo : relax .StructInfo ) -> Target :
41
40
# Get target information from TensorStructInfo
42
- if isinstance (sinfo , TensorStructInfo ):
41
+ if isinstance (sinfo , relax . TensorStructInfo ):
43
42
vdevice = sinfo .vdevice
44
43
if vdevice is not None :
45
44
return vdevice .target
45
+ elif isinstance (sinfo , relax .TupleStructInfo ):
46
+ for f in sinfo .fields :
47
+ tgt = self ._get_target (f )
48
+ if tgt != Target .current ():
49
+ return tgt
46
50
# Return the target in current context
47
51
target = Target .current ()
48
52
if target is None :
@@ -52,38 +56,94 @@ def _get_target(self, expr: Expr) -> Target:
52
56
)
53
57
return target
54
58
55
- def visit_call_ (self , call : Call ) -> Expr :
59
+ def _apply_dlight_gpu_fallback (self , target : Target , tir_call : relax .Call ) -> None :
60
+ # Apply dlight.gpu.Fallback() on GPU
61
+ gvar = tir_call .args [0 ]
62
+ assert isinstance (gvar , relax .GlobalVar )
63
+ scan_prim_func = self .builder_ .get ()[gvar ]
64
+ sch = dlight .base .transform ._apply_rules (
65
+ scan_prim_func ,
66
+ target ,
67
+ [
68
+ dlight .gpu .Fallback (),
69
+ ],
70
+ False ,
71
+ )
72
+ if sch is not None :
73
+ assert len (sch ) == 1
74
+ self .builder_ .update_func (gvar , sch [0 ].mod ["main" ].with_attr ("tir.is_scheduled" , 1 ))
75
+
76
+ def visit_call_ (self , call : relax .Call ) -> relax .Expr :
56
77
if not isinstance (call .op , Op ):
57
78
return super ().visit_call_ (call )
58
79
59
80
if call .op .name == "relax.sort" :
60
- tgt = self ._get_target (call )
81
+ tgt = self ._get_target (call .struct_info )
82
+ te_func = topi .sort
61
83
with tgt :
62
84
if can_use_thrust (tgt , "tvm.contrib.thrust.sort" ):
63
- return self .builder_ .call_te (
64
- topi .cuda .sort_thrust ,
65
- call .args [0 ],
66
- call .attrs .axis ,
67
- not call .attrs .descending ,
68
- )
69
- return self .builder_ .call_te (
70
- topi .cuda .sort if tgt .kind .name == "cuda" else topi .sort ,
71
- call .args [0 ],
72
- call .attrs .axis ,
73
- not call .attrs .descending ,
74
- )
75
-
76
- if call .op .name == "relax.cumsum" :
77
- tgt = self ._get_target (call )
78
- axis = int (call .attrs .axis ) if call .attrs .axis is not None else call .attrs .axis
85
+ te_func = topi .cuda .sort_thrust
86
+ elif tgt .kind .name == "cuda" :
87
+ te_func = topi .cuda .sort
88
+ return self .builder_ .call_te (
89
+ te_func ,
90
+ call .args [0 ],
91
+ call .attrs .axis ,
92
+ not call .attrs .descending ,
93
+ )
94
+ if call .op .name == "relax.argsort" :
95
+ tgt = self ._get_target (call .struct_info )
96
+ te_func = topi .argsort
79
97
with tgt :
80
- return self .builder_ .call_te (
81
- topi .cuda .cumsum if tgt .kind .name == "cuda" else topi .cumsum ,
82
- call .args [0 ],
83
- axis ,
84
- call .attrs .dtype ,
85
- )
86
-
98
+ if can_use_thrust (tgt , "tvm.contrib.thrust.sort" ):
99
+ te_func = topi .cuda .argsort_thrust
100
+ elif tgt .kind .name == "cuda" :
101
+ te_func = topi .cuda .argsort
102
+ return self .builder_ .call_te (
103
+ te_func ,
104
+ call .args [0 ],
105
+ axis = call .attrs .axis ,
106
+ is_ascend = not call .attrs .descending ,
107
+ dtype = call .attrs .dtype ,
108
+ )
109
+ if call .op .name == "relax.topk" :
110
+ tgt = self ._get_target (call .struct_info )
111
+ te_func = topi .topk
112
+ if can_use_thrust (tgt , "tvm.contrib.thrust.sort" ):
113
+ te_func = topi .cuda .topk_thrust
114
+ elif tgt .kind .name == "cuda" :
115
+ te_func = topi .cuda .topk
116
+ tir_call = self .builder_ .call_te (
117
+ te_func ,
118
+ call .args [0 ],
119
+ axis = call .attrs .axis ,
120
+ ret_type = call .attrs .ret_type ,
121
+ is_ascend = not call .attrs .largest ,
122
+ dtype = call .attrs .dtype ,
123
+ )
124
+ if tgt .kind .name != "cuda" :
125
+ return tir_call
126
+ # apply dlight gpu fallback
127
+ self ._apply_dlight_gpu_fallback (tgt , tir_call )
128
+ return tir_call
129
+ if call .op .name in ("relax.cumprod" , "relax.cumsum" ):
130
+ tgt = self ._get_target (call .struct_info )
131
+ axis = int (call .attrs .axis ) if call .attrs .axis is not None else call .attrs .axis
132
+ te_func = topi .cuda .cumsum if tgt .kind .name == "cuda" else topi .cumsum
133
+ if call .op .name == "relax.cumprod" :
134
+ te_func = topi .cuda .cumprod if tgt .kind .name == "cuda" else topi .cumprod
135
+ tir_call = self .builder_ .call_te (
136
+ te_func ,
137
+ call .args [0 ],
138
+ axis ,
139
+ call .attrs .dtype ,
140
+ call .attrs .exclusive ,
141
+ )
142
+ if tgt .kind .name != "cuda" :
143
+ return tir_call
144
+ # apply dlight gpu fallback
145
+ self ._apply_dlight_gpu_fallback (tgt , tir_call )
146
+ return tir_call
87
147
return super ().visit_call_ (call )
88
148
89
149
@@ -96,7 +156,7 @@ class DispatchSortScan:
96
156
def transform_module (self , mod : IRModule , ctx : PassContext ) -> IRModule :
97
157
sort_scan_dispater = SortScanDispatcher (mod )
98
158
for gv , func in mod .functions_items ():
99
- if isinstance (func , Function ):
159
+ if isinstance (func , relax . Function ):
100
160
func = sort_scan_dispater .visit_expr (func )
101
161
sort_scan_dispater .builder_ .update_func (gv , func )
102
- return sort_scan_dispater .builder_ .get ()
162
+ return sort_scan_dispater .builder_ .finalize ()
0 commit comments