Skip to content

Instantly share code, notes, and snippets.

@bradleyboehmke
Created December 22, 2018 15:58
Show Gist options
  • Select an option

  • Save bradleyboehmke/a40087e2eb8ed75c11eaf2ddadc585a4 to your computer and use it in GitHub Desktop.

Select an option

Save bradleyboehmke/a40087e2eb8ed75c11eaf2ddadc585a4 to your computer and use it in GitHub Desktop.
Illustrates how smaller learning rates require more trees to converge to optimal solutions
library(dplyr)
library(tidyr)
library(ggplot2)
library(gganimate)
set.seed(1112)
df <- tibble::tibble(
x = seq(from = 0, to = 2 * pi, length = 500),
y = sin(x) + rnorm(length(x), sd = 0.5),
truth = sin(x)
)
results <- data.frame(NULL)
for(i in c(1, seq(25, 3000, by = 25))) {
set.seed(8451)
gbm_model_hi <- gbm::gbm(truth ~ x, data = df, n.trees = i, shrinkage = .1)
set.seed(8451)
gbm_model_med <- gbm::gbm(truth ~ x, data = df, n.trees = i, shrinkage = .01)
set.seed(8451)
gbm_model_lo <- gbm::gbm(truth ~ x, data = df, n.trees = i, shrinkage = .001)
add_results <- df
add_results$`shrinkage = .1` <- predict(gbm_model_hi, df, n.trees = i)
add_results$`shrinkage = .01` <- predict(gbm_model_med, df, n.trees = i)
add_results$`shrinkage = .001` <- predict(gbm_model_lo, df, n.trees = i)
add_results$trees <- i
results <- rbind(results, add_results)
}
p <- results %>%
gather(rate, prediction, `shrinkage = .1`:`shrinkage = .001`) %>%
mutate(
trees = as.integer(trees),
rate = fct_relevel(rate, "shrinkage = .1", "shrinkage = .01", "shrinkage = .001")
) %>%
ggplot(aes(x, prediction)) +
ylab("y") +
geom_line(data = df, aes(x, truth), colour = "blue", size = 1) +
geom_line(colour = "red", size = 1) +
facet_wrap(~ rate, nrow = 1) +
labs(title = 'Number of trees: {frame_time}') +
transition_time(trees)
animate(p, renderer = gifski_renderer(), device = "png")
results %>%
gather(rate, prediction, `shrinkage = .1`, `shrinkage = .001`) %>%
ggplot(aes(x, prediction)) +
ylab("y") +
geom_line(data = df, aes(x, truth), colour = "blue", size = 1) +
geom_line(colour = "red", size = 1) +
facet_grid(trees ~ rate)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment