Rでロジスティクス回帰分析

顔妻です。

今回はRを使ったロジスティクス回帰分析のやり方をご紹介します。

利用データ

titanicのデータを使います。trainが学習データ、testがテストデータ、submissionが最終的な成果判定用です。

前処理

基礎集計は割愛します。NAの穴埋めやSurvivedとクロス集計や比較分析を行ったときに関係が大きいものはトリミングを行っています。

library(tidyverse)

theme_set(theme_bw())

d <- read.csv('train.csv')
d_test <- read.csv('test.csv')
d_submission <- read.csv('submission.csv')

d %>% head()
d %>% summary()


#### 前処理 ####
d$Survived <- as.factor(d$Survived)
d_submission$Survived <- as.factor(d_submission$Survived)

#### データ成型 ####
d$Cabin_2 <- 
  ifelse(grepl(pattern = 'A',x = d$Cabin),'A',
  ifelse(grepl(pattern = 'B',x = d$Cabin),'B',
  ifelse(grepl(pattern = 'C',x = d$Cabin),'C',
  ifelse(grepl(pattern = 'D',x = d$Cabin),'D',
  ifelse(grepl(pattern = 'E',x = d$Cabin),'E',
  ifelse(grepl(pattern = 'F',x = d$Cabin),'F',
  ifelse(grepl(pattern = 'G',x = d$Cabin),'G','T'
  )))))))

d %>% group_by(Survived,Cabin_2) %>% summarise(cnt = n()) %>% spread(Cabin_2,cnt) %>% write_clip()
d$Age_2 <- ifelse(is.na(d$Age),median(d$Age,na.rm = T),d$Age)
d$Fare_2 <- ifelse(is.na(d$Fare),median(d$Fare,na.rm = T),d$Fare)  

d_test$Cabin_2 <- 
  ifelse(grepl(pattern = 'A',x = d_test$Cabin),'A',
  ifelse(grepl(pattern = 'B',x = d_test$Cabin),'B',
  ifelse(grepl(pattern = 'C',x = d_test$Cabin),'C',
  ifelse(grepl(pattern = 'D',x = d_test$Cabin),'D',
  ifelse(grepl(pattern = 'E',x = d_test$Cabin),'E',
  ifelse(grepl(pattern = 'F',x = d_test$Cabin),'F',
  ifelse(grepl(pattern = 'G',x = d_test$Cabin),'G','T'
  )))))))

d_test$Age_2 <- ifelse(is.na(d_test$Age),median(d_test$Age,na.rm = T),d_test$Age)  
d_test$Fare_2 <- ifelse(is.na(d_test$Fare),median(d_test$Fare,na.rm = T),d_test$Fare)  

ロジスティクス回帰分析の実施

Survivedに関係の強そうなものはひとまずすべて説明変数として設定しました。

結果としてPclass、Sexmale、Age_2、SibSpの要素の影響が強そうです。

d_logi_01 <- 
  glm(formula = Survived  ~  Pclass + Sex + Age_2 + SibSp + Parch + Fare_2 + Cabin_2,family = binomial,data = d)

summary(d_logi_01)

## Call:
## glm(formula = Survived ~ Pclass + Sex + Age_2 + SibSp + Parch + 
##     Fare_2 + Cabin_2, family = binomial, data = d)
## 
## Deviance Residuals: 
##     Min       1Q   Median       3Q      Max  
## -2.6224  -0.5667  -0.4168   0.6260   2.4266  
## 
## Coefficients:
##              Estimate Std. Error z value Pr(>|z|)    
## (Intercept)  5.005103   0.751177   6.663 2.68e-11 ***
## Pclass      -0.852748   0.173552  -4.914 8.95e-07 ***
## Sexmale     -2.798597   0.202995 -13.787  < 2e-16 ***
## Age_2       -0.042309   0.008096  -5.226 1.73e-07 ***
## SibSp       -0.358452   0.110372  -3.248  0.00116 ** 
## Parch       -0.110644   0.118124  -0.937  0.34892    
## Fare_2       0.003089   0.002587   1.194  0.23238    
## Cabin_2B     0.023954   0.722394   0.033  0.97355    
## Cabin_2C    -0.550281   0.668248  -0.823  0.41024    
## Cabin_2D     0.515511   0.740642   0.696  0.48641    
## Cabin_2E     0.912436   0.733874   1.243  0.21375    
## Cabin_2F     0.520992   0.945340   0.551  0.58155    
## Cabin_2G    -1.548005   1.241461  -1.247  0.21243    
## Cabin_2T    -0.659208   0.629322  -1.047  0.29487    
## ---
## Signif. codes:  
## 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 1186.66  on 890  degrees of freedom
## Residual deviance:  772.38  on 877  degrees of freedom
## AIC: 800.38
## 
## Number of Fisher Scoring iterations: 5

予測精度の検証

先ほどの予測モデルにテストデータを適用し、混合行列を使って精度評価を行います。

# テストデータを使ったモデルの適用
d_logi_01.pred <- predict(object = d_logi_01,newdata = d_test)

# 予測結果の出力
d_submission$pred_logi_01 <- predict(object = d_logi_01,newdata = d_test)
d_submission$pred_logi_01_flag <- ifelse(d_submission$pred_logi_01 >= 0.5,1,0)

# 混合行列の作成
tbl.logit_cm <-
  d_submission %>%
  ungroup() %>% 
  group_by(Survived,pred_logi_01_flag) %>% 
  summarise(cnt = n()) %>% 
  spread(Survived,cnt)
tbl.logit_cm

# 正答率、検出率の作成
(tbl.logit_cm[1,2] + tbl.logit_cm[2,3]) / length(d_submission$PassengerId) # 正答率
(tbl.logit_cm[2,3]) / (tbl.logit_cm[2,2] + tbl.logit_cm[2,3]) # 検出率

## > # 正答率、検出率の作成
## > (tbl.logit_cm[1,2] + tbl.logit_cm[2,3]) / length(d_submission$PassengerId) # 正答率
## 0
## 1 0.9210526
## > (tbl.logit_cm[2,3]) / (tbl.logit_cm[2,2] + tbl.logit_cm[2,3]) # 検出率
## 1
## 1 0.9541985