9 Logistic Regression

We proceed in stages. First we use the sparklyr machine learning library to run a logit and to its summary statistics and coefficients in a local table, and then we use a logit to make predictions.

9.1 Logit: extracting the coefficients

The equivalent of glm() in sparklyr is ml_generalized_linear_regression(), that has a syntax that is very similar to the one of glm:

f = heart ~ agecat + gender + state + bmi + smoking + income + education + marital_status 
logit  <- ml_generalized_linear_regression(dat, formula = f, family = "binomial", link = "logit")
logit %>% tidy() %>% 
  kable(digits=c(0,2,2,2,3),format.args = list(big.mark=","),
        caption="**Logit coefficients**. ")

Table 3: Logit coefficients.

term estimate std.error statistic p.value
(Intercept) -0.73 0.02 -30.23 0.000
agecat_18.0, 35.0 -5.19 0.02 -334.71 0.000
agecat_35.0, 45.0 -3.46 0.01 -343.03 0.000
agecat_45.0, 55.0 -2.31 0.01 -300.26 0.000
agecat_55.0, 65.0 -1.39 0.01 -208.65 0.000
agecat_65.0, 75.0 -0.71 0.01 -117.69 0.000
agecat_75.0, 85.0 -0.25 0.01 -45.21 0.000
gender_F -0.66 0.00 -211.33 0.000
state_Aboda 0.01 0.02 0.65 0.514
state_Sintbu 0.01 0.02 0.79 0.430
state_Isnor 0.02 0.02 0.89 0.376
state_Itsware 0.01 0.02 0.41 0.680
state_Haivismal 0.01 0.02 0.36 0.721
state_Blitzbar 0.00 0.02 -0.06 0.950
state_Morgenor 0.01 0.02 0.37 0.714
bmi_Overweight -0.01 0.01 -0.66 0.510
bmi_Normal -0.15 0.01 -10.11 0.000
bmi_Obese 0.20 0.01 13.17 0.000
smoking_Never smoked -0.03 0.00 -6.46 0.000
smoking_Ex-smoker 0.23 0.00 48.13 0.000
income_70K+ -0.21 0.01 -33.62 0.000
income_<20K 0.25 0.01 43.95 0.000
income_50K-70K -0.13 0.01 -20.44 0.000
income_20K-30K 0.14 0.01 23.46 0.000
income_40K-50K -0.02 0.01 -3.62 0.000
education_University 0.06 0.01 10.21 0.000
education_Diploma 0.05 0.01 7.31 0.000
education_Certif 0.00 0.01 0.08 0.940
education_NoCertif 0.07 0.01 11.65 0.000
marital_status_partnered -0.07 0.00 -15.38 0.000
marital_status_separated -0.11 0.01 -16.36 0.000
marital_status_single -0.18 0.01 -22.24 0.000

There is another command which is specific to logit, ml_logistic_regression(), which produces different output. We will use it in the next example.

9.2 Making predictions with logit

Here we use the same data as in the previous example to build a simple predictive model and to test its accuracy using cross validation.

It is common to test the accuracy of a model using several cross-validation samples, that we can create with the sparklyr function sdf_random_split(), that takes one Spark tables and splits it randomly in chunks of given size. We use it as follows:

weights = purrr::set_names(rep(0.1, 10), paste0("fold", 1:10))
weights
##  fold1  fold2  fold3  fold4  fold5  fold6  fold7  fold8  fold9 fold10 
##    0.1    0.1    0.1    0.1    0.1    0.1    0.1    0.1    0.1    0.1
vfolds <- sdf_random_split(dat, weights = weights, seed = 1)
sapply(vfolds, sdf_nrow)
##   fold1   fold2   fold3   fold4   fold5   fold6   fold7   fold8   fold9  fold10 
## 1001569 1003031 1005253 1002322 1001871 1000550 1002636 1002102 1001933 1002095

This function creates a list of 10 Spark tables (“folds”), each containing (approximately) 10% of the original table, as specified by the named list weights. We can now select the first fold as test set and the remaining 9 for training (in practice we would do this ten times selecting a different fold for testing each time, but not here for simplicity). To bind folds 2 to 9 we can use the convenience function sdf_bind_rows(), although do.call(rbind, …) would work as well:

training <- sdf_bind_rows(vfolds[2:10])
test <- vfolds[[1]]

