Skip to content

Commit e1d71b3

Browse files
authored
[Unity] Add dlight.gpu.Fallback in DispatchSortScan, add argsort, topk, and cumprod (#16351)
1 parent 298ad2c commit e1d71b3

File tree

21 files changed

+1178
-274
lines changed

21 files changed

+1178
-274
lines changed

include/tvm/relax/attrs/sort.h

Lines changed: 0 additions & 52 deletions
This file was deleted.

include/tvm/relax/attrs/sorting.h

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file tvm/relax/attrs/sorting.h
22+
* \brief Attributes for sorting operators.
23+
*/
24+
#ifndef TVM_RELAX_ATTRS_SORTING_H_
25+
#define TVM_RELAX_ATTRS_SORTING_H_
26+
27+
#include <tvm/relax/expr.h>
28+
#include <tvm/tir/index_map.h>
29+
30+
namespace tvm {
31+
namespace relax {
32+
33+
/*! \brief Attributes used in sort operator */
34+
struct SortAttrs : public tvm::AttrsNode<SortAttrs> {
35+
int axis;
36+
bool descending;
37+
38+
TVM_DECLARE_ATTRS(SortAttrs, "relax.attrs.SortAttrs") {
39+
TVM_ATTR_FIELD(axis).set_default(-1).describe(
40+
"Axis along which the sort is computed."
41+
"The default the last axis is used.");
42+
TVM_ATTR_FIELD(descending)
43+
.set_default(false)
44+
.describe(
45+
"Whether to sort in descending order."
46+
"If it is not specified, it defaults to the ascending order.");
47+
}
48+
}; // struct SortAttrs
49+
50+
/*! \brief Attributes used in argsort operator */
51+
struct ArgsortAttrs : public tvm::AttrsNode<ArgsortAttrs> {
52+
int axis;
53+
bool descending;
54+
DataType dtype;
55+
56+
TVM_DECLARE_ATTRS(ArgsortAttrs, "relax.attrs.ArgsortAttrs") {
57+
TVM_ATTR_FIELD(axis).set_default(-1).describe(
58+
"Axis along which the argsort is computed."
59+
"The default the last axis is used.");
60+
TVM_ATTR_FIELD(descending)
61+
.set_default(false)
62+
.describe(
63+
"Whether to argsort in descending order."
64+
"If it is not specified, it defaults to the ascending order.");
65+
TVM_ATTR_FIELD(dtype)
66+
.set_default(NullValue<DataType>())
67+
.describe("DType of the output indices.");
68+
}
69+
}; // struct ArgsortAttrs
70+
71+
/*! \brief Attributes used in topk operator */
72+
struct TopKAttrs : public tvm::AttrsNode<TopKAttrs> {
73+
int k;
74+
int axis;
75+
bool largest;
76+
String ret_type;
77+
DataType dtype;
78+
79+
TVM_DECLARE_ATTRS(TopKAttrs, "relax.attrs.TopKAttrs") {
80+
TVM_ATTR_FIELD(k).describe("Number of top elements to select");
81+
TVM_ATTR_FIELD(axis).set_default(-1).describe("Axis along which to sort the input tensor.");
82+
TVM_ATTR_FIELD(ret_type).set_default("both").describe(
83+
"The return type [both, values, indices]."
84+
"both - return both top k data and indices."
85+
"values - return top k data only."
86+
"indices - return top k indices only.");
87+
TVM_ATTR_FIELD(largest).set_default(true).describe(
88+
"Whether to return largest or smallest elements."
89+
"By default, return the largest k elements.");
90+
TVM_ATTR_FIELD(dtype)
91+
.set_default(NullValue<DataType>())
92+
.describe("Data type of the output indices.");
93+
}
94+
}; // struct TopKAttrs
95+
96+
} // namespace relax
97+
} // namespace tvm
98+
99+
#endif // TVM_RELAX_ATTRS_SORTING_H_

include/tvm/relax/attrs/statistical.h

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,20 +42,24 @@ struct StatisticalAttrs : public tvm::AttrsNode<StatisticalAttrs> {
4242
}
4343
}; // struct StatisticalAttrs
4444

45-
/*! \brief Attributes used in cumsum operators */
46-
struct CumsumAttrs : public tvm::AttrsNode<CumsumAttrs> {
45+
/*! \brief Attributes used in scan operators like cumsum, cumprod */
46+
struct ScanopAttrs : public tvm::AttrsNode<ScanopAttrs> {
4747
Optional<Integer> axis;
4848
DataType dtype;
49+
Bool exclusive = Bool(false);
4950

50-
TVM_DECLARE_ATTRS(CumsumAttrs, "relax.attrs.CumsumAttrs") {
51+
TVM_DECLARE_ATTRS(ScanopAttrs, "relax.attrs.ScanopAttrs") {
5152
TVM_ATTR_FIELD(axis).describe(
52-
"Axis along which the cumulative sum is computed."
53-
"The default (None) is to compute the cumsum over the flattened array.");
53+
"The axis along which to perform the scan computation."
54+
"The default (None) is to compute over the flattened array.");
5455
TVM_ATTR_FIELD(dtype).describe(
55-
"Type of the returned array and of the accumulator in which the elements are summed."
56-
"If dtype is not specified, it defaults to the dtype of data.");
56+
"The output data type."
57+
"If dtype is not specified, it defaults to the dtype of input data.");
58+
TVM_ATTR_FIELD(exclusive)
59+
.describe("The first element is not included")
60+
.set_default(Bool(false));
5761
}
58-
}; // struct CumsumAttrs
62+
}; // struct ScanopAttrs
5963

6064
} // namespace relax
6165
} // namespace tvm

python/tvm/relax/backend/dispatch_sort_scan.py

Lines changed: 92 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
# pylint: disable=invalid-name, unused-argument, redefined-argument-from-local
1818
"""Dispatch sort and scan operators to platform dependent implementation."""
1919

20-
from tvm import topi
20+
from tvm import topi, dlight, relax
2121
from tvm.ir import Op
2222
from tvm.ir.module import IRModule
2323
from tvm.ir.transform import PassContext, module_pass
2424
from tvm.target import Target
2525
from tvm.contrib.thrust import can_use_thrust
26-
from tvm.relax import Expr, Function, Call, PyExprMutator, expr_functor, TensorStructInfo
26+
from tvm.relax import PyExprMutator, expr_functor
2727

2828

2929
@expr_functor.mutator
@@ -36,13 +36,17 @@ class SortScanDispatcher(PyExprMutator):
3636
def __init__(self, mod):
3737
super().__init__(mod)
3838

39-
def _get_target(self, expr: Expr) -> Target:
40-
sinfo = expr.struct_info
39+
def _get_target(self, sinfo: relax.StructInfo) -> Target:
4140
# Get target information from TensorStructInfo
42-
if isinstance(sinfo, TensorStructInfo):
41+
if isinstance(sinfo, relax.TensorStructInfo):
4342
vdevice = sinfo.vdevice
4443
if vdevice is not None:
4544
return vdevice.target
45+
elif isinstance(sinfo, relax.TupleStructInfo):
46+
for f in sinfo.fields:
47+
tgt = self._get_target(f)
48+
if tgt != Target.current():
49+
return tgt
4650
# Return the target in current context
4751
target = Target.current()
4852
if target is None:
@@ -52,38 +56,94 @@ def _get_target(self, expr: Expr) -> Target:
5256
)
5357
return target
5458

55-
def visit_call_(self, call: Call) -> Expr:
59+
def _apply_dlight_gpu_fallback(self, target: Target, tir_call: relax.Call) -> None:
60+
# Apply dlight.gpu.Fallback() on GPU
61+
gvar = tir_call.args[0]
62+
assert isinstance(gvar, relax.GlobalVar)
63+
scan_prim_func = self.builder_.get()[gvar]
64+
sch = dlight.base.transform._apply_rules(
65+
scan_prim_func,
66+
target,
67+
[
68+
dlight.gpu.Fallback(),
69+
],
70+
False,
71+
)
72+
if sch is not None:
73+
assert len(sch) == 1
74+
self.builder_.update_func(gvar, sch[0].mod["main"].with_attr("tir.is_scheduled", 1))
75+
76+
def visit_call_(self, call: relax.Call) -> relax.Expr:
5677
if not isinstance(call.op, Op):
5778
return super().visit_call_(call)
5879

5980
if call.op.name == "relax.sort":
60-
tgt = self._get_target(call)
81+
tgt = self._get_target(call.struct_info)
82+
te_func = topi.sort
6183
with tgt:
6284
if can_use_thrust(tgt, "tvm.contrib.thrust.sort"):
63-
return self.builder_.call_te(
64-
topi.cuda.sort_thrust,
65-
call.args[0],
66-
call.attrs.axis,
67-
not call.attrs.descending,
68-
)
69-
return self.builder_.call_te(
70-
topi.cuda.sort if tgt.kind.name == "cuda" else topi.sort,
71-
call.args[0],
72-
call.attrs.axis,
73-
not call.attrs.descending,
74-
)
75-
76-
if call.op.name == "relax.cumsum":
77-
tgt = self._get_target(call)
78-
axis = int(call.attrs.axis) if call.attrs.axis is not None else call.attrs.axis
85+
te_func = topi.cuda.sort_thrust
86+
elif tgt.kind.name == "cuda":
87+
te_func = topi.cuda.sort
88+
return self.builder_.call_te(
89+
te_func,
90+
call.args[0],
91+
call.attrs.axis,
92+
not call.attrs.descending,
93+
)
94+
if call.op.name == "relax.argsort":
95+
tgt = self._get_target(call.struct_info)
96+
te_func = topi.argsort
7997
with tgt:
80-
return self.builder_.call_te(
81-
topi.cuda.cumsum if tgt.kind.name == "cuda" else topi.cumsum,
82-
call.args[0],
83-
axis,
84-
call.attrs.dtype,
85-
)
86-
98+
if can_use_thrust(tgt, "tvm.contrib.thrust.sort"):
99+
te_func = topi.cuda.argsort_thrust
100+
elif tgt.kind.name == "cuda":
101+
te_func = topi.cuda.argsort
102+
return self.builder_.call_te(
103+
te_func,
104+
call.args[0],
105+
axis=call.attrs.axis,
106+
is_ascend=not call.attrs.descending,
107+
dtype=call.attrs.dtype,
108+
)
109+
if call.op.name == "relax.topk":
110+
tgt = self._get_target(call.struct_info)
111+
te_func = topi.topk
112+
if can_use_thrust(tgt, "tvm.contrib.thrust.sort"):
113+
te_func = topi.cuda.topk_thrust
114+
elif tgt.kind.name == "cuda":
115+
te_func = topi.cuda.topk
116+
tir_call = self.builder_.call_te(
117+
te_func,
118+
call.args[0],
119+
axis=call.attrs.axis,
120+
ret_type=call.attrs.ret_type,
121+
is_ascend=not call.attrs.largest,
122+
dtype=call.attrs.dtype,
123+
)
124+
if tgt.kind.name != "cuda":
125+
return tir_call
126+
# apply dlight gpu fallback
127+
self._apply_dlight_gpu_fallback(tgt, tir_call)
128+
return tir_call
129+
if call.op.name in ("relax.cumprod", "relax.cumsum"):
130+
tgt = self._get_target(call.struct_info)
131+
axis = int(call.attrs.axis) if call.attrs.axis is not None else call.attrs.axis
132+
te_func = topi.cuda.cumsum if tgt.kind.name == "cuda" else topi.cumsum
133+
if call.op.name == "relax.cumprod":
134+
te_func = topi.cuda.cumprod if tgt.kind.name == "cuda" else topi.cumprod
135+
tir_call = self.builder_.call_te(
136+
te_func,
137+
call.args[0],
138+
axis,
139+
call.attrs.dtype,
140+
call.attrs.exclusive,
141+
)
142+
if tgt.kind.name != "cuda":
143+
return tir_call
144+
# apply dlight gpu fallback
145+
self._apply_dlight_gpu_fallback(tgt, tir_call)
146+
return tir_call
87147
return super().visit_call_(call)
88148

89149

@@ -96,7 +156,7 @@ class DispatchSortScan:
96156
def transform_module(self, mod: IRModule, ctx: PassContext) -> IRModule:
97157
sort_scan_dispater = SortScanDispatcher(mod)
98158
for gv, func in mod.functions_items():
99-
if isinstance(func, Function):
159+
if isinstance(func, relax.Function):
100160
func = sort_scan_dispater.visit_expr(func)
101161
sort_scan_dispater.builder_.update_func(gv, func)
102-
return sort_scan_dispater.builder_.get()
162+
return sort_scan_dispater.builder_.finalize()

python/tvm/relax/op/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@
9999
from .qdq import quantize, dequantize
100100
from .search import argmax, argmin, where
101101
from .set import unique
102-
from .sort import sort
103-
from .statistical import cumsum, max, mean, min, prod, std, sum, variance
102+
from .sorting import sort, argsort, topk
103+
from .statistical import cumsum, cumprod, max, mean, min, prod, std, sum, variance
104104
from .ternary import ewise_fma
105105
from .unary import (
106106
abs,

python/tvm/relax/op/op_attrs.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,11 @@ class SortAttrs(Attrs):
119119
"""Attributes for sort operator"""
120120

121121

122+
@tvm._ffi.register_object("relax.attrs.ArgsortAttrs")
123+
class ArgsortAttrs(Attrs):
124+
"""Attributes for argsort operator"""
125+
126+
122127
@tvm._ffi.register_object("relax.attrs.SplitAttrs")
123128
class SplitAttrs(Attrs):
124129
"""Attributes used in split operator"""
@@ -154,9 +159,14 @@ class TileAttrs(Attrs):
154159
"""Attributes for tile operator"""
155160

156161

157-
@tvm._ffi.register_object("relax.attrs.CumsumAttrs")
158-
class CumsumAttrs(Attrs):
159-
"""Attributes for cumsum operator"""
162+
@tvm._ffi.register_object("relax.attrs.ScanopAttrs")
163+
class ScanopAttrs(Attrs):
164+
"""Attributes for scan operators"""
165+
166+
167+
@tvm._ffi.register_object("relax.attrs.TopKAttrs")
168+
class TopKAttrs(Attrs):
169+
"""Attributes for topk operators"""
160170

161171

162172
@tvm._ffi.register_object("relax.attrs.EinsumAttrs")

0 commit comments

Comments
 (0)