文章目录
- 前言
- 环境准备
- 分步操作:绘制混淆矩阵
- 第一步:计算混淆矩阵
- 第二步:可视化混淆矩阵
- 第三步:从混淆矩阵衍生关键指标
- 分步操作:绘制ROC曲线与计算AUC
- 第一步:计算ROC曲线所需数据
- 第二步:绘制ROC曲线
- 第三步:解读ROC曲线上的点
- 完整代码示例
- 踩坑提示
- 总结
前言
在模型训练完成后,很多新手朋友(包括当年的我)拿到一个“准确率95%”的结果就兴高采烈,觉得大功告成。直到有一次,我把一个预测用户是否点击广告的模型部署上线,线上效果却一塌糊涂。复盘时才发现,那个“95%准确率”的模型,是把所有样本都预测为“不点击”得来的——因为数据中不点击的用户本就占95%。这个惨痛教训让我明白,评估模型绝不能只看一个数字。混淆矩阵和ROC曲线就是帮助我们全面、可视化地诊断模型性能的“X光机”和“雷达图”。今天,我就带大家一步步用代码实现它们,把模型性能看得清清楚楚。
环境准备
我们使用最经典的scikit-learn和可视化库matplotlib、seaborn。确保你的环境里已经安装好它们。
pipinstallscikit-learn matplotlib seaborn接下来,我们导入必要的库,并创建一个简单的二分类数据集用于演示。这里我选择乳腺癌数据集,因为它特征清晰,且类别相对均衡。
# 导入必要库importnumpyasnpimportmatplotlib.pyplotaspltimportseabornassnsfromsklearn.datasetsimportload_breast_cancerfromsklearn.model_selectionimporttrain_test_splitfromsklearn.linear_modelimportLogisticRegressionfromsklearn.metricsimportconfusion_matrix,roc_curve,auc,RocCurveDisplay# 设置中文字体和图表样式(可选,使图表更美观)plt.rcParams['font.sans-serif']=['SimHei']# 用来正常显示中文标签plt.rcParams['axes.unicode_minus']=False# 用来正常显示负号sns.set_style("whitegrid")# 加载数据并划分训练集、测试集data=load_breast_cancer()X=data.data y=data.target# 标签,0代表恶性(Malignant),1代表良性(Benign)X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.3,random_state=42)# 训练一个简单的逻辑回归模型作为我们的评估对象model=LogisticRegression(max_iter=10000)model.fit(X_train,y_train)y_pred=model.predict(X_test)# 模型预测的类别y_score=model.predict_proba(X_test)[:,1]# 模型预测为正类(良性)的概率,用于ROC曲线分步操作:绘制混淆矩阵
混淆矩阵是理解模型分类错误类型的基石。它是一个NxN的矩阵(N为类别数),对于二分类,就是2x2矩阵。
第一步:计算混淆矩阵
使用sklearn.metrics.confusion_matrix函数,传入真实标签和预测标签即可。
# 计算混淆矩阵cm=confusion_matrix(y_test,y_pred)print("混淆矩阵(原始数值):")print(cm)# 输出示例:# [[ 59 4]# [ 2 106]]# 解读:59个真阴性(TN),4个假阳性(FP),2个假阴性(FN),106个真阳性(TP)第二步:可视化混淆矩阵
直接看数字不够直观,我们用热力图来绘制它。
defplot_confusion_matrix(cm,classes):""" 绘制美观的混淆矩阵热力图 Args: cm: 混淆矩阵数组 classes: 类别名称列表,如 ['恶性', '良性'] """plt.figure(figsize=(6,5))# 使用seaborn的热力图函数,annot=True显示数字,fmt='d'表示整数格式sns.heatmap(cm,annot=True,fmt='d',cmap='Blues',xticklabels=classes,yticklabels=classes)plt.title('混淆矩阵')plt.ylabel('真实标签')plt.xlabel('预测标签')plt.tight_layout()plt.show()# 调用函数,传入我们计算好的cm和类别名称class_names=['恶性 (0)','良性 (1)']plot_confusion_matrix(cm,class_names)这段代码会生成一个带数字的蓝色系热力图。从图中,你可以一眼看出模型在哪个类别上犯了更多错误。比如,如果“假阴性”(FN,实际是恶性但预测为良性)很多,那么这个模型在医疗场景下是极其危险的。
第三步:从混淆矩阵衍生关键指标
混淆矩阵本身包含了计算几乎所有常用指标的信息:
TN,FP,FN,TP=cm.ravel()# 将2x2矩阵展平为四个值accuracy=(TP+TN)/(TP+TN+FP+FN)# 准确率precision=TP/(TP+FP)# 精确率/查准率recall=TP/(TP+FN)# 召回率/查全率f1_score=2*precision*recall/(precision+recall)# F1分数print(f"准确率 (Accuracy):{accuracy:.3f}")print(f"精确率 (Precision):{precision:.3f}")print(f"召回率 (Recall):{recall:.3f}")print(f"F1分数 (F1-Score):{f1_score:.3f}")分步操作:绘制ROC曲线与计算AUC
ROC曲线描绘的是模型在不同判定阈值下,“真正例率(TPR)”和“假正例率(FPR)”的权衡关系。AUC是曲线下的面积,越接近1模型性能越好。
第一步:计算ROC曲线所需数据
我们需要模型预测的“概率值”,而不是最终的类别标签。之前我们已经用model.predict_proba获取了y_score。
# 计算FPR, TPR和对应的阈值fpr,tpr,thresholds=roc_curve(y_test,y_score)# 计算AUC值roc_auc=auc(fpr,tpr)print(f"AUC值为:{roc_auc:.3f}")第二步:绘制ROC曲线
我们可以用sklearn新版本提供的简便方法,也可以手动绘制以进行更多定制。
方法一:使用RocCurveDisplay(推荐,简洁)
RocCurveDisplay.from_estimator(model,X_test,y_test)plt.plot([0,1],[0,1],'k--',label='随机猜测 (AUC=0.5)')# 添加对角线作为参考plt.legend(loc="lower right")plt.title('ROC曲线 (使用RocCurveDisplay)')plt.show()方法二:手动绘制(更灵活,可对比多个模型)
plt.figure(figsize=(8,6))plt.plot(fpr,tpr,color='darkorange',lw=2,label=f'逻辑回归 (AUC ={roc_auc:.3f})')plt.plot([0,1],[0,1],color='navy',lw=2,linestyle='--',label='随机猜测')plt.xlim([0.0,1.0])plt.ylim([0.0,1.05])plt.xlabel('假正例率 (FPR)')plt.ylabel('真正例率 (TPR)')plt.title('乳腺癌分类的ROC曲线')plt.legend(loc="lower right")plt.grid(True,alpha=0.3)plt.show()ROC曲线离左上角越近,说明模型性能越好。AUC值是一个综合评判标准,通常认为:
- AUC = 0.5: 模型没有区分能力,等同于随机猜测。
- 0.5 < AUC < 0.7: 模型有较低区分能力。
- 0.7 ≤ AUC < 0.9: 模型有较好区分能力。
- AUC ≥ 0.9: 模型有非常高区分能力。
第三步:解读ROC曲线上的点
曲线上每一个点对应一个分类阈值。我们可以找出最靠近左上角(即(0,1)点)的阈值,它可能是理论上最优的阈值(但需结合业务)。
# 计算每个阈值点到左上角(0,1)的欧氏距离distances=np.sqrt((0-fpr)**2+(1-tpr)**2)optimal_idx=np.argmin(distances)# 找到最小距离的索引optimal_threshold=thresholds[optimal_idx]optimal_point=(fpr[optimal_idx],tpr[optimal_idx])print(f"理论最优阈值:{optimal_threshold:.3f}")print(f"该阈值对应的 (FPR, TPR):{optimal_point}")# 在实际项目中,你可能需要根据业务成本(如假阴性的代价很高)来调整阈值,而不是单纯依赖这个几何最优点。完整代码示例
将以上步骤整合成一个完整的、可执行的脚本。
# -*- coding: utf-8 -*-""" 混淆矩阵与ROC曲线可视化完整示例 """importnumpyasnpimportmatplotlib.pyplotaspltimportseabornassnsfromsklearn.datasetsimportload_breast_cancerfromsklearn.model_selectionimporttrain_test_splitfromsklearn.linear_modelimportLogisticRegressionfromsklearn.metricsimportconfusion_matrix,roc_curve,auc,RocCurveDisplay# 1. 准备数据与模型data=load_breast_cancer()X=data.data y=data.target X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.3,random_state=42)model=LogisticRegression(max_iter=10000)model.fit(X_train,y_train)y_pred=model.predict(X_test)y_score=model.predict_proba(X_test)[:,1]# 2. 混淆矩阵cm=confusion_matrix(y_test,y_pred)class_names=['恶性 (0)','良性 (1)']plt.figure(figsize=(12,5))plt.subplot(1,2,1)sns.heatmap(cm,annot=True,fmt='d',cmap='Blues',xticklabels=class_names,yticklabels=class_names)plt.title('混淆矩阵')plt.ylabel('真实标签')plt.xlabel('预测标签')# 3. ROC曲线plt.subplot(1,2,2)RocCurveDisplay.from_estimator(model,X_test,y_test,ax=plt.gca())plt.plot([0,1],[0,1],'k--',label='随机猜测 (AUC=0.5)')plt.title('ROC曲线')plt.legend(loc="lower right")plt.tight_layout()plt.show()# 4. 打印关键指标TN,FP,FN,TP=cm.ravel()print("="*50)print("模型性能关键指标:")print(f" 准确率 (Accuracy):{(TP+TN)/(TP+TN+FP+FN):.3f}")print(f" 精确率 (Precision):{TP/(TP+FP):.3f}")print(f" 召回率 (Recall):{TP/(TP+FN):.3f}")print(f" F1分数 (F1-Score):{2*TP/(2*TP+FP+FN):.3f}")fpr,tpr,_=roc_curve(y_test,y_score)print(f" AUC值:{auc(fpr,tpr):.3f}")print("="*50)踩坑提示
- ROC曲线适用于二分类:对于多分类问题,需要为每个类别绘制一条“一对多”的ROC曲线,或者计算宏平均/微平均ROC。
- 类别不平衡时,准确率是陷阱:正如前言中的例子,一定要结合混淆矩阵、精确率、召回率或ROC曲线来评估。在不平衡数据上,AUC通常比准确率更可靠。
predictvspredict_proba:画ROC曲线必须使用预测概率(predict_proba),而不是预测的类别标签(predict)。predict本质上是使用默认阈值(通常为0.5)将概率转化为了类别。- 阈值的选择是业务决策:ROC曲线展示了所有可能的阈值。选择哪个阈值,取决于你的业务场景是更容忍“误报”(FP)还是“漏报”(FN)。例如,在垃圾邮件检测中,我们宁愿漏掉一些垃圾邮件(FN),也绝不能把正常邮件判为垃圾(FP高)。
- 测试集要干净:评估一定要在独立的测试集上进行,不要在训练集上画ROC和算指标,否则会得到过于乐观的结果,即过拟合。
总结
混淆矩阵和ROC曲线是我们模型评估工具箱里的“黄金搭档”。混淆矩阵像一份详细的“诊断报告”,清晰地列出每一种对错情况;而ROC曲线则像一份“能力雷达图”,综合展示了模型在不同严苛程度下的性能表现。养成在模型评估时同时生成和分析它们的习惯,能让你对模型的认知从“好像不错”提升到“好在哪,差在哪”,从而做出更准确的优化和上线决策。
记住,一个好的AI工程师,不仅要会“炼丹”(调参训练),更要会“验丹”(全面评估)。
如有问题欢迎评论区交流,持续更新中…