Skip to content

Commit e995e6a

Browse files
authored
Merge pull request #13 from McLavish/jiahao/npbenchs
Jiahao/npbenchs
2 parents e06985c + fad77da commit e995e6a

File tree

12 files changed

+536
-0
lines changed

12 files changed

+536
-0
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"timeout": 60,
3+
"memory": 2048,
4+
"languages": ["python"],
5+
"modules": ["storage"]
6+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
size_generators = {
2+
"test": {"ny": 61, "nx": 61, "nit": 5, "rho": 1.0, "nu": 0.1, "F": 1.0},
3+
"small": {"ny": 121, "nx": 121, "nit": 10, "rho": 1.0, "nu": 0.1, "F": 1.0},
4+
"large": {"ny": 201, "nx": 201, "nit": 20, "rho": 1.0, "nu": 0.1, "F": 1.0},
5+
}
6+
7+
8+
def generate_input(
9+
data_dir,
10+
size,
11+
benchmarks_bucket,
12+
input_paths,
13+
output_paths,
14+
upload_func,
15+
nosql_func,
16+
):
17+
return {"size": size_generators[size]}
Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
# Barba, Lorena A., and Forsyth, Gilbert F. (2018).
2+
# CFD Python: the 12 steps to Navier-Stokes equations.
3+
# Journal of Open Source Education, 1(9), 21,
4+
# https://doi.org/10.21105/jose.00021
5+
# TODO: License
6+
# (c) 2017 Lorena A. Barba, Gilbert F. Forsyth.
7+
# All content is under Creative Commons Attribution CC-BY 4.0,
8+
# and all code is under BSD-3 clause (previously under MIT, and changed on March 8, 2018).
9+
10+
import datetime
11+
12+
import jax.numpy as jnp
13+
import jax
14+
from jax import lax
15+
from functools import partial
16+
17+
18+
@partial(jax.jit, static_argnums=(0,))
19+
def build_up_b(rho, dt, dx, dy, u, v):
20+
b = jnp.zeros_like(u)
21+
b = b.at[1:-1, 1:-1].set(
22+
(
23+
rho
24+
* (
25+
1
26+
/ dt
27+
* (
28+
(u[1:-1, 2:] - u[1:-1, 0:-2]) / (2 * dx)
29+
+ (v[2:, 1:-1] - v[0:-2, 1:-1]) / (2 * dy)
30+
)
31+
- ((u[1:-1, 2:] - u[1:-1, 0:-2]) / (2 * dx)) ** 2
32+
- 2
33+
* (
34+
(u[2:, 1:-1] - u[0:-2, 1:-1])
35+
/ (2 * dy)
36+
* (v[1:-1, 2:] - v[1:-1, 0:-2])
37+
/ (2 * dx)
38+
)
39+
- ((v[2:, 1:-1] - v[0:-2, 1:-1]) / (2 * dy)) ** 2
40+
)
41+
)
42+
)
43+
44+
# Periodic BC Pressure @ x = 2
45+
b = b.at[1:-1, -1].set(
46+
(
47+
rho
48+
* (
49+
1
50+
/ dt
51+
* ((u[1:-1, 0] - u[1:-1, -2]) / (2 * dx) + (v[2:, -1] - v[0:-2, -1]) / (2 * dy))
52+
- ((u[1:-1, 0] - u[1:-1, -2]) / (2 * dx)) ** 2
53+
- 2 * ((u[2:, -1] - u[0:-2, -1]) / (2 * dy) * (v[1:-1, 0] - v[1:-1, -2]) / (2 * dx))
54+
- ((v[2:, -1] - v[0:-2, -1]) / (2 * dy)) ** 2
55+
)
56+
)
57+
)
58+
59+
# Periodic BC Pressure @ x = 0
60+
b = b.at[1:-1, 0].set(
61+
(
62+
rho
63+
* (
64+
1
65+
/ dt
66+
* ((u[1:-1, 1] - u[1:-1, -1]) / (2 * dx) + (v[2:, 0] - v[0:-2, 0]) / (2 * dy))
67+
- ((u[1:-1, 1] - u[1:-1, -1]) / (2 * dx)) ** 2
68+
- 2 * ((u[2:, 0] - u[0:-2, 0]) / (2 * dy) * (v[1:-1, 1] - v[1:-1, -1]) / (2 * dx))
69+
- ((v[2:, 0] - v[0:-2, 0]) / (2 * dy)) ** 2
70+
)
71+
)
72+
)
73+
74+
return b
75+
76+
77+
@partial(jax.jit, static_argnums=(0,))
78+
def pressure_poisson_periodic(nit, p, dx, dy, b):
79+
def body_func(p, q):
80+
pn = p.copy()
81+
p = p.at[1:-1, 1:-1].set(
82+
((pn[1:-1, 2:] + pn[1:-1, 0:-2]) * dy**2 + (pn[2:, 1:-1] + pn[0:-2, 1:-1]) * dx**2)
83+
/ (2 * (dx**2 + dy**2))
84+
- dx**2 * dy**2 / (2 * (dx**2 + dy**2)) * b[1:-1, 1:-1]
85+
)
86+
87+
# Periodic BC Pressure @ x = 2
88+
p = p.at[1:-1, -1].set(
89+
((pn[1:-1, 0] + pn[1:-1, -2]) * dy**2 + (pn[2:, -1] + pn[0:-2, -1]) * dx**2)
90+
/ (2 * (dx**2 + dy**2))
91+
- dx**2 * dy**2 / (2 * (dx**2 + dy**2)) * b[1:-1, -1]
92+
)
93+
94+
# Periodic BC Pressure @ x = 0
95+
p = p.at[1:-1, 0].set(
96+
(
97+
((pn[1:-1, 1] + pn[1:-1, -1]) * dy**2 + (pn[2:, 0] + pn[0:-2, 0]) * dx**2)
98+
/ (2 * (dx**2 + dy**2))
99+
- dx**2 * dy**2 / (2 * (dx**2 + dy**2)) * b[1:-1, 0]
100+
)
101+
)
102+
103+
# Wall boundary conditions, pressure
104+
p = p.at[-1, :].set(p[-2, :]) # dp/dy = 0 at y = 2
105+
p = p.at[0, :].set(p[1, :]) # dp/dy = 0 at y = 0
106+
107+
return p, None
108+
109+
p, _ = lax.scan(body_func, p, jnp.arange(nit))
110+
111+
112+
@partial(jax.jit, static_argnums=(0, 7, 8, 9))
113+
def channel_flow(nit, u, v, dt, dx, dy, p, rho, nu, F):
114+
udiff = 1
115+
stepcount = 0
116+
117+
array_vals = (udiff, stepcount, u, v, p)
118+
119+
def conf_func(array_vals):
120+
udiff, _, _, _, _ = array_vals
121+
return udiff > 0.001
122+
123+
def body_func(array_vals):
124+
_, stepcount, u, v, p = array_vals
125+
126+
un = u.copy()
127+
vn = v.copy()
128+
129+
b = build_up_b(rho, dt, dx, dy, u, v)
130+
pressure_poisson_periodic(nit, p, dx, dy, b)
131+
132+
u = u.at[1:-1, 1:-1].set(
133+
un[1:-1, 1:-1]
134+
- un[1:-1, 1:-1] * dt / dx * (un[1:-1, 1:-1] - un[1:-1, 0:-2])
135+
- vn[1:-1, 1:-1] * dt / dy * (un[1:-1, 1:-1] - un[0:-2, 1:-1])
136+
- dt / (2 * rho * dx) * (p[1:-1, 2:] - p[1:-1, 0:-2])
137+
+ nu
138+
* (
139+
dt / dx**2 * (un[1:-1, 2:] - 2 * un[1:-1, 1:-1] + un[1:-1, 0:-2])
140+
+ dt / dy**2 * (un[2:, 1:-1] - 2 * un[1:-1, 1:-1] + un[0:-2, 1:-1])
141+
)
142+
+ F * dt
143+
)
144+
145+
v = v.at[1:-1, 1:-1].set(
146+
vn[1:-1, 1:-1]
147+
- un[1:-1, 1:-1] * dt / dx * (vn[1:-1, 1:-1] - vn[1:-1, 0:-2])
148+
- vn[1:-1, 1:-1] * dt / dy * (vn[1:-1, 1:-1] - vn[0:-2, 1:-1])
149+
- dt / (2 * rho * dy) * (p[2:, 1:-1] - p[0:-2, 1:-1])
150+
+ nu
151+
* (
152+
dt / dx**2 * (vn[1:-1, 2:] - 2 * vn[1:-1, 1:-1] + vn[1:-1, 0:-2])
153+
+ dt / dy**2 * (vn[2:, 1:-1] - 2 * vn[1:-1, 1:-1] + vn[0:-2, 1:-1])
154+
)
155+
)
156+
157+
# Periodic BC u @ x = 2
158+
u = u.at[1:-1, -1].set(
159+
un[1:-1, -1]
160+
- un[1:-1, -1] * dt / dx * (un[1:-1, -1] - un[1:-1, -2])
161+
- vn[1:-1, -1] * dt / dy * (un[1:-1, -1] - un[0:-2, -1])
162+
- dt / (2 * rho * dx) * (p[1:-1, 0] - p[1:-1, -2])
163+
+ nu
164+
* (
165+
dt / dx**2 * (un[1:-1, 0] - 2 * un[1:-1, -1] + un[1:-1, -2])
166+
+ dt / dy**2 * (un[2:, -1] - 2 * un[1:-1, -1] + un[0:-2, -1])
167+
)
168+
+ F * dt
169+
)
170+
171+
# Periodic BC u @ x = 0
172+
u = u.at[1:-1, 0].set(
173+
un[1:-1, 0]
174+
- un[1:-1, 0] * dt / dx * (un[1:-1, 0] - un[1:-1, -1])
175+
- vn[1:-1, 0] * dt / dy * (un[1:-1, 0] - un[0:-2, 0])
176+
- dt / (2 * rho * dx) * (p[1:-1, 1] - p[1:-1, -1])
177+
+ nu
178+
* (
179+
dt / dx**2 * (un[1:-1, 1] - 2 * un[1:-1, 0] + un[1:-1, -1])
180+
+ dt / dy**2 * (un[2:, 0] - 2 * un[1:-1, 0] + un[0:-2, 0])
181+
)
182+
+ F * dt
183+
)
184+
185+
# Periodic BC v @ x = 2
186+
v = v.at[1:-1, -1].set(
187+
vn[1:-1, -1]
188+
- un[1:-1, -1] * dt / dx * (vn[1:-1, -1] - vn[1:-1, -2])
189+
- vn[1:-1, -1] * dt / dy * (vn[1:-1, -1] - vn[0:-2, -1])
190+
- dt / (2 * rho * dy) * (p[2:, -1] - p[0:-2, -1])
191+
+ nu
192+
* (
193+
dt / dx**2 * (vn[1:-1, 0] - 2 * vn[1:-1, -1] + vn[1:-1, -2])
194+
+ dt / dy**2 * (vn[2:, -1] - 2 * vn[1:-1, -1] + vn[0:-2, -1])
195+
)
196+
)
197+
198+
# Periodic BC v @ x = 0
199+
v = v.at[1:-1, 0].set(
200+
vn[1:-1, 0]
201+
- un[1:-1, 0] * dt / dx * (vn[1:-1, 0] - vn[1:-1, -1])
202+
- vn[1:-1, 0] * dt / dy * (vn[1:-1, 0] - vn[0:-2, 0])
203+
- dt / (2 * rho * dy) * (p[2:, 0] - p[0:-2, 0])
204+
+ nu
205+
* (
206+
dt / dx**2 * (vn[1:-1, 1] - 2 * vn[1:-1, 0] + vn[1:-1, -1])
207+
+ dt / dy**2 * (vn[2:, 0] - 2 * vn[1:-1, 0] + vn[0:-2, 0])
208+
)
209+
)
210+
211+
# Wall BC: u,v = 0 @ y = 0,2
212+
u = u.at[0, :].set(0)
213+
u = u.at[-1, :].set(0)
214+
v = v.at[0, :].set(0)
215+
v = v.at[-1, :].set(0)
216+
217+
udiff = (jnp.sum(u) - jnp.sum(un)) / jnp.sum(u)
218+
stepcount += 1
219+
220+
return (udiff, stepcount, u, v, p)
221+
222+
_, stepcount, _, _, _ = lax.while_loop(conf_func, body_func, array_vals)
223+
224+
return stepcount
225+
226+
227+
def initialize(ny, nx):
228+
u = jnp.zeros((ny, nx), dtype=jnp.float64)
229+
v = jnp.zeros((ny, nx), dtype=jnp.float64)
230+
p = jnp.ones((ny, nx), dtype=jnp.float64)
231+
dx = 2 / (nx - 1)
232+
dy = 2 / (ny - 1)
233+
dt = 0.1 / ((nx - 1) * (ny - 1))
234+
return u, v, p, dx, dy, dt
235+
236+
237+
def handler(event):
238+
239+
if "size" in event:
240+
size = event["size"]
241+
ny = size["ny"]
242+
nx = size["nx"]
243+
nit = size["nit"]
244+
rho = size["rho"]
245+
nu = size["nu"]
246+
F = size["F"]
247+
248+
generate_begin = datetime.datetime.now()
249+
250+
u, v, p, dx, dy, dt = initialize(ny, nx)
251+
252+
generate_end = datetime.datetime.now()
253+
254+
process_begin = datetime.datetime.now()
255+
256+
results = channel_flow(nit, u, v, dt, dx, dy, p, rho, nu, F)
257+
258+
process_end = datetime.datetime.now()
259+
260+
# y_re_im = jnp.stack([jnp.real(result), jnp.imag(result)], axis=-1).tolist()
261+
262+
process_time = (process_end - process_begin) / datetime.timedelta(milliseconds=1)
263+
generate_time = (generate_end - generate_begin) / datetime.timedelta(milliseconds=1)
264+
265+
try:
266+
results = jax.device_get(results)
267+
except Exception:
268+
pass
269+
270+
if hasattr(results, "item"):
271+
results = results.item()
272+
elif hasattr(results, "tolist"):
273+
results = results.tolist()
274+
275+
return {
276+
"size": size,
277+
"result": results,
278+
"measurement": {"compute_time": process_time, "generate_time": generate_time},
279+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
jax[cuda12]
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"timeout": 60,
3+
"memory": 2048,
4+
"languages": ["python"],
5+
"modules": ["storage"]
6+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
size_generators = {
2+
"test": {"M": 2000, "N": 2000},
3+
"small": {"M": 5000, "N": 5000},
4+
"large": {"M": 16000, "N": 16000},
5+
}
6+
7+
8+
def generate_input(
9+
data_dir,
10+
size,
11+
benchmarks_bucket,
12+
input_paths,
13+
output_paths,
14+
upload_func,
15+
nosql_func,
16+
):
17+
return {"size": size_generators[size]}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import datetime
2+
3+
import jax.numpy as jnp
4+
import jax
5+
6+
7+
@jax.jit
8+
def compute(array_1, array_2, a, b, c):
9+
return jnp.clip(array_1, 2, 10) * a + array_2 * b + c
10+
11+
12+
def initialize(M, N):
13+
from numpy.random import default_rng
14+
15+
rng = default_rng(42)
16+
array_1 = rng.uniform(0, 1000, size=(M, N)).astype(jnp.int64)
17+
array_2 = rng.uniform(0, 1000, size=(M, N)).astype(jnp.int64)
18+
a = jnp.int64(4)
19+
b = jnp.int64(3)
20+
c = jnp.int64(9)
21+
return array_1, array_2, a, b, c
22+
23+
24+
def handler(event):
25+
26+
if "size" in event:
27+
size = event["size"]
28+
M = size["M"]
29+
N = size["N"]
30+
31+
generate_begin = datetime.datetime.now()
32+
33+
array_1, array_2, a, b, c = initialize(M, N)
34+
35+
generate_end = datetime.datetime.now()
36+
37+
process_begin = datetime.datetime.now()
38+
39+
results = compute(array_1, array_2, a, b, c)
40+
41+
process_end = datetime.datetime.now()
42+
43+
# y_re_im = jnp.stack([jnp.real(result), jnp.imag(result)], axis=-1).tolist()
44+
45+
process_time = (process_end - process_begin) / datetime.timedelta(milliseconds=1)
46+
generate_time = (generate_end - generate_begin) / datetime.timedelta(milliseconds=1)
47+
48+
try:
49+
results = jax.device_get(results)
50+
except Exception:
51+
pass
52+
53+
if getattr(results, "ndim", 0) == 0 or getattr(results, "size", 0) == 1:
54+
results = results.item()
55+
else:
56+
results = results.tolist()
57+
58+
return {
59+
"size": size,
60+
"result": results,
61+
"measurement": {"compute_time": process_time, "generate_time": generate_time},
62+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
jax[cuda12]

0 commit comments

Comments
 (0)