Runtime shape contracts and diagnostics for NumPy and JAX.
pip install shapeguardfrom shapeguard import Dim, expects
n, m, k = Dim("n"), Dim("m"), Dim("k")
@expects(a=(n, m), b=(m, k))
def matmul(a, b):
return a @ bWhen shapes don't match, you get clear errors:
ShapeGuardError:
function: matmul
argument: b
expected: (m, k)
actual: (5, 7)
reason: dimension 'm' bound to 4 from a.shape[1], but got 5 from b.shape[0]
MIT