Skip to content

Commit 50a55a8

Browse files
Implemented insert_noop_call scheduling operation (#661)
implement a new scheduling op to insert no-op function calls anywhere, addressing #565. This scheduling op will be useful for inserting prefetching and potentially synchronization primitives arbitrarily. For now, it only works if the body of the proc is a single `pass` statement, but we could potentially extend this in the future. --------- Co-authored-by: Yuka Ikarashi <[email protected]>
1 parent 0cf7e19 commit 50a55a8

File tree

10 files changed

+218
-58
lines changed

10 files changed

+218
-58
lines changed

src/exo/API.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from .API_types import ProcedureBase, ExoType
1212
from . import LoopIR as LoopIR
1313
from .LoopIR_compiler import run_compile, compile_to_strings
14-
from .LoopIR_unification import DoReplace, UnificationError
1514
from .configs import Config
1615
from .boundscheck import CheckBounds
1716
from .memory import Memory

src/exo/API_scheduling.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -765,6 +765,14 @@ def parse_arg(a):
765765
return buf_name, args
766766

767767

768+
class NewExprOrCustomWindowExprA(NewExprA):
769+
def __call__(self, expr_str, all_args):
770+
try:
771+
return NewExprA(self.cursor_arg, self.before)(expr_str, all_args)
772+
except:
773+
return CustomWindowExprA(self.cursor_arg, self.before)(expr_str, all_args)
774+
775+
768776
# --------------------------------------------------------------------------- #
769777
# - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * -
770778
# --------------------------------------------------------------------------- #
@@ -848,6 +856,15 @@ def insert_pass(proc, gap_cursor):
848856
return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd)
849857

850858

