Tuning text models

model tuning
text analysis
logistic regression
Bayesian optimization
extracting results

Prepare text data for predictive modeling and tune with both grid and iterative search.

Introduction

To use code in this article, you will need to install the following packages: stopwords, textfeatures, textrecipes, and tidymodels.

This article demonstrates an advanced example for training and tuning models for text data. Text data must be processed and transformed to a numeric representation to be ready for computation in modeling; in tidymodels, we use a recipe for this preprocessing. This article also shows how to extract information from each model fit during tuning to use later on.

Text as data

The text data we’ll use in this article are from Amazon:

This dataset consists of reviews of fine foods from amazon. The data span a period of more than 10 years, including all ~500,000 reviews up to October 2012. Reviews include product and user information, ratings, and a plaintext review.

This article uses a small subset of the total reviews available at the original source. We sampled a single review from 5,000 random products and allocated 80% of these data to the training set, with the remaining 1,000 reviews held out for the test set.

There is a column for the product, a column for the text of the review, and a factor column for the outcome variable. The outcome is whether the reviewer gave the product a five-star rating or not.

library(tidymodels)

data("small_fine_foods")
training_data
#> # A tibble: 4,000 × 3
#>    product    review                                                       score
#>    <chr>      <chr>                                                        <fct>
#>  1 B000J0LSBG "this stuff is  not stuffing  its  not good at all  save yo… other
#>  2 B000EYLDYE "I absolutely LOVE this dried fruit.  LOVE IT.  Whenever I … great
#>  3 B0026LIO9A "GREAT DEAL, CONVENIENT TOO.  Much cheaper than WalMart and… great
#>  4 B00473P8SK "Great flavor, we go through a ton of this sauce! I discove… great
#>  5 B001SAWTNM "This is excellent salsa/hot sauce, but you can get it for … great
#>  6 B000FAG90U "Again, this is the best dogfood out there.  One suggestion… great
#>  7 B006BXTCEK "The box I received was filled with teas, hot chocolates, a… other
#>  8 B002GWH5OY "This is delicious coffee which compares favorably with muc… great
#>  9 B003R0MFYY "Don't let these little tiny cans fool you.  They pack a lo… great
#> 10 B001EO5ZXI "One of the nicest, smoothest cup of chai I've made. Nice m… great
#> # ℹ 3,990 more rows

Our modeling goal is to create modeling features from the text of the reviews to predict whether the review was five-star or not.

Resampling

There are enough data here so that 10-fold resampling would hold out 400 reviews at a time to estimate performance. Performance estimates using this many observations have sufficiently low noise to measure and tune models.

set.seed(8935)
folds <- vfold_cv(training_data)
folds
#> #  10-fold cross-validation 
#> # A tibble: 10 × 2
#>    splits             id    
#>    <list>             <chr> 
#>  1 <split [3600/400]> Fold01
#>  2 <split [3600/400]> Fold02
#>  3 <split [3600/400]> Fold03
#>  4 <split [3600/400]> Fold04
#>  5 <split [3600/400]> Fold05
#>  6 <split [3600/400]> Fold06
#>  7 <split [3600/400]> Fold07
#>  8 <split [3600/400]> Fold08
#>  9 <split [3600/400]> Fold09
#> 10 <split [3600/400]> Fold10

Extracted results

Let’s return to the grid search results and examine the results of our extract function. For each fitted model, a tibble was saved that contains the relationship between the number of predictors and the penalty value. Let’s look at these results for the best model:

params <- select_best(five_star_glmnet, metric = "roc_auc")
params
#> # A tibble: 1 × 4
#>   penalty mixture num_terms .config               
#>     <dbl>   <dbl>     <dbl> <chr>                 
#> 1   0.695    0.01      4096 Preprocessor3_Model019

Recall that we saved the glmnet results in a tibble. The column five_star_glmnet$.extracts is a list of tibbles. As an example, the first element of the list is:

