diff --git a/.github/workflows/test-tpu.yml b/.github/workflows/test-tpu.yml new file mode 100644 index 00000000..aaa8be1b --- /dev/null +++ b/.github/workflows/test-tpu.yml @@ -0,0 +1,123 @@ +name: TPU Tests + +on: + pull_request: + workflow_dispatch: + inputs: + use_spot: + description: 'Use Spot VM (cheaper but may be preempted)' + required: true + default: true + type: boolean + +env: + GCP_PROJECT: jax-spice-cuda-test + GCP_ZONE: us-central1-a + TPU_NAME: jax-spice-tpu-${{ github.run_id }} + TPU_TYPE: v5litepod-8 + TPU_RUNTIME: v2-alpha-tpuv5-lite + +jobs: + tpu-tests: + runs-on: ubuntu-latest + permissions: + contents: read + id-token: write + + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + submodules: true + + - name: Authenticate to GCP + uses: google-github-actions/auth@v2 + with: + credentials_json: ${{ secrets.GCP_SERVICE_ACCOUNT_KEY }} + + - name: Set up Cloud SDK + uses: google-github-actions/setup-gcloud@v2 + + - name: Create TPU VM + id: create_tpu + run: | + SPOT_FLAG="" + if [ "${{ inputs.use_spot }}" = "true" ]; then + SPOT_FLAG="--spot" + fi + + # Try zones in order until one succeeds + ZONES="us-central1-a us-west4-a us-east1-d us-east5-a" + for zone in $ZONES; do + echo "Trying zone: $zone" + if gcloud compute tpus tpu-vm create "${{ env.TPU_NAME }}" \ + --zone="$zone" \ + --accelerator-type="${{ env.TPU_TYPE }}" \ + --version="${{ env.TPU_RUNTIME }}" \ + ${SPOT_FLAG} \ + --quiet 2>&1; then + echo "TPU created successfully in $zone" + echo "ACTIVE_ZONE=$zone" >> "$GITHUB_OUTPUT" + exit 0 + else + echo "Zone $zone failed, trying next..." + fi + done + echo "All zones exhausted" + exit 1 + + - name: Sync code and setup environment + run: | + tar --exclude='.git' --exclude='__pycache__' -czf /tmp/jax-spice.tar.gz . + + gcloud compute tpus tpu-vm scp /tmp/jax-spice.tar.gz \ + "${{ env.TPU_NAME }}":~/jax-spice.tar.gz \ + --zone="${{ steps.create_tpu.outputs.ACTIVE_ZONE }}" + + gcloud compute tpus tpu-vm ssh "${{ env.TPU_NAME }}" \ + --zone="${{ steps.create_tpu.outputs.ACTIVE_ZONE }}" \ + --command='mkdir -p ~/jax-spice && cd ~/jax-spice && tar -xzf ~/jax-spice.tar.gz && rm ~/jax-spice.tar.gz && curl -LsSf https://astral.sh/uv/install.sh | sh' + + - name: Install dependencies + run: | + gcloud compute tpus tpu-vm ssh "${{ env.TPU_NAME }}" \ + --zone="${{ steps.create_tpu.outputs.ACTIVE_ZONE }}" \ + --command='source ~/.local/bin/env && cd ~/jax-spice && uv sync && uv pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html' + + - name: Run profiling and tests + run: | + # TPU only supports F32 for LU decomposition, so we don't enable X64 + gcloud compute tpus tpu-vm ssh "${{ env.TPU_NAME }}" \ + --zone="${{ steps.create_tpu.outputs.ACTIVE_ZONE }}" \ + --command='source ~/.local/bin/env && cd ~/jax-spice && export JAX_PLATFORMS=tpu && uv run python scripts/profile_gpu.py && uv run pytest tests/ -v --tb=short -x' \ + | tee /tmp/test_output.txt + + - name: Extract profiling report + if: always() + run: | + echo "## TPU Test Results" >> "$GITHUB_STEP_SUMMARY" + echo "- **TPU Type:** ${{ env.TPU_TYPE }}" >> "$GITHUB_STEP_SUMMARY" + echo "- **Zone:** ${{ steps.create_tpu.outputs.ACTIVE_ZONE }}" >> "$GITHUB_STEP_SUMMARY" + echo "- **Spot VM:** ${{ inputs.use_spot }}" >> "$GITHUB_STEP_SUMMARY" + echo "" >> "$GITHUB_STEP_SUMMARY" + if [ -f /tmp/test_output.txt ]; then + sed -n '/# JAX-SPICE/,/Report written/p' /tmp/test_output.txt | head -n -1 >> "$GITHUB_STEP_SUMMARY" || true + fi + + - name: Cleanup TPU VM + if: always() + run: | + # Try to delete in all possible zones (in case we don't know which one was used) + ZONE="${{ steps.create_tpu.outputs.ACTIVE_ZONE }}" + if [ -n "$ZONE" ]; then + gcloud compute tpus tpu-vm delete "${{ env.TPU_NAME }}" \ + --zone="$ZONE" \ + --quiet || true + else + # Fallback: try all zones + for zone in us-central1-a us-west4-a us-east1-d us-east5-a; do + gcloud compute tpus tpu-vm delete "${{ env.TPU_NAME }}" \ + --zone="$zone" \ + --quiet 2>/dev/null || true + done + fi diff --git a/jax_spice/analysis/sparse.py b/jax_spice/analysis/sparse.py index e06af8b4..faf0824b 100644 --- a/jax_spice/analysis/sparse.py +++ b/jax_spice/analysis/sparse.py @@ -4,6 +4,7 @@ selects the best backend based on the JAX platform: - CPU: Uses scipy.sparse.linalg.spsolve via jax.pure_callback - GPU: Uses jax.experimental.sparse.linalg.spsolve (cuSOLVER) +- TPU: Uses dense solve via jnp.linalg.solve (spsolve not supported on TPU) The solver supports reverse-mode autodiff through jax.custom_vjp using the adjoint method for implicit differentiation. @@ -19,6 +20,7 @@ from typing import Tuple import jax +import jax.numpy as jnp from jax import Array import numpy as np from scipy.sparse import csc_matrix, csr_matrix @@ -50,8 +52,35 @@ def sparse_solve( backend = jax.default_backend() if backend == 'cpu': return _spsolve_cpu(data, indices, indptr, b, shape) - else: + elif backend in ('gpu', 'cuda'): return _spsolve_gpu(data, indices, indptr, b, shape) + else: + # TPU and other backends: use dense solve (spsolve not supported) + return _solve_dense_csc(data, indices, indptr, b, shape) + + +def _solve_dense_csc( + data: Array, + indices: Array, + indptr: Array, + b: Array, + shape: Tuple[int, int] +) -> Array: + """Dense solve by reconstructing matrix from CSC format. + + Used for TPU and other backends that don't support sparse solve. + """ + from jax.experimental.sparse import BCOO + + # Convert CSC to BCOO and then to dense + # CSC: data[k] is at (indices[k], col) where col is determined by indptr + n = shape[0] + col_indices = jnp.repeat(jnp.arange(n), jnp.diff(indptr)) + bcoo_indices = jnp.stack([indices, col_indices], axis=1) + A_bcoo = BCOO((data, bcoo_indices), shape=shape) + A_dense = A_bcoo.todense() + + return jnp.linalg.solve(A_dense, b) def _spsolve_cpu( @@ -199,8 +228,35 @@ def sparse_solve_csr( backend = jax.default_backend() if backend == 'cpu': return _spsolve_cpu_csr(data, indices, indptr, b, shape) - else: + elif backend in ('gpu', 'cuda'): return _spsolve_gpu(data, indices, indptr, b, shape) + else: + # TPU and other backends: use dense solve (spsolve not supported) + return _solve_dense_csr(data, indices, indptr, b, shape) + + +def _solve_dense_csr( + data: Array, + indices: Array, + indptr: Array, + b: Array, + shape: Tuple[int, int] +) -> Array: + """Dense solve by reconstructing matrix from CSR format. + + Used for TPU and other backends that don't support sparse solve. + """ + from jax.experimental.sparse import BCOO + + # Convert CSR to BCOO and then to dense + # CSR: data[k] is at (row, indices[k]) where row is determined by indptr + n = shape[0] + row_indices = jnp.repeat(jnp.arange(n), jnp.diff(indptr)) + bcoo_indices = jnp.stack([row_indices, indices], axis=1) + A_bcoo = BCOO((data, bcoo_indices), shape=shape) + A_dense = A_bcoo.todense() + + return jnp.linalg.solve(A_dense, b) def _spsolve_cpu_csr( diff --git a/scripts/profile_gpu.py b/scripts/profile_gpu.py index af49b887..a23050c2 100644 --- a/scripts/profile_gpu.py +++ b/scripts/profile_gpu.py @@ -21,8 +21,9 @@ import jax.numpy as jnp import numpy as np -# Enable float64 -jax.config.update('jax_enable_x64', True) +# Enable float64 (except on TPU which only supports F32 for LU decomposition) +if os.environ.get('JAX_PLATFORMS') != 'tpu': + jax.config.update('jax_enable_x64', True) @dataclass diff --git a/scripts/setup_gcp_tpu_ci.sh b/scripts/setup_gcp_tpu_ci.sh new file mode 100644 index 00000000..34acb13c --- /dev/null +++ b/scripts/setup_gcp_tpu_ci.sh @@ -0,0 +1,286 @@ +#!/bin/bash +# Setup script for GCP TPU CI infrastructure +# This script is fully idempotent - safe to run multiple times +# +# Usage: ./setup_gcp_tpu_ci.sh +# +# Features: +# - Enables TPU API +# - Creates/updates service account with TPU permissions +# - Stores service account key in GCP Secret Manager +# - Syncs secret to GitHub Actions +# +# Prerequisites: +# - gcloud CLI installed and authenticated +# - gh CLI installed and authenticated (for GitHub secret sync) +# - Billing enabled on the GCP project +# - TPU quota available in the zone (request at https://cloud.google.com/tpu/docs/quota) +# +# TPU Quota: +# - You need quota for the TPU type you want to use +# - For v5e: request "TPU v5 Lite PodSlice chips" quota +# - For Spot VMs: request preemptible quota separately +# - Quota request: https://console.cloud.google.com/iam-admin/quotas + +set -euo pipefail + +# Configuration +PROJECT_ID="${GCP_PROJECT:-jax-spice-cuda-test}" +REGION="${GCP_REGION:-us-central1}" +ZONE="${GCP_ZONE:-us-central1-a}" # Must have TPU v5e availability +SA_NAME="github-gpu-ci" # Reuse existing service account +SA_EMAIL="${SA_NAME}@${PROJECT_ID}.iam.gserviceaccount.com" +SECRET_NAME="github-gpu-ci-key" +GITHUB_REPO="${GITHUB_REPO:-ChipFlow/jax-spice}" + +# TPU Configuration +TPU_TYPE="${TPU_TYPE:-v5litepod-8}" # Smallest v5e configuration +TPU_RUNTIME="${TPU_RUNTIME:-v2-alpha-tpuv5-lite}" + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +log_info() { echo -e "${GREEN}[INFO]${NC} $1"; } +log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; } +log_error() { echo -e "${RED}[ERROR]${NC} $1"; } + +echo "==========================================" +echo " JAX-SPICE TPU CI Setup (Idempotent)" +echo "==========================================" +echo "" +echo "Project: ${PROJECT_ID}" +echo "Zone: ${ZONE}" +echo "TPU Type: ${TPU_TYPE}" +echo "Runtime: ${TPU_RUNTIME}" +echo "GitHub Repo: ${GITHUB_REPO}" +echo "" + +# Set project +gcloud config set project "${PROJECT_ID}" --quiet + +# ============================================================================= +# Step 1: Enable required APIs (idempotent) +# ============================================================================= +log_info "Enabling required APIs..." +gcloud services enable tpu.googleapis.com --quiet +gcloud services enable compute.googleapis.com --quiet +gcloud services enable secretmanager.googleapis.com --quiet +gcloud services enable iam.googleapis.com --quiet +log_info "APIs enabled" + +# ============================================================================= +# Step 2: Check/create service account (idempotent) +# ============================================================================= +log_info "Setting up service account..." +if gcloud iam service-accounts describe "${SA_EMAIL}" &>/dev/null; then + log_info "Service account already exists: ${SA_EMAIL}" +else + gcloud iam service-accounts create "${SA_NAME}" \ + --display-name="GitHub GPU/TPU CI Runner" \ + --quiet + log_info "Created service account: ${SA_EMAIL}" +fi + +# ============================================================================= +# Step 3: Grant TPU permissions (idempotent) +# ============================================================================= +log_info "Configuring IAM permissions for TPU..." + +# Define required roles for TPU access +ROLES=( + "roles/tpu.admin" # Create/delete/manage TPU VMs + "roles/compute.networkUser" # Use VPC networks for TPU + "roles/iam.serviceAccountUser" # Use service account on TPU VM + "roles/logging.viewer" # View logs + "roles/storage.objectViewer" # Read from GCS (for JAX wheels) +) + +for ROLE in "${ROLES[@]}"; do + gcloud projects add-iam-policy-binding "${PROJECT_ID}" \ + --member="serviceAccount:${SA_EMAIL}" \ + --role="${ROLE}" \ + --quiet \ + --condition=None 2>/dev/null || true +done +log_info "IAM permissions configured" + +# ============================================================================= +# Step 4: Check TPU quota +# ============================================================================= +log_info "Checking TPU quota..." + +# Extract chip count from TPU type (e.g., v5litepod-8 -> 8) +CHIP_COUNT=$(echo "${TPU_TYPE}" | grep -oE '[0-9]+$') + +echo "" +echo "Required quota for ${TPU_TYPE}:" +echo " - TPU v5 Lite PodSlice chips: ${CHIP_COUNT} (on-demand)" +echo " - Preemptible TPU v5 Lite PodSlice chips: ${CHIP_COUNT} (for Spot VMs)" +echo "" +echo "Check your quota at:" +echo " https://console.cloud.google.com/iam-admin/quotas?project=${PROJECT_ID}" +echo "" +echo "Filter by: 'tpu' and region '${REGION}'" +echo "" + +# ============================================================================= +# Step 5: Verify TPU availability in zone +# ============================================================================= +log_info "Checking TPU availability in ${ZONE}..." + +# List available accelerator types +AVAILABLE_TYPES=$(gcloud compute tpus accelerator-types list \ + --zone="${ZONE}" \ + --format="value(type)" 2>/dev/null | grep -E "^v5litepod" || echo "") + +if [ -z "${AVAILABLE_TYPES}" ]; then + log_warn "No v5e TPUs found in ${ZONE}. Available zones for v5e:" + echo " - us-central1-a" + echo " - us-south1-a" + echo " - us-west1-c" + echo " - us-west4-a" + echo " - europe-west4-b" + echo "" + echo "Update GCP_ZONE environment variable and re-run." +else + log_info "Available TPU types in ${ZONE}:" + echo "${AVAILABLE_TYPES}" | sed 's/^/ - /' +fi + +# ============================================================================= +# Step 6: Create/update secret in Secret Manager (idempotent) +# ============================================================================= +log_info "Setting up Secret Manager..." + +# Check if secret exists +if gcloud secrets describe "${SECRET_NAME}" &>/dev/null; then + log_info "Secret already exists: ${SECRET_NAME}" + + # Check if we need to create a new key version + LATEST_VERSION=$(gcloud secrets versions list "${SECRET_NAME}" \ + --filter="state=ENABLED" \ + --sort-by="~createTime" \ + --limit=1 \ + --format="value(name)" 2>/dev/null || echo "") + + if [ -n "${LATEST_VERSION}" ]; then + log_info "Using existing secret version" + NEED_NEW_KEY=false + else + log_warn "No enabled secret versions found, creating new key" + NEED_NEW_KEY=true + fi +else + log_info "Creating new secret: ${SECRET_NAME}" + gcloud secrets create "${SECRET_NAME}" \ + --replication-policy="automatic" \ + --quiet + NEED_NEW_KEY=true +fi + +# Create new key and add to secret if needed +if [ "${NEED_NEW_KEY:-false}" = true ]; then + log_info "Generating new service account key..." + + # Create temporary key file + KEY_FILE=$(mktemp) + trap "rm -f ${KEY_FILE}" EXIT + + gcloud iam service-accounts keys create "${KEY_FILE}" \ + --iam-account="${SA_EMAIL}" \ + --quiet + + # Add new version to secret + gcloud secrets versions add "${SECRET_NAME}" \ + --data-file="${KEY_FILE}" \ + --quiet + + log_info "Service account key stored in Secret Manager" + + # Clean up old keys (keep only the 2 most recent) + log_info "Cleaning up old service account keys..." + OLD_KEYS=$(gcloud iam service-accounts keys list \ + --iam-account="${SA_EMAIL}" \ + --format="value(name)" \ + --filter="keyType=USER_MANAGED" \ + --sort-by="~validAfterTime" 2>/dev/null | tail -n +3) + + for KEY_ID in ${OLD_KEYS}; do + gcloud iam service-accounts keys delete "${KEY_ID}" \ + --iam-account="${SA_EMAIL}" \ + --quiet 2>/dev/null || true + done +fi + +# ============================================================================= +# Step 7: Sync secret to GitHub (idempotent) +# ============================================================================= +log_info "Syncing secret to GitHub..." + +if command -v gh &>/dev/null; then + if gh auth status &>/dev/null; then + SECRET_VALUE=$(gcloud secrets versions access latest --secret="${SECRET_NAME}" 2>/dev/null) + + if [ -n "${SECRET_VALUE}" ]; then + echo "${SECRET_VALUE}" | gh secret set GCP_SERVICE_ACCOUNT_KEY \ + --repo="${GITHUB_REPO}" 2>/dev/null && \ + log_info "GitHub secret 'GCP_SERVICE_ACCOUNT_KEY' updated" || \ + log_warn "Failed to update GitHub secret (check gh permissions)" + else + log_error "Could not retrieve secret from Secret Manager" + fi + else + log_warn "gh CLI not authenticated. Run 'gh auth login' to sync secrets" + fi +else + log_warn "gh CLI not installed. Install it to auto-sync GitHub secrets" +fi + +# ============================================================================= +# Step 8: Test TPU creation (optional, commented out) +# ============================================================================= +# Uncomment to test TPU VM creation: +# +# log_info "Testing TPU VM creation..." +# gcloud compute tpus tpu-vm create test-tpu-vm \ +# --zone="${ZONE}" \ +# --accelerator-type="${TPU_TYPE}" \ +# --version="${TPU_RUNTIME}" \ +# --spot \ +# --quiet +# +# log_info "TPU VM created successfully, deleting..." +# gcloud compute tpus tpu-vm delete test-tpu-vm \ +# --zone="${ZONE}" \ +# --quiet + +# ============================================================================= +# Summary +# ============================================================================= +echo "" +echo "==========================================" +echo " TPU CI Setup Complete!" +echo "==========================================" +echo "" +echo "Resources configured:" +echo " - Service Account: ${SA_EMAIL}" +echo " - Secret (GCP): ${SECRET_NAME}" +echo " - GitHub Secret: GCP_SERVICE_ACCOUNT_KEY (shared with GPU CI)" +echo "" +echo "TPU Configuration:" +echo " - Type: ${TPU_TYPE} (8 chips, 128 GB HBM2)" +echo " - Zone: ${ZONE}" +echo " - Runtime: ${TPU_RUNTIME}" +echo "" +echo "Estimated costs:" +echo " - On-demand: ~\$9.60/hour (8 × \$1.20/chip)" +echo " - Spot: ~\$1-2/hour (up to 91% discount)" +echo "" +echo "Next steps:" +echo " 1. Request TPU quota if needed (link above)" +echo " 2. Push the test-tpu.yml workflow" +echo " 3. Trigger workflow manually or via PR" +echo "" diff --git a/tests/conftest.py b/tests/conftest.py index 04c65bae..a78f1075 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,8 +3,9 @@ Handles platform-specific JAX configuration: - macOS: Forces CPU backend since Metal doesn't support triangular_solve - Linux with CUDA: Preloads CUDA libraries to help JAX discover them +- Linux with TPU: Uses TPU backend when JAX_PLATFORMS=tpu is set -Uses pytest_configure hook to ensure CUDA setup happens before any test imports. +Uses pytest_configure hook to ensure backend setup happens before any test imports. """ import os @@ -36,19 +37,27 @@ def pytest_configure(config): """ Pytest hook that runs before test collection. - This ensures CUDA libraries are preloaded and JAX is configured + This ensures backend libraries are preloaded and JAX is configured BEFORE any test modules are imported. """ + jax_platforms = os.environ.get('JAX_PLATFORMS', '') + # Platform-specific configuration BEFORE importing JAX if sys.platform == 'darwin': # macOS: Force CPU backend - Metal doesn't support triangular_solve os.environ['JAX_PLATFORMS'] = 'cpu' - elif sys.platform == 'linux' and os.environ.get('JAX_PLATFORMS', '').startswith('cuda'): + elif sys.platform == 'linux' and jax_platforms.startswith('cuda'): # Linux with CUDA: Preload CUDA libraries before JAX import _setup_cuda_libraries() + elif sys.platform == 'linux' and jax_platforms == 'tpu': + # Linux with TPU: JAX handles TPU initialization via libtpu + # No special preloading needed - libtpu is installed with jax[tpu] + pass # Import JAX and configure it import jax # Enable float64 for numerical precision in tests - jax.config.update('jax_enable_x64', True) + # (except on TPU which only supports F32 for LU decomposition) + if jax_platforms != 'tpu': + jax.config.update('jax_enable_x64', True)