Now we run the logit on the training set using the function ml_logistic_regression() and evaluate the performances on the test set using ml_evaluate():

logit <- ml_logistic_regression(training, formula = f)
validation <- ml_evaluate(logit, test)
validation
## BinaryLogisticRegressionSummaryImpl 
##  Access the following via `$` or `ml_summary()`. 
##  - features_col() 
##  - label_col() 
##  - predictions() 
##  - probability_col() 
##  - area_under_roc() 
##  - f_measure_by_threshold() 
##  - pr() 
##  - precision_by_threshold() 
##  - recall_by_threshold() 
##  - roc() 
##  - prediction_col() 
##  - accuracy() 
##  - f_measure_by_label() 
##  - false_positive_rate_by_label() 
##  - labels() 
##  - precision_by_label() 
##  - recall_by_label() 
##  - true_positive_rate_by_label() 
##  - weighted_f_measure() 
##  - weighted_false_positive_rate() 
##  - weighted_precision() 
##  - weighted_recall() 
##  - weighted_true_positive_rate()

As shown above the object validation contains a number of accuracy measures. For example we can extract the ROC curve and and the area under the ROC curve (AUC) and plot them:

roc <- validation$roc() %>% collect()
ggplot(roc, aes(x = FPR, y = TPR)) + geom_line() + geom_abline(lty = "dashed") + ggtitle(paste("AUC:", round(validation$area_under_roc(),2)))

The predictions on the test set are stored in validation under the field predictions() that has the structure:

validation$predictions()
## # Source: spark<?> [?? x 23]
##    geo      gender   age       id bmi        smoking        income  education  marital_status heart diabetes hypertension stroke cancer gvt_cost   cost state   agecat     features   label rawPrediction probability prediction
##    <chr>    <chr>  <int>    <int> <chr>      <chr>          <chr>   <chr>      <chr>          <int>    <int>        <int>  <int>  <int>    <dbl>  <dbl> <chr>   <chr>      <list>     <dbl> <list>        <list>           <dbl>
##  1 Aabdilli F         20 10616144 Normal     Current smoker 70K+    University separated          0        0            0      0      0    315.   351.  Itsware 18.0, 35.0 <dbl [31]>     0 <dbl [2]>     <dbl [2]>            0
##  2 Aabdilli F         31 12484925 Normal     Never smoked   70K+    Diploma    partnered          0        0            0      0      0     36.3   36.3 Itsware 18.0, 35.0 <dbl [31]>     0 <dbl [2]>     <dbl [2]>            0
##  3 Aabdilli F         37 13354911 Normal     Never smoked   70K+    University partnered          0        0            0      0      1    319.   319.  Itsware 35.0, 45.0 <dbl [31]>     0 <dbl [2]>     <dbl [2]>            0
##  4 Aabdilli F         44 14160212 Overweight Ex-smoker      50K-70K University partnered          0        0            0      0      0    193.   231.  Itsware 35.0, 45.0 <dbl [31]>     0 <dbl [2]>     <dbl [2]>            0
##  5 Aabdilli F         47 15009235 Obese      Never smoked   40K-50K University separated          0        0            0      0      1   1183.  1183.  Itsware 45.0, 55.0 <dbl [31]>     0 <dbl [2]>     <dbl [2]>            0
##  6 Aabdilli F         51 15792336 Overweight Never smoked   50K-70K University partnered          0        0            0      0      1    964.  1607.  Itsware 45.0, 55.0 <dbl [31]>     0 <dbl [2]>     <dbl [2]>            0
##  7 Aabdilli F         51 15792365 Normal     Ex-smoker      70K+    University partnered          0        0            0      0      1    579.   579.  Itsware 45.0, 55.0 <dbl [31]>     0 <dbl [2]>     <dbl [2]>            0
##  8 Aabdilli F         53 15792317 Obese      Never smoked   70K+    University partnered          0        0            1      0      1   2657.  3293.  Itsware 45.0, 55.0 <dbl [31]>     0 <dbl [2]>     <dbl [2]>            0
##  9 Aabdilli F         57 16571288 Obese      Ex-smoker      <20K    Certif     partnered          0        1            1      0      0   1905.  2270.  Itsware 55.0, 65.0 <dbl [31]>     0 <dbl [2]>     <dbl [2]>            0
## 10 Aabdilli F         59 16571220 Obese      Never smoked   30K-40K NoCertif   partnered          0        0            1      0      0    730.   730.  Itsware 55.0, 65.0 <dbl [31]>     0 <dbl [2]>     <dbl [2]>            0
## # ... with more rows

