Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions bin/rabbit_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from rabbit.mappings import helpers as mh
from rabbit.mappings import mapping as mp
from rabbit.poi_models import helpers as ph
from rabbit.regularization import helpers as rh
from rabbit.tfhelpers import edmval_cov

from wums import output_tools, logging # isort: skip
Expand Down Expand Up @@ -149,6 +150,12 @@ def make_parser():
type=str,
help="Specify result from external postfit file",
)
parser.add_argument(
"--noFit",
default=False,
action="store_true",
help="Do not not perform the minimization.",
)
parser.add_argument(
"--noPostfitProfileBB",
default=False,
Expand Down Expand Up @@ -182,6 +189,18 @@ def make_parser():
action="store_true",
help="compute impacts of frozen (non-profiled) systematics",
)
parser.add_argument(
"--lCurveScan",
default=False,
action="store_true",
help="For use with regularization, scan the L curve versus values for tau",
)
parser.add_argument(
"--lCurveOptimize",
default=False,
action="store_true",
help="For use with regularization, find the value of tau that maximizes the curvature",
)

return parser.parse_args()

Expand Down Expand Up @@ -281,7 +300,21 @@ def fit(args, fitter, ws, dofit=True):
edmval = None

if args.externalPostfit is not None:
fitter.load_fitresult(args.externalPostfit, args.externalPostfitResult)
fitter.load_fitresult(
args.externalPostfit,
args.externalPostfitResult,
profile=not args.noPostfitProfileBB,
)

if args.lCurveScan:
tau_values, l_curve_values = rh.l_curve_scan_tau(fitter)
ws.add_1D_integer_hist(tau_values, "step", "tau")
ws.add_1D_integer_hist(l_curve_values, "step", "lcurve")

if args.lCurveOptimize:
best_tau, max_curvature = rh.l_curve_optimize_tau(fitter)
ws.add_1D_integer_hist([best_tau], "best", "tau")
ws.add_1D_integer_hist([max_curvature], "best", "lcurve")

if dofit:
cb = fitter.minimize()
Expand All @@ -293,7 +326,8 @@ def fit(args, fitter, ws, dofit=True):
fitter._profile_beta()

if cb is not None:
ws.add_loss_time_hist(cb.loss_history, cb.time_history)
ws.add_1D_integer_hist(cb.loss_history, "epoch", "loss")
ws.add_1D_integer_hist(cb.time_history, "epoch", "time")

if not args.noHessian:
# compute the covariance matrix and estimated distance to minimum
Expand Down Expand Up @@ -477,6 +511,13 @@ def main():
mp.CompositeMapping(mappings),
]

regularizers = []
for margs in args.regularization:
mapping = mh.load_mapping(margs[1], indata, *margs[2:])
regularizer = rh.load_regularizer(margs[0], mapping, dtype=indata.dtype)
regularizers.append(regularizer)
ifitter.regularizers = regularizers

np.random.seed(args.seed)
tf.random.set_seed(args.seed)

Expand Down Expand Up @@ -560,7 +601,7 @@ def main():

if not args.prefitOnly:
ifitter.set_blinding_offsets(blind=blinded_fits[i])
fit(args, ifitter, ws, dofit=ifit >= 0)
fit(args, ifitter, ws, dofit=ifit >= 0 and not args.noFit)
fit_time.append(time.time())

if args.saveHists:
Expand Down
88 changes: 73 additions & 15 deletions rabbit/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def match_regexp_params(regular_expressions, parameter_names):


class FitterCallback:
def __init__(self, xv):
def __init__(self, xv, early_stopping=-1):
self.iiter = 0
self.xval = xv

Expand All @@ -42,6 +42,8 @@ def __init__(self, xv):

self.t0 = time.time()

self.early_stopping = early_stopping

def __call__(self, intermediate_result):
loss = intermediate_result.fun

Expand All @@ -51,6 +53,15 @@ def __call__(self, intermediate_result):
if np.isnan(loss):
raise ValueError(f"Loss value is NaN at iteration {self.iiter}")

if (
self.early_stopping > 0
and len(self.loss_history) > self.early_stopping
and self.loss_history[-self.early_stopping] <= loss
):
raise ValueError(
f"No reduction in loss after {self.early_stopping} iterations, early stopping."
)

