Created
October 8, 2025 10:04
-
-
Save carlislerainey/37521a4fed2ca7d3c74b6408c35ccf24 to your computer and use it in GitHub Desktop.
Illustrating the Metropolis algorithm with a beta(11, 165) posterior
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # ---- setup ---- | |
| # nice printing | |
| options(digits = 3) | |
| # load packages | |
| library(tidyverse) | |
| library(posterior) | |
| library(foreach) | |
| library(doParallel) | |
| library(scales) | |
| library(prettyunits) | |
| # ---- metropolis ---- | |
| # create function | |
| metrop <- function(logf, theta_start, S = 10000, tau = 0.1, progress_bar = FALSE, ...) { | |
| # record start time | |
| start_time <- Sys.time() | |
| # initialize matrix of samples with starting values | |
| k <- length(theta_start) | |
| samples <- matrix(NA_real_, nrow = S, ncol = k) | |
| samples[1, ] <- theta_start | |
| n_accepted <- 0 | |
| # optionally initialize progress bar | |
| if (progress_bar) { | |
| pb <- txtProgressBar(min = 0, max = S, style = 3) | |
| } | |
| # proceed with algorithm | |
| for (s in 2:S) { | |
| # extract current location | |
| current <- samples[s - 1, ] | |
| # generate symmetric random-walk proposal | |
| proposed_move <- runif(k, -tau, tau) | |
| proposal <- current + proposed_move | |
| # acceptance step | |
| delta <- logf(proposal, ...) - logf(current, ...) | |
| if (delta > 0) { | |
| accept <- TRUE | |
| } else { | |
| accept <- (log(runif(1)) <= delta) | |
| } | |
| # update samples | |
| if (accept) { | |
| samples[s, ] <- proposal | |
| n_accepted <- n_accepted + 1 | |
| } else { | |
| samples[s, ] <- current | |
| } | |
| # update progress bar occasionally | |
| if (progress_bar && (s %% (S / 100) == 0 || s == S)) { | |
| setTxtProgressBar(pb, s) | |
| } | |
| } | |
| # close progress bar if used | |
| if (progress_bar) close(pb) | |
| # print summary report | |
| message( | |
| paste0( | |
| "💪 Successfully generated ", scales::comma(S), " dependent samples! 🎉\n\n", | |
| "✅ Accepted moves: ", scales::comma(n_accepted), "\n", | |
| "﹪ Acceptance rate: ", scales::percent(n_accepted / (S - 1), accuracy = 1), | |
| " (tune tau so that acceptance rate is about 20%-50%).\n", | |
| "⏰ Total time: ", prettyunits::pretty_dt(Sys.time() - start_time), "\n", | |
| "🧠 Reminder: if a proposal is not accepted, the previous sample is reused. This makes the samples depending on starting values and each other. Discard initial samples, use R-hat to assess convergence, and use ESS to assess effective sample size." | |
| ) | |
| ) | |
| list( | |
| prop_accepted = n_accepted / (S - 1), | |
| samples = samples | |
| ) | |
| } | |
| # ---- beta(11, 165) posterior ----- | |
| # log-posterior distribution | |
| logf <- function(p) { | |
| ifelse(p <= 0 | p >= 1, -Inf, dbeta(p, 11, 165, log = TRUE)) | |
| } | |
| # run metrop | |
| m <- metrop( | |
| logf, | |
| theta_start = 0.5, | |
| S = 10000, | |
| tau = 0.1, | |
| progress_bar = TRUE) | |
| # closed-form vs simulation mean | |
| beta_mean_cf <- 11 / (11 + 165) | |
| beta_mean_sim_all <- mean(m$samples) | |
| # ---- plotting the samples ---- | |
| par(mfrow = c(2, 2)) | |
| hist(m$samples, main = "All 10,000 samples", xlab = "p") | |
| plot(m$samples, type = "l", main = "All 10,000 samples", ylab = "p", xlab = "Iteration") | |
| plot(m$samples[1:100], type = "l", main = "First 100 samples", ylab = "p", xlab = "Iteration") | |
| plot(m$samples[1001:1100], type = "l", main = "Samples 1,001–1,100", ylab = "p", xlab = "Iteration") | |
| # ---- burn-in ---- | |
| # discard early samples that depend on starting value | |
| mean(m$samples[5001:10000, 1]) | |
| # ---- rhat; short run ---- | |
| # 25 iterations | |
| m1 <- metrop(logf, theta_start = 0.01, S = 25) | |
| m2 <- metrop(logf, theta_start = 0.25, S = 25) | |
| m3 <- metrop(logf, theta_start = 0.75, S = 25) | |
| m4 <- metrop(logf, theta_start = 0.99, S = 25) | |
| # combine samples | |
| matrix_of_chains <- cbind( | |
| m1$samples[, 1], | |
| m2$samples[, 1], | |
| m3$samples[, 1], | |
| m4$samples[, 1] | |
| ) | |
| # compute R-hat | |
| posterior::rhat(matrix_of_chains) | |
| # plot chains | |
| colnames(matrix_of_chains) <- paste("Chain", 1:ncol(matrix_of_chains)) | |
| gg_df <- matrix_of_chains |> | |
| as_tibble() |> | |
| mutate(Iteration = 1:n()) |> | |
| pivot_longer(cols = `Chain 1`:`Chain 4`, names_to = "Chain", values_to = "Sample") |> | |
| separate(Chain, into = c("tmp", "Chain Number"), remove = FALSE) |> | |
| mutate(`Chain Number` = as.numeric(`Chain Number`), | |
| Chain = reorder(Chain, `Chain Number`, ordered = TRUE)) | |
| ggplot(gg_df, aes(x = Iteration, y = Sample, group = Chain, color = Chain)) + | |
| geom_line() + | |
| labs(title = "Metropolis samples (25 iters per chain)", | |
| subtitle = "Short runs from overdispersed starts show strong start-value dependence") + | |
| theme_minimal() | |
| # ---- rhat; long run ---- | |
| # 25k iterations | |
| m1 <- metrop(logf, theta_start = 0.01, S = 25000) | |
| m2 <- metrop(logf, theta_start = 0.25, S = 25000) | |
| m3 <- metrop(logf, theta_start = 0.75, S = 25000) | |
| m4 <- metrop(logf, theta_start = 0.99, S = 25000) | |
| # combine samples | |
| matrix_of_chains <- cbind( | |
| m1$samples[10001:25000, 1], | |
| m2$samples[10001:25000, 1], | |
| m3$samples[10001:25000, 1], | |
| m4$samples[10001:25000, 1] | |
| ) | |
| # compute R-hat | |
| posterior::rhat(matrix_of_chains) | |
| # plot chains | |
| colnames(matrix_of_chains) <- paste("Chain", 1:ncol(matrix_of_chains)) | |
| gg_df <- matrix_of_chains |> | |
| as_tibble() |> | |
| mutate(Iteration = 1:n()) |> | |
| pivot_longer(cols = `Chain 1`:`Chain 4`, names_to = "Chain", values_to = "Sample") |> | |
| separate(Chain, into = c("tmp", "Chain Number"), remove = FALSE) |> | |
| mutate(`Chain Number` = as.numeric(`Chain Number`), | |
| Chain = reorder(Chain, `Chain Number`, ordered = TRUE)) | |
| ggplot(gg_df, aes(x = Iteration, y = Sample, group = Chain, color = Chain)) + | |
| geom_line() + | |
| labs(title = "Metropolis samples (25000 iters per chain)", | |
| subtitle = "Short runs from overdispersed starts show strong start-value dependence") + | |
| theme_minimal() | |
| # ---- parallel computation ---- | |
| # number of cores | |
| parallel::detectCores(logical = FALSE) | |
| # set up 10 cores | |
| cl <- makeCluster(10) | |
| registerDoParallel(cl) | |
| # run 10 chains in parallel | |
| starting_values <- seq(0.05, 0.95, length.out = 10) | |
| matrix_of_chains <- foreach(s = starting_values, .combine = cbind) %dopar% { | |
| m <- metrop(logf, theta_start = s, S = 250) | |
| m$samples | |
| } | |
| stopCluster(cl) | |
| # plot chains | |
| colnames(matrix_of_chains) <- paste("Chain", 1:ncol(matrix_of_chains)) | |
| gg_df <- matrix_of_chains |> | |
| as_tibble() |> | |
| mutate(Iteration = 1:n()) |> | |
| pivot_longer(cols = `Chain 1`:`Chain 10`, names_to = "Chain", values_to = "Sample") |> | |
| separate(Chain, into = c("tmp", "Chain Number"), remove = FALSE) |> | |
| mutate(`Chain Number` = as.numeric(`Chain Number`), | |
| Chain = reorder(Chain, `Chain Number`, ordered = TRUE)) | |
| ggplot(gg_df, aes(x = Iteration, y = Sample, group = Chain, color = Chain)) + | |
| geom_line() + | |
| labs(title = "Metropolis samples (250 iters per chain)", | |
| subtitle = "Short runs from overdispersed starts show strong start-value dependence") + | |
| theme_minimal() | |
| # ---- ess ---- | |
| dim(matrix_of_chains) # rows = iterations, cols = chains | |
| posterior::ess_bulk(matrix_of_chains) | |
| posterior::ess_tail(matrix_of_chains) | |
| # ---- developing our tuning intuition ---- | |
| m1 <- metrop(logf, theta_start = 0.99, S = 1000, tau = 0.1) | |
| m2 <- metrop(logf, theta_start = 0.99, S = 1000, tau = 0.01) | |
| m3 <- metrop(logf, theta_start = 0.99, S = 1000, tau = 0.001) | |
| m4 <- metrop(logf, theta_start = 0.99, S = 1000, tau = 3) | |
| par(mfrow = c(2, 2)) | |
| plot(m1$samples, type = "l", main = "tau = 0.1 (nice)", ylim = c(0, 1), ylab = "p", xlab = "iter") | |
| plot(m2$samples, type = "l", main = "tau = 0.01 (accept too often)", ylim = c(0, 1), ylab = "p", xlab = "iter") | |
| plot(m3$samples, type = "l", main = "tau = 0.001 (way too small)", ylim = c(0, 1), ylab = "p", xlab = "iter") | |
| plot(m4$samples, type = "l", main = "tau = 3 (accept too rarely)", ylim = c(0, 1), ylab = "p", xlab = "iter") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment