Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 78 additions & 30 deletions circulax/solvers/assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,71 @@
algebra kernels.
"""

import functools

import jax
import jax.numpy as jnp
from jax import Array


def _real_physics(v: Array, p: Array, group, t1: float) -> tuple[Array, Array]:
return group.physics_func(y=v, args=p, t=t1)


def _complex_physics(
vr: Array, vi: Array, p: Array, group, t1: float
) -> tuple[Array, Array, Array, Array]:
v = vr + 1j * vi
f, q = group.physics_func(y=v, args=p, t=t1)
return f.real, f.imag, q.real, q.imag


def _primal_and_jac_real(
f, v: Array, p: Array
) -> tuple[tuple[Array, Array], tuple[Array, Array]]:
"""Compute f(v,p) and its Jacobian w.r.t. v in a single forward sweep.

Sweeps n unit tangents via ``jax.jvp``, extracting the primal from the
first sweep rather than computing it separately. Jacobian is returned in
``(n_eqs, n_vars)`` shape to match ``jax.jacfwd`` convention.
"""
n = v.shape[0]
g = lambda v_: f(v_, p) # close over p; differentiate w.r.t. v only
(f_vals, q_vals), (dfs, dqs) = jax.vmap(
lambda e: jax.jvp(g, (v,), (e,))
)(jnp.eye(n))
return (f_vals[0], q_vals[0]), (dfs.T, dqs.T)


def _primal_and_jac_complex(
f, vr: Array, vi: Array, p: Array
) -> tuple[
tuple[Array, Array, Array, Array],
tuple[Array, Array, Array, Array],
tuple[Array, Array, Array, Array],
]:
"""Compute f(vr,vi,p) and its Jacobian w.r.t. (vr, vi) in two forward sweeps.

Mirrors ``jax.jacfwd(f, argnums=(0, 1))``: sweeps unit tangents for vr
then vi, extracting the primal from the first sweep. Each Jacobian block
is returned in ``(n_eqs, n_vars)`` shape.
"""
n = vr.shape[0]
zeros_vr = jnp.zeros_like(vr)
zeros_vi = jnp.zeros_like(vi)
g = lambda vr_, vi_: f(vr_, vi_, p) # close over p; differentiate w.r.t. vr, vi only
(fr_s, fi_s, qr_s, qi_s), (dfr_r, dfi_r, dqr_r, dqi_r) = jax.vmap(
lambda e: jax.jvp(g, (vr, vi), (e, zeros_vi))
)(jnp.eye(n))
_, (dfr_i, dfi_i, dqr_i, dqi_i) = jax.vmap(
lambda e: jax.jvp(g, (vr, vi), (zeros_vr, e))
)(jnp.eye(n))
primal = (fr_s[0], fi_s[0], qr_s[0], qi_s[0])
jac_r = (dfr_r.T, dfi_r.T, dqr_r.T, dqi_r.T)
jac_i = (dfr_i.T, dfi_i.T, dqr_i.T, dqi_i.T)
return primal, jac_r, jac_i


def assemble_system_real(
y_guess: Array,
component_groups: dict,
Expand Down Expand Up @@ -67,11 +127,11 @@ def assemble_system_real(
group = component_groups[k]
v_locs = y_guess[group.var_indices]

def physics_at_t1(v: Array, p: Array) -> tuple[Array, Array]:
return group.physics_func(y=v, args=p, t=t1)
physics_at_t1 = functools.partial(_real_physics, group=group, t1=t1)

(f_l, q_l) = jax.vmap(physics_at_t1)(v_locs, group.params)
(df_l, dq_l) = jax.vmap(jax.jacfwd(physics_at_t1))(v_locs, group.params)
(f_l, q_l), (df_l, dq_l) = jax.vmap(
functools.partial(_primal_and_jac_real, physics_at_t1)
)(v_locs, group.params)

total_f = total_f.at[group.eq_indices].add(f_l)
total_q = total_q.at[group.eq_indices].add(q_l)
Expand Down Expand Up @@ -105,19 +165,18 @@ def assemble_residual_only_real(

Returns:
A two-tuple ``(total_f, total_q)`` where both arrays have shape
``(sys_size,)`` and ``dtype=float64``.
``(sys_size,)`` and ``dtype`` matching ``y_guess.dtype``.

"""
sys_size = y_guess.shape[0]
total_f = jnp.zeros(sys_size, dtype=jnp.float64)
total_q = jnp.zeros(sys_size, dtype=jnp.float64)
total_f = jnp.zeros(sys_size, dtype=y_guess.dtype)
total_q = jnp.zeros(sys_size, dtype=y_guess.dtype)

