Rで回帰分析:Lasso回帰とRidge回帰

Rでデータサイエンス

Lasso回帰とRidge回帰

損失関数

Lasso回帰

\[\sum_{i=1}^{n}\left(y_i-\hat{y}_i\right)^2+\lambda\sum_{k=1}^p|\beta_k|\]

Ridge回帰

\[\sum_{i=1}^{n}\left(y_i-\hat{y}_i\right)^2+\dfrac{1}{2}\,\lambda\sum_{k=1}^p\beta_k^2\]

サンプル

  • 説明変数の数を\(p\)、サンプルサイズを\(n\)とするサンプルデータを作成。

\[y_i=\beta_0+\beta_1x_{1i}+\beta_2x_{2i}+\cdots+\beta_px_{pi}+\epsilon_i\quad(i=1,\cdots,n)\] \[ \begin{bmatrix}y_i\\\vdots\\y_n\end{bmatrix}= \begin{bmatrix} 1&x_{11}&\cdots&x_{p1}\\ \vdots&\vdots&&\vdots\\ 1&x_{1n}&\cdots&x_{pn} \end{bmatrix} \begin{bmatrix} \beta_0\\\vdots\\\beta_p \end{bmatrix}+ \begin{bmatrix} \epsilon_1\\\vdots\\\epsilon_n \end{bmatrix} \] \[ \textrm{y}=\textrm{X}\,\pmb{\beta}+\pmb{\epsilon} \]

library(dplyr)
seed <- 20230524
set.seed(seed = seed)
n <- 50  # サンプルサイズ
# 完全に独立な説明変数×3
x1 <- rnorm(n, mean = 1, sd = 2)
x2 <- rnorm(n, mean = 2, sd = 2)
x3 <- rnorm(n, mean = 3, sd = 2)
# 多重共線性を持つ説明変数×2。x3 は独立として残す。
x4 <- 0.5 * x1 + 0.5 * x2 + rnorm(n, mean = 0, sd = 0.1)
x5 <- 0.3 * x1 + 0.3 * x2 + rnorm(n, mean = 0, sd = 0.1)
X0 <- cbind(x1, x2, x3, x4, x5)
X <- cbind(1, X0)
b0 <- 2
b <- c(b0, 3, -2, 2, 1, 3) %>%
    matrix(ncol = 1)
p <- 5  # 説明変数の数
e <- rnorm(n = n)
y <- X %*% b + e
list(y = y, X = X, b = b)
$y
             [,1]
 [1,] 11.57213018
 [2,] 11.22960457
 [3,] -0.13896306
 [4,]  4.31449805
 [5,] 14.69031440
 [6,]  0.82149846
 [7,] 19.99600447
 [8,] -0.21207788
 [9,] 19.53285171
[10,] 39.12990558
[11,] -4.56517796
[12,] 15.94402528
[13,] 23.84211232
[14,] 12.58717709
[15,] 13.92581049
[16,]  1.26871526
[17,]  8.02534626
[18,]  4.00073647
[19,]  9.68669223
[20,] 13.17805031
[21,] 15.23190292
[22,] 36.81500091
[23,] 16.31280521
[24,]  7.11019833
[25,] 20.97273612
[26,] 14.05288165
[27,] 12.91697812
[28,]  8.64177676
[29,]  2.78866037
[30,]  2.17320743
[31,] -1.55189971
[32,] 23.01675674
[33,] 18.31498908
[34,] 15.69959574
[35,]  4.29024918
[36,]  4.75312634
[37,] 20.04329523
[38,]  8.07195185
[39,]  9.87220900
[40,] 12.19552233
[41,] 11.75092160
[42,] -2.39223928
[43,] 13.31357165
[44,]  0.04945589
[45,]  0.48869858
[46,] 13.76796211
[47,] 19.90664727
[48,] 18.84018766
[49,]  6.23102250
[50,] 11.82768733

