Skip to content

Commit f278051

Browse files
Print entire proc for cursors (#694)
1 parent 50a55a8 commit f278051

18 files changed

+250
-145
lines changed

src/exo/LoopIR_pprint.py

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -625,35 +625,31 @@ def _print_cursor_proc(
625625
def _print_cursor_block(
626626
cur: Block, target: Cursor, env: PrintEnv, indent: str
627627
) -> list[str]:
628-
def if_cursor(c, move, k):
629-
try:
630-
return k(move(c))
631-
except InvalidCursorError:
632-
return []
633-
634-
def more_stmts(_):
635-
return [f'{indent}"..."']
628+
def while_cursor(c, move, k):
629+
s = []
630+
while True:
631+
try:
632+
c = move(c)
633+
s.expand(k(c))
634+
except:
635+
return s
636636

637637
def local_stmt(c):
638638
return _print_cursor_stmt(c, target, env, indent)
639639

640640
if isinstance(target, Gap) and target in cur:
641641
if target._type == GapType.Before:
642642
return [
643-
*if_cursor(target, lambda g: g.anchor().prev(2), more_stmts),
644-
*if_cursor(target, lambda g: g.anchor().prev(), local_stmt),
643+
*while_cursor(target.anchor(), lambda g: g.prev(), local_stmt),
645644
f"{indent}[GAP - Before]",
646-
*if_cursor(target, lambda g: g.anchor(), local_stmt),
647-
*if_cursor(target, lambda g: g.anchor().next(), more_stmts),
645+
*while_cursor(target.anchor(), lambda g: g.next(), local_stmt),
648646
]
649647
else:
650648
assert target._type == GapType.After
651649
return [
652-
*if_cursor(target, lambda g: g.anchor().prev(), more_stmts),
653-
*if_cursor(target, lambda g: g.anchor(), local_stmt),
650+
*while_cursor(target.anchor(), lambda g: g.prev(), local_stmt),
654651
f"{indent}[GAP - After]",
655-
*if_cursor(target, lambda g: g.anchor().next(), local_stmt),
656-
*if_cursor(target, lambda g: g.anchor().next(2), more_stmts),
652+
*while_cursor(target.anchor(), lambda g: g.next(), local_stmt),
657653
]
658654

659655
elif isinstance(target, Block) and target in cur:
@@ -662,21 +658,16 @@ def local_stmt(c):
662658
block.extend(local_stmt(stmt))
663659
block.append(f"{indent}# BLOCK END")
664660
return [
665-
*if_cursor(target, lambda g: g[0].prev(), more_stmts),
661+
*while_cursor(target[0], lambda g: g.prev(), local_stmt),
666662
*block,
667-
*if_cursor(target, lambda g: g[-1].next(), more_stmts),
663+
*while_cursor(target[-1], lambda g: g.next(), local_stmt),
668664
]
669665

670666
else:
671-
stmt = next(filter(lambda s: s.is_ancestor_of(target), cur), None)
672-
if stmt is None:
673-
return [f'{indent}"..."']
674-
675-
return [
676-
*if_cursor(stmt, lambda g: g.prev().before(), more_stmts),
677-
*local_stmt(stmt),
678-
*if_cursor(stmt, lambda g: g.next().after(), more_stmts),
679-
]
667+
block = []
668+
for stmt in cur:
669+
block.extend(local_stmt(stmt))
670+
return block
680671

681672

682673
def _print_cursor_stmt(

tests/golden/asplos25/test_higher_order/test_lrn.txt

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,41 @@ def bar(n: size, A: i8[n] @ DRAM):
22
for i in seq(0, n):
33
for j in seq(0, n):
44
tmp_a: i8[n] @ DRAM # <-- NODE
5-
...
5+
tmp_b: i8[n] @ DRAM
6+
tmp_a[i] = A[i]
7+
tmp_b[i] = A[i]
68
def bar(n: size, A: i8[n] @ DRAM):
79
for i in seq(0, n):
810
for j in seq(0, n):
9-
...
11+
tmp_a: i8[n] @ DRAM
1012
tmp_b: i8[n] @ DRAM # <-- NODE
11-
...
13+
tmp_a[i] = A[i]
14+
tmp_b[i] = A[i]
1215
def bar(n: size, A: i8[n] @ DRAM):
1316
for i in seq(0, n):
1417
for j in seq(0, n):
15-
...
18+
tmp_a: i8[n] @ DRAM
19+
tmp_b: i8[n] @ DRAM
1620
tmp_a[i] = A[i] # <-- NODE
17-
...
21+
tmp_b[i] = A[i]
1822
def bar(n: size, A: i8[n] @ DRAM):
1923
for i in seq(0, n):
2024
for j in seq(0, n):
21-
...
25+
tmp_a: i8[n] @ DRAM
26+
tmp_b: i8[n] @ DRAM
27+
tmp_a[i] = A[i]
2228
tmp_b[i] = A[i] # <-- NODE
2329
def bar(n: size, A: i8[n] @ DRAM):
2430
for i in seq(0, n):
2531
for j in seq(0, n): # <-- NODE
26-
...
32+
tmp_a: i8[n] @ DRAM
33+
tmp_b: i8[n] @ DRAM
34+
tmp_a[i] = A[i]
35+
tmp_b[i] = A[i]
2736
def bar(n: size, A: i8[n] @ DRAM):
2837
for i in seq(0, n): # <-- NODE
29-
...
38+
for j in seq(0, n):
39+
tmp_a: i8[n] @ DRAM
40+
tmp_b: i8[n] @ DRAM
41+
tmp_a[i] = A[i]
42+
tmp_b[i] = A[i]

tests/golden/test_cursors/test_basic_forwarding2.txt

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,15 @@ def filter1D(ow: size, kw: size, x: f32[ow + kw - 1] @ DRAM, y: f32[ow] @ DRAM,
22
w: f32[kw] @ DRAM):
33
for outXo in seq(0, ow / 4):
44
sum: f32[4] @ DRAM # <-- NODE
5-
...
6-
...
5+
for outXi in seq(0, 4):
6+
sum[outXi] = 0.0
7+
for k in seq(0, kw):
8+
sum[outXi] += x[4 * outXo + outXi + k] * w[k]
9+
y[4 * outXo + outXi] = sum[outXi]
10+
if ow % 4 > 0:
11+
for outXi in seq(0, ow % 4):
12+
sum: f32 @ DRAM
13+
sum = 0.0
14+
for k in seq(0, kw):
15+
sum += x[outXi + ow / 4 * 4 + k] * w[k]
16+
y[outXi + ow / 4 * 4] = sum

tests/golden/test_cursors/test_basic_forwarding3.txt

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,15 @@ def filter1D(ow: size, kw: size, x: f32[ow + kw - 1] @ DRAM, y: f32[ow] @ DRAM,
22
w: f32[kw] @ DRAM):
33
for outXo in seq(0, ow / 4):
44
sum: f32[4] @ DRAM # <-- NODE
5-
...
6-
...
5+
for outXi in seq(0, 4):
6+
sum[outXi] = 0.0
7+
for k in seq(0, kw):
8+
sum[outXi] += x[4 * outXo + outXi + k] * w[k]
9+
y[4 * outXo + outXi] = sum[outXi]
10+
if ow % 4 > 0:
11+
for outXi in seq(0, ow % 4):
12+
sum: f32 @ DRAM
13+
sum = 0.0
14+
for k in seq(0, kw):
15+
sum += x[outXi + ow / 4 * 4 + k] * w[k]
16+
y[outXi + ow / 4 * 4] = sum
Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
def foo(x: i8 @ DRAM):
22
for i in seq(0, 5):
33
for j in seq(0, 5): # <-- NODE
4-
...
4+
if i == 0:
5+
x = 1.0
56

67
def foo(x: i8 @ DRAM):
78
for i in seq(0, 5): # <-- NODE
8-
...
9+
for j in seq(0, 5):
10+
if i == 0:
11+
x = 1.0
Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
def foo(x: i8 @ DRAM):
2-
...
2+
for i in seq(0, 8):
3+
if i + 3 < -1:
4+
x = 0.0
5+
pass
36
for i in seq(0, 2): # <-- NODE
4-
...
7+
x = 1.0
58

69
def foo(x: i8 @ DRAM):
710
for i in seq(0, 8): # <-- NODE
8-
...
9-
...
11+
if i + 3 < -1:
12+
x = 0.0
13+
pass
14+
for i in seq(0, 2):
15+
x = 1.0
Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
def foo(x: i8 @ DRAM):
22
for i in seq(0, 8): # <-- NODE
3-
...
4-
...
3+
x = 1.0
4+
for j in seq(0, 2):
5+
x = 2.0
56

67
def foo(x: i8 @ DRAM):
7-
...
8+
for i in seq(0, 8):
9+
x = 1.0
810
for j in seq(0, 2): # <-- NODE
9-
...
11+
x = 2.0
Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
def foo():
2-
...
2+
src_0: i32 @ DRAM
3+
src_1: i32 @ DRAM
34
src_0 = 1.0 # <-- NODE
4-
...
5+
src_1 = 1.0
56
def foo():
6-
...
7+
src_0: i32 @ DRAM
8+
src_1: i32 @ DRAM
9+
src_0 = 1.0
710
src_1 = 1.0 # <-- NODE
811
def foo():
9-
...
12+
src_0: i32 @ DRAM
13+
src_1: i32 @ DRAM
1014
src_0 = 1.0 # <-- NODE
11-
...
15+
src_1 = 1.0
Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
def scal(n: size, alpha: R @ DRAM, x: [R][n] @ DRAM):
22
for io in seq(0, n / 8):
3-
...
3+
alphaReg: R[8] @ DRAM
4+
for ii in seq(0, 8):
5+
alphaReg[ii] = alpha
46
for ii in seq(0, 8):
57
x[8 * io + ii] = alphaReg[ii] * x[8 * io + ii] # <-- NODE
6-
...
8+
for ii in seq(0, n % 8):
9+
x[ii + n / 8 * 8] = alpha * x[ii + n / 8 * 8]
Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
def baz(n: size, m: size):
22
for i in seq(0, n):
33
for j in seq(0, m): # <-- NODE
4-
...
4+
x: f32 @ DRAM
5+
pass
6+
pass
7+
for k in seq(0, n):
8+
pass
9+
pass
510

611
def baz(n: size, m: size):
712
for i in seq(0, n):
813
for j in seq(0, m):
9-
...
14+
x: f32 @ DRAM
15+
pass
16+
pass
1017
for k in seq(0, n):
1118
# BLOCK START
1219
pass
@@ -19,15 +26,14 @@ def baz(n: size, m: size):
1926
# BLOCK START
2027
x: f32 @ DRAM
2128
# BLOCK END
22-
...
2329

2430
def baz(n: size, m: size):
2531
for i in seq(0, n):
2632
for j in seq(0, m):
27-
...
2833
# BLOCK START
2934
for k in seq(0, n):
30-
...
35+
pass
36+
pass
3137
# BLOCK END
3238

3339
def baz(n: size, m: size):
@@ -38,15 +44,14 @@ def baz(n: size, m: size):
3844
pass
3945
pass
4046
for k in seq(0, n):
41-
...
47+
pass
48+
pass
4249
# BLOCK END
4350

4451
def baz(n: size, m: size):
4552
for i in seq(0, n):
4653
for j in seq(0, m):
47-
...
4854
# BLOCK START
4955
pass
5056
pass
51-
# BLOCK END
52-
...
57+
# BLOCK END

0 commit comments

Comments
 (0)