Machine Learning - Random Forest

From Q
Jump to: navigation, search

Fit a random forest of classification or regression trees

Fits a random forest of classification or regression trees.


To run a Random Forest model:

1. In Displayr, select Insert > More > Machine Learning > Random Forest'. In Q, select Automate > Browse Online Library > Machine Learning > Random Forest.
2. Under Inputs > Random Forest > Outcome select your outcome variable.
3. Under Inputs > Random Forest > Predictor(s) select your predictor variables.
4. Make any other selections as required.


Categorical outcome

The table below shows the variable importance as computed by a Random Forest. The column called MeanDecreaseAccuracy contains a measure of the extent to which a variable improves the accuracy of the forest in predicting the classification. Higher values mean that the variable improves prediction. In a rough sense, it can be interpreted as showing the amount of increase in classification accuracy that is provided by including the variable in the model (a more precise statement of the meaning is complicated, and requires a detailed understanding of the underlying mechanics of random forests). In this example, x1 is clearly the most important variable, followed by x2, and x3.

The first three columns show the importance of the variable at improving accuracy by category of the outcome variable. We can see in this example, that x1's importance as a predictor is largely due to its usefulness in predicting membership of Group C, whereas x2 is primarily improving prediction of Group A, followed by Group C, and has a marginally deleterious impact on prediction of Group B.

Importance (MeanDecreaseGini) provides a more nuanced measure of importance, which factors in both the contribution that variable makes to accuracy, and the degree of misclassification (e.g., if a variable improves the probability of an observation being classified to a segment from 55% to 90%, this will show up in the Importance (MeanDecreaseGini), but not in MeanDecreaseAccuracy). As with MeanDecreaseAccuracy, high numbers indicate that a variable is more important as a predictor.

Numeric outcome

The table below shows the random forest outputs for a numeric outcome variable. The first column can be interpreted as indicating the extent to which different variables explain the variance in the dependent variable. The second column can be interpreted as showing the extent to which different variables reduce uncertainty in the predictions of the model. As with the description of the categorical variable random forest, these are only rough "translations" of the true meaning of these metrics. It is not clear which metric is better for judging importance.

Outcome variables which are numeric but only have two non-missing unique values will be treated as categorical.


Outcome The variable to be predicted by the predictors. It may be either a numeric variable, in which case a forest of regression trees is estimated, or classification trees if categorical.

Predictors The variable(s) to predict the outcome.

Algorithm The machine learning algorithm. Defaults to Random Forest but may be changed to other machine learning methods.


Importance Produces importance tables, as illustrated above.
Detail This returns the default output from randomForest in the randomForest package. It includes a confusion matrix for classification trees, and the percentage of variance explained for regression trees.
Prediction-Accuracy Table Produces a table relating the observed and predicted outcome. Also known as a confusion matrix.

Missing data See Missing Data Options.

Variable names Displays Variable Names in the output instead of labels.

Sort by importance Sort the rows by importance (the last column in the table).

Weight. Where a weight has been set for the R Output, a new data set is generated via resampling, and this new data set is used in the estimation. This causes the resulting measures of prediction accuracy (R-square and out-of-bag sample) to be overly optimistic. The unweighted model should be used when evauating prediction accuracy.

Filter The data is automatically filtered using any filters prior to estimating the model.

Additional options are available by editing the code.


Prediction-Accuracy Table Creates a table showing the observed and predicted values, as a heatmap.


Predicted Values Creates a new variable containing predicted values for each case in the data.

Probabilities of Each Response Creates new variables containing predicted probabilities of each response.


Uses the algorithm randomForest algorithm from the randomForest package.

Breiman, L. (2001), Random Forests, Machine Learning 45(1), 5-32.

More information

This blog post explains random forests.
This post describes the data fitting process.
The calculation of variable importance is described here.


var controls = [];