$X
                 x1         x2          x3         x4         x5
 [1,] 1 -0.62700435  3.6382242  6.74937816  1.5746888  0.9496591
 [2,] 1  1.70394218 -0.9988507  1.35523872  0.3821955  0.3015187
 [3,] 1 -0.53187667  5.0274903  1.30043744  2.3202978  1.4245806
 [4,] 1 -0.64619318 -1.9789378  2.19923810 -1.3862515 -0.8212360
 [5,] 1  2.49694000  1.8051241  1.53564577  2.2075833  1.3361047
 [6,] 1 -0.68898149  2.9448171  1.73906547  1.1955192  0.4580169
 [7,] 1  0.97370932  2.2777880  6.65206468  1.5167554  0.9271465
 [8,] 1 -2.29752517  1.4128089  3.63263601 -0.4464841 -0.2508631
 [9,] 1  0.98090672  1.9892576  7.47783494  1.6780158  0.8813097
[10,] 1  6.40536040  0.5966047  5.32234121  3.5932482  2.1507636
[11,] 1 -1.77089107  4.3353585  0.80856582  1.5639769  0.8655342
[12,] 1  1.72277218  2.3068828  3.69465589  2.0463389  1.3166244
[13,] 1  3.02856261  1.0577749  4.26220052  2.0372742  1.2657041
[14,] 1  1.72293370  4.0010203  2.49165065  2.8811131  1.7191462
[15,] 1  4.14842114  2.7646134 -1.44792371  3.4269788  1.9587177
[16,] 1 -0.07434143  2.3624635  1.35627930  1.0202697  0.5776994
[17,] 1  0.17472004  1.0131491  3.92679863  0.5775880  0.2375525
[18,] 1  0.50432732  1.7834863  1.34476749  1.2206051  0.7400788
[19,] 1  0.45777539  4.8941133  3.90448642  2.6711288  1.8286488
[20,] 1  0.18101907 -0.5653752  4.30313069 -0.1615270 -0.1007700
[21,] 1  0.75583281  2.6251589  5.35497729  1.8407398  0.9456129
[22,] 1  4.25168234 -0.2045558  7.81113477  1.9546665  1.2254274
[23,] 1  1.78515712  1.1689084  4.02198211  1.3923890  0.7158893
[24,] 1  0.57281863  0.5815357  1.51677059  0.4510171  0.2226798
[25,] 1  1.10745424  2.7191921  6.64395560  2.0689818  1.2490652
[26,] 1  2.18099774  2.8153554  2.26270197  2.3972797  1.5193136
[27,] 1 -0.17271574  3.0125872  6.95886632  1.3488219  0.8487359
[28,] 1  0.52333234  0.7712954  2.27092760  0.5866684  0.4550115
[29,] 1 -1.94151865  1.4117855  4.14617824 -0.2918326 -0.2907826
[30,] 1 -1.04286100  0.3683858  3.01611889 -0.4970893 -0.3602515
[31,] 1  0.77202858  5.4852953 -1.85628372  3.2539863  1.7984990
[32,] 1  2.26186636  2.1824909  5.99272889  2.0939357  1.3013296
[33,] 1  3.34198645  1.5560493  0.09584852  2.5814933  1.5072305
[34,] 1  2.63665152  5.2446234  3.00870450  3.9516949  2.2891774
[35,] 1  0.43903452  0.3953852 -0.07678263  0.5780739  0.3208047
[36,] 1 -1.94188946  2.3148825  6.88531831  0.1022032  0.1628152
[37,] 1  2.78963259  4.7522747  4.14847718  3.7905167  2.4306930
[38,] 1 -0.23485672  1.1980746  3.87252848  0.5776187  0.2859197
[39,] 1  0.47152183  4.0076978  3.97607268  2.1483553  1.2742200
[40,] 1  0.46362372  3.6975343  4.46451456  2.0008658  1.2957557
[41,] 1  0.78958542  1.6882553  2.81652926  1.2451234  0.9338407
[42,] 1 -2.00357792  1.4159499  2.65499885 -0.3083493 -0.3850617
[43,] 1  1.41936599  1.1407479  2.36015950  1.2872721  0.8413491
[44,] 1 -0.89240949  5.0604682  2.44499626  2.1369174  1.1927702
[45,] 1 -2.21736886  5.6379918  6.65510388  1.6360455  0.8481050
[46,] 1  0.14826974  0.6515563  6.48110381  0.4270066  0.3063347
[47,] 1  2.08292259  1.7703542  5.45336972  2.0253355  0.9951496
[48,] 1  2.83840199  4.8917578  2.33700365  3.9245972  2.4423175
[49,] 1 -0.39497737  1.4598331  3.19092015  0.5748499  0.4897625
[50,] 1  1.71672547  3.5860753  2.09065542  2.7901775  1.5491791

