21  REF SVM

采用了REF(Recursive Feature Elimination)结合SVM(Support Vector Machine)的方法,对差异基因(参考 章节 11) 进行了特征筛选。通过这种方法,能够从大量的候选基因中识别出与特定生物过程或疾病状态最为相关的关键基因。筛选出的这些重要基因不仅具有统计学上的显著性,而且在实际应用中具有较高的预测能力。这些经过REF+SVM筛选的重要基因将被用于后续的生物信息学分析、疾病机制探究以及潜在治疗靶点的识别等研究中,以期进一步推动相关领域的科学进展。

21.1 加载R包

使用rm(list = ls())来清空环境中的所有变量。

library(tidyverse)
library(Biobase)
library(data.table)
library(caret)
library(e1071)
library(mlbench)
library(pROC)
library(Hmisc)
library(MLmetrics)

rm(list = ls())
options(stringsAsFactors = F)
options(future.globals.maxSize = 10000 * 1024^2)

grp_names <- c("Early Stage", "Late Stage")
grp_colors <- c("#8AC786", "#B897CA")
grp_shapes <- c(15, 16)

21.2 导入数据


da_res <- read.csv("./data/result/DA/HCC_Early_vs_Late_limma_select.csv")

ExprSet <- readRDS("./data/result/ExpSetObject/MergeExpSet_VoomSNM_VoomSNM_LIRI-JP_TCGA-LIHC.RDS")

21.3 准备数据


signif_DEG <- da_res %>%
  dplyr::filter(Enrichment != "Nonsignif") %>%
  dplyr::slice(-grep("\\.", FeatureID))

head(signif_DEG[, 1:6])
  • 基因表达谱:行名是样本,列名是基因ID

profile <- exprs(ExprSet) %>%
  as.data.frame()
rownames(profile) <- make.names(rownames(profile))

head(profile_DEG[, 1:6])
  • 去除共线性特征

在机器学习中,数据预处理中使用去除共线性特征的方法,其主要目的是为了提升模型的性能、稳定性和准确性。以下是具体的几个原因:

  1. 提高模型准确性:共线性特征指的是在数据集中存在高度相关性的特征。这些特征在建模时可能会提供冗余的信息,甚至可能导致模型过拟合,从而降低预测的准确性。通过去除共线性特征,可以减少这些冗余信息,使模型更加专注于关键信息,从而提高预测的准确性。

  2. 提高模型稳定性:当数据集中存在共线性特征时,模型的参数可能会变得不稳定,对训练数据的微小变化都可能会产生较大的影响。通过去除共线性特征,可以降低模型的复杂性,提高模型的稳定性,使模型在面对新的、未见过的数据时能够保持较好的性能。

  3. 加速模型训练:在训练模型时,如果数据集中存在大量的共线性特征,那么模型可能需要更多的时间和计算资源来找到最优的参数组合。通过去除共线性特征,可以减少模型的输入维度,降低模型的复杂性,从而加速模型的训练过程。

常用的去除共线性特征的方法包括:

  1. 方差膨胀因子(VIF):VIF用于量化特征之间的共线性程度。当VIF值较高时,表示该特征与其他特征之间存在较强的共线性关系。可以通过设定一个阈值,将VIF值超过该阈值的特征进行剔除或合并。

  2. 主成分分析(PCA):PCA是一种常用的降维方法,它通过找到数据中的主要变化方向(即主成分)来降低数据的维度。PCA可以有效地去除共线性特征,因为它会生成一组新的、不相关的特征(即主成分),这些特征能够捕获原始数据中的大部分信息。

  3. 相关系数分析:通过分析特征之间的相关系数,可以找出存在高度相关性的特征对。对于相关系数较高的特征对,可以选择其中一个特征进行保留,而将另一个特征进行剔除或合并。

这里通过Hmisc::rcorr(Harrell Jr 和 Harrell Jr 2019)采用了相关系数分析去除共线性。


