Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,72 @@ metric.fit(Xi, Xj)
dist = metric.score(Xi, Xj)
```

## Dynamic stochastic shape metrics

In addition to above, we provide methods to compare between stochastic and dynamic neural responses (e.g. biological neural network responses to stimulus repetitions as a function of time, or latent dynamic activations in diffusion models). The API is similar to `LinearMetric()`, but requires differently-formatted inputs.


**1) Dynamic stochastic shape metrics using** `GPStochasticMetric()`

The first method models network response distributions as Gaussian Process, and computes distances based on the analytic solution to the bi-causal optimal transport distance between two stochastic processes. This involves computing class-conditional means and covariances for each network, then computing the metric as follows.

```python
# Given
# -----
# Xi : Tuple[ndarray, ndarray]
# The first array is (num_neurons*num_times x 1) array of means and the second array is (num_neurons*num_times x num_neurons*num_times) covariance matrices of first network.
#
# Xj : Tuple[ndarray, ndarray]
# Same as Xi, but for the second network's responses.
#
# alpha: float between [0, 2].
# When alpha=2, this reduces to the deterministic shape metric. When alpha=1, this is the 2-Wasserstein between two Gaussians. When alpha=0, this is the Bures metric between the two sets of covariance matrices.

# Fit alignment

metric = GPStochasticMetric(
n_dims=num_neurons, # number of neurons
group="orth", # nuisance transformation
type='adapted', # adapted or non-adapted optimal transport distance
alpha=alpha # alpha described above
)

metric.fit(Xi, Xj)

# Evaluate the distance between the two networks
dist = metric.score(Xi, Xj)
```

**2) Dynamic stochastic shape metrics using** `GPStochasticDiff()`

We also provide dynamic stochastic shape metrics based on the differentiable optimization. The metric computes the same metric as in the previous section, but instead of alternating minimization it uses a differentiable optimization strategy.

```python
# Given
# -----
# Xi : ndarray, (num_neurons*num_times x 1)
# First network's responses.
#
# Xj : ndarray, (num_neurons*num_times x num_neurons*num_times)
# Same as Xi, but for the second network's responses.
#

# Fit alignment
GPStochasticDiff(
n_dims=num_neurons, # number of neurons
n_times=num_times, # number of time points
type="Bures" # distance type, options are Bures, Adapted_Bures, Knothe_Rosenblatt, Marginal_Bures
)