Notice that some of the columns, like “probability” are actually lists. In the case of probability the list contains the probability of labels 0 and 1 respectively. In order to extract them and make them proper columns Spark provides the convenience function sdf_separate_column(), that takes the name of the columns we are interested in and create columns with given names:

sdf_separate_column(validation$predictions(), "probability", into=c("prob0","prob1"))
## # Source: spark<?> [?? x 25]
##    geo      gender   age       id bmi        smoking        income  education  marital_status heart diabetes hypertension stroke cancer gvt_cost   cost state   agecat     features   label rawPrediction probability prediction prob0    prob1
##    <chr>    <chr>  <int>    <int> <chr>      <chr>          <chr>   <chr>      <chr>          <int>    <int>        <int>  <int>  <int>    <dbl>  <dbl> <chr>   <chr>      <list>     <dbl> <list>        <list>           <dbl> <dbl>    <dbl>
##  1 Aabdilli F         20 10616144 Normal     Current smoker 70K+    University separated          0        0            0      0      0    315.   351.  Itsware 18.0, 35.0 <dbl [31]>     0 <dbl [2]>     <dbl [2]>            0 0.999 0.000932
##  2 Aabdilli F         31 12484925 Normal     Never smoked   70K+    Diploma    partnered          0        0            0      0      0     36.3   36.3 Itsware 18.0, 35.0 <dbl [31]>     0 <dbl [2]>     <dbl [2]>            0 0.999 0.000921
##  3 Aabdilli F         37 13354911 Normal     Never smoked   70K+    University partnered          0        0            0      0      1    319.   319.  Itsware 35.0, 45.0 <dbl [31]>     0 <dbl [2]>     <dbl [2]>            0 0.995 0.00523 
##  4 Aabdilli F         44 14160212 Overweight Ex-smoker      50K-70K University partnered          0        0            0      0      0    193.   231.  Itsware 35.0, 45.0 <dbl [31]>     0 <dbl [2]>     <dbl [2]>            0 0.992 0.00846 
##  5 Aabdilli F         47 15009235 Obese      Never smoked   40K-50K University separated          0        0            0      0      1   1183.  1183.  Itsware 45.0, 55.0 <dbl [31]>     0 <dbl [2]>     <dbl [2]>            0 0.973 0.0266  
##  6 Aabdilli F         51 15792336 Overweight Never smoked   50K-70K University partnered          0        0            0      0      1    964.  1607.  Itsware 45.0, 55.0 <dbl [31]>     0 <dbl [2]>     <dbl [2]>            0 0.980 0.0202  
##  7 Aabdilli F         51 15792365 Normal     Ex-smoker      70K+    University partnered          0        0            0      0      1    579.   579.  Itsware 45.0, 55.0 <dbl [31]>     0 <dbl [2]>     <dbl [2]>            0 0.979 0.0211  
##  8 Aabdilli F         53 15792317 Obese      Never smoked   70K+    University partnered          0        0            1      0      1   2657.  3293.  Itsware 45.0, 55.0 <dbl [31]>     0 <dbl [2]>     <dbl [2]>            0 0.977 0.0229  
##  9 Aabdilli F         57 16571288 Obese      Ex-smoker      <20K    Certif     partnered          0        1            1      0      0   1905.  2270.  Itsware 55.0, 65.0 <dbl [31]>     0 <dbl [2]>     <dbl [2]>            0 0.898 0.102   
## 10 Aabdilli F         59 16571220 Obese      Never smoked   30K-40K NoCertif   partnered          0        0            1      0      0    730.   730.  Itsware 55.0, 65.0 <dbl [31]>     0 <dbl [2]>     <dbl [2]>            0 0.932 0.0683  
## # ... with more rows

Issue: Running things on my laptop it seems that the function ml_logistic_regression() is very slow compared to
ml_generalized_linear_regression(). The advantage of the first seems to be that its related object already contains the ROC curve, although a major disadvantage is that it does not return standard errors and p values. Also, the predicted object of the second function contains a clearly labeled column “prediction” that is the probability of class 1, so that it does not require a call to sdf_separate_column(). Overall, given how easy is it is to compute an ROC curve, it seems that ml_generalized_linear_regression() is a more interesting option.

Page built: 2020-01-21