Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/euler/acoustic_pulse.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import grudge.op as op
from grudge.array_context import PyOpenCLArrayContext, PytatoPyOpenCLArrayContext
from grudge.models.euler import ConservedEulerField, EulerOperator, InviscidWallBC
from grudge.shortcuts import rk4_step
from grudge.shortcuts import compiled_lsrk45_step


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -200,7 +200,7 @@ def rhs(t, q):
assert norm_q < 5

fields = actx.thaw(actx.freeze(fields))
fields = rk4_step(fields, t, dt, compiled_rhs)
fields = compiled_lsrk45_step(actx, fields, t, dt, compiled_rhs)
t += dt
step += 1

Expand Down
7 changes: 4 additions & 3 deletions examples/euler/vortex.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import grudge.op as op
from grudge.array_context import PyOpenCLArrayContext, PytatoPyOpenCLArrayContext
from grudge.models.euler import EulerOperator, vortex_initial_condition
from grudge.shortcuts import rk4_step
from grudge.shortcuts import compiled_lsrk45_step


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -126,6 +126,8 @@ def rhs(t, q):

vis = make_visualizer(dcoll)

fields = actx.freeze_thaw(fields)

# {{{ time stepping

step = 0
Expand All @@ -146,8 +148,7 @@ def rhs(t, q):
)
assert norm_q < 200

fields = actx.thaw(actx.freeze(fields))
fields = rk4_step(fields, t, dt, compiled_rhs)
fields = compiled_lsrk45_step(actx, fields, t, dt, compiled_rhs)
t += dt
step += 1

Expand Down
2 changes: 2 additions & 0 deletions examples/wave/wave-op-mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,8 @@ def rhs(t, w):
import time
start = time.time()

fields = actx.freeze_thaw(fields)

t = 0
t_final = 3
istep = 0
Expand Down
19 changes: 14 additions & 5 deletions grudge/shortcuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
THE SOFTWARE.
"""

from functools import partial

from arraycontext import BcastUntilActxArray
from arraycontext.context import ArrayContext
from pytools import memoize_in

from grudge.dof_desc import DD_VOLUME_ALL
Expand All @@ -33,19 +37,24 @@ def rk4_step(y, t, h, f):
return y + h/6*(k1 + 2*k2 + 2*k3 + k4)


def _lsrk45_update(y, a, b, h, rhs_val, residual=0):
residual = a*residual + h*rhs_val
y = y + b * residual
def _lsrk45_update(actx: ArrayContext, y, a, b, h, rhs_val, residual=None):
bcast = partial(BcastUntilActxArray, actx)
if residual is None:
residual = bcast(h) * rhs_val
else:
residual = bcast(a) * residual + bcast(h) * rhs_val

y = y + bcast(b) * residual
from pytools.obj_array import make_obj_array
return make_obj_array([y, residual])


def compiled_lsrk45_step(actx, y, t, h, f):
def compiled_lsrk45_step(actx: ArrayContext, y, t, h, f):
from leap.rk import LSRK4MethodBuilder

@memoize_in(actx, (compiled_lsrk45_step, "update"))
def get_state_updater():
return actx.compile(_lsrk45_update)
return actx.compile(partial(_lsrk45_update, actx))

update = get_state_updater()

Expand Down
10 changes: 6 additions & 4 deletions test/test_mpi_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from grudge import dof_desc, op
from grudge.array_context import MPIPyOpenCLArrayContext, MPIPytatoArrayContext
from grudge.discretization import make_discretization_collection
from grudge.shortcuts import rk4_step
from grudge.shortcuts import compiled_lsrk45_step


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -246,8 +246,8 @@ def source_f(actx, dcoll, t=0):
[dcoll.zeros(actx) for i in range(dcoll.dim)]
)

dt = actx.to_numpy(
wave_op.estimate_rk4_timestep(actx, dcoll, fields=fields))
dt = float(actx.to_numpy(
wave_op.estimate_rk4_timestep(actx, dcoll, fields=fields)))

wave_op.check_bc_coverage(local_mesh)

Expand Down Expand Up @@ -277,10 +277,12 @@ def rhs(t, w):
from grudge.shortcuts import make_visualizer
vis = make_visualizer(dcoll)

fields = actx.freeze_thaw(fields)

logmgr.tick_before()
for step in range(nsteps):
t = step*dt
fields = rk4_step(fields, t=t, h=dt, f=compiled_rhs)
fields = compiled_lsrk45_step(actx, fields, t=t, h=dt, f=compiled_rhs)
fields = actx.thaw(actx.freeze(fields))

norm = actx.to_numpy(op.norm(dcoll, fields, 2))
Expand Down
Loading