Skip to content

Commit 9e77f3e

Browse files
committed
tmp: start porting wave eq to gpu
1 parent 9efec19 commit 9e77f3e

File tree

1 file changed

+28
-19
lines changed

1 file changed

+28
-19
lines changed

examples/wave_equation.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,14 @@ def run(n, backend, datatype, benchmark_mode):
111111
t_end = 1.0
112112

113113
# coordinate arrays
114+
sync()
114115
x_t_2d = fromfunction(
115-
lambda i, j: xmin + i * dx + dx / 2, (nx, ny), dtype=dtype
116+
lambda i, j: xmin + i * dx + dx / 2, (nx, ny), dtype=dtype, device=""
116117
)
117118
y_t_2d = fromfunction(
118-
lambda i, j: ymin + j * dy + dy / 2, (nx, ny), dtype=dtype
119+
lambda i, j: ymin + j * dy + dy / 2, (nx, ny), dtype=dtype, device=""
119120
)
121+
sync()
120122

121123
T_shape = (nx, ny)
122124
U_shape = (nx + 1, ny)
@@ -132,7 +134,7 @@ def run(n, backend, datatype, benchmark_mode):
132134
info(f"Total DOFs: {dofs_T + dofs_U + dofs_V}")
133135

134136
# prognostic variables: elevation, (u, v) velocity
135-
e = create_full(T_shape, 0.0, dtype)
137+
# e = create_full(T_shape, 0.0, dtype)
136138
u = create_full(U_shape, 0.0, dtype)
137139
v = create_full(V_shape, 0.0, dtype)
138140

@@ -144,6 +146,8 @@ def run(n, backend, datatype, benchmark_mode):
144146
u2 = create_full(U_shape, 0.0, dtype)
145147
v2 = create_full(V_shape, 0.0, dtype)
146148

149+
sync()
150+
147151
def exact_elev(t, x_t_2d, y_t_2d, lx, ly):
148152
"""
149153
Exact solution for elevation field.
@@ -162,8 +166,11 @@ def exact_elev(t, x_t_2d, y_t_2d, lx, ly):
162166
sol_t = numpy.cos(2 * omega * t)
163167
return amp * sol_x * sol_y * sol_t
164168

165-
# inital elevation
166-
e[:, :] = exact_elev(0.0, x_t_2d, y_t_2d, lx, ly)
169+
# initial elevation
170+
# e[:, :] = exact_elev(0.0, x_t_2d, y_t_2d, lx, ly)
171+
# NOTE assignment fails, do not pre-allocate e
172+
e = exact_elev(0.0, x_t_2d, y_t_2d, lx, ly).to_device(device)
173+
sync()
167174

168175
# compute time step
169176
alpha = 0.5
@@ -215,6 +222,8 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
215222
v[:, 1:-1] = v[:, 1:-1] / 3.0 + 2.0 / 3.0 * (v2[:, 1:-1] + dt * dvdt)
216223
e[:, :] = e[:, :] / 3.0 + 2.0 / 3.0 * (e2[:, :] + dt * dedt)
217224

225+
sync()
226+
218227
t = 0
219228
i_export = 0
220229
next_t_export = 0
@@ -226,9 +235,9 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
226235
t = i * dt
227236

228237
if t >= next_t_export - 1e-8:
229-
_elev_max = np.max(e, all_axes)
230-
_u_max = np.max(u, all_axes)
231-
_total_v = np.sum(e + h, all_axes)
238+
_elev_max = 0 # np.max(e, all_axes)
239+
_u_max = 0 # np.max(u, all_axes)
240+
_total_v = 0 # np.sum(e + h, all_axes)
232241

233242
elev_max = float(_elev_max)
234243
u_max = float(_u_max)
@@ -263,17 +272,17 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
263272
duration = time_mod.perf_counter() - tic
264273
info(f"Duration: {duration:.2f} s")
265274

266-
e_exact = exact_elev(t, x_t_2d, y_t_2d, lx, ly)
267-
err2 = (e_exact - e) * (e_exact - e) * dx * dy / lx / ly
268-
err_L2 = math.sqrt(float(np.sum(err2, all_axes)))
269-
info(f"L2 error: {err_L2:7.5e}")
270-
271-
if nx == 128 and ny == 128 and not benchmark_mode:
272-
if datatype == "f32":
273-
assert numpy.allclose(err_L2, 7.2235471e-03, rtol=1e-4)
274-
else:
275-
assert numpy.allclose(err_L2, 7.224068445111e-03)
276-
info("SUCCESS")
275+
# e_exact = exact_elev(t, x_t_2d, y_t_2d, lx, ly)
276+
# err2 = (e_exact - e) * (e_exact - e) * dx * dy / lx / ly
277+
# err_L2 = math.sqrt(float(np.sum(err2, all_axes)))
278+
# info(f"L2 error: {err_L2:7.5e}")
279+
280+
# if nx == 128 and ny == 128 and not benchmark_mode:
281+
# if datatype == "f32":
282+
# assert numpy.allclose(err_L2, 7.2235471e-03, rtol=1e-4)
283+
# else:
284+
# assert numpy.allclose(err_L2, 7.224068445111e-03)
285+
# info("SUCCESS")
277286

278287
fini()
279288

0 commit comments

Comments
 (0)