Multinomial logistic regression

PSTAT197A/CMPSC190DD Fall 2022

Trevor Ruiz

UCSB

Dimension reduction

From last time

Last time we ended having constructed a TF-IDF document term matrix for the claims data.

  • \(n = 552\) observations

  • \(p = 15,868\) variables (word tokens)

  • binary (rather than multiclass) labels

# A tibble: 552 × 15,870
   .id    bclass    adams afternoon  agent android    app arkansas arrest arrive
   <chr>  <fct>     <dbl>     <dbl>  <dbl>   <dbl>  <dbl>    <dbl>  <dbl>  <dbl>
 1 url1   relevant 0.0692    0.0300 0.0365  0.0450 0.0330   0.0390 0.0140 0.0305
 2 url10  irrelev… 0         0      0       0      0        0      0      0     
 3 url100 irrelev… 0         0      0       0      0        0      0      0     
 4 url101 relevant 0         0      0       0      0        0      0      0     
 5 url102 relevant 0         0      0       0      0        0      0      0     
 6 url105 relevant 0         0      0       0      0        0      0      0     
 7 url106 irrelev… 0         0      0       0      0        0      0      0     
 8 url107 irrelev… 0         0      0       0      0        0      0      0     
 9 url108 relevant 0         0      0       0      0        0      0      0     
10 url109 irrelev… 0         0      0       0      0        0      0      0     
# … with 542 more rows, and 15,860 more variables: assist <dbl>, attic <dbl>,
#   barricade <dbl>, block <dbl>, blytheville <dbl>, burn <dbl>, captain <dbl>,
#   catch <dbl>, check <dbl>, chemical <dbl>, copyright <dbl>, county <dbl>,
#   custody <dbl>, dehydration <dbl>, demand <dbl>, department <dbl>,
#   desktop <dbl>, device <dbl>, dispute <dbl>, division <dbl>, drug <dbl>,
#   enter <dbl>, exit <dbl>, family <dbl>, federal <dbl>, fire <dbl>,
#   force <dbl>, gas <dbl>, gray <dbl>, hartzell <dbl>, hold <dbl>, …

High dimensionality, again

Similar to the ASD data, we again have \(p > n\): more predictors than observations.

But this time, model interpretation is not important.

  • The goal is prediction, not explanation.

  • Individual tokens aren’t likely to be strongly associated with the labels, anyway.

So we have more options for tackling the dimensionality problem.

Sparsity

Another way of saying we have 15,868 predictors is that the predictor is in 15,868-dimensional space.

However, the document term matrix is extremely sparse:

# coerce DTM to sparse matrix
claims_dtm <- claims %>% 
  select(-.id, -bclass) %>%
  as.matrix() %>%
  as('sparseMatrix') 

# proportion of zero entries ('sparsity')
1 - nnzero(claims_dtm)/length(claims_dtm)
[1] 0.99278

Projection

Since >99% of data values are zero, there is almost certainly a low(er)-dimensional representation that well-approximates the full ~16K-dimensional predictor.

So here’s a strategy:

  • project the predictor onto a subspace

  • fit a logistic regression model using the projected data

Principal components

The principal components of a data matrix \(X\) are an orthogonal basis (i.e., coordinate system) for its column space such that the variance of data projections is maximized along each direction.

  • subcollections of PC’s span subspaces

  • used to find a projection that preserves variance

    • choose the first \(k\) PC’s along which the projected data retain XX% of total variance

Illustration

Image from wikipedia.

Computation

The principal components can be computed by singular value decomposition (SVD):