five_star_glmnet$.extracts[[1]]
#> # A tibble: 300 × 5
#>    num_terms penalty mixture .extracts          .config               
#>        <dbl>   <dbl>   <dbl> <list>             <chr>                 
#>  1       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model001
#>  2       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model002
#>  3       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model003
#>  4       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model004
#>  5       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model005
#>  6       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model006
#>  7       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model007
#>  8       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model008
#>  9       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model009
#> 10       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model010
#> # ℹ 290 more rows

More nested tibbles! Let’s unnest() the five_star_glmnet$.extracts column:

library(tidyr)
extracted <- 
  five_star_glmnet %>% 
  dplyr::select(id, .extracts) %>% 
  unnest(cols = .extracts)
extracted
#> # A tibble: 3,000 × 6
#>    id     num_terms penalty mixture .extracts          .config               
#>    <chr>      <dbl>   <dbl>   <dbl> <list>             <chr>                 
#>  1 Fold01       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model001
#>  2 Fold01       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model002
#>  3 Fold01       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model003
#>  4 Fold01       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model004
#>  5 Fold01       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model005
#>  6 Fold01       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model006
#>  7 Fold01       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model007
#>  8 Fold01       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model008
#>  9 Fold01       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model009
#> 10 Fold01       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model010
#> # ℹ 2,990 more rows

One thing to realize here is that tune_grid() may not fit all of the models that are evaluated. In this case, for each value of mixture and num_terms, the model is fit over all penalty values (this is a feature of this particular model and is not generally true for other engines). To select the best parameter set, we can exclude the penalty column in extracted:

extracted <- 
  extracted %>% 
  dplyr::select(-penalty) %>% 
  inner_join(params, by = c("num_terms", "mixture")) %>% 
  # Now remove it from the final results
  dplyr::select(-penalty)
extracted
#> # A tibble: 200 × 6
#>    id     num_terms mixture .extracts          .config.x              .config.y 
#>    <chr>      <dbl>   <dbl> <list>             <chr>                  <chr>     
#>  1 Fold01      4096    0.01 <tibble [100 × 2]> Preprocessor3_Model001 Preproces…
#>  2 Fold01      4096    0.01 <tibble [100 × 2]> Preprocessor3_Model002 Preproces…
#>  3 Fold01      4096    0.01 <tibble [100 × 2]> Preprocessor3_Model003 Preproces…
#>  4 Fold01      4096    0.01 <tibble [100 × 2]> Preprocessor3_Model004 Preproces…
#>  5 Fold01      4096    0.01 <tibble [100 × 2]> Preprocessor3_Model005 Preproces…
#>  6 Fold01      4096    0.01 <tibble [100 × 2]> Preprocessor3_Model006 Preproces…
#>  7 Fold01      4096    0.01 <tibble [100 × 2]> Preprocessor3_Model007 Preproces…
#>  8 Fold01      4096    0.01 <tibble [100 × 2]> Preprocessor3_Model008 Preproces…
#>  9 Fold01      4096    0.01 <tibble [100 × 2]> Preprocessor3_Model009 Preproces…
#> 10 Fold01      4096    0.01 <tibble [100 × 2]> Preprocessor3_Model010 Preproces…
#> # ℹ 190 more rows

Now we can get at the results that we want using another unnest():

extracted <- 
  extracted %>% 
  unnest(col = .extracts) # <- these contain a `penalty` column
