Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions lib/axon/initializers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -686,13 +686,20 @@ defmodule Axon.Initializers do

{m, n} = get_flat_shape(shape)

# Generate a square random matrix of size max(m, n) so QR
# produces enough orthogonal columns for wide matrices
# (e.g. LSTM weights shaped {hidden, 4*hidden})
# This is because mode: :complete returns Q {m, m} for an {m, n} tensor.
# mode: :reduced return Q {m, min(m, n)}
k = max(m, n)

random_seed =
case distribution do
:uniform ->
Nx.Random.uniform_split(key, 0.0, 1.0, shape: {m, n}, type: type)
Nx.Random.uniform_split(key, 0.0, 1.0, shape: {k, k}, type: type)

:normal ->
Nx.Random.normal_split(key, 0.0, 1.0, shape: {m, n}, type: type)
Nx.Random.normal_split(key, 0.0, 1.0, shape: {k, k}, type: type)

dist ->
raise ArgumentError,
Expand Down
25 changes: 25 additions & 0 deletions test/axon/initializers_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,31 @@ defmodule Axon.InitializersTest do
)
end

test "works with wide matrices (n > m, e.g. LSTM/GRU weights)" do
init_fn = Axon.Initializers.orthogonal()

# Wide matrix like LSTM kernel {hidden, 4*hidden}
# Using small dims to keep QR fast on BinaryBackend
t = init_fn.({8, 32}, {:f, 32}, Nx.Random.key(1))
assert Nx.shape(t) == {8, 32}

# Rows should be orthonormal: t * t^T = I
identity = Nx.dot(t, [1], t, [1])

assert_all_close(identity, Nx.eye(Nx.shape(identity)),
atol: 1.0e-3,
rtol: 1.0e-3
)
end

test "works with wide high-rank shapes" do
init_fn = Axon.Initializers.orthogonal()

# Shape that flattens to wide: {2, 3} -> {2, 3} where n > m
t = init_fn.({2, 8}, {:f, 32}, Nx.Random.key(1))
assert Nx.shape(t) == {2, 8}
end

test "raises on input rank less than 2" do
assert_raise ArgumentError,
~r/Axon.Initializers.orthogonal: expected input_shape shape to have at least rank 2/,
Expand Down
Loading