From 4a03bf9a5aa72eec3844538ec4cc2d35ce60bf8b Mon Sep 17 00:00:00 2001 From: davidwalter2 Date: Tue, 3 Feb 2026 14:51:35 -0500 Subject: [PATCH 1/6] First version of regularization --- bin/rabbit_fit.py | 8 +++ rabbit/fitter.py | 47 +++++++++++++-- rabbit/parsing.py | 13 ++++ rabbit/regularization/__init__.py | 0 rabbit/regularization/helpers.py | 13 ++++ rabbit/regularization/regularizer.py | 17 ++++++ rabbit/regularization/svd.py | 89 ++++++++++++++++++++++++++++ 7 files changed, 183 insertions(+), 4 deletions(-) create mode 100644 rabbit/regularization/__init__.py create mode 100644 rabbit/regularization/helpers.py create mode 100644 rabbit/regularization/regularizer.py create mode 100644 rabbit/regularization/svd.py diff --git a/bin/rabbit_fit.py b/bin/rabbit_fit.py index 44893f1..0815206 100755 --- a/bin/rabbit_fit.py +++ b/bin/rabbit_fit.py @@ -13,6 +13,7 @@ from rabbit.mappings import helpers as mh from rabbit.mappings import mapping as mp from rabbit.poi_models import helpers as ph +from rabbit.regularization import helpers as rh from rabbit.tfhelpers import edmval_cov from wums import output_tools, logging # isort: skip @@ -477,6 +478,13 @@ def main(): mp.CompositeMapping(mappings), ] + regularizers = [] + for margs in args.regularization: + mapping = mh.load_mapping(margs[1], indata, *margs[2:]) + regularizer = rh.load_regularizer(margs[0], mapping, dtype=indata.dtype) + regularizers.append(regularizer) + ifitter.regularizers = regularizers + np.random.seed(args.seed) tf.random.set_seed(args.seed) diff --git a/rabbit/fitter.py b/rabbit/fitter.py index 538eb12..4c5387c 100644 --- a/rabbit/fitter.py +++ b/rabbit/fitter.py @@ -292,6 +292,9 @@ def __init__( name="cov", ) + # regularization + self.regularizers = [] + # determine if problem is linear (ie likelihood is purely quadratic) self.is_linear = ( (self.chisqFit or self.covarianceFit) @@ -344,12 +347,14 @@ def update_frozen_params(self): self.frozen_indices = np.where(new_mask_np)[0] def freeze_params(self, frozen_parmeter_expressions): + logger.debug(f"Freeze params with {frozen_parmeter_expressions}") self.frozen_params.extend( match_regexp_params(frozen_parmeter_expressions, self.parms) ) self.update_frozen_params() def defreeze_params(self, unfrozen_parmeter_expressions): + logger.debug(f"Freeze params with {unfrozen_parmeter_expressions}") unfrozen_parmeter = match_regexp_params( unfrozen_parmeter_expressions, self.parms ) @@ -359,6 +364,7 @@ def defreeze_params(self, unfrozen_parmeter_expressions): self.update_frozen_params() def init_blinding_values(self, unblind_parameter_expressions=[]): + logger.debug(f"Unblind parameters with {unblind_parameter_expressions}") unblind_parameters = match_regexp_params( unblind_parameter_expressions, [ @@ -452,6 +458,9 @@ def get_poi(self): else: return poi + def get_x(self): + return tf.concat([self.get_poi(), self.get_theta()], axis=0) + def _default_beta0(self): if self.binByBinStatType in ["gamma", "normal-multiplicative"]: return tf.ones(self.beta_shape, dtype=self.indata.dtype) @@ -534,6 +543,11 @@ def defaultassign(self): if self.do_blinding: self.set_blinding_offsets(False) + xinit = self.get_x() + nexp0 = self.expected_yield(full=True) + for reg in self.regularizers: + reg.set_expectations(xinit, nexp0) + def bayesassign(self): # FIXME use theta0 as the mean and constraintweight to scale the width if self.poi_model.npoi == 0: @@ -1982,10 +1996,10 @@ def _compute_nll_components(self, profile=True, full_nll=False): nexpfullcentral, _, beta = self._compute_yields_with_beta( profile=profile, compute_norm=False, - full=False, + full=len(self.regularizers), ) - nexp = nexpfullcentral + nexp = nexpfullcentral[: self.indata.nbins] if self.chisqFit: ln = 0.5 * tf.reduce_sum((nexp - self.nobs) ** 2 / self.varnobs, axis=-1) @@ -2019,10 +2033,25 @@ def _compute_nll_components(self, profile=True, full_nll=False): lbeta = self._compute_lbeta(beta, full_nll) - return ln, lc, lbeta, beta + # logger.debug(f"L(nobs) = {ln}") + # logger.debug(f"L(const.) = {lc}") + # logger.debug(f"L(beta) = {lbeta}") + + if len(self.regularizers): + x = self.get_x() + penalties = [ + reg.compute_nll_penalty(x, nexpfullcentral) for reg in self.regularizers + ] + lpenalty = tf.add_n(penalties) + else: + lpenalty = None + + # logger.debug(f"L(penalty) = {lpenalty}") + + return ln, lc, lbeta, lpenalty, beta def _compute_nll(self, profile=True, full_nll=False): - ln, lc, lbeta, beta = self._compute_nll_components( + ln, lc, lbeta, lpenalty, beta = self._compute_nll_components( profile=profile, full_nll=full_nll ) l = ln + lc @@ -2030,6 +2059,9 @@ def _compute_nll(self, profile=True, full_nll=False): if lbeta is not None: l = l + lbeta + if lpenalty is not None: + l = l + lpenalty + return l def _compute_loss(self, profile=True): @@ -2138,16 +2170,22 @@ def minimize(self): callback = None else: + logger.info("Perform iterative fit") def scipy_loss(xval): self.x.assign(xval) val, grad = self.loss_val_grad() + # logger.debug(f"xval = {xval}") + # logger.debug(f"val = {val}; grad = {grad}") return val.__array__(), grad.__array__() def scipy_hessp(xval, pval): self.x.assign(xval) p = tf.convert_to_tensor(pval) val, grad, hessp = self.loss_val_grad_hessp(p) + # logger.debug(f"xval = {xval}") + # logger.debug(f"p = {p}") + # logger.debug(f"val = {val}; grad = {grad}; hessp = {hessp}") return hessp.__array__() def scipy_hess(xval): @@ -2190,6 +2228,7 @@ def scipy_hess(xval): except Exception as ex: # minimizer could have called the loss or hessp functions with "random" values, so restore the # state from the end of the last iteration before the exception + logger.debug("Fitter exception") xval = callback.xval logger.debug(ex) else: diff --git a/rabbit/parsing.py b/rabbit/parsing.py index c2dca87..4b874c9 100644 --- a/rabbit/parsing.py +++ b/rabbit/parsing.py @@ -211,5 +211,18 @@ def common_parser(): action="store_true", help="Make a composite mapping and compute the covariance matrix across all mappings.", ) + parser.add_argument( + "-r", + "--regularization", + nargs="+", + action="append", + default=[], + help=""" + apply regularization on the output "nout" of a mapping by including a penalty term P(nout) in the -log(L) of the minimization. + As argument, specify the regulaization defined in rabbit/regularization/, followed by a mapping using the same syntax as discussed above. + e.g. '-r SVD Select ch0_masked' to apply SVD regularization on the channel 'ch0_masked' or '-r SVD Project ch0 pt' for the 1D projection to pt. + Custom regularization can be specified with the full path e.g. '-r custom_regularization.MyCustomRegularization Project ch0 pt'. + """, + ) return parser diff --git a/rabbit/regularization/__init__.py b/rabbit/regularization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/rabbit/regularization/helpers.py b/rabbit/regularization/helpers.py new file mode 100644 index 0000000..c98aae9 --- /dev/null +++ b/rabbit/regularization/helpers.py @@ -0,0 +1,13 @@ +from rabbit import common + +# dictionary with class name and the corresponding filename where it is defined +baseline_regularizations = { + "SVD": "svd", +} + + +def load_regularizer(class_name, *args, **kwargs): + regularization = common.load_class_from_module( + class_name, baseline_regularizations, base_dir="rabbit.regularization" + ) + return regularization(*args, **kwargs) diff --git a/rabbit/regularization/regularizer.py b/rabbit/regularization/regularizer.py new file mode 100644 index 0000000..3621cc3 --- /dev/null +++ b/rabbit/regularization/regularizer.py @@ -0,0 +1,17 @@ +class Regularizer: + + def __init__(self, mapping, dtype): + """ + Initialize the regularization depending on the mapping + """ + + def set_expectations(self, initial_params, initial_observables): + """ + Set the expectations to use in the regularization, this step should be called once per fit configuration + """ + + def compute_nll_penalty(self, params, observables): + """ + Compute the penalty term that gets added to -ln(L), this function should be called in each step of the minimization + """ + return 0 diff --git a/rabbit/regularization/svd.py b/rabbit/regularization/svd.py new file mode 100644 index 0000000..776daf2 --- /dev/null +++ b/rabbit/regularization/svd.py @@ -0,0 +1,89 @@ +import tensorflow as tf + +from rabbit.regularization.regularizer import Regularizer + + +class SVD(Regularizer): + """ + Singular Value Decomposition (SVD) see: https://arxiv.org/abs/hep-ph/9509307 + """ + + def __init__(self, mapping, dtype): + self.strength = 1.0 + + if len(mapping.channel_info) > 1: + raise NotImplementedError( + "Regularization currently only works for 1 channel at a time; use multiple regularizers if you want to penalize multiple channels." + ) + + self.mapping = mapping + self.input_shape = [ + len(a) for v in mapping.channel_info.values() for a in v["axes"] + ] + + self.ndims = len(self.input_shape) + + if self.ndims == 1: + kernel = tf.constant([1, -2, 1], dtype=dtype) + self.kernel = kernel[:, tf.newaxis, tf.newaxis] # (W, In, Out) + elif self.ndims == 2: + kernel = tf.constant([[0, 1, 0], [1, -4, 1], [0, 1, 0]], dtype=dtype) + self.kernel = kernel[:, :, tf.newaxis, tf.newaxis] # (H, W, In, Out) + elif self.ndims == 3: + # Axial neighbors are 1, center is -6 + kernel = tf.zeros((3, 3, 3), dtype=dtype) + indices = [ + [1, 1, 1], + [0, 1, 1], + [2, 1, 1], + [1, 0, 1], + [1, 2, 1], + [1, 1, 0], + [1, 1, 2], + ] + values = [-6.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + kernel = tf.tensor_scatter_nd_update(kernel, indices, values) + self.kernel = kernel[:, :, :, tf.newaxis, tf.newaxis] + else: + raise NotImplementedError("SVD regularization only implemented in up to 3D") + + self.paddings = [[0, 0]] + [[1, 1]] * self.ndims + [[0, 0]] + + def set_expectations(self, initial_params, initial_observables): + # TODO: do we need to include this in autodiff for global impacts, since initial_params = (poi0, theta0)? + nexp0 = self.mapping.compute_flat(initial_params, initial_observables) + self.nexp0 = tf.reshape(nexp0, self.input_shape) + + def compute_nll_penalty(self, params, observables): + mask = self.nexp0 != 0 + nexp0_safe = tf.where(mask, self.nexp0, tf.cast(1.0, self.nexp0.dtype)) + + nexp = self.mapping.compute_flat(params, observables) + nexp = tf.reshape(nexp, self.input_shape) + + dexp = nexp / nexp0_safe + dexp = tf.where(mask, dexp, tf.ones_like(dexp)) + + # add batch (first) and channel (last) dimensions + dexp = dexp[tf.newaxis, ..., tf.newaxis] + + # padding 'SYMMETRIC' means copy the element at the edge, i.e. apply the kernel (1 -2 1) to (x x y) -> x -2x + y = -x y + # which is equivalent to applying a "modified kernal" of (-1, 1) to (x y) -> -x y + padded_input = tf.pad(dexp, self.paddings, mode="SYMMETRIC") + + if self.ndims == 1: + curvature_map = tf.nn.conv1d( + padded_input, self.kernel, stride=1, padding="VALID" + ) + elif self.ndims == 2: + curvature_map = tf.nn.conv2d( + padded_input, self.kernel, strides=[1, 1, 1, 1], padding="VALID" + ) + elif self.ndims == 3: + curvature_map = tf.nn.conv3d( + padded_input, self.kernel, strides=[1, 1, 1, 1, 1], padding="VALID" + ) + + penalty = self.strength * tf.reduce_mean(tf.square(curvature_map)) + + return penalty From bf4221ee45781749845c09d1ee8cc83ac2329a21 Mon Sep 17 00:00:00 2001 From: davidwalter2 Date: Wed, 4 Feb 2026 19:06:28 -0500 Subject: [PATCH 2/6] First implementation of curvature scan for regularization --- bin/rabbit_fit.py | 16 +++- rabbit/fitter.py | 27 ++++--- rabbit/regularization/helpers.py | 123 +++++++++++++++++++++++++++++++ rabbit/regularization/svd.py | 22 ++++-- 4 files changed, 171 insertions(+), 17 deletions(-) diff --git a/bin/rabbit_fit.py b/bin/rabbit_fit.py index 0815206..051b98e 100755 --- a/bin/rabbit_fit.py +++ b/bin/rabbit_fit.py @@ -150,6 +150,12 @@ def make_parser(): type=str, help="Specify result from external postfit file", ) + parser.add_argument( + "--noFit", + default=False, + action="store_true", + help="Do not not perform the minimization.", + ) parser.add_argument( "--noPostfitProfileBB", default=False, @@ -282,7 +288,13 @@ def fit(args, fitter, ws, dofit=True): edmval = None if args.externalPostfit is not None: - fitter.load_fitresult(args.externalPostfit, args.externalPostfitResult) + fitter.load_fitresult( + args.externalPostfit, + args.externalPostfitResult, + profile=not args.noPostfitProfileBB, + ) + + rh.optimize_tau(fitter) if dofit: cb = fitter.minimize() @@ -568,7 +580,7 @@ def main(): if not args.prefitOnly: ifitter.set_blinding_offsets(blind=blinded_fits[i]) - fit(args, ifitter, ws, dofit=ifit >= 0) + fit(args, ifitter, ws, dofit=ifit >= 0 and not args.noFit) fit_time.append(time.time()) if args.saveHists: diff --git a/rabbit/fitter.py b/rabbit/fitter.py index 4c5387c..b73487b 100644 --- a/rabbit/fitter.py +++ b/rabbit/fitter.py @@ -304,7 +304,7 @@ def __init__( and ((not self.binByBinStat) or self.binByBinStatType == "normal-additive") ) - def load_fitresult(self, fitresult_file, fitresult_key): + def load_fitresult(self, fitresult_file, fitresult_key, profile=True): # load results from external fit and set postfit value and covariance elements for common parameters cov_ext = None with h5py.File(fitresult_file, "r") as fext: @@ -340,6 +340,9 @@ def load_fitresult(self, fitresult_file, fitresult_key): covval[np.ix_(idxs, idxs)] = cov_ext[np.ix_(idxs_ext, idxs_ext)] self.cov.assign(tf.constant(covval)) + if profile: + self._profile_beta() + def update_frozen_params(self): new_mask_np = np.isin(self.parms, self.frozen_params) @@ -1992,15 +1995,7 @@ def _compute_lbeta(self, beta, full_nll=False): return None - def _compute_nll_components(self, profile=True, full_nll=False): - nexpfullcentral, _, beta = self._compute_yields_with_beta( - profile=profile, - compute_norm=False, - full=len(self.regularizers), - ) - - nexp = nexpfullcentral[: self.indata.nbins] - + def _compute_ln(self, nexp, full_nll=False): if self.chisqFit: ln = 0.5 * tf.reduce_sum((nexp - self.nobs) ** 2 / self.varnobs, axis=-1) elif self.covarianceFit: @@ -2028,6 +2023,18 @@ def _compute_nll_components(self, profile=True, full_nll=False): ln = tf.reduce_sum( -self.nobs * (lognexp - self.lognobs) + nexp - self.nobs, axis=-1 ) + return ln + + def _compute_nll_components(self, profile=True, full_nll=False): + nexpfullcentral, _, beta = self._compute_yields_with_beta( + profile=profile, + compute_norm=False, + full=len(self.regularizers), + ) + + nexp = nexpfullcentral[: self.indata.nbins] + + ln = self._compute_ln(nexp, full_nll) lc = self._compute_lc(full_nll) diff --git a/rabbit/regularization/helpers.py b/rabbit/regularization/helpers.py index c98aae9..402ee43 100644 --- a/rabbit/regularization/helpers.py +++ b/rabbit/regularization/helpers.py @@ -1,4 +1,10 @@ +import tensorflow as tf +from wums import logging + from rabbit import common +from rabbit.regularization.svd import SVD + +logger = logging.child_logger(__name__) # dictionary with class name and the corresponding filename where it is defined baseline_regularizations = { @@ -11,3 +17,120 @@ def load_regularizer(class_name, *args, **kwargs): class_name, baseline_regularizations, base_dir="rabbit.regularization" ) return regularization(*args, **kwargs) + + +def compute_curvature(fitter, tau): + """ + Following Eq.(4.3) from https://iopscience.iop.org/article/10.1088/1748-0221/7/10/T10003/pdf + """ + + # the full derivative d(Li) / d(tau) = pd(Li)/pd(tau) + pd(Li)/pd(x) * d(x)/d(tau) + # = pd(Li)/pd(tau) - pd(Li)/pd(x) * (pd2(L)/pd(x^2))^-1 * pd2(L)/pd(x)pd(tau) + # there is no dependency of Li on tau, thus, the first term is 0 + # (pd2(L)/pd(x^2))^-1 is the covariance matrix + with tf.GradientTape(persistent=True) as t3: + t3.watch(tau) + + # 1) compute dx/dtau + with tf.GradientTape(persistent=True) as t2: + t2.watch(tau) + with tf.GradientTape() as t1: + t1.watch(tau) + nll = fitter._compute_nll() + + pdLpdx = t1.gradient(nll, fitter.x) + + pd2Lpdx2 = t2.jacobian(pdLpdx, fitter.x) + pd2Lpdxpdtau = t2.jacobian(pdLpdx, tau) + + chol = tf.linalg.cholesky(pd2Lpdx2) + dxdtau = -tf.linalg.cholesky_solve(chol, pd2Lpdxpdtau[:, None]) + dxdtau = tf.reshape(dxdtau, -1) + + # 2) compute pdLx/pdx, pdLy/pdx and pd^2Lx/pdx^2, pd^2Ly/pdx^2 + with tf.GradientTape(persistent=True) as t_inner: + t_inner.watch(tau) + nexpfullcentral, _, beta = fitter._compute_yields_with_beta( + profile=False, + compute_norm=False, + full=len(fitter.regularizers), + ) + + nexp = nexpfullcentral[: fitter.indata.nbins] + + ln = fitter._compute_ln(nexp) + lc = fitter._compute_lc() + lbeta = fitter._compute_lbeta(beta) + lx = tf.math.log(ln + lc + lbeta) + + x = fitter.get_x() + penalties = [ + reg.compute_nll_penalty_unweighted(x, nexpfullcentral) + for reg in fitter.regularizers + ] + ly = tf.math.log(tf.add_n(penalties)) + + pdLxpdx = t_inner.gradient(lx, fitter.x) + pdLypdx = t_inner.gradient(ly, fitter.x) + + pdLxpdx_dxdtau = tf.reduce_sum(pdLxpdx * dxdtau) + pdLypdx_dxdtau = tf.reduce_sum(pdLypdx * dxdtau) + + pdLxpdx_d2xdtau2 = t3.gradient(pdLxpdx_dxdtau, tau) + pdLypdx_d2xdtau2 = t3.gradient(pdLypdx_dxdtau, tau) + + pd2Lxpdx2 = t3.jacobian(pdLxpdx, fitter.x) + pd2Lypdx2 = t3.jacobian(pdLypdx, fitter.x) + + dLxdtau = tf.reduce_sum(pdLxpdx * dxdtau) + dLydtau = tf.reduce_sum(pdLypdx * dxdtau) + + d2Lxdtau2 = ( + tf.reduce_sum(dxdtau * tf.linalg.matvec(pd2Lxpdx2, dxdtau)) + pdLxpdx_d2xdtau2 + ) + d2Lydtau2 = ( + tf.reduce_sum(dxdtau * tf.linalg.matvec(pd2Lypdx2, dxdtau)) + pdLypdx_d2xdtau2 + ) + + curvature = (d2Lydtau2 * dLxdtau - d2Lxdtau2 * dLydtau) / tf.pow( + tf.square(dLxdtau) + tf.square(dLydtau), 1.5 + ) + + return curvature + + +def optimize_tau(fitter): + # find the tau where the curvature is maximum, minimize curvature w.r.t. tau + + tau = SVD.tau + + fitter.minimize() + curvature = -compute_curvature(fitter, tau) + + edm = 1 + i = 0 + while i < 50 and edm > 1e-10: + fitter.minimize() + logger.info(f"Iteration {i}") + + with tf.GradientTape() as t2: + t2.watch(tau) + with tf.GradientTape() as t1: + t1.watch(tau) + curvature = -compute_curvature(fitter, tau) + logger.info(f"Curvature (value) = {curvature}") + grad = t1.gradient(curvature, tau) + logger.info(f"Curvature (gradient) = {grad}") + hess = t2.gradient(grad, tau) + logger.info(f"Curvature (hessian) = {hess}") + + # eps = 1e-8 + # safe_hess = tf.where(hess > 0, hess, tf.ones_like(hess)) + step = grad / hess # (safe_hess + eps) + logger.info(f"Curvature (step) = {-step}") + tau.assign_sub(step) + edm = tf.reduce_max(0.5 * grad * step) + i = i + 1 + + logger.debug(f"Curvature edm = {edm}") + logger.debug(f"Curvature tau = {tau}") diff --git a/rabbit/regularization/svd.py b/rabbit/regularization/svd.py index 776daf2..09102e0 100644 --- a/rabbit/regularization/svd.py +++ b/rabbit/regularization/svd.py @@ -8,17 +8,23 @@ class SVD(Regularizer): Singular Value Decomposition (SVD) see: https://arxiv.org/abs/hep-ph/9509307 """ - def __init__(self, mapping, dtype): - self.strength = 1.0 + # one common regularization strength parameter + tau = tf.Variable(1.0, trainable=True, name="tau", dtype=tf.float64) + def __init__(self, mapping, dtype): if len(mapping.channel_info) > 1: raise NotImplementedError( "Regularization currently only works for 1 channel at a time; use multiple regularizers if you want to penalize multiple channels." ) self.mapping = mapping + + # there is an embiguity about what to do with the flow bins. + # they are not part of the fit, thus, the flow bins are not taken except for masked channels self.input_shape = [ - len(a) for v in mapping.channel_info.values() for a in v["axes"] + a.extent if v["flow"] else a.size + for v in mapping.channel_info.values() + for a in v["axes"] ] self.ndims = len(self.input_shape) @@ -54,7 +60,7 @@ def set_expectations(self, initial_params, initial_observables): nexp0 = self.mapping.compute_flat(initial_params, initial_observables) self.nexp0 = tf.reshape(nexp0, self.input_shape) - def compute_nll_penalty(self, params, observables): + def compute_nll_penalty_unweighted(self, params, observables): mask = self.nexp0 != 0 nexp0_safe = tf.where(mask, self.nexp0, tf.cast(1.0, self.nexp0.dtype)) @@ -84,6 +90,12 @@ def compute_nll_penalty(self, params, observables): padded_input, self.kernel, strides=[1, 1, 1, 1, 1], padding="VALID" ) - penalty = self.strength * tf.reduce_mean(tf.square(curvature_map)) + penalty = tf.reduce_sum(tf.square(curvature_map)) return penalty + + def compute_nll_penalty(self, params, observables): + + penalty = self.compute_nll_penalty_unweighted(params, observables) + + return penalty * tf.exp(2 * self.tau) From fd228b30cc58c009c7e7800176c6aad2896ef86e Mon Sep 17 00:00:00 2001 From: davidwalter2 Date: Thu, 5 Feb 2026 08:23:52 -0500 Subject: [PATCH 3/6] Add 'earlyStopping' feature to stop minimization if no reduction after 'x' iterations is obtained --- rabbit/fitter.py | 18 +++++++++++++++--- rabbit/parsing.py | 6 ++++++ 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/rabbit/fitter.py b/rabbit/fitter.py index b73487b..aa22d76 100644 --- a/rabbit/fitter.py +++ b/rabbit/fitter.py @@ -33,7 +33,7 @@ def match_regexp_params(regular_expressions, parameter_names): class FitterCallback: - def __init__(self, xv): + def __init__(self, xv, early_stopping=-1): self.iiter = 0 self.xval = xv @@ -42,6 +42,8 @@ def __init__(self, xv): self.t0 = time.time() + self.early_stopping = early_stopping + def __call__(self, intermediate_result): loss = intermediate_result.fun @@ -51,6 +53,15 @@ def __call__(self, intermediate_result): if np.isnan(loss): raise ValueError(f"Loss value is NaN at iteration {self.iiter}") + if ( + self.early_stopping > 0 + and len(self.loss_history) > self.early_stopping + and self.loss_history[self.early_stopping] <= loss + ): + raise ValueError( + f"No reduction in loss after {self.early_stopping} iterations, early stopping." + ) + self.loss_history.append(loss) self.time_history.append(time.time() - self.t0) @@ -67,6 +78,7 @@ def __init__( ): self.indata = indata + self.earlyStopping = options.earlyStopping self.globalImpactsFromJVP = globalImpactsFromJVP self.binByBinStat = not options.noBinByBinStat self.binByBinStatMode = options.binByBinStatMode @@ -2150,7 +2162,7 @@ def loss_val_grad_hess_beta(self, profile=True): return val, grad, hess - def minimize(self): + def minimize(self, eraly_stopping=10): if self.is_linear: logger.info( "Likelihood is purely quadratic, solving by Cholesky decomposition instead of iterative fit" @@ -2207,7 +2219,7 @@ def scipy_hess(xval): xval = self.x.numpy() - callback = FitterCallback(xval) + callback = FitterCallback(xval, self.earlyStopping) if self.minimizer_method in [ "trust-krylov", diff --git a/rabbit/parsing.py b/rabbit/parsing.py index 4b874c9..a11d77b 100644 --- a/rabbit/parsing.py +++ b/rabbit/parsing.py @@ -43,6 +43,12 @@ def common_parser(): action="store_true", help="Calculate and print additional info for diagnostics (condition number, edm value)", ) + parser.add_argument( + "--earlyStopping", + default=10, + type=int, + help="Number of iterations with no improvement after which training will be stopped. Specify -1 to disable.", + ) parser.add_argument( "--minimizerMethod", default="trust-krylov", From a1a6f4de00e6b9f47eac5887286f646a577b812b Mon Sep 17 00:00:00 2001 From: davidwalter2 Date: Thu, 5 Feb 2026 13:46:07 -0500 Subject: [PATCH 4/6] Implement curvature scan and support for plotting it --- bin/rabbit_fit.py | 25 ++++++- rabbit/regularization/helpers.py | 84 ++++++++++++++++------ rabbit/workspace.py | 9 +-- tests/plot_epoch_loss_time.py | 115 ++++++++++++++++++++++++------- 4 files changed, 179 insertions(+), 54 deletions(-) diff --git a/bin/rabbit_fit.py b/bin/rabbit_fit.py index 051b98e..aa48e05 100755 --- a/bin/rabbit_fit.py +++ b/bin/rabbit_fit.py @@ -189,6 +189,18 @@ def make_parser(): action="store_true", help="compute impacts of frozen (non-profiled) systematics", ) + parser.add_argument( + "--lCurveScan", + default=False, + action="store_true", + help="For use with regularization, scan the L curve versus values for tau", + ) + parser.add_argument( + "--lCurveOptimize", + default=False, + action="store_true", + help="For use with regularization, find the value of tau that maximizes the curvature", + ) return parser.parse_args() @@ -294,7 +306,15 @@ def fit(args, fitter, ws, dofit=True): profile=not args.noPostfitProfileBB, ) - rh.optimize_tau(fitter) + if args.lCurveScan: + tau_values, l_curve_values = rh.l_curve_scan_tau(fitter) + ws.add_1D_integer_hist(tau_values, "step", "tau") + ws.add_1D_integer_hist(l_curve_values, "step", "lcurve") + + if args.lCurveOptimize: + best_tau, max_curvature = rh.l_curve_optimize_tau(fitter) + ws.add_1D_integer_hist([best_tau], "best", "tau") + ws.add_1D_integer_hist([max_curvature], "best", "lcurve") if dofit: cb = fitter.minimize() @@ -306,7 +326,8 @@ def fit(args, fitter, ws, dofit=True): fitter._profile_beta() if cb is not None: - ws.add_loss_time_hist(cb.loss_history, cb.time_history) + ws.add_1D_integer_hist(cb.loss_history, "epoch", "loss") + ws.add_1D_integer_hist(cb.time_history, "epoch", "time") if not args.noHessian: # compute the covariance matrix and estimated distance to minimum diff --git a/rabbit/regularization/helpers.py b/rabbit/regularization/helpers.py index 402ee43..6eca7c9 100644 --- a/rabbit/regularization/helpers.py +++ b/rabbit/regularization/helpers.py @@ -1,3 +1,4 @@ +import numpy as np import tensorflow as tf from wums import logging @@ -19,7 +20,7 @@ def load_regularizer(class_name, *args, **kwargs): return regularization(*args, **kwargs) -def compute_curvature(fitter, tau): +def _compute_curvature(fitter, tau): """ Following Eq.(4.3) from https://iopscience.iop.org/article/10.1088/1748-0221/7/10/T10003/pdf """ @@ -35,7 +36,6 @@ def compute_curvature(fitter, tau): with tf.GradientTape(persistent=True) as t2: t2.watch(tau) with tf.GradientTape() as t1: - t1.watch(tau) nll = fitter._compute_nll() pdLpdx = t1.gradient(nll, fitter.x) @@ -45,11 +45,10 @@ def compute_curvature(fitter, tau): chol = tf.linalg.cholesky(pd2Lpdx2) dxdtau = -tf.linalg.cholesky_solve(chol, pd2Lpdxpdtau[:, None]) - dxdtau = tf.reshape(dxdtau, -1) + dxdtau = tf.reshape(dxdtau, [-1]) # 2) compute pdLx/pdx, pdLy/pdx and pd^2Lx/pdx^2, pd^2Ly/pdx^2 with tf.GradientTape(persistent=True) as t_inner: - t_inner.watch(tau) nexpfullcentral, _, beta = fitter._compute_yields_with_beta( profile=False, compute_norm=False, @@ -99,38 +98,79 @@ def compute_curvature(fitter, tau): return curvature -def optimize_tau(fitter): - # find the tau where the curvature is maximum, minimize curvature w.r.t. tau +@tf.function +def compute_curvature(fitter, tau): + return _compute_curvature(fitter, tau) + + +@tf.function +def neg_curvature_val_grad_hess(fitter, tau): + with tf.GradientTape() as t2: + t2.watch(tau) + with tf.GradientTape() as t1: + t1.watch(tau) + val = -1 * _compute_curvature(fitter, tau) + grad = t1.gradient(val, tau) + hess = t2.gradient(grad, tau) + + return val, grad, hess + +def l_curve_scan_tau(fitter, min=-5, max=5.1, step=0.1): tau = SVD.tau + tau0 = tau.numpy() + + curvatures = [] + tau_steps = np.arange(min, max, step) + + for i, v in enumerate(tau_steps): + logger.info(f"Iteration {i} with tau = {v}") + + tau.assign(v) + cb = fitter.minimize() + val = compute_curvature(fitter, tau).numpy() + curvatures.append(val) + + logger.info(f"Curvature (value) = {val}") - fitter.minimize() - curvature = -compute_curvature(fitter, tau) + # set tau back to the original value + tau.assign(tau0) + + return tau_steps, np.array(curvatures) + + +def l_curve_optimize_tau(fitter): + # find the tau where the curvature is maximum, minimize curvature w.r.t. tau + + tau = SVD.tau edm = 1 i = 0 - while i < 50 and edm > 1e-10: - fitter.minimize() + while i < 50 and edm > 1e-16: + cb = fitter.minimize() logger.info(f"Iteration {i}") - with tf.GradientTape() as t2: - t2.watch(tau) - with tf.GradientTape() as t1: - t1.watch(tau) - curvature = -compute_curvature(fitter, tau) - logger.info(f"Curvature (value) = {curvature}") - grad = t1.gradient(curvature, tau) - logger.info(f"Curvature (gradient) = {grad}") - hess = t2.gradient(grad, tau) + val, grad, hess = neg_curvature_val_grad_hess(fitter, tau) + + logger.info(f"Curvature (value) = {-val}") + logger.info(f"Curvature (gradient) = {grad}") logger.info(f"Curvature (hessian) = {hess}") # eps = 1e-8 - # safe_hess = tf.where(hess > 0, hess, tf.ones_like(hess)) - step = grad / hess # (safe_hess + eps) + # safe_hess = tf.where(hess != 0, hess, tf.ones_like(hess)) + # step = grad / (tf.abs(safe_hess) + eps) + step = grad / hess logger.info(f"Curvature (step) = {-step}") tau.assign_sub(step) - edm = tf.reduce_max(0.5 * grad * step) + edm = tf.reduce_max(0.5 * tf.square(grad) * tf.abs(hess)) i = i + 1 logger.debug(f"Curvature edm = {edm}") logger.debug(f"Curvature tau = {tau}") + + logger.info(f"Optimization terminated") + logger.info(f" edm: {edm}") + logger.info(f" maximum curvature: {-val}") + logger.info(f" tau: {tau.numpy()}") + + return tau.numpy(), val.numpy() diff --git a/rabbit/workspace.py b/rabbit/workspace.py index d2f8e4d..12b09e6 100644 --- a/rabbit/workspace.py +++ b/rabbit/workspace.py @@ -510,12 +510,13 @@ def add_expected_hists( return name, label - def add_loss_time_hist(self, loss, time, name="epoch"): + def add_1D_integer_hist(self, values, name_x, name_y): axis_epoch = hist.axis.Integer( - 0, len(loss), underflow=False, overflow=False, name="epoch" + 0, len(values), underflow=False, overflow=False, name=name_x + ) + self.add_hist( + f"{name_x}_{name_y}", axis_epoch, values, label=f"{name_x} {name_y}" ) - self.add_hist(f"{name}_loss", axis_epoch, loss, label=f"{name} loss") - self.add_hist(f"{name}_time", axis_epoch, time, label=f"{name} time") def write_meta(self, meta): ioutils.pickle_dump_h5py("meta", meta, self.fout) diff --git a/tests/plot_epoch_loss_time.py b/tests/plot_epoch_loss_time.py index 844d8cd..6d3c8d7 100644 --- a/tests/plot_epoch_loss_time.py +++ b/tests/plot_epoch_loss_time.py @@ -50,6 +50,13 @@ "--legCols", type=int, default=2, help="Number of columns in legend" ) parser.add_argument("--startEpoch", type=int, default=0, help="Epoch to start plotting") +parser.add_argument( + "--types", + nargs="+", + default=["loss"], + choices=["loss", "lcurve"], + help="Make 1D plot as function of epoch/step/...", +) args = parser.parse_args() outdir = output_tools.make_plot_dir(args.outpath) @@ -58,19 +65,30 @@ times = [] losses = [] dlosses = [] +tau_steps = [] +lcurves = [] +best_tau = [] +best_curvature = [] for infile in args.infile: fitresult, meta = rabbit.io_tools.get_fitresult(infile, args.result, meta=True) - h_time = fitresult["epoch_time"].get() - h_loss = fitresult["epoch_loss"].get() + if "loss" in args.types: + h_loss = fitresult["epoch_loss"].get() + loss = 2 * h_loss.values() + losses.append(loss) + dlosses.append(-np.diff(loss)) # reduction of loss after each epoch + epochs.append(np.arange(1, len(loss) + 1)) - times.append(h_time.values()) - loss = 2 * h_loss.values() + h_time = fitresult["epoch_time"].get() + times.append(h_time.values()) - epochs.append(np.arange(1, len(loss) + 1)) + if "lcurve" in args.types: + tau_steps.append(fitresult["step_tau"].get().values()) + lcurves.append(fitresult["step_lcurve"].get().values()) - losses.append(loss) - dlosses.append(-np.diff(loss)) # reduction of loss after each epoch + if "best_tau" in fitresult.keys(): + best_tau.append(fitresult["best_tau"].get().values()) + best_curvature.append(fitresult["best_lcurve"].get().values()) linestyles = [ "-", @@ -82,27 +100,23 @@ ":", "-.", ] -linestyles = [linestyles[i % len(linestyles)] for i in range(len(epochs))] +linestyles = [linestyles[i % len(linestyles)] for i in range(len(args.infile))] start = args.startEpoch stop = None -for x, y, xlabel, ylabel, stop, suffix in ( - (times, losses, "time [s]", r"$-2\Delta \ln(L)$", None, "loss"), - (epochs, losses, "epoch", r"$-2\Delta \ln(L)$", None, "loss_time"), - (times, dlosses, "epoch", r"$-2(\ln(L_{t}) - \ln(L_{t-1}))$", -1, "reduction_loss"), - ( - epochs, - dlosses, - "time [s]", - r"$-2(\ln(L_{t}) - \ln(L_{t-1}))$", - -1, - "reduction_loss_time", - ), -): - ymin = min([min(iy) for iy in y]) +if args.labels: + labels = args.labels +else: + labels = [None] * len(args.infile) + + +def plot(x, y, xlabel, ylabel, stop, suffix, points=[]): + if any(x in suffix for x in ["loss"]): + # Normalize to 0 + ymin = min([min(iy) for iy in y]) + y = [iy - ymin for iy in y] - y = [iy - ymin for iy in y] x = [ix[start:stop] for ix in x] if args.logy: @@ -111,22 +125,37 @@ max([max(iy) for iy in y]) * 2, ] else: - ylim = [0, max([max(iy) for iy in y]) * 1.1] + ymin = min([min(iy) for iy in y]) + ylim = [ymin * 1.1 if ymin < 0 else 0, max([max(iy) for iy in y]) * 1.1] fig, ax1 = plot_tools.figure( None, xlabel, ylabel, width_scale=1, - xlim=[0, max([max(ix) for ix in x])], + xlim=[min([min(ix) for ix in x]), max([max(ix) for ix in x])], ylim=ylim, automatic_scale=False, logy=args.logy, ) - for ix, iy, l, s in zip(x, y, args.labels, linestyles): + ax1.plot([0.8, 0.8], ylim, color="grey", linestyle="--") + + for ix, iy, l, s in zip(x, y, labels, linestyles): ax1.plot(ix, iy, label=l, linestyle=s) + for point_x, point_y in points: + ax1.plot( + point_x, + point_y, + marker="*", + markersize=15, + markerfacecolor="yellow", + markeredgecolor="black", + markeredgewidth=1.5, + linestyle="None", + ) + plot_tools.add_decor( ax1, args.title, @@ -156,3 +185,37 @@ outfile, args=args, ) + + +combinations = [] +if "loss" in args.types: + plot(epochs, losses, "epoch", r"$-2\Delta \ln(L)$", None, "loss") + plot( + epochs, + dlosses, + "time [s]", + r"$-2(\ln(L_{t}) - \ln(L_{t-1}))$", + -1, + "reduction_loss", + ) + if "time" in args.types: + plot(times, losses, "time [s]", r"$-2\Delta \ln(L)$", None, "loss_time") + plot( + times, + dlosses, + "epoch", + r"$-2(\ln(L_{t}) - \ln(L_{t-1}))$", + -1, + "reduction_loss_time", + ) + +if "lcurve" in args.types: + plot( + tau_steps, + lcurves, + r"$\tau$", + "Curvature", + None, + "lcurve", + points=zip(best_tau, best_curvature), + ) From 93f069dad16003109c300dde1dd76f17e9251ff9 Mon Sep 17 00:00:00 2001 From: davidwalter2 Date: Thu, 5 Feb 2026 13:47:42 -0500 Subject: [PATCH 5/6] Add flag to ensure numerical reproducebility --- setup.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.sh b/setup.sh index 5c038f1..6c2efa1 100644 --- a/setup.sh +++ b/setup.sh @@ -1,3 +1,4 @@ +export TF_ENABLE_ONEDNN_OPTS=0 export RABBIT_BASE=$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd ) export PYTHONPATH="${RABBIT_BASE}:$PYTHONPATH" export PATH="$PATH:${RABBIT_BASE}/bin" From 1ecfbfab030836108ee9fe30654bf9a5b5776137 Mon Sep 17 00:00:00 2001 From: davidwalter2 Date: Thu, 5 Feb 2026 15:04:10 -0500 Subject: [PATCH 6/6] Fix early stopping functionality --- rabbit/fitter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rabbit/fitter.py b/rabbit/fitter.py index aa22d76..ee89d50 100644 --- a/rabbit/fitter.py +++ b/rabbit/fitter.py @@ -56,7 +56,7 @@ def __call__(self, intermediate_result): if ( self.early_stopping > 0 and len(self.loss_history) > self.early_stopping - and self.loss_history[self.early_stopping] <= loss + and self.loss_history[-self.early_stopping] <= loss ): raise ValueError( f"No reduction in loss after {self.early_stopping} iterations, early stopping."