Skip to content
123 changes: 123 additions & 0 deletions .github/workflows/test-tpu.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
name: TPU Tests

on:
pull_request:
workflow_dispatch:
inputs:
use_spot:
description: 'Use Spot VM (cheaper but may be preempted)'
required: true
default: true
type: boolean

env:
GCP_PROJECT: jax-spice-cuda-test
GCP_ZONE: us-central1-a
TPU_NAME: jax-spice-tpu-${{ github.run_id }}
TPU_TYPE: v5litepod-8
TPU_RUNTIME: v2-alpha-tpuv5-lite

jobs:
tpu-tests:
runs-on: ubuntu-latest
permissions:
contents: read
id-token: write

steps:
- name: Checkout
uses: actions/checkout@v4
with:
submodules: true

- name: Authenticate to GCP
uses: google-github-actions/auth@v2
with:
credentials_json: ${{ secrets.GCP_SERVICE_ACCOUNT_KEY }}

- name: Set up Cloud SDK
uses: google-github-actions/setup-gcloud@v2

- name: Create TPU VM
id: create_tpu
run: |
SPOT_FLAG=""
if [ "${{ inputs.use_spot }}" = "true" ]; then
SPOT_FLAG="--spot"
fi

# Try zones in order until one succeeds
ZONES="us-central1-a us-west4-a us-east1-d us-east5-a"
for zone in $ZONES; do
echo "Trying zone: $zone"
if gcloud compute tpus tpu-vm create "${{ env.TPU_NAME }}" \
--zone="$zone" \
--accelerator-type="${{ env.TPU_TYPE }}" \
--version="${{ env.TPU_RUNTIME }}" \
${SPOT_FLAG} \
--quiet 2>&1; then
echo "TPU created successfully in $zone"
echo "ACTIVE_ZONE=$zone" >> "$GITHUB_OUTPUT"
exit 0
else
echo "Zone $zone failed, trying next..."
fi
done
echo "All zones exhausted"
exit 1

- name: Sync code and setup environment
run: |
tar --exclude='.git' --exclude='__pycache__' -czf /tmp/jax-spice.tar.gz .

gcloud compute tpus tpu-vm scp /tmp/jax-spice.tar.gz \
"${{ env.TPU_NAME }}":~/jax-spice.tar.gz \
--zone="${{ steps.create_tpu.outputs.ACTIVE_ZONE }}"

gcloud compute tpus tpu-vm ssh "${{ env.TPU_NAME }}" \
--zone="${{ steps.create_tpu.outputs.ACTIVE_ZONE }}" \
--command='mkdir -p ~/jax-spice && cd ~/jax-spice && tar -xzf ~/jax-spice.tar.gz && rm ~/jax-spice.tar.gz && curl -LsSf https://astral.sh/uv/install.sh | sh'

- name: Install dependencies
run: |
gcloud compute tpus tpu-vm ssh "${{ env.TPU_NAME }}" \
--zone="${{ steps.create_tpu.outputs.ACTIVE_ZONE }}" \
--command='source ~/.local/bin/env && cd ~/jax-spice && uv sync && uv pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

- name: Run profiling and tests
run: |
# TPU only supports F32 for LU decomposition, so we don't enable X64
gcloud compute tpus tpu-vm ssh "${{ env.TPU_NAME }}" \
--zone="${{ steps.create_tpu.outputs.ACTIVE_ZONE }}" \
--command='source ~/.local/bin/env && cd ~/jax-spice && export JAX_PLATFORMS=tpu && uv run python scripts/profile_gpu.py && uv run pytest tests/ -v --tb=short -x' \
| tee /tmp/test_output.txt

- name: Extract profiling report
if: always()
run: |
echo "## TPU Test Results" >> "$GITHUB_STEP_SUMMARY"
echo "- **TPU Type:** ${{ env.TPU_TYPE }}" >> "$GITHUB_STEP_SUMMARY"
echo "- **Zone:** ${{ steps.create_tpu.outputs.ACTIVE_ZONE }}" >> "$GITHUB_STEP_SUMMARY"
echo "- **Spot VM:** ${{ inputs.use_spot }}" >> "$GITHUB_STEP_SUMMARY"
echo "" >> "$GITHUB_STEP_SUMMARY"
if [ -f /tmp/test_output.txt ]; then
sed -n '/# JAX-SPICE/,/Report written/p' /tmp/test_output.txt | head -n -1 >> "$GITHUB_STEP_SUMMARY" || true
fi

