diff --git a/CHANGES.md b/CHANGES.md index 5c60514..2c87ca5 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,6 +1,7 @@ ## v0.30 - development version * Uses prettier `tqdm` output that is now aware of Jupyter notebooks. +* `bayes.update` now supports `lognorm` and `gamma` distributions in addition to `norm` and `beta`. ## v0.29 - latest release diff --git a/CLAUDE.md b/CLAUDE.md index 83c77f4..fe9fe1b 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -17,4 +17,8 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co - Type hints: Use throughout codebase - Error handling: Validate inputs, use ValueError with descriptive messages - Use operator overloading (`__add__`, `__mul__`, etc.) and custom operators (`@` for sampling) -- Tests: Descriptive names, unit tests match module structure, use hypothesis for property testing \ No newline at end of file +- Tests: Descriptive names, unit tests match module structure, use hypothesis for property testing + +## Workflow +- Always run `make format` before committing to ensure code passes formatting checks +- Update CHANGES.md when adding new features or fixing bugs \ No newline at end of file diff --git a/squigglepy/bayes.py b/squigglepy/bayes.py index 4487823..bce4728 100644 --- a/squigglepy/bayes.py +++ b/squigglepy/bayes.py @@ -8,7 +8,17 @@ from datetime import datetime -from .distributions import BetaDistribution, NormalDistribution, norm, beta, mixture +from .distributions import ( + BetaDistribution, + NormalDistribution, + LognormalDistribution, + GammaDistribution, + norm, + beta, + lognorm, + gamma, + mixture, +) from .utils import _core_cuts, _init_tqdm, _tick_tqdm, _flush_tqdm @@ -310,14 +320,15 @@ def update(prior, evidence, evidence_weight=1): Parameters ---------- prior : Distribution - The prior distribution. Currently must either be normal or beta type. Other - types are not yet supported. + The prior distribution. Supported types: normal, lognormal, beta, gamma. evidence : Distribution - The distribution used to update the prior. Currently must either be normal - or beta type. Other types are not yet supported. + The distribution used to update the prior. Must be the same type as prior. + Supported types: normal, lognormal, beta, gamma. evidence_weight : float - How much weight to put on the evidence distribution? Currently this only matters - for normal distributions, where this should be equivalent to the sample weight. + How much weight to put on the evidence distribution? For normal and lognormal + distributions, this is equivalent to the sample weight. For gamma distributions, + this scales the evidence shape parameter. For beta distributions, this parameter + is currently ignored. Returns ------- @@ -330,6 +341,14 @@ def update(prior, evidence, evidence_weight=1): >> evidence = sq.norm(2,3) >> bayes.update(prior, evidence) norm(mean=2.53, sd=0.29) + >> prior = sq.lognorm(1, 10) + >> evidence = sq.lognorm(2, 8) + >> bayes.update(prior, evidence) + lognorm(...) + >> prior = sq.gamma(shape=2, scale=1) + >> evidence = sq.gamma(shape=3, scale=1.5) + >> bayes.update(prior, evidence) + gamma(shape=5, scale=...) """ if isinstance(prior, NormalDistribution) and isinstance(evidence, NormalDistribution): prior_mean = prior.mean @@ -345,12 +364,48 @@ def update(prior, evidence, evidence_weight=1): (evidence_var * prior_var) / (evidence_weight * prior_var + evidence_var) ), ) + elif isinstance(prior, LognormalDistribution) and isinstance(evidence, LognormalDistribution): + # Lognormal update is performed in log-space where it behaves like a normal + prior_norm_mean = prior.norm_mean + prior_norm_var = prior.norm_sd**2 + evidence_norm_mean = evidence.norm_mean + evidence_norm_var = evidence.norm_sd**2 + # Apply normal update formula in log-space + posterior_norm_mean = ( + evidence_norm_var * prior_norm_mean + + evidence_weight * (prior_norm_var * evidence_norm_mean) + ) / (evidence_weight * prior_norm_var + evidence_norm_var) + posterior_norm_sd = math.sqrt( + (evidence_norm_var * prior_norm_var) + / (evidence_weight * prior_norm_var + evidence_norm_var) + ) + return lognorm(norm_mean=posterior_norm_mean, norm_sd=posterior_norm_sd) elif isinstance(prior, BetaDistribution) and isinstance(evidence, BetaDistribution): prior_a = prior.a prior_b = prior.b evidence_a = evidence.a evidence_b = evidence.b return beta(prior_a + evidence_a, prior_b + evidence_b) + elif isinstance(prior, GammaDistribution) and isinstance(evidence, GammaDistribution): + # Gamma conjugate update: combine shape parameters and compute weighted scale + # For gamma distributions with shape α and scale θ, the mean is αθ + # We add the shapes and compute a weighted harmonic mean of scales + prior_shape = prior.shape + prior_scale = prior.scale + evidence_shape = evidence.shape * evidence_weight + evidence_scale = evidence.scale + # Posterior shape is sum of shapes + posterior_shape = prior_shape + evidence_shape + # Posterior scale uses precision-weighted combination (like normal variance) + # Using rate (1/scale) for combination, then converting back + prior_rate = 1 / prior_scale + evidence_rate = 1 / evidence_scale + # Weight rates by their respective shape parameters (effective sample sizes) + posterior_rate = (prior_shape * prior_rate + evidence_shape * evidence_rate) / ( + prior_shape + evidence_shape + ) + posterior_scale = 1 / posterior_rate + return gamma(shape=posterior_shape, scale=posterior_scale) elif not isinstance(prior, type(evidence)): print(type(prior), type(evidence)) raise ValueError("can only update distributions of the same type.") diff --git a/tests/test_bayes.py b/tests/test_bayes.py index 09c6323..e2789bd 100644 --- a/tests/test_bayes.py +++ b/tests/test_bayes.py @@ -3,9 +3,15 @@ from ..squigglepy.bayes import simple_bayes, bayesnet, update, average from ..squigglepy.samplers import sample -from ..squigglepy.distributions import discrete, norm, beta, gamma +from ..squigglepy.distributions import discrete, norm, beta, gamma, lognorm, poisson from ..squigglepy.rng import set_seed -from ..squigglepy.distributions import BetaDistribution, MixtureDistribution, NormalDistribution +from ..squigglepy.distributions import ( + BetaDistribution, + GammaDistribution, + LognormalDistribution, + MixtureDistribution, + NormalDistribution, +) def test_simple_bayes(): @@ -407,9 +413,54 @@ def test_update_beta(): assert out.b == 3 +def test_update_lognormal(): + # Lognormal update is performed in log-space + out = update(lognorm(1, 10), lognorm(2, 8)) + assert isinstance(out, LognormalDistribution) + # The posterior should be between the prior and evidence in log-space + assert out.norm_mean is not None + assert out.norm_sd is not None + # Posterior uncertainty should be less than both prior and evidence + prior_norm_sd = lognorm(1, 10).norm_sd + evidence_norm_sd = lognorm(2, 8).norm_sd + assert out.norm_sd < prior_norm_sd + assert out.norm_sd < evidence_norm_sd + + +def test_update_lognormal_evidence_weight(): + # With higher evidence weight, posterior should be closer to evidence + prior = lognorm(1, 10) + evidence = lognorm(2, 8) + out_weight1 = update(prior, evidence, evidence_weight=1) + out_weight3 = update(prior, evidence, evidence_weight=3) + assert isinstance(out_weight3, LognormalDistribution) + # With higher weight on evidence, posterior norm_mean should be closer to evidence + assert abs(out_weight3.norm_mean - evidence.norm_mean) < abs( + out_weight1.norm_mean - evidence.norm_mean + ) + + +def test_update_gamma(): + out = update(gamma(shape=2, scale=1), gamma(shape=3, scale=1.5)) + assert isinstance(out, GammaDistribution) + # Shape should be the sum of shapes + assert out.shape == 5 + # Scale should be between the two scales (weighted harmonic mean) + assert 1 < out.scale < 1.5 + + +def test_update_gamma_evidence_weight(): + prior = gamma(shape=2, scale=1) + evidence = gamma(shape=3, scale=2) + out = update(prior, evidence, evidence_weight=2) + assert isinstance(out, GammaDistribution) + # Shape should be prior.shape + evidence_weight * evidence.shape + assert out.shape == 2 + 2 * 3 # = 8 + + def test_update_not_implemented(): with pytest.raises(ValueError) as excinfo: - update(gamma(1), gamma(2)) + update(poisson(1), poisson(2)) assert "not supported" in str(excinfo.value)