如何解释神经网络为给定输入产生输出的方式?

人工智能 神经网络 分类 可解释的ai
2021-10-29 21:31:34

我正在使用一个小型神经网络(2 个隐藏层,每个 60 个神经元)来解决一个相当复杂的二元分类问题。

网络运行良好,但我想知道它是如何使用输入来执行分类的。最终,我想解释经过训练的网络,以便更多地了解负责生成数据的过程。

理想情况下,我最终会得到一个方程,该方程允许我在没有网络的情况下执行分类,并且具有我可以在使用网络的系统的上下文中解释的参数。

我的第一个想法是在程序上屏蔽约 4000 个参数的不断增长的子集,直到在性能和简单性之间进行适当的权衡,然后可能使用符号逻辑库来尝试进一步简化。

我认为这不是最好的计划,所以我想知道是否存在解释神经网络的现有工作流程。

1个回答

理想情况下,我最终会得到一个等式,它允许我在没有网络的情况下执行分类”。

如果你可以在没有机器学习的情况下找到这样的解析方程,那么为什么首先要训练一个多层感知器呢?或者换一种说法,你训练的 mlp 就是那个方程。而且我并不是想讽刺,如果您需要分析解释,那么不要使用多层感知器,而是转向决策树算法,例如,您可以从字面上绘制模型本身(仍然难以解释取决于您正在使用的功能数量)。

相反,如果您仍然想坚持使用 mlp,那么您可以做一些事情来更好地理解您的模型,即绘制由它学习的决策边界。Sklearn 有一个关于如何做的很好的教程,我复制了它并用 mlp 更改了 svm 只是为了表明无论模型如何,该方法都有效

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.neural_network import MLPClassifier

# import some data to play with
iris = datasets.load_iris()
X = iris.data[:, :2]  # we only take the first two features. We could
# avoid this ugly slicing by using a two-dim dataset
y = iris.target

h = 0.02  # step size in the mesh

# we create an instance of SVM and fit out data. We do not scale our
# data since we want to plot the support vectors
C = 1.0  # SVM regularization parameter
mlp = MLPClassifier().fit(X, y)

# create a mesh to plot in
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))

# title for the plots
title = "MLP boundries"


# Plot the decision boundary. For that, we will assign a color to each
# point in the mesh [x_min, x_max]x[y_min, y_max].
# plt.subplot(1, 1)
# plt.subplots_adjust(wspace=0.4, hspace=0.4)

Z = mlp.predict(np.c_[xx.ravel(), yy.ravel()])

# Put the result into a color plot
Z = Z.reshape(xx.shape)
plt.contourf(xx, yy, Z, cmap=plt.cm.coolwarm, alpha=0.8)

# Plot also the training points
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.coolwarm)
plt.xlabel("Sepal length")
plt.ylabel("Sepal width")
plt.xlim(xx.min(), xx.max())
plt.ylim(yy.min(), yy.max())
plt.xticks(())
plt.yticks(())
plt.title(title)

plt.show()

代码输出此图: 在此处输入图像描述

您可以看到边界是 mlp 行为的粗略近似,它们只是根据应用于由输入数据的 2 个特征生成的 2D 图的所有点的蛮力预测来估计的。因此,边界也会根据您决定绘制的特征而改变。但它让您了解 mlp 学习到的关系。

如果你想要更多,我再次强调你应该训练一个不同的模型,比如决策树、随机森林或 XGboost,使用这些模型,你可以计算特征重要性的分数,并从字面上绘制模型学习的决策阈值。