$b
     [,1]
[1,]    2
[2,]    3
[3,]   -2
[4,]    2
[5,]    1
[6,]    3
# 多重共線性の確認
car::vif(mod = lm(y ~ x1 + x2 + x3 + x4 + x5))
        x1         x2         x3         x4         x5 
106.989034 105.204805   1.049484 197.635668  58.505162 
# パラレル化と関数の読み込み
doParallel::registerDoParallel(parallel::detectCores())
library(glmnet)
packageVersion("glmnet")
[1] '4.1.8'

lasso回帰

set.seed(seed = seed)
result_lasso <- 
  cv.glmnet(x = X0,
            y = y,
            family = "gaussian",
            alpha = 1, # lasso回帰は1、Ridge回帰は0。
            standardize = T,
            intercept = T,
            grouped = F,
            parallel = T,
            nfolds = 10)
(s <- result_lasso$lambda.min)
[1] 0.01648718
(coef_lasso <- coefficients(object = result_lasso, s = s))
6 x 1 sparse Matrix of class "dgCMatrix"
                       s1
(Intercept)  2.0293548660
x1           3.2524585611
x2          -1.5520915998
x3           1.9914108842
x4           0.0004384731
x5           3.5268052193

Ridge回帰

set.seed(seed = seed)
result_ridge <- cv.glmnet(x = X0, y = y, family = "gaussian", alpha = 0, standardize = T, intercept = T, grouped = F, parallel = T, nfolds = 10)
(s <- result_ridge$lambda.min)
[1] 0.7652671
(coef_ridge <- coefficients(object = result_ridge, s = s))
6 x 1 sparse Matrix of class "dgCMatrix"
                   s1
(Intercept)  2.451201
x1           2.756858
x2          -1.733561
x3           1.850065
x4           1.199551
x5           2.457834

線形回帰

result_lm <- lm(y ~ X0)
# result_lm <- glm(y ~ X0,family = gaussian(link = 'identity'))
result_lm %>%
    summary()

Call:
lm(formula = y ~ X0)

Residuals:
     Min       1Q   Median       3Q      Max 
-2.34790 -0.58413 -0.04564  0.69683  2.15141 

Coefficients:
            Estimate Std. Error t value             Pr(>|t|)    
(Intercept)  2.00582    0.38469   5.214           0.00000474 ***
X0x1         1.85050    0.91797   2.016              0.04995 *  
X0x2        -2.97643    0.92251  -3.226              0.00237 ** 
X0x3         2.01399    0.07186  28.028 < 0.0000000000000002 ***
X0x4         2.32846    1.82219   1.278              0.20801    
X0x5         4.27996    1.62262   2.638              0.01149 *  
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Residual standard error: 1.121 on 44 degrees of freedom
Multiple R-squared:  0.9865,    Adjusted R-squared:  0.9849 
F-statistic: 641.7 on 5 and 44 DF,  p-value: < 0.00000000000000022

偏回帰係数の比較

data.frame(b = b, lm = coefficients(result_lm), lasso = coef_lasso[, 1], ridge = coef_ridge[, 1])
             b        lm         lasso     ridge