if (file.exists("./data/result/ML/SVM/HCC_SVM_RM_profile.tsv")) {
  profile_remain <- fread("./data/result/ML/SVM/HCC_SVM_RM_profile.tsv", sep = "\t") %>%
    tibble::column_to_rownames("V1")
} else {
  
  write.table(profile_remain, "./data/result/ML/SVM/HCC_SVM_RM_profile.tsv",
              row.names = T, sep = "\t", quote = F)
}

print(dim(profile_remain))
  • 临床表型表:包含分组等信息

metadata <- pData(ExprSet)
head(metadata[, 1:6])
  • 合并数据

MergeData <- metadata %>%
  dplyr::select(SampleID, Group) %>%
  dplyr::inner_join(profile_remain %>%
                      tibble::rownames_to_column("SampleID"),
                    by = "SampleID") %>%
  tibble::column_to_rownames("SampleID") %>%
  dplyr::mutate(Group = recode(Group, "Early Stage" = "Early",
                               "Late Stage" = "Late")) %>%
  dplyr::mutate(Group = factor(Group))

head(MergeData[, 1:6])

21.4 机器学习特征筛选

特征筛选步骤:

  1. 数据分割:

将原始数据集划分为训练集、验证集和测试集。通常,训练集用于模型训练,验证集用于调整超参数和选择最佳模型,测试集用于评估最终模型的性能。 划分比例可以根据数据集的大小和特性进行调整,但一般常见的比例是训练集占60%-80%,验证集和测试集各占10%-20%。

  1. 数据转换:

数据清洗:处理缺失值、异常值、重复值等。对于缺失值,可以使用均值、中位数、众数等填充,或采用插值法、机器学习预测等方法进行填充。

特征编码:对于分类变量,需要进行编码以便模型能够处理。常见的编码方式包括独热编码(One-Hot Encoding)、标签编码(Label Encoding)等。

特征离散化:将连续特征转换为离散类别,例如通过分箱操作。这有助于处理一些具有非线性关系的特征。

特征标准化或归一化:对于不同量纲或不同分布的特征,需要进行标准化或归一化,以便模型能够更好地处理它们。

  1. 特征筛选:

相关性分析:计算每个特征与目标变量之间的相关性,如皮尔逊相关系数或斯皮尔曼等级相关系数。通过相关性分析,可以初步筛选出与目标变量强相关的特征。

树模型重要性:利用决策树、随机森林等树模型评估特征的重要性。这些模型可以根据特征在树中分裂的贡献度来评估特征的重要性。

启发式搜索策略:如前向序列选择方法,从空的候选集合出发,逐步添加与目标变量相关性最强的特征。

全局最优搜索策略:如穷举法或分支界定法,从所有可能的特征组合中挑选出表现最优的特征子集。

  1. 训练集构建模型:使用经过特征筛选的训练集数据构建机器学习模型。选择合适的机器学习算法,如逻辑回归、支持向量机、决策树、随机森林、神经网络等。

  2. 调参:使用验证集对模型进行调参,优化模型的超参数。超参数是机器学习算法中需要人为设定的参数,如学习率、迭代次数、正则化参数等。 可以采用网格搜索(Grid Search)、随机搜索(Random Search)等方法进行调参。

  3. 测试集评估模型:使用测试集对经过调参的模型进行评估,计算模型的性能指标,如准确率、精确率、召回率、F1分数、AUC-ROC等。根据评估结果选择性能最佳的模型作为最终模型。

21.4.1 数据分割

在机器学习的实践中,数据分割是一个至关重要的步骤,用于将原始数据集划分为训练集和测试集(使用了caret::createDataPartition (Kuhn 2008))。


set.seed(123)

trainIndex <- caret::createDataPartition(
          MergeData$Group, 
          p = 0.8, 
          list = FALSE, 
          times = 1)

trainData <- MergeData[trainIndex, ]
X_train <- trainData[, -1]
y_train <- trainData[, 1]

testData <- MergeData[-trainIndex, ]
X_test <- testData[, -1]
y_test <- testData[, 1]

21.4.2 基础模型


set.seed(123)

base.fit <- e1071::svm(
  Group ~ .,
  data = trainData,
  kernel = "radial")

base.fit

21.4.3 Recursive Feature Elimination特征筛选