for k in sorted(component_groups.keys()):
group = component_groups[k]
v = y_guess[group.var_indices]

def physics_at_t1(v: Array, p: Array) -> tuple[Array, Array]:
return group.physics_func(y=v, args=p, t=t1)
physics_at_t1 = functools.partial(_real_physics, group=group, t1=t1)

f_l, q_l = jax.vmap(physics_at_t1)(v, group.params)

Expand Down Expand Up @@ -179,24 +238,18 @@ def assemble_system_complex(
group = component_groups[k]
v_r, v_i = y_real[group.var_indices], y_imag[group.var_indices]

def physics_split(
vr: Array, vi: Array, p: Array
) -> tuple[Array, Array, Array, Array]:
v = vr + 1j * vi
f, q = group.physics_func(y=v, args=p, t=t1)
return f.real, f.imag, q.real, q.imag
physics_split = functools.partial(_complex_physics, group=group, t1=t1)

fr, fi, qr, qi = jax.vmap(physics_split)(v_r, v_i, group.params)
(fr, fi, qr, qi), (dfr_r, dfi_r, dqr_r, dqi_r), (dfr_i, dfi_i, dqr_i, dqi_i) = (
jax.vmap(functools.partial(_primal_and_jac_complex, physics_split))(
v_r, v_i, group.params
)
)

idx_r, idx_i = group.eq_indices, group.eq_indices + half_size
total_f = total_f.at[idx_r].add(fr).at[idx_i].add(fi)
total_q = total_q.at[idx_r].add(qr).at[idx_i].add(qi)

jac_res = jax.vmap(jax.jacfwd(physics_split, argnums=(0, 1)))(
v_r, v_i, group.params
)
((dfr_r, dfr_i), (dfi_r, dfi_i), (dqr_r, dqr_i), (dqi_r, dqi_i)) = jac_res

vals_blocks[0].append((dfr_r + dqr_r / dt).reshape(-1)) # RR
vals_blocks[1].append((dfr_i + dqr_i / dt).reshape(-1)) # RI
vals_blocks[2].append((dfi_r + dqi_r / dt).reshape(-1)) # IR
Expand Down Expand Up @@ -229,26 +282,21 @@ def assemble_residual_only_complex(

Returns:
A two-tuple ``(total_f, total_q)`` where both arrays have shape
``(2 * num_vars,)`` and ``dtype=float64``.
``(2 * num_vars,)`` and ``dtype`` matching ``y_guess.dtype``.

"""
sys_size = y_guess.shape[0]
half_size = sys_size // 2
y_real, y_imag = y_guess[:half_size], y_guess[half_size:]

total_f = jnp.zeros(sys_size, dtype=jnp.float64)
total_q = jnp.zeros(sys_size, dtype=jnp.float64)
total_f = jnp.zeros(sys_size, dtype=y_guess.dtype)
total_q = jnp.zeros(sys_size, dtype=y_guess.dtype)

for k in sorted(component_groups.keys()):
group = component_groups[k]
v_r, v_i = y_real[group.var_indices], y_imag[group.var_indices]

def physics_split(
vr: Array, vi: Array, p: Array
) -> tuple[Array, Array, Array, Array]:
v = vr + 1j * vi
f, q = group.physics_func(y=v, args=p, t=t1)
return f.real, f.imag, q.real, q.imag
physics_split = functools.partial(_complex_physics, group=group, t1=t1)

fr, fi, qr, qi = jax.vmap(physics_split)(v_r, v_i, group.params)

Expand Down
Loading