(Intercept)  2  2.005819  2.0293548660  2.451201
X0x1         3  1.850497  3.2524585611  2.756858
X0x2        -2 -2.976431 -1.5520915998 -1.733561
X0x3         2  2.013986  1.9914108842  1.850065
X0x4         1  2.328460  0.0004384731  1.199551
X0x5         3  4.279964  3.5268052193  2.457834

引数 standardize について

scale {base}による標準化

\[\dfrac{x_i-\bar{x}}{\displaystyle\sqrt{\dfrac{1}{n-1}\sum\left(x_i-\bar{x}\right)^2}}\]

# scale {base}による標準化
X0_centered <- apply(X = X0, MARGIN = 2, FUN = function(x) x - mean(x))
X0.scaled.by.scale <- apply(X = X0_centered, MARGIN = 2, FUN = function(x) x/sqrt(sum(x^2)/(length(x) - 1)))
sum(X0.scaled.by.scale != scale(X0))
[1] 0

glmnet {glmnet}による標準化

\[\dfrac{x_i-\bar{x}}{\displaystyle\sqrt{\dfrac{1}{n}\sum\left(x_i-\bar{x}\right)^2}}\]

# glmnet {glmnet}による標準化では n-1 ではなく n としている。
X0.scaled.by.glmnet <- apply(X = X0_centered, MARGIN = 2, FUN = function(x) x/sqrt(sum(x^2)/(length(x) - 0)))

両者の比較

# 比較
cbind(scale = X0.scaled.by.scale[, 1], glmnet = X0.scaled.by.glmnet[, 1]) %>%
    head(10)
            scale      glmnet
 [1,] -0.79486137 -0.80293123
 [2,]  0.49678693  0.50183058
 [3,] -0.74214823 -0.74968292
 [4,] -0.80549448 -0.81367230
 [5,]  0.93621112  0.94571604
 [6,] -0.82920478 -0.83762332
 [7,]  0.09214271  0.09307820
 [8,] -1.72054772 -1.73801565
 [9,]  0.09613101  0.09710698
[10,]  3.10198562  3.13347867

手作業による標準化

# 手作業による係数算出
# 標準化済みのデータをstandardizeはFALSEで渡し、結果の標準化偏回帰係数から偏回帰係数を求める。
## データの標準化
y_std <- (y - mean(y))/sqrt(sum((y - mean(y))^2)/n)
X0_std <- apply(X = X0, MARGIN = 2, FUN = function(x) (x - mean(x))/sqrt(sum((x - mean(x))^2)/n))
## 標準化済みのデータを渡しstandardizeはFALSEとする。
set.seed(seed = seed)
std_FALSE <- cv.glmnet(x = X0_std, y = y_std, intercept = T, standardize = FALSE, standardize.response = FALSE)
## 偏回帰係数
# 標準化偏回帰係数=偏回帰係数×説明変数の標準偏差/目的変数の標準偏差
# 偏回帰係数=標準化偏回帰係数×目的変数の標準偏差/説明変数の標準偏差
std_coef0 <- coef(object = std_FALSE, s = std_FALSE$lambda.min)[-1]
std_coef <- std_coef0 * sqrt(sum((y - mean(y))^2)/n)/apply(X = X0, MARGIN = 2, FUN = function(x) sqrt(sum((x - mean(x))^2)/n))
## 切片
std_intercept <- mean(y) - sum(std_coef * colMeans(X0))
## 手作業による切片と偏回帰係数
(standardize_by_manual <- c(std_intercept, std_coef))
                         x1            x2            x3            x4 
 2.0293548660  3.2524585611 -1.5520915998  1.9914108842  0.0004384731 
           x5 
 3.5268052193 
# glmnet による係数算出
# 偏回帰係数が求められる。
set.seed(seed = seed)
sfd_TRUE <- cv.glmnet(x = X0, y = y, intercept = TRUE, standardize = TRUE)
(standardize_by_glmnet <- coef(object = sfd_TRUE, s = sfd_TRUE$lambda.min))
6 x 1 sparse Matrix of class "dgCMatrix"
                       s1
