可解释增强机器#
API 参考链接:ExplainableBoostingClassifier, ExplainableBoostingRegressor
摘要
可解释增强机器 (EBM) 是一种基于树的循环梯度增强广义加性模型,具有自动交互检测功能。EBM 通常与最先进的黑盒模型一样准确,同时保持完全可解释性。
工作原理
作为框架的一部分,InterpretML 还包含一种新的可解释性算法——可解释增强机器 (EBM)。EBM 是一种白盒模型,旨在实现与随机森林和增强树等最先进的机器学习方法相当的准确性,同时具有高度的可理解性和可解释性。EBM 是一种广义加性模型 (GAM),其形式为
其中 \(g\) 是一个链接函数,用于将 GAM 应用于不同的设置,例如回归或分类。
相较于传统 GAM [2],EBM 有一些重大改进。首先,EBM 使用现代机器学习技术(如 bagging 和梯度增强)来学习每个特征函数 \(f_j\)。增强过程被仔细限制为以循环方式每次仅在一个特征上进行训练,使用非常低的学习率,这样特征顺序就不重要了。它通过循环特征来减轻共线性的影响,并学习每个特征的最佳特征函数 \(f_j\),以显示每个特征如何对模型在该问题上的预测做出贡献。其次,EBM 可以自动检测并包含以下形式的成对交互项
这在提高准确性的同时保持了可理解性。EBM 是 GA2M 算法 [1] 的快速实现,用 C++ 和 Python 编写。该实现是可并行的,并利用 joblib 提供多核和多机并行化。训练过程、成对交互项的选择以及案例研究的算法细节可以在 [1, 3, 4] 中找到。
EBM 具有高度的可理解性,因为通过绘制 \(f_j\),每个特征对最终预测的贡献可以被可视化和理解。由于 EBM 是一个加性模型,每个特征以模块化的方式对预测做出贡献,这使得很容易推断每个特征对预测的贡献。
为了进行个体预测,每个函数 \(f_j\) 充当每个特征的查找表,并返回一个项贡献。这些项贡献被简单地累加起来,并通过链接函数 \(g\) 计算最终预测。由于模块化(加性),可以对项贡献进行排序和可视化,以显示哪些特征对任何个体预测影响最大。
为了保持单个项的加性,EBM 支付了额外的训练成本,这使得它比类似方法稍慢。然而,由于进行预测只需要在特征函数 \(f_j\) 内部进行简单的加法和查找,EBM 是预测时执行速度最快的模型之一。EBM 的轻量内存使用和快速预测时间使其在生产环境中部署模型时特别有吸引力。
如果您觉得视频是学习该算法更好的媒介,可以在下面找到该算法的概念概述:
代码示例
以下代码将为成人收入数据集训练一个 EBM 分类器。提供的可视化效果将用于全局和局部解释。
from interpret import set_visualize_provider
from interpret.provider import InlineProvider
set_visualize_provider(InlineProvider())
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from interpret.glassbox import ExplainableBoostingClassifier
from interpret import show
df = pd.read_csv(
"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data",
header=None)
df.columns = [
"Age", "WorkClass", "fnlwgt", "Education", "EducationNum",
"MaritalStatus", "Occupation", "Relationship", "Race", "Gender",
"CapitalGain", "CapitalLoss", "HoursPerWeek", "NativeCountry", "Income"
]
X = df.iloc[:, :-1]
y = df.iloc[:, -1]
seed = 42
np.random.seed(seed)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=seed)
ebm = ExplainableBoostingClassifier()
ebm.fit(X_train, y_train)
auc = roc_auc_score(y_test, ebm.predict_proba(X_test)[:, 1])
print("AUC: {:.3f}".format(auc))
AUC: 0.930
show(ebm.explain_global())
show(ebm.explain_local(X_test[:5], y_test[:5]), 0)
更多资源
参考文献
[1] Yin Lou, Rich Caruana, Johannes Gehrke, and Giles Hooker. Accurate intelligible models with pairwise interactions. In Proceedings of the 19th ACM SIGKDD international conference on Knowledge discovery and data mining, 623–631. 2013. 论文链接
[2] Trevor Hastie and Robert Tibshirani. Generalized additive models: some applications. Journal of the American Statistical Association, 82(398):371–386, 1987.
[3] Yin Lou, Rich Caruana, and Johannes Gehrke. Intelligible models for classification and regression. In Proceedings of the 18th ACM SIGKDD international conference on Knowledge discovery and data mining, 150–158. 2012. 论文链接
[4] Rich Caruana, Yin Lou, Johannes Gehrke, Paul Koch, Marc Sturm, and Noemie Elhadad. Intelligible models for healthcare: predicting pneumonia risk and hospital 30-day readmission. In Proceedings of the 21th ACM SIGKDD international conference on knowledge discovery and data mining, 1721–1730. 2015. 论文链接