1 线性判别分析

Fisher 线性判别分析(Linear Discriminant Analysis, LDA):用于分类任务的降维技术。

Fisher判别法试图最大化类间差异(不同类别的数据点彼此远离)并最小化类内差异(同一类别的数据点尽可能聚集。

它侧重于最大化类间差异(between-class variance)与类内差异(within-class variance)的比率

1.0.1 MASS

Code
# 加载MASS包,它包含了lda函数
library(MASS)

# 加载内置的鸢尾花数据集
data(iris)

# 查看数据集结构
str(iris)
#> 'data.frame':    150 obs. of  5 variables:
#>  $ Sepal.Length: num  5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ...
#>  $ Sepal.Width : num  3.5 3 3.2 3.1 3.6 3.9 3.4 3.4 2.9 3.1 ...
#>  $ Petal.Length: num  1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
#>  $ Petal.Width : num  0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
#>  $ Species     : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...

# 应用Fisher线性判别分析
# 使用鸢尾花数据集的前四列作为特征,Species作为类别
lda_model <- lda(Species ~ ., data=iris)

# 查看判别模型的摘要
summary(lda_model)
#>         Length Class  Mode     
#> prior    3     -none- numeric  
#> counts   3     -none- numeric  
#> means   12     -none- numeric  
#> scaling  8     -none- numeric  
#> lev      3     -none- character
#> svd      2     -none- numeric  
#> N        1     -none- numeric  
#> call     3     -none- call     
#> terms    3     terms  call     
#> xlevels  0     -none- list
lda_model
#> Call:
#> lda(Species ~ ., data = iris)
#> 
#> Prior probabilities of groups:
#>     setosa versicolor  virginica 
#>  0.3333333  0.3333333  0.3333333 
#> 
#> Group means:
#>            Sepal.Length Sepal.Width Petal.Length Petal.Width
#> setosa            5.006       3.428        1.462       0.246
#> versicolor        5.936       2.770        4.260       1.326
#> virginica         6.588       2.974        5.552       2.026
#> 
#> Coefficients of linear discriminants:
#>                     LD1         LD2
#> Sepal.Length  0.8293776 -0.02410215
#> Sepal.Width   1.5344731 -2.16452123
#> Petal.Length -2.2012117  0.93192121
#> Petal.Width  -2.8104603 -2.83918785
#> 
#> Proportion of trace:
#>    LD1    LD2 
#> 0.9912 0.0088
# 打印判别函数的系数
print(lda_model$coefficients)
#> NULL

# 使用判别模型对数据进行分类
predicted_species <- predict(lda_model, iris)

# 计算准确率
accuracy <- sum(predicted_species$class == iris$Species) / nrow(iris)
print(paste("分类准确率:", accuracy))
#> [1] "分类准确率: 0.98"

# 可视化判别结果
plot(lda_model)

  1. 模型调用(Call):

    • 显示了创建LDA模型时使用的函数调用。在这个例子中,模型使用鸢尾花数据集的所有特征(Sepal.Length, Sepal.Width, Petal.Length, Petal.Width)来预测物种(Species)。
  2. 组的先验概率(Prior probabilities of groups):

    • 显示了每个物种(setosa, versicolor, virginica)的先验概率。这里每个物种的先验概率都是0.3333,意味着在没有任何额外信息的情况下,每个物种出现的概率是相同的。
  3. 组内均值(Group means):

    • 显示了每个物种在各个特征上的均值。例如,setosa物种的花萼长度(Sepal.Length)均值是5.006,花萼宽度(Sepal.Width)均值是3.428,花瓣长度(Petal.Length)均值是1.462,花瓣宽度(Petal.Width)均值是0.246。
  4. 线性判别系数(Coefficients of linear discriminants):

    • 显示了两个线性判别函数(LD1和LD2)的系数。这些系数用于计算判别分数,以区分不同的物种。例如,LD1判别函数中,花萼长度(Sepal.Length)的系数是0.8293776,花萼宽度(Sepal.Width)的系数是1.5344731,以此类推。
  5. 特征值的比例(Proportion of trace):

    • 显示了每个线性判别函数对总方差的解释比例。在这个例子中,LD1解释了99.12%的方差,而LD2仅解释了0.88%的方差。这表明LD1是主要的判别方向,而LD2的贡献相对较小。

如何使用这些信息:

  • 可以使用这些系数来计算每个观测值在LD1和LD2上的判别分数。判别分数的计算公式为: LD1=0.8293776×Sepal.Length+1.5344731×Sepal.Width−2.2012117×Petal.Length−2.8104603×Petal.WidthLD1=0.8293776×Sepal.Length+1.5344731×Sepal.Width−2.2012117×Petal.Length−2.8104603×Petal.Width LD2=−0.02410215×Sepal.Length−2.16452123×Sepal.Width+0.93192121×Petal.Length−2.83918785×Petal.WidthLD2=−0.02410215×Sepal.Length−2.16452123×Sepal.Width+0.93192121×Petal.Length−2.83918785×Petal.Width

  • 通常,主要的判别函数(在这个例子中是LD1)足以进行有效的分类。如果需要,也可以使用LD2作为辅助。

  • 根据判别分数,可以确定每个观测值最有可能属于的物种类别。

1.0.2 tidymodels

Code
library(tidymodels)
#> ── Attaching packages ────────────────────────────────────── tidymodels 1.2.0 ──
#> ✔ broom        1.0.6     ✔ rsample      1.2.1
#> ✔ dials        1.2.1     ✔ tune         1.2.1
#> ✔ infer        1.0.7     ✔ workflows    1.1.4
#> ✔ modeldata    1.4.0     ✔ workflowsets 1.1.0
#> ✔ parsnip      1.2.1     ✔ yardstick    1.3.1
#> ✔ recipes      1.1.0
library(discrim)
Code
Smarket <- read_csv("data/Smarket.csv")
#> Rows: 1250 Columns: 9
#> ── Column specification ────────────────────────────────────────────────────────
#> Delimiter: ","
#> chr (1): Direction
#> dbl (8): Year, Lag1, Lag2, Lag3, Lag4, Lag5, Volume, Today
#> 
#> ℹ Use `spec()` to retrieve the full column specification for this data.
#> ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
Smarket$Direction <- factor(Smarket$Direction)
head(Smarket)
#> # A tibble: 6 × 9
#>    Year   Lag1   Lag2   Lag3   Lag4   Lag5 Volume  Today Direction
#>   <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl> <fct>    
#> 1  2001  0.381 -0.192 -2.62  -1.06   5.01    1.19  0.959 Up       
#> 2  2001  0.959  0.381 -0.192 -2.62  -1.06    1.30  1.03  Up       
#> 3  2001  1.03   0.959  0.381 -0.192 -2.62    1.41 -0.623 Down     
#> 4  2001 -0.623  1.03   0.959  0.381 -0.192   1.28  0.614 Up       
#> 5  2001  0.614 -0.623  1.03   0.959  0.381   1.21  0.213 Up       
#> 6  2001  0.213  0.614 -0.623  1.03   0.959   1.35  1.39  Up
Code
lda_spec <- discrim_linear() %>%
  set_mode("classification") %>%
  set_engine("MASS")
lda_fit <- lda_spec %>%
  fit(Direction ~ Lag1 + Lag2, data = Smarket)

lda_fit
#> parsnip model object
#> 
#> Call:
#> lda(Direction ~ Lag1 + Lag2, data = data)
#> 
#> Prior probabilities of groups:
#>   Down     Up 
#> 0.4816 0.5184 
#> 
#> Group means:
#>             Lag1        Lag2
#> Down  0.05068605  0.03229734
#> Up   -0.03969136 -0.02244444
#> 
#> Coefficients of linear discriminants:
#>             LD1
#> Lag1 -0.7567605
#> Lag2 -0.4707872
Code
predict(lda_fit, new_data = Smarket)
#> # A tibble: 1,250 × 1
#>    .pred_class
#>    <fct>      
#>  1 Up         
#>  2 Down       
#>  3 Down       
#>  4 Up         
#>  5 Up         
#>  6 Up         
#>  7 Down       
#>  8 Up         
#>  9 Up         
#> 10 Down       
#> # ℹ 1,240 more rows
predict(lda_fit, new_data = Smarket, type = "prob")
#> # A tibble: 1,250 × 2
#>    .pred_Down .pred_Up
#>         <dbl>    <dbl>
#>  1      0.486    0.514
#>  2      0.503    0.497
#>  3      0.510    0.490
#>  4      0.482    0.518
#>  5      0.485    0.515
#>  6      0.492    0.508
#>  7      0.509    0.491
#>  8      0.490    0.510
#>  9      0.477    0.523
#> 10      0.505    0.495
#> # ℹ 1,240 more rows
augment(lda_fit, new_data = Smarket) %>%
  conf_mat(truth = Direction, estimate = .pred_class) 
#>           Truth
#> Prediction Down  Up
#>       Down  114 102
#>       Up    488 546

augment(lda_fit, new_data = Smarket) %>%
  accuracy(truth = Direction, estimate = .pred_class) 
#> # A tibble: 1 × 3
#>   .metric  .estimator .estimate
#>   <chr>    <chr>          <dbl>
#> 1 accuracy binary         0.528
Back to top