(Intercept)  2.0293548660
x1           3.2524585611
x2          -1.5520915998
x3           1.9914108842
x4           0.0004384731
x5           3.5268052193
# 両者は一致する
cbind(by_manual = standardize_by_manual, by_glmnet = standardize_by_glmnet[, 1])
       by_manual     by_glmnet
    2.0293548660  2.0293548660
x1  3.2524585611  3.2524585611
x2 -1.5520915998 -1.5520915998
x3  1.9914108842  1.9914108842
x4  0.0004384731  0.0004384731
x5  3.5268052193  3.5268052193

標準化偏回帰係数の算出

set.seed(seed = seed)
result <- cv.glmnet(X0, y, intercept = T, standardize = T)
coef_by_non_scale <- coef(result, result$lambda.min)

Agresti method

\[ \begin{eqnarray} b_x^{*} & = & b_x \cdot \sigma _x\\ b_0^{*} & = & b_0 + \displaystyle\sum b_x \cdot \mu_x \end{eqnarray} \]

scale {base}による標準化

標準化したデータによる
set.seed(seed = seed)
result <- cv.glmnet(scale(X0), y, intercept = T, standardize = F)
(coef_by_scale <- coef(object = result, s = result$lambda.min))
6 x 1 sparse Matrix of class "dgCMatrix"
                       s1
(Intercept) 11.2867022628
x1           5.8694824331
x2          -2.7638296978
x3           4.5470050804
x4           0.0005417872
x5           2.6626186854
Agresti methodによる
# sigma_x の算出
sigma <- sapply(X = as.data.frame(X0), FUN = sd)
# mu_x の算出
mu <- sapply(X = as.data.frame(X0), FUN = mean)
# 標準化した偏回帰係数
b_std <- coef_by_non_scale[-1, 1] * sigma
# 標準化した切片
b0_std <- coef_by_non_scale[1, 1] + sum(coef_by_non_scale[-1, 1] * mu)
# 両者は一致する
cbind(by_scale = coef_by_scale[, 1], by_agresti = c(b0_std, b_std))
                 by_scale    by_agresti
(Intercept) 11.2867022628 11.2867022628
x1           5.8694824331  5.8694824331
x2          -2.7638296978 -2.7638296978
x3           4.5470050804  4.5470050804
x4           0.0005417872  0.0005417872
x5           2.6626186854  2.6626186854

手作業による標準化

標準化したデータによる
set.seed(seed = seed)
result <- cv.glmnet(X0_std, y, intercept = TRUE, standardize = F)
coef_by_manual <- coef(result, result$lambda.min)
Agresti methodによる
# sigma_x の算出
sigma <- apply(X0, 2, function(x) sqrt(sum((x - mean(x))^2)/n))
# mu_x の算出
mu <- sapply(X = as.data.frame(X0), FUN = mean)
# 標準化した偏回帰係数
b_std <- coef_by_non_scale[-1, 1] * sigma
# 標準化した切片
b0_std <- coef_by_non_scale[1, 1] + sum(coef_by_non_scale[-1, 1] * mu)
# 両者は一致する
cbind(by_scale = coef_by_manual[, 1], by_agresti = c(b0_std, b_std))
                 by_scale    by_agresti
(Intercept) 11.2867022628 11.2867022628
x1           5.8104911627  5.8104911627
x2          -2.7360518099 -2.7360518099
x3           4.5013053770  4.5013053770
x4           0.0005363419  0.0005363419
x5           2.6358580194  2.6358580194

参考引用資料

最終更新

Sys.time()
[1] "2024-03-27 12:35:14 JST"

R、Quarto、Package

R.Version()$version.string
[1] "R version 4.3.3 (2024-02-29 ucrt)"
quarto::quarto_version()
[1] '1.4.542'
packageVersion(pkg = "tidyverse")
[1] '2.0.0'

著者