var algorithm = form.comboBox({label: "Algorithm",
                               alternatives: ["CART", "Deep Learning", "Gradient Boosting", "Linear Discriminant Analysis",
                                              "Random Forest", "Regression", "Support Vector Machine"],
                               name: "formAlgorithm", default_value: "Random Forest",
                               prompt: "Machine learning or regression algorithm for fitting the model"});

algorithm = algorithm.getValue();

var regressionType = "";
if (algorithm == "Regression")
    regressionTypeControl = form.comboBox({label: "Regression type", 
                                           alternatives: ["Linear", "Binary Logit", "Ordered Logit", "Multinomial Logit", "Poisson",
                                                          "Quasi-Poisson", "NBD"], 
                                           name: "formRegressionType", default_value: "Linear",
                                           prompt: "Select type according to outcome variable type"});
    regressionType = regressionTypeControl.getValue();

missing_data_options = ["Error if missing data", "Exclude cases with missing data", "Imputation (replace missing values with estimates)"];

if (algorithm == "Support Vector Machine")
    output_options = ["Accuracy", "Prediction-Accuracy Table", "Detail"];
if (algorithm == "Gradient Boosting") 
    output_options = ["Accuracy", "Importance", "Prediction-Accuracy Table", "Detail"];
if (algorithm == "Random Forest")
    output_options = ["Importance", "Prediction-Accuracy Table", "Detail"];
if (algorithm == "Deep Learning")
    output_options = ["Accuracy", "Prediction-Accuracy Table", "Cross Validation", "Network Layers"];
if (algorithm == "Linear Discriminant Analysis")
    output_options = ["Means", "Detail", "Prediction-Accuracy Table", "Scatterplot", "Moonplot"];

