class: center, middle, inverse, title-slide #
PPOL670 | Introduction to Data Science for Public Policy
Week 12
Interpretable Machine Learning
###
Prof. Eric Dunford ◆ Georgetown University ◆ McCourt School of Public Policy ◆
eric.dunford@georgetown.edu
--- layout: true <div class="slide-footer"><span> PPOL670 | Introduction to Data Science for Public Policy           Week 12 <!-- Week of the Footer Here -->              Interpretable Machine Learning <!-- Title of the lecture here --> </span></div> --- class: outline # Outline for Today - Approaches to **_Interpretable Machine Learning_** + _Variable Importance_ + _Partial Dependencies_ + _Individual Conditional Expectations_ + _Global Surrogate Models_ --- class: newsection # Interpretable Machine Learning --- ## Model Interpretation - Knowing a model is predictive is _necessary_ but rarely _sufficient_ to draw **_substantive insights_**. -- - In the social sciences, we are interested in understanding **_why_** certain features matter in an effort to detect potential **_interventions_**: if we change `\(X\)` will we get a different outcome? -- - Interpretability offers insights into the features the model **_relies on to make its prediction_**. -- - In addition, interpretability is a useful debugging tool for **_detecting bias_** in machine learning models. -- - Model needs to be a fairly **_good approximation of the data generating process_** (i.e. high predictive accuracy) for interpretation to matter --- ## Variable Importance ![:space 2] - Variable/Feature importance is concerned with how much a given model **_relies on a set of variables/features to make accurate predictions_**. - If those variables/features were removed, the model should **_perform worse_** (i.e. diminished predictive capacity). - Determining variable importance helps with **_variable selection_**. - What variables could we drop from the model (not contributing much information)? - What variables should we make sure to always measure and use in the model? --- ## Variable Importance Consider output from a simple multivariate variable OLS regression, what variables seem to matter most? ![:space 3] ``` ## # A tibble: 4 x 5 ## term estimate std.error statistic p.value ## <chr> <dbl> <dbl> <dbl> <dbl> ## 1 (Intercept) -0.00776 0.0314 -0.247 0.805 ## 2 x1 2.00 0.0318 62.9 0 ## 3 x2 -0.00812 0.0312 -0.260 0.795 ## 4 x3 0.0515 0.0321 1.60 0.109 ``` -- ![:space 3] - `x1` is clearly both substantively and statistically significant. We should keep it in the model. --- ## Variable Importance ![:space 2] - Some models offer a natural way of determining importance: - _Regression_: coefficient and test statistic size - _Trees_: split importance ![:space 2] - But other models are more complicated (e.g. support vector machines, KNN, Neural Networks). We call these **_black box_** models because it's difficult to "peer inside" the model to understand how it works. ![:space 2] - We need ways of determining predictive performance that are **_model agnostic_** (i.e. doesn't depend on the type of model you use). --- ### Permutation Importance - Permutation Importance offers a model agnostic way to determine variable importance. - The idea: **_scramble the data_** one variable at a time and see if the predictive performance of the model _decreases_. -- - How it works: + **_Train_** a model + **_Permute_** (i.e. scramble the order) a single variable/feature in the training data. + Use the model to **_predict_** on the data with the permuted variable + See if there is a **_drop in predictive performance_** + **_Repeat_** --- ### The logic of permuting Let's walk through the logic of why this would work using simulated data. ```r # Simulate some data set.seed(123) N = 1000 # Number of observations x1 <- rnorm(N) # independent variable x2 <- rnorm(N) # independent variable error <- rnorm(N) # error y <- 1 + 2*x1 + -2*x2 + error # dependent variable D <- tibble(y,x1,x2) # Gather as data frame head(D) ``` ``` ## # A tibble: 6 x 3 ## y x1 x2 ## <dbl> <dbl> <dbl> ## 1 1.36 -0.560 -0.996 ## 2 2.86 -0.230 -1.04 ## 3 3.61 1.56 -0.0180 ## 4 2.62 0.0705 -0.132 ## 5 6.53 0.129 -2.55 ## 6 1.73 1.72 1.04 ``` --- ### The logic of permuting ```r # Run Model model <- lm(y ~ x1 + x2, data = D) broom::tidy(model) %>% mutate_if(is.numeric,function(x) round(x,2)) ``` ``` ## # A tibble: 3 x 5 ## term estimate std.error statistic p.value ## <chr> <dbl> <dbl> <dbl> <dbl> ## 1 (Intercept) 0.98 0.03 31.6 0 ## 2 x1 1.98 0.03 63.1 0 ## 3 x2 -1.97 0.03 -64.1 0 ``` --- ### The logic of permuting ```r # Randomly scamble the order of X1 and estimate OLS model model <- D %>% * mutate(x1 = sample(x1)) %>% lm(y ~ x1 + x2, data = .) broom::tidy(model) %>% mutate_if(is.numeric,function(x) round(x,2)) ``` ``` ## # A tibble: 3 x 5 ## term estimate std.error statistic p.value ## <chr> <dbl> <dbl> <dbl> <dbl> ## 1 (Intercept) 1 0.07 14.5 0 ## 2 x1 0.03 0.07 0.38 0.71 ## 3 x2 -1.8 0.07 -26.3 0 ``` --- ### The logic of permuting ![:space 5] - Permuting a variable effectively **_breaks the statistical relationship_** between outcome and predictor. - If a **_variable doesn't matter, then permuting it won't change the performance_** (as the model already doesn't rely on this variable ) - We must permute each variable **_multiple times_** as permuting is a random process + We want to ensure a specific importance ordering isn't a results of a single permutation. --- ### Example What features in the data best predict whether someone will vote or not? ```r vote_data ``` ``` ## # A tibble: 1,000 x 5 ## voted age political_therm visited_europe eat_bread ## <fct> <dbl> <dbl> <int> <int> ## 1 Yes 0.24 0.84 0 0 ## 2 No 0.53 0.49 0 1 ## 3 No 0.53 0.11 0 1 ## 4 No 0.54 0.35 0 1 ## 5 Yes 0.67 0.76 0 1 ## 6 No 0.54 0.39 0 1 ## 7 Yes 0.19 0.46 1 1 ## 8 No 0.31 0.11 0 0 ## 9 Yes 0.56 0.96 1 1 ## 10 Yes 0.47 0.99 0 0 ## # … with 990 more rows ``` --- ### Example Run Machine Learning Model. ```r # Cross fold validation set.seed(1988) folds <- createFolds(vote_data$voted, k = 5) control_conditions <- trainControl(method='cv', summaryFunction = twoClassSummary, classProbs = TRUE, index = folds ) # Random Forest Model rf_model <-train(voted ~ ., data=vote_data, method = "ranger", metric = "ROC", trControl = control_conditions) # Predictive accuracy pred <- predict(rf_model,vote_data) Metrics::accuracy(vote_data$voted,pred) ``` ``` ## [1] 0.934 ``` --- ### Example Feature importance. ```r require(vip) vi_permute(rf_model, # Machine learning model train = vote_data, # Training data nsim = 10, # Number of times to permute each variable target = "voted", # outcome reference_class = "Yes", # what class are you predicting metric = "accuracy", # metric pred_wrapper = predict) # prediction function ``` ``` ## # A tibble: 4 x 3 ## Variable Importance StDev ## <chr> <dbl> <dbl> ## 1 age 0.197 0.00587 ## 2 political_therm 0.381 0.0136 ## 3 visited_europe 0.0812 0.00555 ## 4 eat_bread 0.0759 0.00687 ``` --- ### Example Use `vip` to auto generate a feature importance plot ```r vip(rf_model,train = vote_data, nsim = 10,method="permute", geom = "boxplot",target = "voted", reference_class = "Yes",metric = "accuracy", pred_wrapper = predict) ``` <img src="lecture-week-12-interpretable-ml-ppol670_files/figure-html/unnamed-chunk-9-1.png" style="display: block; margin: auto;" /> --- ### Example We can use feature importance to **_select_** relevant variables, refining our models. ```r # Random Forest Model rf_model2 <-train(voted ~ age + political_therm, data=vote_data, method = "ranger", metric = "ROC", trControl = control_conditions) ``` ``` ## note: only 1 unique complexity parameters in default grid. Truncating the grid to 1 . ``` ```r # Predictive accuracy pred <- predict(rf_model2,vote_data) Metrics::accuracy(vote_data$voted,pred) ``` ``` ## [1] 0.978 ``` --- ## Partial Dependence Plots (PDP) ![:space 5] - Variable importance cannot tell us how variables **_relate_** to the outcome. - Partial dependency plots show the **_marginal effect_** one or two features have on the predicted outcome of the model. - A partial dependence plot can show whether the **_relationship_** between the target and a feature is linear, monotonic or more complex. - The partial dependence plot is a **_global method_**: The method considers all instances and gives a statement about the global relationship of a feature with the predicted outcome. --- ## Partial Dependence Plots (PDP) ![:space 5] - The steps: + Train a model + Identify the features that matter most (feature importance) + Manipulate the values of those features (one at a time) and take the average prediction, holding all other features at their observed values. + Plot the values and interpret the curve. --- ## Partial Dependence Plots (PDP) Recall our model from before. Let's explore the marginal effect of age on the likelihood of voting. ```r require(pdp) partial(rf_model, pred.var = "age", plot = TRUE,prob=T, plot.engine = "ggplot2") ``` <img src="lecture-week-12-interpretable-ml-ppol670_files/figure-html/unnamed-chunk-11-1.png" style="display: block; margin: auto;" /> --- ## Partial Dependence Plots (PDP) Let's explore the marginal effect of age and political thermometer measure on the likelihood of voting. Useful for exploring **_interactions_** between data. ```r # Provide a grid of values to speed up computation time. # more values, the longer the wait (like tuning) *grid_values <- * expand.grid(age = seq(.18,.75,by=.1), * political_therm = seq(0,1,by=.05)) partial(rf_model, # Choose two predictive values * pred.var = c("age", "political_therm"), plot = TRUE, prob=T, #provide values to calc preds * pred.grid = grid_values, # Plotting engine and color scheme plot.engine = "ggplot2", palette = "magma") ``` --- ## Partial Dependence Plots (PDP) Let's explore the marginal effect of age and political thermometer measure on the likelihood of voting. Useful for exploring **_interactions_** between data. <img src="lecture-week-12-interpretable-ml-ppol670_files/figure-html/unnamed-chunk-13-1.png" style="display: block; margin: auto;" /> --- ## Individual Conditional Expectation Plots (ICE) - Partial dependency offers a plot of the **_average marginal effect_**; however, can obscure a heterogeneous relationship created by **_interactions_**. <img src="lecture-week-12-interpretable-ml-ppol670_files/figure-html/unnamed-chunk-14-1.png" style="display: block; margin: auto;" /> --- ## Individual Conditional Expectation Plots (ICE) - Partial dependency offers a plot of the **_average marginal effect_**; however, can obscure a heterogeneous relationship created by **_interactions_**. ![:space 3] - ICE plots plots show the **_marginal effect for each observation_** in the data. - We can observe if there are **_divergence_** or **_convergence_** in the predicted effect cross observations. - The PDP is just the average taken across the different ICE curves. --- ## Individual Conditional Expectation Plots (ICE) ```r partial(rf_model, * ice=T,alpha=.05, pred.var = "age", plot = TRUE,prob=T, plot.engine = "ggplot2") ``` <img src="lecture-week-12-interpretable-ml-ppol670_files/figure-html/unnamed-chunk-15-1.png" style="display: block; margin: auto;" /> --- ## Individual Conditional Expectation Plots (ICE) ```r partial(rf_model, * ice=T,alpha=.05, center=T, pred.var = "age", plot = TRUE,prob=T, plot.engine = "ggplot2") ``` <img src="lecture-week-12-interpretable-ml-ppol670_files/figure-html/unnamed-chunk-16-1.png" style="display: block; margin: auto;" /> --- ## Global Surrogate Models - Black box models (e.g. Random Forest) often do a better job at fitting the data but at the expense of interpretability. - A global surrogate model is an interpretable model (e.g. decision tree, linear model) trained to **_approximate the predictions of a black box model_**. - Steps: + Train a model + Get the predictions from that model. + Train an interpretable model (e.g. cart, lm) on those predictions, using the original training data or a subset of that data. + Examine the fit. If good (low-ish error), interpret the output. --- ## Global Surrogate Models ```r # Extract the predictions from the black box model y_probs <- predict(rf_model,vote_data,type = "prob")$Yes # Generate a new data frame vote_data2 <- vote_data %>% mutate(y_probs=y_probs) %>% select(-voted) # Plot the predictions (just to see what these probabilities look like) vote_data2 %>% ggplot(aes(y_probs)) + geom_histogram() ``` <img src="lecture-week-12-interpretable-ml-ppol670_files/figure-html/unnamed-chunk-17-1.png" style="display: block; margin: auto;" /> --- ## Global Surrogate Models ```r # Train a Decision Tree surrogate model require(rpart) # Let's use the model package directly (rather than in caret) cart_surrogate <-rpart(y_probs ~ .,data = vote_data2, control = rpart.control(maxdepth = 3)) rattle::fancyRpartPlot(cart_surrogate,sub = "",type = 1) # plot the tree ``` <img src="lecture-week-12-interpretable-ml-ppol670_files/figure-html/unnamed-chunk-18-1.png" style="display: block; margin: auto;" /> --- ## Global Surrogate Models - If the surrogate doesn't do a good job fitting the black box, then it's not useful for interpretation. - Using R-squared, we can easily measure how good the surrogate model is in approximating the black box predictions. ```r # Calculate the R squared y_hat <- predict(cart_surrogate) y <- vote_data2$y_probs TSS = sum((y-mean(y))^2) # Total Sum of Squares MSS = sum((y_hat-mean(y))^2) # Model Sum of Squares R2 = MSS/TSS # R-Squared (variance explained) round(R2,2) ``` ``` ## [1] 0.75 ``` - Ultimately, surrogate models are intuitive, model-agnostic ways of extracting substantive insights from a black box model.