Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
## v0.30 - development version

* Uses prettier `tqdm` output that is now aware of Jupyter notebooks.
* `bayes.update` now supports `lognorm` and `gamma` distributions in addition to `norm` and `beta`.

## v0.29 - latest release

Expand Down
6 changes: 5 additions & 1 deletion CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,8 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
- Type hints: Use throughout codebase
- Error handling: Validate inputs, use ValueError with descriptive messages
- Use operator overloading (`__add__`, `__mul__`, etc.) and custom operators (`@` for sampling)
- Tests: Descriptive names, unit tests match module structure, use hypothesis for property testing
- Tests: Descriptive names, unit tests match module structure, use hypothesis for property testing

## Workflow
- Always run `make format` before committing to ensure code passes formatting checks
- Update CHANGES.md when adding new features or fixing bugs
69 changes: 62 additions & 7 deletions squigglepy/bayes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,17 @@

from datetime import datetime

from .distributions import BetaDistribution, NormalDistribution, norm, beta, mixture
from .distributions import (
BetaDistribution,
NormalDistribution,
LognormalDistribution,
GammaDistribution,
norm,
beta,
lognorm,
gamma,
mixture,
)
from .utils import _core_cuts, _init_tqdm, _tick_tqdm, _flush_tqdm


Expand Down Expand Up @@ -310,14 +320,15 @@ def update(prior, evidence, evidence_weight=1):
Parameters
----------
prior : Distribution
The prior distribution. Currently must either be normal or beta type. Other
types are not yet supported.
The prior distribution. Supported types: normal, lognormal, beta, gamma.
evidence : Distribution
The distribution used to update the prior. Currently must either be normal
or beta type. Other types are not yet supported.
The distribution used to update the prior. Must be the same type as prior.
Supported types: normal, lognormal, beta, gamma.
evidence_weight : float
How much weight to put on the evidence distribution? Currently this only matters
for normal distributions, where this should be equivalent to the sample weight.
How much weight to put on the evidence distribution? For normal and lognormal
distributions, this is equivalent to the sample weight. For gamma distributions,
this scales the evidence shape parameter. For beta distributions, this parameter
is currently ignored.

Returns
-------
Expand All @@ -330,6 +341,14 @@ def update(prior, evidence, evidence_weight=1):
>> evidence = sq.norm(2,3)
>> bayes.update(prior, evidence)
<Distribution> norm(mean=2.53, sd=0.29)
>> prior = sq.lognorm(1, 10)
>> evidence = sq.lognorm(2, 8)
>> bayes.update(prior, evidence)
<Distribution> lognorm(...)
>> prior = sq.gamma(shape=2, scale=1)
>> evidence = sq.gamma(shape=3, scale=1.5)
>> bayes.update(prior, evidence)
<Distribution> gamma(shape=5, scale=...)
"""
if isinstance(prior, NormalDistribution) and isinstance(evidence, NormalDistribution):
prior_mean = prior.mean
Expand All @@ -345,12 +364,48 @@ def update(prior, evidence, evidence_weight=1):
(evidence_var * prior_var) / (evidence_weight * prior_var + evidence_var)
),
)
elif isinstance(prior, LognormalDistribution) and isinstance(evidence, LognormalDistribution):
# Lognormal update is performed in log-space where it behaves like a normal
prior_norm_mean = prior.norm_mean
prior_norm_var = prior.norm_sd**2
evidence_norm_mean = evidence.norm_mean
evidence_norm_var = evidence.norm_sd**2
# Apply normal update formula in log-space
posterior_norm_mean = (
evidence_norm_var * prior_norm_mean
+ evidence_weight * (prior_norm_var * evidence_norm_mean)
) / (evidence_weight * prior_norm_var + evidence_norm_var)
posterior_norm_sd = math.sqrt(
(evidence_norm_var * prior_norm_var)
/ (evidence_weight * prior_norm_var + evidence_norm_var)
)
return lognorm(norm_mean=posterior_norm_mean, norm_sd=posterior_norm_sd)
elif isinstance(prior, BetaDistribution) and isinstance(evidence, BetaDistribution):
prior_a = prior.a
prior_b = prior.b
evidence_a = evidence.a
evidence_b = evidence.b
return beta(prior_a + evidence_a, prior_b + evidence_b)
elif isinstance(prior, GammaDistribution) and isinstance(evidence, GammaDistribution):
# Gamma conjugate update: combine shape parameters and compute weighted scale
# For gamma distributions with shape α and scale θ, the mean is αθ
# We add the shapes and compute a weighted harmonic mean of scales
prior_shape = prior.shape
prior_scale = prior.scale
evidence_shape = evidence.shape * evidence_weight
evidence_scale = evidence.scale
# Posterior shape is sum of shapes
posterior_shape = prior_shape + evidence_shape
# Posterior scale uses precision-weighted combination (like normal variance)
# Using rate (1/scale) for combination, then converting back
prior_rate = 1 / prior_scale
evidence_rate = 1 / evidence_scale
# Weight rates by their respective shape parameters (effective sample sizes)
posterior_rate = (prior_shape * prior_rate + evidence_shape * evidence_rate) / (
prior_shape + evidence_shape
)
posterior_scale = 1 / posterior_rate
return gamma(shape=posterior_shape, scale=posterior_scale)
elif not isinstance(prior, type(evidence)):
print(type(prior), type(evidence))
raise ValueError("can only update distributions of the same type.")
Expand Down
57 changes: 54 additions & 3 deletions tests/test_bayes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@

from ..squigglepy.bayes import simple_bayes, bayesnet, update, average
from ..squigglepy.samplers import sample
from ..squigglepy.distributions import discrete, norm, beta, gamma
from ..squigglepy.distributions import discrete, norm, beta, gamma, lognorm, poisson
from ..squigglepy.rng import set_seed
from ..squigglepy.distributions import BetaDistribution, MixtureDistribution, NormalDistribution
from ..squigglepy.distributions import (
BetaDistribution,
GammaDistribution,
LognormalDistribution,
MixtureDistribution,
NormalDistribution,
)


def test_simple_bayes():
Expand Down Expand Up @@ -407,9 +413,54 @@ def test_update_beta():
assert out.b == 3


def test_update_lognormal():
# Lognormal update is performed in log-space
out = update(lognorm(1, 10), lognorm(2, 8))
assert isinstance(out, LognormalDistribution)
# The posterior should be between the prior and evidence in log-space
assert out.norm_mean is not None
assert out.norm_sd is not None
# Posterior uncertainty should be less than both prior and evidence
prior_norm_sd = lognorm(1, 10).norm_sd
evidence_norm_sd = lognorm(2, 8).norm_sd
assert out.norm_sd < prior_norm_sd
assert out.norm_sd < evidence_norm_sd


def test_update_lognormal_evidence_weight():
# With higher evidence weight, posterior should be closer to evidence
prior = lognorm(1, 10)
evidence = lognorm(2, 8)
out_weight1 = update(prior, evidence, evidence_weight=1)
out_weight3 = update(prior, evidence, evidence_weight=3)
assert isinstance(out_weight3, LognormalDistribution)
# With higher weight on evidence, posterior norm_mean should be closer to evidence
assert abs(out_weight3.norm_mean - evidence.norm_mean) < abs(
out_weight1.norm_mean - evidence.norm_mean
)


def test_update_gamma():
out = update(gamma(shape=2, scale=1), gamma(shape=3, scale=1.5))
assert isinstance(out, GammaDistribution)
# Shape should be the sum of shapes
assert out.shape == 5
# Scale should be between the two scales (weighted harmonic mean)
assert 1 < out.scale < 1.5


def test_update_gamma_evidence_weight():
prior = gamma(shape=2, scale=1)
evidence = gamma(shape=3, scale=2)
out = update(prior, evidence, evidence_weight=2)
assert isinstance(out, GammaDistribution)
# Shape should be prior.shape + evidence_weight * evidence.shape
assert out.shape == 2 + 2 * 3 # = 8


def test_update_not_implemented():
with pytest.raises(ValueError) as excinfo:
update(gamma(1), gamma(2))
update(poisson(1), poisson(2))
assert "not supported" in str(excinfo.value)


Expand Down