diff --git a/jax_spice/analysis/solver_factories.py b/jax_spice/analysis/solver_factories.py index 1747b191..a31cb2be 100644 --- a/jax_spice/analysis/solver_factories.py +++ b/jax_spice/analysis/solver_factories.py @@ -37,6 +37,38 @@ DEFAULT_ABSTOL = 1e4 # Corresponds to ~10nV voltage accuracy with G=1e12 +def _compute_all_diag_indices( + bcsr_indptr: Array, + bcsr_indices: Array, + n_unknowns: int, +) -> Array: + """Pre-compute CSR indices for ALL diagonal elements. + + Used to add Tikhonov regularization to sparse Jacobians to prevent + singular matrix errors in spsolve. + + Args: + bcsr_indptr: CSR row pointers + bcsr_indices: CSR column indices + n_unknowns: Number of unknowns (rows/cols) + + Returns: + Array of CSR data indices corresponding to diagonal elements + """ + indptr_np = np.asarray(bcsr_indptr) + indices_np = np.asarray(bcsr_indices) + + diag_indices = [] + for row in range(n_unknowns): + row_start, row_end = int(indptr_np[row]), int(indptr_np[row + 1]) + for j in range(row_start, row_end): + if indices_np[j] == row: + diag_indices.append(j) + break + + return jnp.array(diag_indices, dtype=jnp.int32) + + def _compute_noi_masks( noi_indices: Optional[Array], n_nodes: int, @@ -226,8 +258,12 @@ def body_fn(state): J = J.at[noi_res_idx, noi_res_idx].set(1.0) f = f.at[noi_res_idx].set(0.0) + # Add regularization for numerical stability (prevents singular matrix errors) + reg = 1e-14 * jnp.eye(J.shape[0], dtype=J.dtype) + J_reg = J + reg + # Solve linear system J @ delta = -f - delta = jax.scipy.linalg.solve(J, -f) + delta = jax.scipy.linalg.solve(J_reg, -f) # Step limiting max_delta = jnp.max(jnp.abs(delta)) @@ -328,6 +364,11 @@ def make_sparse_full_mna_solver( noi_diag_indices = masks['noi_diag_indices'] noi_res_indices_arr = masks['noi_res_indices_arr'] + # Pre-compute ALL diagonal indices for Tikhonov regularization + all_diag_indices = None + if use_precomputed: + all_diag_indices = _compute_all_diag_indices(bcsr_indptr, bcsr_indices, n_augmented) + # Create augmented residual mask (node equations + branch equations) if masks['residual_mask'] is not None: # NOI nodes should be masked in residual convergence check @@ -393,6 +434,10 @@ def body_fn(state): csr_data = csr_data.at[noi_diag_indices].set(1.0) f_solve = f.at[noi_res_indices_arr].set(0.0) + # Add Tikhonov regularization to prevent singular matrix errors + if all_diag_indices is not None: + csr_data = csr_data.at[all_diag_indices].add(1e-14) + delta = spsolve(csr_data, bcsr_indices, bcsr_indptr, -f_solve, tol=1e-6) else: # Fallback: sort each iteration @@ -407,6 +452,10 @@ def body_fn(state): data = data.at[noi_diag_indices].set(1.0) f_solve = f.at[noi_res_indices_arr].set(0.0) + # Add Tikhonov regularization (compute diag indices on-the-fly for fallback path) + fallback_diag_indices = _compute_all_diag_indices(J_bcsr.indptr, J_bcsr.indices, n_augmented) + data = data.at[fallback_diag_indices].add(1e-14) + delta = spsolve(data, J_bcsr.indices, J_bcsr.indptr, -f_solve, tol=1e-6) # Step limiting @@ -686,8 +735,12 @@ def body_fn(state): J = J.at[noi_res_idx, noi_res_idx].set(1.0) f = f.at[noi_res_idx].set(0.0) + # Add regularization for numerical stability (prevents singular matrix errors) + reg = 1e-14 * jnp.eye(J.shape[0], dtype=J.dtype) + J_reg = J + reg + # Solve linear system - delta = jax.scipy.linalg.solve(J, -f) + delta = jax.scipy.linalg.solve(J_reg, -f) # Step limiting max_delta = jnp.max(jnp.abs(delta)) @@ -772,6 +825,11 @@ def make_sparse_solver( noi_diag_indices = masks['noi_diag_indices'] noi_res_indices_arr = masks['noi_res_indices_arr'] + # Pre-compute ALL diagonal indices for Tikhonov regularization + all_diag_indices_sparse = None + if use_precomputed: + all_diag_indices_sparse = _compute_all_diag_indices(bcsr_indptr, bcsr_indices, n_unknowns) + def nr_solve(V_init: Array, vsource_vals: Array, isource_vals: Array, Q_prev: Array, integ_c0: float | Array, device_arrays_arg: Dict[str, Array], @@ -827,6 +885,10 @@ def body_fn(state): csr_data = csr_data.at[noi_diag_indices].set(1.0) f_solve = f.at[noi_res_indices_arr].set(0.0) + # Add Tikhonov regularization to prevent singular matrix errors + if all_diag_indices_sparse is not None: + csr_data = csr_data.at[all_diag_indices_sparse].add(1e-14) + delta = spsolve(csr_data, bcsr_indices, bcsr_indptr, -f_solve, tol=1e-6) else: # Fallback: sort each iteration @@ -841,6 +903,10 @@ def body_fn(state): data = data.at[noi_diag_indices].set(1.0) f_solve = f.at[noi_res_indices_arr].set(0.0) + # Add Tikhonov regularization (compute diag indices on-the-fly for fallback path) + fallback_diag_indices = _compute_all_diag_indices(J_bcsr.indptr, J_bcsr.indices, n_unknowns) + data = data.at[fallback_diag_indices].add(1e-14) + delta = spsolve(data, J_bcsr.indices, J_bcsr.indptr, -f_solve, tol=1e-6) # Step limiting