From e206ebb479547680365fb5136d85ff32c4820234 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Mon, 26 Jan 2026 23:23:14 +0000 Subject: [PATCH 1/2] Add Tikhonov regularization to dense solvers to prevent singular matrix errors MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit JAX's scipy.linalg.solve raises a hard error on singular matrices, unlike UMFPACK which can issue a warning and continue. This was causing GPU CI tests to fail with "INTERNAL: Singular matrix in linear solve" for circuits with high condition numbers (like graetz with 1GΩ grounding resistors). The fix adds small Tikhonov regularization (1e-14 * I) to the Jacobian before solving, matching the approach already used in jax_spice/analysis/solver.py. This prevents numerical singularity without meaningfully affecting results. Applied to both make_dense_full_mna_solver and make_dense_solver factory functions. Co-developed-by: Claude Code v2.1.19 (claude-opus-4-5-20251101) --- jax_spice/analysis/solver_factories.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/jax_spice/analysis/solver_factories.py b/jax_spice/analysis/solver_factories.py index 1747b191..096f4f58 100644 --- a/jax_spice/analysis/solver_factories.py +++ b/jax_spice/analysis/solver_factories.py @@ -226,8 +226,12 @@ def body_fn(state): J = J.at[noi_res_idx, noi_res_idx].set(1.0) f = f.at[noi_res_idx].set(0.0) + # Add regularization for numerical stability (prevents singular matrix errors) + reg = 1e-14 * jnp.eye(J.shape[0], dtype=J.dtype) + J_reg = J + reg + # Solve linear system J @ delta = -f - delta = jax.scipy.linalg.solve(J, -f) + delta = jax.scipy.linalg.solve(J_reg, -f) # Step limiting max_delta = jnp.max(jnp.abs(delta)) @@ -686,8 +690,12 @@ def body_fn(state): J = J.at[noi_res_idx, noi_res_idx].set(1.0) f = f.at[noi_res_idx].set(0.0) + # Add regularization for numerical stability (prevents singular matrix errors) + reg = 1e-14 * jnp.eye(J.shape[0], dtype=J.dtype) + J_reg = J + reg + # Solve linear system - delta = jax.scipy.linalg.solve(J, -f) + delta = jax.scipy.linalg.solve(J_reg, -f) # Step limiting max_delta = jnp.max(jnp.abs(delta)) From 0ea81d847d36cc3e4eb6f8c116e3cdb06560bf1e Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Tue, 27 Jan 2026 22:23:27 +0000 Subject: [PATCH 2/2] Add Tikhonov regularization to sparse solvers Extend the regularization fix to sparse solvers (make_sparse_full_mna_solver and make_sparse_solver) which were still failing with "Singular matrix in linear solve" errors on GPU. For sparse matrices, regularization is added by: 1. Pre-computing CSR indices for all diagonal elements 2. Adding 1e-14 to those diagonal entries before spsolve This matches the dense solver fix but adapted for CSR sparse format. Co-developed-by: Claude Code v2.1.19 (claude-opus-4-5-20251101) --- jax_spice/analysis/solver_factories.py | 58 ++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/jax_spice/analysis/solver_factories.py b/jax_spice/analysis/solver_factories.py index 096f4f58..a31cb2be 100644 --- a/jax_spice/analysis/solver_factories.py +++ b/jax_spice/analysis/solver_factories.py @@ -37,6 +37,38 @@ DEFAULT_ABSTOL = 1e4 # Corresponds to ~10nV voltage accuracy with G=1e12 +def _compute_all_diag_indices( + bcsr_indptr: Array, + bcsr_indices: Array, + n_unknowns: int, +) -> Array: + """Pre-compute CSR indices for ALL diagonal elements. + + Used to add Tikhonov regularization to sparse Jacobians to prevent + singular matrix errors in spsolve. + + Args: + bcsr_indptr: CSR row pointers + bcsr_indices: CSR column indices + n_unknowns: Number of unknowns (rows/cols) + + Returns: + Array of CSR data indices corresponding to diagonal elements + """ + indptr_np = np.asarray(bcsr_indptr) + indices_np = np.asarray(bcsr_indices) + + diag_indices = [] + for row in range(n_unknowns): + row_start, row_end = int(indptr_np[row]), int(indptr_np[row + 1]) + for j in range(row_start, row_end): + if indices_np[j] == row: + diag_indices.append(j) + break + + return jnp.array(diag_indices, dtype=jnp.int32) + + def _compute_noi_masks( noi_indices: Optional[Array], n_nodes: int, @@ -332,6 +364,11 @@ def make_sparse_full_mna_solver( noi_diag_indices = masks['noi_diag_indices'] noi_res_indices_arr = masks['noi_res_indices_arr'] + # Pre-compute ALL diagonal indices for Tikhonov regularization + all_diag_indices = None + if use_precomputed: + all_diag_indices = _compute_all_diag_indices(bcsr_indptr, bcsr_indices, n_augmented) + # Create augmented residual mask (node equations + branch equations) if masks['residual_mask'] is not None: # NOI nodes should be masked in residual convergence check @@ -397,6 +434,10 @@ def body_fn(state): csr_data = csr_data.at[noi_diag_indices].set(1.0) f_solve = f.at[noi_res_indices_arr].set(0.0) + # Add Tikhonov regularization to prevent singular matrix errors + if all_diag_indices is not None: + csr_data = csr_data.at[all_diag_indices].add(1e-14) + delta = spsolve(csr_data, bcsr_indices, bcsr_indptr, -f_solve, tol=1e-6) else: # Fallback: sort each iteration @@ -411,6 +452,10 @@ def body_fn(state): data = data.at[noi_diag_indices].set(1.0) f_solve = f.at[noi_res_indices_arr].set(0.0) + # Add Tikhonov regularization (compute diag indices on-the-fly for fallback path) + fallback_diag_indices = _compute_all_diag_indices(J_bcsr.indptr, J_bcsr.indices, n_augmented) + data = data.at[fallback_diag_indices].add(1e-14) + delta = spsolve(data, J_bcsr.indices, J_bcsr.indptr, -f_solve, tol=1e-6) # Step limiting @@ -780,6 +825,11 @@ def make_sparse_solver( noi_diag_indices = masks['noi_diag_indices'] noi_res_indices_arr = masks['noi_res_indices_arr'] + # Pre-compute ALL diagonal indices for Tikhonov regularization + all_diag_indices_sparse = None + if use_precomputed: + all_diag_indices_sparse = _compute_all_diag_indices(bcsr_indptr, bcsr_indices, n_unknowns) + def nr_solve(V_init: Array, vsource_vals: Array, isource_vals: Array, Q_prev: Array, integ_c0: float | Array, device_arrays_arg: Dict[str, Array], @@ -835,6 +885,10 @@ def body_fn(state): csr_data = csr_data.at[noi_diag_indices].set(1.0) f_solve = f.at[noi_res_indices_arr].set(0.0) + # Add Tikhonov regularization to prevent singular matrix errors + if all_diag_indices_sparse is not None: + csr_data = csr_data.at[all_diag_indices_sparse].add(1e-14) + delta = spsolve(csr_data, bcsr_indices, bcsr_indptr, -f_solve, tol=1e-6) else: # Fallback: sort each iteration @@ -849,6 +903,10 @@ def body_fn(state): data = data.at[noi_diag_indices].set(1.0) f_solve = f.at[noi_res_indices_arr].set(0.0) + # Add Tikhonov regularization (compute diag indices on-the-fly for fallback path) + fallback_diag_indices = _compute_all_diag_indices(J_bcsr.indptr, J_bcsr.indices, n_unknowns) + data = data.at[fallback_diag_indices].add(1e-14) + delta = spsolve(data, J_bcsr.indices, J_bcsr.indptr, -f_solve, tol=1e-6) # Step limiting