if (algorithm == "CART") {
    output_options = ["Sankey", "Tree", "Text", "Prediction-Accuracy Table", "Cross Validation"];
    missing_data_options = ["Error if missing data", "Exclude cases with missing data",
                             "Use partial data", "Imputation (replace missing values with estimates)"]
if (algorithm == "Regression") {
    if (regressionType == "Multinomial Logit")
        output_options = ["Summary", "Detail", "ANOVA"];
    else if (regressionType == "Linear")
        output_options = ["Summary", "Detail", "ANOVA", "Relative Importance Analysis", "Shapley Regression", "Jaccard Coefficient", "Correlation", "Effects Plot"];
        output_options = ["Summary", "Detail", "ANOVA", "Relative Importance Analysis", "Effects Plot"];

var outputControl = form.comboBox({label: "Output", prompt: "The type of output used to show the results",
                                   alternatives: output_options, name: "formOutput",
                                   default_value: output_options[0]});
var output = outputControl.getValue();

if (algorithm == "Regression") {
    if (regressionType == "Linear") {
        if (output == "Jaccard Coefficient" || output == "Correlation")
            missing_data_options = ["Error if missing data", "Exclude cases with missing data", "Use partial data (pairwise correlations)"];
            missing_data_options = ["Error if missing data", "Exclude cases with missing data", "Dummy variable adjustment", "Use partial data (pairwise correlations)", "Multiple imputation"];
        missing_data_options = ["Error if missing data", "Exclude cases with missing data", "Dummy variable adjustment", "Multiple imputation"];

var missingControl = form.comboBox({label: "Missing data", 
                                    alternatives: missing_data_options, name: "formMissing", default_value: "Exclude cases with missing data",
                                    prompt: "Options for handling cases with missing data"});
var missing = missingControl.getValue();
controls.push(form.checkBox({label: "Variable names", name: "formNames", default_value: false, prompt: "Display names instead of labels"}));


if (algorithm == "Support Vector Machine")
    controls.push(form.textBox({label: "Cost", name: "formCost", default_value: 1, type: "number",
                                prompt: "High cost produces a complex model with risk of overfitting, low cost produces a simpler mode with risk of underfitting"}));

if (algorithm == "Gradient Boosting") {
    controls.push(form.comboBox({label: "Booster", 
                                 alternatives: ["gbtree", "gblinear"], name: "formBooster", default_value: "gbtree",
                                 prompt: "Boost tree or linear underlying models"}));
    controls.push(form.checkBox({label: "Grid search", name: "formSearch", default_value: false,
                                 prompt: "Search for optimal hyperparameters"}));

if (algorithm == "Random Forest")
    if (output == "Importance")
        controls.push(form.checkBox({label: "Sort by importance", name: "formImportance", default_value: true}));

if (algorithm == "Deep Learning") {
    controls.push(form.numericUpDown({name:"formEpochs", label:"Maximum epochs", default_value: 10, minimum: 1, maximum: Number.MAX_SAFE_INTEGER,
                                      prompt: "Number of rounds of training"}));
    controls.push(form.textBox({name: "formHiddenLayers", label: "Hidden layers", prompt: "Comma delimited list of the number of nodes in each hidden layer", required: true}));
    controls.push(form.checkBox({label: "Normalize predictors", name: "formNormalize", default_value: true,
                                 prompt: "Normalize to zero mean and unit variance"}));

if (algorithm == "Linear Discriminant Analysis") {
    if (output == "Scatterplot")
        controls.push(form.colorPicker({label: "Outcome color", name: "formOutColor", default_value:"#5B9BD5"}));
        controls.push(form.colorPicker({label: "Predictors color", name: "formPredColor", default_value:"#ED7D31"}));
    controls.push(form.comboBox({label: "Prior", alternatives: ["Equal", "Observed",], name: "formPrior", default_value: "Observed",
                                 prompt: "Probabilities of group membership"}));

if (algorithm == "CART") {
    controls.push(form.comboBox({label: "Pruning", alternatives: ["Minimum error", "Smallest tree", "None"], 
                                 name: "formPruning", default_value: "Minimum error",
                                 prompt: "Remove nodes after tree has been built"}));
    controls.push(form.checkBox({label: "Early stopping", name: "formStopping", default_value: false,
                                 prompt: "Stop building tree when fit does not improve"}));
    controls.push(form.comboBox({label: "Predictor category labels", alternatives: ["Full labels", "Abbreviated labels", "Letters"],
                                 name: "formPredictorCategoryLabels", default_value: "Abbreviated labels",
                                 prompt: "Labelling of predictor categories in the tree"}));
    controls.push(form.comboBox({label: "Outcome category labels", alternatives: ["Full labels", "Abbreviated labels", "Letters"],
                                 name: "formOutcomeCategoryLabels", default_value: "Full labels",
                                 prompt: "Labelling of outcome categories in the tree"}));
    controls.push(form.checkBox({label: "Allow long-running calculations", name: "formLongRunningCalculations", default_value: false,
                                 prompt: "Allow predictors with more than 30 categories"}));

var stacked_check = false;
if (algorithm == "Regression") {
    if (missing == "Multiple imputation")
        controls.push(form.dropBox({label: "Auxiliary variables",
                                    types:["Variable: Numeric, Date, Money, Categorical, OrderedCategorical"], 
                                    name: "formAuxiliaryVariables", required: false, multi:true,
                                    prompt: "Additional variables to use when imputing missing values"}));
    controls.push(form.comboBox({label: "Correction", alternatives: ["None", "False Discovery Rate", "Bonferroni"], name: "formCorrection",
                                 default_value: "None", prompt: "Multiple comparisons correction applied when computing p-values of post-hoc comparisons"}));
    var is_RIA_or_shapley = output == "Relative Importance Analysis" || output == "Shapley Regression";
    var is_Jaccard_or_Correlation = output == "Jaccard Coefficient" || output == "Correlation";
    if (regressionType == "Linear" && missing != "Use partial data (pairwise correlations)" && missing != "Multiple imputation")
        controls.push(form.checkBox({label: "Robust standard errors", name: "formRobustSE", default_value: false,
                                     prompt: "Standard errors are robust to violations of assumption of constant variance"}));
    if (is_RIA_or_shapley)
        controls.push(form.checkBox({label: "Absolute importance scores", name: "formAbsoluteImportance", default_value: false,
                                     prompt: "Show absolute instead of signed importances"}));
    if (regressionType != "Multinomial Logit" && (is_RIA_or_shapley || is_Jaccard_or_Correlation || output == "Summary"))
        controls.push(form.dropBox({label: "Crosstab interaction", name: "formInteraction", types:["Variable: Numeric, Date, Money, Categorical, OrderedCategorical"],
                                    required: false, prompt: "Categorical variable to test for interaction with other variables"}));
    if (regressionType !== "Multinomial Logit")
        controls.push(form.numericUpDown({name : "formOutlierProportion", label:"Automated outlier removal percentage", default_value: 0, 
                                          minimum:0, maximum:49.9, increment:0.1,
                                          prompt: "Data points removed and model refitted based on the residual values in the model using the full dataset"}));
    stacked_check_box = form.checkBox({label: "Stack data", name: "formStackedData", default_value: false,
                                       prompt: "Allow input into the Outcome control to be a single multi variable and Predictors to be a single grid variable"})
    stacked_check = stacked_check_box.getValue();

controls.push(form.numericUpDown({name:"formSeed", label:"Random seed", default_value: 12321, minimum: 1, maximum: Number.MAX_SAFE_INTEGER,
                                  prompt: "Initializes randomization for imputation and certain algorithms"}));

var outcome = form.dropBox({label: "Outcome", 
                            types: [ stacked_check ? "VariableSet: BinaryMulti, NominalMulti, OrdinalMulti, NumericMulti" : "Variable: Numeric, Date, Money, Categorical, OrderedCategorical"], 
                            multi: false,
                            name: "formOutcomeVariable",
                            prompt: "Independent target variable to be predicted"});
var predictors = form.dropBox({label: "Predictor(s)",
                               types:[ stacked_check ? "VariableSet: BinaryGrid, NumericGrid" : "Variable: Numeric, Date, Money, Categorical, OrderedCategorical"], 
                               name: "formPredictorVariables", multi: stacked_check ? false : true,
                               prompt: "Dependent input variables"});


form.setHeading((regressionType == "" ? "" : (regressionType + " ")) + algorithm);

model <- MachineLearning(formula = if (isTRUE(get0("formStackedData"))) as.formula(NULL) else QFormula(formOutcomeVariable ~ formPredictorVariables),
                         algorithm = formAlgorithm,
                         weights = QPopulationWeight, subset = QFilter,
                         missing = formMissing,
                         output = if (formOutput == "Shapley Regression") "Shapley regression" else formOutput,
                         show.labels = !formNames,
                         seed = get0("formSeed"),
                         cost = get0("formCost"),
                         booster = get0("formBooster"),
                = get0("formSearch"),
                = get0("formImportance"),
                         hidden.nodes = get0("formHiddenLayers"),
                         max.epochs = get0("formEpochs"),
                         normalize = get0("formNormalize"),
                         outcome.color = get0("formOutColor"),
                         predictors.color = get0("formPredColor"),
                         prior = get0("formPrior"),
                         prune = get0("formPruning"),
                         early.stopping = get0("formStopping"),
                         predictor.level.treatment = get0("formPredictorCategoryLabels"),
                         outcome.level.treatment = get0("formOutcomeCategoryLabels"),
                         long.running.calculations = get0("formLongRunningCalculations"),
                         type = get0("formRegressionType"),
                = get0("formAuxiliaryVariables"),
                         correction = get0("formCorrection"),
                = get0("formRobustSE", ifnotfound = FALSE),
                         importance.absolute = get0("formAbsoluteImportance"),
                         interaction = get0("formInteraction"),
                = if (get0("formRegressionType", ifnotfound = "") != "Multinomial Logit") get0("formOutlierProportion")/100 else NULL,
                = get0("formStackedData"),
                = if (isTRUE(get0("formStackedData"))) list(Y = get0("formOutcomeVariable"), X = get0("formPredictorVariables")) else NULL)