This shows how cross-validation is used to assess the proper complexity parameter for a regression tree. This is the graph shown by printcp
, based on cross-validation computations in rpart
, both in the rpart
package.
Required packages.
require(rpart); require(rpart.plot); require(sp)
## Loading required package: rpart
## Loading required package: rpart.plot
## Loading required package: sp
Sample dataset is Meuse heavy metals in soil samples, comes with sp
:
data(meuse)
names(meuse)
## [1] "x" "y" "cadmium" "copper" "lead" "zinc" "elev"
## [8] "dist" "om" "ffreq" "soil" "lime" "landuse" "dist.m"
# meuse$logZn <- log10(meuse$zinc)
We select Zn (zinc
) as the target, and try to predict from flooding frequency, distance to river in meters, and elevation m.a.sl.
Build a complex tree, with maximum possible splitting and cp=0.001
:
m.lzn.rp <- rpart(zinc ~ ffreq + dist.m + elev,
data = meuse, minsplit = 2, cp = 0.001)
rpart.plot(m.lzn.rp)
This is obviously too site-specific, i.e., overfit, from just 155 observations.
The cross-validation procedure is:
Randomly split the dataset into ten parts. Do this by a random permutation of the record numbers, followed by a split.
For each of the ten parts:
2.1 Hold the subset out for validation.
2.2 Build a tree with the others.
2.3 Predict at the held-out points.
2.4 Compute validation statistics.
First, decide on the number of cross-validations, and from that determine the size of each hold-out evaluation sample.
# number of splits, we decide this
k <- 10; (size <- floor(dim(meuse)[1]/k))
## [1] 15
Set up a data frame to keep the results of each cross-validation run.
max.split <- 10 # the maximum number of levels we will test
split.vs.xval <- as.data.frame(matrix(0, nrow=max.split, ncol=3))
names(split.vs.xval) <- c("nsplit", "xerr", "xerr.sd")
To avoid any sequential effects in the database (we don’t know why they are presented in this order) make a random permutation, out of which we will sample.
records.permute <- sample(1:dim(meuse)[1])
Run the cross-validation for each possible depth, i.e., number of splits.
for (i.split in 1:max.split) {
rmse <- rep(0,k)
for (i in 0:(k-1)) {
ix <- records.permute[(i*size+1):(i*size+size)]
# training set to build the tree, test set to evaluate
meuse.cal <- meuse[-ix,]; meuse.val <- meuse[ix,]
# build the tree
m.cal <- rpart(zinc ~ ffreq + dist.m + elev,
data = meuse.cal, maxdepth=i.split, cp=0)
val <- (meuse.val$zinc - predict(m.cal, meuse.val)) # vector of residuals
rmse[i+1] <- sqrt(sum(val^2)/length(val)) ## RMSE
} # for k-fold
# Now we have `k` RMSE:
rmse.m <- mean(rmse) # this is the value we use to match with split
rmse.sd <- sd(rmse)/length(ix) # and it has a standard deviation of the mean
# save it
split.vs.xval[i.split,] = c(i.split, rmse.m, rmse.sd)
} # for minsplit
From this we can find an “optimim” as the minimum cross-validation RMSE:
split.vs.xval
## nsplit xerr xerr.sd
## 1 1 253.1795 4.849847
## 2 2 212.5010 3.657651
## 3 3 202.1870 3.290276
## 4 4 209.4729 3.776380
## 5 5 211.3142 3.749992
## 6 6 212.1163 3.710246
## 7 7 211.6462 3.744795
## 8 8 211.6462 3.744795
## 9 9 211.6462 3.744795
## 10 10 211.6462 3.744795
(ix.min <- which.min(split.vs.xval$xerr))
## [1] 3
Final result: the x-validation RMSE vs. number of splits, also showing one standard error.
plot(split.vs.xval[,2] ~ split.vs.xval[,1], type="b", ylab="cross-validation RMSE", xlab="Number of splits")
grid()
abline(h=(split.vs.xval[ix.min, "xerr"] + split.vs.xval[ix.min, "xerr.sd"]), lty=2)
rpart::printcp
What does the built-in procedure say about this model? Each run is different, although the tree is the same.
for (i in 1:4) {
m.lzn.rp.003 <- rpart(zinc ~ ffreq + dist.m + elev,
data = meuse, minsplit = 2, cp = 0.005)
plotcp(m.lzn.rp.003)
}