feat: Device eval branch specialization (Phase 1)#111
Open
feat: Device eval branch specialization (Phase 1)#111
Conversation
9c9a46a to
05f56e1
Compare
Analyzes Jacobian sparsity patterns to report exploitable parallelism: - Elimination tree with level-set parallelism metrics - Supernodal detection for BLAS-3 opportunities - Fill-in analysis via SuperLU (MMD and COLAMD orderings) - RCM bandwidth reduction - Device scatter pattern analysis (vmap fan-in conflicts) - Pattern stability verification across NR iterations Supports two modes: - Benchmark mode: runs simulation and captures matrices - File mode: analyzes existing Matrix Market files Outputs JSON, human-readable summary, level-set CSV, and etree .npy. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
Uses jax.named_scope to annotate NR body phases (build_system, linear_solve, enforce_noi) and jax.debug.callback for CPU-accurate timestamps. Captures Perfetto traces viewable at ui.perfetto.dev. c6288 finding: build_system (device eval + assembly) takes 99% of NR iteration time. Linear solve (UMFPACK, 25k unknowns) is only 1%. This means IREE/Baspacho optimizations should focus on the assembly pipeline, not just the factorization. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
Analyzes compiled device model eval functions to determine how many unique device configurations exist and whether branches (jnp.where) can be eliminated through specialization. Key finding: PSP103 has 854 real branches (excluding safe-divide guards), and ALL of them trace back to static parameters — zero are voltage-dependent. For c6288 (10,112 transistors), only 2 specialized variants are needed (NMOS/PMOS), meaning all branches could be resolved at compile time for straight-line GPU kernels. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
Inline all shared parameter values (2,591 for PSP103) and shared cache values (407) as Python float literals in the generated eval function code. At JAX trace time, constant expressions evaluate eagerly, so jnp.where(const_bool, a, b) only traces the taken branch — eliminating all static-parameter-dependent branches from the compiled XLA program. This is Phase 1 of device eval branch specialization. The generated source code still contains jnp.where calls, but JAX's tracer constant-folds them when the condition is a Python bool rather than a traced abstract value. The key change: shared_params[N] lookups (which produce abstract values under tracing) are replaced with concrete float literals (which Python evaluates immediately). Changes: - function_builder.py: build_with_cache_split(), _emit_param_mapping(), and _emit_cache_mapping() accept optional concrete value lists - __init__.py: translate_eval_array_with_cache_split() passes through concrete values with logging - openvaf_models.py: prepare_static_inputs() extracts concrete values from already-computed shared_params_list and shared_cache arrays Verified: rc, graetz, ring benchmarks produce identical results via compare_vacask.py. Generated code shows 0 shared_params[N] references and 0 shared_cache[N] references (all inlined). Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
Remove unused variables (n, openvaf_by_type, widths, config_indices) and fix import sorting flagged by ruff. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
Document ruff check and ruff format commands that CI enforces, so they're run before committing. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
The translate_eval_array_with_cache_split() method (used by prepare_static_inputs()) now builds sccp_known_values from concrete shared param and cache values, enabling SCCP to eliminate dead branches at Python codegen time — before JAX ever sees the code. PSP103 results: 695/954 MIR blocks dead, 47 static branches resolved, 7066 constants propagated, jnp.where reduced from 2247 to 1801 (20%). Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
The dense Jacobian assembly (assemble_dense_jacobian / _build_system_dense_direct) already adds gmin (1e-12) diagonal regularization. The linear_solve function was adding an additional 1e-14 * eye(n) which was redundant and added unnecessary computation (eye allocation + matrix addition) per NR solve. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
Replace jnp.where-wrapped writes to times_out, V_out, I_out with unconditional .at[step_idx].set(). On step rejection, step_idx doesn't advance so stale values get overwritten by the next accepted step. The caller trims output using step_idx, so values beyond it are ignored. This avoids materializing both branches of jnp.where on the full output arrays (up to max_steps x n_nodes) every timestep, saving ~8-12% per step. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
- Fix test_vacask_suite.py VACASK_ROOT to use vendored path (vendor/VACASK) instead of sibling directory (../VACASK). This was causing all 78 discovery/parsing tests to silently skip. - Add dump_jaxpr() to CircuitEngine for analyzing compiled simulation functions (build_system + nr_solve) via jaxpr and HLO dumps. - Wire dump_jaxpr into compare_vacask.py --analyze flag, replacing the scan-only analysis path. Now works with the default while_loop path. - Remove unused analyze_compiled_function from compare_vacask.py. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
SCCP dead-block elimination is kept (eliminates 695/954 MIR blocks for PSP103), but the generated Python code now uses shared_params[i] array reads instead of literal float constants. Literal inlining (~1300 constants for PSP103) caused a 7.8x GPU regression on ring by embedding too many constants in the XLA kernel, hurting register pressure and instruction cache. CPU was 14% faster with inlining, but GPU went from 1.43ms/step to 11.17ms/step. The fix: pass concrete values to SCCP for dead-block analysis only, via a new build_sccp_known_values() method. The builder no longer accepts concrete_shared_values/concrete_shared_cache parameters — all generated code uses array reads. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
Pyright flagged subscript access on Optional type. Add explicit None guard with descriptive error message. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
Script tests 8 XLA flag configurations (autotune levels, command buffers, while-loop double buffering, PGLE) on CUDA benchmarks. Each config runs in a subprocess for clean XLA state. Workflow is dispatch-only (manual trigger) so it doesn't affect regular CI. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
SCCP dead-block elimination changes the generated eval function hash, invalidating the persistent XLA compilation cache. For ring (PSP103), this causes 49.5s cold compilation (vs 0.58s cache hit on main), which inflates the benchmark wall_time by ~152s. The actual per-step execution with SCCP is ~2.3ms vs ~1.4ms without, but the benchmark reports 9.84ms because JIT compilation happens inside run_while() and is counted as execution time. SCCP provides marginal benefit for the unified eval function (both NMOS+PMOS branches still needed), but the infrastructure is preserved in build_sccp_known_values() for future config-group specialization where each group has a unique TYPE value. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
…ice recompile extract_results() slices output to n_steps and converts to numpy, then run_transient() was converting back to JAX via jnp.asarray(). This created dynamically-sized JAX arrays (shape = n_steps) that triggered jit(dynamic_slice) recompilation on CUDA whenever the step count changed between warmup and actual run. All consumers already convert to numpy, so the round-trip was unnecessary. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
nsys was not installed — only cuda-nvcc and cuda-cudart-dev were. Add nsight-systems-cli package and discover its PATH dynamically. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
The NVIDIA CUDA apt repo uses versioned package names like cuda-nsight-systems-12-6, not nsight-systems-cli. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
- Export .sqlite for offline analysis without nsys installed - Generate CSV stats for kernel, API, memory transfer, and trace reports - Expand job summary with CUDA API and memory transfer sections - Upload all artifacts (nsys-rep, sqlite, CSV stats) Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
JAX CUDA plugin requires cuDNN to initialize the GPU backend. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
Individual CUDA packages miss runtime libs like cuFFT. Use cuda-toolkit-12-6 (matching benchmark workflow) to get everything. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
Match benchmark-comparison CUDA packages: cuda-toolkit, cudnn, cudss. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
New pushes to a PR branch now cancel any in-flight runs of the same workflow, avoiding wasted runner time on stale commits. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
Use cudaProfilerStart()/cudaProfilerStop() to capture only the simulation run, excluding prepare() warmup (JIT compilation, module loading, memory allocation). Increase default timesteps from 50 to 500 for representative steady-state profiling. The previous 50-step profile was dominated by one-time startup costs (cuMemHostAlloc 89ms, cuModuleLoadFatBinary 9ms, etc.) that made the timing unrepresentative of actual benchmark behavior. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
d6548a0 to
7356c78
Compare
- apt-sources needs `|` (not `>-`) to preserve newlines between source entries - dpkg -L and nsys --version now tolerate missing packages gracefully Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
scikit-build-core runs cmake in an isolated environment that doesn't inherit GITHUB_PATH additions. Explicitly set CUDAToolkit_ROOT so cmake can find nvcc on runners where CUDA was installed via apt. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
CUDA toolkit apt packages use post-install scripts to create /usr/local/cuda-* directories. Without execute_install_scripts=true, the cached dpkg files don't set up the toolkit properly. Also dynamically discover CUDAToolkit_ROOT instead of hardcoding the path, and export it as env var for cmake. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
Previous cache was created without execute_install_scripts, so CUDA toolkit post-install scripts weren't captured. Bumping version forces a fresh install with scripts enabled. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
Helps debug whether cache-apt-pkgs-action properly restores CUDA toolkit files and post-install script results. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
Prevents cache cross-contamination between runners with different pre-installed packages (runner-1 has CUDA toolkit, runner-2 doesn't). Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
The NVIDIA runner image already has the CUDA repo configured with its own GPG key (cuda-archive-keyring.gpg). Adding the same repo via apt-sources with a different key file causes APT to fail with "Conflicting values set for option Signed-By", preventing all package installation. Only add the LLVM repo via apt-sources. Bump cache version to invalidate the stale empty cache. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
The setup-uv action uses the system Python version (3.12.3) in the cache key when python-version is not specified. Since workflows install a different Python version afterwards (3.10-3.13), the cache was keyed incorrectly — storing cp312 wheels but needing cp31x wheels. GitHub Actions caches are immutable per key, so the stale ~2MB cache could never be updated, causing every run to re-download all packages. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
Replace ad-hoc apt-get install and llvm.sh script calls with robtaylor/cache-apt-pkgs-action@feat/apt-sources across all workflows. This caches installed packages between runs, avoiding repeated ~30s apt-get update + install cycles. Changes per workflow: - test.yml: Merge LLVM + system deps into single cached step - benchmark-comparison.yml: Replace 4 apt steps (cache, CPU deps, CUDA deps, LLVM) with 2 conditional cached steps (CPU vs CUDA) - test-pdk.yml: Replace llvm.sh with cached LLVM packages - profile-nsys.yml: Remove stale version parameter Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
The cache-apt-pkgs-action may restore .so files without proper symlinks, causing CMake's FindBLAS to fail with "Could NOT find BLAS". Running ldconfig after restore regenerates the symlinks. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
The cache-apt-pkgs-action registers dpkg metadata but may not restore actual library files to the filesystem. Detect and reinstall openblas if the shared library is missing. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
The cache-apt-pkgs-action may create version conflicts between cached and pre-installed packages (e.g., gcc-14-base version mismatch). Run apt --fix-broken install first, then reinstall both openblas and suitesparse to ensure library files are present. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
The apt cache action may register newer gcc-14-base dpkg metadata than what's installed on the runner, creating unresolvable dependency conflicts. Use --force-overwrite to install the correct versions. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
The cache-apt-pkgs-action cache was corrupted (gcc-14-base version mismatch with runner). Deleted the cache entry and removed the workaround. Fresh install will build a clean cache. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
cuda-toolkit-12-6 pulls in 216+ packages (cufft, curand, npp, nvjpeg, opencl, documentation, visual tools, etc.) that we don't need. Replace with the 7 specific packages required for building spineax/BaSpaCho: cuda-nvcc-12-6, cuda-cudart-dev-12-6, cuda-driver-dev-12-6, libcublas-dev-12-6, libcusolver-dev-12-6, libcusparse-dev-12-6, libnvjitlink-dev-12-6 Also removes stale dpkg diagnostics and ldconfig workaround from profile-nsys.yml (the corrupted apt cache was already deleted). Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
JAX's xla_cuda12 plugin checks cuFFT version during initialization even if the application doesn't use FFTs. Without libcufft.so, the CUDA backend fails to initialize and falls back to CPU. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
setup-uv's prune-cache (default: true) runs `uv cache prune` before saving, which was stripping nearly all cached content — leaving only ~2-5 MB of HTTP metadata instead of the ~200+ MB of downloaded wheels. Combined with GitHub Actions' immutable cache keys, this meant every run re-downloaded jaxlib (78 MB), scipy (33 MB), numpy (15 MB), etc. Also deleted the existing stale caches via `gh cache delete` so new properly-populated caches can be saved on the next run. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
JAX's xla_cuda12 plugin checks cuFFT, cuPTI, and other CUDA library versions during initialization. Use cuda-libraries-12-6 meta-package (13 runtime libs) instead of cherry-picking individual ones. Also add cuda-cupti-12-6 which JAX requires but isn't in cuda-libraries. Still much smaller than cuda-toolkit-12-6 (216+ packages including docs, visual tools, nvprof, etc). Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
Environment variables don't propagate through uv pip install's subprocess. Pass CMAKE_*_COMPILER_LAUNCHER via -C cmake.define so scikit-build-core forwards them to CMake. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
Scheduled to run on the 1st of every 3rd month to prevent unbounded cache growth. Also supports manual dispatch for ad-hoc cleanup. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
JAX's xla_cuda12 plugin dlopen's all CUDA runtime libraries (cuSPARSE, cuFFT, cuBLAS, etc.) at initialization. Without CUDA_ROOT/lib64 in LD_LIBRARY_PATH, these libraries aren't found even though the apt packages are installed. Previous runs worked because cuda-toolkit-12-6 included ldconfig snippets; our minimal package set does not. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
GITHUB_ENV takes the last value per key within a step. Writing LD_LIBRARY_PATH twice (once for CUDA lib64, once for cuDSS) meant the CUDA path was lost. Build the full path in a shell variable first, then write it once. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
Wraps each phase (JAX init, solver imports, vajax import, circuit parse, prepare + JIT warmup, transient simulation) with timing output to identify where startup time is spent. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
Add NVTX push/pop around run_transient() and use nsys --capture-range=nvtx to exclude JIT warmup and Python startup from the profiling window. This gives clean data for just the simulation hot path. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
NVTX capture-range=nvtx prevents nsys from writing the report when SIGSEGV occurs during JAX/CUDA teardown. Revert to default capture mode with NVTX markers kept for annotation/filtering in post-analysis. Also remove --no-cache from spineax pip install (unnecessary slowdown) and add flush=True to profiling prints to prevent output loss on crash. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
scikit-build-core bundles CMake 3.31.6 which can't find OpenBLAS via the generic libblas.so symlink (cache-apt-pkgs-action doesn't replay update-alternatives hooks). Setting BLA_VENDOR=OpenBLAS makes FindBLAS search directly for libopenblas.so. Also adds compiler cache (sccache/ccache) to the SuiteSparse FetchContent build, and aligns benchmark CUDA packages with the nsys profiling workflow. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
Enables VLOG on command_buffer_conversion_pass (level 2) and while_thunk (level 3) to identify which thunks in the NR while loop body prevent CUDA graph conditional node capture. Without graph capture, each NR iteration requires a D2H sync for the loop predicate (574µs × 82K iterations = 47s overhead). Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
TF_CPP_VLOG_FLAGS is not recognized by JAX/XLA. The correct env var for module-level VLOG control is TF_CPP_VMODULE. Also set TF_CPP_MIN_LOG_LEVEL=0 (was 2) so VLOG output is not suppressed. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
VLOG output confirmed: - NR while_loop NOT captured as CUDA graph (WhileThunk fallback mode) - BaSpaCho FFI custom call blocks WhileThunk→WhileCmd conversion (lacks kCmdBufferCompatible trait) - XLA creates 4 partial command buffer graphs within the NR body - NR NEVER converges (all breaks at iter=20 = tran_itl limit) Commenting out VLOG to reduce noise for future profiling runs. Re-enable by uncommenting TF_CPP_VMODULE line. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
Enable reduced VLOG set (conversion_pass=2, while_thunk=3) to verify that BaSpaCho FFI Execute handler with kCmdBufferCompatible trait enables WhileThunk→WhileCmd conversion for the NR while_loop. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
jnp.where(const_bool, a, b)only traces the taken branch, eliminating static-parameter-dependent branches from the compiled XLA programscripts/profile_nr_phases.py) and parallelism analysis tooling (scripts/analyze_parallelism.py)How it works
Before (traced): XLA sees abstract values, evaluates both
jnp.wherebranches:After (concrete): Python evaluates constants, only one branch traced:
Verification
Generated PSP103 eval code shows 0
shared_params[N]references and 0shared_cache[N]references (all 2,998 values inlined as literals).Benchmarks produce identical numerical results:
Test plan
compare_vacask.py --benchmark rcpassescompare_vacask.py --benchmark graetzpassescompare_vacask.py --benchmark ringpasses