diff --git a/conda_reqs.txt b/conda_reqs.txt new file mode 100644 index 0000000..bef1e74 --- /dev/null +++ b/conda_reqs.txt @@ -0,0 +1,12 @@ +numpy +#torch +transformers +pyyaml +tqdm +loguru +ipython +torchsummary +pyjet +scipy +git+https://github.com/allenai/longformer.git +matplotlib diff --git a/grapple/data/torch.py b/grapple/data/torch.py index c1f3e7c..a1b9964 100644 --- a/grapple/data/torch.py +++ b/grapple/data/torch.py @@ -5,6 +5,7 @@ import numpy as np from tqdm import tqdm from itertools import chain +import math class PUDataset(IterableDataset): @@ -69,10 +70,20 @@ def __iter__(self): puppimet = data['puppimet'] pfmet = data['pfmet'] - X = X[:, :self.n_particles, :] - Y = Y[:, :self.n_particles] - P = P[:, :self.n_particles] - Q = Q[:, :self.n_particles] + n_particles_raw = Y.shape[1] + + if self.n_particles < n_particles_raw: + X = X[:, :self.n_particles, :] + Y = Y[:, :self.n_particles] + P = P[:, :self.n_particles] + Q = Q[:, :self.n_particles] + elif n_particles_raw < self.n_particles: + diff = self.n_particles - n_particles_raw + X = np.pad(X, (0, diff, 0)) + Y = np.pad(Y, (0, diff)) + P = np.pad(P, (0, diff)) + Q = np.pad(Q, (0, diff)) + mask_base = np.arange(X.shape[1]) idx = np.arange(X.shape[0]) @@ -276,14 +287,33 @@ def _get_len(self): return n_tot def __iter__(self): - np.random.shuffle(self._files) - for f in self._files[:self.num_max_files]: + files = self._files[:] + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + per_worker = int(math.ceil(len(files) / worker_info.num_workers)) + this_worker = worker_info.id + files = files[this_worker * per_worker : (this_worker+1) * per_worker] + np.random.shuffle(files) + for f in files[:self.num_max_files]: raw_data = np.load(f) data = raw_data['x'] met = raw_data['met'] jet1 = raw_data['jet1'] - - data = data[:, :self.n_particles, :] + if 'genz' in raw_data: + genz = raw_data['genz'] + else: + genz = np.copy(met) + + n_particles_raw = data.shape[1] + if n_particles_raw > self.n_particles: + data = data[:, :self.n_particles, :] + elif n_particles_raw < self.n_particles: + diff = self.n_particles - n_particles_raw + data = np.pad( + data, + pad_width=((0, 0), (0, diff), (0, 0)), + mode='constant', constant_values=0 + ) X = data[:,:,[i for b,i in self.b2i.items() if b not in ('hardfrac', 'puppi')]] y = data[:,:,self.b2i['hardfrac']] @@ -303,6 +333,7 @@ def __iter__(self): 'mask': mask[i, :], 'neutral_mask': neutral_mask[i, :], 'genmet': met[i, :], + 'genv': genz[i, :], 'jet1': jet1[i, :], } diff --git a/grapple/metrics.py b/grapple/metrics.py index 893ae6a..83bbd01 100644 --- a/grapple/metrics.py +++ b/grapple/metrics.py @@ -6,183 +6,60 @@ import matplotlib.pyplot as plt import numpy as np import pickle +import pandas as pd +import seaborn as sns import pyjet +from pyjet.testdata import get_event -EPS = 1e-4 - - -class Metrics(object): - def __init__(self, device, softmax=True): - self.loss_calc = nn.CrossEntropyLoss( - ignore_index=-1, - reduction='none' - # weight=torch.FloatTensor([1, 5]).to(device) - ) - self.reset() - self.apply_softmax = softmax - - def reset(self): - self.loss = 0 - self.acc = 0 - self.pos_acc = 0 - self.neg_acc = 0 - self.n_pos = 0 - self.n_particles = 0 - self.n_steps = 0 - self.hists = {} - - @staticmethod - def make_roc(pos_hist, neg_hist): - pos_hist = pos_hist / pos_hist.sum() - neg_hist = neg_hist / neg_hist.sum() - tp, fp = [], [] - for i in np.arange(pos_hist.shape[0], -1, -1): - tp.append(pos_hist[i:].sum()) - fp.append(neg_hist[i:].sum()) - auc = np.trapz(tp, x=fp) - plt.plot(fp, tp, label=f'AUC={auc:.3f}') - return fp, tp - - def add_values(self, yhat, y, label, idx, w=None): - if w is not None: - w = w[y==label] - hist, self.bins = np.histogram(yhat[y==label], bins=np.linspace(0, 1, 100), weights=w) - if idx not in self.hists: - self.hists[idx] = hist + EPS - else: - self.hists[idx] += hist - - def compute(self, yhat, y, orig_y, w=None, m=None): - # yhat = [batch, particles, labels]; y = [batch, particles] - loss = self.loss_calc(yhat.view(-1, yhat.shape[-1]), y.view(-1)) - if w is not None: - wv = w.view(-1) - loss *= wv - if m is None: - m = np.ones_like(t2n(y), dtype=bool) - loss = torch.mean(loss) - self.loss += t2n(loss).mean() - - mask = (y != -1) - n_particles = t2n(mask.sum()) - - pred = torch.argmax(yhat, dim=-1) # [batch, particles] - pred = t2n(pred) - y = t2n(y) - mask = t2n(mask) - - acc = (pred == y)[mask].sum() / n_particles - self.acc += acc - - n_pos = np.logical_and(m, y == 1).sum() - pos_acc = (pred == y)[np.logical_and(m, y == 1)].sum() / n_pos - self.pos_acc += pos_acc - neg_acc = (pred == y)[np.logical_and(m, y == 0)].sum() / (n_particles - n_pos) - self.neg_acc += neg_acc - - self.n_pos += n_pos - self.n_particles += n_particles - - self.n_steps += 1 - - if self.apply_softmax: - yhat = t2n(nn.functional.softmax(yhat, dim=-1)) - else: - yhat = t2n(yhat) - if w is not None: - w = t2n(w).reshape(orig_y.shape) - wm = w[m] - wnm = w[~m] - else: - wm = wnm = None - self.add_values(yhat[:,:,1][m], orig_y[m], 0, 0, wm) - self.add_values(yhat[:,:,1][m], orig_y[m], 1, 1, wm) - self.add_values(yhat[:,:,1][~m], orig_y[~m], 0, 2, wnm) - self.add_values(yhat[:,:,1][~m], orig_y[~m], 1, 3, wnm) - - if self.n_steps % 50 == 0 and False: - print(t2n(y[0])[:10]) - print(t2n(pred[0])[:10]) - print(t2n(yhat[0])[:10, :]) - - return loss, acc - - def mean(self): - return ([x / self.n_steps - for x in [self.loss, self.acc, self.pos_acc, self.neg_acc]] - + [self.n_pos / self.n_particles]) - - def plot(self, path): - plt.clf() - x = (self.bins[:-1] + self.bins[1:]) * 0.5 - hist_args = { - 'histtype': 'step', - #'alpha': 0.25, - 'bins': self.bins, - 'log': True, - 'x': x, - 'density': True - } - plt.hist(weights=self.hists[0], label='PU Neutral', **hist_args) - plt.hist(weights=self.hists[1], label='Hard Neutral', **hist_args) - plt.hist(weights=self.hists[2], label='PU Charged', **hist_args) - plt.hist(weights=self.hists[3], label='Hard Charged', **hist_args) - plt.ylim(bottom=0.001, top=5e3) - plt.xlabel('P(Hard|p,e)') - plt.legend() - for ext in ('pdf', 'png'): - plt.savefig(path + '.' + ext) - - plt.clf() - fig_handle = plt.figure() - fp, tp, = self.make_roc(self.hists[1], self.hists[0]) - plt.ylabel('True Neutral Positive Rate') - plt.xlabel('False Neutral Positive Rate') - path += '_roc' - plt.legend() - for ext in ('pdf', 'png'): - plt.savefig(path + '.' + ext) - pickle.dump({'fp': fp, 'tp': tp}, open(path + '.pkl', 'wb')) - plt.close(fig_handle) +from loguru import logger +EPS = 1e-4 -class METMetrics(Metrics): - def __init__(self, device, softmax=True): - super().__init__(device, softmax) +from ._old import * - self.mse = nn.MSELoss() - self.met_loss_weight = 1 - def compute(self, yhat, y, orig_y, met, methat, w=None, m=None): - met_loss = self.met_loss_weight * self.mse(methat.view(-1), met.view(-1)) - pu_loss, acc = super().compute(yhat, y, orig_y, w, m) - loss = met_loss + pu_loss - return loss, acc +def square(x): + return torch.pow(x, 2) class JetResolution(object): + bins = { + 'pt': np.linspace(-100, 100, 100), + 'm': np.linspace(-50, 50, 100), + 'mjj': np.linspace(-1000, 1000, 100), + } + bins_2 = { + 'pt': np.linspace(0, 500, 100), + 'm': np.linspace(0, 50, 100), + 'mjj': np.linspace(0, 3000, 100), + } + labels = { + 'pt': 'Jet $p_T$', + 'm': 'Jet $m$', + 'mjj': '$m_{jj}$', + } + methods = { + 'model': r'$\mathrm{PUMA}$', + 'puppi': r'$\mathrm{PUPPI}$', + 'truth': 'Truth+PF', + } def __init__(self): - self.bins = { - 'pt': np.linspace(-100, 100, 100), - 'm': np.linspace(-50, 50, 100), - 'mjj': np.linspace(-1000, 1000, 100), - } - self.bins_2 = { - 'pt': np.linspace(0, 500, 100), - 'm': np.linspace(0, 50, 100), - 'mjj': np.linspace(0, 3000, 100), - } - self.labels = { - 'pt': 'Jet $p_T$', - 'm': 'Jet $m$', - 'mjj': '$m_{jj}$', - } self.reset() def reset(self): - self.dists = {k:[] for k in ['pt', 'm', 'mjj']} - self.dists_2 = {k:([], []) for k in ['pt', 'm', 'mjj']} + self.dists = {k:{m:[] for m in self.methods} for k in ['pt', 'm', 'mjj']} + self.dists_2 = {k:{m:([], []) for m in self.methods} for k in ['pt', 'm', 'mjj']} + self.truth_pt = {m:[] for m in self.methods} + + @staticmethod + def compute_p4(x): + pt, eta, phi, e = (x[:,:,i] for i in range(4)) + px = pt * np.cos(phi) + py = pt * np.sin(phi) + pz = pt * np.sinh(eta) + p4 = np.stack([e, px, py, pz], axis=-1) + return p4 @staticmethod def compute_mass(x): @@ -191,33 +68,46 @@ def compute_mass(x): m = np.sqrt(np.clip(e**2 - p**2, 0, None)) return m - def compute(self, x, weight, mask, pt0, m0=None, mjj=None): - x = np.copy(x[:, :, :4]) - m = self.compute_mass(x) - x[:, :, 0] = x[:, :, 0] * weight - x[:,:,3] = m - #print(x[:,:,3]) - #x[:, :, 3] = 0 # temporary override to approximate mass - x = x.astype(np.float64) - n_batch = x.shape[0] + def compute(self, x, weights, mask, pt0, m0=None, mjj=None): + for k,v in weights.items(): + self._internal_compute(k, x, v, mask, pt0, m0, mjj) + + def _internal_compute(self, tag, x, weight, mask, pt0, m0=None, mjj=None): + p4 = self.compute_p4(x[:, :, :4]) + p4 *= weight[:, :, None] + p4 = p4.astype(np.float64) + n_batch = p4.shape[0] + pt = x[:, :, 0] * weight for i in range(n_batch): - evt = x[i][np.logical_and(mask[i].astype(bool), x[i,:,0]>0)] + particle_mask = pt[i, :] > 0 + particle_mask = np.logical_and( + particle_mask, + np.logical_and( + ~np.isnan(p4[i]).sum(-1), + ~np.isinf(p4[i]).sum(-1) + ) + ) + evt = p4[i][np.logical_and( + mask[i].astype(bool), + particle_mask + )] evt = np.core.records.fromarrays( evt.T, - names='pt, eta, phi, m', + names='E, px, py, pz', formats='f8, f8, f8, f8' ) - seq = pyjet.cluster(evt, R=0.4, p=-1) + seq = pyjet.cluster(evt, R=0.4, p=-1, ep=True) jets = seq.inclusive_jets() if len(jets) > 0: - self.dists['pt'].append(jets[0].pt - pt0[i]) - self.dists_2['pt'][0].append(pt0[i]) - self.dists_2['pt'][1].append(jets[0].pt) + self.dists['pt'][tag].append(jets[0].pt - pt0[i]) + self.truth_pt[tag].append(pt0[i]) + self.dists_2['pt'][tag][0].append(pt0[i]) + self.dists_2['pt'][tag][1].append(jets[0].pt) if m0 is not None: self.dists['m'].append(jets[0].mass - m0[i]) - self.dists_2['m'][0].append(m0[i]) - self.dists_2['m'][1].append(jets[0].mass) + self.dists_2['m'][tag][0].append(m0[i]) + self.dists_2['m'][tag][1].append(jets[0].mass) if mjj is not None: if len(jets) > 1: @@ -231,38 +121,80 @@ def compute(self, x, weight, mask, pt0, m0=None, mjj=None): else: mjj_pred = 0 if mjj[i] > 0: - self.dists['mjj'].append(mjj_pred - mjj[i]) - self.dists_2['mjj'][0].append(mjj[i]) - self.dists_2['mjj'][1].append(mjj_pred) + self.dists['mjj'][tag].append(mjj_pred - mjj[i]) + self.dists_2['mjj'][tag][0].append(mjj[i]) + self.dists_2['mjj'][tag][1].append(mjj_pred) + + @staticmethod + def _compute_moments(x): + return np.mean(x), np.std(x) + + def plot(self, path): + plt.clf() + x = (self.bins[:-1] + self.bins[1:]) * 0.5 + + mean_p, var_p = self._compute_moments(x, self.dist_p) + mean_pup, var_pup = self._compute_moments(x, self.dist_pup) + def plot(self, path): - for k, data in self.dists.items(): - if len(data) == 0: - continue + for k, m_data in self.dists.items(): plt.clf() - plt.hist(data, bins=self.bins[k]) + for m, data in m_data.items(): + if len(data) == 0: + continue + mean, var = self._compute_moments(data) + label = fr'{self.methods[m]} ($\delta=' + f'{mean:.1f}' + r'\pm' + f'{np.sqrt(var):.1f})$' + plt.hist(data, bins=self.bins[k], label=label, + histtype='step') + plt.legend() plt.xlabel(f'Predicted-True {self.labels[k]} [GeV]') for ext in ('pdf', 'png'): plt.savefig(f'{path}_{k}_err.{ext}') - with open(f'{path}_{k}_err.pkl', 'wb') as fpkl: - pickle.dump( - {'data': data, 'bins':self.bins[k]}, - fpkl - ) - for k, data in self.dists_2.items(): - if len(data[0]) == 0: - continue - plt.clf() - plt.hist2d(data[0], data[1], bins=self.bins_2[k]) - plt.xlabel(f'True {self.labels[k]} [GeV]') - plt.ylabel(f'Predicted {self.labels[k]} [GeV]') - for ext in ('pdf', 'png'): - plt.savefig(f'{path}_{k}_corr.{ext}') + plt.clf() + dfs = [] + lo, hi = 0, 500 + n_bins = 10 + bins = np.linspace(lo, hi, n_bins) + for m,m_label in self.methods.items(): + truth = np.digitize(self.truth_pt[m], bins) + truth = (truth * (hi - lo) / n_bins) + lo + df = pd.DataFrame({ + 'x': truth, + 'y': self.dists['pt'][m], + 'Method': [m_label] * truth.shape[0] + }) + dfs.append(df) + df = pd.concat(dfs, axis=0) + sns.boxplot( + x='x', y='y', hue='Method', data=df, + order=(bins * (hi - lo) / n_bins) + lo, + ) + plt.xlabel(rf'True {self.labels["pt"]} [GeV]') + plt.ylabel(rf'Error {self.labels["pt"]} [GeV]') + for ext in ('pdf', 'png'): + plt.savefig(f'{path}_differr.{ext}') + + for k, m_data in self.dists_2.items(): + for m, data in m_data.items(): + if len(data) == 0: + continue + plt.clf() + plt.hist2d(data[0], data[1], bins=self.bins_2[k], cmin=0.1) + plt.xlabel(f'True {self.labels[k]} [GeV]') + plt.ylabel(f'Predicted {self.labels[k]} [GeV]') + for ext in ('pdf', 'png'): + plt.savefig(f'{path}_{k}_{m}_corr.{ext}') class METResolution(object): + methods = { + 'model': r'$\mathrm{PUMA}$', + 'puppi': r'$\mathrm{PUPPI}$', + 'truth': 'Truth+PF', + } def __init__(self, bins=np.linspace(-100, 100, 40)): self.bins = bins self.bins_2 = (0, 400) @@ -334,10 +266,10 @@ def plot(self, path): mean_p, var_p = self._compute_moments(x, self.dist_p) mean_pup, var_pup = self._compute_moments(x, self.dist_pup) - label = r'Model ($\delta=' + f'{mean:.1f}' + r'\pm' + f'{np.sqrt(var):.1f})$' + label = r'$\mathrm{PUMA}$ ($\delta=' + f'{mean:.1f}' + r'\pm' + f'{np.sqrt(var):.1f})$' plt.hist(x=x, weights=self.dist, label=label, histtype='step', bins=self.bins) - label = r'Puppi ($\delta=' + f'{mean_pup:.1f}' + r'\pm' + f'{np.sqrt(var_pup):.1f})$' + label = r'$\mathrm{PUPPI}$ ($\delta=' + f'{mean_pup:.1f}' + r'\pm' + f'{np.sqrt(var_pup):.1f})$' plt.hist(x=x, weights=self.dist_pup, label=label, histtype='step', bins=self.bins) label = r'Ground Truth ($\delta=' + f'{mean_p:.1f}' + r'\pm' + f'{np.sqrt(var_p):.1f})$' @@ -384,21 +316,55 @@ def plot(self, path): class ParticleMETResolution(METResolution): + def __init__(self, which='mag', bins=np.linspace(-100, 100, 40)): + super().__init__(bins) + self.bins_phi = np.linspace(0, 3.142, 40) + self.which = which + + def reset(self): + self.dist = None + self.dist_p = None + self.dist_pup = None + self.dist_2 = None + self.dist_met = None + self.dist_pred = None + self.dist_2_p = None + self.dist_2_pup = None + + self.pred = {k: [] for k in self.methods} + self.truth = [] + self.predphi = {k: [] for k in self.methods} + self.truthphi = [] + @staticmethod - def _compute_res(pt, phi, w, gm): + def _compute_res(pt, phi, w, gm, gmphi, which): pt = pt * w - px = pt * np.cos(phi) - py = pt * np.sin(phi) - metx = np.sum(px, axis=-1) - mety = np.sum(py, axis=-1) - met = np.sqrt(np.power(metx, 2) + np.power(mety, 2)) + if which in ('mag', 'x'): + px = pt * np.cos(phi) + metx = -np.sum(px, axis=-1) + if which in ('mag', 'y'): + py = pt * np.sin(phi) + mety = -np.sum(py, axis=-1) + if which == 'mag': + met = np.sqrt(np.power(metx, 2) + np.power(mety, 2)) + met_vec = np.stack([metx, mety], axis=-1) + gm_vec = np.stack([gm*np.cos(gmphi), gm*np.sin(gmphi)], axis=-1) + resphi = np.arccos( + np.einsum('ij,ij->i', met_vec, gm_vec) / (met * gm) + ) + elif which == 'x': + met = metx + resphi = np.zeros_like(met) + elif which == 'y': + met = mety + resphi = np.zeros_like(met) res = met - gm # (met / gm) - 1 - return res + return res, resphi - def compute(self, pt, phi, w, y, baseline, gm): - res = self._compute_res(pt, phi, w, gm) - res_t = self._compute_res(pt, phi, y, gm) - res_p = self._compute_res(pt, phi, baseline, gm) + def compute(self, pt, phi, w, y, baseline, gm, gmphi): + res, resphi = self._compute_res(pt, phi, w, gm, gmphi, self.which) + res_t, resphi_t = self._compute_res(pt, phi, y, gm, gmphi, self.which) + res_p, resphi_p = self._compute_res(pt, phi, baseline, gm, gmphi, self.which) hist, _ = np.histogram(res, bins=self.bins) hist_p, _ = np.histogram(res_p, bins=self.bins) @@ -412,6 +378,15 @@ def compute(self, pt, phi, w, y, baseline, gm): self.dist_p += hist_p self.dist_met += hist_met + self.pred['model'] += res.tolist() + self.pred['puppi'] += res_p.tolist() + self.pred['truth'] += res_t.tolist() + self.truth += gm.tolist() + self.predphi['model'] += resphi.tolist() + self.predphi['puppi'] += resphi_p.tolist() + self.predphi['truth'] += resphi_t.tolist() + self.truthphi += gmphi.tolist() + def plot(self, path): plt.clf() x = (self.bins[:-1] + self.bins[1:]) * 0.5 @@ -420,25 +395,70 @@ def plot(self, path): mean_p, var_p = self._compute_moments(x, self.dist_p) mean_met, var_met = self._compute_moments(x, self.dist_met) - label = r'Model ($\delta=' + f'{mean:.1f}' + r'\pm' + f'{np.sqrt(var):.1f})$' + label = r'$\mathrm{PUMA}$ ($\delta=' + f'{mean:.1f}' + r'\pm' + f'{np.sqrt(var):.1f})$' plt.hist(x=x, weights=self.dist, label=label, histtype='step', bins=self.bins) - label = r'Puppi ($\delta=' + f'{mean_p:.1f}' + r'\pm' + f'{np.sqrt(var_p):.1f})$' + label = r'$\mathrm{PUPPI}$ ($\delta=' + f'{mean_p:.1f}' + r'\pm' + f'{np.sqrt(var_p):.1f})$' plt.hist(x=x, weights=self.dist_p, label=label, histtype='step', bins=self.bins) label = r'Truth+PF ($\delta=' + f'{mean_met:.1f}' + r'\pm' + f'{np.sqrt(var_met):.1f})$' plt.hist(x=x, weights=self.dist_met, label=label, histtype='step', bins=self.bins) - plt.xlabel('(Predicted-True)') + plt.xlabel(r'Predicted-True $p_\mathrm{T}^\mathrm{miss}$') plt.legend() for ext in ('pdf', 'png'): plt.savefig(path + '.' + ext) + plt.clf() + dfs = [] + lo, hi = 0, 500 + n_bins = 10 + bins = np.linspace(lo, hi, n_bins) + truth = np.digitize(self.truth, bins) + truth = (truth * (hi - lo) / n_bins) + lo + for m,m_label in self.methods.items(): + df = pd.DataFrame({ + 'x': truth, + 'y': self.pred[m], + 'Method': [m_label] * truth.shape[0] + }) + dfs.append(df) + df = pd.concat(dfs, axis=0) + sns.boxplot( + x='x', y='y', hue='Method', data=df, + order=(bins * (hi - lo) / n_bins) + lo, + ) + plt.xlabel(r'True $p_\mathrm{T}^\mathrm{miss}$ [GeV]') + plt.ylabel(r'Error $p_\mathrm{T}^\mathrm{miss}$ [GeV]') + for ext in ('pdf', 'png'): + plt.savefig(f'{path}_differr.{ext}') + + plt.clf() + dfs = [] + for m,m_label in self.methods.items(): + df = pd.DataFrame({ + 'x': truth, + 'y': self.predphi[m], + 'Method': [m_label] * truth.shape[0] + }) + dfs.append(df) + df = pd.concat(dfs, axis=0) + sns.boxplot( + x='x', y='y', hue='Method', data=df, + order=(bins * (hi - lo) / n_bins) + lo, + ) + plt.xlabel(r'True $p_\mathrm{T}^\mathrm{miss}$ [GeV]') + plt.ylabel(r'Error $\phi^\mathrm{miss}$ [GeV]') + for ext in ('pdf', 'png'): + plt.savefig(f'{path}_diffphierr.{ext}') + + return {'model': (mean, np.sqrt(var)), 'puppi': (mean_p, np.sqrt(var_p))} class PapuMetrics(object): - def __init__(self, beta=False): + def __init__(self, beta=False, met_weight=0): + self.met_weight = met_weight self.beta = beta if not self.beta: self.loss_calc = nn.MSELoss( @@ -492,7 +512,34 @@ def add_values(self, val, key, w=None, lo=0, hi=1): else: self.hists[key] += hist - def compute(self, yhat, y, w=None, m=None, plot_m=None): + def _compute_met_constraint(self, yhat, y, x): + if self.met_weight == 0: + return 0 + + assert (x is not None) + + pt, phi = x[:,:,0], x[:,:,2] + px = pt * torch.cos(phi) + py = pt * torch.sin(phi) + + def calc(scale, p): + print(scale.shape, pt.shape, p.shape) + return torch.sum(scale * p, dim=-1) + + yhat = yhat.reshape(pt.shape) + y = y.reshape(pt.shape) + + pred_metx = calc(yhat, px) + pred_mety = calc(yhat, py) + true_metx = calc(y, px) + true_mety = calc(y, py) + + err = square(true_metx - pred_metx) + square(true_mety - pred_mety) + err = err.mean() + return self.met_weight * err + + + def compute(self, yhat, y, w=None, m=None, plot_m=None, x=None): y = y.view(-1) if not self.beta: yhat = yhat.view(-1) @@ -519,17 +566,18 @@ def compute(self, yhat, y, w=None, m=None, plot_m=None): nan_mask = t2n(torch.isnan(loss)).astype(bool) loss = torch.mean(loss) + loss += self._compute_met_constraint(yhat, y, x) + self.loss += t2n(loss).mean() if nan_mask.sum() > 0: yhat = t2n(yhat) - print(nan_mask) - print(yhat[nan_mask]) + logger.info(nan_mask) + logger.info(yhat[nan_mask]) if self.beta: p, q = t2n(p), t2n(q) - print(p[nan_mask]) - print(q[nan_mask]) - print() + logger.info(p[nan_mask]) + logger.info(q[nan_mask]) plot_m = t2n(plot_m).astype(bool) y = t2n(y)[plot_m] diff --git a/grapple/model/_longformer_helpers.py b/grapple/model/_longformer_helpers.py new file mode 100644 index 0000000..0d5a077 --- /dev/null +++ b/grapple/model/_longformer_helpers.py @@ -0,0 +1,586 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +def _compute_global_attention_mask(input_ids, sep_token_id, before_sep_token=True): + """ + Computes global attention mask by putting attention on all tokens before `sep_token_id` if `before_sep_token is + True` else after `sep_token_id`. + """ + question_end_index = _get_question_end_index(input_ids, sep_token_id) + question_end_index = question_end_index.unsqueeze(dim=1) # size: batch_size x 1 + # bool attention mask with True in locations of global attention + attention_mask = torch.arange(input_ids.shape[1], device=input_ids.device) + if before_sep_token is True: + attention_mask = (attention_mask.expand_as(input_ids) < question_end_index).to(torch.uint8) + else: + # last token is separation token and should not be counted and in the middle are two separation tokens + attention_mask = (attention_mask.expand_as(input_ids) > (question_end_index + 1)).to(torch.uint8) * ( + attention_mask.expand_as(input_ids) < input_ids.shape[-1] + ).to(torch.uint8) + + return attention_mask + + +def create_position_ids_from_input_ids(input_ids, padding_idx): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask + return incremental_indices.long() + padding_idx + + +class LongformerSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + self.num_heads = config.num_attention_heads + self.head_dim = int(config.hidden_size / config.num_attention_heads) + self.embed_dim = config.hidden_size + + self.query = nn.Linear(config.hidden_size, self.embed_dim) + self.key = nn.Linear(config.hidden_size, self.embed_dim) + self.value = nn.Linear(config.hidden_size, self.embed_dim) + + # separate projection layers for tokens with global attention + self.query_global = nn.Linear(config.hidden_size, self.embed_dim) + self.key_global = nn.Linear(config.hidden_size, self.embed_dim) + self.value_global = nn.Linear(config.hidden_size, self.embed_dim) + + attention_band = config.attention_band + assert ( + attention_band % 2 == 0 + ), f"`attention_band` for layer has to be an even value. Given {attention_band}" + assert ( + attention_band > 0 + ), f"`attention_band` for layer has to be positive. Given {attention_band}" + + self.one_sided_attn_window_size = attention_band // 2 + + def forward( + self, + hidden_states, + attention_mask=None, + is_index_masked=None, + is_index_global_attn=None, + is_global_attn=None, + output_attentions=False, + ): + """ + :class:`LongformerSelfAttention` expects `len(hidden_states)` to be multiple of `attention_band`. Padding to + `attention_band` happens in :meth:`LongformerModel.forward` to avoid redoing the padding on each layer. + + The `attention_mask` is changed in :meth:`LongformerModel.forward` from 0, 1, 2 to: + + * -10000: no attention + * 0: local attention + * +10000: global attention + """ + hidden_states = hidden_states.transpose(0, 1) + + # project hidden states + query_vectors = self.query(hidden_states) + key_vectors = self.key(hidden_states) + value_vectors = self.value(hidden_states) + + seq_len, batch_size, embed_dim = hidden_states.size() + assert ( + embed_dim == self.embed_dim + ), f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}" + + # normalize query + query_vectors /= math.sqrt(self.head_dim) + + query_vectors = query_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) + key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) + + attn_scores = self._sliding_chunks_query_key_matmul( + query_vectors, key_vectors, self.one_sided_attn_window_size + ) + + # values to pad for attention probs + remove_from_windowed_attention_mask = (attention_mask != 0)[:, :, None, None] + + # cast to fp32/fp16 then replace 1's with -inf + float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill( + remove_from_windowed_attention_mask, -10000.0 + ) + # diagonal mask with zeros everywhere and -inf inplace of padding + diagonal_mask = self._sliding_chunks_query_key_matmul( + float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attn_window_size + ) + + # pad local attention probs + attn_scores += diagonal_mask + + assert list(attn_scores.size()) == [ + batch_size, + seq_len, + self.num_heads, + self.one_sided_attn_window_size * 2 + 1, + ], f"local_attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {attn_scores.size()}" + + # compute local attention probs from global attention keys and contact over window dim + if is_global_attn: + # compute global attn indices required through out forward fn + ( + max_num_global_attn_indices, + is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero, + ) = self._get_global_attn_indices(is_index_global_attn) + # calculate global attn probs from global key + + global_key_attn_scores = self._concat_with_global_key_attn_probs( + query_vectors=query_vectors, + key_vectors=key_vectors, + max_num_global_attn_indices=max_num_global_attn_indices, + is_index_global_attn_nonzero=is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, + ) + # concat to local_attn_probs + # (batch_size, seq_len, num_heads, extra attention count + 2*window+1) + attn_scores = torch.cat((global_key_attn_scores, attn_scores), dim=-1) + + # free memory + del global_key_attn_scores + + attn_probs = F.softmax(attn_scores, dim=-1, dtype=torch.float32) # use fp32 for numerical stability + + # softmax sometimes inserts NaN if all positions are masked, replace them with 0 + attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0) + attn_probs = attn_probs.type_as(attn_scores) + + # free memory + del attn_scores + + # apply dropout + attn_probs = F.dropout(attn_probs, p=self.config.attention_probs_dropout_prob, training=self.training) + + value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) + + # compute local attention output with global attention value and add + if is_global_attn: + # compute sum of global and local attn + attn_output = self._compute_attn_output_with_global_indices( + value_vectors=value_vectors, + attn_probs=attn_probs, + max_num_global_attn_indices=max_num_global_attn_indices, + is_index_global_attn_nonzero=is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, + ) + else: + # compute local attn only + attn_output = self._sliding_chunks_matmul_attn_probs_value( + attn_probs, value_vectors, self.one_sided_attn_window_size + ) + + assert attn_output.size() == (batch_size, seq_len, self.num_heads, self.head_dim), "Unexpected size" + attn_output = attn_output.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous() + + # compute value for global attention and overwrite to attention output + # TODO: remove the redundant computation + if is_global_attn: + global_attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden( + hidden_states=hidden_states, + max_num_global_attn_indices=max_num_global_attn_indices, + is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, + is_index_global_attn_nonzero=is_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, + is_index_masked=is_index_masked, + ) + + # get only non zero global attn output + nonzero_global_attn_output = global_attn_output[ + is_local_index_global_attn_nonzero[0], :, is_local_index_global_attn_nonzero[1] + ] + + # overwrite values with global attention + attn_output[is_index_global_attn_nonzero[::-1]] = nonzero_global_attn_output.view( + len(is_local_index_global_attn_nonzero[0]), -1 + ) + # The attention weights for tokens with global attention are + # just filler values, they were never used to compute the output. + # Fill with 0 now, the correct values are in 'global_attn_probs'. + attn_probs[is_index_global_attn_nonzero] = 0 + + outputs = (attn_output.transpose(0, 1),) + + if output_attentions: + outputs += (attn_probs,) + + return outputs + (global_attn_probs,) if (is_global_attn and output_attentions) else outputs + + @staticmethod + def _pad_and_transpose_last_two_dims(hidden_states_padded, padding): + """pads rows and then flips rows and columns""" + hidden_states_padded = F.pad( + hidden_states_padded, padding + ) # padding value is not important because it will be overwritten + hidden_states_padded = hidden_states_padded.view( + *hidden_states_padded.size()[:-2], hidden_states_padded.size(-1), hidden_states_padded.size(-2) + ) + return hidden_states_padded + + @staticmethod + def _pad_and_diagonalize(chunked_hidden_states): + """ + shift every row 1 step right, converting columns into diagonals. + + Example:: + + chunked_hidden_states: [ 0.4983, 2.6918, -0.0071, 1.0492, + -1.8348, 0.7672, 0.2986, 0.0285, + -0.7584, 0.4206, -0.0405, 0.1599, + 2.0514, -1.1600, 0.5372, 0.2629 ] + window_overlap = num_rows = 4 + (pad & diagonalize) => + [ 0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000 + 0.0000, -1.8348, 0.7672, 0.2986, 0.0285, 0.0000, 0.0000 + 0.0000, 0.0000, -0.7584, 0.4206, -0.0405, 0.1599, 0.0000 + 0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ] + """ + total_num_heads, num_chunks, window_overlap, hidden_dim = chunked_hidden_states.size() + chunked_hidden_states = F.pad( + chunked_hidden_states, (0, window_overlap + 1) + ) # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). Padding value is not important because it'll be overwritten + chunked_hidden_states = chunked_hidden_states.view( + total_num_heads, num_chunks, -1 + ) # total_num_heads x num_chunks x window_overlap*window_overlap+window_overlap + chunked_hidden_states = chunked_hidden_states[ + :, :, :-window_overlap + ] # total_num_heads x num_chunks x window_overlap*window_overlap + chunked_hidden_states = chunked_hidden_states.view( + total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim + ) + chunked_hidden_states = chunked_hidden_states[:, :, :, :-1] + return chunked_hidden_states + + @staticmethod + def _chunk(hidden_states, window_overlap): + """convert into overlapping chunks. Chunk size = 2w, overlap size = w""" + + # non-overlapping chunks of size = 2w + hidden_states = hidden_states.view( + hidden_states.size(0), + hidden_states.size(1) // (window_overlap * 2), + window_overlap * 2, + hidden_states.size(2), + ) + + # use `as_strided` to make the chunks overlap with an overlap size = window_overlap + chunk_size = list(hidden_states.size()) + chunk_size[1] = chunk_size[1] * 2 - 1 + + chunk_stride = list(hidden_states.stride()) + chunk_stride[1] = chunk_stride[1] // 2 + return hidden_states.as_strided(size=chunk_size, stride=chunk_stride) + + @staticmethod + def _mask_invalid_locations(input_tensor, affected_seq_len) -> torch.Tensor: + beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0]) + beginning_mask = beginning_mask_2d[None, :, None, :] + ending_mask = beginning_mask.flip(dims=(1, 3)) + beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1] + beginning_mask = beginning_mask.expand(beginning_input.size()) + beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8 + ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :] + ending_mask = ending_mask.expand(ending_input.size()) + ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8 + + def _sliding_chunks_query_key_matmul(self, query: torch.Tensor, key: torch.Tensor, window_overlap: int): + """ + Matrix multiplication of query and key tensors using with a sliding window attention pattern. This + implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer) with an + overlap of size window_overlap + """ + batch_size, seq_len, num_heads, head_dim = query.size() + assert ( + seq_len % (window_overlap * 2) == 0 + ), f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}" + assert query.size() == key.size() + + chunks_count = seq_len // window_overlap - 1 + + # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2 + query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim) + key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim) + + query = self._chunk(query, window_overlap) + key = self._chunk(key, window_overlap) + + # matrix multiplication + # bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim + # bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim + # bcxy: batch_size * num_heads x chunks x 2window_overlap x window_overlap + diagonal_chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (query, key)) # multiply + + # convert diagonals into columns + diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims( + diagonal_chunked_attention_scores, padding=(0, 0, 0, 1) + ) + + # allocate space for the overall attention matrix where the chunks are combined. The last dimension + # has (window_overlap * 2 + 1) columns. The first (window_overlap) columns are the window_overlap lower triangles (attention from a word to + # window_overlap previous words). The following column is attention score from each word to itself, then + # followed by window_overlap columns for the upper triangle. + + diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty( + (batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap * 2 + 1) + ) + + # copy parts from diagonal_chunked_attention_scores into the combined matrix of attentions + # - copying the main diagonal and the upper triangle + diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[ + :, :, :window_overlap, : window_overlap + 1 + ] + diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[ + :, -1, window_overlap:, : window_overlap + 1 + ] + # - copying the lower triangle + diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[ + :, :, -(window_overlap + 1) : -1, window_overlap + 1 : + ] + + diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[ + :, 0, : window_overlap - 1, 1 - window_overlap : + ] + + # separate batch_size and num_heads dimensions again + diagonal_attention_scores = diagonal_attention_scores.view( + batch_size, num_heads, seq_len, 2 * window_overlap + 1 + ).transpose(2, 1) + + self._mask_invalid_locations(diagonal_attention_scores, window_overlap) + return diagonal_attention_scores + + def _sliding_chunks_matmul_attn_probs_value( + self, attn_probs: torch.Tensor, value: torch.Tensor, window_overlap: int + ): + """ + Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors. Returned tensor will be of the + same shape as `attn_probs` + """ + batch_size, seq_len, num_heads, head_dim = value.size() + + assert seq_len % (window_overlap * 2) == 0 + assert attn_probs.size()[:3] == value.size()[:3] + assert attn_probs.size(3) == 2 * window_overlap + 1 + chunks_count = seq_len // window_overlap - 1 + # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap + + chunked_attn_probs = attn_probs.transpose(1, 2).reshape( + batch_size * num_heads, seq_len // window_overlap, window_overlap, 2 * window_overlap + 1 + ) + + # group batch_size and num_heads dimensions into one + value = value.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim) + + # pad seq_len with w at the beginning of the sequence and another window overlap at the end + padded_value = F.pad(value, (0, 0, window_overlap, window_overlap), value=-1) + + # chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap + chunked_value_size = (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim) + chunked_value_stride = padded_value.stride() + chunked_value_stride = ( + chunked_value_stride[0], + window_overlap * chunked_value_stride[1], + chunked_value_stride[1], + chunked_value_stride[2], + ) + chunked_value = padded_value.as_strided(size=chunked_value_size, stride=chunked_value_stride) + + chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs) + + context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attn_probs, chunked_value)) + return context.view(batch_size, num_heads, seq_len, head_dim).transpose(1, 2) + + @staticmethod + def _get_global_attn_indices(is_index_global_attn): + """ compute global attn indices required throughout forward pass """ + # helper variable + num_global_attn_indices = is_index_global_attn.long().sum(dim=1) + + # max number of global attn indices in batch + max_num_global_attn_indices = num_global_attn_indices.max() + + # indices of global attn + is_index_global_attn_nonzero = is_index_global_attn.nonzero(as_tuple=True) + + # helper variable + is_local_index_global_attn = torch.arange( + max_num_global_attn_indices, device=is_index_global_attn.device + ) < num_global_attn_indices.unsqueeze(dim=-1) + + # location of the non-padding values within global attention indices + is_local_index_global_attn_nonzero = is_local_index_global_attn.nonzero(as_tuple=True) + + # location of the padding values within global attention indices + is_local_index_no_global_attn_nonzero = (is_local_index_global_attn == 0).nonzero(as_tuple=True) + return ( + max_num_global_attn_indices, + is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero, + ) + + def _concat_with_global_key_attn_probs( + self, + key_vectors, + query_vectors, + max_num_global_attn_indices, + is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero, + ): + batch_size = key_vectors.shape[0] + + # create only global key vectors + key_vectors_only_global = key_vectors.new_zeros( + batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim + ) + + key_vectors_only_global[is_local_index_global_attn_nonzero] = key_vectors[is_index_global_attn_nonzero] + + # (batch_size, seq_len, num_heads, max_num_global_attn_indices) + attn_probs_from_global_key = torch.einsum("blhd,bshd->blhs", (query_vectors, key_vectors_only_global)) + + attn_probs_from_global_key[ + is_local_index_no_global_attn_nonzero[0], :, :, is_local_index_no_global_attn_nonzero[1] + ] = -10000.0 + + return attn_probs_from_global_key + + def _compute_attn_output_with_global_indices( + self, + value_vectors, + attn_probs, + max_num_global_attn_indices, + is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero, + ): + batch_size = attn_probs.shape[0] + + # cut local attn probs to global only + attn_probs_only_global = attn_probs.narrow(-1, 0, max_num_global_attn_indices) + # get value vectors for global only + value_vectors_only_global = value_vectors.new_zeros( + batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim + ) + value_vectors_only_global[is_local_index_global_attn_nonzero] = value_vectors[is_index_global_attn_nonzero] + + # use `matmul` because `einsum` crashes sometimes with fp16 + # attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v)) + # compute attn output only global + attn_output_only_global = torch.matmul( + attn_probs_only_global.transpose(1, 2), value_vectors_only_global.transpose(1, 2) + ).transpose(1, 2) + + # reshape attn probs + attn_probs_without_global = attn_probs.narrow( + -1, max_num_global_attn_indices, attn_probs.size(-1) - max_num_global_attn_indices + ).contiguous() + + # compute attn output with global + attn_output_without_global = self._sliding_chunks_matmul_attn_probs_value( + attn_probs_without_global, value_vectors, self.one_sided_attn_window_size + ) + return attn_output_only_global + attn_output_without_global + + def _compute_global_attn_output_from_hidden( + self, + hidden_states, + max_num_global_attn_indices, + is_local_index_global_attn_nonzero, + is_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero, + is_index_masked, + ): + seq_len, batch_size = hidden_states.shape[:2] + + # prepare global hidden states + global_attn_hidden_states = hidden_states.new_zeros(max_num_global_attn_indices, batch_size, self.embed_dim) + global_attn_hidden_states[is_local_index_global_attn_nonzero[::-1]] = hidden_states[ + is_index_global_attn_nonzero[::-1] + ] + + # global key, query, value + global_query_vectors_only_global = self.query_global(global_attn_hidden_states) + global_key_vectors = self.key_global(hidden_states) + global_value_vectors = self.value_global(hidden_states) + + # normalize + global_query_vectors_only_global /= math.sqrt(self.head_dim) + + # reshape + global_query_vectors_only_global = ( + global_query_vectors_only_global.contiguous() + .view(max_num_global_attn_indices, batch_size * self.num_heads, self.head_dim) + .transpose(0, 1) + ) # (batch_size * self.num_heads, max_num_global_attn_indices, head_dim) + global_key_vectors = ( + global_key_vectors.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1) + ) # batch_size * self.num_heads, seq_len, head_dim) + global_value_vectors = ( + global_value_vectors.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1) + ) # batch_size * self.num_heads, seq_len, head_dim) + + # compute attn scores + global_attn_scores = torch.bmm(global_query_vectors_only_global, global_key_vectors.transpose(1, 2)) + + assert list(global_attn_scores.size()) == [ + batch_size * self.num_heads, + max_num_global_attn_indices, + seq_len, + ], f"global_attn_scores have the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is {global_attn_scores.size()}." + + global_attn_scores = global_attn_scores.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len) + + global_attn_scores[ + is_local_index_no_global_attn_nonzero[0], :, is_local_index_no_global_attn_nonzero[1], : + ] = -10000.0 + + global_attn_scores = global_attn_scores.masked_fill( + is_index_masked[:, None, None, :], + -10000.0, + ) + + global_attn_scores = global_attn_scores.view(batch_size * self.num_heads, max_num_global_attn_indices, seq_len) + + # compute global attn probs + global_attn_probs_float = F.softmax( + global_attn_scores, dim=-1, dtype=torch.float32 + ) # use fp32 for numerical stability + + global_attn_probs = F.dropout( + global_attn_probs_float.type_as(global_attn_scores), p=self.config.attention_probs_dropout_prob, training=self.training + ) + + # global attn output + global_attn_output = torch.bmm(global_attn_probs, global_value_vectors) + + assert list(global_attn_output.size()) == [ + batch_size * self.num_heads, + max_num_global_attn_indices, + self.head_dim, + ], f"global_attn_output tensor has the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is {global_attn_output.size()}." + + global_attn_probs = global_attn_probs.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len) + global_attn_output = global_attn_output.view( + batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim + ) + return global_attn_output, global_attn_probs diff --git a/grapple/model/sparse.py b/grapple/model/sparse.py index 6905784..217cd32 100644 --- a/grapple/model/sparse.py +++ b/grapple/model/sparse.py @@ -24,11 +24,15 @@ import math import os +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import CrossEntropyLoss, MSELoss +from copy import deepcopy +from typing import * +import math from transformers.modeling_bert import ACT2FN, BertEmbeddings, BertSelfAttention, prune_linear_layer, gelu_new @@ -37,11 +41,15 @@ from .met_layer import METLayer +from ._longformer_helpers import * + VERBOSE = False -class OskarAttention(BertSelfAttention): + + +class OskarAttention(LongformerSelfAttention): def __init__(self, config): super().__init__(config) @@ -55,39 +63,37 @@ def __init__(self, config): self.pruned_heads = set() self.attention_band = config.attention_band - def prune_heads(self, heads): - if len(heads) == 0: - return - mask = torch.ones(self.num_attention_heads, self.attention_head_size) - heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads - for head in heads: - # Compute how many pruned heads are before the head and move the index accordingly - head = head - sum(1 if h < head else 0 for h in self.pruned_heads) - mask[head] = 0 - mask = mask.view(-1).contiguous().eq(1) - index = torch.arange(len(mask))[mask].long() - - # Prune linear layers - self.query = prune_linear_layer(self.query, index) - self.key = prune_linear_layer(self.key, index) - self.value = prune_linear_layer(self.value, index) - self.dense = prune_linear_layer(self.dense, index, dim=1) - - # Update hyper params and store pruned heads - self.num_attention_heads = self.num_attention_heads - len(heads) - self.all_head_size = self.attention_head_size * self.num_attention_heads - self.pruned_heads = self.pruned_heads.union(heads) - - def forward(self, input_ids, attention_mask=None, head_mask=None): - mixed_query_layer = self.query(input_ids) - mixed_key_layer = self.key(input_ids) - mixed_value_layer = self.value(input_ids) - - query_layer = self.transpose_for_scores(mixed_query_layer) - key_layer = self.transpose_for_scores(mixed_key_layer) - value_layer = self.transpose_for_scores(mixed_value_layer) + self.query_global = nn.Linear(config.hidden_size, self.embed_dim) + self.key_global = nn.Linear(config.hidden_size, self.embed_dim) + self.value_global = nn.Linear(config.hidden_size, self.embed_dim) + + self.config = config + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + input_layer, + attention_mask=None, + is_index_masked=None, + is_index_global_attn=None, + is_global_attn=None, + output_attentions=False, + ): if self.attention_band is not None: + return super().forward( + input_layer, + attention_mask, + is_index_masked, + is_index_global_attn, + is_global_attn + ) + + elif False: query_layer = query_layer.permute(0, 2, 1, 3) key_layer = key_layer.permute(0, 2, 1, 3) value_layer = value_layer.permute(0, 2, 1, 3) @@ -99,61 +105,41 @@ def forward(self, input_ids, attention_mask=None, head_mask=None): query_layer /= math.sqrt(self.attention_head_size) query_layer = query_layer.float().contiguous() key_layer = key_layer.float().contiguous() - if False: - attention_scores = diagonaled_mm_tvm( - query_layer, key_layer, - attn_band, - 1, False, 0, False # dilation, is_t1_diag, padding, autoregressive - ) - else: - attention_scores = sliding_chunks_matmul_qk( - query_layer, key_layer, - attn_band, padding_value=0 - ) + attention_scores = sliding_chunks_matmul_qk( + query_layer, key_layer, + attn_band, padding_value=0 + ) mask_invalid_locations(attention_scores, attn_band, 1, False) if attention_mask is not None: remove_from_windowed_attention_mask = remove_from_windowed_attention_mask.unsqueeze(dim=-1).unsqueeze(dim=-1) float_mask = remove_from_windowed_attention_mask.type_as(query_layer).masked_fill(remove_from_windowed_attention_mask, -10000.0) float_mask = float_mask.repeat(1, 1, 1, 1) # don't think I need this ones = float_mask.new_ones(size=float_mask.size()) - if False: - d_mask = diagonaled_mm_tvm(ones, float_mask, attn_band, 1, False, 0, False) - else: - d_mask = sliding_chunks_matmul_qk(ones, float_mask, attn_band, padding_value=0) + d_mask = sliding_chunks_matmul_qk(ones, float_mask, attn_band, padding_value=0) attention_scores += d_mask attention_probs = F.softmax(attention_scores, dim=-1, dtype=torch.float32) attention_probs = self.dropout(attention_probs) value_layer = value_layer.float().contiguous() - if False: - context_layer = diagonaled_mm_tvm(attention_probs, value_layer, attn_band, 1, True, 0, False) - else: - context_layer = sliding_chunks_matmul_pv(attention_probs, value_layer, attn_band) + context_layer = sliding_chunks_matmul_pv(attention_probs, value_layer, attn_band) else: - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in BertModel forward() function) - attention_scores = attention_scores + attention_mask - - # Normalize the attention scores to probabilities. - attention_probs = nn.Softmax(dim=-1)(attention_scores) - if VERBOSE: - # print(attention_probs[0, :8, :8]) - print(torch.max(attention_probs), torch.min(attention_probs)) + mixed_query_layer = self.query(input_layer) + mixed_key_layer = self.key(input_layer) + mixed_value_layer = self.value(input_layer) - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask + query_layer = (torch.nn.functional.elu(query_layer) + 1) + key_layer = (torch.nn.functional.elu(key_layer) + 1) + key_layer = attention_mask * key_layer - context_layer = torch.matmul(attention_probs, value_layer) + D_inv = 1. / torch.einsum('...nd,...d->...n', query_layer, key_layer.sum(dim=2)) + context = torch.einsum('...nd,...ne->...de', key_layer, value_layer) + context_layer = torch.einsum('...de,...nd,...n->...ne', context, query_layer, D_inv) context_layer = context_layer.permute(0, 2, 1, 3) @@ -169,7 +155,7 @@ def forward(self, input_ids, attention_mask=None, head_mask=None): projected_context_layer = torch.einsum("bfnd,ndh->bfh", context_layer, w) + b projected_context_layer_dropout = self.dropout(projected_context_layer) - layernormed_context_layer = self.LayerNorm(input_ids + projected_context_layer_dropout) + layernormed_context_layer = self.LayerNorm(input_layer + projected_context_layer_dropout) return (layernormed_context_layer, attention_probs) if self.output_attentions else (layernormed_context_layer,) @@ -187,14 +173,23 @@ def __init__(self, config): except KeyError: self.activation = config.hidden_act - def forward(self, hidden_states, attention_mask=None, head_mask=None): - attention_output = self.attention(hidden_states, attention_mask, head_mask) + def forward(self, hidden_states, attention_mask=None, + is_index_masked=None, is_index_global_attn=None, + is_global_attn=None): + attention_output = self.attention( + hidden_states, attention_mask, + is_index_masked=is_index_masked, + is_index_global_attn=is_index_global_attn, + is_global_attn=is_global_attn, + ) + ffn_output = self.ffn(attention_output[0]) ffn_output = self.activation(ffn_output) ffn_output = self.ffn_output(ffn_output) hidden_states = self.full_layer_layer_norm(ffn_output + attention_output[0]) - return (hidden_states,) + attention_output[1:] # add attentions if we output them + + return (hidden_states, ) + attention_output[2:] # add attentions if we output them class OskarLayerGroup(nn.Module): @@ -205,21 +200,23 @@ def __init__(self, config): self.output_hidden_states = config.output_hidden_states self.albert_layers = nn.ModuleList([OskarLayer(config) for _ in range(config.inner_group_num)]) - def forward(self, hidden_states, attention_mask=None, head_mask=None): + def forward(self, hidden_states, attention_mask=None, + is_index_masked=None, is_index_global_attn=None, + is_global_attn=None): layer_hidden_states = () layer_attentions = () for layer_index, albert_layer in enumerate(self.albert_layers): - layer_output = albert_layer(hidden_states, attention_mask, head_mask[layer_index]) + layer_output = albert_layer( + hidden_states, + attention_mask, + is_index_masked=is_index_masked, + is_index_global_attn=is_index_global_attn, + is_global_attn=is_global_attn, + ) hidden_states = layer_output[0] - if self.output_attentions: - layer_attentions = layer_attentions + (layer_output[1],) - - if self.output_hidden_states: - layer_hidden_states = layer_hidden_states + (hidden_states,) - - outputs = (hidden_states,) + outputs = (hidden_states, ) if self.output_hidden_states: outputs = outputs + (layer_hidden_states,) if self.output_attentions: @@ -245,6 +242,11 @@ def forward(self, hidden_states, attention_mask=None, head_mask=None): if self.output_hidden_states: all_hidden_states = (hidden_states,) + attention_mask = attention_mask.type(hidden_states.dtype) + is_index_masked = attention_mask < 0 + is_index_global_attn = attention_mask > 0 + is_global_attn = is_index_global_attn.flatten().any().item() + for i in range(self.config.num_hidden_layers): # Number of layers in a hidden group layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups) @@ -255,19 +257,15 @@ def forward(self, hidden_states, attention_mask=None, head_mask=None): layer_group_output = self.albert_layer_groups[group_idx]( hidden_states, attention_mask, - head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group], + is_index_masked=is_index_masked, + is_index_global_attn=is_index_global_attn, + is_global_attn=is_global_attn, ) hidden_states = layer_group_output[0] - if self.output_attentions: - all_attentions = all_attentions + layer_group_output[-1] - - if self.output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - outputs = (hidden_states,) + outputs = (hidden_states, ) if self.output_hidden_states: - outputs = outputs + (all_hidden_states,) + outputs = outputs + (all_hidden_states, ) if self.output_attentions: outputs = outputs + (all_attentions,) return outputs # last-layer hidden state, (all hidden states), (all attentions) @@ -282,28 +280,15 @@ def __init__(self, config): config.output_attentions = False config.output_hidden_states = False - config.num_hidden_groups = 1 - config.inner_group_num = 1 config.layer_norm_eps = 1e-12 config.hidden_dropout_prob = 0 config.attention_probs_dropout_prob = 0 - config.hidden_act = self.tanh #"gelu_new" + config.hidden_act = "gelu_new" self.embedder = nn.Linear(config.feature_size, config.embedding_size) - self.encoders = nn.ModuleList([OskarTransformer(config) for _ in range(config.num_encoders)]) + self.encoder = OskarTransformer(config) self.decoders = nn.ModuleList([nn.Linear(config.hidden_size, config.hidden_size), nn.Linear(config.hidden_size, 1)]) - self.tests = nn.ModuleList( - [ - nn.Linear(config.feature_size, 1, bias=False), - # nn.Linear(config.feature_size, config.hidden_size), - # nn.Linear(config.hidden_size, config.hidden_size), - # nn.Linear(config.hidden_size, config.hidden_size), - # nn.Linear(config.hidden_size, config.hidden_size), - # nn.Linear(config.hidden_size, 1) - ] - ) - self.loss_fn = nn.MSELoss() self.config = config @@ -349,8 +334,7 @@ def forward(self, x, mask=None, y=None): h = self.embedder(x) h = torch.tanh(h) - for e in self.encoders: - h = e(h, mask, head_mask)[0] + h = self.encoder(h, None, mask, head_mask)[0] h = self.decoders[0](h[:, 0, :]) h = self.tanh(h) h = self.decoders[1](h).squeeze(-1) @@ -367,13 +351,15 @@ class Bruno(nn.Module): def __init__(self, config): super().__init__() + self.config = config + self.relu = gelu_new #nn.ReLU() self.tanh = nn.Tanh() config.output_attentions = False config.output_hidden_states = False - config.num_hidden_groups = 1 - config.inner_group_num = 1 + # config.num_hidden_groups = 1 + # config.inner_group_num = 1 config.layer_norm_eps = 1e-12 config.hidden_dropout_prob = 0 config.attention_probs_dropout_prob = 0 @@ -384,35 +370,23 @@ def __init__(self, config): self.embedder = nn.Linear(config.feature_size, config.embedding_size) self.embed_bn = nn.BatchNorm1d(config.embedding_size) - self.encoders = nn.ModuleList([OskarTransformer(config) for _ in range(config.num_encoders)]) + self.encoder = OskarTransformer(config) self.decoders = nn.ModuleList([ - nn.Linear(config.hidden_size, config.hidden_size), - nn.Linear(config.hidden_size, config.hidden_size), nn.Linear(config.hidden_size, config.label_size) ]) self.decoder_bn = nn.ModuleList([nn.BatchNorm1d(config.hidden_size) for _ in self.decoders[:-1]]) - self.tests = nn.ModuleList( - [ - nn.Linear(config.feature_size, 1, bias=False), - # nn.Linear(config.feature_size, config.hidden_size), - # nn.Linear(config.hidden_size, config.hidden_size), - # nn.Linear(config.hidden_size, config.hidden_size), - # nn.Linear(config.hidden_size, config.hidden_size), - # nn.Linear(config.hidden_size, 1) - ] - ) - - self.config = config + if self.config.num_global_objects is not None: + self.global_x = torch.FloatTensor( + np.random.normal(size=(1, self.config.num_global_objects, self.config.feature_size)) + ) self.apply(self._init_weights) def _init_weights(self, module): """ Initialize the weights. """ - if isinstance(module, (nn.Linear, nn.Embedding)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 + if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=0.02) if isinstance(module, (nn.Linear)) and module.bias is not None: module.bias.data.zero_() @@ -420,14 +394,144 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def forward(self, x, mask=None): - if mask is None: - mask = torch.ones(x.size()[:-1], device=self.config.device) - if len(mask.shape) == 3: - attn_mask = mask.unsqueeze(1) # [B, P, P] -> [B, 1, P, P] + def _pad_to_window_size( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + ): + """A helper function to pad tokens and mask to work with implementation of Longformer self-attention.""" + # padding + attention_band = ( + self.config.attention_band + if isinstance(self.config.attention_band, int) + else max(self.config.attention_band) + ) + + assert attention_band % 2 == 0, f"`attention_band` should be an even value. Given {attention_band}" + input_shape = input_ids.shape if input_ids is not None else inputs_embeds.shape + batch_size, seq_len = input_shape[:2] + + padding_len = (attention_band - seq_len % attention_band) % attention_band + if padding_len > 0: + logger.info( + "Input ids are automatically padded from {} to {} to be a multiple of `config.attention_band`: {}".format( + seq_len, seq_len + padding_len, attention_band + ) + ) + if input_ids is not None: + input_ids = F.pad(input_ids, (0, 0, padding_len), value=0) + attention_mask = F.pad(attention_mask, (0, padding_len), value=False) # no attention on the padding tokens + + return padding_len, input_ids, attention_mask + + def _merge_to_attention_mask(self, attention_mask: torch.Tensor, global_attention_mask: torch.Tensor) -> torch.Tensor: + # longformer self attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn) + # (global_attention_mask + 1) => 1 for local attention, 2 for global attention + # => final attention_mask => 0 for no attention, 1 for local attention 2 for global attention + if attention_mask is not None: + attention_mask = attention_mask * (global_attention_mask + 1) else: - attn_mask = mask.unsqueeze(1).unsqueeze(2) # [B, P] -> [B, 1, P, 1] - attn_mask = (1 - attn_mask) * -1e9 + # simply use `global_attention_mask` as `attention_mask` + # if no `attention_mask` is given + attention_mask = global_attention_mask + 1 + return attention_mask + + def get_extended_attention_mask(self, attention_mask: torch.Tensor, input_shape: Tuple[int], device, dtype) -> torch.Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if False and self.config.is_decoder: + batch_size, seq_length = input_shape + seq_ids = torch.arange(seq_length, device=device) + causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] + # in case past_key_values are used we need to add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype + ), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + + def forward(self, x, attn_mask=None): + if attn_mask is None: + attn_mask = torch.ones(x.size()[:-1], device=self.config.device) + + if self.config.num_global_objects is not None: + batch_size = x.shape[0] + global_x = self.global_x.repeat((batch_size, 1, 1)).to(self.config.device) + global_attn_mask = torch.cat( + [torch.ones(global_x.shape[:-1], device=self.config.device), torch.zeros(x.shape[:-1], device=self.config.device)], + dim=1 + ).long() + attn_mask = torch.cat( + [torch.ones(global_x.shape[:-1], device=self.config.device).long(), attn_mask], + dim=1 + ) + x = torch.cat([global_x, x], dim=1) + + attn_mask = self._merge_to_attention_mask(attn_mask, global_attn_mask) + + # if len(attn_mask.shape) == 3: + # attn_mask = attn_mask.unsqueeze(1) # [B, P, P] -> [B, 1, P, P] + # else: + # attn_mask = attn_mask.unsqueeze(1).unsqueeze(-1) # [B, P] -> [B, 1, P, 1] + + # if self.config.attention_band is not None: + # # attn_mask = self.get_extended_attention_mask(attn_mask) + # attn_mask = (1 - attn_mask) * -1e9 # needed for (sparse) softmax attention + attn_mask: torch.Tensor = self.get_extended_attention_mask(attn_mask, x.shape[:-1], self.config.device, x.dtype)[ + :, 0, 0, : + ] + + padding_len, x, attn_mask = self._pad_to_window_size( + input_ids=x, + attention_mask=attn_mask, + ) head_mask = [None] * self.config.num_hidden_layers @@ -437,17 +541,11 @@ def forward(self, x, mask=None): h = torch.relu(h) h = self.embed_bn(h.permute(0, 2, 1)).permute(0, 2, 1) - for e in self.encoders: - h = e(h, attn_mask, head_mask)[0] + h, = self.encoder(h, attn_mask, head_mask) h = self.decoders[0](h) - h = self.relu(h) - h = self.decoder_bn[0](h.permute(0, 2, 1)).permute(0, 2, 1) - - h = self.decoders[1](h) - h = self.relu(h) - h = self.decoder_bn[1](h.permute(0, 2, 1)).permute(0, 2, 1) - h = self.decoders[2](h) + if self.config.num_global_objects is not None: + h = h[:, self.config.num_global_objects:] return h diff --git a/pip_reqs.txt b/pip_reqs.txt new file mode 100644 index 0000000..d47116a --- /dev/null +++ b/pip_reqs.txt @@ -0,0 +1,12 @@ +numpy +#torch +#transformers +pyyaml +tqdm +loguru +ipython +torchsummary +pyjet +scipy +git+https://github.com/allenai/longformer.git +matplotlib diff --git a/scripts/training/papu/infer_pu.py b/scripts/training/papu/infer_pu.py index 288ca2f..bfcb325 100755 --- a/scripts/training/papu/infer_pu.py +++ b/scripts/training/papu/infer_pu.py @@ -4,6 +4,7 @@ p = utils.ArgumentParser() p.add_args( '--dataset_pattern', '--output', ('--n_epochs', p.INT), + '--checkpoint_path', ('--embedding_size', p.INT), ('--hidden_size', p.INT), ('--feature_size', p.INT), ('--num_attention_heads', p.INT), ('--intermediate_size', p.INT), ('--label_size', p.INT), ('--num_hidden_layers', p.INT), ('--batch_size', p.INT), @@ -33,6 +34,8 @@ import os from apex import amp from functools import partial +from glob import glob +import re def scale_fn(c, decay): @@ -50,10 +53,8 @@ def scale_fn(c, decay): to_t = lambda x: torch.Tensor(x).to(device) to_lt = lambda x: torch.LongTensor(x).to(device) - if config.grad_acc is not None: - config.batch_size //= config.grad_acc - else: - config.grad_acc = 1 + # override + config.batch_size = 128 if torch.cuda.device_count() > 1: config.batch_size *= torch.cuda.device_count() @@ -66,38 +67,52 @@ def scale_fn(c, decay): logger.info(f'Building model') + def load_checkpoint(path): + existing_checkpoints = glob(snapshot.get_path(path)) + if existing_checkpoints: + ckpt = sorted(existing_checkpoints)[-1] + config.from_snapshot = ckpt + epoch = int(re.sub(r'.*epoch', '', re.sub(r'\.pt$', '', ckpt))) + config.epoch_offset = epoch + + return True + else: + return False + + if config.from_snapshot is None: + loaded = load_checkpoint('model_weights_best_epoch*pt') + if not loaded: + # if best doesn't exist, take the latest + loaded = load_checkpoint('model_weights_epoch*pt') + + # if config.from_snapshot is None: + # existing_checkpoints = glob(snapshot.get_path('model_weights_epoch*pt')) + # if existing_checkpoints: + # ckpt = sorted(existing_checkpoints)[-1] + # config.from_snapshot = ckpt + # if config.from_snapshot is not None: + # epoch = int(re.sub(r'.*epoch', '', re.sub(r'\.pt$', '', config.from_snapshot))) + # config.epoch_offset = epoch + model = Bruno(config) - opt = torch.optim.Adam(model.parameters(), lr=config.lr) if config.from_snapshot is not None: - # original saved file with DataParallel state_dicts = torch.load(config.from_snapshot) - # create new OrderedDict that does not contain `module.` - from collections import OrderedDict - model_state_dict = OrderedDict() - for k, v in state_dicts['model'].items(): - name = k - if k.startswith('module'): - name = k[7:] # remove `module.` - model_state_dict[name] = v - # load params - model.load_state_dict(model_state_dict) - - opt.load_state_dict(state_dict['opt']) - - logger.info(f'Snapshot {config.from_snapshot} loaded.') - - # lr = torch.optim.lr_scheduler.ReduceLROnPlateau( - # opt, - # factor=config.lr_decay, - # patience=3 - # ) + if 'model' in state_dicts: + state_dicts = state_dicts['model'] + state_dicts = {re.sub(r'^module\.', '', k):v for k,v in state_dicts.items()} + model.load_state_dict(state_dicts) + + logger.info(f'Model ckpt {config.from_snapshot} loaded.') + metrics = PapuMetrics(config.beta) metrics_puppi = PapuMetrics() metres = ParticleMETResolution() + metresx = ParticleMETResolution(which='x') + metresy = ParticleMETResolution(which='y') jetres = JetResolution() model = model.to(device) - model, opt = amp.initialize(model, opt, opt_level='O1') + model = amp.initialize(model, opt_level='O1') if torch.cuda.device_count() > 1: logger.info(f'Distributing model across {torch.cuda.device_count()} GPUs') model = nn.DataParallel(model) @@ -127,6 +142,7 @@ def scale_fn(c, decay): qm = to_lt(batch['mask'] & batch['neutral_mask']) cqm = to_lt(batch['mask'] & ~batch['neutral_mask']) genmet = batch['genmet'][:, 0] + genmetphi = batch['genmet'][:, 1] if config.pt_weight: weight = x[:, :, 0] / x[:, 0, 0].reshape(-1, 1) @@ -153,33 +169,59 @@ def scale_fn(c, decay): if config.beta: p, q = yhat[:, :, 0], yhat[:, :, 1] - # logger.info(' '.join([str(x) for x in [p.max(), p.min(), q.max(), q.min()]])) yhat = p / (p + q + 1e-5) score = t2n(torch.clamp(yhat.squeeze(-1), 0, 1)) charged_mask = ~batch['neutral_mask'] score[charged_mask] = batch['y'][charged_mask] - metres.compute(pt=batch['x'][:, :, 0], - phi=batch['x'][:, :, 2], + pt = batch['x'][:, :, 0] + phi = batch['x'][:, :, 2] + + metres.compute(pt=pt, + phi=phi, w=score, y=batch['y'], baseline=batch['puppi'], - gm=genmet) + gm=genmet, + gmphi=genmetphi) + + metresx.compute(pt=pt, + phi=phi, + w=score, + y=batch['y'], + baseline=batch['puppi'], + gm=genmet * np.cos(genmetphi), + gmphi=np.zeros_like(genmetphi)) + + metresy.compute(pt=pt, + phi=phi, + w=score, + y=batch['y'], + baseline=batch['puppi'], + gm=genmet * np.sin(genmetphi), + gmphi=np.zeros_like(genmetphi)) jetres.compute(x=batch['x'], - weight=score, + weights={ + 'model': score, + 'puppi': batch['puppi'], + 'truth': batch['y'] + }, mask=batch['mask'], pt0=batch['jet1'][:,0]) avg_loss_tensor /= n_batch - plot_path = f'{config.plot}/resolution_{e:03d}' + plot_path = f'{config.plot}/resolution_inference' metrics.plot(plot_path + '_model') metrics_puppi.plot(plot_path + '_puppi') resolution = metres.plot(plot_path + '_met') + metresx.plot(plot_path + '_metx') + metresy.plot(plot_path + '_mety') + jetres.plot(plot_path + '_jet') avg_loss, avg_acc, avg_posacc, avg_negacc, avg_posfrac = metrics.mean() logger.info(f'Epoch {e}: Average fraction of hard particles = {avg_posfrac}') diff --git a/scripts/training/papu/puma_64_4.slurm b/scripts/training/papu/puma_64_4.slurm deleted file mode 100644 index a2c6543..0000000 --- a/scripts/training/papu/puma_64_4.slurm +++ /dev/null @@ -1,53 +0,0 @@ -#!/bin/bash -#SBATCH -J sluma_4GPUs -#SBATCH -o sluma_4GPUs_%j.out -#SBATCH -e sluma_4GPUs_%j.err -#SBATCH --mail-user=sidn@mit.edu -#SBATCH --mail-type=ALL -#SBATCH --gres=gpu:4 -#SBATCH --gpus-per-node=4 -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=4 -#SBATCH --mem=0 -#SBATCH --time=12:00:00 -#SBATCH --exclusive -#SBATCH --exclude=node0024 - -## User python environment -HOME2=/nobackup/users/$(whoami) -PYTHON_VIRTUAL_ENVIRONMENT=pumappc -CONDA_ROOT=$HOME2/anaconda3 - -## Activate WMLCE virtual environment -source ${CONDA_ROOT}/etc/profile.d/conda.sh -conda activate $PYTHON_VIRTUAL_ENVIRONMENT - -cd /home/${USER}/puma/grapple/ -export PYTHONPATH=${PYTHONPATH}:${PWD} -cd - -nvidia-smi - - -ulimit -s unlimited - -## Creating SLURM nodes list -export NODELIST=nodelist.$ -srun -l bash -c 'hostname' | sort -k 2 -u | awk -vORS=, '{print $2":4"}' | sed 's/,$//' > $NODELIST - -## Number of total processes -echo " " -echo " Nodelist:= " $SLURM_JOB_NODELIST -echo " Number of nodes:= " $SLURM_JOB_NUM_NODES -echo " GPUs per node:= " $SLURM_JOB_GPUS -echo " Ntasks per node:= " $SLURM_NTASKS_PER_NODE - - -#### Use MPI for communication with Horovod - this can be hard-coded during installation as well. -echo " Run started at:- " -date - -## Horovod execution -python /home/${USER}/puma/grapple/scripts/training/papu/train_pu.py -c /home/${USER}/puma/grapple/scripts/training/papu/latest.yaml - -echo "Run completed at:- " -date diff --git a/scripts/training/papu/train_pu.py b/scripts/training/papu/train_pu.py index 16bd69b..b4e1703 100755 --- a/scripts/training/papu/train_pu.py +++ b/scripts/training/papu/train_pu.py @@ -5,15 +5,18 @@ p.add_args( '--dataset_pattern', '--output', ('--n_epochs', p.INT), '--checkpoint_path', + ('--met_constraint', p.FLOAT), + ('--num_global_objects', p.INT), ('--embedding_size', p.INT), ('--hidden_size', p.INT), ('--feature_size', p.INT), ('--num_attention_heads', p.INT), ('--intermediate_size', p.INT), ('--label_size', p.INT), ('--num_hidden_layers', p.INT), ('--batch_size', p.INT), + ('--num_hidden_groups', p.INT), ('--inner_group_num', p.INT), ('--mask_charged', p.STORE_TRUE), ('--lr', {'type': float}), ('--attention_band', p.INT), ('--epoch_offset', p.INT), ('--from_snapshot'), ('--lr_schedule', p.STORE_TRUE), '--plot', - ('--pt_weight', p.STORE_TRUE), ('--num_max_files', p.INT), + ('--pt_weight', p.INT), ('--num_max_files', p.INT), ('--num_max_particles', p.INT), ('--dr_adj', p.FLOAT), ('--beta', p.STORE_TRUE), ('--lr_policy'), ('--grad_acc', p.INT), @@ -33,15 +36,23 @@ from torch.utils.data import RandomSampler import os from apex import amp +from apex.optimizers import FusedAdam, FusedLAMB from functools import partial from glob import glob import re +from transformers.optimization import get_linear_schedule_with_warmup as linear_sched def scale_fn(c, decay): return decay ** c +def check_mem_stats(device=None): + cached = torch.cuda.memory_cached(device) // 1.e9 + allocated = torch.cuda.memory_allocated(device) // 1.e9 + logger.info(f'Device {device} memory: {allocated} GB ({cached} GB) allocated (cached)') + + if __name__ == '__main__': snapshot = utils.Snapshot(config.output, config) @@ -61,52 +72,74 @@ def scale_fn(c, decay): if torch.cuda.device_count() > 1: config.batch_size *= torch.cuda.device_count() + logger.info(f'Computational batch size is {config.batch_size}, accumulated over {config.grad_acc} steps') + logger.info(f'Reading dataset at {config.dataset_pattern}') + num_workers = 8 ds = PapuDataset(config) dl = DataLoader(ds, batch_size=config.batch_size, - collate_fn=PapuDataset.collate_fn) - steps_per_epoch = len(ds) // config.batch_size + collate_fn=PapuDataset.collate_fn, + num_workers=num_workers, pin_memory=True) + batches_per_epoch = len(ds) // config.batch_size * num_workers + steps_per_epoch = batches_per_epoch // config.grad_acc + + config.epoch_offset = 0 logger.info(f'Building model') - if config.from_snapshot is None: - existing_checkpoints = glob(snapshot.get_path('model_weights_epoch*pt')) + def load_checkpoint(path): + existing_checkpoints = glob(snapshot.get_path(path)) if existing_checkpoints: ckpt = sorted(existing_checkpoints)[-1] config.from_snapshot = ckpt epoch = int(re.sub(r'.*epoch', '', re.sub(r'\.pt$', '', ckpt))) config.epoch_offset = epoch + + if config.lr_policy != 'linear': + lr_steps = epoch // (4 if config.lr_policy == 'cyclic' else 1) + config.lr *= (config.lr_decay ** lr_steps) + return True + else: + return False + + if config.from_snapshot is None: + loaded = load_checkpoint('model_weights_best_epoch*pt') + if not loaded: + # if best doesn't exist, take the latest + loaded = load_checkpoint('model_weights_epoch*pt') model = Bruno(config) if config.from_snapshot is not None: - # original saved file with DataParallel state_dicts = torch.load(config.from_snapshot) - # create new OrderedDict that does not contain `module.` - from collections import OrderedDict - model_state_dict = OrderedDict() - for k, v in state_dicts['model'].items(): - name = k - if k.startswith('module'): - name = k[7:] # remove `module.` - model_state_dict[name] = v - # load params - model.load_state_dict(model_state_dict) - - # opt.load_state_dict(state_dicts['opt']) - # lr.load_state_dict(state_dicts['lr']) - - logger.info(f'Snapshot {config.from_snapshot} loaded.') - - opt = torch.optim.Adam(model.parameters(), lr=config.lr) + model.load_state_dict(state_dicts['model']) + + logger.info(f'Model ckpt {config.from_snapshot} loaded.') + + model = model.to(device) + + # opt = torch.optim.Adam(model.parameters(), lr=config.lr) + opt = FusedLAMB(model.parameters(), lr=config.lr) + + if config.from_snapshot is not None: + state_dicts = torch.load(config.from_snapshot) + opt.load_state_dict(state_dicts['opt']) + if config.lr_policy == 'exp' or config.lr_policy is None: lr = torch.optim.lr_scheduler.ExponentialLR(opt, config.lr_decay) elif config.lr_policy == 'cyclic': lr = torch.optim.lr_scheduler.CyclicLR(opt, 0, config.lr, step_size_up=steps_per_epoch*2, scale_fn=partial(scale_fn, decay=config.lr_decay), cycle_momentum=False) + elif config.lr_policy == 'linear': + lr = linear_sched(opt, num_warmup_steps=steps_per_epoch*2, num_training_steps=128*steps_per_epoch, + last_epoch=(config.epoch_offset*steps_per_epoch)-1) # 'epoch' here refers to batch for us - model = model.to(device) - model, opt = amp.initialize(model, opt, opt_level='O1') + # if config.from_snapshot is not None: + # state_dicts = torch.load(config.from_snapshot) + # opt.load_state_dict(state_dicts['lr']) + + if config.attention_band is not None: + model, opt = amp.initialize(model, opt, opt_level='O1') # lr = torch.optim.lr_scheduler.ReduceLROnPlateau( @@ -114,7 +147,7 @@ def scale_fn(c, decay): # factor=config.lr_decay, # patience=3 # ) - metrics = PapuMetrics(config.beta) + metrics = PapuMetrics(config.beta, config.met_constraint) metrics_puppi = PapuMetrics() metres = ParticleMETResolution() @@ -130,6 +163,9 @@ def scale_fn(c, decay): else: min_epoch = 0 + if config.pt_weight is None: + config.pt_weight = 0 + for e in range(min_epoch, config.n_epochs+min_epoch): logger.info(f'Epoch {e}: Start') current_lr = [group['lr'] for group in opt.param_groups][0] @@ -152,10 +188,13 @@ def scale_fn(c, decay): metrics_puppi.reset() metres.reset() + best_loss = np.inf + avg_loss_tensor = 0 # tqdm = lambda x, **kwargs: x opt.zero_grad() - for n_batch, batch in enumerate(tqdm(dl, total=steps_per_epoch)): + ready_for_lr = False + for n_batch, batch in enumerate(tqdm(dl, total=batches_per_epoch)): sparse.VERBOSE = (n_batch == 0) x = to_t(batch['x']) @@ -165,32 +204,37 @@ def scale_fn(c, decay): qm = to_lt(batch['mask'] & batch['neutral_mask']) cqm = to_lt(batch['mask'] & ~batch['neutral_mask']) genmet = batch['genmet'][:, 0] + genmetphi = batch['genmet'][:, 1] - if config.pt_weight: - weight = x[:, :, 0] / x[:, 0, 0].reshape(-1, 1) - weight = weight ** 2 + if config.pt_weight == 0: + weight = None else: - weight = None + weight = x[:, :, 0] / x[:, 0, 0].reshape(-1, 1) + weight = weight ** config.pt_weight if True or e < 3: loss_mask = m else: loss_mask = qm - yhat = model(x, mask=m) + yhat = model(x, attn_mask=m) if not config.beta: yhat = torch.sigmoid(yhat) else: yhat = torch.relu(yhat) - loss, _ = metrics.compute(yhat, y, w=weight, m=loss_mask, plot_m=qm) + loss, _ = metrics.compute(yhat, y, w=weight, m=loss_mask, plot_m=qm, x=x) loss /= config.grad_acc - with amp.scale_loss(loss, opt) as scaled_loss: - scaled_loss.backward() + if config.attention_band is not None: + with amp.scale_loss(loss, opt) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() if (n_batch+1) % config.grad_acc == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), 5) opt.step() opt.zero_grad() + ready_for_lr = True metrics_puppi.compute(p, y, w=weight, m=loss_mask, plot_m=qm) @@ -210,16 +254,24 @@ def scale_fn(c, decay): w=score, y=batch['y'], baseline=batch['puppi'], - gm=genmet) + gm=genmet, + gmphi=genmetphi) if config.lr_policy == 'cyclic': - lr.step() - # current_lr = [group['lr'] for group in opt.param_groups][0] - # logger.info(f'Epoch {e}: Step {n_batch}: Current LR = {current_lr}') + if ready_for_lr: + lr.step() + # current_lr = [group['lr'] for group in opt.param_groups][0] + # logger.info(f'Epoch {e}: Step {n_batch}: Current LR = {current_lr}') + + if e == 0 & n_batch == 100: + check_mem_stats(device) + + check_mem_stats(device) avg_loss_tensor /= n_batch if config.lr_policy != 'cyclic': - lr.step() + if ready_for_lr: + lr.step() plot_path = f'{config.plot}/resolution_{e:03d}' @@ -232,6 +284,11 @@ def scale_fn(c, decay): logger.info(f'Epoch {e}: MODEL:') logger.info(f'Epoch {e}: Loss = {avg_loss}; Accuracy = {avg_acc}') logger.info(f'Epoch {e}: Hard ID = {avg_posacc}; PU ID = {avg_negacc}') + + is_best = False + if avg_loss < best_loss: + best_loss = avg_loss + is_best = True avg_loss, avg_acc, avg_posacc, avg_negacc, _ = metrics_puppi.mean() logger.info(f'Epoch {e}: PUPPI:') @@ -247,5 +304,7 @@ def scale_fn(c, decay): 'lr': lr.state_dict()} torch.save(state_dicts, snapshot.get_path(f'model_weights_epoch{e:06d}.pt')) + if is_best: + torch.save(state_dicts, snapshot.get_path(f'model_weights_best_epoch{e:06d}.pt')) # ds.n_particles = min(2000, ds.n_particles + 50) diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..53750ca --- /dev/null +++ b/setup.py @@ -0,0 +1,7 @@ +from setuptools import setup, find_packages + +setup( + name="grapple", + version="0.1", + packages=find_packages(), +)