Machine Learning - Classification And Regression Trees (CART)

From Q
Jump to: navigation, search

A Classification And Regression Tree (CART), is a predictive model, which explains how an outcome variable's values can be predicted based on other values. A CART output is a decision tree where each fork is a split in a predictor variable and each end node contains a prediction for the outcome variable.

For an introduction to decision trees, see this blog post. This blog post discusses the Pruning and Early stopping options and this post describes how trees are built.

Example

An interactive tree created using the Sankey output option:

Options

Outcome Variable to be predicted.

Predictors Variables which will be considered as predictors of the Outcome. Predictors that are considered to be uninformative will be automatically excluded from the model.

Algorithm The machine learning algorithm. Defaults to CART but may be changed to other machine learning methods.

Output How the tree should be displayed. The choices are:

  • Sankey: An interactive tree. This is the default.
  • Tree: A greyscale tree plot.
  • Text: A text representation of the tree.
  • Prediction-Accuracy Table: Produces a table relating the observed and predicted outcome. Also known as a confusion matrix.
  • Cross Validation: A plot of the cross-validation accuracy versus the size of the tree in terms of the number of leaves.

Missing data Method for dealing with missing data. See Missing Data Options.

Pruning The type of post-pruning applied to the tree. Choices are:

  • Minimum error: Prune back leaf nodes to create the tree with the smallest cross validation error.
  • Smallest tree: Prune to create the smallest tree with cross validation error at most 1 standard error greater than the minimum error.
  • None: Retain the tree as it has been built. Note that choosing this option without Early stopping is prone to overfitting.

Early stopping Whether to stop splitting nodes before the fit stops improving. Setting this may decrease the time to build the tree, potentially at the cost of not finding the tree with the best accuracy. See here for more detail.

Variable names Displays Variable Names in the output.

Predictor category labels Whether to shorten category labels from categorical predictor variables. The choices are:

  • Full labels: The complete labels.
  • Abbreviated labels: Labels that have been shortened by taking the first few letters from each word.
  • Letters: Letters from the alphabet where "a" corresponds to the first category, "b" to the second category, and so on.

Outcome category labels Same as above but for the outcome variable.

Allow long-running calculations Predictors with m categories require evaluation of 2^(m - 1) split points. This may cause calculations to run for a long time. Checking this box allows categorical variables with more than 30 categories to be included in Predictors.

Acknowledgements

Uses the R packages rpart and partykit.

Code

form.dropBox({label: "Outcome", 
            types:["Variable: Numeric, Date, Money, Categorical, OrderedCategorical"], 
            name: "formOutcomeVariable",
            prompt: "Independent target variable to be predicted"});
form.dropBox({label: "Predictor(s)",
            types:["Variable: Numeric, Date, Money, Categorical, OrderedCategorical"], 
            name: "formPredictorVariables", multi:true,
            prompt: "Dependent input variables"});

// ALGORITHM
var algorithm = form.comboBox({label: "Algorithm",
               alternatives: ["CART", "Deep Learning", "Gradient Boosting", "Linear Discriminant Analysis",
                              "Random Forest", "Regression", "Support Vector Machine"],
               name: "formAlgorithm", default_value: "CART",
               prompt: "Machine learning or regression algorithm for fitting the model"}).getValue();
var regressionType = "";
if (algorithm == "Regression")
    regressionType = form.comboBox({label: "Regression type", 
                                        alternatives: ["Linear", "Binary Logit", "Ordered Logit", "Multinomial Logit", "Poisson",
                                                                                                          "Quasi-Poisson", "NBD"], 
                                        name: "formRegressionType", default_value: "Linear",
                                        prompt: "Select type according to outcome variable type"}).getValue();
form.setHeading((regressionType == "" ? "" : (regressionType + " ")) + algorithm);

// DEFAULT CONTROLS
missing_data_options = ["Error if missing data", "Exclude cases with missing data", "Imputation (replace missing values with estimates)"];

// AMEND DEFAULT CONTROLS PER ALGORITHM
if (algorithm == "Support Vector Machine")
    output_options = ["Accuracy", "Prediction-Accuracy Table", "Detail"];
