Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 68 additions & 2 deletions jax_spice/analysis/solver_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,38 @@
DEFAULT_ABSTOL = 1e4 # Corresponds to ~10nV voltage accuracy with G=1e12


def _compute_all_diag_indices(
bcsr_indptr: Array,
bcsr_indices: Array,
n_unknowns: int,
) -> Array:
"""Pre-compute CSR indices for ALL diagonal elements.

Used to add Tikhonov regularization to sparse Jacobians to prevent
singular matrix errors in spsolve.

Args:
bcsr_indptr: CSR row pointers
bcsr_indices: CSR column indices
n_unknowns: Number of unknowns (rows/cols)

Returns:
Array of CSR data indices corresponding to diagonal elements
"""
indptr_np = np.asarray(bcsr_indptr)
indices_np = np.asarray(bcsr_indices)

diag_indices = []
for row in range(n_unknowns):
row_start, row_end = int(indptr_np[row]), int(indptr_np[row + 1])
for j in range(row_start, row_end):
if indices_np[j] == row:
diag_indices.append(j)
break

return jnp.array(diag_indices, dtype=jnp.int32)


def _compute_noi_masks(
noi_indices: Optional[Array],
n_nodes: int,
Expand Down Expand Up @@ -226,8 +258,12 @@ def body_fn(state):
J = J.at[noi_res_idx, noi_res_idx].set(1.0)
f = f.at[noi_res_idx].set(0.0)

# Add regularization for numerical stability (prevents singular matrix errors)
reg = 1e-14 * jnp.eye(J.shape[0], dtype=J.dtype)
J_reg = J + reg

# Solve linear system J @ delta = -f
delta = jax.scipy.linalg.solve(J, -f)
delta = jax.scipy.linalg.solve(J_reg, -f)

# Step limiting
max_delta = jnp.max(jnp.abs(delta))
Expand Down Expand Up @@ -328,6 +364,11 @@ def make_sparse_full_mna_solver(
noi_diag_indices = masks['noi_diag_indices']
noi_res_indices_arr = masks['noi_res_indices_arr']

# Pre-compute ALL diagonal indices for Tikhonov regularization
all_diag_indices = None
if use_precomputed:
all_diag_indices = _compute_all_diag_indices(bcsr_indptr, bcsr_indices, n_augmented)

# Create augmented residual mask (node equations + branch equations)
if masks['residual_mask'] is not None:
# NOI nodes should be masked in residual convergence check
Expand Down Expand Up @@ -393,6 +434,10 @@ def body_fn(state):
csr_data = csr_data.at[noi_diag_indices].set(1.0)
f_solve = f.at[noi_res_indices_arr].set(0.0)

# Add Tikhonov regularization to prevent singular matrix errors
if all_diag_indices is not None:
csr_data = csr_data.at[all_diag_indices].add(1e-14)

delta = spsolve(csr_data, bcsr_indices, bcsr_indptr, -f_solve, tol=1e-6)
else:
# Fallback: sort each iteration
Expand All @@ -407,6 +452,10 @@ def body_fn(state):
data = data.at[noi_diag_indices].set(1.0)
f_solve = f.at[noi_res_indices_arr].set(0.0)

# Add Tikhonov regularization (compute diag indices on-the-fly for fallback path)
fallback_diag_indices = _compute_all_diag_indices(J_bcsr.indptr, J_bcsr.indices, n_augmented)
data = data.at[fallback_diag_indices].add(1e-14)

delta = spsolve(data, J_bcsr.indices, J_bcsr.indptr, -f_solve, tol=1e-6)

# Step limiting
Expand Down Expand Up @@ -686,8 +735,12 @@ def body_fn(state):
J = J.at[noi_res_idx, noi_res_idx].set(1.0)
f = f.at[noi_res_idx].set(0.0)

# Add regularization for numerical stability (prevents singular matrix errors)
reg = 1e-14 * jnp.eye(J.shape[0], dtype=J.dtype)
J_reg = J + reg

# Solve linear system
delta = jax.scipy.linalg.solve(J, -f)
delta = jax.scipy.linalg.solve(J_reg, -f)

# Step limiting
max_delta = jnp.max(jnp.abs(delta))
Expand Down Expand Up @@ -772,6 +825,11 @@ def make_sparse_solver(
noi_diag_indices = masks['noi_diag_indices']
noi_res_indices_arr = masks['noi_res_indices_arr']

# Pre-compute ALL diagonal indices for Tikhonov regularization
all_diag_indices_sparse = None
if use_precomputed:
all_diag_indices_sparse = _compute_all_diag_indices(bcsr_indptr, bcsr_indices, n_unknowns)

def nr_solve(V_init: Array, vsource_vals: Array, isource_vals: Array,
Q_prev: Array, integ_c0: float | Array,
device_arrays_arg: Dict[str, Array],
Expand Down Expand Up @@ -827,6 +885,10 @@ def body_fn(state):
csr_data = csr_data.at[noi_diag_indices].set(1.0)
f_solve = f.at[noi_res_indices_arr].set(0.0)

# Add Tikhonov regularization to prevent singular matrix errors
if all_diag_indices_sparse is not None:
csr_data = csr_data.at[all_diag_indices_sparse].add(1e-14)

delta = spsolve(csr_data, bcsr_indices, bcsr_indptr, -f_solve, tol=1e-6)
else:
# Fallback: sort each iteration
Expand All @@ -841,6 +903,10 @@ def body_fn(state):
data = data.at[noi_diag_indices].set(1.0)
f_solve = f.at[noi_res_indices_arr].set(0.0)

# Add Tikhonov regularization (compute diag indices on-the-fly for fallback path)
fallback_diag_indices = _compute_all_diag_indices(J_bcsr.indptr, J_bcsr.indices, n_unknowns)
data = data.at[fallback_diag_indices].add(1e-14)

delta = spsolve(data, J_bcsr.indices, J_bcsr.indptr, -f_solve, tol=1e-6)

# Step limiting
Expand Down
Loading