如何使用 k 折交叉验证从 R 中的 nnet 获得泛化性能?

机器算法验证 r 机器学习 交叉验证 神经网络
2022-03-28 15:09:52

我正在使用该nnet包在 R 中进行一些机器学习。我想通过使用 k 折交叉验证来估计我的分类器的泛化性能。

我该怎么做呢?是否有一些内置函数可以为我执行此操作?tune.nnet()e1071包装中看到过,但我不确定它是否完全符合我的要求。

基本上我想做交叉验证(分成 10 组,训练 9 组,测试另一组 1 组,重复),然后从中获得某种衡量我的分类器泛化程度的方法——但我不确定是什么应该是这样的措施。我想我想看看不同交叉验证示例的平均准确度,但我不确定如何使用tune.nnet()上面的函数来做到这一点。

有任何想法吗?

2个回答

如果您计划在训练数据上调整网络(例如,为学习率选择一个值)并确定同一数据集上的错误泛化,您需要使用嵌套交叉验证,在每个折叠中,您都在调整数据集的 9/10 上的模型(使用 10 倍 cv)。看到这篇文章,我问了一个类似的问题并得到了很好的答案(帖子)。

我不确定是否有一个 R 函数来完成此任务 - 也许可以通过在调用中传递 tune 函数来使用ipred包 - 不确定。不过,简单地编写一个循环来完成整个过程是相当微不足道的。

至于要检查的指标,这取决于您的问题(分类与回归)以及您是否对准确性(分类率或 kappa)或模型排名(例如提升)或回归的 MAE 或 RMSE 感兴趣。

在 R 中实现 k-fold CV(有或没有嵌套)相对简单;分层抽样(关于班级成员或受试者的特征,例如年龄或性别)并不难。

关于评估一个分类器性能的方式,你可以直接看tune()函数的R代码。(只需tune在 R 提示符下键入。)对于分类问题,这是计算的类一致性(预测和观察到的类成员之间)。

但是,如果您正在寻找一个完整的 R 框架,其中在几个命令中提供了数据预处理(特征消除、缩放等)、训练/测试重采样和分类器准确性的比较测量,我绝对建议您查看caret包,其中还包含许多有用的小插曲(另请参阅JSS 论文)。

值得注意的是,尽管 NN 是可从内部调用的方法的一部分,但caret您可能需要查看其他性能和大多数情况下比 NN 更好的方法(例如,随机森林、SVM 等)