递归特征消除(Recursive Feature Elimination,RFE)是一种基于模型的特征选择方法,其原理是通过反复训练模型和剔除最不重要特征的方式来选择最优的特征子集。以下是RFE特征筛选的具体步骤:


if (!file.exists("./data/result/ML/SVM/SVM_preData.RData")) {
  set.seed(123)
  
  if (!dir.exists("./data/result/ML/SVM")) {
    dir.create("./data/result/ML/SVM", recursive = TRUE)
  }

  save(X_train, y_train, X_test, y_test, fs_rfe,
       file = "./data/result/ML/SVM/SVM_preData.RData")
} else {
  load("./data/result/ML/SVM/SVM_preData.RData")
}

print(fs_rfe)
#list the chosen features
predictors(fs_rfe)[1:41]
plot(fs_rfe, type = c("g", "o"))

trainData_select <- trainData %>%
  dplyr::select(all_of(c("Group", feature_rfe)))
X_train_select <- trainData_select[, -1]
y_train_select <- trainData_select[, 1]

testData_select <- testData %>%
  dplyr::select(all_of(c("Group", feature_rfe)))
X_test_select <- testData_select[, -1]
y_test_select <- testData_select[, 1]

21.4.4 调参


set.seed(123)

if (file.exists("./data/result/ML/SVM/HCC_SVM_tuneFit.RData")) {
  load("./data/result/ML/SVM/HCC_SVM_tuneFit.RData")
} else {

  set.seed(123)
  tune_fit <- train(
    Group ~.,
    data = trainData_select,
    method = "svmLinear", # svmRadial svmLinear svmRadialCost
    trControl = myControl,
    tuneGrid = tuneGrid)
  
  save(tune_fit, file = "./data/result/ML/SVM/HCC_SVM_tuneFit.RData")
}

## Plot model accuracy vs different values of Cost
print(plot(tune_fit))

## Print the best tuning parameter that maximizes model accuracy
optimalVar <- data.frame(tune_fit$results[which.max(tune_fit$results[, 3]), ])
print(optimalVar)

21.4.5 最终分类模型

在已经通过适当的参数调优确定了Cost关键参数的值之后,可以基于这些参数构建最终的SVM(支持向量机)模型。这个模型将使用选定的Cost值来控制对误差的容忍度,从而确保模型在训练数据上具有良好的拟合能力,并在未知数据上保持优秀的泛化性能。


optimal <- length(fs_rfe$optVariables[1:selected_num])
selected_columns <- c("Group", fs_rfe$optVariables[1:selected_num])

trainData_optimal <- trainData_select %>%
  dplyr::select(all_of(selected_columns))

testData_optimal <- testData_select %>%
  dplyr::select(all_of(selected_columns))

set.seed(123)

svm_fit_optimal

21.4.6 测试集验证

首先,将训练好的模型应用于测试集数据,以预测每个样本的分类标签。预测结果将与测试集的真实标签进行对比,以计算模型在各个类别上的分类准确性。

为了更详细地了解模型的性能,构建混淆矩阵。混淆矩阵是一个N x N的表格(其中N为分类类别数),用于显示每个类别下的真实标签与预测标签之间的对比情况。通过混淆矩阵,可以计算出精确度(Precision)、召回率(Recall)、F1分数(F1-Score)等评估指标,这些指标能够全面反映模型在各类别上的分类效果。

此外,还可以绘制ROC曲线来评估模型的性能。ROC曲线是通过设置不同的分类阈值,计算真正例率(True Positive Rate,TPR)和假正例率(False Positive Rate,FPR)得到的。ROC曲线越靠近左上角,说明模型在保持较低假正例率的同时,能够获得较高的真正例率,即模型的性能越好。

通过混淆矩阵和ROC曲线,可以全面、客观地评估机器学习模型在测试集上的性能。

  • 混淆矩阵
print(caret::confusionMatrix(pred_raw, testData_optimal$Group))