if (algorithm == "Gradient Boosting") 
    output_options = ["Accuracy", "Importance", "Prediction-Accuracy Table", "Detail"];
if (algorithm == "Random Forest")
    output_options = ["Importance", "Prediction-Accuracy Table", "Detail"];
if (algorithm == "Deep Learning")
    output_options = ["Accuracy", "Prediction-Accuracy Table", "Cross Validation", "Network Layers"];
if (algorithm == "Linear Discriminant Analysis")
    output_options = ["Means", "Detail", "Prediction-Accuracy Table", "Scatterplot", "Moonplot"];

if (algorithm == "CART") {
    output_options = ["Sankey", "Tree", "Text", "Prediction-Accuracy Table", "Cross Validation"];
    missing_data_options = ["Error if missing data", "Exclude cases with missing data",
                             "Use partial data", "Imputation (replace missing values with estimates)"]
}
if (algorithm == "Regression") {
    if (regressionType == "Multinomial Logit")
        output_options = ["Summary", "Detail", "ANOVA"];
    else
        output_options = ["Summary", "Detail", "ANOVA", "Relative Importance Analysis", "Effects Plot"]
    if (regressionType == "Linear")
        missing_data_options = ["Error if missing data", "Exclude cases with missing data", "Use partial data (pairwise correlations)", "Multiple imputation"];
    else
        missing_data_options = ["Error if missing data", "Exclude cases with missing data", "Multiple imputation"];
}

// COMMON CONTROLS FOR ALL ALGORITHMS
var output = form.comboBox({label: "Output", 
              alternatives: output_options, name: "formOutput", default_value: output_options[0]}).getValue();
var missing = form.comboBox({label: "Missing data", 
              alternatives: missing_data_options, name: "formMissing", default_value: "Exclude cases with missing data",
              prompt: "Options for handling cases with missing data"}).getValue();
form.checkBox({label: "Variable names", name: "formNames", default_value: false, prompt: "Display names instead of labels"});

// CONTROLS FOR SPECIFIC ALGORITHMS

if (algorithm == "Support Vector Machine")
    form.textBox({label: "Cost", name: "formCost", default_value: 1, type: "number",
                  prompt: "High cost produces a complex model with risk of overfitting, low cost produces a simpler mode with risk of underfitting"});

if (algorithm == "Gradient Boosting") {
    form.comboBox({label: "Booster", 
                  alternatives: ["gbtree", "gblinear"], name: "formBooster", default_value: "gbtree",
                  prompt: "Boost tree or linear underlying models"})
    form.checkBox({label: "Grid search", name: "formSearch", default_value: false,
                   prompt: "Search for optimal hyperparameters"});
}

if (algorithm == "Random Forest")
    if (output == "Importance")
        form.checkBox({label: "Sort by importance", name: "formImportance", default_value: true});

if (algorithm == "Deep Learning") {
    form.numericUpDown({name:"formEpochs", label:"Maximum epochs", default_value: 10, minimum: 1, maximum: 1000000,
                        prompt: "Number of rounds of training"});
    form.textBox({name: "formHiddenLayers", label: "Hidden layers", prompt: "Comma delimited list of the number of nodes in each hidden layer", required: true});
    form.checkBox({label: "Normalize predictors", name: "formNormalize", default_value: true,
                   prompt: "Normalize to zero mean and unit variance"});
}

if (algorithm == "Linear Discriminant Analysis") {
    if (output == "Scatterplot")
    {
        form.colorPicker({label: "Outcome color", name: "formOutColor", default_value:"#5B9BD5"});
        form.colorPicker({label: "Predictors color", name: "formPredColor", default_value:"#ED7D31"});
    }
    form.comboBox({label: "Prior", alternatives: ["Equal", "Observed",], name: "formPrior", default_value: "Observed",
                   prompt: "Probabilities of group membership"})
}