# Evaluate the distance between the two networks
dist = metric.fit_score(
Xi, Xj,
lr=1e-3, # learning rate
tol=1e-5, # tolerance of optimization
epsilon=1e-6 # used for well-conditioning covariances
)
```

### Computing distances between many networks

Things start to get really interesting when we start to consider larger cohorts containing more than just two networks. The `netrep.multiset` file contains some useful methods. Let `Xs = [X1, X2, X3, ..., Xk]` be a list of `num_samples x num_neurons` matrices similar to those described above. We can do the following:
Expand Down
278 changes: 278 additions & 0 deletions examples/multiset_demo.ipynb

Large diffs are not rendered by default.

211 changes: 211 additions & 0 deletions examples/stochastic_process_metrics.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/mnt/home/anejatbakhsh/anaconda3/envs/netrep/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"# %%\n",
"\"\"\"\n",
"Tests metrics betwen stochastic process neural representations.\n",
"\"\"\"\n",
"import numpy as np\n",
"from netrep.metrics import GPStochasticMetric,GaussianStochasticMetric,GPStochasticDiff\n",
"from netrep.utils import rand_orth\n",
"from sklearn.utils.validation import check_random_state\n",
"from sklearn.covariance import EmpiricalCovariance\n",
"\n",
"from numpy import random as rand\n",
"from netrep.utils import rand_orth\n",
"\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# %% Class for sampling from a gaussian process given a kernel\n",
"class GaussianProcess:\n",
" def __init__(self,kernel,D):\n",
" self.kernel = kernel\n",
" self.D = D\n",
"\n",
" def evaluate_kernel(self, xs, ys):\n",
" fun = np.vectorize(self.kernel)\n",
" return fun(xs[:, None], ys)\n",
"\n",
" def sample(self,ts,seed=0):\n",
" np.random.seed(seed)\n",
"\n",
" T = ts.shape[0]\n",
" c_g = self.evaluate_kernel(ts,ts)\n",
" fs = rand.multivariate_normal(\n",
" mean=np.zeros(T),\n",
" cov=c_g,\n",
" size=self.D\n",
" )\n",
" return fs"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"seed = 0\n",
"t = 5\n",
"n = 2\n",
"k = 100\n",
"\n",
"# Set random seed, draw random rotation\n",
"rs = check_random_state(seed)\n",
"if n > 1:\n",
" Q = rand_orth(n, n, random_state=rs)\n",
"else:\n",
" Q = 1\n",
" \n",
"# Generate data from a gaussian process with RBF kernel\n",
"ts = np.linspace(0,1,t)\n",
"gpA = GaussianProcess(\n",
" kernel = lambda x, y: 1e-2*(1e-6*(x==y)+np.exp(-np.linalg.norm(x-y)**2/(2*1.**2))),\n",
" D=n\n",
")\n",
"sA = np.array([gpA.sample(ts,seed=i) for i in range(k)]).reshape(k,n*t)\n",
"\n",
"# Transform GP according to a rotation applied to individiual \n",
"# blocks of the full covariance matrix\n",
"A = [sA.mean(0),EmpiricalCovariance().fit(sA).covariance_]\n",
"B = [\n",
" np.kron(np.eye(t),Q)@A[0],\n",
" np.kron(np.eye(t),Q)@A[1]@(np.kron(np.eye(t),Q)).T\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DSSD: -7.450580596923828e-09 , Marginal SSD: 0.0 , Adapted SSD: 9.125060374972147e-09\n"
]
}
],
"source": [
"# Using alternating optimization and Orthogonal Procrustes\n",
"# Compute dSSD\n",
"metric = GPStochasticMetric(n_dims=n,group=\"orth\")\n",
"dssd = metric.fit_score(A,B)\n",
"\n",
"# Compute aSSD\n",
"metric = GPStochasticMetric(\n",
" n_dims=n,\n",
" group=\"orth\",\n",
" type='adapted',\n",
")\n",
"assd = metric.fit_score(A,B)\n",
"\n",
"# Compute mSSD\n",
"metric = GaussianStochasticMetric(group=\"orth\")\n",
"A_marginal = [\n",
" A[0].reshape(t,n),\n",
" np.array([A[1][i*n:(i+1)*n,i*n:(i+1)*n] for i in range(t)])\n",
"]\n",
"B_marginal = [\n",
" B[0].reshape(t,n),\n",
" np.array([B[1][i*n:(i+1)*n,i*n:(i+1)*n] for i in range(t)])\n",
"]\n",
"mssd = metric.fit_score(A_marginal,B_marginal)\n",
"\n",
"print('DSSD: ', dssd, ', Marginal SSD: ', mssd, ', Adapted SSD: ', assd)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Iteration 700, loss 0.04: : 0it [00:01, ?it/s]\n",
"Iteration 200, loss 0.00: : 0it [00:00, ?it/s]\n",
"Iteration 200, loss 0.00: : 0it [00:00, ?it/s]\n",
"Iteration 700, loss 0.01: : 0it [00:02, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"DSSD: 0.03943214 , Adapted DSSD: 0.0023483392 , Marginal SSD: 0.005165683 , Knothe Rosenblatt SSD: 0.0023267935\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# Using differentiable optimization and Cayley orthogonal parameterization\n",
"\n",
"metric = GPStochasticDiff(n_dims=n,n_times=t,type=\"Bures\")\n",
"dssd = metric.fit_score(A,B,lr=1e-3,tol=1e-5,epsilon=1e-6)\n",
"\n",
"metric = GPStochasticDiff(n_dims=n,n_times=t,type=\"Adapted_Bures\")\n",
"assd = metric.fit_score(A,B,lr=1e-3,tol=1e-5,epsilon=1e-6)\n",
"\n",
"metric = GPStochasticDiff(n_dims=n,n_times=t,type=\"Knothe_Rosenblatt\")\n",
"kssd = metric.fit_score(A,B,lr=1e-3,tol=1e-5,epsilon=1e-6)\n",
"\n",
"metric = GPStochasticDiff(n_dims=n,n_times=t,type=\"Marginal_Bures\")\n",
"mssd = metric.fit_score(A,B,lr=1e-3,tol=1e-5,epsilon=1e-6)\n",
"\n",
"print('DSSD: ', dssd, ', Adapted DSSD: ', assd, ', Marginal SSD: ', mssd, ', Knothe Rosenblatt SSD: ', kssd)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "netrep",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading