From 504866ccc36306400532d70c59a99d888add4b94 Mon Sep 17 00:00:00 2001 From: Bradley Fargo Date: Sat, 28 Feb 2026 04:21:58 -0600 Subject: [PATCH 1/3] fix: Handle wide matrices in orthogonal initializer QR decomposition of an {m, n} matrix produces Q of shape {m, m}, which fails when n > m (e.g. LSTM weights {hidden, 4*hidden}). Generate a {max(m,n), max(m,n)} square random matrix so QR always produces enough orthogonal columns, then slice to {m, n}. Adds tests for wide 2D and high-rank shapes. Co-Authored-By: Claude Opus 4.6 --- lib/axon/initializers.ex | 9 +++++++-- test/axon/initializers_test.exs | 25 +++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/lib/axon/initializers.ex b/lib/axon/initializers.ex index a947f7a1d..5ae8ac2a5 100644 --- a/lib/axon/initializers.ex +++ b/lib/axon/initializers.ex @@ -686,13 +686,18 @@ defmodule Axon.Initializers do {m, n} = get_flat_shape(shape) + # Generate a square random matrix of size max(m, n) so QR + # produces enough orthogonal columns for wide matrices + # (e.g. LSTM weights shaped {hidden, 4*hidden}) + k = max(m, n) + random_seed = case distribution do :uniform -> - Nx.Random.uniform_split(key, 0.0, 1.0, shape: {m, n}, type: type) + Nx.Random.uniform_split(key, 0.0, 1.0, shape: {k, k}, type: type) :normal -> - Nx.Random.normal_split(key, 0.0, 1.0, shape: {m, n}, type: type) + Nx.Random.normal_split(key, 0.0, 1.0, shape: {k, k}, type: type) dist -> raise ArgumentError, diff --git a/test/axon/initializers_test.exs b/test/axon/initializers_test.exs index 09aa65e1e..b1b1bf785 100644 --- a/test/axon/initializers_test.exs +++ b/test/axon/initializers_test.exs @@ -164,6 +164,31 @@ defmodule Axon.InitializersTest do ) end + test "works with wide matrices (n > m, e.g. LSTM/GRU weights)" do + init_fn = Axon.Initializers.orthogonal() + + # Wide matrix like LSTM kernel {hidden, 4*hidden} + # Using small dims to keep QR fast on BinaryBackend + t = init_fn.({8, 32}, {:f, 32}, Nx.Random.key(1)) + assert Nx.shape(t) == {8, 32} + + # Rows should be orthonormal: t * t^T = I + identity = t |> Nx.dot(Nx.transpose(t)) + + assert_all_close(identity, Nx.eye(Nx.shape(identity)), + atol: 1.0e-3, + rtol: 1.0e-3 + ) + end + + test "works with wide high-rank shapes" do + init_fn = Axon.Initializers.orthogonal() + + # Shape that flattens to wide: {2, 3} -> {2, 3} where n > m + t = init_fn.({2, 8}, {:f, 32}, Nx.Random.key(1)) + assert Nx.shape(t) == {2, 8} + end + test "raises on input rank less than 2" do assert_raise ArgumentError, ~r/Axon.Initializers.orthogonal: expected input_shape shape to have at least rank 2/, From ee95893f95162167768a20c09e495f87623e7787 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 4 Mar 2026 18:21:32 -0300 Subject: [PATCH 2/3] Apply suggestions from code review --- lib/axon/initializers.ex | 2 ++ test/axon/initializers_test.exs | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/lib/axon/initializers.ex b/lib/axon/initializers.ex index 5ae8ac2a5..7d5262307 100644 --- a/lib/axon/initializers.ex +++ b/lib/axon/initializers.ex @@ -689,6 +689,8 @@ defmodule Axon.Initializers do # Generate a square random matrix of size max(m, n) so QR # produces enough orthogonal columns for wide matrices # (e.g. LSTM weights shaped {hidden, 4*hidden}) + # This is because mode: :complete returns Q {m, m} for an {m, n} tensor. + # mode: :reduced return Q {m, min(m, n)} k = max(m, n) random_seed = diff --git a/test/axon/initializers_test.exs b/test/axon/initializers_test.exs index b1b1bf785..e09166a43 100644 --- a/test/axon/initializers_test.exs +++ b/test/axon/initializers_test.exs @@ -173,7 +173,7 @@ defmodule Axon.InitializersTest do assert Nx.shape(t) == {8, 32} # Rows should be orthonormal: t * t^T = I - identity = t |> Nx.dot(Nx.transpose(t)) + identity = Nx.dot(t, 1, t, 1) assert_all_close(identity, Nx.eye(Nx.shape(identity)), atol: 1.0e-3, From 8a162dede54534c7f24fbe7b9767c5724ddc2891 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 4 Mar 2026 18:26:37 -0300 Subject: [PATCH 3/3] Apply suggestion from @polvalente --- test/axon/initializers_test.exs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/axon/initializers_test.exs b/test/axon/initializers_test.exs index e09166a43..28113128c 100644 --- a/test/axon/initializers_test.exs +++ b/test/axon/initializers_test.exs @@ -173,7 +173,7 @@ defmodule Axon.InitializersTest do assert Nx.shape(t) == {8, 32} # Rows should be orthonormal: t * t^T = I - identity = Nx.dot(t, 1, t, 1) + identity = Nx.dot(t, [1], t, [1]) assert_all_close(identity, Nx.eye(Nx.shape(identity)), atol: 1.0e-3,