1 交叉验证

Code
library(tidymodels)
#> ── Attaching packages ────────────────────────────────────── tidymodels 1.2.0 ──
#> ✔ broom        1.0.6     ✔ rsample      1.2.1
#> ✔ dials        1.2.1     ✔ tune         1.2.1
#> ✔ infer        1.0.7     ✔ workflows    1.1.4
#> ✔ modeldata    1.4.0     ✔ workflowsets 1.1.0
#> ✔ parsnip      1.2.1     ✔ yardstick    1.3.1
#> ✔ recipes      1.1.0

data <- read_csv("data/Default.csv") 
#> Rows: 10000 Columns: 4
#> ── Column specification ────────────────────────────────────────────────────────
#> Delimiter: ","
#> chr (2): default, student
#> dbl (2): balance, income
#> 
#> ℹ Use `spec()` to retrieve the full column specification for this data.
#> ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.

data <- data %>% 
    mutate(
        default = factor(if_else(default == "No",0,1)),
        student = factor(if_else(student == "No",0,1)),
    )

data %>% DT::datatable()
Code

# 设置种子以确保可重复性
set.seed(10)


# 将数据分为训练集和测试集
data_split <- initial_split(data, prop = 0.8)
train_data <- training(data_split)
test_data <- testing(data_split)

1.1 tidymodels

Code
# 定义模型
log_reg <- logistic_reg() %>% 
  set_engine("glm")

# 建立工作流
full_workflow <- workflow() %>%
  add_model(log_reg) %>%
  add_formula(default ~ .)

# 拟合全模型
full_fit <- full_workflow %>% 
  fit(data = train_data)

# 预测并评估全模型
full_predictions <- full_fit %>% 
  predict(test_data, type = "prob") %>%
  bind_cols(test_data %>% select(default))


# 绘制全模型的ROC曲线
full_roc <- roc_curve(full_predictions,
                      truth = default,
                      .pred_1,
                      event_level = "second") %>%
         #第二级逻辑将结果编码为0/1(在这种情况下,第二个值是事件)
    autoplot() + ggtitle("ROC Curve - Full Model")

print(full_roc)

Code

# 计算校准曲线
full_predictions <- full_predictions %>%
  mutate(pred_bin = cut(.pred_1, breaks = seq(0, 1, by = 0.1)))

calibration_data <- full_predictions %>%
  group_by(pred_bin) %>%
  dplyr::summarize(mean_pred = mean(.pred_1), 
            mean_actual = mean(default == "1"))

ggplot(calibration_data, aes(x = mean_pred, y = mean_actual)) +
  geom_point() +
  geom_line() +
  geom_abline(slope = 1, intercept = 0, linetype = "dashed") +
  xlim(0, 1) + ylim(0, 1) +
  ggtitle("Calibration Curve - Full Model") +
  xlab("Predicted Probability") +
  ylab("Observed Probability")

Code
# 定义交叉验证
cv_5 <- vfold_cv(train_data, v = 5)
cv_10 <- vfold_cv(train_data, v = 10)

# 5折交叉验证
cv_5_results <- fit_resamples(
  full_workflow,
  resamples = cv_5,
  metrics = metric_set(roc_auc),
  control = control_resamples(save_pred = TRUE)
)

# 10折交叉验证
cv_10_results <- fit_resamples(
  full_workflow,
  resamples = cv_10,
  metrics = metric_set(roc_auc),
  control = control_resamples(save_pred = TRUE)
)

# 预测和评估5折交叉验证模型
cv_5_predictions <- collect_predictions(cv_5_results)

# 绘制5折交叉验证模型的ROC曲线
cv_5_roc <- roc_curve(cv_5_predictions, truth = default, .pred_1 , event_level = "second") %>% 
  autoplot() + ggtitle("ROC Curve - 5-fold Cross-Validation Model")

print(cv_5_roc)

Code

# 计算5折交叉验证的校准曲线
cv_5_predictions <- cv_5_predictions %>%
  mutate(pred_bin = cut(.pred_1, breaks = seq(0, 1, by = 0.1)))

calibration_data_5 <- cv_5_predictions %>%
  group_by(pred_bin) %>%
  dplyr::summarize(mean_pred = mean(.pred_1), 
            mean_actual = mean(default == "1"))

ggplot(calibration_data_5, aes(x = mean_pred, y = mean_actual)) +
  geom_point() +
  geom_line() +
  geom_abline(slope = 1, intercept = 0, linetype = "dashed") +
  xlim(0, 1) + ylim(0, 1) +
  ggtitle("Calibration Curve - 5-fold Cross-Validation Model") +
  xlab("Predicted Probability") +
  ylab("Observed Probability")

Code

# 预测和评估10折交叉验证模型
cv_10_predictions <- collect_predictions(cv_10_results)

# 绘制10折交叉验证模型的ROC曲线
cv_10_roc <- roc_curve(cv_10_predictions, truth = default, .pred_1,  event_level = "second") %>% 
  autoplot() + ggtitle("ROC Curve - 10-fold Cross-Validation Model")

print(cv_10_roc)

Code

# 计算10折交叉验证的校准曲线
cv_10_predictions <- cv_10_predictions %>%
  mutate(pred_bin = cut(.pred_1, breaks = seq(0, 1, by = 0.1)))

calibration_data_10 <- cv_10_predictions %>%
  group_by(pred_bin) %>%
  dplyr::summarize(mean_pred = mean(.pred_1), 
            mean_actual = mean(default == "1"))

ggplot(calibration_data_10, aes(x = mean_pred, y = mean_actual)) +
  geom_point() +
  geom_line() +
  geom_abline(slope = 1, intercept = 0, linetype = "dashed") +
  xlim(0, 1) + ylim(0, 1) +
  ggtitle("Calibration Curve - 10-fold Cross-Validation Model") +
  xlab("Predicted Probability") +
  ylab("Observed Probability")

Back to top