Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions rabbit/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1632,6 +1632,22 @@ def _compute_yields_with_beta(self, profile=True, compute_norm=False, full=True)
if self.binByBinStatType in ["gamma", "normal-multiplicative"]:
betamask = self.betamask[: nexp.shape[0]]
if self.binByBinStatMode == "full":

if self.indata.betavar is not None and full:
# apply beta variations as normal scaling
n0 = self.indata.norm
sbeta = tf.math.sqrt(self.kstat[: self.indata.nbins])
dbeta = sbeta * (betasel[: self.indata.nbins] - 1)
dbeta = tf.where(
betamask[: self.indata.nbins], tf.zeros_like(dbeta), dbeta
)
var = tf.einsum("ijk,jk->ik", self.indata.betavar, dbeta)
safe_n0 = tf.where(
n0 > 0, n0, 1.0
) # Use 1.0 as a dummy to avoid div by zero
ratio = var / safe_n0
norm = tf.where(n0 > 0, norm * (1 + ratio), norm)

norm = tf.where(betamask, norm, betasel * norm)
nexp = tf.reduce_sum(norm, -1)
else:
Expand Down Expand Up @@ -2063,6 +2079,14 @@ def loss_val_grad_hess_beta(self, profile=True):
grad = t1.gradient(val, self.ubeta)
hess = t2.jacobian(grad, self.ubeta)

grad = tf.reshape(grad, [-1])
hess = tf.reshape(hess, [grad.shape[0], grad.shape[0]])

betamask = ~tf.reshape(self.betamask, [-1])
grad = grad[betamask]
hess = tf.boolean_mask(hess, betamask, axis=0)
hess = tf.boolean_mask(hess, betamask, axis=1)

return val, grad, hess

def minimize(self):
Expand Down
4 changes: 4 additions & 0 deletions rabbit/inputdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ def __init__(self, filename, pseudodata=None):
else:
self.norm = maketensor(f["hnorm"])
self.logk = maketensor(f["hlogk"])
if "hbetavariations" in f.keys():
self.betavar = maketensor(f["hbetavariations"])
else:
self.betavar = None

# infer some metadata from loaded information
self.dtype = self.data_obs.dtype
Expand Down
84 changes: 84 additions & 0 deletions rabbit/tensorwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(
self.dict_logkhalfdiff = {} # [channel][proc][syst]
self.dict_logkavg_indices = {}
self.dict_logkhalfdiff_indices = {}
self.dict_beta_variations = {} # [channel][syst][process]

self.clipSystVariations = False
if self.clipSystVariations > 0.0:
Expand Down Expand Up @@ -160,6 +161,7 @@ def add_channel(self, axes, name=None, masked=False, flow=False):
self.nbinschan[name] = ibins
self.dict_norm[name] = {}
self.dict_sumw2[name] = {}
self.dict_beta_variations[name] = {}

# add masked channels last and not masked channels first
this_channel = {"axes": [a for a in axes], "masked": masked, "flow": flow}
Expand Down Expand Up @@ -392,6 +394,56 @@ def add_systematic(
var_name_out, add_to_data_covariance=add_to_data_covariance, **kargs
)

def add_beta_variations(
self,
h,
process,
source_channel,
dest_channel,
):
"""
Adds a template variation in the destination channel that is correlated with the beta variation in the source channel for a given process
h: must be a histogram with the axes of the source channel and destiation channel.
"""
if self.sparse:
raise NotImplementedError("Sparse implementation not yet implemented")

if source_channel not in self.channels.keys():
raise RuntimeError(f"Channel {source_channel} not known!")
if dest_channel not in self.channels.keys():
raise RuntimeError(f"Channel {dest_channel} not known!")
if not self.channels[dest_channel]["masked"]:
raise RuntimeError(
f"Beta variations can only be applied to masked channels"
)

norm = self.dict_norm[dest_channel][process]

source_axes = self.channels[source_channel]["axes"]
dest_axes = self.channels[dest_channel]["axes"]

source_axes_names = [a.name for a in source_axes]
dest_axes_names = [a.name for a in dest_axes]

for a in source_axes:
if a.name not in h.axes.name:
raise RuntimeError(
f"Axis {a.name} not found in histogram h with {h.axes.name}"
)
for a in dest_axes:
if a.name not in h.axes.name:
raise RuntimeError(
f"Axis {a.name} not found in histogram h with {h.axes.name}"
)

variation = h.project(*dest_axes_names, *source_axes_names).values()
variation = variation.reshape((*norm.shape, -1))

if source_channel not in self.dict_beta_variations[dest_channel].keys():
self.dict_beta_variations[dest_channel][source_channel] = {}

self.dict_beta_variations[dest_channel][source_channel][process] = variation

def get_logk(self, syst, norm, kfac=1.0, systematic_type=None):
if not np.all(np.isfinite(syst)):
raise RuntimeError(
Expand Down Expand Up @@ -714,6 +766,7 @@ def write(self, outfolder="./", outfilename="rabbit_input.hdf5", args={}):
else:
logk = np.zeros([nbinsfull, nproc, 2, nsyst], self.dtype)

beta_variations = np.zeros([nbinsfull, nbins, nproc], self.dtype)
for chan in self.channels.keys():
nbinschan = self.nbinschan[chan]
dict_norm_chan = self.dict_norm[chan]
Expand Down Expand Up @@ -745,6 +798,32 @@ def write(self, outfolder="./", outfilename="rabbit_input.hdf5", args={}):
dict_logkhalfdiff_proc[syst]
)

for (
source_channel,
source_channel_dict,
) in self.dict_beta_variations[chan].items():
if proc in source_channel_dict:

# find the bins of the source channel
ibin_start = 0
for c, nb in self.nbinschan.items():
if self.channels[c]["masked"]:
continue # masked channels can not be source channels
if c == source_channel:
ibin_end = ibin_start + nb
break
else:
ibin_start += nb
else:
raise RuntimeError(
f"Did not find source channel {source_channel} in list of channels {[k for k in self.nbinschan.keys()]}"
)

beta_vars = source_channel_dict[proc]
beta_variations[
ibin : ibin + nbinschan, ibin_start:ibin_end, iproc
] = beta_vars

ibin += nbinschan

if self.data_covariance is None and (
Expand Down Expand Up @@ -916,6 +995,11 @@ def create_dataset(
)
logk = None

nbytes += h5pyutils.writeFlatInChunks(
beta_variations, f, "hbetavariations", maxChunkBytes=self.chunkSize
)
beta_variations = None

logger.info(f"Total raw bytes in arrays = {nbytes}")

def get_systsstandard(self):
Expand Down