-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils_sample.py
More file actions
66 lines (52 loc) · 2.16 KB
/
utils_sample.py
File metadata and controls
66 lines (52 loc) · 2.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import torch.distributions as D
def normalise(probs: torch.Tensor, eps: float = 1e-30) -> torch.Tensor:
"""
Normalise a tensor of probabilities to sum to 1, with numerical stability.
:param torch.Tensor probs: tensor of probabilities (not necessarily normalized)
:param float eps: small constant to prevent division by zero
:return torch.Tensor: normalized probabilities
"""
total = torch.sum(probs) + eps
return probs / total
def bern_sample(prob: float | torch.Tensor) -> int:
"""
Sample from a Bernoulli distribution with given probability.
:param float | torch.Tensor prob: probability of success (between 0 and 1)
:return int: sampled value (0 or 1)
"""
return int(D.Bernoulli(probs=prob).sample().item())
def cat_sample(probs: torch.Tensor) -> int:
"""
Sample from a categorical distribution defined by the given probabilities.
:param torch.Tensor probs: tensor of probabilities (not necessarily normalized)
:return int: sampled category index
"""
return int(D.Categorical(probs).sample().item())
class Categorical2D():
"""
Class for representing a 2D categorical distribution [e.g. a joint P(X,Y)].
"""
def __init__(self, probs: torch.Tensor):
"""
:param torch.Tensor probs: 2D tensor of probabilities (not necessarily normalized)
"""
assert probs.dim() == 2, "probs must be a 2D tensor"
self.dim1, self.dim2 = probs.shape
self.probs2d = probs / torch.sum(probs)
self.probs = probs.flatten()
def sample(self) -> tuple[int, int]:
"""
Sample from the 2D categorical distribution and return the corresponding indices (i, j).
"""
idx = D.Categorical(probs=self.probs).sample().item()
i = idx // self.dim2
j = idx % self.dim2
return int(i), int(j)
def cat2D_sample(probs: torch.Tensor) -> tuple[int, int]:
"""
Sample from a 2D categorical distribution defined by the given probabilities.
:param torch.Tensor probs: 2D tensor of probabilities
:return tuple[int, int]: sampled indices (i, j)
"""
return Categorical2D(probs).sample()