From 3f1afaefc9908b4837b13515106ed689bfd3fe02 Mon Sep 17 00:00:00 2001 From: Gilad Mishne Date: Wed, 7 Jun 2023 21:53:49 +0000 Subject: [PATCH] Add support for perturbation magnitude/direction --- gears/data_utils.py | 102 ++++++++++++------------ gears/gears.py | 188 ++++++++++++++++++++++---------------------- gears/model.py | 65 ++++++++------- gears/pertdata.py | 188 ++++++++++++++++++++++++-------------------- gears/utils.py | 130 +++++++++++++++++------------- 5 files changed, 362 insertions(+), 311 deletions(-) diff --git a/gears/data_utils.py b/gears/data_utils.py index 5ffad0f..e656c95 100644 --- a/gears/data_utils.py +++ b/gears/data_utils.py @@ -1,7 +1,6 @@ import pandas as pd import numpy as np import scanpy as sc -from random import shuffle sc.settings.verbosity = 0 from tqdm import tqdm import requests @@ -10,7 +9,10 @@ import warnings warnings.filterwarnings("ignore") -from .utils import parse_single_pert, parse_combo_pert, parse_any_pert, print_sys +from .utils import ( + parse_single_pert, parse_combo_pert, parse_any_pert, print_sys, + get_pert_genes, rm_magnitude +) def rank_genes_groups_by_cov( adata, @@ -53,24 +55,24 @@ def rank_genes_groups_by_cov( if return_dict: return gene_dict - + def get_DE_genes(adata, skip_calc_de): adata.obs.loc[:, 'dose_val'] = adata.obs.condition.apply(lambda x: '1+1' if len(x.split('+')) == 2 else '1') adata.obs.loc[:, 'control'] = adata.obs.condition.apply(lambda x: 0 if len(x.split('+')) == 2 else 1) - adata.obs.loc[:, 'condition_name'] = adata.obs.apply(lambda x: '_'.join([x.cell_type, x.condition, x.dose_val]), axis = 1) - + adata.obs.loc[:, 'condition_name'] = adata.obs.apply(lambda x: '_'.join([x.cell_type, x.condition, x.dose_val]), axis = 1) + adata.obs = adata.obs.astype('category') if not skip_calc_de: - rank_genes_groups_by_cov(adata, - groupby='condition_name', - covariate='cell_type', - control_group='ctrl_1', + rank_genes_groups_by_cov(adata, + groupby='condition_name', + covariate='cell_type', + control_group='ctrl_1', n_genes=len(adata.var), key_added = 'rank_genes_groups_cov_all') return adata def get_dropout_non_zero_genes(adata): - + # calculate mean expression for each condition unique_conditions = adata.obs.condition.unique() conditions2index = {} @@ -83,7 +85,7 @@ def get_dropout_non_zero_genes(adata): pert_list = np.array(list(condition2mean_expression.keys())) mean_expression = np.array(list(condition2mean_expression.values())).reshape(len(adata.obs.condition.unique()), adata.X.toarray().shape[1]) ctrl = mean_expression[np.where(pert_list == 'ctrl')[0]] - + ## in silico modeling and upperbounding pert2pert_full_id = dict(adata.obs[['condition', 'condition_name']].values) pert_full_id2pert = dict(adata.obs[['condition_name', 'condition']].values) @@ -118,17 +120,17 @@ def get_dropout_non_zero_genes(adata): non_dropout_gene_idx[pert] = np.sort(non_dropouts) top_non_dropout_de_20[pert] = np.array(non_dropout_20_gene_id) top_non_zero_de_20[pert] = np.array(non_zero_20_gene_id) - + non_zero = np.where(np.array(X)[0] != 0)[0] zero = np.where(np.array(X)[0] == 0)[0] true_zeros = np.intersect1d(zero, np.where(np.array(ctrl)[0] == 0)[0]) non_dropouts = np.concatenate((non_zero, true_zeros)) - + adata.uns['top_non_dropout_de_20'] = top_non_dropout_de_20 adata.uns['non_dropout_gene_idx'] = non_dropout_gene_idx adata.uns['non_zeros_gene_idx'] = non_zeros_gene_idx adata.uns['top_non_zero_de_20'] = top_non_zero_de_20 - + return adata @@ -152,11 +154,11 @@ def split_data(self, test_size=0.1, test_pert_genes=None, np.random.seed(seed=seed) unique_perts = [p for p in self.adata.obs['condition'].unique() if p != 'ctrl'] - + if self.split_type == 'simulation': train, test, test_subgroup = self.get_simulation_split(unique_perts, train_gene_set_size, - combo_seen2_train_frac, + combo_seen2_train_frac, seed, test_perts, only_test_set_perts) train, val, val_subgroup = self.get_simulation_split(train, 0.9, @@ -174,17 +176,17 @@ def split_data(self, test_size=0.1, test_pert_genes=None, elif self.split_type == 'no_test': print('test_pert_genes',str(test_pert_genes)) print('test_perts',str(test_perts)) - + train, val = self.get_split_list(unique_perts, test_pert_genes=test_pert_genes, test_perts=test_perts, - test_size=test_size) + test_size=test_size) else: train, test = self.get_split_list(unique_perts, test_pert_genes=test_pert_genes, test_perts=test_perts, test_size=test_size) - + train, val = self.get_split_list(train, test_size=val_size) map_dict = {x: 'train' for x in train} @@ -196,19 +198,19 @@ def split_data(self, test_size=0.1, test_pert_genes=None, self.adata.obs[split_name] = self.adata.obs['condition'].map(map_dict) if self.split_type == 'simulation': - return self.adata, {'test_subgroup': test_subgroup, + return self.adata, {'test_subgroup': test_subgroup, 'val_subgroup': val_subgroup } else: return self.adata - + def get_simulation_split_single(self, pert_list, train_gene_set_size = 0.85, seed = 1, test_set_perts = None, only_test_set_perts = False): unique_pert_genes = self.get_genes_from_perts(pert_list) - + pert_train = [] pert_test = [] np.random.seed(seed=seed) - + if only_test_set_perts and (test_set_perts is not None): ood_genes = np.array(test_set_perts) train_gene_candidates = np.setdiff1d(unique_pert_genes, ood_genes) @@ -223,24 +225,24 @@ def get_simulation_split_single(self, pert_list, train_gene_set_size = 0.85, see ood_genes_exclude_test_set = np.setdiff1d(unique_pert_genes, np.union1d(train_gene_candidates, test_set_perts)) train_set_addition = np.random.choice(ood_genes_exclude_test_set, num_overlap, replace = False) train_gene_candidates = np.concatenate((train_gene_candidates, train_set_addition)) - + ## ood genes - ood_genes = np.setdiff1d(unique_pert_genes, train_gene_candidates) - + ood_genes = np.setdiff1d(unique_pert_genes, train_gene_candidates) + pert_single_train = self.get_perts_from_genes(train_gene_candidates, pert_list,'single') unseen_single = self.get_perts_from_genes(ood_genes, pert_list, 'single') assert len(unseen_single) + len(pert_single_train) == len(pert_list) - + return pert_single_train, unseen_single, {'unseen_single': unseen_single} - + def get_simulation_split(self, pert_list, train_gene_set_size = 0.85, combo_seen2_train_frac = 0.85, seed = 1, test_set_perts = None, only_test_set_perts = False): - + unique_pert_genes = self.get_genes_from_perts(pert_list) - + pert_train = [] pert_test = [] np.random.seed(seed=seed) - + if only_test_set_perts and (test_set_perts is not None): ood_genes = np.array(test_set_perts) train_gene_candidates = np.setdiff1d(unique_pert_genes, ood_genes) @@ -255,35 +257,35 @@ def get_simulation_split(self, pert_list, train_gene_set_size = 0.85, combo_seen ood_genes_exclude_test_set = np.setdiff1d(unique_pert_genes, np.union1d(train_gene_candidates, test_set_perts)) train_set_addition = np.random.choice(ood_genes_exclude_test_set, num_overlap, replace = False) train_gene_candidates = np.concatenate((train_gene_candidates, train_set_addition)) - + ## ood genes - ood_genes = np.setdiff1d(unique_pert_genes, train_gene_candidates) - + ood_genes = np.setdiff1d(unique_pert_genes, train_gene_candidates) + pert_single_train = self.get_perts_from_genes(train_gene_candidates, pert_list,'single') pert_combo = self.get_perts_from_genes(train_gene_candidates, pert_list,'combo') pert_train.extend(pert_single_train) - + ## the combo set with one of them in OOD - combo_seen1 = [x for x in pert_combo if len([t for t in x.split('+') if + combo_seen1 = [x for x in pert_combo if len([t for t in get_pert_genes(x) if t in train_gene_candidates]) == 1] pert_test.extend(combo_seen1) - + pert_combo = np.setdiff1d(pert_combo, combo_seen1) ## randomly sample the combo seen 2 as a test set, the rest in training set np.random.seed(seed=seed) pert_combo_train = np.random.choice(pert_combo, int(len(pert_combo) * combo_seen2_train_frac), replace = False) - + combo_seen2 = np.setdiff1d(pert_combo, pert_combo_train).tolist() pert_test.extend(combo_seen2) pert_train.extend(pert_combo_train) - + ## unseen single unseen_single = self.get_perts_from_genes(ood_genes, pert_list, 'single') combo_ood = self.get_perts_from_genes(ood_genes, pert_list, 'combo') pert_test.extend(unseen_single) - + ## here only keeps the seen 0, since seen 1 is tackled above - combo_seen0 = [x for x in combo_ood if len([t for t in x.split('+') if + combo_seen0 = [x for x in combo_ood if len([t for t in get_pert_genes(x) if t in train_gene_candidates]) == 0] pert_test.extend(combo_seen0) assert len(combo_seen1) + len(combo_seen0) + len(unseen_single) + len(pert_train) + len(combo_seen2) == len(pert_list) @@ -292,7 +294,7 @@ def get_simulation_split(self, pert_list, train_gene_set_size = 0.85, combo_seen 'combo_seen1': combo_seen1, 'combo_seen2': combo_seen2, 'unseen_single': unseen_single} - + def get_split_list(self, pert_list, test_size=0.1, test_pert_genes=None, test_perts=None, hold_outs=True): @@ -336,7 +338,7 @@ def get_split_list(self, pert_list, test_size=0.1, if hold_outs: # This just checks that none of the combos have 2 seen genes hold_out = [t for t in combo_perts if - len([t for t in t.split('+') if + len([t for t in get_pert_genes(t) if t not in test_pert_genes]) > 0] combo_perts = [c for c in combo_perts if c not in hold_out] test_perts = single_perts + combo_perts @@ -353,7 +355,7 @@ def get_split_list(self, pert_list, test_size=0.1, if hold_outs: # This just checks that none of the combos have 2 seen genes hold_out = [t for t in combo_perts if - len([t for t in t.split('+') if + len([t for t in get_pert_genes(t) if t not in test_pert_genes]) > 1] combo_perts = [c for c in combo_perts if c not in hold_out] test_perts = single_perts + combo_perts @@ -361,14 +363,14 @@ def get_split_list(self, pert_list, test_size=0.1, elif self.seen == 2: if test_perts is None: test_perts = np.random.choice(combo_perts, - int(len(combo_perts) * test_size)) + int(len(combo_perts) * test_size)) else: test_perts = np.array(test_perts) else: if test_perts is None: test_perts = np.random.choice(combo_perts, int(len(combo_perts) * test_size)) - + train_perts = [p for p in pert_list if (p not in test_perts) and (p not in hold_out)] return train_perts, test_perts @@ -380,16 +382,16 @@ def get_perts_from_genes(self, genes, pert_list, type_='both'): single_perts = [p for p in pert_list if ('ctrl' in p) and (p != 'ctrl')] combo_perts = [p for p in pert_list if 'ctrl' not in p] - + perts = [] - + if type_ == 'single': pert_candidate_list = single_perts elif type_ == 'combo': pert_candidate_list = combo_perts elif type_ == 'both': pert_candidate_list = pert_list - + for p in pert_candidate_list: for g in genes: if g in parse_any_pert(p): @@ -404,7 +406,7 @@ def get_genes_from_perts(self, perts): if type(perts) is str: perts = [perts] - gene_list = [p.split('+') for p in np.unique(perts)] + gene_list = [get_pert_genes(p) for p in np.unique(perts)] gene_list = [item for sublist in gene_list for item in sublist] gene_list = [g for g in gene_list if g != 'ctrl'] return np.unique(gene_list) \ No newline at end of file diff --git a/gears/gears.py b/gears/gears.py index 64a730d..137a518 100644 --- a/gears/gears.py +++ b/gears/gears.py @@ -1,11 +1,9 @@ from copy import deepcopy -import argparse -from time import time -import sys, os +import os import pickle -import scanpy as sc import numpy as np +import scipy import torch import torch.optim as optim @@ -14,11 +12,11 @@ from .model import GEARS_Model from .inference import evaluate, compute_metrics, deeper_analysis, \ - non_dropout_analysis, compute_synergy_loss -from .utils import loss_fct, uncertainty_loss_fct, parse_any_pert, \ + non_dropout_analysis +from .utils import loss_fct, uncertainty_loss_fct, \ get_similarity_network, print_sys, GeneSimNetwork, \ create_cell_graph_dataset_for_prediction, get_mean_control, \ - get_GI_genes_idx, get_GI_params + get_GI_genes_idx, get_GI_params, rm_magnitude torch.manual_seed(0) @@ -26,26 +24,26 @@ warnings.filterwarnings("ignore") class GEARS: - def __init__(self, pert_data, + def __init__(self, pert_data, device = 'cuda', - weight_bias_track = False, - proj_name = 'GEARS', + weight_bias_track = False, + proj_name = 'GEARS', exp_name = 'GEARS', pred_scalar = False, gi_predict = False): - + self.weight_bias_track = weight_bias_track - + if self.weight_bias_track: import wandb - wandb.init(project=proj_name, name=exp_name) + wandb.init(project=proj_name, name=exp_name) self.wandb = wandb else: self.wandb = None - + self.device = device self.config = None - + self.dataloader = pert_data.dataloader self.adata = pert_data.adata self.node_map = pert_data.node_map @@ -65,7 +63,7 @@ def __init__(self, pert_data, self.default_pert_graph = pert_data.default_pert_graph self.saved_pred = {} self.saved_logvar_sum = {} - + self.ctrl_expression = torch.tensor( np.mean(self.adata.X[self.adata.obs.condition == 'ctrl'], axis=0)).reshape(-1, ).to(self.device) @@ -77,7 +75,7 @@ def __init__(self, pert_data, self.adata.uns['non_zeros_gene_idx'].items() if i in pert_full_id2pert} self.ctrl_adata = self.adata[self.adata.obs['condition'] == 'ctrl'] - + gene_dict = {g:i for i,g in enumerate(self.gene_list)} self.pert2gene = {p: gene_dict[pert] for p, pert in enumerate(self.pert_list) if pert in self.gene_list} @@ -94,33 +92,34 @@ def tunable_parameters(self): 'uncertainty_reg': 'regularization term to balance uncertainty loss and prediction loss, default 1', 'direction_lambda': 'regularization term to balance direction loss and prediction loss, default 1' } - + def model_initialize(self, hidden_size = 64, - num_go_gnn_layers = 1, + num_go_gnn_layers = 1, num_gene_gnn_layers = 1, decoder_hidden_size = 16, num_similar_genes_go_graph = 20, - num_similar_genes_co_express_graph = 20, + num_similar_genes_co_express_graph = 20, coexpress_threshold = 0.4, - uncertainty = False, + uncertainty = False, uncertainty_reg = 1, direction_lambda = 1e-1, G_go = None, G_go_weight = None, G_coexpress = None, G_coexpress_weight = None, - no_perturb = False, + no_perturb = False, cell_fitness_pred = False, + go_path: str = None, ): - + self.config = {'hidden_size': hidden_size, - 'num_go_gnn_layers' : num_go_gnn_layers, + 'num_go_gnn_layers' : num_go_gnn_layers, 'num_gene_gnn_layers' : num_gene_gnn_layers, 'decoder_hidden_size' : decoder_hidden_size, 'num_similar_genes_go_graph' : num_similar_genes_go_graph, 'num_similar_genes_co_express_graph' : num_similar_genes_co_express_graph, 'coexpress_threshold': coexpress_threshold, - 'uncertainty' : uncertainty, + 'uncertainty' : uncertainty, 'uncertainty_reg' : uncertainty_reg, 'direction_lambda' : direction_lambda, 'G_go': G_go, @@ -133,10 +132,10 @@ def model_initialize(self, hidden_size = 64, 'no_perturb': no_perturb, 'cell_fitness_pred': cell_fitness_pred, } - + if self.wandb: self.wandb.config.update(self.config) - + if self.config['G_coexpress'] is None: ## calculating co expression similarity graph edge_list = get_similarity_network(network_type='co-express', @@ -152,7 +151,7 @@ def model_initialize(self, hidden_size = 64, sim_network = GeneSimNetwork(edge_list, self.gene_list, node_map = self.node_map) self.config['G_coexpress'] = sim_network.edge_index self.config['G_coexpress_weight'] = sim_network.edge_weight - + if self.config['G_go'] is None: ## calculating gene ontology similarity graph edge_list = get_similarity_network(network_type='go', @@ -165,23 +164,24 @@ def model_initialize(self, hidden_size = 64, split=self.split, seed=self.seed, train_gene_set_size=self.train_gene_set_size, set2conditions=self.set2conditions, - default_pert_graph=self.default_pert_graph) + default_pert_graph=self.default_pert_graph, + go_path=go_path) sim_network = GeneSimNetwork(edge_list, self.pert_list, node_map = self.node_map_pert) self.config['G_go'] = sim_network.edge_index self.config['G_go_weight'] = sim_network.edge_weight - + self.model = GEARS_Model(self.config).to(self.device) self.best_model = deepcopy(self.model) - + def load_pretrained(self, path): with open(os.path.join(path, 'config.pkl'), 'rb') as f: config = pickle.load(f) - + del config['device'], config['num_genes'], config['num_perts'] self.model_initialize(**config) self.config = config - + state_dict = torch.load(os.path.join(path, 'model.pt'), map_location = torch.device('cpu')) if next(iter(state_dict))[:7] == 'module.': # the pretrained model is from data-parallel module @@ -191,42 +191,42 @@ def load_pretrained(self, path): name = k[7:] # remove `module.` new_state_dict[name] = v state_dict = new_state_dict - + self.model.load_state_dict(state_dict) self.model = self.model.to(self.device) self.best_model = self.model - + def save_model(self, path): if not os.path.exists(path): os.mkdir(path) - + if self.config is None: raise ValueError('No model is initialized...') - + with open(os.path.join(path, 'config.pkl'), 'wb') as f: pickle.dump(self.config, f) - + torch.save(self.best_model.state_dict(), os.path.join(path, 'model.pt')) - - def predict(self, pert_list): + + def predict(self, pert_list, batch_size=300, cache_results=True): ## given a list of single/combo genes, return the transcriptome ## if uncertainty mode is on, also return uncertainty score. - + self.ctrl_adata = self.adata[self.adata.obs['condition'] == 'ctrl'] for pert in pert_list: for i in pert: - if i not in self.pert_list: + if rm_magnitude(i) not in self.pert_list: raise ValueError(i+ " is not in the perturbation graph. " "Please select from GEARS.pert_list!") - + if self.config['uncertainty']: results_logvar = {} - + self.best_model = self.best_model.to(self.device) self.best_model.eval() results_pred = {} results_logvar_sum = {} - + from torch_geometric.data import DataLoader for pert in pert_list: try: @@ -237,10 +237,10 @@ def predict(self, pert_list): continue except: pass - + cg = create_cell_graph_dataset_for_prediction(pert, self.ctrl_adata, self.pert_list, self.device) - loader = DataLoader(cg, 300, shuffle = False) + loader = DataLoader(cg, batch_size, shuffle = False) batch = next(iter(loader)) batch.to(self.device) @@ -251,21 +251,22 @@ def predict(self, pert_list): results_logvar_sum['_'.join(pert)] = np.exp(-np.mean(results_logvar['_'.join(pert)])) else: p = self.best_model(batch) - + results_pred['_'.join(pert)] = np.mean(p.detach().cpu().numpy(), axis = 0) - - self.saved_pred.update(results_pred) - + + if cache_results: + self.saved_pred.update(results_pred) + if self.config['uncertainty']: self.saved_logvar_sum.update(results_logvar_sum) return results_pred, results_logvar_sum else: return results_pred - + def GI_predict(self, combo, GI_genes_file='./genes_with_hi_mean.npy'): - ## given a gene pair, return (1) transcriptome of A,B,A+B and (2) GI scores. + ## given a gene pair, return (1) transcriptome of A,B,A+B and (2) GI scores. ## if uncertainty mode is on, also return uncertainty score. - + try: # If prediction is already saved, then skip inference pred = {} @@ -278,23 +279,22 @@ def GI_predict(self, combo, GI_genes_file='./genes_with_hi_mean.npy'): else: pred = self.predict([[combo[0]], [combo[1]], combo]) - mean_control = get_mean_control(self.adata).values - pred = {p:pred[p]-mean_control for p in pred} + mean_control = get_mean_control(self.adata).values + pred = {p:pred[p]-mean_control for p in pred} if GI_genes_file is not None: # If focussing on a specific subset of genes for calculating metrics - GI_genes_idx = get_GI_genes_idx(self.adata, GI_genes_file) + GI_genes_idx = get_GI_genes_idx(self.adata, GI_genes_file) else: GI_genes_idx = np.arange(len(self.adata.var.gene_name.values)) - + pred = {p:pred[p][GI_genes_idx] for p in pred} return get_GI_params(pred, combo) - + def plot_perturbation(self, query, save_file = None): import seaborn as sns - import numpy as np import matplotlib.pyplot as plt - + sns.set_theme(style="ticks", rc={"axes.facecolor": (0, 0, 0, 0)}, font_scale=1.5) adata = self.adata @@ -307,7 +307,7 @@ def plot_perturbation(self, query, save_file = None): genes = [gene_raw2id[i] for i in adata.uns['top_non_dropout_de_20'][cond2name[query]]] truth = adata[adata.obs.condition == query].X.toarray()[:, de_idx] - + query_ = [q for q in query.split('+') if q != 'ctrl'] pred = self.predict([query_])['_'.join(query_)][de_idx] ctrl_means = adata[adata.obs['condition'] == 'ctrl'].to_df().mean()[ @@ -315,11 +315,12 @@ def plot_perturbation(self, query, save_file = None): pred = pred - ctrl_means truth = truth - ctrl_means - + plt.figure(figsize=[16.5,4.5]) - plt.title(query) - plt.boxplot(truth, showfliers=False, - medianprops = dict(linewidth=0)) + spearmanr = scipy.stats.spearmanr(pred, truth.mean(axis=0)).statistic + plt.title(f"{query} (spr={spearmanr:.2f})") + plt.boxplot(truth, showfliers=False, showmeans=True, + medianprops = dict(linewidth=0)) for i in range(pred.shape[0]): _ = plt.scatter(i+1, pred[i], color='red') @@ -333,20 +334,20 @@ def plot_perturbation(self, query, save_file = None): plt.tick_params(axis='x', which='major', pad=5) plt.tick_params(axis='y', which='major', pad=5) sns.despine() - + if save_file: plt.savefig(save_file, bbox_inches='tight') plt.show() - - - def train(self, epochs = 20, + + + def train(self, epochs = 20, lr = 1e-3, weight_decay = 5e-4 ): - + train_loader = self.dataloader['train_loader'] val_loader = self.dataloader['val_loader'] - + self.model = self.model.to(self.device) best_model = deepcopy(self.model) optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay = weight_decay) @@ -366,13 +367,13 @@ def train(self, epochs = 20, pred, logvar = self.model(batch) loss = uncertainty_loss_fct(pred, logvar, y, batch.pert, reg = self.config['uncertainty_reg'], - ctrl = self.ctrl_expression, + ctrl = self.ctrl_expression, dict_filter = self.dict_filter, direction_lambda = self.config['direction_lambda']) else: pred = self.model(batch) loss = loss_fct(pred, y, batch.pert, - ctrl = self.ctrl_expression, + ctrl = self.ctrl_expression, dict_filter = self.dict_filter, direction_lambda = self.config['direction_lambda']) loss.backward() @@ -382,8 +383,8 @@ def train(self, epochs = 20, if self.wandb: self.wandb.log({'training_loss': loss.item()}) - if step % 50 == 0: - log = "Epoch {} Step {} Train Loss: {:.4f}" + if (step + 1) % 100 == 0: + log = "Epoch {} Step {} Train Loss: {:.4f}" print_sys(log.format(epoch + 1, step + 1, loss.item())) scheduler.step() @@ -398,15 +399,15 @@ def train(self, epochs = 20, # Print epoch performance log = "Epoch {}: Train Overall MSE: {:.4f} " \ "Validation Overall MSE: {:.4f}. " - print_sys(log.format(epoch + 1, train_metrics['mse'], + print_sys(log.format(epoch + 1, train_metrics['mse'], val_metrics['mse'])) - + # Print epoch performance for DE genes log = "Train Top 20 DE MSE: {:.4f} " \ "Validation Top 20 DE MSE: {:.4f}. " print_sys(log.format(train_metrics['mse_de'], val_metrics['mse_de'])) - + if self.wandb: metrics = ['mse', 'pearson'] for m in metrics: @@ -414,48 +415,48 @@ def train(self, epochs = 20, 'val_'+m: val_metrics[m], 'train_de_' + m: train_metrics[m + '_de'], 'val_de_'+m: val_metrics[m + '_de']}) - + if val_metrics['mse_de'] < min_val: min_val = val_metrics['mse_de'] best_model = deepcopy(self.model) - + print_sys("Done!") self.best_model = best_model if 'test_loader' not in self.dataloader: print_sys('Done! No test dataloader detected.') return - + # Model testing test_loader = self.dataloader['test_loader'] print_sys("Start Testing...") test_res = evaluate(test_loader, self.best_model, self.config['uncertainty'], self.device) - test_metrics, test_pert_res = compute_metrics(test_res) + test_metrics, test_pert_res = compute_metrics(test_res) log = "Best performing model: Test Top 20 DE MSE: {:.4f}" print_sys(log.format(test_metrics['mse_de'])) - + if self.wandb: metrics = ['mse', 'pearson'] for m in metrics: self.wandb.log({'test_' + m: test_metrics[m], - 'test_de_'+m: test_metrics[m + '_de'] + 'test_de_'+m: test_metrics[m + '_de'] }) - + out = deeper_analysis(self.adata, test_res) out_non_dropout = non_dropout_analysis(self.adata, test_res) - + metrics = ['pearson_delta'] metrics_non_dropout = ['frac_opposite_direction_top20_non_dropout', 'frac_sigma_below_1_non_dropout', 'mse_top20_de_non_dropout'] - + if self.wandb: for m in metrics: self.wandb.log({'test_' + m: np.mean([j[m] for i,j in out.items() if m in j])}) for m in metrics_non_dropout: - self.wandb.log({'test_' + m: np.mean([j[m] for i,j in out_non_dropout.items() if m in j])}) + self.wandb.log({'test_' + m: np.mean([j[m] for i,j in out_non_dropout.items() if m in j])}) if self.split == 'simulation': print_sys("Start doing subgroup analysis for simulation split...") @@ -473,11 +474,14 @@ def train(self, epochs = 20, for name, result in subgroup_analysis.items(): for m in result.keys(): - subgroup_analysis[name][m] = np.mean(subgroup_analysis[name][m]) + mean_subgroup_analysis = np.mean(subgroup_analysis[name][m]) + subgroup_analysis[name][m] = mean_subgroup_analysis + if np.isnan(mean_subgroup_analysis): + continue if self.wandb: - self.wandb.log({'test_' + name + '_' + m: subgroup_analysis[name][m]}) + self.wandb.log({'test_' + name + '_' + m: mean_subgroup_analysis}) - print_sys('test_' + name + '_' + m + ': ' + str(subgroup_analysis[name][m])) + print_sys('test_' + name + '_' + m + ': ' + str(mean_subgroup_analysis)) ## deeper analysis subgroup_analysis = {} diff --git a/gears/model.py b/gears/model.py index 8dd0086..76efdff 100644 --- a/gears/model.py +++ b/gears/model.py @@ -1,10 +1,10 @@ import torch import torch.nn as nn -import torch.nn.functional as F -from torch.nn import Sequential, Linear, ReLU from torch_geometric.nn import SGConv +from .utils import DEFAULT_MAGNITUDE + class MLP(torch.nn.Module): def __init__(self, sizes, batch_norm=True, last_layer_act="linear"): @@ -33,7 +33,7 @@ class GEARS_Model(torch.nn.Module): def __init__(self, args): super(GEARS_Model, self).__init__() - self.args = args + self.args = args self.num_genes = args['num_genes'] self.num_perts = args['num_perts'] hidden_size = args['hidden_size'] @@ -44,21 +44,21 @@ def __init__(self, args): self.no_perturb = args['no_perturb'] self.cell_fitness_pred = args['cell_fitness_pred'] self.pert_emb_lambda = 0.2 - + # perturbation positional embedding added only to the perturbed genes self.pert_w = nn.Linear(1, hidden_size) - - # gene/globel perturbation embedding dictionary lookup + + # gene/globel perturbation embedding dictionary lookup self.gene_emb = nn.Embedding(self.num_genes, hidden_size, max_norm=True) self.pert_emb = nn.Embedding(self.num_perts, hidden_size, max_norm=True) - + # transformation layer self.emb_trans = nn.ReLU() self.pert_base_trans = nn.ReLU() self.transform = nn.ReLU() self.emb_trans_v2 = MLP([hidden_size, hidden_size, hidden_size], last_layer_act='ReLU') self.pert_fuse = MLP([hidden_size, hidden_size, hidden_size], last_layer_act='ReLU') - + # gene co-expression GNN self.G_coexpress = args['G_coexpress'].to(args['device']) self.G_coexpress_weight = args['G_coexpress_weight'].to(args['device']) @@ -67,7 +67,7 @@ def __init__(self, args): self.layers_emb_pos = torch.nn.ModuleList() for i in range(1, self.num_layers_gene_pos + 1): self.layers_emb_pos.append(SGConv(hidden_size, hidden_size, 1)) - + ### perturbation gene ontology GNN self.G_sim = args['G_go'].to(args['device']) self.G_sim_weight = args['G_go_weight'].to(args['device']) @@ -75,10 +75,10 @@ def __init__(self, args): self.sim_layers = torch.nn.ModuleList() for i in range(1, self.num_layers + 1): self.sim_layers.append(SGConv(hidden_size, hidden_size, 1)) - + # decoder shared MLP self.recovery_w = MLP([hidden_size, hidden_size*2, hidden_size], last_layer_act='linear') - + # gene specific decoder self.indv_w1 = nn.Parameter(torch.rand(self.num_genes, hidden_size, 1)) @@ -86,7 +86,7 @@ def __init__(self, args): self.act = nn.ReLU() nn.init.xavier_normal_(self.indv_w1) nn.init.xavier_normal_(self.indv_b1) - + # Cross gene MLP self.cross_gene_state = MLP([self.num_genes, hidden_size, hidden_size]) @@ -96,32 +96,32 @@ def __init__(self, args): self.indv_b2 = nn.Parameter(torch.rand(1, self.num_genes)) nn.init.xavier_normal_(self.indv_w2) nn.init.xavier_normal_(self.indv_b2) - + # batchnorms self.bn_emb = nn.BatchNorm1d(hidden_size) self.bn_pert_base = nn.BatchNorm1d(hidden_size) self.bn_pert_base_trans = nn.BatchNorm1d(hidden_size) - + # uncertainty mode if self.uncertainty: self.uncertainty_w = MLP([hidden_size, hidden_size*2, hidden_size, 1], last_layer_act='linear') - + #if self.cell_fitness_pred: self.cell_fitness_mlp = MLP([self.num_genes, hidden_size*2, hidden_size, 1], last_layer_act='linear') def forward(self, data): - x, pert_idx = data.x, data.pert_idx + x, pert_idx, pert_magnitude = data.x, data.pert_idx, data.pert_magnitude if self.no_perturb: out = x.reshape(-1,1) - out = torch.split(torch.flatten(out), self.num_genes) + out = torch.split(torch.flatten(out), self.num_genes) return torch.stack(out) else: num_graphs = len(data.batch.unique()) ## get base gene embeddings - emb = self.gene_emb(torch.LongTensor(list(range(self.num_genes))).repeat(num_graphs, ).to(self.args['device'])) + emb = self.gene_emb(torch.LongTensor(list(range(self.num_genes))).repeat(num_graphs, ).to(self.args['device'])) emb = self.bn_emb(emb) - base_emb = self.emb_trans(emb) + base_emb = self.emb_trans(emb) pos_emb = self.emb_pos(torch.LongTensor(list(range(self.num_genes))).repeat(num_graphs, ).to(self.args['device'])) for idx, layer in enumerate(self.layers_emb_pos): @@ -136,12 +136,12 @@ def forward(self, data): pert_index = [] for idx, i in enumerate(pert_idx): - for j in i: + for gene_idx, j in enumerate(i): if j != -1: - pert_index.append([idx, j]) + pert_index.append([idx, j, pert_magnitude[idx][gene_idx]]) pert_index = torch.tensor(pert_index).T - pert_global_emb = self.pert_emb(torch.LongTensor(list(range(self.num_perts))).to(self.args['device'])) + pert_global_emb = self.pert_emb(torch.LongTensor(list(range(self.num_perts))).to(self.args['device'])) ## augment global perturbation embedding with GNN for idx, layer in enumerate(self.sim_layers): @@ -156,10 +156,17 @@ def forward(self, data): ### in case all samples in the batch are controls, then there is no indexing for pert_index. pert_track = {} for i, j in enumerate(pert_index[0]): + magnitude = pert_index[2][i] / DEFAULT_MAGNITUDE + if magnitude < 1.0: + # `magnitude` is interpreted as a percentage of the original expression. + # As a multiplier of the embedding, values below 1.0 are translated to a + # negative multiplier, i.e. a magnitude of 0.2 means a multiplier of -0.8. + magnitude -= 1 + embedding = pert_global_emb[pert_index[1][i]] if j.item() in pert_track: - pert_track[j.item()] = pert_track[j.item()] + pert_global_emb[pert_index[1][i]] + pert_track[j.item()] = pert_track[j.item()] + magnitude * embedding else: - pert_track[j.item()] = pert_global_emb[pert_index[1][i]] + pert_track[j.item()] = magnitude * embedding if len(list(pert_track.values())) > 0: if len(list(pert_track.values())) == 1: @@ -175,7 +182,7 @@ def forward(self, data): base_emb = self.bn_pert_base(base_emb) ## apply the first MLP - base_emb = self.transform(base_emb) + base_emb = self.transform(base_emb) out = self.recovery_w(base_emb) out = out.reshape(num_graphs, self.num_genes, -1) out = out.unsqueeze(-1) * self.indv_w1 @@ -191,7 +198,7 @@ def forward(self, data): cross_gene_out = cross_gene_out * self.indv_w2 cross_gene_out = torch.sum(cross_gene_out, axis=2) - out = cross_gene_out + self.indv_b2 + out = cross_gene_out + self.indv_b2 out = out.reshape(num_graphs * self.num_genes, -1) + x.reshape(-1,1) out = torch.split(torch.flatten(out), self.num_genes) @@ -200,9 +207,9 @@ def forward(self, data): out_logvar = self.uncertainty_w(base_emb) out_logvar = torch.split(torch.flatten(out_logvar), self.num_genes) return torch.stack(out), torch.stack(out_logvar) - + if self.cell_fitness_pred: return torch.stack(out), self.cell_fitness_mlp(torch.stack(out)) - + return torch.stack(out) - + diff --git a/gears/pertdata.py b/gears/pertdata.py index 1cacb5f..78589d5 100644 --- a/gears/pertdata.py +++ b/gears/pertdata.py @@ -15,18 +15,21 @@ from .data_utils import get_DE_genes, get_dropout_non_zero_genes, DataSplitter from .utils import print_sys, zip_data_download_wrapper, dataverse_download,\ - filter_pert_in_go, get_genes_from_perts + filter_pert_in_go, get_genes_from_perts, \ + get_pert_genes, get_pert_magnitude, DEFAULT_MAGNITUDE class PertData: - - def __init__(self, data_path, - gene_set_path=None, - default_pert_graph=True): - + + def __init__(self, data_path, + gene_set_path=None, + default_pert_graph=True, + gene2go_path: str = None): + # Dataset/Dataloader attributes self.data_path = data_path self.default_pert_graph = default_pert_graph self.gene_set_path = gene_set_path + self.gene2go_path = gene2go_path self.dataset_name = None self.dataset_path = None self.adata = None @@ -46,49 +49,57 @@ def __init__(self, data_path, server_path = 'https://dataverse.harvard.edu/api/access/datafile/6153417' dataverse_download(server_path, os.path.join(self.data_path, 'gene2go_all.pkl')) - with open(os.path.join(self.data_path, 'gene2go_all.pkl'), 'rb') as f: + if not gene2go_path: + server_path = 'https://dataverse.harvard.edu/api/access/datafile/6153417' + gene2go_path = os.path.join(self.data_path, 'gene2go_all.pkl') + dataverse_download(server_path, gene2go_path) + with open(gene2go_path, 'rb') as f: self.gene2go = pickle.load(f) - + def set_pert_genes(self): """ - Set the list of genes that can be perturbed and are to be included in + Set the list of genes that can be perturbed and are to be included in perturbation graph """ - - if self.gene_set_path is not None: - # If gene set specified for perturbation graph, use that - path_ = self.gene_set_path - self.default_pert_graph = False - with open(path_, 'rb') as f: - essential_genes = pickle.load(f) - - elif self.default_pert_graph is False: - # Use a smaller perturbation graph - all_pert_genes = get_genes_from_perts(self.adata.obs['condition']) - essential_genes = list(self.adata.var['gene_name'].values) - essential_genes += all_pert_genes - + + if self.gene2go_path: + # External GO term mapping provided for essential genes. + gene2go = self.gene2go else: - # Otherwise, use a large set of genes to create perturbation graph - server_path = 'https://dataverse.harvard.edu/api/access/datafile/6934320' - path_ = os.path.join(self.data_path, - 'essential_all_data_pert_genes.pkl') - dataverse_download(server_path, path_) - with open(path_, 'rb') as f: - essential_genes = pickle.load(f) - - gene2go = {i: self.gene2go[i] for i in essential_genes if i in self.gene2go} + if self.gene_set_path is not None: + # If gene set specified for perturbation graph, use that + path_ = self.gene_set_path + self.default_pert_graph = False + with open(path_, 'rb') as f: + essential_genes = pickle.load(f) + + elif self.default_pert_graph is False: + # Use a smaller perturbation graph + all_pert_genes = get_genes_from_perts(self.adata.obs['condition']) + essential_genes = list(self.adata.var['gene_name'].values) + essential_genes += all_pert_genes + + else: + # Otherwise, use a large set of genes to create perturbation graph + server_path = 'https://dataverse.harvard.edu/api/access/datafile/6934320' + path_ = os.path.join(self.data_path, + 'essential_all_data_pert_genes.pkl') + dataverse_download(server_path, path_) + with open(path_, 'rb') as f: + essential_genes = pickle.load(f) + + gene2go = {i: self.gene2go[i] for i in essential_genes if i in self.gene2go} self.pert_names = np.unique(list(gene2go.keys())) self.node_map_pert = {x: it for it, x in enumerate(self.pert_names)} - + def load(self, data_name = None, data_path = None): """ Load existing dataloader Use data_name for loading 'norman', 'adamson', 'dixit' datasets For other datasets use data_path """ - + if data_name in ['norman', 'adamson', 'dixit']: ## load from harvard dataverse if data_name == 'norman': @@ -98,7 +109,7 @@ def load(self, data_name = None, data_path = None): elif data_name == 'dixit': url = 'https://dataverse.harvard.edu/api/access/datafile/6154416' data_path = os.path.join(self.data_path, data_name) - zip_data_download_wrapper(url, data_path, self.data_path) + zip_data_download_wrapper(url, data_path, self.data_path) self.dataset_name = data_path.split('/')[-1] self.dataset_path = data_path adata_path = os.path.join(data_path, 'perturb_processed.h5ad') @@ -112,7 +123,7 @@ def load(self, data_name = None, data_path = None): else: raise ValueError("data attribute is either Norman/Adamson/Dixit " "or a path to an h5ad file") - + self.set_pert_genes() print_sys('These perturbations are not in the GO graph and their ' 'perturbation can thus not be predicted') @@ -121,7 +132,7 @@ def load(self, data_name = None, data_path = None): lambda x:not filter_pert_in_go(x, self.pert_names))].condition.unique()) print_sys(not_in_go_pert) - + filter_go = self.adata.obs[self.adata.obs.condition.apply( lambda x: filter_pert_in_go(x, self.pert_names))] self.adata = self.adata[filter_go.index.values, :] @@ -129,37 +140,37 @@ def load(self, data_name = None, data_path = None): if not os.path.exists(pyg_path): os.mkdir(pyg_path) dataset_fname = os.path.join(pyg_path, 'cell_graphs.pkl') - + if os.path.isfile(dataset_fname): print_sys("Local copy of pyg dataset is detected. Loading...") - self.dataset_processed = pickle.load(open(dataset_fname, "rb")) + self.dataset_processed = pickle.load(open(dataset_fname, "rb")) print_sys("Done!") else: self.ctrl_adata = self.adata[self.adata.obs['condition'] == 'ctrl'] self.gene_names = self.adata.var.gene_name - - + + print_sys("Creating pyg object for each cell in the data...") self.dataset_processed = self.create_dataset_file() - print_sys("Saving new dataset pyg object at " + dataset_fname) - pickle.dump(self.dataset_processed, open(dataset_fname, "wb")) + print_sys("Saving new dataset pyg object at " + dataset_fname) + pickle.dump(self.dataset_processed, open(dataset_fname, "wb")) print_sys("Done!") - + def new_data_process(self, dataset_name, adata = None, skip_calc_de = False): - + if 'condition' not in adata.obs.columns.values: raise ValueError("Please specify condition") if 'gene_name' not in adata.var.columns.values: raise ValueError("Please specify gene name") if 'cell_type' not in adata.obs.columns.values: raise ValueError("Please specify cell type") - + dataset_name = dataset_name.lower() self.dataset_name = dataset_name save_data_folder = os.path.join(self.data_path, dataset_name) - + if not os.path.exists(save_data_folder): os.mkdir(save_data_folder) self.dataset_path = save_data_folder @@ -167,7 +178,7 @@ def new_data_process(self, dataset_name, if not skip_calc_de: self.adata = get_dropout_non_zero_genes(self.adata) self.adata.write_h5ad(os.path.join(save_data_folder, 'perturb_processed.h5ad')) - + self.set_pert_genes() self.ctrl_adata = self.adata[self.adata.obs['condition'] == 'ctrl'] self.gene_names = self.adata.var.gene_name @@ -177,12 +188,12 @@ def new_data_process(self, dataset_name, dataset_fname = os.path.join(pyg_path, 'cell_graphs.pkl') print_sys("Creating pyg object for each cell in the data...") self.dataset_processed = self.create_dataset_file() - print_sys("Saving new dataset pyg object at " + dataset_fname) - pickle.dump(self.dataset_processed, open(dataset_fname, "wb")) + print_sys("Saving new dataset pyg object at " + dataset_fname) + pickle.dump(self.dataset_processed, open(dataset_fname, "wb")) print_sys("Done!") - - def prepare_split(self, split = 'simulation', - seed = 1, + + def prepare_split(self, split = 'simulation', + seed = 1, train_gene_set_size = 0.75, combo_seen2_train_frac = 0.75, combo_single_split_test_set_fraction = 0.1, @@ -198,17 +209,17 @@ def prepare_split(self, split = 'simulation', self.seed = seed self.subgroup = None self.train_gene_set_size = train_gene_set_size - + split_folder = os.path.join(self.dataset_path, 'splits') if not os.path.exists(split_folder): os.mkdir(split_folder) split_file = self.dataset_name + '_' + split + '_' + str(seed) + '_' \ + str(train_gene_set_size) + '.pkl' split_path = os.path.join(split_folder, split_file) - + if test_perts: split_path = split_path[:-4] + '_' + test_perts + '.pkl' - + if os.path.exists(split_path): print_sys("Local copy of split is detected. Loading...") set2conditions = pickle.load(open(split_path, "rb")) @@ -220,11 +231,11 @@ def prepare_split(self, split = 'simulation', print_sys("Creating new splits....") if test_perts: test_perts = test_perts.split('_') - + if split in ['simulation', 'simulation_single']: DS = DataSplitter(self.adata, split_type=split) - - adata, subgroup = DS.split_data(train_gene_set_size = train_gene_set_size, + + adata, subgroup = DS.split_data(train_gene_set_size = train_gene_set_size, combo_seen2_train_frac = combo_seen2_train_frac, seed=seed, test_perts = test_perts, @@ -233,40 +244,40 @@ def prepare_split(self, split = 'simulation', subgroup_path = split_path[:-4] + '_subgroup.pkl' pickle.dump(subgroup, open(subgroup_path, "wb")) self.subgroup = subgroup - + elif split[:5] == 'combo': split_type = 'combo' seen = int(split[-1]) if test_pert_genes: test_pert_genes = test_pert_genes.split('_') - + DS = DataSplitter(self.adata, split_type=split_type, seen=int(seen)) adata = DS.split_data(test_size=combo_single_split_test_set_fraction, test_perts=test_perts, test_pert_genes=test_pert_genes, seed=seed) - + elif split == 'single': DS = DataSplitter(self.adata, split_type=split) adata = DS.split_data(test_size=combo_single_split_test_set_fraction, seed=seed) - + elif split == 'no_test': DS = DataSplitter(self.adata, split_type=split) adata = DS.split_data(test_size=combo_single_split_test_set_fraction, seed=seed) - - elif split == 'no_split': + + elif split == 'no_split': adata = self.adata adata.obs['split'] = 'test' - + set2conditions = dict(adata.obs.groupby('split').agg({'condition': lambda x: x}).condition) - set2conditions = {i: j.unique().tolist() for i,j in set2conditions.items()} + set2conditions = {i: j.unique().tolist() for i,j in set2conditions.items()} pickle.dump(set2conditions, open(split_path, "wb")) print_sys("Saving new splits at " + split_path) - + self.set2conditions = set2conditions if split == 'simulation': @@ -274,14 +285,14 @@ def prepare_split(self, split = 'simulation', for i,j in subgroup['test_subgroup'].items(): print_sys(i + ':' + str(len(j))) print_sys("Done!") - + def get_dataloader(self, batch_size, test_batch_size = None): if test_batch_size is None: test_batch_size = batch_size - + self.node_map = {x: it for it, x in enumerate(self.adata.var.gene_name)} self.gene_names = self.adata.var.gene_name - + # Create cell graphs cell_graphs = {} if self.split == 'no_split': @@ -290,7 +301,7 @@ def get_dataloader(self, batch_size, test_batch_size = None): for p in self.set2conditions[i]: if p != 'ctrl': cell_graphs[i].extend(self.dataset_processed[p]) - + print_sys("Creating dataloaders....") # Set up dataloaders test_loader = DataLoader(cell_graphs['test'], @@ -309,13 +320,13 @@ def get_dataloader(self, batch_size, test_batch_size = None): cell_graphs[i].extend(self.dataset_processed[p]) print_sys("Creating dataloaders....") - + # Set up dataloaders train_loader = DataLoader(cell_graphs['train'], batch_size=batch_size, shuffle=True, drop_last = True) val_loader = DataLoader(cell_graphs['val'], batch_size=batch_size, shuffle=True) - + if self.split !='no_test': test_loader = DataLoader(cell_graphs['test'], batch_size=batch_size, shuffle=False) @@ -323,39 +334,38 @@ def get_dataloader(self, batch_size, test_batch_size = None): 'val_loader': val_loader, 'test_loader': test_loader} - else: + else: self.dataloader = {'train_loader': train_loader, 'val_loader': val_loader} print_sys("Done!") #del self.dataset_processed # clean up some memory - - + + def create_dataset_file(self): dl = {} for p in tqdm(self.adata.obs['condition'].unique()): cell_graph_dataset = self.create_cell_graph_dataset(self.adata, p, num_samples=1) dl[p] = cell_graph_dataset return dl - + def get_pert_idx(self, pert_category, adata_): + names = [p for p in get_pert_genes(pert_category) if p != 'ctrl'] try: - pert_idx = [np.where(p == self.pert_names)[0][0] - for p in pert_category.split('+') - if p != 'ctrl'] + pert_idx = [np.where(p == self.pert_names)[0][0] for p in names] except: - print(pert_category) + print('no index found for', pert_category) pert_idx = None - + return pert_idx # Set up feature matrix and output - + def create_cell_graph(self, X, y, de_idx, pert, pert_idx=None): #pert_feats = np.expand_dims(pert_feats, 0) #feature_mat = torch.Tensor(np.concatenate([X, pert_feats])).T feature_mat = torch.Tensor(X).T - + ''' pert_feats = np.zeros(len(self.pert_names)) if pert_idx is not None: @@ -365,8 +375,12 @@ def create_cell_graph(self, X, y, de_idx, pert, pert_idx=None): ''' if pert_idx is None: pert_idx = [-1] + pert_magnitude = [DEFAULT_MAGNITUDE] + else: + pert_magnitude = get_pert_magnitude(pert) return Data(x=feature_mat, pert_idx=pert_idx, - y=torch.Tensor(y), de_idx=de_idx, pert=pert) + y=torch.Tensor(y), de_idx=de_idx, pert=pert, + pert_magnitude=pert_magnitude) def create_cell_graph_dataset(self, split_adata, pert_category, num_samples=1): @@ -374,7 +388,7 @@ def create_cell_graph_dataset(self, split_adata, pert_category, Combine cell graphs to create a dataset of cell graphs """ - num_de_genes = 20 + num_de_genes = 20 adata_ = split_adata[split_adata.obs['condition'] == pert_category] if 'rank_genes_groups_cov_all' in adata_.uns: de_genes = adata_.uns['rank_genes_groups_cov_all'] diff --git a/gears/utils.py b/gears/utils.py index 6e39799..49ecaa8 100644 --- a/gears/utils.py +++ b/gears/utils.py @@ -12,7 +12,27 @@ from sklearn.linear_model import TheilSenRegressor from dcor import distance_correlation from multiprocessing import Pool -from functools import partial + + +# Measured as "percent relative to control", e.g. 100 is the same as control, +# 200 is a 2X overexpression, 20 is 5X underexpression, 0 is knockout. +DEFAULT_MAGNITUDE = 100 + +def rm_magnitude(gene_with_magnitude: str) -> str: + return gene_with_magnitude.split('*')[0] + +def get_pert_genes(pert: list[str]) -> list[str]: + # Remove perturbation mangnitude. + # TODO: can be combined with `get_genes_from_perts` + return [rm_magnitude(g) for g in pert.split('+')] + +def get_pert_magnitude(pert: str) -> list[int]: + # Returns magnitudes for genes in `pert`. + return [ + int(gene.split('*')[-1]) if '*' in gene else DEFAULT_MAGNITUDE + for gene in pert.split('+') + if gene != 'ctrl' + ] def parse_single_pert(i): a = i.split('+')[0] @@ -21,10 +41,10 @@ def parse_single_pert(i): pert = b else: pert = a - return pert + return rm_magnitude(pert) def parse_combo_pert(i): - return i.split('+')[0], i.split('+')[1] + return get_pert_genes(i)[:2] def combine_res(res_1, res_2): res_out = {} @@ -56,7 +76,7 @@ def dataverse_download(url, save_path): url (str): the url of the dataset path (str): the path to save the dataset """ - + if os.path.exists(save_path): print_sys('Found local copy...') else: @@ -71,7 +91,7 @@ def dataverse_download(url, save_path): file.write(data) progress_bar.close() - + def zip_data_download_wrapper(url, save_path, data_path): if os.path.exists(save_path): @@ -81,8 +101,8 @@ def zip_data_download_wrapper(url, save_path, data_path): print_sys('Extracting zip file...') with ZipFile((save_path + '.zip'), 'r') as zip: zip.extractall(path = data_path) - print_sys("Done!") - + print_sys("Done!") + def tar_data_download_wrapper(url, save_path, data_path): if os.path.exists(save_path): @@ -92,11 +112,11 @@ def tar_data_download_wrapper(url, save_path, data_path): print_sys('Extracting tar file...') with tarfile.open(save_path + '.tar.gz') as tar: tar.extractall(path= data_path) - print_sys("Done!") - + print_sys("Done!") + def get_go_auto(gene_list, data_path, data_name): go_path = os.path.join(data_path, data_name, 'go.csv') - + if os.path.exists(go_path): return pd.read_csv(go_path) else: @@ -123,7 +143,7 @@ def get_go_auto(gene_list, data_path, data_name): df_edge_list = df_edge_list.rename(columns = {'gene1': 'source', 'gene2': 'target', 'score': 'importance'}) - df_edge_list.to_csv(go_path, index = False) + df_edge_list.to_csv(go_path, index = False) return df_edge_list def get_go(df_gene2go): @@ -162,18 +182,18 @@ def __init__(self, edge_list, gene_list, node_map): self.edge_list = edge_list self.G = nx.from_pandas_edgelist(self.edge_list, source='source', target='target', edge_attr=['importance'], - create_using=nx.DiGraph()) + create_using=nx.DiGraph()) self.gene_list = gene_list for n in self.gene_list: if n not in self.G.nodes(): self.G.add_node(n) - + edge_index_ = [(node_map[e[0]], node_map[e[1]]) for e in self.G.edges] self.edge_index = torch.tensor(edge_index_, dtype=torch.long).T #self.edge_weight = torch.Tensor(self.edge_list['importance'].values) - - edge_attr = nx.get_edge_attributes(self.G, 'importance') + + edge_attr = nx.get_edge_attributes(self.G, 'importance') importance = np.array([edge_attr[e] for e in self.G.edges]) self.edge_weight = torch.Tensor(importance) @@ -186,7 +206,7 @@ def get_GO_edge_list(args): if score > 0.1: edge_list.append((g1, g2, score)) return edge_list - + def make_GO(data_path, pert_list, data_name, num_workers=25, save=True): """ Creates Gene Ontology graph from a custom set of genes @@ -211,7 +231,7 @@ def make_GO(data_path, pert_list, data_name, num_workers=25, save=True): df_edge_list = pd.DataFrame(edge_list).rename( columns={0: 'source', 1: 'target', 2: 'importance'}) - + if save: print('Saving edge_list to file') df_edge_list.to_csv(fname, index=False) @@ -220,24 +240,28 @@ def make_GO(data_path, pert_list, data_name, num_workers=25, save=True): def get_similarity_network(network_type, adata, threshold, k, data_path, data_name, split, seed, train_gene_set_size, - set2conditions, default_pert_graph=True, pert_list=None): - + set2conditions, default_pert_graph=True, pert_list=None, + go_path: str = None): + if network_type == 'co-express': df_out = get_coexpression_network_from_train(adata, threshold, k, data_path, data_name, split, seed, train_gene_set_size, set2conditions) elif network_type == 'go': - if default_pert_graph: - server_path = 'https://dataverse.harvard.edu/api/access/datafile/6934319' - tar_data_download_wrapper(server_path, - os.path.join(data_path, 'go_essential_all'), - data_path) - df_jaccard = pd.read_csv(os.path.join(data_path, - 'go_essential_all/go_essential_all.csv')) - + if go_path is not None: + df_jaccard = pd.read_csv(go_path) + print_sys(f'Loaded {len(df_jaccard)} similarities from {go_path}') else: - df_jaccard = make_GO(data_path, pert_list, data_name) + if default_pert_graph: + server_path = 'https://dataverse.harvard.edu/api/access/datafile/6934319' + tar_data_download_wrapper(server_path, + os.path.join(data_path, 'go_essential_all'), + data_path) + df_jaccard = pd.read_csv(os.path.join(data_path, + 'go_essential_all/go_essential_all.csv')) + else: + df_jaccard = make_GO(data_path, pert_list, data_name) df_out = df_jaccard.groupby('target').apply(lambda x: x.nlargest(k + 1, ['importance'])).reset_index(drop = True) @@ -247,17 +271,17 @@ def get_similarity_network(network_type, adata, threshold, k, def get_coexpression_network_from_train(adata, threshold, k, data_path, data_name, split, seed, train_gene_set_size, set2conditions): - + fname = os.path.join(os.path.join(data_path, data_name), split + '_' + str(seed) + '_' + str(train_gene_set_size) + '_' + str(threshold) + '_' + str(k) + '_co_expression_network.csv') - + if os.path.exists(fname): return pd.read_csv(fname) else: gene_list = [f for f in adata.var.gene_name.values] - idx2gene = dict(zip(range(len(gene_list)), gene_list)) + idx2gene = dict(zip(range(len(gene_list)), gene_list)) X = adata.X train_perts = set2conditions['train'] X_tr = X[np.isin(adata.obs.condition, [i for i in train_perts if 'ctrl' in i])] @@ -283,23 +307,22 @@ def get_coexpression_network_from_train(adata, threshold, k, data_path, 2: 'importance'}) df_co_expression.to_csv(fname, index = False) return df_co_expression - + def filter_pert_in_go(condition, pert_names): if condition == 'ctrl': return True else: - cond1 = condition.split('+')[0] - cond2 = condition.split('+')[1] + cond1, cond2 = get_pert_genes(condition)[:2] num_ctrl = (cond1 == 'ctrl') + (cond2 == 'ctrl') num_in_perts = (cond1 in pert_names) + (cond2 in pert_names) if num_ctrl + num_in_perts == 2: return True else: return False - + def uncertainty_loss_fct(pred, logvar, y, perts, reg = 0.1, ctrl = None, direction_lambda = 1e-3, dict_filter = None): - gamma = 2 + gamma = 2 perts = np.array(perts) losses = torch.tensor(0.0, requires_grad=True).to(pred.device) for p in set(perts): @@ -312,12 +335,12 @@ def uncertainty_loss_fct(pred, logvar, y, perts, reg = 0.1, ctrl = None, pred_p = pred[np.where(perts==p)[0]] y_p = y[np.where(perts==p)[0]] logvar_p = logvar[np.where(perts==p)[0]] - + # uncertainty based loss losses += torch.sum((pred_p - y_p)**(2 + gamma) + reg * torch.exp( -logvar_p) * (pred_p - y_p)**(2 + gamma))/pred_p.shape[0]/pred_p.shape[1] - - # direction loss + + # direction loss if p!= 'ctrl': losses += torch.sum(direction_lambda * (torch.sign(y_p - ctrl[retain_idx]) - @@ -328,7 +351,7 @@ def uncertainty_loss_fct(pred, logvar, y, perts, reg = 0.1, ctrl = None, (torch.sign(y_p - ctrl) - torch.sign(pred_p - ctrl))**2)/\ pred_p.shape[0]/pred_p.shape[1] - + return losses/(len(set(perts))) @@ -340,7 +363,7 @@ def loss_fct(pred, y, perts, ctrl = None, direction_lambda = 1e-3, dict_filter = for p in set(perts): pert_idx = np.where(perts == p)[0] - + # during training, we remove the all zero genes into calculation of loss. # this gives a cleaner direction loss. empirically, the performance stays the same. if p!= 'ctrl': @@ -352,7 +375,7 @@ def loss_fct(pred, y, perts, ctrl = None, direction_lambda = 1e-3, dict_filter = y_p = y[pert_idx] #losses += torch.sum((torch.mean(pred_p, 0) - y_p)**(2+gamma))/pred_p.shape[1] losses = losses + torch.sum((pred_p - y_p)**(2 + gamma))/pred_p.shape[0]/pred_p.shape[1] - + ## direction loss if (p!= 'ctrl'): losses = losses + torch.sum(direction_lambda * @@ -373,20 +396,21 @@ def print_sys(s): s (str): the string to print """ print(s, flush = True, file = sys.stderr) - + def create_cell_graph_for_prediction(X, pert_idx, pert_gene): if pert_idx is None: pert_idx = [-1] - return Data(x=torch.Tensor(X).T, pert_idx = pert_idx, pert=pert_gene) - + pert_magnitude = [get_pert_magnitude(g)[0] for g in pert_gene] + return Data(x=torch.Tensor(X).T, pert_idx=pert_idx, pert=pert_gene, pert_magnitude=pert_magnitude) + def create_cell_graph_dataset_for_prediction(pert_gene, ctrl_adata, gene_names, device, num_samples = 300): # Get the indices (and signs) of applied perturbation - pert_idx = [np.where(p == np.array(gene_names))[0][0] for p in pert_gene] - + pert_gene_no_magnitude = [rm_magnitude(p) for p in pert_gene] + pert_idx = [np.where(p == np.array(gene_names))[0][0] for p in pert_gene_no_magnitude] Xs = ctrl_adata[np.random.randint(0, len(ctrl_adata), num_samples), :].X.toarray() # Create cell graphs cell_graphs = [create_cell_graph_for_prediction(X, pert_idx, pert_gene).to(device) for X in Xs] @@ -401,7 +425,7 @@ def get_coeffs(singles_expr, first_expr, second_expr, double_expr): results['ts'] = TheilSenRegressor(fit_intercept=False, max_subpopulation=1e5, max_iter=1000, - random_state=1000) + random_state=1000) X = singles_expr y = double_expr results['ts'].fit(X, y.ravel()) @@ -409,7 +433,7 @@ def get_coeffs(singles_expr, first_expr, second_expr, double_expr): results['c1'] = results['ts'].coef_[0] results['c2'] = results['ts'].coef_[1] results['mag'] = np.sqrt((results['c1']**2 + results['c2']**2)) - + results['dcor'] = distance_correlation(singles_expr, double_expr) results['dcor_singles'] = distance_correlation(first_expr, second_expr) results['dcor_first'] = distance_correlation(first_expr, double_expr) @@ -418,23 +442,23 @@ def get_coeffs(singles_expr, first_expr, second_expr, double_expr): results['dominance'] = np.abs(np.log10(results['c1']/results['c2'])) results['eq_contr'] = np.min([results['dcor_first'], results['dcor_second']])/\ np.max([results['dcor_first'], results['dcor_second']]) - + return results def get_GI_params(preds, combo): - + singles_expr = np.array([preds[combo[0]], preds[combo[1]]]).T first_expr = np.array(preds[combo[0]]).T second_expr = np.array(preds[combo[1]]).T double_expr = np.array(preds[combo[0]+'_'+combo[1]]).T - + return get_coeffs(singles_expr, first_expr, second_expr, double_expr) def get_GI_genes_idx(adata, GI_gene_file): # Genes used for linear model fitting GI_genes = np.load(GI_gene_file, allow_pickle=True) GI_genes_idx = np.where([g in GI_genes for g in adata.var.gene_name.values])[0] - + return GI_genes_idx def get_mean_control(adata):