- name: Cleanup TPU VM
if: always()
run: |
# Try to delete in all possible zones (in case we don't know which one was used)
ZONE="${{ steps.create_tpu.outputs.ACTIVE_ZONE }}"
if [ -n "$ZONE" ]; then
gcloud compute tpus tpu-vm delete "${{ env.TPU_NAME }}" \
--zone="$ZONE" \
--quiet || true
else
# Fallback: try all zones
for zone in us-central1-a us-west4-a us-east1-d us-east5-a; do
gcloud compute tpus tpu-vm delete "${{ env.TPU_NAME }}" \
--zone="$zone" \
--quiet 2>/dev/null || true
done
fi
60 changes: 58 additions & 2 deletions jax_spice/analysis/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
selects the best backend based on the JAX platform:
- CPU: Uses scipy.sparse.linalg.spsolve via jax.pure_callback
- GPU: Uses jax.experimental.sparse.linalg.spsolve (cuSOLVER)
- TPU: Uses dense solve via jnp.linalg.solve (spsolve not supported on TPU)

The solver supports reverse-mode autodiff through jax.custom_vjp using
the adjoint method for implicit differentiation.
Expand All @@ -19,6 +20,7 @@

from typing import Tuple
import jax
import jax.numpy as jnp
from jax import Array
import numpy as np
from scipy.sparse import csc_matrix, csr_matrix
Expand Down Expand Up @@ -50,8 +52,35 @@ def sparse_solve(
backend = jax.default_backend()
if backend == 'cpu':
return _spsolve_cpu(data, indices, indptr, b, shape)
else:
elif backend in ('gpu', 'cuda'):
return _spsolve_gpu(data, indices, indptr, b, shape)
else:
# TPU and other backends: use dense solve (spsolve not supported)
return _solve_dense_csc(data, indices, indptr, b, shape)


def _solve_dense_csc(
data: Array,
indices: Array,
indptr: Array,
b: Array,
shape: Tuple[int, int]
) -> Array:
"""Dense solve by reconstructing matrix from CSC format.

Used for TPU and other backends that don't support sparse solve.
"""
from jax.experimental.sparse import BCOO

# Convert CSC to BCOO and then to dense
# CSC: data[k] is at (indices[k], col) where col is determined by indptr
n = shape[0]
col_indices = jnp.repeat(jnp.arange(n), jnp.diff(indptr))
bcoo_indices = jnp.stack([indices, col_indices], axis=1)
A_bcoo = BCOO((data, bcoo_indices), shape=shape)
A_dense = A_bcoo.todense()

return jnp.linalg.solve(A_dense, b)


def _spsolve_cpu(
Expand Down Expand Up @@ -199,8 +228,35 @@ def sparse_solve_csr(
backend = jax.default_backend()
if backend == 'cpu':
return _spsolve_cpu_csr(data, indices, indptr, b, shape)
else:
elif backend in ('gpu', 'cuda'):
return _spsolve_gpu(data, indices, indptr, b, shape)
else:
# TPU and other backends: use dense solve (spsolve not supported)
return _solve_dense_csr(data, indices, indptr, b, shape)


def _solve_dense_csr(
data: Array,
indices: Array,
indptr: Array,
b: Array,
shape: Tuple[int, int]
) -> Array:
"""Dense solve by reconstructing matrix from CSR format.

Used for TPU and other backends that don't support sparse solve.
"""
from jax.experimental.sparse import BCOO

# Convert CSR to BCOO and then to dense
# CSR: data[k] is at (row, indices[k]) where row is determined by indptr
n = shape[0]
row_indices = jnp.repeat(jnp.arange(n), jnp.diff(indptr))
bcoo_indices = jnp.stack([row_indices, indices], axis=1)
A_bcoo = BCOO((data, bcoo_indices), shape=shape)
A_dense = A_bcoo.todense()

return jnp.linalg.solve(A_dense, b)


def _spsolve_cpu_csr(
Expand Down
5 changes: 3 additions & 2 deletions scripts/profile_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
import jax.numpy as jnp
import numpy as np

# Enable float64
jax.config.update('jax_enable_x64', True)
# Enable float64 (except on TPU which only supports F32 for LU decomposition)
if os.environ.get('JAX_PLATFORMS') != 'tpu':
jax.config.update('jax_enable_x64', True)


@dataclass
Expand Down
Loading
Loading