diff --git a/lib/axon.ex b/lib/axon.ex index 7438607d..2c381ccf 100644 --- a/lib/axon.ex +++ b/lib/axon.ex @@ -2070,6 +2070,74 @@ defmodule Axon do ) end + @doc ~S""" + Adds an RMS normalization layer to the network. + + RMS normalization normalizes the input tensor using only the root + mean square, without centering by the mean. This is computationally + simpler than layer normalization while achieving similar results. + + See `Axon.Layers.rms_norm/3` for more details. + + $$y = \frac{x}{\sqrt{E[x^2] + \epsilon}} * (\text{shift} + \gamma)$$ + + ## Options + + * `:name` - layer name. + + * `:gamma_initializer` - gamma parameter initializer. Defaults + to `:ones`. + + * `:channel_index` - input feature index used for calculating + the root mean square. Defaults to `-1`. + + * `:epsilon` - numerical stability term. Defaults to `1.0e-6`. + + * `:shift` - numeric shift added to gamma before scaling. + Defaults to `0.0`. + + * `:upcast` - adds explicit type casting to make sure the norm + is computed in high numerical precision. Either of: + + * `:normalization` (default) - upcasts only the input normalization + part + + * `:all` - upcasts both input normalization and the scaling + expression + + ## References + + * [Root Mean Square Layer Normalization](https://arxiv.org/abs/1910.07467) + + """ + @doc type: :normalization + def rms_norm(%Axon{} = x, opts \\ []) do + opts = + Keyword.validate!(opts, [ + :name, + :meta, + gamma_initializer: :ones, + channel_index: -1, + epsilon: 1.0e-6, + shift: 0.0, + upcast: :normalization + ]) + + channel_index = opts[:channel_index] + gamma_shape = &Axon.Shape.norm_param(&1, channel_index) + gamma = param("gamma", gamma_shape, initializer: opts[:gamma_initializer]) + + layer(:rms_norm, [x, gamma], + name: opts[:name], + meta: opts[:meta], + epsilon: opts[:epsilon], + channel_index: channel_index, + shift: opts[:shift], + upcast: opts[:upcast], + op_name: :rms_norm + ) + end + @doc """ Applies the given `Nx` expression to the input. diff --git a/lib/axon/layers.ex b/lib/axon/layers.ex index 479d8a48..bc4b0582 100644 --- a/lib/axon/layers.ex +++ b/lib/axon/layers.ex @@ -1301,6 +1301,101 @@ defmodule Axon.Layers do Nx.Shape.normalize_axis(Nx.shape(input), channel_index, Nx.shape(input)) end + @doc ~S""" + Functional implementation of RMS normalization. + + Normalizes the input by calculating the root mean square of the + input tensor along the given feature dimension `:channel_index`. + Unlike layer normalization, RMS normalization does not center the + input by subtracting the mean. + + $$y = \frac{x}{\sqrt{E[x^2] + \epsilon}} * (\text{shift} + \gamma)$$ + + `gamma` is often a trainable parameter. This method does not maintain + an EMA of variance. + + ## Options + + * `:epsilon` - numerical stability term. $\epsilon$ in the above + formulation. Defaults to `1.0e-6`. + + * `:channel_index` - channel index used to determine reduction + axes for RMS calculation. Defaults to `-1`. + + * `:shift` - numeric shift added to gamma before scaling. + Defaults to `0.0`. + + * `:upcast` - controls type casting for numerical precision. + Either `:normalization` (default) to upcast only the normalization + part, or `:all` to upcast the entire computation. + + ## References + + * [Root Mean Square Layer Normalization](https://arxiv.org/abs/1910.07467) + """ + @doc type: :normalization + defn rms_norm(input, gamma, opts \\ []) do + opts = + keyword!(opts, + epsilon: 1.0e-6, + channel_index: -1, + shift: 0.0, + upcast: :normalization, + mode: :inference + ) + + rms_norm_impl(input, gamma, opts) + end + + deftransformp rms_norm_impl(input, gamma, opts) do + case opts[:upcast] do + :normalization -> + rms_norm_upcast_normalization(input, gamma, opts) + + :all -> + rms_norm_upcast_all(input, gamma, opts) + + other -> + raise ArgumentError, + "expected :upcast to be either :all or :normalization, got: #{inspect(other)}" + end + end + + defnp rms_norm_upcast_normalization(input, gamma, opts) do + num_channels = Nx.axis_size(input, opts[:channel_index]) + parameter_shape = norm_parameter_reshape(input, num_channels, opts[:channel_index]) + gamma = Nx.reshape(gamma, parameter_shape) + + normalized_input = + input + |> Nx.as_type(:f32) + |> rms_normalize(opts) + |> Nx.as_type(Nx.type(input)) + + normalized_input * (opts[:shift] + gamma) + end + + defnp rms_norm_upcast_all(input, gamma, opts) do + num_channels = Nx.axis_size(input, opts[:channel_index]) + parameter_shape = norm_parameter_reshape(input, num_channels, opts[:channel_index]) + gamma = Nx.reshape(gamma, parameter_shape) + + input = Nx.as_type(input, :f32) + gamma = Nx.as_type(gamma, :f32) + + normalized_input = rms_normalize(input, opts) + normalized_input * (opts[:shift] + gamma) + end + + defnp rms_normalize(input, opts) do + variance = + input + |> Nx.pow(2) + |> Nx.mean(axes: [opts[:channel_index]], keep_axes: true) + + input * Nx.rsqrt(variance + opts[:epsilon]) + end + @doc ~S""" Functional implementation of instance normalization. diff --git a/test/axon/layers_test.exs b/test/axon/layers_test.exs index 17a61631..f8fe2336 100644 --- a/test/axon/layers_test.exs +++ b/test/axon/layers_test.exs @@ -988,7 +988,7 @@ defmodule Axon.LayersTest do [[6.1974, 7.1974, 8.1974], [7.8947, 8.8947, 9.8947], [9.5921, 10.5921, 11.5921]] ] ]), - atol: 1.0e-4 + atol: 1.0e-3 ) assert_all_close( @@ -1000,7 +1000,7 @@ defmodule Axon.LayersTest do [[6.3724, 7.3724, 8.3724], [8.2449, 9.2449, 10.2449], [10.1173, 11.1173, 12.1173]] ] ]), - atol: 1.0e-4 + atol: 1.0e-3 ) assert_all_close( @@ -1012,7 +1012,7 @@ defmodule Axon.LayersTest do [[6.4508, 7.4508, 8.4508], [8.4016, 9.4016, 10.4016], [10.3525, 11.3525, 12.3525]] ] ]), - atol: 1.0e-4 + atol: 1.0e-3 ) end @@ -1036,7 +1036,7 @@ defmodule Axon.LayersTest do ] ] ]), - atol: 1.0e-4 + atol: 1.0e-3 ) # Downscaling (no effect) @@ -1052,7 +1052,7 @@ defmodule Axon.LayersTest do [[6.1974, 7.1974, 8.1974], [7.8947, 8.8947, 9.8947], [9.5921, 10.5921, 11.5921]] ] ]), - atol: 1.0e-4 + atol: 1.0e-3 ) end end @@ -1723,6 +1723,88 @@ defmodule Axon.LayersTest do end end + describe "rms_norm" do + test "matches pytorch 2D input" do + input = + Nx.tensor([ + [1.9269, 1.4873, 0.9007, -2.1055, 0.6784, -1.2345, -0.0431, -1.6047], + [-0.7521, 1.6487, -0.3925, -1.4036, -0.7279, -0.5594, -0.7688, 0.7624] + ]) + + gamma = + Nx.tensor([0.4617, 0.2674, 0.5349, 0.8094, 1.1103, -1.6898, -0.9890, 0.9580]) + + expected = + Nx.tensor([ + [0.6344, 0.2836, 0.3436, -1.2153, 0.5372, 1.4877, 0.0304, -1.0962], + [-0.3605, 0.4576, -0.2179, -1.1793, -0.8390, 0.9814, 0.7893, 0.7582] + ]) + + actual = Axon.Layers.rms_norm(input, gamma, epsilon: 1.0e-6, channel_index: -1) + assert_all_close(expected, actual, atol: 1.0e-3) + end + + test "matches pytorch 3D input" do + input = + Nx.tensor([ + [ + [-1.3847, -0.8712, -0.2234, 1.7174, 0.3189, -0.4245], + [0.3057, -0.7746, -1.5576, 0.9956, -0.8798, -0.6011], + [-1.2742, 2.1228, -1.2347, -0.4879, -0.9138, -0.6581], + [0.0780, 0.5258, -0.4880, 1.1914, -0.8140, -0.7360] + ], + [ + [-1.4032, 0.0360, -0.0635, 0.6756, -0.0978, 1.8446], + [-1.1845, 1.3835, 1.4451, 0.8564, 2.2181, 0.5232], + [0.3466, -0.1973, -1.0546, 1.2780, -0.1722, 0.5238], + [0.0566, 0.4263, 0.5750, -0.6417, -2.2064, -0.7508] + ] + ]) + + gamma = + Nx.tensor([0.4679, -0.2049, -0.7409, 0.3618, 1.9199, -0.2254]) + + expected = + Nx.tensor([ + [ + [-0.6502, 0.1792, 0.1661, 0.6236, 0.6144, 0.0960], + [0.1530, 0.1698, 1.2341, 0.3853, -1.8064, 0.1449], + [-0.4825, -0.3521, 0.7403, -0.1429, -1.4199, 0.1201], + [0.0504, -0.1488, 0.4994, 0.5955, -2.1588, 0.2291] + ], + [ + [-0.6653, -0.0075, 0.0477, 0.2477, -0.1903, -0.4213], + [-0.4033, -0.2063, -0.7791, 0.2255, 3.0986, -0.0858], + [0.2218, 0.0553, 1.0685, 0.6324, -0.4521, -0.1614], + [0.0257, -0.0849, -0.4138, -0.2255, -4.1147, 0.1644] + ] + ]) + + actual = Axon.Layers.rms_norm(input, gamma, epsilon: 1.0e-6, channel_index: -1) + assert_all_close(expected, actual, atol: 1.0e-3) + end + + test "matches pytorch with ones weight" do + input = + Nx.tensor([ + [0.6127, -1.1754, -0.7646, -0.6666], + [0.7444, -0.6453, -1.3890, -0.2730] + ]) + + gamma = + Nx.tensor([1.0000, 1.0000, 1.0000, 1.0000]) + + expected = + Nx.tensor([ + [0.7342, -1.4084, -0.9163, -0.7987], + [0.8632, -0.7483, -1.6108, -0.3165] + ]) + + actual = Axon.Layers.rms_norm(input, gamma, epsilon: 1.0e-6, channel_index: -1) + assert_all_close(expected, actual, atol: 1.0e-3) + end + end + describe "batch_norm" do test "matches pytorch when variance < epsilon" do input_val = -0.002805