extracted
#> # A tibble: 20,000 × 7
#>    id     num_terms mixture penalty num_vars .config.x              .config.y   
#>    <chr>      <dbl>   <dbl>   <dbl>    <int> <chr>                  <chr>       
#>  1 Fold01      4096    0.01    8.60        0 Preprocessor3_Model001 Preprocesso…
#>  2 Fold01      4096    0.01    8.21        2 Preprocessor3_Model001 Preprocesso…
#>  3 Fold01      4096    0.01    7.84        2 Preprocessor3_Model001 Preprocesso…
#>  4 Fold01      4096    0.01    7.48        3 Preprocessor3_Model001 Preprocesso…
#>  5 Fold01      4096    0.01    7.14        3 Preprocessor3_Model001 Preprocesso…
#>  6 Fold01      4096    0.01    6.82        3 Preprocessor3_Model001 Preprocesso…
#>  7 Fold01      4096    0.01    6.51        4 Preprocessor3_Model001 Preprocesso…
#>  8 Fold01      4096    0.01    6.21        6 Preprocessor3_Model001 Preprocesso…
#>  9 Fold01      4096    0.01    5.93        7 Preprocessor3_Model001 Preprocesso…
#> 10 Fold01      4096    0.01    5.66        7 Preprocessor3_Model001 Preprocesso…
#> # ℹ 19,990 more rows

Let’s look at a plot of these results (per resample):

ggplot(extracted, aes(x = penalty, y = num_vars)) + 
  geom_line(aes(group = id, col = id), alpha = .5) + 
  ylab("Number of retained predictors") + 
  scale_x_log10()  + 
  ggtitle(paste("mixture = ", params$mixture, "and", params$num_terms, "features")) + 
  theme(legend.position = "none")

These results might help guide the choice of the penalty range if more optimization was conducted.

Session information

#> ─ Session info ─────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.3.1 (2023-06-16)
#>  os       macOS Ventura 13.5.2
#>  system   aarch64, darwin20
#>  ui       X11
#>  language (EN)
#>  collate  en_US.UTF-8
#>  ctype    en_US.UTF-8
#>  tz       America/Los_Angeles
#>  date     2023-09-26
#>  pandoc   3.1.1 @ /Applications/RStudio.app/Contents/Resources/app/quarto/bin/tools/ (via rmarkdown)
#> 
#> ─ Packages ─────────────────────────────────────────────────────────
#>  package      * version date (UTC) lib source
#>  broom        * 1.0.5   2023-06-09 [1] CRAN (R 4.3.0)
#>  dials        * 1.2.0   2023-04-03 [1] CRAN (R 4.3.0)
#>  dplyr        * 1.1.3   2023-09-03 [1] CRAN (R 4.3.0)
#>  ggplot2      * 3.4.3   2023-08-14 [1] CRAN (R 4.3.0)
#>  infer        * 1.0.5   2023-09-06 [1] CRAN (R 4.3.0)
#>  parsnip      * 1.1.1   2023-08-17 [1] CRAN (R 4.3.0)
#>  purrr        * 1.0.2   2023-08-10 [1] CRAN (R 4.3.0)
#>  recipes      * 1.0.8   2023-08-25 [1] CRAN (R 4.3.0)
#>  rlang          1.1.1   2023-04-28 [1] CRAN (R 4.3.0)
#>  rsample      * 1.2.0   2023-08-23 [1] CRAN (R 4.3.0)
#>  stopwords    * 2.3     2021-10-28 [1] CRAN (R 4.3.0)
#>  textfeatures * 0.3.3   2019-09-03 [1] CRAN (R 4.3.0)
#>  textrecipes  * 1.0.4   2023-08-17 [1] CRAN (R 4.3.0)
#>  tibble       * 3.2.1   2023-03-20 [1] CRAN (R 4.3.0)
#>  tidymodels   * 1.1.1   2023-08-24 [1] CRAN (R 4.3.0)
#>  tune         * 1.1.2   2023-08-23 [1] CRAN (R 4.3.0)
#>  workflows    * 1.1.3   2023-02-22 [1] CRAN (R 4.3.0)
#>  yardstick    * 1.2.0   2023-04-21 [1] CRAN (R 4.3.0)
#> 
#>  [1] /Users/emilhvitfeldt/Library/R/arm64/4.3/library
#>  [2] /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/library
#> 
#> ────────────────────────────────────────────────────────────────────
Resources
Explore searchable tables of all tidymodels packages and functions.
Study up on statistics and modeling with our comprehensive books.
Hear the latest about tidymodels packages at the tidyverse blog.