predict.rsf {randomSurvivalForest} | R Documentation |
Prediction on test data using Random Survival Forests.
predict.rsf(object = NULL, test = NULL, importance = c("randomsplit", "permute", "none")[1], na.action = c("na.omit", "na.impute")[1], outcome = c("train", "test")[1], proximity = FALSE, split.depth = FALSE, seed = NULL, do.trace = FALSE, ...)
object |
An object of class (rsf, grow) or (rsf,
forest) . Requires forest =TRUE in the original rsf
call. |
test |
Data frame containing test data. Missing values allowed. |
importance |
Method used to compute variable importance (VIMP). Only applies when test data contains survival outcomes. |
na.action |
Action to be taken if the data contains NA's. Possible
values are na.omit , which removes the entire record if
even one of its entries is NA, and na.impute , which
imputes the test data. See details below. |
outcome |
Data frame used in calculating the ensemble. By default this is always the training data, but see details below. |
proximity |
Logical. Should proximity measure between test observations be calculated? Can be large. Default is FALSE. |
split.depth |
Return minimal depth for each variable for each test set individual? Default is FALSE. See details below. |
seed |
Seed for random number generator. Must be a negative integer (the R wrapper handles incorrectly set seed values). |
do.trace |
Logical. Should trace output be enabled? Default is
FALSE. Integer values can also be passed. A positive value
causes output to be printed each do.trace iteration. |
... |
Further arguments passed to or from other methods. |
predict.rsf
takes a test data set, drops it down the forest
grown from the training data, and using the grow-forest (i.e., the
forest grown from the training data), computes an ensemble cumulative
hazard function (CHF). CHF's are calculated for each individual in
the test data for all unique death time points from the original grow
data. If survival outcome information is present in the test data,
the overall error rate and VIMP for each variable is also returned.
Setting na.action
=na.impute
imputes missing test data
(x-variables or outcomes). Imputation uses the grow-forest such that
only training data is used when imputing test data to avoid biasing
error rates and VIMP (Ishwaran et al. 2008).
For competing risks, the ensemble conditional CHF (CCHF) is computed for each event type in addition to the ensemble CHF.
If outcome
=test
, the ensemble is calculated by
specifically using survival information from the test data (survival
information must be present). In this case, the terminal nodes from
the grow-forest are recalculated using survival data from the test
set. This yields a modified predictor in which the topology of the
forest is based solely on the training data, but where the predicted
value is based on the test data. Error rates and VIMP are calculated
by bootstrapping the test data and using out-of-bagging to ensure
unbiased estimates. See Examples 2 and 3 below for illustration.
An object of class (rsf, predict)
, which is a list with the
following components:
call |
The original grow call to rsf . |
forest |
The grow forest. |
ntree |
Number of trees in the grow forest. |
leaf.count |
Number of terminal nodes for each tree in the
grow forest. Vector of length ntree . |
timeInterest |
Sorted unique event times from grow (training) data. Ensemble values given for these time points only. |
n |
Sample size of test data (depends upon NA's, see na.action ). |
ndead |
Number of deaths in test data (can be NULL). |
time |
Vector of survival times from test data (can be NULL). |
cens |
Vector of censoring indicators from test data (can be NULL). |
predictorNames |
Character vector of variable names. |
predictors |
Data frame comprising x-variables used for prediction. |
ensemble |
Matrix containing the ensemble CHF for the test
data. Each row corresponds to a test data individual's CHF
evaluated at each of the time points in timeInterest .
For competing risks, a 3-D array where the 3rd dimension is for
the ensemble CHF and each of the CCHFs, respectively. |
mortality |
Vector containing ensemble mortality for each
individual in the test data. Ensemble mortality should
be interpreted in terms of total number of training deaths if
outcome =train . |
err.rate |
Vector of length ntree containing error
rate of the test data. For competing risks, a matrix with
rows corresponding to the ensemble CHF and each of the CCHFs,
respectively. Can be NULL. If outcome =test only the
error rate for the combined forest is returned. |
importance |
VIMP of each variable in the test data. For competing risks, a matrix with rows corresponding to the ensemble CHF and each of the CCHFs, respectively. Can be NULL. |
proximity |
If proximity =TRUE, a matrix recording
proximity of the inputs from test data is computed. Value
returned is a vector of the lower diagonal of the matrix. Use
plot.proximity() to extract this information. |
imputedIndv |
Vector of indices of records in test data with missing values. Can be NULL. |
imputedData |
Data frame comprising imputed test data. First
two columns are censoring and survival time, respectively. The
remaining columns are the x-variables. Row i contains imputed
outcomes and x-variables for row imputedIndv [i] of
predictors . Can be NULL. |
splitDepth |
Matrix where entry [i][j] is the mean minimal
depth for variable [j] for case [i] in the test data. Used for
variable selection (see max.subtree ). Can be NULL. |
The key deliverable is the matrix ensemble
which contains the
ensemble CHF for each individual in the test data evaluated at a set
of distinct time points.
Hemant Ishwaran hemant.ishwaran@gmail.com and Udaya B. Kogalur ubk2101@columbia.edu
L. Breiman (2001). Random forests, Machine Learning, 45:5-32.
H. Ishwaran, U.B. Kogalur, E.H. Blackstone and M.S. Lauer (2008). Random survival forests, Ann. App. Statist., 2:841-860.
H. Ishwaran, U.B. Kogalur (2007). Random survival forests for R, Rnews, 7/2:25-31.
competing.risk
,
plot.ensemble
,
plot.variable
,
plot.error
,
pmml2rsf
,
print.rsf
,
rsf
,
rsf2rfz
,
rsf2pmml
.
#------------------------------------------------------------ # Example 1: Veteran's Administration lung cancer data data(veteran, package = "randomSurvivalForest") pt.train <- sample(1:nrow(veteran), round(nrow(veteran)*0.80)) veteran.out <- rsf(Surv(time, status) ~ ., forest = TRUE, data = veteran[pt.train , ]) veteran.pred <- predict(veteran.out, veteran[-pt.train , ]) ## Not run: #------------------------------------------------------------ # Example 2: Get out-of-bag error rate using the training # data as test data (pbc example) data(pbc, package = "randomSurvivalForest") pbc.grow <- rsf(Surv(days, status) ~ ., pbc, nsplit = 3, forest = TRUE) pbc.pred <- predict(pbc.grow, pbc, outcome = "test") cat("GROW error rate :", round(pbc.grow$err.rate[1000], 3)) cat("PRED error rate :", round(pbc.pred$err.rate, 3)) #------------------------------------------------------------ # Example 3: Verify reproducibility of forest (pbc data) #primary call data(pbc, package = "randomSurvivalForest") pt.train <- sample(1:nrow(pbc), round(nrow(pbc)*0.50)) pbc.out <- rsf(Surv(days, status) ~ ., nsplit = 3, forest = TRUE, data = pbc[pt.train, ]) #make separate predict calls using the outcome option pbc.train <- predict(pbc.out, pbc[-pt.train, ], outcome = "train") pbc.test <- predict(pbc.out, pbc[-pt.train, ], outcome = "test") #check forest reproducibilility by comparing predicted survival curves timeInterest <- pbc.out$timeInterest surv.train <- exp(-pbc.train$ensemble) surv.test <- exp(-pbc.test$ensemble) matplot(timeInterest, t(surv.train - surv.test), type = "l") #test reproducibility by repeating B times #compute l1-difference in predicted survival B <- 25 l1.valid <- rep(NA, B) for (b in 1:B) { cat("Replication:", b, "\n") pt.train <- sample(1:nrow(pbc), round(nrow(pbc)*0.50)) pbc.out <- rsf(Surv(days, status) ~ ., nsplit = 3, forest = TRUE, data = pbc[pt.train, ]) surv.train <- exp(-predict(pbc.out, pbc[-pt.train, ], outcome = "train")$ensemble) surv.test <- exp(-predict(pbc.out, pbc[-pt.train, ], outcome = "test")$ensemble) l1.valid <- mean(apply(abs(surv.train - surv.test), 1, mean, na.rm = TRUE), na.rm = TRUE) } cat("l1-reproducibility:", round(mean(l1.valid, na.rm = TRUE), 3), "\n") ## End(Not run)