将 2 类模型扩展到多类问题

机器算法验证 r 机器学习 分类 统计学习 多级
2022-03-24 04:45:37

这篇关于 Adaboost的论文给出了一些建议和代码(第 17 页),用于将 2 类模型扩展到 K 类问题。我想概括这段代码,以便我可以轻松插入不同的 2 类模型并比较结果。因为大多数分类模型都有公式接口和predict方法,所以其中一些应该相对容易。不幸的是,我还没有找到从 2 类模型中提取类概率的标准方法,因此每个模型都需要一些自定义代码。

这是我编写的一个函数,用于将 K 类问题分解为 2 类问题,并返回 K 个模型:

oneVsAll <- function(X,Y,FUN,...) {
    models <- lapply(unique(Y), function(x) {
        name <- as.character(x)
        .Target <- factor(ifelse(Y==name,name,'other'), levels=c(name, 'other'))
        dat <- data.frame(.Target, X)
        model <- FUN(.Target~., data=dat, ...)
        return(model)
    })
    names(models) <- unique(Y)
    info <- list(X=X, Y=Y, classes=unique(Y))
    out <- list(models=models, info=info)
    class(out) <- 'oneVsAll'
    return(out)
}

这是我编写的一种预测方法,用于迭代每个模型并进行预测:

predict.oneVsAll <- function(object, newX=object$info$X, ...) {
    stopifnot(class(object)=='oneVsAll')
    lapply(object$models, function(x) {
        predict(x, newX, ...)
    })
}

data.frame最后,这是一个对预测概率进行归一化并对案例进行分类的函数。请注意,您可以data.frame从每个模型中构建概率的 K 列,因为没有统一的方法从 2 类模型中提取类概率:

classify <- function(dat) {
    out <- dat/rowSums(dat)
    out$Class <- apply(dat, 1, function(x) names(dat)[which.max(x)])
    out
}

这是一个使用示例adaboost

library(ada)
library(caret) 
X <- iris[,-5]
Y <- iris[,5]
myModels <- oneVsAll(X, Y, ada)
preds <- predict(myModels, X, type='probs')
preds <- data.frame(lapply(preds, function(x) x[,2])) #Make a data.frame of probs
preds <- classify(preds)
>confusionMatrix(preds$Class, Y)
Confusion Matrix and Statistics

            Reference
Prediction   setosa versicolor virginica
  setosa         50          0         0
  versicolor      0         47         2
  virginica       0          3        48

这是一个使用示例lda(我知道 lda 可以处理多个类,但这只是一个示例):

library(MASS)
myModels <- oneVsAll(X, Y, lda)
preds <- predict(myModels, X)
preds <- data.frame(lapply(preds, function(x) x[[2]][,1])) #Make a data.frame of probs
preds <- classify(preds)
>confusionMatrix(preds$Class, Y)
Confusion Matrix and Statistics

            Reference
Prediction   setosa versicolor virginica
  setosa         50          0         0
  versicolor      0         39         5
  virginica       0         11        45

这些函数应该适用于任何具有公式接口和predict方法的 2 类模型。注意,你必须手动拆分X和Y分量,这有点难看,但现在写一个公式界面已经超出了我的范围。

这种方法对每个人都有意义吗?有什么办法可以改进它,或者是否有现有的软件包可以解决这个问题?

1个回答

改进的一种方法是使用“加权所有对”方法,该方法据说比“一对多”更好,同时仍可扩展。

对于现有包,glmnet支持(正则化)多项式 logit,可用作多类分类器。