if (algorithm == "CART") {
    form.comboBox({label: "Pruning", alternatives: ["Minimum error", "Smallest tree", "None"], 
                   name: "formPruning", default_value: "Minimum error",
                   prompt: "Remove nodes after tree has been built"})
    form.checkBox({label: "Early stopping", name: "formStopping", default_value: false,
                   prompt: "Stop building tree when fit does not improve"});
    form.comboBox({label: "Predictor category labels", alternatives: ["Full labels", "Abbreviated labels", "Letters"],
                   name: "formPredictorCategoryLabels", default_value: "Abbreviated labels",
                   prompt: "Labelling of predictor categories in the tree"})
    form.comboBox({label: "Outcome category labels", alternatives: ["Full labels", "Abbreviated labels", "Letters"],
                   name: "formOutcomeCategoryLabels", default_value: "Full labels",
                   prompt: "Labelling of outcome categories in the tree"})
    form.checkBox({label: "Allow long-running calculations", name: "formLongRunningCalculations", default_value: false,
                   prompt: "Allow predictors with more than 30 categories"});
}

if (algorithm == "Regression") {
    if (missing == "Multiple imputation")
        form.dropBox({label: "Auxiliary variables",
            types:["Variable: Numeric, Date, Money, Categorical, OrderedCategorical"], 
            name: "formAuxiliaryVariables", required: false, multi:true,
            prompt: "Additional variables to use when imputing missing values"});
    form.comboBox({label: "Correction", alternatives: ["None", "False Discovery Rate", "Bonferroni"], name: "formCorrection",
                   default_value: "None", prompt: "Multiple comparisons correction applied when computing p-values of post-hoc comparisons"});
    var is_RIA = (output == "Relative Importance Analysis");
    if (regressionType == "Linear" && missing != "Use partial data (pairwise correlations)" && missing != "Multiple imputation")
        form.checkBox({label: "Robust standard errors", name: "formRobustSE", default_value: false,
                       prompt: "Standard errors are robust to violations of assumption of constant variance"});
    if (output == "Relative Importance Analysis")
        form.checkBox({label: "Absolute importance scores", name: "formAbsoluteImportance", default_value: false,
                       prompt: "Show absolute instead of signed importances"});
    if (regressionType != "Multinomial Logit" && (is_RIA || output == "Summary"))
        form.dropBox({label: "Crosstab interaction", name: "formInteraction", types:["Variable: Numeric, Date, Money, Categorical, OrderedCategorical"],
                      required: false, prompt: "Categorical variable to test for interaction with other variables"});
}

form.numericUpDown({name:"formSeed", label:"Random seed", default_value: 12321, minimum: 1, maximum: 1000000,
                    prompt: "Initializes randomization for imputation and certain algorithms"});
library(flipMultivariates)

model <- MachineLearning(formula = QFormula(formOutcomeVariable ~ formPredictorVariables),
                                    algorithm = formAlgorithm,
                                    weights = QPopulationWeight, subset = QFilter,
                                    missing = formMissing, output = formOutput, show.labels = !formNames,
                                    seed = get0("formSeed"),
                                    cost = get0("formCost"),
                                    booster = get0("formBooster"),
                                    grid.search = get0("formSearch"),
                                    sort.by.importance = get0("formImportance"),
                                    hidden.nodes = get0("formHiddenLayers"),
                                    max.epochs = get0("formEpochs"),
                                    normalize = get0("formNormalize"),
                                    outcome.color = get0("formOutColor"),
                                    predictors.color = get0("formPredColor"),
                                    prior = get0("formPrior"),
                                    prune = get0("formPruning"),
                                    early.stopping = get0("formStopping"),
                                    predictor.level.treatment = get0("formPredictorCategoryLabels"),
                                    outcome.level.treatment = get0("formOutcomeCategoryLabels"),
                                    long.running.calculations = get0("formLongRunningCalculations"), 
                                    type = get0("formRegressionType"),
                                    auxiliary.data = get0("formAuxiliaryVariables"),
                                    correction = get0("formCorrection"),
                                    robust.se = get0("formRobustSE", ifnotfound = FALSE),
                                    importance.absolute = get0("formAbsoluteImportance"),
                                    interaction = get0("formInteraction"))