Created
December 22, 2018 15:58
-
-
Save bradleyboehmke/a40087e2eb8ed75c11eaf2ddadc585a4 to your computer and use it in GitHub Desktop.
Illustrates how smaller learning rates require more trees to converge to optimal solutions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| library(dplyr) | |
| library(tidyr) | |
| library(ggplot2) | |
| library(gganimate) | |
| set.seed(1112) | |
| df <- tibble::tibble( | |
| x = seq(from = 0, to = 2 * pi, length = 500), | |
| y = sin(x) + rnorm(length(x), sd = 0.5), | |
| truth = sin(x) | |
| ) | |
| results <- data.frame(NULL) | |
| for(i in c(1, seq(25, 3000, by = 25))) { | |
| set.seed(8451) | |
| gbm_model_hi <- gbm::gbm(truth ~ x, data = df, n.trees = i, shrinkage = .1) | |
| set.seed(8451) | |
| gbm_model_med <- gbm::gbm(truth ~ x, data = df, n.trees = i, shrinkage = .01) | |
| set.seed(8451) | |
| gbm_model_lo <- gbm::gbm(truth ~ x, data = df, n.trees = i, shrinkage = .001) | |
| add_results <- df | |
| add_results$`shrinkage = .1` <- predict(gbm_model_hi, df, n.trees = i) | |
| add_results$`shrinkage = .01` <- predict(gbm_model_med, df, n.trees = i) | |
| add_results$`shrinkage = .001` <- predict(gbm_model_lo, df, n.trees = i) | |
| add_results$trees <- i | |
| results <- rbind(results, add_results) | |
| } | |
| p <- results %>% | |
| gather(rate, prediction, `shrinkage = .1`:`shrinkage = .001`) %>% | |
| mutate( | |
| trees = as.integer(trees), | |
| rate = fct_relevel(rate, "shrinkage = .1", "shrinkage = .01", "shrinkage = .001") | |
| ) %>% | |
| ggplot(aes(x, prediction)) + | |
| ylab("y") + | |
| geom_line(data = df, aes(x, truth), colour = "blue", size = 1) + | |
| geom_line(colour = "red", size = 1) + | |
| facet_wrap(~ rate, nrow = 1) + | |
| labs(title = 'Number of trees: {frame_time}') + | |
| transition_time(trees) | |
| animate(p, renderer = gifski_renderer(), device = "png") | |
| results %>% | |
| gather(rate, prediction, `shrinkage = .1`, `shrinkage = .001`) %>% | |
| ggplot(aes(x, prediction)) + | |
| ylab("y") + | |
| geom_line(data = df, aes(x, truth), colour = "blue", size = 1) + | |
| geom_line(colour = "red", size = 1) + | |
| facet_grid(trees ~ rate) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment