From 0c42ff5371cda3819ba3d2d3172eb7bd0121b968 Mon Sep 17 00:00:00 2001 From: Zach Denton Date: Thu, 18 Dec 2025 11:47:24 -0500 Subject: [PATCH 1/2] Fix logic error in normalize. --- lib/axon/shared.ex | 7 +------ test/axon/layers_test.exs | 23 +++++++++++++++++++++++ 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/lib/axon/shared.ex b/lib/axon/shared.ex index 87eff5ae7..f31d847d4 100644 --- a/lib/axon/shared.ex +++ b/lib/axon/shared.ex @@ -245,12 +245,7 @@ defmodule Axon.Shared do defn normalize(input, mean, variance, gamma, bias, opts \\ []) do [epsilon: epsilon] = keyword!(opts, epsilon: 1.0e-6) - - # The select is so that we improve numerical stability by clipping - # both insignificant values of variance and NaNs to epsilon. - scale = - gamma * Nx.select(variance >= epsilon, Nx.rsqrt(variance + epsilon), Nx.rsqrt(epsilon)) - + scale = gamma * Nx.rsqrt(variance + epsilon) scale * (input - mean) + bias end diff --git a/test/axon/layers_test.exs b/test/axon/layers_test.exs index 3637e89db..d33a1fd50 100644 --- a/test/axon/layers_test.exs +++ b/test/axon/layers_test.exs @@ -1722,4 +1722,27 @@ defmodule Axon.LayersTest do assert_all_close(expected, actual, atol: 1.0e-3) end end + + describe "batch_norm" do + test "matches pytorch when variance < epsilon" do + input_val = -0.002805 + mean = -0.008561 + variance = 0.000412 + weight = 1.0 + bias = -0.144881 + epsilon = 0.001 + + expected = Nx.tensor([0.0083]) + + actual = + Axon.Layers.batch_norm( + Nx.tensor([[[[input_val]]]]), + Nx.tensor([weight]), + Nx.tensor([bias]), + Nx.tensor([mean]), + Nx.tensor([variance]), mode: :inference, epsilon: epsilon) + + assert_all_close(expected, actual, atol: 1.0e-3) + end + end end From 73b0fc0eb224338bd716a916792672c2866f9a08 Mon Sep 17 00:00:00 2001 From: Zach Denton Date: Thu, 18 Dec 2025 12:08:42 -0500 Subject: [PATCH 2/2] Fix formatting. --- test/axon/layers_test.exs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/axon/layers_test.exs b/test/axon/layers_test.exs index d33a1fd50..17a61631f 100644 --- a/test/axon/layers_test.exs +++ b/test/axon/layers_test.exs @@ -1740,7 +1740,10 @@ defmodule Axon.LayersTest do Nx.tensor([weight]), Nx.tensor([bias]), Nx.tensor([mean]), - Nx.tensor([variance]), mode: :inference, epsilon: epsilon) + Nx.tensor([variance]), + mode: :inference, + epsilon: epsilon + ) assert_all_close(expected, actual, atol: 1.0e-3) end