859+
@sched_op([GapCursorA, ProcA, ListA(NewExprOrCustomWindowExprA("gap_cursor"))])
860+
def insert_noop_call(proc, gap_cursor, instr, args):
861+
if len(args) != len(instr._loopir_proc.args):
862+
raise TypeError("Function argument count mismatch")
863+
864+
ir, fwd = scheduling.DoInsertNoopCall(gap_cursor._impl, instr._loopir_proc, args)
865+
return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd)
866+
867+
851868
@sched_op([])
852869
def delete_pass(proc):
853870
"""
@@ -1080,7 +1097,7 @@ def replace(proc, block_cursor, subproc, quiet=False):
10801097
except UnificationError:
10811098
if quiet:
10821099
raise
1083-
print(f"Failed to unify the following:\nSubproc:\n{subproc}Statements:\n")
1100+
print(f"Failed to unify the following:\nSubproc:\n{subproc}\nStatements:")
10841101
[print(sc._impl._node) for sc in block_cursor]
10851102
raise
10861103

src/exo/LoopIR_scheduling.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import exo.API as api
4242
from .pattern_match import match_pattern
4343
from .memory import DRAM
44+
from .typecheck import check_call_types
4445

4546
from functools import partial
4647

@@ -2727,6 +2728,54 @@ def DoInsertPass(gap):
27272728
return ir, fwd
27282729

27292730

2731+
def DoInsertNoopCall(gap, proc, args):
2732+
srcinfo = gap.parent()._node.srcinfo
2733+
2734+
body = proc.body
2735+
if not (len(body) == 1 and isinstance(body[0], LoopIR.Pass)):
2736+
# TODO: We should allow for a more general case, e.g. loops of passes
2737+
raise SchedulingError("Cannot insert a proc whose body is not pass")
2738+
2739+
syms_env = extract_env(gap.anchor())
2740+
2741+
def get_typ_mem(buf_name):
2742+
for name, typ, mem in syms_env:
2743+
if str(name) == buf_name:
2744+
return name, typ, mem
2745+
2746+
def process_slice(idx):
2747+
if not isinstance(idx, tuple):
2748+
return idx
2749+
2750+
buf_name, w_exprs = idx
2751+
name, typ, _ = get_typ_mem(buf_name)
2752+
2753+
idxs = []
2754+
win_shape = []
2755+
for w_e in w_exprs:
2756+
if isinstance(w_e, tuple):
2757+
lo, hi = w_e
2758+
win_shape.append(LoopIR.BinOp("-", hi, lo, hi.type, srcinfo))
2759+
idxs.append(LoopIR.Interval(lo, hi, srcinfo))
2760+
else:
2761+
idxs.append(LoopIR.Point(w_e, srcinfo))
2762+
2763+
as_tensor = T.Tensor(win_shape, True, typ.basetype())
2764+
w_typ = T.Window(typ, as_tensor, name, idxs)
2765+
return LoopIR.WindowExpr(name, idxs, w_typ, srcinfo)
2766+
2767+
args = [process_slice(arg) for arg in args]
2768+
call_stmt = LoopIR.Call(proc, args, srcinfo)
2769+
ir, fwd = gap._insert([call_stmt])
2770+
2771+
def err_handler(_, msg):
2772+
raise SchedulingError(f"Function argument type mismatch:" + msg)
2773+
2774+
check_call_types(err_handler, args, proc.args)
2775+
2776+
return ir, fwd
2777+
2778+
27302779
def DoDeleteConfig(proc_cursor, config_cursor):
27312780
eq_mod_config = Check_DeleteConfigWrite(proc_cursor._node, [config_cursor._node])
27322781
p, fwd = config_cursor._delete()

src/exo/platforms/x86.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,17 @@
33
from .. import instr, DRAM
44
from ..libs.memories import AVX2, AVX512
55

6+
# --------------------------------------------------------------------------- #
7+
# Prefetching
8+
# --------------------------------------------------------------------------- #
9+
10+
11+
@instr("_mm_prefetch(&{A_data}, {locality_hint});")
12+
def prefetch(A: [R][1] @ DRAM, locality_hint: size):
13+
assert 0 <= locality_hint
14+
assert locality_hint < 8
15+
pass
16+
617

718
# --------------------------------------------------------------------------- #
819
# AVX2 intrinsics
@@ -127,7 +138,7 @@ def mm256_broadcast_sd_scalar(out: [f64][4] @ AVX2, val: f64):
127138
out[i] = val
128139

129140

130-
@instr("{dst_data} = _mm512_fmadd_ps({dst_data}, {lhs_data}, {rhs_data});")
141+
@instr("{dst_data} = _mm256_fmadd_ps({dst_data}, {lhs_data}, {rhs_data});")
131142
def mm256_fmadd_ps_broadcast(
132143
dst: [f32][8] @ AVX2, lhs: [f32][8] @ AVX2, rhs: [f32][1] @ DRAM
133144
):
@@ -251,6 +262,39 @@ def mm256_add_epi16(out: [ui16][16] @ AVX2, x: [ui16][16] @ AVX2, y: [ui16][16]
251262
# --------------------------------------------------------------------------- #
252263

253264

265+
@instr("{dst_data} = _mm512_setzero_ps();")
266+
def mm512_setzero_ps(dst: [f32][16] @ AVX512):
267+
assert stride(dst, 0) == 1
268+
269+
for i in seq(0, 16):
270+
dst[i] = 0.0
271+
272+
273+
@instr("{out_data} = _mm512_add_ps({x_data}, {y_data});")
274+
def mm512_add_ps(out: [f32][16] @ AVX512, x: [f32][16] @ AVX512, y: [f32][16] @ AVX512):
275+
assert stride(out, 0) == 1
276+
assert stride(x, 0) == 1
277+
assert stride(y, 0) == 1
278+
279+
for i in seq(0, 16):
280+
out[i] = x[i] + y[i]
281+
282+
283+
@instr(
284+
"{out_data} = _mm512_mask_add_ps({out_data}, ((1 << {N}) - 1), {x_data}, {y_data});"
285+
)
286+
def mm512_mask_add_ps(
287+
N: size, out: [f32][16] @ AVX512, x: [f32][16] @ AVX512, y: [f32][16] @ AVX512
288+
):
289+
assert stride(out, 0) == 1
290+
assert stride(x, 0) == 1
291+
assert stride(y, 0) == 1
292+
293+
for i in seq(0, 16):
294+
if i < N:
295+
out[i] = x[i] + y[i]
296+
297+
254298
@instr("{dst_data} = _mm512_loadu_ps(&{src_data});")
255299
def mm512_loadu_ps(dst: [f32][16] @ AVX512, src: [f32][16] @ DRAM):
256300
assert stride(src, 0) == 1

src/exo/pyparser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def parse_num_type(self, node, is_arg=False):
484484
elif isinstance(node, pyast.Name) and node.id in Parser._prim_types:
485485
return Parser._prim_types[node.id]
486486
else:
487-
self.err(node, "unrecognized type: " + ast.dump(node))
487+
self.err(node, "unrecognized type: " + pyast.dump(node))
488488

489489
def parse_stmt_block(self, stmts):
490490
assert isinstance(stmts, list)

src/exo/stdlib/scheduling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
inline,
3636
replace,
3737
call_eqv,
38+
insert_noop_call,
3839
#
3940
# precision, memory, and window annotation setting
4041
set_precision,

src/exo/typecheck.py

Lines changed: 59 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,63 @@
3434
# The typechecker
3535

3636

37+
def check_call_types(err_handler, args, call_args):
38+
for call_a, sig_a in zip(args, call_args):
39+
if call_a.type == T.err:
40+
pass
41+
elif sig_a.type is T.size or sig_a.type is T.index:
42+
if not call_a.type.is_indexable():
43+
err_handler(
44+
call_a,
45+
"expected size or index type "
46+
"expression, "
47+
f"but got type {call_a.type}",
48+
)
49+
50+
elif sig_a.type is T.bool:
51+
if not call_a.type is T.bool:
52+
err_handler(
53+
call_a,
54+
"expected bool-type variable, " f"but got type {call_a.type}",
55+
)
56+
57+
elif sig_a.type is T.stride:
58+
if not call_a.type.is_stridable():
59+
err_handler(
60+
call_a,
61+
"expected stride-type variable, " f"but got type {call_a.type}",
62+
)
63+
64+
elif sig_a.type.is_numeric():
65+
if call_a.type.is_numeric():
66+
if len(call_a.type.shape()) != len(sig_a.type.shape()):
67+
err_handler(
68+
call_a,
69+
f"expected argument of type '{sig_a.type}', "
70+
f"but got type '{call_a.type}'",
71+
)
72+
73+
# ensure scalars are simply variable names
74+
elif (
75+
call_a.type.is_real_scalar()
76+
and not isinstance(call_a, LoopIR.ReadConfig)
77+
and not (isinstance(call_a, LoopIR.Read) and len(call_a.idx) == 0)
78+
):
79+
err_handler(
80+
call_a,
81+
"expected scalar arguments "
82+
"to be simply variable names "
83+
"for now",
84+
)
85+
else:
86+
err_handler(
87+
call_a,
88+
"expected numeric type expression, " f"but got type {call_a.type}",
89+
)
90+
else:
91+
assert False, "bad argument type case"
92+
93+
3794
class TypeChecker:
3895
def __init__(self, proc):
3996
self.uast_proc = proc
@@ -235,7 +292,6 @@ def check_single_stmt(self, stmt):
235292
return [LoopIR.WriteConfig(stmt.config, stmt.field, rhs, stmt.srcinfo)]
236293
elif isinstance(stmt, UAST.Pass):
237294
return [LoopIR.Pass(stmt.srcinfo)]
238-
239295
elif isinstance(stmt, UAST.If):
240296
cond = self.check_e(stmt.cond, is_index=True)
241297
if cond.type != T.err and cond.type != T.bool:
@@ -290,60 +346,11 @@ def check_single_stmt(self, stmt):
290346
for call_a, sig_a in zip(stmt.args, stmt.f.args)
291347
]
292348

293-
for call_a, sig_a in zip(args, stmt.f.args):
294-
if call_a.type == T.err:
295-
pass
296-
elif sig_a.type is T.size or sig_a.type is T.index:
297-
if not call_a.type.is_indexable():
298-
self.err(
299-
call_a,
300-
"expected size or index type "
301-
"expression, "
302-
f"but got type {call_a.type}",
303-
)
304-
305-
elif sig_a.type is T.bool:
306-
if not call_a.type is T.bool:
307-
self.err(
308-
call_a,
309-
"expected bool-type variable, "
310-
f"but got type {call_a.type}",
311-
)
312-
313-
elif sig_a.type is T.stride:
314-
if not call_a.type.is_stridable():
315-
self.err(
316-
call_a,
317-
"expected stride-type variable, "
318-
f"but got type {call_a.type}",
319-
)
320-
321-
elif sig_a.type.is_numeric():
322-
if len(call_a.type.shape()) != len(sig_a.type.shape()):
323-
self.err(
324-
call_a,
325-
f"expected argument of type '{sig_a.type}', "
326-
f"but got '{call_a.type}'",
327-
)
328-
329-
# ensure scalars are simply variable names
330-
elif call_a.type.is_real_scalar():
331-
if not isinstance(call_a, LoopIR.ReadConfig) and not (
332-
isinstance(call_a, LoopIR.Read) and len(call_a.idx) == 0
333-
):
334-
self.err(
335-
call_a,
336-
"expected scalar arguments "
337-
"to be simply variable names "
338-
"for now",
339-
)
340-
341-
else:
342-
assert False, "bad argument type case"
349+
check_call_types(self.err, args, stmt.f.args)
343350

344351
return [LoopIR.Call(stmt.f, args, stmt.srcinfo)]
345352
else:
346-
assert False, "not a loopir in check_stmts"
353+
assert False, f"not a loopir in check_stmts {type(stmt)}"
347354

348355
def check_w_access(self, e, orig_hi):
349356
if isinstance(e, UAST.Point):
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
def foo(n: size, x: i8[n] @ DRAM, locality_hint: size):
2+
assert locality_hint >= 0
3+
assert locality_hint < 8
4+
prefetch(x[1:2], locality_hint)
5+
pass

tests/test_schedules.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
import pytest
44

5-
from exo import ParseFragmentError
6-
from exo import proc, DRAM, Procedure, config
5+
from exo import ParseFragmentError, proc, DRAM, Procedure, config
76
from exo.libs.memories import GEMM_SCRATCH
87
from exo.stdlib.scheduling import *
98
from exo.platforms.x86 import *
@@ -4630,6 +4629,33 @@ def foo(x: f32[30], result: f32):
46304629
assert str(simplify(foo)) == golden
46314630

46324631

4632+
def test_insert_noop_call(golden):
4633+
@proc
4634+
def foo(n: size, x: i8[n], locality_hint: size):
4635+
assert locality_hint >= 0
4636+
assert locality_hint < 8
4637+
pass
4638+
4639+
foo = insert_noop_call(
4640+
foo, foo.find("pass").before(), prefetch, ["x[1:2]", "locality_hint"]
4641+
)
4642+
assert str(foo) == golden
4643+
4644+
4645+
def test_insert_noop_call_bad_args():
4646+
@proc
4647+
def foo(n: size, x: i8[n], locality_hint: size):
4648+
pass
4649+
4650+
with pytest.raises(TypeError, match="Function argument count mismatch"):
4651+
insert_noop_call(foo, foo.find("pass").before(), prefetch, [])
4652+
4653+
with pytest.raises(SchedulingError, match="Function argument type mismatch"):
4654+
insert_noop_call(
4655+
foo, foo.find("pass").before(), prefetch, ["n", "locality_hint"]
4656+
)
4657+
4658+
46334659
def test_old_lift_alloc_config(golden):
46344660
@config
46354661
class CFG:

tests/test_typecheck.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,3 +497,15 @@ def dot(m: size, x: R[1, 1], y: R[m]):
497497
@proc
498498
def proj(n: size, x: R[100, 10, 1], y: R[10, n]):
499499
dot(n, x[1], y[0])
500+
501+
502+
def test_numeric_type_mismatch():
503+
with pytest.raises(TypeError, match="but got type size"):
504+
505+
@proc
506+
def bar(n: R):
507+
pass
508+
509+
@proc
510+
def foo(n: size):
511+
bar(n)

0 commit comments

Comments
 (0)