-
Notifications
You must be signed in to change notification settings - Fork 66
Open
Description
# compare adaptive HMC to HMC on some knotty problems
library(greta)
n_chains <- 10
warmup <- 2000
# very hard model
correlation <- 0.99
rel_sd_range <- 1e5
# # very easy model
# correlation <- 0
# rel_sd_range <- 1
# sample from an correlated multivariate Gaussian with different marginal
# variances
dim <- 4
C <- matrix(correlation, dim, dim)
diag(C) <- 1
sd <- seq(1, rel_sd_range, length.out = dim)
Sigma <- diag(sd) %*% C %*% diag(sd)
x <- multivariate_normal(mean = zeros(1, dim),
Sigma = Sigma)
m <- model(x)
init <- initials(x = t(rep(0, dim)))
# default HMC
hmc_time <- system.time(
draws_hmc <- mcmc(m,
warmup = warmup,
chains = n_chains,
sampler = hmc())
)
# HMC with insufficient leapfrog steps (representative of more complex
# posteriors, when using the default step sizes)
hmc_l2_time <- system.time(
draws_hmc_l2 <- mcmc(m,
warmup = warmup,
chains = n_chains,
sampler = hmc(Lmin = 1, Lmax = 3))
)
# adaptive HMC, default initialisation
ahmc_time <- system.time(
draws_ahmc <- mcmc(m,
warmup = warmup,
chains = n_chains,
sampler = adaptive_hmc())
)
# adaptive HMC, initialised well, at a single point
ahmc_init_time <- system.time(
draws_ahmc_init <- mcmc(m,
warmup = warmup,
chains = n_chains,
initial_values = init,
sampler = adaptive_hmc())
)
# convergence/sampling statistics
# run times are similar (hmc_l2 expected to be faster)
hmc_time["elapsed"]
hmc_l2_time["elapsed"]
ahmc_time["elapsed"]
ahmc_init_time["elapsed"]
# what's the sampler efficiency (effective numbers of samples, per second
# overall sampling) for the worst-sampled variable
worst_efficiency <- function(draws, time) {
neffs <- coda::effectiveSize(draws)
min(neffs) / time["elapsed"]
}
# higher is better
worst_efficiency(draws_hmc, hmc_time)
worst_efficiency(draws_hmc_l2, hmc_l2_time)
worst_efficiency(draws_ahmc, ahmc_time)
worst_efficiency(draws_ahmc_init, ahmc_init_time)
# what's the upper bound for the univariate potential scale reduction factor,
# for the worst-sampled variable
worst_rhat <- function(draws) {
rhats <- coda::gelman.diag(draws,
autoburnin = FALSE,
multivariate = FALSE)
uppers <- rhats$psrf[, 2]
max(uppers)
}
# lower is better, <1.1 is ideal
worst_rhat(draws_hmc)
worst_rhat(draws_hmc_l2)
worst_rhat(draws_ahmc)
worst_rhat(draws_ahmc_init)
bayesplot::mcmc_trace(draws_hmc)
bayesplot::mcmc_trace(draws_ahmc)
plot(draws_hmc)
plot(draws_ahmc)
par(mfrow = c(1, 1))
plot(as.matrix(draws_ahmc)[, 1:2],
cex = 0.5,
pch = 16,
col = grey(0.8))
points(as.matrix(draws_hmc)[, 1:2],
cex = 0.5,
pch = 16,
col = "red")Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels