我编写了一个用于 10 折交叉验证的函数,我想将其用于不同的模型,例如 PPR、MARS。但是,运行它时出现错误,我无法弄清楚为什么它不起作用?我的简历功能:
cv10 <- function(reg.fn, formula, dataset, ...)
{
set.seed(201)
### Number of observations
nrow <- nrow(dataset)
### Create a permutation of the observations indices
Ind <- sample.int(nrow,nrow, replace = FALSE)
### Compute the size of each of the 10 folds
M <- nrow / 10 # 'fold size'
### Initialize the score
score <- 0
### The first fold will then contain the observations which correspond to..
### ..the indices of the first M elements of Ind.
for(i in 1:10){
beg <- i*M
end <- (i+1)*M
### Data to train the model
data.train <- dataset[Ind[-beg:-end],]
### Data to test the model
data.fold <- dataset[Ind[beg:end],]
### Fit the model
model <- reg.fn(formula,data=data.train,...)
predicted.y <- predict(model,data.fold)
### Update the CV-score
score <- sum((predicted.y - data.fold[,1])^2) / M
}
return(score/10)
}
使用 ppr 进行测试:
cv.scores <- numeric(10)
### Some code
for(i in 1:10){
score <- cv10(reg.fn = ppr, formula = y~.,
dataset = data, nterms=i)
cv.scores[i] <- scores
}
cv.scores
追溯:
> Error in matrix(NA, length(keep), object$q, dimnames = list(rn,
> object$ynames)) :
> length of 'dimnames' [1] not equal to array extent
> 4.
> matrix(NA, length(keep), object$q, dimnames = list(rn, object$ynames))
> 3.
> predict.ppr(model, data.fold)
> 2.
> predict(model, data.fold)
> 1.
> cv10(reg.fn = ppr, formula = y ~ ., dataset = data, nterms = i)
我正在使用的数据:
structure(list(y = c(23.0551546516262, 27.8893494373006, 3.32468370559938,
-13.5852336127512, -5.14668013186906, -0.489523212484223, -14.328750654513,
-4.26428395686341, -2.75486620989581, 17.3107345018601, 25.6193450849393,
0.605103858286016, -1.30909806542865, 2.03575942172917, -19.1193524499977,
-1.46508279385589, 2.65778970954973, 14.8513018374104, -2.87449028138997,
1.37368992108124, -1.43518738939116, 0.0199676357940499, -1.549025998582,
-4.06263285631006, -9.15130335901099, -2.62794216480131, -1.68473200963303,
3.15144283445608, 7.78027589015824, 9.09732626327383), x1 = c(0.286060694657523,
-0.344546030966432, 0.325763726232689, -1.69658096808073, -1.2854825202758,
-0.0750318862014798, 0.266937353823139, 0.0559340444850217, -2.30403430891787,
0.189004139305415, 0.693296170158882, 0.223809355083932, 0.398456942903131,
1.01347438447768, -0.64785307166209, 0.648452713333917, 0.207342703528518,
0.0643901392726141, 0.669380920067964, -0.374254446133507, -0.244000842201787,
-0.988253138922366, 1.24206047974719, -1.68266602919039, 1.44289062580162,
-0.465439746975312, 0.693661499094998, -0.0877255722586039, -0.955080382553146,
0.170100884691593), x2 = c(-0.343601401483176, -0.924078839603673,
0.973710320640175, 0.0267187344544633, -1.36283892301834, 0.105184057636645,
-0.644019900369909, 0.960031901250783, 0.147336523178527, 0.339467057535232,
-0.192287076626924, 0.0722969316029643, 0.389789911800799, -0.328247051156339,
-0.090450711707476, 0.716681577815978, 0.0626860575507786, -0.69236622624416,
0.584444051353438, -0.0911664147267412, -0.315213328094698, -0.0806856079787168,
0.484583750517842, -0.120406402869962, 0.596077475841207, -0.36353784662963,
-0.780093462571257, 0.324679908484668, 0.508548510215705, -0.193595813912055
), x3 = c(0.982327855388361, 0.624091435911063, 0.621531522270016,
-0.902870741076395, 0.931325903563023, -1.05264178470207, 0.307132555544596,
0.275469955530981, 2.78596687577565, -0.590390951909848, -0.0257046477898407,
-0.122008374353289, 0.455026913225061, -0.607514744574133, 0.595817459312108,
1.48223488775224, 0.636854208609479, 0.201054337281812, -0.716437866742046,
-2.30960460962945, -1.11690418809942, 0.296611889529358, 0.992033628272787,
-0.769290105905667, -1.4112664763812, 0.972758797977034, 0.680563892580633,
0.0312007101558726, 2.40109797772769, 0.27149586035907), x4 = c(2.87744884192944,
2.97037391737103, 2.04590974515304, -2.09065303439274, -0.886272139381617,
0.258417838253081, -2.48789734393358, -1.14431498106569, 1.52785618370399,
2.43856811150908, 2.88160788919777, 0.143826744519174, -1.32458955561742,
0.850324050989002, -2.63397432630882, -0.270683331415057, 1.85416122945026,
2.19268380571157, -1.33175755385309, 1.08762756781653, 0.7014160878025,
0.907778979744762, -1.3183526317589, 0.718872689176351, -2.21834870846942,
-0.750489700119942, -0.889076801016927, 1.39292777515948, 2.34955989941955,
2.1975970286876), x5 = c(1.48984236368162, 0.869139640762428,
0.748845036625717, 0.351786000608901, -1.47779050566991, -2.3154451409239,
2.20221698212952, 0.414262887380592, 0.244955910040375, 0.429121363729595,
-0.317306195296495, -1.38016320237183, 0.694020488858179, 0.305431051706151,
-0.398558943204744, -1.00163421976715, 1.29024064725421, -0.770948417017754,
0.741664981312622, 0.169399870781162, -1.35676745536567, 0.471865193264912,
0.960859048309877, 1.46760491067668, 1.4378809852526, 0.0349201858899876,
-1.42177690061078, -1.43127605517511, -0.101638629745238, 1.49972397311187
)), .Names = c("y", "x1", "x2", "x3", "x4", "x5"), row.names = c(NA,
30L), class = "data.frame")