diff --git a/rabbit/fitter.py b/rabbit/fitter.py index faba62a..3bc3554 100644 --- a/rabbit/fitter.py +++ b/rabbit/fitter.py @@ -1632,6 +1632,22 @@ def _compute_yields_with_beta(self, profile=True, compute_norm=False, full=True) if self.binByBinStatType in ["gamma", "normal-multiplicative"]: betamask = self.betamask[: nexp.shape[0]] if self.binByBinStatMode == "full": + + if self.indata.betavar is not None and full: + # apply beta variations as normal scaling + n0 = self.indata.norm + sbeta = tf.math.sqrt(self.kstat[: self.indata.nbins]) + dbeta = sbeta * (betasel[: self.indata.nbins] - 1) + dbeta = tf.where( + betamask[: self.indata.nbins], tf.zeros_like(dbeta), dbeta + ) + var = tf.einsum("ijk,jk->ik", self.indata.betavar, dbeta) + safe_n0 = tf.where( + n0 > 0, n0, 1.0 + ) # Use 1.0 as a dummy to avoid div by zero + ratio = var / safe_n0 + norm = tf.where(n0 > 0, norm * (1 + ratio), norm) + norm = tf.where(betamask, norm, betasel * norm) nexp = tf.reduce_sum(norm, -1) else: @@ -2063,6 +2079,14 @@ def loss_val_grad_hess_beta(self, profile=True): grad = t1.gradient(val, self.ubeta) hess = t2.jacobian(grad, self.ubeta) + grad = tf.reshape(grad, [-1]) + hess = tf.reshape(hess, [grad.shape[0], grad.shape[0]]) + + betamask = ~tf.reshape(self.betamask, [-1]) + grad = grad[betamask] + hess = tf.boolean_mask(hess, betamask, axis=0) + hess = tf.boolean_mask(hess, betamask, axis=1) + return val, grad, hess def minimize(self): diff --git a/rabbit/inputdata.py b/rabbit/inputdata.py index 806bd44..88e41dd 100644 --- a/rabbit/inputdata.py +++ b/rabbit/inputdata.py @@ -90,6 +90,10 @@ def __init__(self, filename, pseudodata=None): else: self.norm = maketensor(f["hnorm"]) self.logk = maketensor(f["hlogk"]) + if "hbetavariations" in f.keys(): + self.betavar = maketensor(f["hbetavariations"]) + else: + self.betavar = None # infer some metadata from loaded information self.dtype = self.data_obs.dtype diff --git a/rabbit/tensorwriter.py b/rabbit/tensorwriter.py index 9121284..108e5f0 100644 --- a/rabbit/tensorwriter.py +++ b/rabbit/tensorwriter.py @@ -55,6 +55,7 @@ def __init__( self.dict_logkhalfdiff = {} # [channel][proc][syst] self.dict_logkavg_indices = {} self.dict_logkhalfdiff_indices = {} + self.dict_beta_variations = {} # [channel][syst][process] self.clipSystVariations = False if self.clipSystVariations > 0.0: @@ -160,6 +161,7 @@ def add_channel(self, axes, name=None, masked=False, flow=False): self.nbinschan[name] = ibins self.dict_norm[name] = {} self.dict_sumw2[name] = {} + self.dict_beta_variations[name] = {} # add masked channels last and not masked channels first this_channel = {"axes": [a for a in axes], "masked": masked, "flow": flow} @@ -392,6 +394,56 @@ def add_systematic( var_name_out, add_to_data_covariance=add_to_data_covariance, **kargs ) + def add_beta_variations( + self, + h, + process, + source_channel, + dest_channel, + ): + """ + Adds a template variation in the destination channel that is correlated with the beta variation in the source channel for a given process + h: must be a histogram with the axes of the source channel and destiation channel. + """ + if self.sparse: + raise NotImplementedError("Sparse implementation not yet implemented") + + if source_channel not in self.channels.keys(): + raise RuntimeError(f"Channel {source_channel} not known!") + if dest_channel not in self.channels.keys(): + raise RuntimeError(f"Channel {dest_channel} not known!") + if not self.channels[dest_channel]["masked"]: + raise RuntimeError( + f"Beta variations can only be applied to masked channels" + ) + + norm = self.dict_norm[dest_channel][process] + + source_axes = self.channels[source_channel]["axes"] + dest_axes = self.channels[dest_channel]["axes"] + + source_axes_names = [a.name for a in source_axes] + dest_axes_names = [a.name for a in dest_axes] + + for a in source_axes: + if a.name not in h.axes.name: + raise RuntimeError( + f"Axis {a.name} not found in histogram h with {h.axes.name}" + ) + for a in dest_axes: + if a.name not in h.axes.name: + raise RuntimeError( + f"Axis {a.name} not found in histogram h with {h.axes.name}" + ) + + variation = h.project(*dest_axes_names, *source_axes_names).values() + variation = variation.reshape((*norm.shape, -1)) + + if source_channel not in self.dict_beta_variations[dest_channel].keys(): + self.dict_beta_variations[dest_channel][source_channel] = {} + + self.dict_beta_variations[dest_channel][source_channel][process] = variation + def get_logk(self, syst, norm, kfac=1.0, systematic_type=None): if not np.all(np.isfinite(syst)): raise RuntimeError( @@ -714,6 +766,7 @@ def write(self, outfolder="./", outfilename="rabbit_input.hdf5", args={}): else: logk = np.zeros([nbinsfull, nproc, 2, nsyst], self.dtype) + beta_variations = np.zeros([nbinsfull, nbins, nproc], self.dtype) for chan in self.channels.keys(): nbinschan = self.nbinschan[chan] dict_norm_chan = self.dict_norm[chan] @@ -745,6 +798,32 @@ def write(self, outfolder="./", outfilename="rabbit_input.hdf5", args={}): dict_logkhalfdiff_proc[syst] ) + for ( + source_channel, + source_channel_dict, + ) in self.dict_beta_variations[chan].items(): + if proc in source_channel_dict: + + # find the bins of the source channel + ibin_start = 0 + for c, nb in self.nbinschan.items(): + if self.channels[c]["masked"]: + continue # masked channels can not be source channels + if c == source_channel: + ibin_end = ibin_start + nb + break + else: + ibin_start += nb + else: + raise RuntimeError( + f"Did not find source channel {source_channel} in list of channels {[k for k in self.nbinschan.keys()]}" + ) + + beta_vars = source_channel_dict[proc] + beta_variations[ + ibin : ibin + nbinschan, ibin_start:ibin_end, iproc + ] = beta_vars + ibin += nbinschan if self.data_covariance is None and ( @@ -916,6 +995,11 @@ def create_dataset( ) logk = None + nbytes += h5pyutils.writeFlatInChunks( + beta_variations, f, "hbetavariations", maxChunkBytes=self.chunkSize + ) + beta_variations = None + logger.info(f"Total raw bytes in arrays = {nbytes}") def get_systsstandard(self):