AUROC <- function(
    DataTest, 
    PredProb = pred_prob, 
    nfeature) {
  
  # plot
  pl <- ggplot(data = roc, aes(x = fpr, y = tpr)) +
    geom_path(color = "red", size = 1) +
    geom_abline(intercept = 0, slope = 1, 
                color = "grey", linewidth = 1, linetype = 2) +
    labs(x = "False Positive Rate (1 - Specificity)",
         y = "True Positive Rate",
         title = paste0("AUROC (", nfeature, " Features)")) +
    annotate("text", 
             x = 1 - rocbj_df$specificities[max_value_row] + 0.15, 
             y = rocbj_df$sensitivities[max_value_row] - 0.05, 
             label = paste0(threshold, " (", 
                            rocbj_df$specificities[max_value_row], ",",
                            rocbj_df$sensitivities[max_value_row], ")"),
             size = 5, family = "serif") +
    annotate("point", 
             x = 1 - rocbj_df$specificities[max_value_row], 
             y = rocbj_df$sensitivities[max_value_row], 
             color = "black", size = 2) +    
    annotate("text", 
             x = .75, y = .25, 
             label = roc_CI_lab,
             size = 5, family = "serif") +
    coord_cartesian(xlim = c(0, 1), ylim = c(0, 1)) +
    theme_bw() +
    theme(panel.background = element_rect(fill = "transparent"),
          plot.title = element_text(size = 12, color = "black", face = "bold"),
          axis.title = element_text(size = 11, color = "black", face = "bold"), 
          axis.text = element_text(size= 10, color = "black"),
          axis.ticks.length = unit(0.4, "lines"),
          axis.ticks = element_line(color = "black"),
          axis.line = element_line(size = .5, color = "black"),
          text = element_text(size = 8, color = "black", family = "serif"))
  
  res <- list(rocobj = rocobj,
              roc_CI = roc_CI_lab,
              roc_pl = pl)
  
  return(res)
}

AUROC_res <- AUROC(
    DataTest = testData_optimal, 
    PredProb = pred_prob, 
    nfeature = optimal)

AUROC_res$roc_pl

AUPRC <- function(DataTest, PredProb, nfeature) {
  
  # plot
  pl <- ggplot(data = prc, aes(x = recall, y = precision)) +
    geom_path(color = "red", size = 1) +
    labs(x = "Recall",
         y = "Precision",
         title = paste0("AUPRC (", nfeature, " Features)")) +
    coord_cartesian(xlim = c(0, 1), ylim = c(0, 1)) +
    theme_bw() +
    theme(panel.background = element_rect(fill = "transparent"),
          plot.title = element_text(color = "black", size = 14, face = "bold"),
          axis.ticks.length = unit(0.4, "lines"),
          axis.ticks = element_line(color = "black"),
          axis.line = element_line(size = .5, color = "black"),
          axis.title = element_text(color = "black", size = 12, face = "bold"),
          axis.text = element_text(color = "black", size = 10),
          text = element_text(size = 8, color = "black", family = "serif"))
  
  res <- list(dat_PR = dat_PR,
              PC_pl = pl)
  
  return(res)
} 

AUPRC_res <- AUPRC(
    DataTest = testData_optimal, 
    PredProb = pred_prob, 
    nfeature = optimal)

AUPRC_res$PC_pl

Evaluate_index <- function(DataTest, PredProb, label, PredRaw) {
  
  threshold <- rocbj_df$threshold[max_value_row]
  sen <- round(TP / (TP + FN), 3) # caret::sensitivity(con_matrix)
  spe <- round(TN / (TN + FP), 3) # caret::specificity(con_matrix)
  acc <- round((TP + TN) / (TP + TN + FP + FN), 3) # Accuracy
  pre <- round(TP / (TP + FP), 3) # precision
  rec <- round(TP / (TP + FN), 3) # recall
  #F1S <- round(2 * TP / (TP + TN + FP + FN + TP - TN), 3)# F1-Score
  F1S <- round(2 * TP / (2 * TP + FP + FN), 3)# F1-Score
  youden <- sen + spe - 1 # youden index
  
  
  # AUROC
  AUROC <- round(as.numeric(auc(DataTest$Group, PredProb[, 1])), 3)
  
  # AUPRC
  AUPRC <- round(MLmetrics::PRAUC(y_pred = PredProb[, 1], 
                                  y_true = DataTest$Group), 3)  
  
  index_df <- data.frame(Index = c("Threshold", "Sensitivity",
                                   "Specificity", "Accuracy",
                                   "Precision", "Recall",
                                   "F1 Score", "Youden index",
                                   "AUROC", "AUPRC"),
                         Value = c(threshold, sen, spe,
                                   acc, pre, rec, F1S, 
                                   youden, AUROC, AUPRC)) %>%
    stats::setNames(c("Index", label))
  
  return(index_df)
}

Evaluate_index(
  DataTest = testData_optimal,
  PredProb = pred_prob,
  label = group_names[1],
  PredRaw = pred_raw)

结果:从表中我们可以看到以下指标及其对应的数值:

  • 阈值(Threshold):0.598。阈值用于将模型的预测结果转换为具体的类别(例如,在二分类问题中,通常将预测概率大于阈值的样本视为正类,小于阈值的视为负类)。

  • 灵敏度(Sensitivity):0.911。也称为真正率(True Positive Rate, TPR)或召回率(Recall),它表示所有实际为正样本的样本中,被模型正确预测为正样本的比例。

  • 特异度(Specificity):0.343。也称为真负率(True Negative Rate, TNR),它表示所有实际为负样本的样本中,被模型正确预测为负样本的比例。

  • 准确率(Accuracy):0.737。它表示模型在所有样本上的正确分类的比例。

  • 精度(Precision):0.758。它表示所有被模型预测为正样本的样本中,实际为正样本的比例。

  • 召回率(Recall):0.911(注意这个与Sensitivity的值是一样的,因为在二分类问题中,Sensitivity和Recall是等价的)。

  • F1得分(F1Score):0.828。它是精度和召回率的调和平均值,用于综合衡量两者的表现。

  • Youden指数:0.254。它是一个用于评估诊断测试性能的指标,等于灵敏度与特异度之和减去1。

  • AUROC(Area Under the Receiver Operating Characteristic Curve):0.663。AUROC是一种衡量分类模型性能的指标,它表示ROC曲线下的面积。

  • AUPRC(Area Under the Precision-Recall Curve):0.231。AUPRC是另一种衡量分类模型性能的指标,特别适用于正样本数量远少于负样本的情况,它表示Precision-Recall曲线下的面积。

21.5 标记基因

通过REF特征筛选,模型能够自动筛选出那些对预测结果具有显著影响的特征,因此被称为标记基因或显著特征。


optimal_feature <- fs_rfe$variables %>%
  dplyr::rename(FeatureID = var) %>%
  dplyr::filter(FeatureID %in% fs_rfe$optVariables[1:selected_num]) %>%
  dplyr::inner_join(fs_rfe$results, by = "Variables") %>%
  dplyr::select(FeatureID, Accuracy, Kappa, AccuracySD, KappaSD) %>%
  dplyr::arrange(Accuracy) %>%
  dplyr::mutate(FeatureID = forcats::fct_inorder(FeatureID))

optimal_feature <- optimal_feature[pmatch(unique(optimal_feature$FeatureID),
                                          optimal_feature$FeatureID), ,]

head(optimal_feature)
  • 提取特征的表达谱

profile_SVM <- profile[pmatch(optimal_feature$FeatureID,
                              rownames(profile)), ]

head(profile_SVM[, 1:6])

21.6 输出结果



if (!dir.exists("./data/result/ML/SVM")) {
  dir.create("./data/result/ML/SVM", recursive = TRUE)
}

write.csv(optimal_feature, "./data/result/ML/SVM/HCC_SVM_feature.csv", row.names = F)
write.table(profile_SVM, "./data/result/ML/SVM/HCC_SVM_profile.tsv", row.names = T, sep = "\t", quote = F)
save(AUROC_res, file = "./data/result/ML/SVM/HCC_SVM_AUROC.RData")

if (!dir.exists("./data/result/Figure/")) {
  dir.create("./data/result/Figure/", recursive = TRUE)
}

pdf("./data/result/Figure/SFig3-G.pdf", width = 5, height = 4)
plot(fs_rfe, type = c("g", "o"))
dev.off()

ggsave("./data/result/Figure/SFig3-H.pdf", AUROC_res$roc_pl, width = 5, height = 4, dpi = 600)

21.7 总结

经过对差异基因进行REF特征筛选,成功识别出一组与特定生物学现象或条件紧密相关的基因子集。接着,利用支持向量机构建预测模型。通过这一详尽的模型训练和评估过程,最终确定了41个特征,这些特征在预测任务中展现出了相对不错的性能(需要注意,它在Late分期预测不好)。这41个特征将被用于后续的分析,以进一步揭示这些差异基因在生物学过程中的作用机制和潜在价值。

警告

相比其他两种机器学习方法(见 章节 19章节 20)得到的预测模型,REF+SVM的预测模型在Late分期样本预测效果相对不佳。

系统信息
sessionInfo()
R version 4.3.3 (2024-02-29)
Platform: aarch64-apple-darwin20 (64-bit)
Running under: macOS Sonoma 14.2

Matrix products: default
BLAS:   /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/lib/libRblas.0.dylib 
LAPACK: /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/lib/libRlapack.dylib;  LAPACK version 3.11.0

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

time zone: Asia/Shanghai
tzcode source: internal

attached base packages:
[1] stats     graphics  grDevices datasets  utils     methods   base     

other attached packages:
 [1] MLmetrics_1.1.3     Hmisc_5.1-2         pROC_1.18.5        
 [4] mlbench_2.1-5       e1071_1.7-14        caret_6.0-94       
 [7] lattice_0.22-6      data.table_1.15.4   Biobase_2.62.0     
[10] BiocGenerics_0.48.1 lubridate_1.9.3     forcats_1.0.0      
[13] stringr_1.5.1       dplyr_1.1.4         purrr_1.0.2        
[16] readr_2.1.5         tidyr_1.3.1         tibble_3.2.1       
[19] ggplot2_3.5.1       tidyverse_2.0.0    

loaded via a namespace (and not attached):
 [1] tidyselect_1.2.1     timeDate_4032.109    fastmap_1.1.1       
 [4] digest_0.6.35        rpart_4.1.23         timechange_0.3.0    
 [7] lifecycle_1.0.4      cluster_2.1.6        survival_3.7-0      
[10] magrittr_2.0.3       compiler_4.3.3       rlang_1.1.3         
[13] tools_4.3.3          utf8_1.2.4           yaml_2.3.8          
[16] knitr_1.46           htmlwidgets_1.6.4    plyr_1.8.9          
[19] foreign_0.8-86       withr_3.0.0          nnet_7.3-19         
[22] grid_4.3.3           stats4_4.3.3         fansi_1.0.6         
[25] colorspace_2.1-0     future_1.33.2        globals_0.16.3      
[28] scales_1.3.0         iterators_1.0.14     MASS_7.3-60.0.1     
[31] cli_3.6.2            rmarkdown_2.26       generics_0.1.3      
[34] rstudioapi_0.16.0    future.apply_1.11.2  reshape2_1.4.4      
[37] tzdb_0.4.0           proxy_0.4-27         splines_4.3.3       
[40] parallel_4.3.3       BiocManager_1.30.23  base64enc_0.1-3     
[43] vctrs_0.6.5          hardhat_1.3.1        Matrix_1.6-5        
[46] jsonlite_1.8.8       hms_1.1.3            htmlTable_2.4.2     
[49] Formula_1.2-5        listenv_0.9.1        foreach_1.5.2       
[52] gower_1.0.1          recipes_1.0.10       glue_1.7.0          
[55] parallelly_1.37.1    codetools_0.2-19     stringi_1.8.4       
[58] gtable_0.3.5         munsell_0.5.1        pillar_1.9.0        
[61] htmltools_0.5.8.1    ipred_0.9-14         lava_1.8.0          
[64] R6_2.5.1             evaluate_0.23        backports_1.4.1     
[67] renv_1.0.0           class_7.3-22         Rcpp_1.0.12         
[70] checkmate_2.3.1      gridExtra_2.3        nlme_3.1-164        
[73] prodlim_2023.08.28   xfun_0.43            pkgconfig_2.0.3     
[76] ModelMetrics_1.2.2.2