# 交叉验证
```{r}
library(tidymodels)
data <- read_csv("data/Default.csv")
data <- data %>%
mutate(
default = factor(if_else(default == "No",0,1)),
student = factor(if_else(student == "No",0,1)),
)
data %>% DT::datatable()
# 设置种子以确保可重复性
set.seed(10)
# 将数据分为训练集和测试集
data_split <- initial_split(data, prop = 0.8)
train_data <- training(data_split)
test_data <- testing(data_split)
```
## tidymodels
```{r}
# 定义模型
log_reg <- logistic_reg() %>%
set_engine("glm")
# 建立工作流
full_workflow <- workflow() %>%
add_model(log_reg) %>%
add_formula(default ~ .)
# 拟合全模型
full_fit <- full_workflow %>%
fit(data = train_data)
# 预测并评估全模型
full_predictions <- full_fit %>%
predict(test_data, type = "prob") %>%
bind_cols(test_data %>% select(default))
# 绘制全模型的ROC曲线
full_roc <- roc_curve(full_predictions,
truth = default,
.pred_1,
event_level = "second") %>%
#第二级逻辑将结果编码为0/1(在这种情况下,第二个值是事件)
autoplot() + ggtitle("ROC Curve - Full Model")
print(full_roc)
# 计算校准曲线
full_predictions <- full_predictions %>%
mutate(pred_bin = cut(.pred_1, breaks = seq(0, 1, by = 0.1)))
calibration_data <- full_predictions %>%
group_by(pred_bin) %>%
dplyr::summarize(mean_pred = mean(.pred_1),
mean_actual = mean(default == "1"))
ggplot(calibration_data, aes(x = mean_pred, y = mean_actual)) +
geom_point() +
geom_line() +
geom_abline(slope = 1, intercept = 0, linetype = "dashed") +
xlim(0, 1) + ylim(0, 1) +
ggtitle("Calibration Curve - Full Model") +
xlab("Predicted Probability") +
ylab("Observed Probability")
```
```{r}
# 定义交叉验证
cv_5 <- vfold_cv(train_data, v = 5)
cv_10 <- vfold_cv(train_data, v = 10)
# 5折交叉验证
cv_5_results <- fit_resamples(
full_workflow,
resamples = cv_5,
metrics = metric_set(roc_auc),
control = control_resamples(save_pred = TRUE)
)
# 10折交叉验证
cv_10_results <- fit_resamples(
full_workflow,
resamples = cv_10,
metrics = metric_set(roc_auc),
control = control_resamples(save_pred = TRUE)
)
# 预测和评估5折交叉验证模型
cv_5_predictions <- collect_predictions(cv_5_results)
# 绘制5折交叉验证模型的ROC曲线
cv_5_roc <- roc_curve(cv_5_predictions, truth = default, .pred_1 , event_level = "second") %>%
autoplot() + ggtitle("ROC Curve - 5-fold Cross-Validation Model")
print(cv_5_roc)
# 计算5折交叉验证的校准曲线
cv_5_predictions <- cv_5_predictions %>%
mutate(pred_bin = cut(.pred_1, breaks = seq(0, 1, by = 0.1)))
calibration_data_5 <- cv_5_predictions %>%
group_by(pred_bin) %>%
dplyr::summarize(mean_pred = mean(.pred_1),
mean_actual = mean(default == "1"))
ggplot(calibration_data_5, aes(x = mean_pred, y = mean_actual)) +
geom_point() +
geom_line() +
geom_abline(slope = 1, intercept = 0, linetype = "dashed") +
xlim(0, 1) + ylim(0, 1) +
ggtitle("Calibration Curve - 5-fold Cross-Validation Model") +
xlab("Predicted Probability") +
ylab("Observed Probability")
# 预测和评估10折交叉验证模型
cv_10_predictions <- collect_predictions(cv_10_results)
# 绘制10折交叉验证模型的ROC曲线
cv_10_roc <- roc_curve(cv_10_predictions, truth = default, .pred_1, event_level = "second") %>%
autoplot() + ggtitle("ROC Curve - 10-fold Cross-Validation Model")
print(cv_10_roc)
# 计算10折交叉验证的校准曲线
cv_10_predictions <- cv_10_predictions %>%
mutate(pred_bin = cut(.pred_1, breaks = seq(0, 1, by = 0.1)))
calibration_data_10 <- cv_10_predictions %>%
group_by(pred_bin) %>%
dplyr::summarize(mean_pred = mean(.pred_1),
mean_actual = mean(default == "1"))
ggplot(calibration_data_10, aes(x = mean_pred, y = mean_actual)) +
geom_point() +
geom_line() +
geom_abline(slope = 1, intercept = 0, linetype = "dashed") +
xlim(0, 1) + ylim(0, 1) +
ggtitle("Calibration Curve - 10-fold Cross-Validation Model") +
xlab("Predicted Probability") +
ylab("Observed Probability")
```