Skip to content

Commit fe341e1

Browse files
committed
Add integrations/decoders.py
1 parent de300a9 commit fe341e1

File tree

2 files changed

+220
-38
lines changed

2 files changed

+220
-38
lines changed

cebra/integrations/decoders.py

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
from typing import List, Optional, Tuple, Union
2+
3+
import numpy as np
4+
import sklearn.metrics
5+
import torch
6+
import torch.nn as nn
7+
import torch.optim as optim
8+
from sklearn.linear_model import Ridge
9+
from sklearn.model_selection import GridSearchCV
10+
from torch.utils.data import DataLoader
11+
from torch.utils.data import TensorDataset
12+
13+
14+
def ridge_decoding(
15+
embedding_train: Union[torch.Tensor, dict],
16+
embedding_valid: Union[torch.Tensor, dict],
17+
label_train: Union[torch.Tensor, dict],
18+
label_valid: Union[torch.Tensor, dict],
19+
n_run: Optional[int] = None,
20+
) -> Tuple[List[float], List[float], np.ndarray]:
21+
"""
22+
Perform ridge regression decoding on training and validation embeddings.
23+
24+
Args:
25+
embedding_train (Union[torch.Tensor, dict]): Training embeddings.
26+
embedding_valid (Union[torch.Tensor, dict]): Validation embeddings.
27+
label_train (Union[torch.Tensor, dict]): Training labels.
28+
label_valid (Union[torch.Tensor, dict]): Validation labels.
29+
n_run (Optional[int]): Optional run number for dataset definition.
30+
31+
Returns:
32+
Training R2 scores, validation R2 scores, and validation predictions.
33+
"""
34+
if isinstance(embedding_train, dict): # only on run 1
35+
if n_run is None:
36+
raise ValueError(f"n_run must be specified, got {n_run}.")
37+
38+
all_train_embeddings = np.concatenate(
39+
[
40+
embedding_train[i][n_run].cpu().numpy()
41+
for i in range(len(embedding_train))
42+
],
43+
axis=0,
44+
)
45+
train = np.concatenate(
46+
[
47+
label_train[i].continuous.cpu().numpy()
48+
for i in range(len(label_train))
49+
],
50+
axis=0,
51+
)
52+
all_val_embeddings = np.concatenate(
53+
[
54+
embedding_valid[i][n_run].cpu().numpy()
55+
for i in range(len(embedding_valid))
56+
],
57+
axis=0,
58+
)
59+
valid = np.concatenate(
60+
[
61+
label_valid[i].continuous.cpu().numpy()
62+
for i in range(len(label_valid))
63+
],
64+
axis=0,
65+
)
66+
else:
67+
all_train_embeddings = embedding_train.cpu().numpy()
68+
train = label_train.cpu().numpy()
69+
all_val_embeddings = embedding_valid.cpu().numpy()
70+
valid = label_valid.cpu().numpy()
71+
72+
decoder = GridSearchCV(Ridge(), {"alpha": np.logspace(-4, 0, 9)})
73+
decoder.fit(all_train_embeddings, train)
74+
75+
train_prediction = decoder.predict(all_train_embeddings)
76+
train_scores = sklearn.metrics.r2_score(train,
77+
train_prediction,
78+
multioutput="raw_values").tolist()
79+
valid_prediction = decoder.predict(all_val_embeddings)
80+
valid_scores = sklearn.metrics.r2_score(valid,
81+
valid_prediction,
82+
multioutput="raw_values").tolist()
83+
84+
return train_scores, valid_scores, valid_prediction
85+
86+
87+
class SingleLayerDecoder(nn.Module):
88+
"""Supervised module to predict behaviors.
89+
90+
Note:
91+
By default, the output dimension is 2, to predict x/y velocity
92+
(Perich et al., 2018).
93+
"""
94+
95+
def __init__(self, input_dim, output_dim=2):
96+
super(SingleLayerDecoder, self).__init__()
97+
self.fc = nn.Linear(input_dim, output_dim)
98+
99+
def forward(self, x):
100+
return self.fc(x)
101+
102+
103+
class TwoLayersDecoder(nn.Module):
104+
"""Supervised module to predict behaviors.
105+
106+
Note:
107+
By default, the output dimension is 2, to predict x/y velocity
108+
(Perich et al., 2018).
109+
"""
110+
111+
def __init__(self, input_dim, output_dim=2):
112+
super(TwoLayersDecoder, self).__init__()
113+
self.fc = nn.Sequential(nn.Linear(input_dim, 32), nn.GELU(),
114+
nn.Linear(32, output_dim))
115+
116+
def forward(self, x):
117+
return self.fc(x)
118+
119+
120+
def mlp_decoding(
121+
embedding_train: Union[dict, torch.Tensor],
122+
embedding_valid: Union[dict, torch.Tensor],
123+
label_train: Union[dict, torch.Tensor],
124+
label_valid: Union[dict, torch.Tensor],
125+
num_epochs: int = 20,
126+
lr: float = 0.001,
127+
batch_size: int = 500,
128+
device: str = "cuda",
129+
model_type: str = "SingleLayerMLP",
130+
n_run: Optional[int] = None,
131+
):
132+
""" Perform MLP decoding on training and validation embeddings.
133+
134+
Args:
135+
embedding_train (Union[dict, torch.Tensor]): Training embeddings.
136+
embedding_valid (Union[dict, torch.Tensor]): Validation embeddings.
137+
label_train (Union[dict, torch.Tensor]): Training labels.
138+
label_valid (Union[dict, torch.Tensor]): Validation labels.
139+
num_epochs (int): Number of training epochs.
140+
lr (float): Learning rate for the optimizer.
141+
batch_size (int): Batch size for training.
142+
device (str): Device to run the model on ('cuda' or 'cpu').
143+
model_type (str): Type of MLP model to use ('SingleLayerMLP' or 'TwoLayersMLP').
144+
n_run (Optional[int]): Optional run number for dataset definition.
145+
146+
Returns:
147+
Training R2 scores, validation R2 scores, and validation predictions.
148+
"""
149+
if len(label_train.shape) == 1:
150+
label_train = label_train[:, None]
151+
label_valid = label_valid[:, None]
152+
153+
if isinstance(embedding_train, dict): # only on run 1
154+
if n_run is None:
155+
raise ValueError(f"n_run must be specified, got {n_run}.")
156+
157+
all_train_embeddings = torch.cat(
158+
[embedding_train[i][n_run] for i in range(len(embedding_train))],
159+
axis=0)
160+
train = torch.cat(
161+
[label_train[i].continuous for i in range(len(label_train))],
162+
axis=0)
163+
all_val_embeddings = torch.cat(
164+
[embedding_valid[i][n_run] for i in range(len(embedding_valid))],
165+
axis=0)
166+
valid = torch.cat(
167+
[label_valid[i].continuous for i in range(len(label_valid))],
168+
axis=0)
169+
else:
170+
all_train_embeddings = embedding_train
171+
train = label_train
172+
all_val_embeddings = embedding_valid
173+
valid = label_valid
174+
175+
dataset = TensorDataset(all_train_embeddings.to(device), train.to(device))
176+
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
177+
178+
input_dim = all_train_embeddings.shape[1]
179+
output_dim = train.shape[1]
180+
if model_type == "SingleLayerMLP":
181+
model = SingleLayerDecoder(input_dim=input_dim, output_dim=output_dim)
182+
elif model_type == "TwoLayersMLP":
183+
model = TwoLayersDecoder(input_dim=input_dim, output_dim=output_dim)
184+
else:
185+
raise NotImplementedError()
186+
model.to(device)
187+
188+
criterion = nn.MSELoss()
189+
optimizer = optim.Adam(model.parameters(), lr=lr)
190+
191+
for epoch in range(num_epochs):
192+
model.train()
193+
running_loss = 0.0
194+
195+
for inputs, labels in train_loader:
196+
inputs, labels = inputs.to(device), labels.to(device)
197+
198+
optimizer.zero_grad()
199+
outputs = model(inputs)
200+
loss = criterion(outputs, labels)
201+
loss.backward()
202+
optimizer.step()
203+
running_loss += loss.item()
204+
205+
model.eval()
206+
train_pred = model(all_train_embeddings.to(device))
207+
train_r2 = sklearn.metrics.r2_score(
208+
y_true=train.cpu().numpy(),
209+
y_pred=train_pred.cpu().detach().numpy(),
210+
multioutput="raw_values",
211+
).tolist()
212+
213+
valid_pred = model(all_val_embeddings.to(device))
214+
valid_r2 = sklearn.metrics.r2_score(
215+
y_true=valid.cpu().numpy(),
216+
y_pred=valid_pred.cpu().detach().numpy(),
217+
multioutput="raw_values",
218+
).tolist()
219+
220+
return train_r2, valid_r2, valid_pred

cebra/models/decoders.py

Lines changed: 0 additions & 38 deletions
This file was deleted.

0 commit comments

Comments
 (0)