From dc1b0a2fc39009b8a1f35a62d84a06330b49542c Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Mon, 8 Dec 2025 22:18:13 +0000 Subject: [PATCH 01/10] Add TPU CI support via GCP TPU VMs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add test-tpu.yml workflow for running tests on TPU v5e (v5litepod-8) - Add setup_gcp_tpu_ci.sh for configuring TPU quota and permissions - Update conftest.py to recognize TPU backend Workflow features: - Manual dispatch with Spot VM option for cost savings - Runs same profiling + tests as GPU Cloud Run workflow - Creates TPU VM on demand, cleans up after tests - Extracts profiling report to GitHub job summary Estimated costs: - On-demand: ~$9.60/hour (8 chips × $1.20) - Spot: ~$1-2/hour (up to 91% discount) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- .github/workflows/test-tpu.yml | 93 +++++++++++ scripts/setup_gcp_tpu_ci.sh | 286 +++++++++++++++++++++++++++++++++ tests/conftest.py | 13 +- 3 files changed, 389 insertions(+), 3 deletions(-) create mode 100644 .github/workflows/test-tpu.yml create mode 100644 scripts/setup_gcp_tpu_ci.sh diff --git a/.github/workflows/test-tpu.yml b/.github/workflows/test-tpu.yml new file mode 100644 index 00000000..9e5eeaf2 --- /dev/null +++ b/.github/workflows/test-tpu.yml @@ -0,0 +1,93 @@ +name: TPU Tests + +on: + workflow_dispatch: + inputs: + use_spot: + description: 'Use Spot VM (cheaper but may be preempted)' + required: true + default: true + type: boolean + +env: + GCP_PROJECT: jax-spice-cuda-test + GCP_ZONE: us-central1-a + TPU_NAME: jax-spice-tpu-${{ github.run_id }} + TPU_TYPE: v5litepod-8 + TPU_RUNTIME: v2-alpha-tpuv5-lite + +jobs: + tpu-tests: + runs-on: ubuntu-latest + permissions: + contents: read + id-token: write + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Authenticate to GCP + uses: google-github-actions/auth@v2 + with: + credentials_json: ${{ secrets.GCP_SERVICE_ACCOUNT_KEY }} + + - name: Set up Cloud SDK + uses: google-github-actions/setup-gcloud@v2 + + - name: Create TPU VM + run: | + SPOT_FLAG="" + if [ "${{ inputs.use_spot }}" = "true" ]; then + SPOT_FLAG="--spot" + fi + + gcloud compute tpus tpu-vm create "${{ env.TPU_NAME }}" \ + --zone="${{ env.GCP_ZONE }}" \ + --accelerator-type="${{ env.TPU_TYPE }}" \ + --version="${{ env.TPU_RUNTIME }}" \ + ${SPOT_FLAG} \ + --quiet + + - name: Sync code and setup environment + run: | + tar --exclude='.git' --exclude='__pycache__' -czf /tmp/jax-spice.tar.gz . + + gcloud compute tpus tpu-vm scp /tmp/jax-spice.tar.gz \ + "${{ env.TPU_NAME }}":~/jax-spice.tar.gz \ + --zone="${{ env.GCP_ZONE }}" + + gcloud compute tpus tpu-vm ssh "${{ env.TPU_NAME }}" \ + --zone="${{ env.GCP_ZONE }}" \ + --command='mkdir -p ~/jax-spice && cd ~/jax-spice && tar -xzf ~/jax-spice.tar.gz && rm ~/jax-spice.tar.gz && curl -LsSf https://astral.sh/uv/install.sh | sh' + + - name: Install dependencies + run: | + gcloud compute tpus tpu-vm ssh "${{ env.TPU_NAME }}" \ + --zone="${{ env.GCP_ZONE }}" \ + --command='source ~/.local/bin/env && cd ~/jax-spice && uv pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html && uv sync' + + - name: Run profiling and tests + run: | + gcloud compute tpus tpu-vm ssh "${{ env.TPU_NAME }}" \ + --zone="${{ env.GCP_ZONE }}" \ + --command='source ~/.local/bin/env && cd ~/jax-spice && export JAX_PLATFORMS=tpu && export JAX_ENABLE_X64=1 && uv run python scripts/profile_gpu.py && uv run pytest tests/ -v --tb=short -x' \ + | tee /tmp/test_output.txt + + - name: Extract profiling report + if: always() + run: | + echo "## TPU Test Results" >> "$GITHUB_STEP_SUMMARY" + echo "- **TPU Type:** ${{ env.TPU_TYPE }}" >> "$GITHUB_STEP_SUMMARY" + echo "- **Spot VM:** ${{ inputs.use_spot }}" >> "$GITHUB_STEP_SUMMARY" + echo "" >> "$GITHUB_STEP_SUMMARY" + if [ -f /tmp/test_output.txt ]; then + sed -n '/# JAX-SPICE/,/Report written/p' /tmp/test_output.txt | head -n -1 >> "$GITHUB_STEP_SUMMARY" || true + fi + + - name: Cleanup TPU VM + if: always() + run: | + gcloud compute tpus tpu-vm delete "${{ env.TPU_NAME }}" \ + --zone="${{ env.GCP_ZONE }}" \ + --quiet || true diff --git a/scripts/setup_gcp_tpu_ci.sh b/scripts/setup_gcp_tpu_ci.sh new file mode 100644 index 00000000..34acb13c --- /dev/null +++ b/scripts/setup_gcp_tpu_ci.sh @@ -0,0 +1,286 @@ +#!/bin/bash +# Setup script for GCP TPU CI infrastructure +# This script is fully idempotent - safe to run multiple times +# +# Usage: ./setup_gcp_tpu_ci.sh +# +# Features: +# - Enables TPU API +# - Creates/updates service account with TPU permissions +# - Stores service account key in GCP Secret Manager +# - Syncs secret to GitHub Actions +# +# Prerequisites: +# - gcloud CLI installed and authenticated +# - gh CLI installed and authenticated (for GitHub secret sync) +# - Billing enabled on the GCP project +# - TPU quota available in the zone (request at https://cloud.google.com/tpu/docs/quota) +# +# TPU Quota: +# - You need quota for the TPU type you want to use +# - For v5e: request "TPU v5 Lite PodSlice chips" quota +# - For Spot VMs: request preemptible quota separately +# - Quota request: https://console.cloud.google.com/iam-admin/quotas + +set -euo pipefail + +# Configuration +PROJECT_ID="${GCP_PROJECT:-jax-spice-cuda-test}" +REGION="${GCP_REGION:-us-central1}" +ZONE="${GCP_ZONE:-us-central1-a}" # Must have TPU v5e availability +SA_NAME="github-gpu-ci" # Reuse existing service account +SA_EMAIL="${SA_NAME}@${PROJECT_ID}.iam.gserviceaccount.com" +SECRET_NAME="github-gpu-ci-key" +GITHUB_REPO="${GITHUB_REPO:-ChipFlow/jax-spice}" + +# TPU Configuration +TPU_TYPE="${TPU_TYPE:-v5litepod-8}" # Smallest v5e configuration +TPU_RUNTIME="${TPU_RUNTIME:-v2-alpha-tpuv5-lite}" + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +log_info() { echo -e "${GREEN}[INFO]${NC} $1"; } +log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; } +log_error() { echo -e "${RED}[ERROR]${NC} $1"; } + +echo "==========================================" +echo " JAX-SPICE TPU CI Setup (Idempotent)" +echo "==========================================" +echo "" +echo "Project: ${PROJECT_ID}" +echo "Zone: ${ZONE}" +echo "TPU Type: ${TPU_TYPE}" +echo "Runtime: ${TPU_RUNTIME}" +echo "GitHub Repo: ${GITHUB_REPO}" +echo "" + +# Set project +gcloud config set project "${PROJECT_ID}" --quiet + +# ============================================================================= +# Step 1: Enable required APIs (idempotent) +# ============================================================================= +log_info "Enabling required APIs..." +gcloud services enable tpu.googleapis.com --quiet +gcloud services enable compute.googleapis.com --quiet +gcloud services enable secretmanager.googleapis.com --quiet +gcloud services enable iam.googleapis.com --quiet +log_info "APIs enabled" + +# ============================================================================= +# Step 2: Check/create service account (idempotent) +# ============================================================================= +log_info "Setting up service account..." +if gcloud iam service-accounts describe "${SA_EMAIL}" &>/dev/null; then + log_info "Service account already exists: ${SA_EMAIL}" +else + gcloud iam service-accounts create "${SA_NAME}" \ + --display-name="GitHub GPU/TPU CI Runner" \ + --quiet + log_info "Created service account: ${SA_EMAIL}" +fi + +# ============================================================================= +# Step 3: Grant TPU permissions (idempotent) +# ============================================================================= +log_info "Configuring IAM permissions for TPU..." + +# Define required roles for TPU access +ROLES=( + "roles/tpu.admin" # Create/delete/manage TPU VMs + "roles/compute.networkUser" # Use VPC networks for TPU + "roles/iam.serviceAccountUser" # Use service account on TPU VM + "roles/logging.viewer" # View logs + "roles/storage.objectViewer" # Read from GCS (for JAX wheels) +) + +for ROLE in "${ROLES[@]}"; do + gcloud projects add-iam-policy-binding "${PROJECT_ID}" \ + --member="serviceAccount:${SA_EMAIL}" \ + --role="${ROLE}" \ + --quiet \ + --condition=None 2>/dev/null || true +done +log_info "IAM permissions configured" + +# ============================================================================= +# Step 4: Check TPU quota +# ============================================================================= +log_info "Checking TPU quota..." + +# Extract chip count from TPU type (e.g., v5litepod-8 -> 8) +CHIP_COUNT=$(echo "${TPU_TYPE}" | grep -oE '[0-9]+$') + +echo "" +echo "Required quota for ${TPU_TYPE}:" +echo " - TPU v5 Lite PodSlice chips: ${CHIP_COUNT} (on-demand)" +echo " - Preemptible TPU v5 Lite PodSlice chips: ${CHIP_COUNT} (for Spot VMs)" +echo "" +echo "Check your quota at:" +echo " https://console.cloud.google.com/iam-admin/quotas?project=${PROJECT_ID}" +echo "" +echo "Filter by: 'tpu' and region '${REGION}'" +echo "" + +# ============================================================================= +# Step 5: Verify TPU availability in zone +# ============================================================================= +log_info "Checking TPU availability in ${ZONE}..." + +# List available accelerator types +AVAILABLE_TYPES=$(gcloud compute tpus accelerator-types list \ + --zone="${ZONE}" \ + --format="value(type)" 2>/dev/null | grep -E "^v5litepod" || echo "") + +if [ -z "${AVAILABLE_TYPES}" ]; then + log_warn "No v5e TPUs found in ${ZONE}. Available zones for v5e:" + echo " - us-central1-a" + echo " - us-south1-a" + echo " - us-west1-c" + echo " - us-west4-a" + echo " - europe-west4-b" + echo "" + echo "Update GCP_ZONE environment variable and re-run." +else + log_info "Available TPU types in ${ZONE}:" + echo "${AVAILABLE_TYPES}" | sed 's/^/ - /' +fi + +# ============================================================================= +# Step 6: Create/update secret in Secret Manager (idempotent) +# ============================================================================= +log_info "Setting up Secret Manager..." + +# Check if secret exists +if gcloud secrets describe "${SECRET_NAME}" &>/dev/null; then + log_info "Secret already exists: ${SECRET_NAME}" + + # Check if we need to create a new key version + LATEST_VERSION=$(gcloud secrets versions list "${SECRET_NAME}" \ + --filter="state=ENABLED" \ + --sort-by="~createTime" \ + --limit=1 \ + --format="value(name)" 2>/dev/null || echo "") + + if [ -n "${LATEST_VERSION}" ]; then + log_info "Using existing secret version" + NEED_NEW_KEY=false + else + log_warn "No enabled secret versions found, creating new key" + NEED_NEW_KEY=true + fi +else + log_info "Creating new secret: ${SECRET_NAME}" + gcloud secrets create "${SECRET_NAME}" \ + --replication-policy="automatic" \ + --quiet + NEED_NEW_KEY=true +fi + +# Create new key and add to secret if needed +if [ "${NEED_NEW_KEY:-false}" = true ]; then + log_info "Generating new service account key..." + + # Create temporary key file + KEY_FILE=$(mktemp) + trap "rm -f ${KEY_FILE}" EXIT + + gcloud iam service-accounts keys create "${KEY_FILE}" \ + --iam-account="${SA_EMAIL}" \ + --quiet + + # Add new version to secret + gcloud secrets versions add "${SECRET_NAME}" \ + --data-file="${KEY_FILE}" \ + --quiet + + log_info "Service account key stored in Secret Manager" + + # Clean up old keys (keep only the 2 most recent) + log_info "Cleaning up old service account keys..." + OLD_KEYS=$(gcloud iam service-accounts keys list \ + --iam-account="${SA_EMAIL}" \ + --format="value(name)" \ + --filter="keyType=USER_MANAGED" \ + --sort-by="~validAfterTime" 2>/dev/null | tail -n +3) + + for KEY_ID in ${OLD_KEYS}; do + gcloud iam service-accounts keys delete "${KEY_ID}" \ + --iam-account="${SA_EMAIL}" \ + --quiet 2>/dev/null || true + done +fi + +# ============================================================================= +# Step 7: Sync secret to GitHub (idempotent) +# ============================================================================= +log_info "Syncing secret to GitHub..." + +if command -v gh &>/dev/null; then + if gh auth status &>/dev/null; then + SECRET_VALUE=$(gcloud secrets versions access latest --secret="${SECRET_NAME}" 2>/dev/null) + + if [ -n "${SECRET_VALUE}" ]; then + echo "${SECRET_VALUE}" | gh secret set GCP_SERVICE_ACCOUNT_KEY \ + --repo="${GITHUB_REPO}" 2>/dev/null && \ + log_info "GitHub secret 'GCP_SERVICE_ACCOUNT_KEY' updated" || \ + log_warn "Failed to update GitHub secret (check gh permissions)" + else + log_error "Could not retrieve secret from Secret Manager" + fi + else + log_warn "gh CLI not authenticated. Run 'gh auth login' to sync secrets" + fi +else + log_warn "gh CLI not installed. Install it to auto-sync GitHub secrets" +fi + +# ============================================================================= +# Step 8: Test TPU creation (optional, commented out) +# ============================================================================= +# Uncomment to test TPU VM creation: +# +# log_info "Testing TPU VM creation..." +# gcloud compute tpus tpu-vm create test-tpu-vm \ +# --zone="${ZONE}" \ +# --accelerator-type="${TPU_TYPE}" \ +# --version="${TPU_RUNTIME}" \ +# --spot \ +# --quiet +# +# log_info "TPU VM created successfully, deleting..." +# gcloud compute tpus tpu-vm delete test-tpu-vm \ +# --zone="${ZONE}" \ +# --quiet + +# ============================================================================= +# Summary +# ============================================================================= +echo "" +echo "==========================================" +echo " TPU CI Setup Complete!" +echo "==========================================" +echo "" +echo "Resources configured:" +echo " - Service Account: ${SA_EMAIL}" +echo " - Secret (GCP): ${SECRET_NAME}" +echo " - GitHub Secret: GCP_SERVICE_ACCOUNT_KEY (shared with GPU CI)" +echo "" +echo "TPU Configuration:" +echo " - Type: ${TPU_TYPE} (8 chips, 128 GB HBM2)" +echo " - Zone: ${ZONE}" +echo " - Runtime: ${TPU_RUNTIME}" +echo "" +echo "Estimated costs:" +echo " - On-demand: ~\$9.60/hour (8 × \$1.20/chip)" +echo " - Spot: ~\$1-2/hour (up to 91% discount)" +echo "" +echo "Next steps:" +echo " 1. Request TPU quota if needed (link above)" +echo " 2. Push the test-tpu.yml workflow" +echo " 3. Trigger workflow manually or via PR" +echo "" diff --git a/tests/conftest.py b/tests/conftest.py index 04c65bae..64153d94 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,8 +3,9 @@ Handles platform-specific JAX configuration: - macOS: Forces CPU backend since Metal doesn't support triangular_solve - Linux with CUDA: Preloads CUDA libraries to help JAX discover them +- Linux with TPU: Uses TPU backend when JAX_PLATFORMS=tpu is set -Uses pytest_configure hook to ensure CUDA setup happens before any test imports. +Uses pytest_configure hook to ensure backend setup happens before any test imports. """ import os @@ -36,16 +37,22 @@ def pytest_configure(config): """ Pytest hook that runs before test collection. - This ensures CUDA libraries are preloaded and JAX is configured + This ensures backend libraries are preloaded and JAX is configured BEFORE any test modules are imported. """ + jax_platforms = os.environ.get('JAX_PLATFORMS', '') + # Platform-specific configuration BEFORE importing JAX if sys.platform == 'darwin': # macOS: Force CPU backend - Metal doesn't support triangular_solve os.environ['JAX_PLATFORMS'] = 'cpu' - elif sys.platform == 'linux' and os.environ.get('JAX_PLATFORMS', '').startswith('cuda'): + elif sys.platform == 'linux' and jax_platforms.startswith('cuda'): # Linux with CUDA: Preload CUDA libraries before JAX import _setup_cuda_libraries() + elif sys.platform == 'linux' and jax_platforms == 'tpu': + # Linux with TPU: JAX handles TPU initialization via libtpu + # No special preloading needed - libtpu is installed with jax[tpu] + pass # Import JAX and configure it import jax From 46bdc6a3ac6fd02259b0f4e7a86ce1b455b30ae0 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Mon, 8 Dec 2025 22:30:01 +0000 Subject: [PATCH 02/10] CI: TPU: Enable test-tpu for pull_requests --- .github/workflows/test-tpu.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/test-tpu.yml b/.github/workflows/test-tpu.yml index 9e5eeaf2..9611b505 100644 --- a/.github/workflows/test-tpu.yml +++ b/.github/workflows/test-tpu.yml @@ -1,6 +1,7 @@ name: TPU Tests on: + pull_request: workflow_dispatch: inputs: use_spot: From 330906e827d67464304ab3b0165fe6f0ff1a611f Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Mon, 8 Dec 2025 22:44:18 +0000 Subject: [PATCH 03/10] Fix uv pip install order: sync first to create venv MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit uv pip install requires a virtual environment. Run uv sync first to create the venv, then install jax[tpu] into it. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- .github/workflows/test-tpu.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-tpu.yml b/.github/workflows/test-tpu.yml index 9611b505..b1e41693 100644 --- a/.github/workflows/test-tpu.yml +++ b/.github/workflows/test-tpu.yml @@ -66,7 +66,7 @@ jobs: run: | gcloud compute tpus tpu-vm ssh "${{ env.TPU_NAME }}" \ --zone="${{ env.GCP_ZONE }}" \ - --command='source ~/.local/bin/env && cd ~/jax-spice && uv pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html && uv sync' + --command='source ~/.local/bin/env && cd ~/jax-spice && uv sync && uv pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html' - name: Run profiling and tests run: | From f2f1d66b9abf21e5310eb3a1d6008a6009498566 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Mon, 8 Dec 2025 22:54:59 +0000 Subject: [PATCH 04/10] CI: TPU: Add submodules: true to checkout MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Required for openvaf-py submodule to be included in the tarball. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- .github/workflows/test-tpu.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/test-tpu.yml b/.github/workflows/test-tpu.yml index b1e41693..a057cb3a 100644 --- a/.github/workflows/test-tpu.yml +++ b/.github/workflows/test-tpu.yml @@ -27,6 +27,8 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 + with: + submodules: true - name: Authenticate to GCP uses: google-github-actions/auth@v2 From 06b933788a66fa8f3b131e5051fe3e7c2e08f629 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Mon, 8 Dec 2025 23:08:56 +0000 Subject: [PATCH 05/10] CI: TPU: Use F32 mode (TPU doesn't support F64 LU decomposition) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit TPU v5e only supports F32 and C64 for LuDecomposition operations. Disable JAX_ENABLE_X64 when running on TPU to use float32 instead. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- .github/workflows/test-tpu.yml | 3 ++- tests/conftest.py | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test-tpu.yml b/.github/workflows/test-tpu.yml index a057cb3a..dbf247f3 100644 --- a/.github/workflows/test-tpu.yml +++ b/.github/workflows/test-tpu.yml @@ -72,9 +72,10 @@ jobs: - name: Run profiling and tests run: | + # TPU only supports F32 for LU decomposition, so we don't enable X64 gcloud compute tpus tpu-vm ssh "${{ env.TPU_NAME }}" \ --zone="${{ env.GCP_ZONE }}" \ - --command='source ~/.local/bin/env && cd ~/jax-spice && export JAX_PLATFORMS=tpu && export JAX_ENABLE_X64=1 && uv run python scripts/profile_gpu.py && uv run pytest tests/ -v --tb=short -x' \ + --command='source ~/.local/bin/env && cd ~/jax-spice && export JAX_PLATFORMS=tpu && uv run python scripts/profile_gpu.py && uv run pytest tests/ -v --tb=short -x' \ | tee /tmp/test_output.txt - name: Extract profiling report diff --git a/tests/conftest.py b/tests/conftest.py index 64153d94..a78f1075 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -58,4 +58,6 @@ def pytest_configure(config): import jax # Enable float64 for numerical precision in tests - jax.config.update('jax_enable_x64', True) + # (except on TPU which only supports F32 for LU decomposition) + if jax_platforms != 'tpu': + jax.config.update('jax_enable_x64', True) From e05e7e546634e9e401d78bec54a74705622e1870 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Mon, 8 Dec 2025 23:23:01 +0000 Subject: [PATCH 06/10] CI: TPU: Disable X64 in profile_gpu.py for TPU MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The profiler script hardcoded jax_enable_x64 = True, which overrode the environment setting. Now it checks JAX_PLATFORMS before enabling. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- scripts/profile_gpu.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/scripts/profile_gpu.py b/scripts/profile_gpu.py index af49b887..a23050c2 100644 --- a/scripts/profile_gpu.py +++ b/scripts/profile_gpu.py @@ -21,8 +21,9 @@ import jax.numpy as jnp import numpy as np -# Enable float64 -jax.config.update('jax_enable_x64', True) +# Enable float64 (except on TPU which only supports F32 for LU decomposition) +if os.environ.get('JAX_PLATFORMS') != 'tpu': + jax.config.update('jax_enable_x64', True) @dataclass From 88707048c62d21e9ee842bfb19efd2b5de01e612 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Mon, 8 Dec 2025 23:28:10 +0000 Subject: [PATCH 07/10] Add TPU fallback to CPU for sparse solve MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit TPU doesn't have native sparse solve support (no XLA sparse ops). Fall back to CPU via scipy pure_callback, same as the existing CPU path. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- jax_spice/analysis/sparse.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/jax_spice/analysis/sparse.py b/jax_spice/analysis/sparse.py index e06af8b4..18bdaf16 100644 --- a/jax_spice/analysis/sparse.py +++ b/jax_spice/analysis/sparse.py @@ -50,6 +50,9 @@ def sparse_solve( backend = jax.default_backend() if backend == 'cpu': return _spsolve_cpu(data, indices, indptr, b, shape) + elif backend == 'tpu': + # TPU doesn't have native sparse solve; fall back to CPU via callback + return _spsolve_cpu(data, indices, indptr, b, shape) else: return _spsolve_gpu(data, indices, indptr, b, shape) @@ -199,6 +202,9 @@ def sparse_solve_csr( backend = jax.default_backend() if backend == 'cpu': return _spsolve_cpu_csr(data, indices, indptr, b, shape) + elif backend == 'tpu': + # TPU doesn't have native sparse solve; fall back to CPU via callback + return _spsolve_cpu_csr(data, indices, indptr, b, shape) else: return _spsolve_gpu(data, indices, indptr, b, shape) From ce355c9f9b3fda59888c4748f90eb361d7f0bdff Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Mon, 8 Dec 2025 23:30:03 +0000 Subject: [PATCH 08/10] Revert sparse.py TPU fallback - run native F32 experiment MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove the CPU fallback for TPU sparse solve. Let the experiment run with native TPU operations in F32 mode to see what works. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- jax_spice/analysis/sparse.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/jax_spice/analysis/sparse.py b/jax_spice/analysis/sparse.py index 18bdaf16..e06af8b4 100644 --- a/jax_spice/analysis/sparse.py +++ b/jax_spice/analysis/sparse.py @@ -50,9 +50,6 @@ def sparse_solve( backend = jax.default_backend() if backend == 'cpu': return _spsolve_cpu(data, indices, indptr, b, shape) - elif backend == 'tpu': - # TPU doesn't have native sparse solve; fall back to CPU via callback - return _spsolve_cpu(data, indices, indptr, b, shape) else: return _spsolve_gpu(data, indices, indptr, b, shape) @@ -202,9 +199,6 @@ def sparse_solve_csr( backend = jax.default_backend() if backend == 'cpu': return _spsolve_cpu_csr(data, indices, indptr, b, shape) - elif backend == 'tpu': - # TPU doesn't have native sparse solve; fall back to CPU via callback - return _spsolve_cpu_csr(data, indices, indptr, b, shape) else: return _spsolve_gpu(data, indices, indptr, b, shape) From fc2582cdfb78c520365f21a8dca533b22f185952 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Mon, 8 Dec 2025 23:39:22 +0000 Subject: [PATCH 09/10] Add TPU support to sparse solvers via dense fallback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit sparse.py now properly detects TPU backend and uses dense solve (via BCOO.todense() + jnp.linalg.solve) instead of spsolve which only works on GPU/CUDA. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- jax_spice/analysis/sparse.py | 60 ++++++++++++++++++++++++++++++++++-- 1 file changed, 58 insertions(+), 2 deletions(-) diff --git a/jax_spice/analysis/sparse.py b/jax_spice/analysis/sparse.py index e06af8b4..faf0824b 100644 --- a/jax_spice/analysis/sparse.py +++ b/jax_spice/analysis/sparse.py @@ -4,6 +4,7 @@ selects the best backend based on the JAX platform: - CPU: Uses scipy.sparse.linalg.spsolve via jax.pure_callback - GPU: Uses jax.experimental.sparse.linalg.spsolve (cuSOLVER) +- TPU: Uses dense solve via jnp.linalg.solve (spsolve not supported on TPU) The solver supports reverse-mode autodiff through jax.custom_vjp using the adjoint method for implicit differentiation. @@ -19,6 +20,7 @@ from typing import Tuple import jax +import jax.numpy as jnp from jax import Array import numpy as np from scipy.sparse import csc_matrix, csr_matrix @@ -50,8 +52,35 @@ def sparse_solve( backend = jax.default_backend() if backend == 'cpu': return _spsolve_cpu(data, indices, indptr, b, shape) - else: + elif backend in ('gpu', 'cuda'): return _spsolve_gpu(data, indices, indptr, b, shape) + else: + # TPU and other backends: use dense solve (spsolve not supported) + return _solve_dense_csc(data, indices, indptr, b, shape) + + +def _solve_dense_csc( + data: Array, + indices: Array, + indptr: Array, + b: Array, + shape: Tuple[int, int] +) -> Array: + """Dense solve by reconstructing matrix from CSC format. + + Used for TPU and other backends that don't support sparse solve. + """ + from jax.experimental.sparse import BCOO + + # Convert CSC to BCOO and then to dense + # CSC: data[k] is at (indices[k], col) where col is determined by indptr + n = shape[0] + col_indices = jnp.repeat(jnp.arange(n), jnp.diff(indptr)) + bcoo_indices = jnp.stack([indices, col_indices], axis=1) + A_bcoo = BCOO((data, bcoo_indices), shape=shape) + A_dense = A_bcoo.todense() + + return jnp.linalg.solve(A_dense, b) def _spsolve_cpu( @@ -199,8 +228,35 @@ def sparse_solve_csr( backend = jax.default_backend() if backend == 'cpu': return _spsolve_cpu_csr(data, indices, indptr, b, shape) - else: + elif backend in ('gpu', 'cuda'): return _spsolve_gpu(data, indices, indptr, b, shape) + else: + # TPU and other backends: use dense solve (spsolve not supported) + return _solve_dense_csr(data, indices, indptr, b, shape) + + +def _solve_dense_csr( + data: Array, + indices: Array, + indptr: Array, + b: Array, + shape: Tuple[int, int] +) -> Array: + """Dense solve by reconstructing matrix from CSR format. + + Used for TPU and other backends that don't support sparse solve. + """ + from jax.experimental.sparse import BCOO + + # Convert CSR to BCOO and then to dense + # CSR: data[k] is at (row, indices[k]) where row is determined by indptr + n = shape[0] + row_indices = jnp.repeat(jnp.arange(n), jnp.diff(indptr)) + bcoo_indices = jnp.stack([row_indices, indices], axis=1) + A_bcoo = BCOO((data, bcoo_indices), shape=shape) + A_dense = A_bcoo.todense() + + return jnp.linalg.solve(A_dense, b) def _spsolve_cpu_csr( From 3c005bc65d885d49dea029764f9ecee4d67268e6 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Tue, 9 Dec 2025 01:06:44 +0000 Subject: [PATCH 10/10] CI: TPU: Add zone fallback for capacity issues MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Try multiple zones (us-central1-a, us-west4-a, us-east1-d, us-east5-a) when creating TPU VM to handle temporary capacity exhaustion in any single zone. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- .github/workflows/test-tpu.yml | 52 +++++++++++++++++++++++++--------- 1 file changed, 39 insertions(+), 13 deletions(-) diff --git a/.github/workflows/test-tpu.yml b/.github/workflows/test-tpu.yml index dbf247f3..aaa8be1b 100644 --- a/.github/workflows/test-tpu.yml +++ b/.github/workflows/test-tpu.yml @@ -39,18 +39,32 @@ jobs: uses: google-github-actions/setup-gcloud@v2 - name: Create TPU VM + id: create_tpu run: | SPOT_FLAG="" if [ "${{ inputs.use_spot }}" = "true" ]; then SPOT_FLAG="--spot" fi - gcloud compute tpus tpu-vm create "${{ env.TPU_NAME }}" \ - --zone="${{ env.GCP_ZONE }}" \ - --accelerator-type="${{ env.TPU_TYPE }}" \ - --version="${{ env.TPU_RUNTIME }}" \ - ${SPOT_FLAG} \ - --quiet + # Try zones in order until one succeeds + ZONES="us-central1-a us-west4-a us-east1-d us-east5-a" + for zone in $ZONES; do + echo "Trying zone: $zone" + if gcloud compute tpus tpu-vm create "${{ env.TPU_NAME }}" \ + --zone="$zone" \ + --accelerator-type="${{ env.TPU_TYPE }}" \ + --version="${{ env.TPU_RUNTIME }}" \ + ${SPOT_FLAG} \ + --quiet 2>&1; then + echo "TPU created successfully in $zone" + echo "ACTIVE_ZONE=$zone" >> "$GITHUB_OUTPUT" + exit 0 + else + echo "Zone $zone failed, trying next..." + fi + done + echo "All zones exhausted" + exit 1 - name: Sync code and setup environment run: | @@ -58,23 +72,23 @@ jobs: gcloud compute tpus tpu-vm scp /tmp/jax-spice.tar.gz \ "${{ env.TPU_NAME }}":~/jax-spice.tar.gz \ - --zone="${{ env.GCP_ZONE }}" + --zone="${{ steps.create_tpu.outputs.ACTIVE_ZONE }}" gcloud compute tpus tpu-vm ssh "${{ env.TPU_NAME }}" \ - --zone="${{ env.GCP_ZONE }}" \ + --zone="${{ steps.create_tpu.outputs.ACTIVE_ZONE }}" \ --command='mkdir -p ~/jax-spice && cd ~/jax-spice && tar -xzf ~/jax-spice.tar.gz && rm ~/jax-spice.tar.gz && curl -LsSf https://astral.sh/uv/install.sh | sh' - name: Install dependencies run: | gcloud compute tpus tpu-vm ssh "${{ env.TPU_NAME }}" \ - --zone="${{ env.GCP_ZONE }}" \ + --zone="${{ steps.create_tpu.outputs.ACTIVE_ZONE }}" \ --command='source ~/.local/bin/env && cd ~/jax-spice && uv sync && uv pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html' - name: Run profiling and tests run: | # TPU only supports F32 for LU decomposition, so we don't enable X64 gcloud compute tpus tpu-vm ssh "${{ env.TPU_NAME }}" \ - --zone="${{ env.GCP_ZONE }}" \ + --zone="${{ steps.create_tpu.outputs.ACTIVE_ZONE }}" \ --command='source ~/.local/bin/env && cd ~/jax-spice && export JAX_PLATFORMS=tpu && uv run python scripts/profile_gpu.py && uv run pytest tests/ -v --tb=short -x' \ | tee /tmp/test_output.txt @@ -83,6 +97,7 @@ jobs: run: | echo "## TPU Test Results" >> "$GITHUB_STEP_SUMMARY" echo "- **TPU Type:** ${{ env.TPU_TYPE }}" >> "$GITHUB_STEP_SUMMARY" + echo "- **Zone:** ${{ steps.create_tpu.outputs.ACTIVE_ZONE }}" >> "$GITHUB_STEP_SUMMARY" echo "- **Spot VM:** ${{ inputs.use_spot }}" >> "$GITHUB_STEP_SUMMARY" echo "" >> "$GITHUB_STEP_SUMMARY" if [ -f /tmp/test_output.txt ]; then @@ -92,6 +107,17 @@ jobs: - name: Cleanup TPU VM if: always() run: | - gcloud compute tpus tpu-vm delete "${{ env.TPU_NAME }}" \ - --zone="${{ env.GCP_ZONE }}" \ - --quiet || true + # Try to delete in all possible zones (in case we don't know which one was used) + ZONE="${{ steps.create_tpu.outputs.ACTIVE_ZONE }}" + if [ -n "$ZONE" ]; then + gcloud compute tpus tpu-vm delete "${{ env.TPU_NAME }}" \ + --zone="$ZONE" \ + --quiet || true + else + # Fallback: try all zones + for zone in us-central1-a us-west4-a us-east1-d us-east5-a; do + gcloud compute tpus tpu-vm delete "${{ env.TPU_NAME }}" \ + --zone="$zone" \ + --quiet 2>/dev/null || true + done + fi