library(tidyverse)
## ── Attaching packages ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── tidyverse 1.3.0 ──
## ✓ ggplot2 3.3.2 ✓ purrr 0.3.4
## ✓ tibble 3.0.3 ✓ dplyr 1.0.2
## ✓ tidyr 1.1.1 ✓ stringr 1.4.0
## ✓ readr 1.3.1 ✓ forcats 0.5.0
## ── Conflicts ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── tidyverse_conflicts() ──
## x dplyr::filter() masks stats::filter()
## x dplyr::lag() masks stats::lag()
df <- readr::read_rds("df.rds")
step_2_a_df <- readr::read_rds("step_2_a_df.rds")
iii_models <- readr::read_rds("iii_models.rds")
This R Markdown file tackles part iv, specifically training, evaluating, tuning, and comparing models for the binary classifier outcome_2
as a function of xA
, xB
, x01:x11
.
We will use the caret
package to handle training, testing, and evaluation.
library(caret)
## Loading required package: lattice
##
## Attaching package: 'caret'
## The following object is masked from 'package:purrr':
##
## lift
Throughout all methods, we will use 5-fold cross validation, as our resampling method, by specifying "repeatedcv"
as the method
argument to a caret::trainControl()
. For this classification problem, we will use the Area under the ROC
curve as our primary performance metric. We must specify the summaryFunction
argument to be twoClassSummary
within the trainControl()
function in order to maximize the area under the ROC curve. We will also instruct caret
to return the class predicted probabilities.
my_ctrl <- caret::trainControl(method = "repeatedcv",
number = 5,
repeats = 5,
savePredictions = TRUE,
summaryFunction = twoClassSummary,
classProbs = TRUE)
roc_metric <- "ROC"
First we will train a logistic regression model with additive terms, using method = "glm"
in caret::train
. We will train the model for outcome_2
as a function of xA
, xB
, x01:x11
.
The main purpose of this logistic regression model is to provide a baseline comparison to the other complex models we will train.
set.seed(12345)
mod_glm <- caret::train(outcome_2 ~ .,
method = "glm",
metric = roc_metric,
trControl = my_ctrl,
preProcess = c("center", "scale"),
data = step_2_a_df)
mod_glm
## Generalized Linear Model
##
## 2013 samples
## 13 predictor
## 2 classes: 'Fail', 'Pass'
##
## Pre-processing: centered (15), scaled (15)
## Resampling: Cross-Validated (5 fold, repeated 5 times)
## Summary of sample sizes: 1610, 1611, 1611, 1610, 1610, 1611, ...
## Resampling results:
##
## ROC Sens Spec
## 0.6179373 0.5267005 0.6527407
Look at confusion matrix associated with the mod_glm
model.
confusionMatrix.train(mod_glm)
## Cross-Validated (5 fold, repeated 5 times) Confusion Matrix
##
## (entries are percentual average cell counts across resamples)
##
## Reference
## Prediction Fail Pass
## Fail 25.8 17.7
## Pass 23.2 33.3
##
## Accuracy (average) : 0.5911
We now try a regularization approach. Elastic net is a mixture between Lasso and Ridge penalties. We will train two different models with interactions, specifically one with all pair interactions between all step_2_a_df
input variables, and one with all triplet interactions.
Let’s first fit a regularized regression model with elastic net, on all pairwise interactions between all step_2_a_df
inputs, using caret::train
with method="glmnet"
. We specify centering and scaling as preprocessing steps.
set.seed(12345)
mod_glmnet_2 <- caret::train(outcome_2 ~ (.)^2,
method = "glmnet",
preProcess = c("center", "scale"),
metric = roc_metric,
trControl = my_ctrl,
data = step_2_a_df)
mod_glmnet_2
## glmnet
##
## 2013 samples
## 13 predictor
## 2 classes: 'Fail', 'Pass'
##
## Pre-processing: centered (117), scaled (117)
## Resampling: Cross-Validated (5 fold, repeated 5 times)
## Summary of sample sizes: 1610, 1611, 1611, 1610, 1610, 1611, ...
## Resampling results across tuning parameters:
##
## alpha lambda ROC Sens Spec
## 0.10 0.0001455549 0.6283084 0.5447716 0.6435927
## 0.10 0.0014555489 0.6170795 0.5269036 0.6496235
## 0.10 0.0145554894 0.6172807 0.5248731 0.6507895
## 0.55 0.0001455549 0.6238588 0.5435533 0.6443713
## 0.55 0.0014555489 0.6164439 0.5264975 0.6498196
## 0.55 0.0145554894 0.6185974 0.5220305 0.6577921
## 1.00 0.0001455549 0.6232111 0.5419289 0.6455401
## 1.00 0.0014555489 0.6165564 0.5269036 0.6494303
## 1.00 0.0145554894 0.6183943 0.5175635 0.6624523
##
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were alpha = 0.1 and lambda = 0.0001455549.
Create a custom tuning grid enet_grid
to try out many possible values of the penalty factor (lambda
) and the mixing fraction (alpha
).
enet_grid <- expand.grid(alpha = seq(0.1, 0.9, by = 0.1),
lambda = exp(seq(-6, 0.5, length.out = 25)))
Now retrain the pairwise interactions model using tuneGrid = enet_grid
.
set.seed(12345)
mod_glmnet_2_b <- caret::train(outcome_2 ~ (.)^2,
method = "glmnet",
preProcess = c("center", "scale"),
tuneGrid = enet_grid,
metric = roc_metric,
trControl = my_ctrl,
data = step_2_a_df)
mod_glmnet_2_b
## glmnet
##
## 2013 samples
## 13 predictor
## 2 classes: 'Fail', 'Pass'
##
## Pre-processing: centered (117), scaled (117)
## Resampling: Cross-Validated (5 fold, repeated 5 times)
## Summary of sample sizes: 1610, 1611, 1611, 1610, 1610, 1611, ...
## Resampling results across tuning parameters:
##
## alpha lambda ROC Sens Spec
## 0.1 0.002478752 0.6165188 0.524873096 0.6484537
## 0.1 0.003249784 0.6165159 0.525482234 0.6482624
## 0.1 0.004260651 0.6165151 0.525482234 0.6494322
## 0.1 0.005585954 0.6166770 0.525279188 0.6494303
## 0.1 0.007323502 0.6168269 0.525888325 0.6488477
## 0.1 0.009601525 0.6170501 0.526091371 0.6505972
## 0.1 0.012588142 0.6172558 0.526091371 0.6500128
## 0.1 0.016503766 0.6173706 0.525279188 0.6515700
## 0.1 0.021637371 0.6175801 0.522436548 0.6529349
## 0.1 0.028367816 0.6178745 0.522842640 0.6535202
## 0.1 0.037191811 0.6180846 0.522842640 0.6535174
## 0.1 0.048760568 0.6182851 0.521015228 0.6558522
## 0.1 0.063927861 0.6188195 0.521015228 0.6579882
## 0.1 0.083813041 0.6191233 0.520000000 0.6601269
## 0.1 0.109883635 0.6190885 0.516751269 0.6640095
## 0.1 0.144063659 0.6186933 0.512487310 0.6671182
## 0.1 0.188875603 0.6180372 0.499289340 0.6752868
## 0.1 0.247626595 0.6171168 0.480812183 0.6920189
## 0.1 0.324652467 0.6146657 0.438578680 0.7229401
## 0.1 0.425637741 0.6119760 0.360406091 0.7851963
## 0.1 0.558035146 0.6064190 0.272690355 0.8455127
## 0.1 0.731615629 0.5215393 0.018883249 0.9862136
## 0.1 0.959189457 0.5000000 0.000000000 1.0000000
## 0.1 1.257551613 0.5000000 0.000000000 1.0000000
## 0.1 1.648721271 0.5000000 0.000000000 1.0000000
## 0.2 0.002478752 0.6164164 0.526497462 0.6482624
## 0.2 0.003249784 0.6165596 0.525888325 0.6494303
## 0.2 0.004260651 0.6167460 0.526091371 0.6496263
## 0.2 0.005585954 0.6169060 0.525888325 0.6498205
## 0.2 0.007323502 0.6170325 0.526091371 0.6498167
## 0.2 0.009601525 0.6173449 0.524873096 0.6513730
## 0.2 0.012588142 0.6173960 0.522842640 0.6527407
## 0.2 0.016503766 0.6177171 0.523045685 0.6541037
## 0.2 0.021637371 0.6179976 0.522639594 0.6544902
## 0.2 0.028367816 0.6183068 0.520203046 0.6554592
## 0.2 0.037191811 0.6188033 0.521624365 0.6583775
## 0.2 0.048760568 0.6189697 0.519593909 0.6597386
## 0.2 0.063927861 0.6188790 0.515329949 0.6638115
## 0.2 0.083813041 0.6183066 0.510050761 0.6678987
## 0.2 0.109883635 0.6172998 0.494619289 0.6820999
## 0.2 0.144063659 0.6154955 0.470253807 0.6984362
## 0.2 0.188875603 0.6126165 0.419695431 0.7361695
## 0.2 0.247626595 0.6081873 0.325482234 0.8098963
## 0.2 0.324652467 0.5777423 0.150050761 0.9078011
## 0.2 0.425637741 0.5000000 0.000000000 1.0000000
## 0.2 0.558035146 0.5000000 0.000000000 1.0000000
## 0.2 0.731615629 0.5000000 0.000000000 1.0000000
## 0.2 0.959189457 0.5000000 0.000000000 1.0000000
## 0.2 1.257551613 0.5000000 0.000000000 1.0000000
## 0.2 1.648721271 0.5000000 0.000000000 1.0000000
## 0.3 0.002478752 0.6165397 0.526700508 0.6494293
## 0.3 0.003249784 0.6166947 0.526091371 0.6494322
## 0.3 0.004260651 0.6168449 0.526091371 0.6492370
## 0.3 0.005585954 0.6170464 0.526091371 0.6509856
## 0.3 0.007323502 0.6173576 0.524670051 0.6511788
## 0.3 0.009601525 0.6174652 0.523654822 0.6525465
## 0.3 0.012588142 0.6176945 0.523248731 0.6533270
## 0.3 0.016503766 0.6180401 0.522233503 0.6533213
## 0.3 0.021637371 0.6184254 0.520406091 0.6577959
## 0.3 0.028367816 0.6186978 0.520609137 0.6579882
## 0.3 0.037191811 0.6187187 0.517563452 0.6610950
## 0.3 0.048760568 0.6186448 0.515939086 0.6640057
## 0.3 0.063927861 0.6175732 0.505583756 0.6719839
## 0.3 0.083813041 0.6162671 0.488527919 0.6821037
## 0.3 0.109883635 0.6138465 0.456446701 0.7108852
## 0.3 0.144063659 0.6102588 0.393502538 0.7614663
## 0.3 0.188875603 0.6049319 0.282233503 0.8357878
## 0.3 0.247626595 0.5204231 0.018477157 0.9864078
## 0.3 0.324652467 0.5000000 0.000000000 1.0000000
## 0.3 0.425637741 0.5000000 0.000000000 1.0000000
## 0.3 0.558035146 0.5000000 0.000000000 1.0000000
## 0.3 0.731615629 0.5000000 0.000000000 1.0000000
## 0.3 0.959189457 0.5000000 0.000000000 1.0000000
## 0.3 1.257551613 0.5000000 0.000000000 1.0000000
## 0.3 1.648721271 0.5000000 0.000000000 1.0000000
## 0.4 0.002478752 0.6166453 0.526700508 0.6496263
## 0.4 0.003249784 0.6167599 0.525888325 0.6492351
## 0.4 0.004260651 0.6169703 0.525888325 0.6504011
## 0.4 0.005585954 0.6172490 0.524467005 0.6511778
## 0.4 0.007323502 0.6173923 0.523654822 0.6525465
## 0.4 0.009601525 0.6176281 0.523451777 0.6529358
## 0.4 0.012588142 0.6179363 0.522842640 0.6552669
## 0.4 0.016503766 0.6183829 0.521218274 0.6572115
## 0.4 0.021637371 0.6186955 0.520812183 0.6589590
## 0.4 0.028367816 0.6187167 0.517563452 0.6614824
## 0.4 0.037191811 0.6185758 0.514923858 0.6634232
## 0.4 0.048760568 0.6173234 0.505380711 0.6731499
## 0.4 0.063927861 0.6160457 0.490152284 0.6826862
## 0.4 0.083813041 0.6136052 0.455634518 0.7064182
## 0.4 0.109883635 0.6095516 0.395329949 0.7585489
## 0.4 0.144063659 0.6042792 0.278781726 0.8359839
## 0.4 0.188875603 0.5110218 0.000000000 1.0000000
## 0.4 0.247626595 0.5000000 0.000000000 1.0000000
## 0.4 0.324652467 0.5000000 0.000000000 1.0000000
## 0.4 0.425637741 0.5000000 0.000000000 1.0000000
## 0.4 0.558035146 0.5000000 0.000000000 1.0000000
## 0.4 0.731615629 0.5000000 0.000000000 1.0000000
## 0.4 0.959189457 0.5000000 0.000000000 1.0000000
## 0.4 1.257551613 0.5000000 0.000000000 1.0000000
## 0.4 1.648721271 0.5000000 0.000000000 1.0000000
## 0.5 0.002478752 0.6166966 0.526294416 0.6494303
## 0.5 0.003249784 0.6168429 0.525888325 0.6500109
## 0.5 0.004260651 0.6171462 0.525076142 0.6519555
## 0.5 0.005585954 0.6173083 0.523857868 0.6523514
## 0.5 0.007323502 0.6174829 0.524060914 0.6521582
## 0.5 0.009601525 0.6178138 0.523451777 0.6542979
## 0.5 0.012588142 0.6182613 0.521015228 0.6564319
## 0.5 0.016503766 0.6186490 0.521827411 0.6575970
## 0.5 0.021637371 0.6186878 0.518375635 0.6599252
## 0.5 0.028367816 0.6186380 0.516548223 0.6622562
## 0.5 0.037191811 0.6176482 0.509847716 0.6698413
## 0.5 0.048760568 0.6162440 0.495837563 0.6805465
## 0.5 0.063927861 0.6138608 0.467614213 0.6964973
## 0.5 0.083813041 0.6103134 0.415228426 0.7400682
## 0.5 0.109883635 0.6051515 0.303959391 0.8229439
## 0.5 0.144063659 0.5309444 0.030253807 0.9784087
## 0.5 0.188875603 0.5000000 0.000000000 1.0000000
## 0.5 0.247626595 0.5000000 0.000000000 1.0000000
## 0.5 0.324652467 0.5000000 0.000000000 1.0000000
## 0.5 0.425637741 0.5000000 0.000000000 1.0000000
## 0.5 0.558035146 0.5000000 0.000000000 1.0000000
## 0.5 0.731615629 0.5000000 0.000000000 1.0000000
## 0.5 0.959189457 0.5000000 0.000000000 1.0000000
## 0.5 1.257551613 0.5000000 0.000000000 1.0000000
## 0.5 1.648721271 0.5000000 0.000000000 1.0000000
## 0.6 0.002478752 0.6167351 0.525685279 0.6492342
## 0.6 0.003249784 0.6169407 0.526091371 0.6513739
## 0.6 0.004260651 0.6171543 0.524263959 0.6521544
## 0.6 0.005585954 0.6173811 0.524670051 0.6519650
## 0.6 0.007323502 0.6176091 0.524060914 0.6531319
## 0.6 0.009601525 0.6180259 0.521624365 0.6552631
## 0.6 0.012588142 0.6184687 0.522436548 0.6577940
## 0.6 0.016503766 0.6187594 0.520609137 0.6577874
## 0.6 0.021637371 0.6188025 0.517969543 0.6607009
## 0.6 0.028367816 0.6179336 0.512893401 0.6661454
## 0.6 0.037191811 0.6169148 0.502741117 0.6758778
## 0.6 0.048760568 0.6147347 0.478984772 0.6856055
## 0.6 0.063927861 0.6117418 0.436548223 0.7200322
## 0.6 0.083813041 0.6064129 0.350659898 0.7871456
## 0.6 0.109883635 0.5691158 0.137055838 0.9130466
## 0.6 0.144063659 0.5000000 0.000000000 1.0000000
## 0.6 0.188875603 0.5000000 0.000000000 1.0000000
## 0.6 0.247626595 0.5000000 0.000000000 1.0000000
## 0.6 0.324652467 0.5000000 0.000000000 1.0000000
## 0.6 0.425637741 0.5000000 0.000000000 1.0000000
## 0.6 0.558035146 0.5000000 0.000000000 1.0000000
## 0.6 0.731615629 0.5000000 0.000000000 1.0000000
## 0.6 0.959189457 0.5000000 0.000000000 1.0000000
## 0.6 1.257551613 0.5000000 0.000000000 1.0000000
## 0.6 1.648721271 0.5000000 0.000000000 1.0000000
## 0.7 0.002478752 0.6167618 0.525076142 0.6504021
## 0.7 0.003249784 0.6170553 0.524467005 0.6517604
## 0.7 0.004260651 0.6171669 0.524467005 0.6525475
## 0.7 0.005585954 0.6174059 0.524060914 0.6515728
## 0.7 0.007323502 0.6177751 0.524467005 0.6540990
## 0.7 0.009601525 0.6182228 0.523248731 0.6570163
## 0.7 0.012588142 0.6186528 0.521827411 0.6575970
## 0.7 0.016503766 0.6187421 0.520812183 0.6597319
## 0.7 0.021637371 0.6184595 0.517157360 0.6638115
## 0.7 0.028367816 0.6175508 0.507411168 0.6712081
## 0.7 0.037191811 0.6158492 0.492791878 0.6797755
## 0.7 0.048760568 0.6130957 0.459086294 0.7017533
## 0.7 0.063927861 0.6088593 0.401218274 0.7499919
## 0.7 0.083813041 0.6032275 0.280406091 0.8350073
## 0.7 0.109883635 0.5068880 0.000000000 1.0000000
## 0.7 0.144063659 0.5000000 0.000000000 1.0000000
## 0.7 0.188875603 0.5000000 0.000000000 1.0000000
## 0.7 0.247626595 0.5000000 0.000000000 1.0000000
## 0.7 0.324652467 0.5000000 0.000000000 1.0000000
## 0.7 0.425637741 0.5000000 0.000000000 1.0000000
## 0.7 0.558035146 0.5000000 0.000000000 1.0000000
## 0.7 0.731615629 0.5000000 0.000000000 1.0000000
## 0.7 0.959189457 0.5000000 0.000000000 1.0000000
## 0.7 1.257551613 0.5000000 0.000000000 1.0000000
## 0.7 1.648721271 0.5000000 0.000000000 1.0000000
## 0.8 0.002478752 0.6168812 0.525888325 0.6507895
## 0.8 0.003249784 0.6170996 0.524263959 0.6519593
## 0.8 0.004260651 0.6172724 0.524670051 0.6515757
## 0.8 0.005585954 0.6176063 0.524060914 0.6525475
## 0.8 0.007323502 0.6178828 0.522436548 0.6540961
## 0.8 0.009601525 0.6183473 0.522639594 0.6581823
## 0.8 0.012588142 0.6186736 0.520609137 0.6595330
## 0.8 0.016503766 0.6186813 0.518375635 0.6605049
## 0.8 0.021637371 0.6179455 0.512081218 0.6655610
## 0.8 0.028367816 0.6167882 0.503147208 0.6749088
## 0.8 0.037191811 0.6143869 0.479593909 0.6856064
## 0.8 0.048760568 0.6113780 0.435939086 0.7200360
## 0.8 0.063927861 0.6058259 0.349847716 0.7888960
## 0.8 0.083813041 0.5677464 0.121421320 0.9208458
## 0.8 0.109883635 0.5000000 0.000000000 1.0000000
## 0.8 0.144063659 0.5000000 0.000000000 1.0000000
## 0.8 0.188875603 0.5000000 0.000000000 1.0000000
## 0.8 0.247626595 0.5000000 0.000000000 1.0000000
## 0.8 0.324652467 0.5000000 0.000000000 1.0000000
## 0.8 0.425637741 0.5000000 0.000000000 1.0000000
## 0.8 0.558035146 0.5000000 0.000000000 1.0000000
## 0.8 0.731615629 0.5000000 0.000000000 1.0000000
## 0.8 0.959189457 0.5000000 0.000000000 1.0000000
## 0.8 1.257551613 0.5000000 0.000000000 1.0000000
## 0.8 1.648721271 0.5000000 0.000000000 1.0000000
## 0.9 0.002478752 0.6170305 0.524873096 0.6515662
## 0.9 0.003249784 0.6170523 0.524873096 0.6525456
## 0.9 0.004260651 0.6173041 0.523654822 0.6515766
## 0.9 0.005585954 0.6176021 0.523451777 0.6537106
## 0.9 0.007323502 0.6181586 0.523248731 0.6572096
## 0.9 0.009601525 0.6184531 0.521827411 0.6581814
## 0.9 0.012588142 0.6186364 0.521827411 0.6593408
## 0.9 0.016503766 0.6183262 0.516954315 0.6620611
## 0.9 0.021637371 0.6176239 0.509035533 0.6688733
## 0.9 0.028367816 0.6159141 0.496040609 0.6795813
## 0.9 0.037191811 0.6132204 0.466802030 0.6994213
## 0.9 0.048760568 0.6090681 0.411776650 0.7383187
## 0.9 0.063927861 0.6032487 0.292182741 0.8270339
## 0.9 0.083813041 0.5108156 0.007106599 0.9945631
## 0.9 0.109883635 0.5000000 0.000000000 1.0000000
## 0.9 0.144063659 0.5000000 0.000000000 1.0000000
## 0.9 0.188875603 0.5000000 0.000000000 1.0000000
## 0.9 0.247626595 0.5000000 0.000000000 1.0000000
## 0.9 0.324652467 0.5000000 0.000000000 1.0000000
## 0.9 0.425637741 0.5000000 0.000000000 1.0000000
## 0.9 0.558035146 0.5000000 0.000000000 1.0000000
## 0.9 0.731615629 0.5000000 0.000000000 1.0000000
## 0.9 0.959189457 0.5000000 0.000000000 1.0000000
## 0.9 1.257551613 0.5000000 0.000000000 1.0000000
## 0.9 1.648721271 0.5000000 0.000000000 1.0000000
##
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were alpha = 0.1 and lambda = 0.08381304.
Print out the non-zero coefficients, specifying the optimal value of lambda identified by resampling.
coef(mod_glmnet_2_b$finalModel, s = mod_glmnet_2_b$bestTune$lambda) %>%
as.matrix() %>%
as.data.frame() %>%
tibble::rownames_to_column("coef_name") %>%
tibble::as_tibble() %>%
purrr::set_names(c("coef_name", "coef_value")) %>%
filter(coef_value != 0)
## # A tibble: 57 x 2
## coef_name coef_value
## <chr> <dbl>
## 1 (Intercept) 0.0426
## 2 xBB2 0.00933
## 3 xBB3 0.00195
## 4 xBB4 -0.0217
## 5 x02 0.0449
## 6 x09 0.0101
## 7 x10 -0.00917
## 8 xAA2:xBB3 0.0310
## 9 xBB2:x01 0.00648
## 10 xBB3:x01 0.00783
## # … with 47 more rows
Visualize trends of metric AUC with respect to mixing percentage alpha
and regularization parameter lambda
, for model trained with our defined enet_grid
.
plot(mod_glmnet_2_b, xTrans = log)
Now fit a regularized regression model with elastic net, on all triplet interactions between all step_2_a_df
inputs, using tuneGrid = enet_grid
, then displaying the optimal tuning parameters.
Warning: The code chunk below takes more than a few minutes to run to completion.
set.seed(12345)
mod_glmnet_3_b <- caret::train(outcome_2 ~ (.)^3,
method = "glmnet",
preProcess = c("center", "scale"),
tuneGrid = enet_grid,
metric = roc_metric,
trControl = my_ctrl,
data = step_2_a_df)
mod_glmnet_3_b$bestTune
## alpha lambda
## 1 0.1 0.002478752
Check number of coefficients:
# number of coefficients
mod_glmnet_3_b$coefnames %>% length()
## [1] 535
# check
(model.matrix(outcome_2 ~ (.)^3, data = step_2_a_df) %>% colnames() %>% length() - 1) - (mod_glmnet_3_b$coefnames %>% length())
## [1] 0
Visualize trends of metric AUC with respect to mixing percentage alpha
and regularization parameter lambda
, for model trained with our defined enet_grid
.
plot(mod_glmnet_3_b, xTrans = log)
Compare resampling results across the two different models.
glmnet_results <- resamples(list(glmnet_2way = mod_glmnet_2_b,
glmnet_3way = mod_glmnet_3_b))
dotplot(glmnet_results)
Check confusionMatrix for both models. glmnet_2_b
:
confusionMatrix.train(mod_glmnet_2_b)
## Cross-Validated (5 fold, repeated 5 times) Confusion Matrix
##
## (entries are percentual average cell counts across resamples)
##
## Reference
## Prediction Fail Pass
## Fail 25.4 17.4
## Pass 23.5 33.7
##
## Accuracy (average) : 0.5916
glmnet_3_b
:
confusionMatrix.train(mod_glmnet_3_b)
## Cross-Validated (5 fold, repeated 5 times) Confusion Matrix
##
## (entries are percentual average cell counts across resamples)
##
## Reference
## Prediction Fail Pass
## Fail 26.1 18.2
## Pass 22.9 32.8
##
## Accuracy (average) : 0.5892
As described in Dr Yurko’s Ionosphere Caret Demo, “partial least squares (PLS) models are particularly well suited when the inputs are highly correlated to each other”. Although our EDA did not reveal any particularly interesting correlations between inputs like there are in the Ionosphere dataset, we can still try PLS to see how well the model performs for the step_2_a_df
inputs.
pls_grid <- expand.grid(ncomp = seq(1, 5, by = 1))
set.seed(12345)
mod_pls <- caret::train(outcome_2 ~ .,
method = "pls",
preProcess = c("center", "scale"),
tuneGrid = pls_grid,
metric = roc_metric,
trControl = my_ctrl,
data = step_2_a_df)
plot(mod_pls)
Check confusion matrix.
confusionMatrix.train(mod_pls)
## Cross-Validated (5 fold, repeated 5 times) Confusion Matrix
##
## (entries are percentual average cell counts across resamples)
##
## Reference
## Prediction Fail Pass
## Fail 26.9 19.5
## Pass 22.1 31.6
##
## Accuracy (average) : 0.5847
Now we will try several more complex, non-linear methods (which can capture non-linear relationships between inputs).
Random forests have become a handy and convenient learning algorithm that has good predictive performance with “relatively little hyperparameter tuning”. We will use method = "rf"
that allows us to use caret::train
as we have for all other models. By default, the random forest model creates 500 bagged tree models. The random forest model randomly selects, at each split, mtry
features to consider for the splitting process.
We use a custom grid for different mtry
values. Because we have 13 predictors, we will try mtry = c(2, 4, 5, 7, 9, 11, 13)
. The code chunk below might take a few minutes to run to completion.
rf_grid <- expand.grid(mtry = c(2, 4, 5, 7, 9, 11, 13))
set.seed(12345)
mod_rf <- caret::train(outcome_2 ~ .,
method = "rf",
importance = TRUE,
tuneGrid = rf_grid,
trControl = my_ctrl,
metric = roc_metric,
data = step_2_a_df)
mod_rf
## Random Forest
##
## 2013 samples
## 13 predictor
## 2 classes: 'Fail', 'Pass'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold, repeated 5 times)
## Summary of sample sizes: 1610, 1611, 1611, 1610, 1610, 1611, ...
## Resampling results across tuning parameters:
##
## mtry ROC Sens Spec
## 2 0.7886988 0.7122843 0.7182714
## 4 0.8002955 0.7608122 0.7073663
## 5 0.8000063 0.7634518 0.7013393
## 7 0.7984453 0.7652792 0.6974293
## 9 0.7949668 0.7612183 0.7001563
## 11 0.7938937 0.7549239 0.6995728
## 13 0.7916721 0.7506599 0.6982108
##
## ROC was used to select the optimal model using the largest value.
## The final value used for the model was mtry = 4.
Check confusion matrix based on cross-validation results.
confusionMatrix.train(mod_rf)
## Cross-Validated (5 fold, repeated 5 times) Confusion Matrix
##
## (entries are percentual average cell counts across resamples)
##
## Reference
## Prediction Fail Pass
## Fail 37.2 14.9
## Pass 11.7 36.1
##
## Accuracy (average) : 0.7335
Gradient boosting machines (GBM) build shallow trees in sequence, with each tree “learning and improving on the previous one”; as opposed to random forests which build deep independent trees. When gradient boosted and tuned, these shallow trees collectively form one of the best predictive models.
Set method = "xgbTree"
in caret::train
.
set.seed(12345)
mod_xgb <- caret::train(outcome_2 ~ .,
method = "xgbTree",
verbose = FALSE,
metric = roc_metric,
trControl = my_ctrl,
data = step_2_a_df)
mod_xgb
## eXtreme Gradient Boosting
##
## 2013 samples
## 13 predictor
## 2 classes: 'Fail', 'Pass'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold, repeated 5 times)
## Summary of sample sizes: 1610, 1611, 1611, 1610, 1610, 1611, ...
## Resampling results across tuning parameters:
##
## eta max_depth colsample_bytree subsample nrounds ROC Sens
## 0.3 1 0.6 0.50 50 0.7698187 0.6584772
## 0.3 1 0.6 0.50 100 0.7727299 0.6749239
## 0.3 1 0.6 0.50 150 0.7719606 0.6800000
## 0.3 1 0.6 0.75 50 0.7727008 0.6576650
## 0.3 1 0.6 0.75 100 0.7754783 0.6741117
## 0.3 1 0.6 0.75 150 0.7747493 0.6706599
## 0.3 1 0.6 1.00 50 0.7729116 0.6408122
## 0.3 1 0.6 1.00 100 0.7778342 0.6659898
## 0.3 1 0.6 1.00 150 0.7773261 0.6720812
## 0.3 1 0.8 0.50 50 0.7711913 0.6637563
## 0.3 1 0.8 0.50 100 0.7708642 0.6686294
## 0.3 1 0.8 0.50 150 0.7694209 0.6755330
## 0.3 1 0.8 0.75 50 0.7727520 0.6546193
## 0.3 1 0.8 0.75 100 0.7756598 0.6777665
## 0.3 1 0.8 0.75 150 0.7732621 0.6730964
## 0.3 1 0.8 1.00 50 0.7733490 0.6422335
## 0.3 1 0.8 1.00 100 0.7780521 0.6655838
## 0.3 1 0.8 1.00 150 0.7779404 0.6694416
## 0.3 2 0.6 0.50 50 0.7763377 0.6795939
## 0.3 2 0.6 0.50 100 0.7694828 0.6791878
## 0.3 2 0.6 0.50 150 0.7658762 0.6791878
## 0.3 2 0.6 0.75 50 0.7845011 0.6826396
## 0.3 2 0.6 0.75 100 0.7825355 0.6905584
## 0.3 2 0.6 0.75 150 0.7769851 0.6887310
## 0.3 2 0.6 1.00 50 0.7900593 0.6848731
## 0.3 2 0.6 1.00 100 0.7873255 0.6885279
## 0.3 2 0.6 1.00 150 0.7831357 0.6972589
## 0.3 2 0.8 0.50 50 0.7769232 0.6862944
## 0.3 2 0.8 0.50 100 0.7715836 0.6802030
## 0.3 2 0.8 0.50 150 0.7702877 0.6864975
## 0.3 2 0.8 0.75 50 0.7880421 0.6864975
## 0.3 2 0.8 0.75 100 0.7857674 0.6897462
## 0.3 2 0.8 0.75 150 0.7812681 0.6842640
## 0.3 2 0.8 1.00 50 0.7934665 0.6921827
## 0.3 2 0.8 1.00 100 0.7911592 0.6917766
## 0.3 2 0.8 1.00 150 0.7877259 0.6948223
## 0.3 3 0.6 0.50 50 0.7731563 0.6791878
## 0.3 3 0.6 0.50 100 0.7640676 0.6842640
## 0.3 3 0.6 0.50 150 0.7571524 0.6806091
## 0.3 3 0.6 0.75 50 0.7844676 0.6948223
## 0.3 3 0.6 0.75 100 0.7741658 0.6854822
## 0.3 3 0.6 0.75 150 0.7698236 0.6909645
## 0.3 3 0.6 1.00 50 0.7936389 0.6944162
## 0.3 3 0.6 1.00 100 0.7864082 0.6992893
## 0.3 3 0.6 1.00 150 0.7824682 0.6929949
## 0.3 3 0.8 0.50 50 0.7772389 0.6976650
## 0.3 3 0.8 0.50 100 0.7666117 0.6852792
## 0.3 3 0.8 0.50 150 0.7613667 0.6812183
## 0.3 3 0.8 0.75 50 0.7895111 0.6964467
## 0.3 3 0.8 0.75 100 0.7808798 0.6923858
## 0.3 3 0.8 0.75 150 0.7751093 0.6856853
## 0.3 3 0.8 1.00 50 0.7931749 0.7021320
## 0.3 3 0.8 1.00 100 0.7897491 0.7035533
## 0.3 3 0.8 1.00 150 0.7836854 0.6964467
## 0.4 1 0.6 0.50 50 0.7663168 0.6617259
## 0.4 1 0.6 0.50 100 0.7657465 0.6737056
## 0.4 1 0.6 0.50 150 0.7606027 0.6702538
## 0.4 1 0.6 0.75 50 0.7728382 0.6672081
## 0.4 1 0.6 0.75 100 0.7746753 0.6775635
## 0.4 1 0.6 0.75 150 0.7710139 0.6802030
## 0.4 1 0.6 1.00 50 0.7742321 0.6517766
## 0.4 1 0.6 1.00 100 0.7778580 0.6732995
## 0.4 1 0.6 1.00 150 0.7760507 0.6745178
## 0.4 1 0.8 0.50 50 0.7660028 0.6607107
## 0.4 1 0.8 0.50 100 0.7651894 0.6692386
## 0.4 1 0.8 0.50 150 0.7633827 0.6661929
## 0.4 1 0.8 0.75 50 0.7730557 0.6590863
## 0.4 1 0.8 0.75 100 0.7729226 0.6751269
## 0.4 1 0.8 0.75 150 0.7693549 0.6755330
## 0.4 1 0.8 1.00 50 0.7741773 0.6495431
## 0.4 1 0.8 1.00 100 0.7777564 0.6722843
## 0.4 1 0.8 1.00 150 0.7761287 0.6755330
## 0.4 2 0.6 0.50 50 0.7703060 0.6797970
## 0.4 2 0.6 0.50 100 0.7612938 0.6775635
## 0.4 2 0.6 0.50 150 0.7557010 0.6749239
## 0.4 2 0.6 0.75 50 0.7805801 0.6814213
## 0.4 2 0.6 0.75 100 0.7738490 0.6832487
## 0.4 2 0.6 0.75 150 0.7698275 0.6891371
## 0.4 2 0.6 1.00 50 0.7870354 0.6816244
## 0.4 2 0.6 1.00 100 0.7824073 0.6901523
## 0.4 2 0.6 1.00 150 0.7762232 0.6871066
## 0.4 2 0.8 0.50 50 0.7731089 0.6767513
## 0.4 2 0.8 0.50 100 0.7646752 0.6755330
## 0.4 2 0.8 0.50 150 0.7607448 0.6761421
## 0.4 2 0.8 0.75 50 0.7841538 0.6909645
## 0.4 2 0.8 0.75 100 0.7769770 0.6856853
## 0.4 2 0.8 0.75 150 0.7685628 0.6735025
## 0.4 2 0.8 1.00 50 0.7907755 0.6871066
## 0.4 2 0.8 1.00 100 0.7854047 0.6901523
## 0.4 2 0.8 1.00 150 0.7809258 0.6879188
## 0.4 3 0.6 0.50 50 0.7570689 0.6735025
## 0.4 3 0.6 0.50 100 0.7485898 0.6806091
## 0.4 3 0.6 0.50 150 0.7462318 0.6822335
## 0.4 3 0.6 0.75 50 0.7788115 0.6944162
## 0.4 3 0.6 0.75 100 0.7670294 0.6844670
## 0.4 3 0.6 0.75 150 0.7621416 0.6787817
## 0.4 3 0.6 1.00 50 0.7839113 0.6948223
## 0.4 3 0.6 1.00 100 0.7752788 0.6844670
## 0.4 3 0.6 1.00 150 0.7708931 0.6885279
## 0.4 3 0.8 0.50 50 0.7621217 0.6797970
## 0.4 3 0.8 0.50 100 0.7510226 0.6759391
## 0.4 3 0.8 0.50 150 0.7412198 0.6682234
## 0.4 3 0.8 0.75 50 0.7779812 0.6964467
## 0.4 3 0.8 0.75 100 0.7688880 0.6814213
## 0.4 3 0.8 0.75 150 0.7638909 0.6820305
## 0.4 3 0.8 1.00 50 0.7883959 0.6962437
## 0.4 3 0.8 1.00 100 0.7826881 0.6962437
## 0.4 3 0.8 1.00 150 0.7779497 0.6984772
## Spec
## 0.7581785
## 0.7501994
## 0.7404622
## 0.7585593
## 0.7496159
## 0.7494047
## 0.7599223
## 0.7601222
## 0.7535051
## 0.7535041
## 0.7375572
## 0.7348226
## 0.7583661
## 0.7538906
## 0.7461075
## 0.7614861
## 0.7577902
## 0.7562302
## 0.7400957
## 0.7338489
## 0.7274298
## 0.7482510
## 0.7400758
## 0.7326706
## 0.7523220
## 0.7546512
## 0.7488174
## 0.7295780
## 0.7303367
## 0.7283874
## 0.7525323
## 0.7484452
## 0.7425953
## 0.7564206
## 0.7492134
## 0.7505764
## 0.7320890
## 0.7173128
## 0.7105025
## 0.7406555
## 0.7320966
## 0.7258669
## 0.7490069
## 0.7429837
## 0.7320966
## 0.7275908
## 0.7190443
## 0.7250931
## 0.7416225
## 0.7383074
## 0.7299541
## 0.7550386
## 0.7493990
## 0.7384968
## 0.7515671
## 0.7383348
## 0.7336481
## 0.7606943
## 0.7496045
## 0.7445579
## 0.7593502
## 0.7525371
## 0.7505849
## 0.7361781
## 0.7330637
## 0.7326725
## 0.7585697
## 0.7470812
## 0.7420260
## 0.7614880
## 0.7525352
## 0.7521392
## 0.7295600
## 0.7250912
## 0.7155482
## 0.7447521
## 0.7332692
## 0.7291622
## 0.7558248
## 0.7484461
## 0.7416368
## 0.7363779
## 0.7254805
## 0.7260668
## 0.7422221
## 0.7357982
## 0.7309477
## 0.7581501
## 0.7453175
## 0.7437604
## 0.7132209
## 0.7038740
## 0.6953322
## 0.7317054
## 0.7196287
## 0.7145726
## 0.7480284
## 0.7295468
## 0.7223547
## 0.7231513
## 0.7145915
## 0.7035018
## 0.7316969
## 0.7192555
## 0.7095278
## 0.7456983
## 0.7379219
## 0.7309174
##
## Tuning parameter 'gamma' was held constant at a value of 0
## Tuning
## parameter 'min_child_weight' was held constant at a value of 1
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were nrounds = 50, max_depth = 3, eta
## = 0.3, gamma = 0, colsample_bytree = 0.6, min_child_weight = 1 and subsample
## = 1.
The best model identified has 50 iterations (nrounds
), complexity (max_depth
) of 3, learning rate (eta
) of 0.3, and minimum number of training set samples in a node to commence sampling (subsample
) of 1.
Check confusion matrix.
confusionMatrix.train(mod_xgb)
## Cross-Validated (5 fold, repeated 5 times) Confusion Matrix
##
## (entries are percentual average cell counts across resamples)
##
## Reference
## Prediction Fail Pass
## Fail 34.0 12.8
## Pass 15.0 38.3
##
## Accuracy (average) : 0.7223
The motivation for fitting a Support Vector Machine (SVM) is that SVMs have several advantages compared to other methods, as mentioned in the Hands-On Machine Learning with R book:
The basic idea of SVMs is dividing classes through hyperplanes; using a “kernel trick”, as Dr Yurko puts it, to transform from the original space to a new feature space, on which it then tries to create linear separating boundaries between the classes.
First, load kernlab
library.
library(kernlab)
##
## Attaching package: 'kernlab'
## The following object is masked from 'package:purrr':
##
## cross
## The following object is masked from 'package:ggplot2':
##
## alpha
We will stick to the general rule of thumb to use a radial basis kernel in our caret::train
call, using method="svmRadial"
.
First see what are the parameters to be learned:
caret::getModelInfo("svmRadial")$svmRadial$parameters
## parameter class label
## 1 sigma numeric Sigma
## 2 C numeric Cost
Now fit the model.
set.seed(12345)
mod_svm <- caret::train(outcome_2 ~ .,
method = "svmRadial",
preProcess = c("center", "scale"),
metric = roc_metric,
trControl = my_ctrl,
data = step_2_a_df)
mod_svm
## Support Vector Machines with Radial Basis Function Kernel
##
## 2013 samples
## 13 predictor
## 2 classes: 'Fail', 'Pass'
##
## Pre-processing: centered (15), scaled (15)
## Resampling: Cross-Validated (5 fold, repeated 5 times)
## Summary of sample sizes: 1610, 1611, 1611, 1610, 1610, 1611, ...
## Resampling results across tuning parameters:
##
## C ROC Sens Spec
## 0.25 0.7187581 0.6355330 0.6867866
## 0.50 0.7374841 0.6410152 0.7070234
## 1.00 0.7458523 0.6473096 0.7212181
##
## Tuning parameter 'sigma' was held constant at a value of 0.04054147
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were sigma = 0.04054147 and C = 1.
Plot the results to see cross-validated ROC
scores against different cost values.
plot(mod_svm)
Use a refined custom grid search, based on the identified best sigma
.
svm_grid <- expand.grid(sigma = mod_svm$bestTune$sigma * c(0.25, 0.5, 1, 2),
C = c(0.5, 1.0, 2.0, 4.0, 8.0, 16.0, 32.0))
set.seed(12345)
mod_svm_b <- caret::train(outcome_2 ~ .,
method = "svmRadial",
preProcess = c("center", "scale"),
tuneGrid = svm_grid,
metric = roc_metric,
trControl = my_ctrl,
data = step_2_a_df)
mod_svm_b$bestTune
## sigma C
## 24 0.08108294 2
Plot results.
ggplot(mod_svm_b) + theme_bw()
Clearly, the model corresponding to the red line, sigma = 0.01756960
is the best model with the highest AUC at Cost = 16.0
.
Check confusion matrix.
confusionMatrix.train(mod_svm_b)
## Cross-Validated (5 fold, repeated 5 times) Confusion Matrix
##
## (entries are percentual average cell counts across resamples)
##
## Reference
## Prediction Fail Pass
## Fail 32.3 14.5
## Pass 16.6 36.5
##
## Accuracy (average) : 0.6881
The motivation for fitting a Multivariate Adaptive Regression Splines (MARS) model is to explore more nonlinear relationships between the inputs. MARS is capable of extending linear models to capture multiple nonlinear relationships by searching for and discovering nonlinearities and interactions in the data that will help maximize predictive accuracy.
First, load earth
library for MARS modeling.
library(earth)
## Loading required package: Formula
## Loading required package: plotmo
## Loading required package: plotrix
## Loading required package: TeachingDemos
Hands-On Maching Learning in R describes the inner workings of MARS. Instead of explicitly defining polynomial functions or natural spline functions ourselves, MARS provides a convenient approach to capture the nonlinear relationships in the data by assessing cutpoints, like step functions. The procedure assesses each data point for each input as a knot and creates a linear regression model with the candidate feature(s).
To help in the tuning of this procedure, we can specify tuning parameters such as the maximum degree of interactions, degree
, and the number of terms retained in the final model, nprune
, in a tuning grid to be passed into the caret::train
call. Since there is rarely any benefit in assessing greater than triplet interactions, we choose degree = 1:3
. We also start out with 10 evenly spaced values and intend to zoom in when we later find an approximate optimal solution and there is cause to.
mars_grid <- expand.grid(degree = 1:3,
nprune = seq(2, 100, length.out = 10) %>% floor())
mars_grid %>% head()
## degree nprune
## 1 1 2
## 2 2 2
## 3 3 2
## 4 1 12
## 5 2 12
## 6 3 12
We will use caret::train
, as in the previous sections. The grid search might take a few minutes.
set.seed(12345)
mod_mars <- caret::train(outcome_2 ~ .,
method = "earth",
tuneGrid = mars_grid,
metric = roc_metric,
trControl = my_ctrl,
data = step_2_a_df)
mod_mars$bestTune
## nprune degree
## 13 23 2
Plot the model.
ggplot(mod_mars)
Here, because the optimal ROC
values stay constant beyond roughly nprune = 23
terms, there is no need to adjust to a more specific tuning grid.
mod_mars$resample %>% summary()
## ROC Sens Spec Resample
## Min. :0.7267 Min. :0.5990 Min. :0.6699 Length:25
## 1st Qu.:0.7712 1st Qu.:0.6802 1st Qu.:0.7136 Class :character
## Median :0.7848 Median :0.7005 Median :0.7317 Mode :character
## Mean :0.7856 Mean :0.7003 Mean :0.7424
## 3rd Qu.:0.8034 3rd Qu.:0.7208 3rd Qu.:0.7756
## Max. :0.8241 Max. :0.7563 Max. :0.8398
Check confusion matrix.
confusionMatrix.train(mod_mars)
## Cross-Validated (5 fold, repeated 5 times) Confusion Matrix
##
## (entries are percentual average cell counts across resamples)
##
## Reference
## Prediction Fail Pass
## Fail 34.3 13.2
## Pass 14.7 37.9
##
## Accuracy (average) : 0.7218
Now that we have fit all of the models, we can compare the cross-validation hold-out set performance metrics. We first compile all of the “resample” results together with the resamples() function.
iv_results <- resamples(list(glm = mod_glm,
glmnet_2way = mod_glmnet_2_b,
glmnet_3way = mod_glmnet_3_b,
nnet = mod_nnet_b,
rf = mod_rf,
xgb = mod_xgb,
svm = mod_svm_b,
mars = mod_mars,
pls = mod_pls))
Then we visually compare the performance metrics.
dotplot(iv_results)
dotplot(iv_results, metric = "ROC")
dotplot(iv_results, metric = "Sens")
dotplot(iv_results, metric = "Spec")
Based on AUC, rf
is the best model; although rf
, nnet
, xgb
, mars
seem to be close to each other in terms of performance. While rf
does the best in terms of AUC and Sens
, it does not fare as well in Spec
.
Assemble the ROC curves for comparison. First, identify the best tuned model and combine the cross-validation hold-out set predictions.
cv_pred_results <- mod_glm$pred %>% tbl_df() %>%
filter(parameter == mod_glm$bestTune$parameter) %>%
select(pred, obs, Fail, Pass, rowIndex, Resample) %>%
mutate(model_name = "glm") %>%
bind_rows(mod_glmnet_2_b$pred %>% tbl_df() %>%
filter(alpha %in% mod_glmnet_2_b$bestTune$alpha,
lambda %in% mod_glmnet_2_b$bestTune$lambda) %>%
select(pred, obs, Fail, Pass, rowIndex, Resample) %>%
mutate(model_name = "glmnet_2_b")) %>%
bind_rows(mod_glmnet_3_b$pred %>% tbl_df() %>%
filter(alpha %in% mod_glmnet_3_b$bestTune$alpha,
lambda %in% mod_glmnet_3_b$bestTune$lambda) %>%
select(pred, obs, Fail, Pass, rowIndex, Resample) %>%
mutate(model_name = "glmnet_3_b")) %>%
bind_rows(mod_pls$pred %>% tbl_df() %>%
filter(ncomp %in% mod_pls$bestTune$ncomp) %>%
select(pred, obs, Fail, Pass, rowIndex, Resample) %>%
mutate(model_name = "pls")) %>%
bind_rows(mod_nnet_b$pred %>% tbl_df() %>%
filter(size == mod_nnet_b$bestTune$size,
decay == mod_nnet_b$bestTune$decay) %>%
select(pred, obs, Fail, Pass, rowIndex, Resample) %>%
mutate(model_name = "nnet")) %>%
bind_rows(mod_rf$pred %>% tbl_df() %>%
filter(mtry == mod_rf$bestTune$mtry) %>%
select(pred, obs, Fail, Pass, rowIndex, Resample) %>%
mutate(model_name = "rf")) %>%
bind_rows(mod_xgb$pred %>% tbl_df() %>%
filter(nrounds == mod_xgb$bestTune$nrounds,
max_depth == mod_xgb$bestTune$max_depth,
eta %in% mod_xgb$bestTune$eta,
gamma %in% mod_xgb$bestTune$gamma,
colsample_bytree %in% mod_xgb$bestTune$colsample_bytree,
min_child_weight == mod_xgb$bestTune$min_child_weight,
subsample == mod_xgb$bestTune$subsample) %>%
select(pred, obs, Fail, Pass, rowIndex, Resample) %>%
mutate(model_name = "xgb")) %>%
bind_rows(mod_svm_b$pred %>% tbl_df() %>%
filter(sigma %in% mod_svm_b$bestTune$sigma,
C %in% mod_svm_b$bestTune$C) %>%
select(pred, obs, Fail, Pass, rowIndex, Resample) %>%
mutate(model_name = "svm")) %>%
bind_rows(mod_mars$pred %>% tbl_df() %>%
filter(nprune == mod_mars$bestTune$nprune,
degree == mod_mars$bestTune$degree) %>%
select(pred, obs, Fail, Pass, rowIndex, Resample) %>%
mutate(model_name = "mars"))
## Warning: `tbl_df()` is deprecated as of dplyr 1.0.0.
## Please use `tibble::as_tibble()` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_warnings()` to see where this warning was generated.
### nrounds = 50, max_depth = 2, eta = 0.3, gamma = 0,
### colsample_bytree = 0.8, min_child_weight = 1 and subsample = 1.
Load plotROC
to plot ROC curves.
library(plotROC)
Visualize the ROC curves for each fold-resample broken up by the methods.
cv_pred_results %>%
ggplot(mapping = aes(m = Fail,
d = ifelse(obs == "Fail",
1,
0))) +
geom_roc(cutoffs.at = 0.5,
mapping = aes(color = Resample)) +
geom_roc(cutoffs.at = 0.5) +
coord_equal() +
facet_wrap(~model_name) +
style_roc()
The black line is the ROC curve averaged over all folds and repeats. Examine rf
and mars
models more closely since they seem to be the best performing.
cv_pred_results %>%
filter(model_name %in% c("rf", "mars")) %>%
ggplot(mapping = aes(m = Fail,
d = ifelse(obs == "Fail",
1,
0))) +
geom_roc(cutoffs.at = 0.5,
mapping = aes(color = Resample)) +
geom_roc(cutoffs.at = 0.5) +
coord_equal() +
facet_wrap(~model_name) +
style_roc()
Compare cross-validation averaged ROC curves.
cv_pred_results %>%
ggplot(mapping = aes(m = Fail,
d = ifelse(obs == "Fail",
1,
0),
color = model_name)) +
geom_roc(cutoffs.at = 0.5) +
coord_equal() +
style_roc() +
ggthemes::scale_color_calc()
As we expected, rf
, nnet
, xgb
and mars
all perform comparatively well.
Consider the calibration curves associated with the cross-validation hold-out sets for the above four models, and a linear model glmnet
.
rf_test_pred_good <- mod_rf$pred %>% tbl_df() %>%
filter(mtry == mod_rf$bestTune$mtry) %>%
select(obs, Fail, rowIndex, Resample)
nnet_test_pred_good <- mod_nnet_b$pred %>% tbl_df() %>%
filter(size == mod_nnet_b$bestTune$size,
decay == mod_nnet_b$bestTune$decay) %>%
select(obs, Fail, rowIndex, Resample)
xgb_test_pred_good <- mod_xgb$pred %>% tbl_df() %>%
filter(nrounds == mod_xgb$bestTune$nrounds,
max_depth == mod_xgb$bestTune$max_depth,
eta %in% mod_xgb$bestTune$eta,
gamma %in% mod_xgb$bestTune$gamma,
colsample_bytree %in% mod_xgb$bestTune$colsample_bytree,
min_child_weight == mod_xgb$bestTune$min_child_weight,
subsample == mod_xgb$bestTune$subsample) %>%
select(obs, Fail, rowIndex, Resample)
mars_test_pred_good <- mod_mars$pred %>% tbl_df() %>%
filter(nprune == mod_mars$bestTune$nprune,
degree == mod_mars$bestTune$degree) %>%
select(obs, Fail, rowIndex, Resample)
glmnet_3_b_test_pred_good <- mod_glmnet_3_b$pred %>% tbl_df() %>%
filter(alpha %in% mod_glmnet_3_b$bestTune$alpha,
lambda %in% mod_glmnet_3_b$bestTune$lambda) %>%
select(obs, Fail, rowIndex, Resample)
cal_holdout_preds <- rf_test_pred_good %>% rename(rf = Fail) %>%
left_join(nnet_test_pred_good %>% rename(nnet = Fail),
by = c("obs", "rowIndex", "Resample")) %>%
left_join(xgb_test_pred_good %>% rename(xgb = Fail),
by = c("obs", "rowIndex", "Resample")) %>%
left_join(mars_test_pred_good %>% rename(mars = Fail),
by = c("obs", "rowIndex", "Resample")) %>%
left_join(glmnet_3_b_test_pred_good %>% rename(glmnet = Fail),
by = c("obs", "rowIndex", "Resample")) %>%
select(outcome_2 = obs, rf, nnet, xgb, mars, glmnet)
Generate calibration curves.
cal_object <- calibration(outcome_2 ~ rf + nnet + xgb + mars + glmnet,
data = cal_holdout_preds,
cuts = 10)
ggplot(cal_object) + theme_bw() + theme(legend.position = "top")
glmnet
seems to not be so well calibrated.
cal_object <- calibration(outcome_2 ~ rf + nnet + xgb + mars + glmnet,
data = cal_holdout_preds,
cuts = 5)
ggplot(cal_object) + theme_bw() + theme(legend.position = "top")
Based on Accuracy
, the result for best model appear to be slightly different, although the previously identified four best models are still the same.
calc_accuracy <- function(model) {
cf <- confusionMatrix.train(model)
return( (cf$table[1,1] + cf$table[2,2]) / 100 )
}
models <- list(glm = mod_glm, glmnet_2way = mod_glmnet_2_b, glmnet_3way = mod_glmnet_3_b, nnet = mod_nnet_b, rf = mod_rf, xgb = mod_xgb, pls = mod_pls, svm = mod_svm, mars = mod_mars)
accuracy_results <- purrr::map_dbl(models, calc_accuracy)
accuracy_results %>% sort(decreasing = TRUE)
## rf xgb mars nnet svm glmnet_2way
## 0.7335320 0.7223050 0.7218082 0.6949826 0.6850472 0.5915549
## glm glmnet_3way pls
## 0.5910581 0.5891704 0.5846995
rf
is the best performing model in terms of Accuracy
.
Complex non-linear models can be difficult to interpret. We can consider ranking the relative importance of the input variables in the step_2_a_df
dataset. Plot variable importance based on rf
model.
plot(varImp(mod_rf))
Plot variable importance based on xgb
model.
plot(varImp(mod_xgb))
x07
and x08
seem to be the two most important inputs.
# Rename models list item names
names(models) <- c("iv_glm", "iv_glmnet_2way", "iv_glmnet_3way", "iv_nnet", "iv_rf", "iv_xgb", "iv_pls", "iv_svm", "iv_mars")
# Save models list
models %>% readr::write_rds("iv_models.rds")