它不是严重的过度拟合(取决于定义)。测试集的目标信息被保留。半监督允许生成额外的合成数据集来训练模型。在所描述的方法中,原始训练数据以 4:3 的比例与未加权的合成数据混合。因此,如果合成数据的质量很差,这种方法将是灾难性的。我猜对于预测不确定的任何问题,合成数据集的准确性都会很差。如果底层结构非常复杂且系统噪声低,我猜它可能有助于生成合成数据。我认为半监督学习在深度学习中相当大(不是我的专长),其中也需要学习特征表示。
我试图通过半监督训练来重现更高的准确性,同时使用 rf 和 xgboost 对几个数据集进行训练,但没有任何积极的结果。[随意编辑我的代码。]我注意到在 kaggle 报告中使用半监督的实际准确性提高是相当适度的,也许是随机的?
rm(list=ls())
#define a data structure
fy2 = function(nobs=2000,nclass=9) sample(1:nclass-1,nobs,replace=T)
fX2 = function(y,noise=.05,twist=8,min.width=.7) {
x1 = runif(length(y)) * twist
helixStart = seq(0,2*pi,le=length(unique(y))+1)[-1]
x2 = sin(helixStart[y+1]+x1)*(abs(x1)+min.width) + rnorm(length(y))*noise
x3 = cos(helixStart[y+1]+x1)*(abs(x1)+min.width) + rnorm(length(y))*noise
cbind(x1,x2,x3)
}
#define a wrapper to predict n-1 folds of test set and retrain and predict last fold
smartTrainPred = function(model,trainX,trainy,testX,nfold=4,...) {
obj = model(trainX,trainy,...)
folds = split(sample(1:dim(trainX)[1]),1:nfold)
predDF = do.call(rbind,lapply(folds, function(fold) {
bigX = rbind(trainX ,testX[-fold,])
bigy = c(trainy,predict(obj,testX[-fold,]))
if(is.factor(trainy)) bigy=factor(bigy-1)
bigModel = model(bigX,bigy,...)
predFold = predict(bigModel,testX[fold,])
data.frame(sampleID=fold, pred=predFold)
}))
smartPreds = predDF[sort(predDF$sampleID,ind=T)$ix,2]
}
library(xgboost)
library(randomForest)
#complex but perfect separatable
trainy = fy2(); trainX = fX2(trainy)
testy = fy2(); testX = fX2(testy )
pairs(trainX,col=trainy+1)
#try with randomForest
rf = randomForest(trainX,factor(trainy))
normPred = predict(rf,testX)
cat("\n supervised rf", mean(testy!=normPred))
smartPred = smartTrainPred(randomForest,trainX,factor(trainy),testX,nfold=4)
cat("\n semi-supervised rf",mean(testy!=smartPred))
#try with xgboost
xgb = xgboost(trainX,trainy,
nrounds=35,verbose=F,objective="multi:softmax",num_class=9)
normPred = predict(xgb,testX)
cat("\n supervised xgboost",mean(testy!=normPred))
smartPred = smartTrainPred(xgboost,trainX,trainy,testX,nfold=4,
nrounds=35,verbose=F,objective="multi:softmax",num_class=9)
cat("\n semi-supervised xgboost",mean(testy!=smartPred))
printing prediction error:
supervised rf 0.007
semi-supervised rf 0.0085
supervised xgboost 0.046
semi-supervised xgboost 0.049