Skip to content

Runtime shape contracts and diagnostics for NumPy and JAX arrays

License

Notifications You must be signed in to change notification settings

jayendra13/jax-shape-guard

Repository files navigation

ShapeGuard

Tests Lint PyPI version Python versions License

Runtime shape contracts and diagnostics for NumPy and JAX.

Installation

pip install shapeguard

Quick Start

from shapeguard import Dim, expects

n, m, k = Dim("n"), Dim("m"), Dim("k")

@expects(a=(n, m), b=(m, k))
def matmul(a, b):
    return a @ b

When shapes don't match, you get clear errors:

ShapeGuardError:
  function: matmul
  argument: b
  expected: (m, k)
  actual:   (5, 7)
  reason:   dimension 'm' bound to 4 from a.shape[1], but got 5 from b.shape[0]

License

MIT

About

Runtime shape contracts and diagnostics for NumPy and JAX arrays

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages