Activity: multinomial logistic regression
While we are getting started, on your table’s workstation open RStudio and execute the following command:
Then open a new script, copy-paste the code chunk below, and execute once.
# packages
# path to activity files on repo
<- ''
# load a few functions for the activity
source(paste(url, 'projection-functions.R', sep = ''))
# read in data
<- paste(url, 'claims-multi-tfidf.csv', sep = '') %>%
claims read_csv()
# preview
# A tibble: 552 × 15,871
mclass .id bclass adams afternoon agent android app arkansas arrest
<chr> <chr> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 unlawful url1 relev… 0.0692 0.0300 0.0365 0.0450 0.0330 0.0390 0.0140
2 irreleva… url3 irrel… 0 0 0 0 0 0 0
3 other url4 relev… 0 0 0 0 0 0 0
4 fatality url5 relev… 0 0 0 0 0 0 0
5 irreleva… url7 irrel… 0 0 0 0 0 0 0.0107
6 fatality url8 relev… 0 0 0 0 0 0 0
7 irreleva… url9 irrel… 0 0 0 0 0 0 0
8 irreleva… url10 irrel… 0 0 0 0 0 0 0
9 unlawful url11 relev… 0 0 0 0 0 0 0.0716
10 unlawful url12 relev… 0 0 0 0 0 0 0.0161
# … with 542 more rows, and 15,861 more variables: arrive <dbl>, assist <dbl>,
# attic <dbl>, barricade <dbl>, block <dbl>, blytheville <dbl>, burn <dbl>,
# captain <dbl>, catch <dbl>, check <dbl>, chemical <dbl>, copyright <dbl>,
# county <dbl>, custody <dbl>, dehydration <dbl>, demand <dbl>,
# department <dbl>, desktop <dbl>, device <dbl>, dispute <dbl>,
# division <dbl>, drug <dbl>, enter <dbl>, exit <dbl>, family <dbl>,
# federal <dbl>, fire <dbl>, force <dbl>, gas <dbl>, gray <dbl>, …
Activity 1 (10 min)
You’ll be given about ten minutes to do the following on your group’s workstation.
- Partition the data into training and test sets.
- Using the training data, find principal components that preserve at least 80% of the total variance and project the data onto those PCs.
- Fit a logistic regression model to the training data.
Step 1: partitioning
This should be familiar from last week’s lab. Use the code chunk below to partition the data. Do not change the RNG seed or split proportion!
# partition data
<- claims %>% initial_split(prop = 0.8)
# separate DTM from labels
<- testing(partitions) %>%
test_dtm select(, -bclass, -mclass)
<- testing(partitions) %>%
test_labels select(.id, bclass, mclass)
# same, training set
<- training(partitions) %>%
train_dtm select(, -bclass, -mclass)
<- training(partitions) %>%
train_labels select(.id, bclass, mclass)
Note that we have separated the document term matrix (DTM) from the labels for both partitions. When we project the data onto a subspace, we only want to project the DTM and not the labels.
Step 2: projection
Now find the number of principal components that capture at least 70% of variation and project the document term matrix (DTM) onto those components. Use the custom function projection_fn(.dtm, .prop)
# find projections based on training data
<- projection_fn(.dtm = train_dtm, .prop = 0.7)
proj_out <- proj_out$data
# how many components were used?
$n_pc proj_out
Note: projections were found using the training data only. The test data will ultimately be projected onto the same components, as if it were new information we were feeding into a predictive model developed entirely using the training data.
Step 3: regression
Bind the binary labels to the projected document term matrix and fit a logistic regression model.
The code chunk below gives you the input data frame you need to use glm()
. It’s up to you to specify the other arguments needed to fit the model.
<- train_labels %>%
train transmute(bclass = factor(bclass)) %>%
<- glm(..., data = train, ...) fit
You will most likely get a warning of some kind – that’s expected. Take note of what the warning says and stop here.
Briefly discuss with your table: what do you think the warning means?
Activity 2 (10 min)
This part will guide you through the following steps.
- Fit a logistic regression model with an elastic net penalty to the training data.
- Quantify classification accuracy on the test data using sensitivity, specificity, and AUROC.
Step 1: fit a regularized logistic regression
implements the elastic net penalty when a parameter alpha
is provided. In the function call, a predictor matrix and response vector are used to specify the model instead of a formula.
Use the code chunk below to fit the model for a path of regularization strengths, select a strength, and extract the fitted model corresponding to that strength. Do not adjust the RNG seed.
# store predictors and response as matrix and vector
<- train %>% select(-bclass) %>% as.matrix()
x_train <- train_labels %>% pull(bclass)
# fit enet model
<- 0.3
alpha_enet <- glmnet(x = x_train,
fit_reg y = y_train,
family = 'binomial',
alpha = alpha_enet)
# choose a strength by cross-validation
<- cv.glmnet(x = x_train,
cvout y = y_train,
family = 'binomial',
alpha = alpha_enet)
# store optimal strength
<- cvout$lambda.min
# view results
Comment. The elastic net parameter alpha
controls the balance between ridge and LASSO penalties: alpha = 0
corresponds to ridge regression, alpha = 1
corresponds to LASSO, and all other values specify a mixture. When the parameter is closer to 1, the LASSO penalty is stronger relative to ridge; and vice-versa when it’s closer to 0.
Step 2: prediction
To compute predictions, we’ll need to project the test data onto the same directions used to transform the training data.
Once that’s done, we can simply feed the projected test data, the fitted model fit_reg
, and the optimal strength lambda_opt
to a predict()
# project test data onto PCs
<- reproject_fn(.dtm = test_dtm, proj_out)
# coerce to matrix
<- as.matrix(test_dtm_projected)
# compute predicted probabilities
<- predict(fit_reg,
preds s = lambda_opt,
newx = x_test,
type = 'response')
Next bind the test labels to the predictions:
# store predictions in a data frame with true labels
<- test_labels %>%
pred_df transmute(bclass = factor(bclass)) %>%
bind_cols(pred = as.numeric(preds)) %>%
mutate(bclass.pred = factor(pred > 0.5,
labels = levels(bclass)))
# define classification metric panel
<- metric_set(sensitivity,
# compute test set accuracy
%>% panel(truth = bclass,
pred_df estimate = bclass.pred,
pred, event_level = 'second')
Briefly discuss with your table:
- How satisfied are you with the predictive performance?
- Does the classifier do a better job picking out relevant pages or irrelevant pages?
Activity 3 (10 min)
Now we’ll fit a multinomial logistic regression model using the multiclass labels rather than the binary ones, still using regularization to prevent overfitting.
Step 1: multinomial regression
Use the code chunk below to do the fitting. Notice that it’s as simple as supplying the multiclass labels and changing the family = 'binomial'
to family = 'multinomial'
, but the number of non-intercept parameters is now
\[ \text{number of predictors} \times (\text{number of classes} - 1) \]
So in our case, the logistic regression model had \(p = 55\) , but when we fit a multinomial model to the data using labels with \(k = 5\) classes, we have \(p(k - 1) = 220\) parameters!
# get multiclass labels
<- train_labels %>% pull(mclass)
# fit enet model
<- 0.2
alpha_enet <- glmnet(x = x_train,
fit_reg_multi y = y_train_multi,
family = 'multinomial',
alpha = alpha_enet)
# choose a strength by cross-validation
<- cv.glmnet(x = x_train,
cvout_multi y = y_train_multi,
family = 'multinomial',
alpha = alpha_enet)
# view results
Step 2: predictions
The predictions from this model are a set of proabilities, one per class:
<- predict(fit_reg_multi,
preds_multi s = cvout_multi$lambda.min,
newx = x_test,
type = 'response')
as_tibble(preds_multi[, , 1])
If we choose the most probable class as the prediction and cross-tabulate with the actual label, we end up with the following table:
<- as_tibble(preds_multi[, , 1]) %>%
pred_class mutate(row = row_number()) %>%
names_to = 'label',
values_to = 'probability') %>%
group_by(row) %>%
slice_max(probability, n = 1) %>%
<- table(pull(test_labels, mclass), pred_class)
If time, take a moment to discuss:
- What do you think of the overall accuracy?
- Which classes are well-predicted and which are not?
- Do you prefer the logistic or multinomial regression and why?