\[ X = UDV' \]

  • columns of \(V\) give the projections

  • diagonals of \(D\) give the standard deviations on each direction

Selecting components

  1. Find the smallest number of components such that the proportion of variance retained exceeds a specified value:

    \[ n_{pc} = \min \left\{i: \frac{\sum_{j = 1}^i d_{jj}^2}{\sum_i d_{ii}^2} > q\right\} \]

  2. Select the corresponding projections and project the data:

    \[ \tilde{X} = XV_{1:n_{pc}} \quad\text{where}\quad V_{1:n_{pc}} = \left( v_1 \;\cdots\; v_{n_{pc}}\right) \]

Projected data are referred to as ‘scores’.

Implementation

Usually prcomp() does the trick, and has a broom::tidy method available, but it’s slow for large matrices.

Better to use SVD implemented with sparse matrix computations.

start <- Sys.time()
svd_out <- sparsesvd(claims_dtm)
end <- Sys.time()
time_ssvd <- end - start

start <- Sys.time()
prcomp_out <- prcomp(claims_dtm, center = T)
end <- Sys.time()
time_prcomp <- end - start

time_prcomp - time_ssvd
Time difference of 38.9305 secs

Obtaining projections

For today, we’ll use a function I’ve written to obtain principal components. It’s basically a wrapper around sparsesvd().

The following will return the data projected onto a subspace in which it retains at least .prop percent of the total variance.

proj_out <- projection_fn(claims_dtm, .prop = 0.7)

proj_out$data
# A tibble: 552 × 63
        pc1      pc2     pc3      pc4     pc5      pc6      pc7      pc8     pc9
      <dbl>    <dbl>   <dbl>    <dbl>   <dbl>    <dbl>    <dbl>    <dbl>   <dbl>
 1  5.24e-6 -7.68e-5 0.0135  -2.26e-4 0.00137 -0.00586  0.00255 -4.79e-3  0.0292
 2  1.03e-5 -3.51e-4 0.00205 -2.46e-2 0.0627  -0.00335  0.00468 -1.91e-3  0.0661
 3  1.00e-5 -6.56e-5 0.00281 -2.10e-4 0.00532 -0.00970  0.00158 -2.10e-3  0.0392
 4  9.92e-6 -1.74e-3 0.00527 -1.32e-3 0.00296 -0.0821   0.297    1.48e-2  0.0614
 5  3.90e-6 -7.43e-5 0.00575 -4.68e-4 0.00245 -0.00441  0.00891 -1.97e-3  0.0209
 6  6.65e-6 -1.61e-3 0.0546  -8.38e-4 0.00198 -0.0342   0.119   -1.18e-2  0.0699
 7  4.95e-6 -1.74e-4 0.00415 -6.60e-4 0.00170 -0.00907  0.0245  -2.75e-3  0.0600
 8  7.46e-6 -1.41e-4 0.00101 -2.58e-4 0.00130 -0.00273  0.00336 -8.41e-4  0.0267
 9  2.01e-5 -4.44e-4 0.00172 -5.78e-4 0.00157 -0.0133   0.0397   3.82e-4  0.0609
10  6.76e-7 -6.03e-5 0.00288 -9.14e-4 0.00397 -2.05    -0.537    2.95e-2 -0.0998
# … with 542 more rows, and 54 more variables: pc10 <dbl>, pc11 <dbl>,
#   pc12 <dbl>, pc13 <dbl>, pc14 <dbl>, pc15 <dbl>, pc16 <dbl>, pc17 <dbl>,
#   pc18 <dbl>, pc19 <dbl>, pc20 <dbl>, pc21 <dbl>, pc22 <dbl>, pc23 <dbl>,
#   pc24 <dbl>, pc25 <dbl>, pc26 <dbl>, pc27 <dbl>, pc28 <dbl>, pc29 <dbl>,
#   pc30 <dbl>, pc31 <dbl>, pc32 <dbl>, pc33 <dbl>, pc34 <dbl>, pc35 <dbl>,
#   pc36 <dbl>, pc37 <dbl>, pc38 <dbl>, pc39 <dbl>, pc40 <dbl>, pc41 <dbl>,
#   pc42 <dbl>, pc43 <dbl>, pc44 <dbl>, pc45 <dbl>, pc46 <dbl>, pc47 <dbl>, …

Activity 1 (10 min)

  1. Partition the claims data into training and test sets.
  2. Using the training data, find principal components that preserve at least 80% of the total variance and project the data onto those PCs.
  3. Fit a logistic regression model to the training data with binary class labels.

Overfitting

You should have observed a warning that numerically 0 or 1 fitted probabilities occurred.

  • that means the model fit some data points exactly

Overfitting occurs when a model is fit too closely to the training data.

  • measures of fit suggest high quality

  • but predicts poorly out of sample

The curious can verify this using the model you just fit.

Another use of regularization

Last week we spoke about using LASSO regularization for variable selection.

Regularization can also be used to reduce overfitting.

  • LASSO penalty \(\|\beta\|_1 < t\) works

  • ‘ridge’ penalty \(\|\beta\|_2 < t\) also works (but won’t shrink parameters to zero)

  • or the ‘elastic net’ penalty \(\|\beta\|_1 < t\) AND \(\|\beta\|_2 < s\)

Activity 2 (10 min)

  1. Follow activity instructions to fit a logistic regression model with an elastic net penalty to the training data.
  2. Quantify classification accuracy on the test data using sensitivity, specificity, and AUROC.
# A tibble: 4 × 3
  .metric     .estimator .estimate
  <chr>       <chr>          <dbl>
1 sensitivity binary         0.621
2 specificity binary         0.830
3 accuracy    binary         0.721
4 roc_auc     binary         0.796

Multinomial regression

Quick refresher

The logistic regression model is

\[ \log\left(\frac{P(Y_i = 1)}{P(Y_i = 0)}\right) = \beta_0 + \beta_1 x_{i1} + \cdots + \beta_p x_{ip} \]

This is for a binary outcome \(Y_i \in \{0, 1\}\).

Multinomial response

If the response is instead \(Y \in \{1, 2, \dots, K\}\), its probability distribution can be described by the multinomial distribution (with 1 trial):

\[ P(Y = k) = p_k \quad\text{for}\quad k = 1, \dots, k \quad\text{with}\quad \sum_k p_k = 1 \]

Multinomial regression

Multinomial regression fits the following model:

\[ \begin{aligned} \log\left(\frac{p_1}{p_K}\right) &= \beta_0^{(1)} + x_i^T \beta^{(1)} \\ \log\left(\frac{p_2}{p_K}\right) &= \beta_0^{(2)} + x_i^T \beta^{(2)} \\ &\vdots \\ \log\left(\frac{p_{K - 1}}{p_K}\right) &= \beta_0^{(K - 1)} + x_i^T \beta^{(K - 1)} \\ \end{aligned} \]

So the number of parameters is \((p + 1)\times (K - 1)\).

Prediction

With some manipulation, one can obtain expressions for each \(p_k\), and thus estimates of the probabilities \(\hat{p}_k\) for each class \(k\).

A natural prediction to use is whichever class is most probable:

\[ \hat{Y}_i = \arg\max_k \hat{p}_k \]

Activity 3 (10 min)

  1. Follow instructions to fit a multinomial model to the claims data.

  2. Compute predictions and evaluate accuracy.

Results

# A tibble: 111 × 5
   irrelevant physical fatality unlawful    other
        <dbl>    <dbl>    <dbl>    <dbl>    <dbl>
 1    0.758    0.0509   0.0639  0.0822   0.0454  
 2    0.206    0.0254   0.728   0.0257   0.0149  
 3    0.564    0.0680   0.244   0.0715   0.0526  
 4    0.206    0.0254   0.728   0.0257   0.0149  
 5    0.524    0.260    0.0303  0.137    0.0490  
 6    0.755    0.0505   0.0652  0.0836   0.0460  
 7    0.251    0.0304   0.669   0.0235   0.0257  
 8    0.121    0.839    0.0184  0.00427  0.0177  
 9    0.00946  0.00355  0.00326 0.978    0.00579 
10    0.0107   0.00248  0.985   0.000806 0.000859
# … with 101 more rows
fatality irrelevant other physical unlawful
fatality 17 10 0 0 0
irrelevant 2 46 1 1 3
other 0 3 0 0 0
physical 0 9 0 6 0
unlawful 0 6 0 0 7
# overall accuracy
sum(diag(pred_tbl))/sum(pred_tbl)
[1] 0.6846847
# classwise error rates
diag(pred_tbl)/rowSums(pred_tbl)
  fatality irrelevant      other   physical   unlawful 
 0.6296296  0.8679245  0.0000000  0.4000000  0.5384615 
# predictionwise error rates
diag(pred_tbl)/colSums(pred_tbl)
  fatality irrelevant      other   physical   unlawful 
 0.8947368  0.6216216  0.0000000  0.8571429  0.7000000