From d8b085844be67755423c3a0a110f02d12fdada65 Mon Sep 17 00:00:00 2001 From: francesco-vaselli Date: Wed, 15 Mar 2023 15:28:05 +0100 Subject: [PATCH 1/3] float precision check --- nflows/transforms/splines/rational_quadratic.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/nflows/transforms/splines/rational_quadratic.py b/nflows/transforms/splines/rational_quadratic.py index bb6a8c4..5d752cf 100644 --- a/nflows/transforms/splines/rational_quadratic.py +++ b/nflows/transforms/splines/rational_quadratic.py @@ -139,6 +139,10 @@ def rational_quadratic_spline( c = -input_delta * (inputs - input_cumheights) discriminant = b.pow(2) - 4 * a * c + + float_precision_mask = (torch.abs(discriminant)/(b.pow(2) + 1e-8)) < 1e-6 + discriminant[float_precision_mask] = 0 + assert (discriminant >= 0).all() root = (2 * c) / (-b - torch.sqrt(discriminant)) From 4934a85b71752640a70b00a8c41e54e53562de78 Mon Sep 17 00:00:00 2001 From: Francesco Vaselli Date: Sat, 30 Sep 2023 10:45:48 +0200 Subject: [PATCH 2/3] comments and torch.where --- nflows/transforms/splines/rational_quadratic.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/nflows/transforms/splines/rational_quadratic.py b/nflows/transforms/splines/rational_quadratic.py index 5d752cf..eccbb61 100644 --- a/nflows/transforms/splines/rational_quadratic.py +++ b/nflows/transforms/splines/rational_quadratic.py @@ -140,8 +140,13 @@ def rational_quadratic_spline( discriminant = b.pow(2) - 4 * a * c - float_precision_mask = (torch.abs(discriminant)/(b.pow(2) + 1e-8)) < 1e-6 - discriminant[float_precision_mask] = 0 + # Correcting for floating-point errors in the discriminant calculation. + # The float_precision_mask identifies elements where the discriminant is essentially zero, + # but appears nonzero due to machine precision limitations. + # Threshold values (1e-8 and 1e-6) are heuristic-based to manage numerical stability. + float_precision_mask = (torch.abs(discriminant) / (b.pow(2) + 1e-8)) < 1e-6 + discriminant = torch.where(float_precision_mask, + torch.zeros_like(discriminant), discriminant) assert (discriminant >= 0).all() From 4bcc21d0aae56ecad799ba93029682c83b8cf133 Mon Sep 17 00:00:00 2001 From: Francesco Vaselli Date: Sat, 30 Sep 2023 10:47:27 +0200 Subject: [PATCH 3/3] comments 1 --- nflows/transforms/splines/rational_quadratic.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nflows/transforms/splines/rational_quadratic.py b/nflows/transforms/splines/rational_quadratic.py index eccbb61..57cd8a5 100644 --- a/nflows/transforms/splines/rational_quadratic.py +++ b/nflows/transforms/splines/rational_quadratic.py @@ -142,6 +142,7 @@ def rational_quadratic_spline( # Correcting for floating-point errors in the discriminant calculation. # The float_precision_mask identifies elements where the discriminant is essentially zero, + # compared to the magnitude of b.pow(2), # but appears nonzero due to machine precision limitations. # Threshold values (1e-8 and 1e-6) are heuristic-based to manage numerical stability. float_precision_mask = (torch.abs(discriminant) / (b.pow(2) + 1e-8)) < 1e-6