top of page
  • Writer's pictureJackson Curtis

XGBoost and Model Assumptions

Updated: Jun 28, 2021

My post this week is simply to emphasize the work Saupin Guillame found here. In his post, Saupin demonstrates that a major limitation of XGBoost is that it will never extrapolate beyond the extremes of the training data, even when such extrapolation may be desirable. This makes it useless for a wide variety of modeling things like time series and growth curves. This tied into a larger point about data modeling in general: that there’s no such thing as an assumptionless model. Every model you can fit will come with some kind of limitation, and there is a danger in thinking you can use a model without understanding it because the conclusions you draw may be invalidated by the assumptions you don’t understand.


XGBoost and Extrapolation


What do I mean that XGBoost does not extrapolate, and why is it a problem? Consider an extremely simple time series:

library(tidyverse)
set.seed(12)
data = tibble(x=1:300, date=Sys.Date()-x, y=350+-1*x+rnorm(300, 0, 10))
ggplot(data, aes(x=date, y=y))+geom_point()+theme_minimal()



Now, let’s fit an XGBoost model to the data and look at the model fit:


library(tidymodels)

xgb_spec <- boost_tree(
  trees = 1000, 
  tree_depth = 2, min_n = 5, 
  loss_reduction = .01,                    
  mtry = 1,         ## randomness
  learn_rate = .01,                         ## step size
) %>% 
  set_engine("xgboost") %>% 
  set_mode("regression")

xgb_wf <- workflow() %>%
  add_formula(y ~ x) %>%
  add_model(xgb_spec) %>%
  fit(data=data)


preds = tibble(x=1:300) %>% 
mutate(pred= predict(xgb_wf, .) %>%pull,
       date=Sys.Date()-x)

ggplot(data, aes(x=date, y=y))+
geom_point()+
geom_line(data=preds, size=2, aes(y=pred, color="XGBoost Prediction"))+
theme_minimal()+
theme(legend.title = element_blank())+
ylim(c(0, 400))


XGBoost fits the observable data extremely well. While it’s not a perfectly smooth line, we don’t expect it to be. Tree models are never going to be perfectly smooth (that’s why we use them), but we could probably tweak some hyperparameters to make it smoother. Now what happens when we project this time series out a few months?

preds = tibble(x=-150:300) %>% 
mutate(pred= predict(xgb_wf, .) %>%pull,
       date=Sys.Date()-x)

ggplot(data, aes(x=date, y=y))+
geom_point()+
geom_line(data=preds, size=2, aes(y=pred, color="XGBoost Prediction"))+
theme_minimal()+
theme(legend.title = element_blank(), legend.position = "bottom")+
ylim(c(0, 480))+
ggtitle("XGBoost Time Series Projection")

What happened to our smoothly increasing trend? After months and months of a linearly increasing trend, what does XGBoost predict for the coming months? No linear trend at all. A complete flatline from June onward.


So what went wrong? Saupin’s post linked above goes into much more mathematical detail about XGBoost, but in short, tree based models don’t extrapolate. They will never predict a point higher than the observed data. This can be seen quite intuitively by thinking about a decision tree. The end of every decision tree is something in the form of “if x>100 predict y1, else predict y2.” None of the standard tree models like Random Forest or XGBoost (although I’m sure there are interesting extensions) have rules like “if x>100 predict 350-x.”**


All models have assumptions


Typically people’s standard reason for turning away from linear regression comes down to a desire to get away from its restrictive assumptions. And they are restrictive, and if those assumptions aren’t met in the problem at hand you definitely should find a modeling technique that doesn’t employ those assumptions. Where I think people frequently cross the line is assuming machine learning offers a way to get predictions without assumptions. No one talks about the assumptions or XGBoost, but the assumption of “no trends continue outside the range of the observable data” is a HUGE assumption and is wildly inappropriate for many types of prediction problems (like time series). Similarly, one of the reasons I feel that progress in deep learning is so slow is that it sells itself as the true assumptionless model, yet every single network topology has its own assumptions that are appropriate for different problems. CNNs, RNNs, LSTMs, Transformers, etc have all been so successful because they’ve managed to align the data they’re trying to model with the assumptions their model makes. Deep learning models are extremely flexible, but you pay for that flexibility with opaqueness in understanding the limitations of different architectures.



**A side note for time series enthusiasts

The above example demonstrates one reason important difference between normal machine learning and time series modeling. In normal machine learning, you typically do k-fold cross validation by randomly selecting a group of points to hold out and fitting the model to the rest. In time series, you should always do cross validation by holding out a future set and training on past data, so for example, holding out Jan 2021 and training on 2020, then holding out February and training on everything up till January. I used to think this was important only with models that specifically modeled autocorrelation, but this is an excellent example of a case where if you did normal k-fold validation you would report incredible accuracy, but your real world predictions would be atrocious. However, if you properly followed the typical time series validation scheme, you would get a more accurate estimate of the real-world error and (hopefully) be able to build a better model.

0 comments

コメント


Post: Blog2_Post
bottom of page