Skip to content

code for comparing HMC to Adaptive HMC #790

@njtierney

Description

@njtierney
# compare adaptive HMC to HMC on some knotty problems

library(greta)

n_chains <- 10
warmup <- 2000


# very hard model
correlation <- 0.99
rel_sd_range <- 1e5

# # very easy model
# correlation <- 0
# rel_sd_range <- 1

# sample from an correlated multivariate Gaussian with different marginal
# variances
dim <- 4
C <- matrix(correlation, dim, dim)
diag(C) <- 1
sd <- seq(1, rel_sd_range, length.out = dim)
Sigma <- diag(sd) %*% C %*% diag(sd)
x <- multivariate_normal(mean = zeros(1, dim),
                         Sigma = Sigma)
m <- model(x)
init <- initials(x = t(rep(0, dim)))



# default HMC
hmc_time <- system.time(
  draws_hmc <- mcmc(m,
                    warmup = warmup,
                    chains = n_chains,
                    sampler = hmc())
)

# HMC with insufficient leapfrog steps (representative of more complex
# posteriors, when using the default step sizes)
hmc_l2_time <- system.time(
  draws_hmc_l2 <- mcmc(m,
                       warmup = warmup,
                       chains = n_chains,
                       sampler = hmc(Lmin = 1, Lmax = 3))
)

# adaptive HMC, default initialisation
ahmc_time <- system.time(
  draws_ahmc <- mcmc(m,
                     warmup = warmup,
                     chains = n_chains,
                     sampler = adaptive_hmc())
)

# adaptive HMC, initialised well, at a single point
ahmc_init_time <- system.time(
  draws_ahmc_init <- mcmc(m,
                          warmup = warmup,
                          chains = n_chains,
                          initial_values = init,
                          sampler = adaptive_hmc())
)

# convergence/sampling statistics

# run times are similar (hmc_l2 expected to be faster)
hmc_time["elapsed"]
hmc_l2_time["elapsed"]
ahmc_time["elapsed"]
ahmc_init_time["elapsed"]

# what's the sampler efficiency (effective numbers of samples, per second
# overall sampling) for the worst-sampled variable
worst_efficiency <- function(draws, time) {
  neffs <- coda::effectiveSize(draws)
  min(neffs) / time["elapsed"]
}

# higher is better
worst_efficiency(draws_hmc, hmc_time)
worst_efficiency(draws_hmc_l2, hmc_l2_time)
worst_efficiency(draws_ahmc, ahmc_time)
worst_efficiency(draws_ahmc_init, ahmc_init_time)


# what's the upper bound for the univariate potential scale reduction factor,
# for the worst-sampled variable
worst_rhat <- function(draws) {
  rhats <- coda::gelman.diag(draws,
                             autoburnin = FALSE,
                             multivariate = FALSE)
  uppers <- rhats$psrf[, 2]
  max(uppers)
}

# lower is better, <1.1 is ideal
worst_rhat(draws_hmc)
worst_rhat(draws_hmc_l2)
worst_rhat(draws_ahmc)
worst_rhat(draws_ahmc_init)

bayesplot::mcmc_trace(draws_hmc)
bayesplot::mcmc_trace(draws_ahmc)

plot(draws_hmc)
plot(draws_ahmc)

par(mfrow = c(1, 1))
plot(as.matrix(draws_ahmc)[, 1:2],
     cex = 0.5,
     pch = 16,
     col = grey(0.8))
points(as.matrix(draws_hmc)[, 1:2],
       cex = 0.5,
       pch = 16,
       col = "red")

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions