-
Notifications
You must be signed in to change notification settings - Fork 2
Add cox proportional hazards model for survival analysis #9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
src/shears/pp/__init__.py
Outdated
| def cell_corrs( | ||
| adata_sc, | ||
| adata_bulk, | ||
| *, | ||
| inplace=True, | ||
| layer_sc="quantile_norm", | ||
| layer_bulk="quantile_norm", | ||
| key_added="pearson", | ||
| random_state=0, | ||
| n_jobs=None, | ||
| ) -> Optional[pd.DataFrame]: | ||
| """ | ||
| Computes a bulk_sample x cell matrix assigning each cell a Pearson correlation with each bulk sample. | ||
| If inplace is True, stores the resulting matrix in adata_sc.obsm[key_added] | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this function for? This would basically be the old Scissor way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I have had a side project where I was asked to do this and the implementation in python is much faster than the R equivalent. This shows the comparison between the "old" cell corrs and the "new" shears weights.
I thought it doesn't hurt to add this function to shears as it fits the topic.
grst
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall! From an API design perspective, I think I would make separate functions for the survival and the GLM case.
The global variables are a reasonable solution to the parellelization. My go-to solution tends to be joblib nowadays: https://joblib.readthedocs.io/en/latest/parallel.html, example in Scirpy: usage example | helper function. Joblib overcame an obsure bug in multiprocessing a user was facing and it also allows to parallelize beyond a single node using dask. But the problem with serializing remains and as you did here one needs to make sure data gets serialized only once.
src/shears/tl/__init__.py
Outdated
| fam = sm.families.Binomial() if family == "binomial" else sm.families.Gaussian() | ||
| res = smf.glm(formula, data=df, family=fam).fit() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe allow to pass a "family" directly? E.g. one could use negative binomial or poisson for count data.
src/shears/tl/__init__.py
Outdated
| adata_sc: AnnData, | ||
| adata_bulk: AnnData, | ||
| *, | ||
| dep_var, | ||
| covariate_str="", | ||
| inplace=True, | ||
| cell_weights_key="cell_weights", | ||
| key_added="shears", | ||
| n_jobs=None, | ||
| ): | ||
| dep_var: str = "", | ||
| covariate_str: str = "", | ||
| inplace: bool = True, | ||
| cell_weights_key: str = "cell_weights", | ||
| key_added: str = "shears", | ||
| family: Literal["binomial", "gaussian", "cox"] = "binomial", | ||
| duration_col: str = "OS_time", | ||
| event_col: str = "OS_status", | ||
| n_jobs: Optional[int] = None, | ||
| **kwargs: Any, | ||
| ) -> pd.DataFrame: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not a big fan of an API interface where some variables are only used in some cases (i.e dep_var and ducation_col + event_col are mutually exlusive).
I think this is confusing from a user experience and it also makes it difficult to get the type hints right (s.t. a linter can warn the user about invalid parameter configuration). See also https://stackoverflow.com/questions/75757890/python-overload-single-argument.
An alternative interface would split this up into two different functions:
def shears_survival(
adata_sc,
adata_bulk,
*,
event_col,
duration_col,
covariate_str,
strata : list[str] | None = None,
) -> pd.DataFrame:
...
def shears_glm(
adata_sc,
adata_bulk,
*,
dep_var,
covariate_str,
family: sm.families.Family | Literal["bionmial", "gaussian"] = "binomial",
...
) -> pd.DataFrame:
...
Those functions could call the same helper functions to reduce code duplication where appropriate.
src/shears/tl/__init__.py
Outdated
| init_kwargs = {k: v for k, v in kwargs.items() if k in init_params} | ||
| fit_kwargs = {k: v for k, v in kwargs.items() if k in fit_params} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if there were any argument that has the same name in __init__ and fit this might fall apart.
I think I'd just accept two dicts init_kwargs and fit_kwargs in the function signature.
These can then directly be passed to __init__ and .fit, respectively.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
helper function to call _parallelize_with_joblib. Any reason you import logging from scanpy? see here
from scanpy import logging
from joblib import Parallel, delayed
def _cell_worker_map(cell_weights, worker, *, n_jobs = None, backend = "loky"):
cell_names = list(cell_weights.index)
weights_arr = cell_weights.values
jobs = (delayed(worker)(weights_arr[i, :]) for i in range(len(cell_names)))
res_list = list(
_parallelize_with_joblib(
jobs,
total=len(cell_names),
n_jobs=n_jobs,
backend=backend,
)
)
return pd.DataFrame(res_list, index=cell_names, columns=["pvalue", "coef"])Updated shears code:
import functools
import numpy as np
import pandas as pd
import patsy
import statsmodels.api as sm
from shears._util import _cell_worker_map
from threadpoolctl import threadpool_limits
def _test_cell_glm(endog_array, exog_array, cell_weights, family, init_kwargs, fit_kwargs):
X = exog_array.copy()
X[:, -1] = cell_weights
with threadpool_limits(limits=1):
res = sm.GLM(endog_array, X, family=family, **init_kwargs).fit(**fit_kwargs)
return float(res.pvalues[-1]), float(res.params[-1])def shears_glm(
adata_sc,
adata_bulk,
dep_var,
*,
family=sm.families.Binomial(),
covariate_str=None,
cell_weights_key="cell_weights",
key_added="shears",
n_jobs=None,
inplace=True,
init_kwargs=None,
fit_kwargs=None,
):
covariate_str = (covariate_str or "").lstrip("+").strip()
covariate_str = f" + {covariate_str}" if covariate_str else ""
formula = f"{dep_var} ~ cell_weight{covariate_str}"
print("Formula:", formula)
keep = [c for c in adata_bulk.obs.columns if c in (dep_var + covariate_str)]
assert "cell_weight" not in keep, "cell_weight is reserved"
bulk_obs = adata_bulk.obs.loc[:, keep].copy()
weights_df = adata_sc.obsm[cell_weights_key].loc[:, bulk_obs.index]
bulk_obs["cell_weight"] = 0.0
response_df, predictors_df = patsy.dmatrices(formula, bulk_obs, return_type="dataframe")
endog_array = response_df.iloc[:, 0].values
exog_array = predictors_df.values
init_kwargs = init_kwargs or {}
fit_kwargs = fit_kwargs or {}
worker = functools.partial(
_test_cell_glm,
endog_array,
exog_array,
family=family,
init_kwargs=init_kwargs,
fit_kwargs=fit_kwargs,
)
df_res = _cell_worker_map(weights_df, worker, n_jobs=n_jobs, backend="loky")
if inplace:
adata_sc.obs[key_added] = df_res["coef"]
return df_resUsing smf.glm(formula, data=df, family=fam).fit() is roughly 3 times slower than building the design up front as I do here as patsy needs to build the design for every iteration. The output is not exactly the same but very close. Not sure if smf.glm and sm.GLM have other defaults or maybe patsy does something different then I do here?
| obs_names | p-value (sm.GLM) | coef (sm.GLM) | p-value (smf.glm) | coef (smf.glm) |
|---|---|---|---|---|
| MUI_Innsbruck-P11-1960 | 2.645133e-01 | 6.219972e+05 | 2.578861e-01 | 6.316554e+05 |
| MUI_Innsbruck-P11-6178 | 2.560558e-06 | -1.393915e+06 | 2.646097e-06 | -1.392864e+06 |
| MUI_Innsbruck-P11-11945 | 6.353457e-07 | 2.413854e+06 | 6.647526e-07 | 2.410789e+06 |
| MUI_Innsbruck-P11-17309 | 8.411141e-24 | -2.552637e+06 | 8.937308e-24 | -2.552816e+06 |
| MUI_Innsbruck-P11-20381 | 4.545991e-23 | -3.283744e+06 | 5.102287e-23 | -3.295412e+06 |
| MUI_Innsbruck-P9-14038323 | 3.712091e-21 | -1.988764e+06 | 4.014555e-21 | -1.986937e+06 |
| MUI_Innsbruck-P9-14039124 | 2.293766e-20 | -2.463100e+06 | 2.419365e-20 | -2.460761e+06 |
| MUI_Innsbruck-P9-14041415 | 8.545003e-15 | -3.110951e+06 | 8.828607e-15 | -3.108735e+06 |
| MUI_Innsbruck-P9-14044479 | 1.789800e-02 | 2.140314e+05 | 1.846450e-02 | 2.131536e+05 |
| MUI_Innsbruck-P9-14044842 | 7.931423e-07 | 1.116282e+06 | 8.407039e-07 | 1.116369e+06 |
I have also noticed that I cannot get the reference level for dep_var with
res = shears_glm(
adata_sc,
adata_bulk,
dep_var="C(microsatellite_status, Treatment(reference='MSS'))",
covariate_str=covariate_str,
n_jobs=-1,
)If I understood correctly patsy does not consider this notation for the first variable in the formula, instead I need to do something like
adata_bulk.obs["microsatellite_status"] = pd.Categorical(
adata_bulk.obs["microsatellite_status"], categories=["MSI", "MSS"]
)This is only going to affect the coef sign but I guess its good to be able to define the base level of your dep_var
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
helper function to call _parallelize_with_joblib. Any reason you import logging from scanpy? see here
I wanted the logging settings to be consistent with scanpy. I'm not sure I would do it again that way and it certainly doesn't make sense for shears.
I wasn't aware the formula interface is so much slower, good idea to make the matrices upfront! I don't know why you get different results with smf.GLM and GLM, but as you said, it's really close anyway.
I would suggest to use formulaic over patsy. It is considered the successor of patsy which is not actively developed anymore and is also much faster for large design matrices. I thought that the reference level is considered also for the first variable, but either I never noticed or it's different between formulaic and patsy.
src/shears/pp/__init__.py
Outdated
| def _compute_block(sc_mat: np.ndarray, bulk_mat: np.ndarray, i0: int, i1: int, j0: int, j1: int) -> List[Tuple[float, float, int, int]]: | ||
| """ | ||
| Calculate the Pearson correlation coefficient and its p-value for each (cell, bulk) pair in the given block. | ||
| Returns a list of tuples (r, p, global_cell_idx, global_bulk_idx). | ||
| """ | ||
| res = [] | ||
| for ii in range(i0, i1): | ||
| for jj in range(j0, j1): | ||
| r, p = pearsonr(sc_mat[ii], bulk_mat[jj]) | ||
| res.append((r, p, ii, jj)) | ||
| return res | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really care that much, because this is only for comparison with the legacy scissor, so feel free to leave as is...
but shouldn't numpy.corrcoef directly compute the correlation between two arrays? I'd assume it to be more efficient and already use parallelism implicitly due to the underlying BLAS library.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scipy.stats.pearsonr also computes the p value. I will just drop the sh.pp.cell_corrs() function from the pull request.
|
We can ignore the failed docs for now and fix it in a separate PR |
A few comments:
shears/src/shears/tl/__init__.py
Lines 47 to 51 in 659346d
Instead of having n threads running at 100% I had only one or two, probably because the overhead of copying the whole data for every cell was to large.
I tried 2 approaches and settled with the one using module-level globals to share large data objects with worker processes without re-serializing them on every call. This is also compatible with the
from shears._util import process_mapfunction. I am not sure if this has any drawbacks.The other approach is not compatible with
process_mapand uses aninitializerplusProcessPoolExecutor.from statsmodels.formula.api import phregbut ultimately settled withfrom lifelines import CoxPHFitter. The package seems to be well maintained, can handle patsy formulas, and it also accepts additional args to tweak the fit. I am not sure the way I implemented how the function passes the kwargs is the neatest but it works ...You can pass for example:
Especially the strata argument is compelling for batches that do not meet the proportional‐hazards assumption. Using
stratafits a separate baseline hazard for each batch.lifelinesalso has a build in function to check if this is the case: