-
Notifications
You must be signed in to change notification settings - Fork 122
Description
Bug
Axon.Initializers.orthogonal() raises an ArgumentError when the weight shape has more columns than rows. This affects recurrent layers where kernels combine multiple gates — LSTM {hidden, 4*hidden}, GRU {hidden, 3*hidden}.
Reproduce
init_fn = Axon.Initializers.orthogonal()
init_fn.({64, 256}, {:f, 32}, Nx.Random.key(0))
# ** (ArgumentError) length at axis 1 must be less than axis size of 64, got: 256Cause
orthogonal_impl generates a random matrix of shape {m, n}, then takes its QR decomposition. QR of {m, n} produces Q of shape {m, m} — only m orthogonal columns. When n > m, the subsequent slice Q[:m, :n] fails because Q doesn't have n columns.
Expected
JAX (jax.nn.initializers.orthogonal), PyTorch (torch.nn.init.orthogonal_), and TensorFlow (tf.initializers.Orthogonal) all handle wide matrices. Axon should too.
Fix
Generate a {max(m, n), max(m, n)} square random matrix so QR always produces enough orthogonal columns, then slice to {m, n}.
Environment
- Axon: main (d5ecacb)
- Nx: 0.9+