self.loss_history.append(loss)
self.time_history.append(time.time() - self.t0)

Expand All @@ -67,6 +78,7 @@ def __init__(
):
self.indata = indata

self.earlyStopping = options.earlyStopping
self.globalImpactsFromJVP = globalImpactsFromJVP
self.binByBinStat = not options.noBinByBinStat
self.binByBinStatMode = options.binByBinStatMode
Expand Down Expand Up @@ -292,6 +304,9 @@ def __init__(
name="cov",
)

# regularization
self.regularizers = []

# determine if problem is linear (ie likelihood is purely quadratic)
self.is_linear = (
(self.chisqFit or self.covarianceFit)
Expand All @@ -301,7 +316,7 @@ def __init__(
and ((not self.binByBinStat) or self.binByBinStatType == "normal-additive")
)

def load_fitresult(self, fitresult_file, fitresult_key):
def load_fitresult(self, fitresult_file, fitresult_key, profile=True):
# load results from external fit and set postfit value and covariance elements for common parameters
cov_ext = None
with h5py.File(fitresult_file, "r") as fext:
Expand Down Expand Up @@ -337,19 +352,24 @@ def load_fitresult(self, fitresult_file, fitresult_key):
covval[np.ix_(idxs, idxs)] = cov_ext[np.ix_(idxs_ext, idxs_ext)]
self.cov.assign(tf.constant(covval))

if profile:
self._profile_beta()

def update_frozen_params(self):
new_mask_np = np.isin(self.parms, self.frozen_params)

self.frozen_params_mask.assign(new_mask_np)
self.frozen_indices = np.where(new_mask_np)[0]

def freeze_params(self, frozen_parmeter_expressions):
logger.debug(f"Freeze params with {frozen_parmeter_expressions}")
self.frozen_params.extend(
match_regexp_params(frozen_parmeter_expressions, self.parms)
)
self.update_frozen_params()

def defreeze_params(self, unfrozen_parmeter_expressions):
logger.debug(f"Freeze params with {unfrozen_parmeter_expressions}")
unfrozen_parmeter = match_regexp_params(
unfrozen_parmeter_expressions, self.parms
)
Expand All @@ -359,6 +379,7 @@ def defreeze_params(self, unfrozen_parmeter_expressions):
self.update_frozen_params()

def init_blinding_values(self, unblind_parameter_expressions=[]):
logger.debug(f"Unblind parameters with {unblind_parameter_expressions}")
unblind_parameters = match_regexp_params(
unblind_parameter_expressions,
[
Expand Down Expand Up @@ -452,6 +473,9 @@ def get_poi(self):
else:
return poi

def get_x(self):
return tf.concat([self.get_poi(), self.get_theta()], axis=0)

def _default_beta0(self):
if self.binByBinStatType in ["gamma", "normal-multiplicative"]:
return tf.ones(self.beta_shape, dtype=self.indata.dtype)
Expand Down Expand Up @@ -534,6 +558,11 @@ def defaultassign(self):
if self.do_blinding:
self.set_blinding_offsets(False)

xinit = self.get_x()
nexp0 = self.expected_yield(full=True)
for reg in self.regularizers:
reg.set_expectations(xinit, nexp0)

def bayesassign(self):
# FIXME use theta0 as the mean and constraintweight to scale the width
if self.poi_model.npoi == 0:
Expand Down Expand Up @@ -1978,15 +2007,7 @@ def _compute_lbeta(self, beta, full_nll=False):

return None

def _compute_nll_components(self, profile=True, full_nll=False):
nexpfullcentral, _, beta = self._compute_yields_with_beta(
profile=profile,
compute_norm=False,
full=False,
)

nexp = nexpfullcentral

def _compute_ln(self, nexp, full_nll=False):
if self.chisqFit:
ln = 0.5 * tf.reduce_sum((nexp - self.nobs) ** 2 / self.varnobs, axis=-1)
elif self.covarianceFit:
Expand Down Expand Up @@ -2014,22 +2035,52 @@ def _compute_nll_components(self, profile=True, full_nll=False):
ln = tf.reduce_sum(
-self.nobs * (lognexp - self.lognobs) + nexp - self.nobs, axis=-1
)
return ln

def _compute_nll_components(self, profile=True, full_nll=False):
nexpfullcentral, _, beta = self._compute_yields_with_beta(
profile=profile,
compute_norm=False,
full=len(self.regularizers),
)

nexp = nexpfullcentral[: self.indata.nbins]

ln = self._compute_ln(nexp, full_nll)

lc = self._compute_lc(full_nll)

lbeta = self._compute_lbeta(beta, full_nll)

return ln, lc, lbeta, beta
# logger.debug(f"L(nobs) = {ln}")
# logger.debug(f"L(const.) = {lc}")
# logger.debug(f"L(beta) = {lbeta}")

if len(self.regularizers):
x = self.get_x()
penalties = [
reg.compute_nll_penalty(x, nexpfullcentral) for reg in self.regularizers
]
lpenalty = tf.add_n(penalties)
else:
lpenalty = None

# logger.debug(f"L(penalty) = {lpenalty}")

return ln, lc, lbeta, lpenalty, beta

def _compute_nll(self, profile=True, full_nll=False):
ln, lc, lbeta, beta = self._compute_nll_components(
ln, lc, lbeta, lpenalty, beta = self._compute_nll_components(
profile=profile, full_nll=full_nll
)
l = ln + lc

if lbeta is not None:
l = l + lbeta

if lpenalty is not None:
l = l + lpenalty

return l

def _compute_loss(self, profile=True):
Expand Down Expand Up @@ -2111,7 +2162,7 @@ def loss_val_grad_hess_beta(self, profile=True):

return val, grad, hess

def minimize(self):
def minimize(self, eraly_stopping=10):
if self.is_linear:
logger.info(
"Likelihood is purely quadratic, solving by Cholesky decomposition instead of iterative fit"
Expand All @@ -2138,16 +2189,22 @@ def minimize(self):

callback = None
else:
logger.info("Perform iterative fit")

def scipy_loss(xval):
self.x.assign(xval)
val, grad = self.loss_val_grad()
# logger.debug(f"xval = {xval}")
# logger.debug(f"val = {val}; grad = {grad}")
return val.__array__(), grad.__array__()

def scipy_hessp(xval, pval):
self.x.assign(xval)
p = tf.convert_to_tensor(pval)
val, grad, hessp = self.loss_val_grad_hessp(p)
# logger.debug(f"xval = {xval}")
# logger.debug(f"p = {p}")
# logger.debug(f"val = {val}; grad = {grad}; hessp = {hessp}")
return hessp.__array__()

def scipy_hess(xval):
Expand All @@ -2162,7 +2219,7 @@ def scipy_hess(xval):

xval = self.x.numpy()

callback = FitterCallback(xval)
callback = FitterCallback(xval, self.earlyStopping)

if self.minimizer_method in [
"trust-krylov",
Expand Down Expand Up @@ -2190,6 +2247,7 @@ def scipy_hess(xval):
except Exception as ex:
# minimizer could have called the loss or hessp functions with "random" values, so restore the
# state from the end of the last iteration before the exception
logger.debug("Fitter exception")
xval = callback.xval
logger.debug(ex)
else:
Expand Down
19 changes: 19 additions & 0 deletions rabbit/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ def common_parser():
action="store_true",
help="Calculate and print additional info for diagnostics (condition number, edm value)",
)
parser.add_argument(
"--earlyStopping",
default=10,
type=int,
help="Number of iterations with no improvement after which training will be stopped. Specify -1 to disable.",
)
parser.add_argument(
"--minimizerMethod",
default="trust-krylov",
Expand Down Expand Up @@ -211,5 +217,18 @@ def common_parser():
action="store_true",
help="Make a composite mapping and compute the covariance matrix across all mappings.",
)
parser.add_argument(
"-r",
"--regularization",
nargs="+",
action="append",
default=[],
help="""
apply regularization on the output "nout" of a mapping by including a penalty term P(nout) in the -log(L) of the minimization.
As argument, specify the regulaization defined in rabbit/regularization/, followed by a mapping using the same syntax as discussed above.
e.g. '-r SVD Select ch0_masked' to apply SVD regularization on the channel 'ch0_masked' or '-r SVD Project ch0 pt' for the 1D projection to pt.
Custom regularization can be specified with the full path e.g. '-r custom_regularization.MyCustomRegularization Project ch0 pt'.
""",
)

return parser
Empty file.
Loading