Machine Learning - Gradient Boosting

From Q
Jump to: navigation, search

Creates a predictive model for either regression or classification from an ensemble of underlying tree or linear regression models.

This blog post explains gradient boosting and this post describes an example of predicting customer churn.

Background

Boosting is a method for combining a series of simple individual models to create a more powerful model. An initial model (either tree or linear regression) is fitted to the data. A second model is then built that focuses on accurately predicting for the cases where the first model performs poorly relative to its target outcomes. The combination of these two models is expected to be better than either model alone. The process is then repeated with each successive model attempting to correct for the shortcomings of the combined ensemble of all previous models.

The best possible next model, when combined with previous models, minimises the overall prediction error. The key idea behind gradient boosting is to set the target outcomes for this next model in order to minimise the error. The target outcome for each case depends on how much a change in that case's prediction impacts the overall prediction error.

If a small increase in the prediction causes a large drop in error for a case, then the next target outcome is a high value. This means that if the new model predicts close to its target, then the error is reduced.

If a small increase in the prediction no change in error for a case, then the next target outcome is zero because changing this prediction does not decrease the error.

The name gradient boosting arises because target outcomes are set based on the gradient of the error with respect to the prediction of each case. Each new model takes a step in the direction that minimises prediction error in the space of predictions for each training case.

Example

With the inputs are follows,

GradientBoostInput.PNG

The chart below shows the relative importance of the predictor variables. The most important variable has an importance of 1. Note that categorical predictors with more than 2 levels are split into individual binary variables. The variables are grouped into clusters of similar importance.

Options

Outcome The variable to be predicted by the predictor variables. It may be either a numeric or categorical variable.

Predictors The variable(s) to predict the outcome.

Algorithm The machine learning algorithm. Defaults to Gradient Boosting but may be changed to other machine learning methods.

Output

Accuracy Produces measures of the goodness of model fit. For categorical outcomes the breakdown by category is shown.
Importance Produces a chart showing the importance of the predictors in determining the outcome. Only available for gbtree booster.
Prediction-Accuracy Table Produces a table relating the observed and predicted outcome. Also known as a confusion matrix.
Detail Text output from the underlying xgboost package.

Missing data See Missing Data Options.

Booster The underlying model to be boosted. Choice between gbtree and gblinear.

Grid search Whether to search the parameter space in order to tune the model. If not checked, the default parameters of xgboost are used. Increasing this will usually create a more accurate predictor, at the cost of taking a longer time to run.

Variable names Displays Variable Names in the output instead of labels.

Weight Where a weight has been set for the R Output, a new data set is generated via resampling, and this new data set is used in the estimation.

Filter The data is automatically filtered using any filters prior to estimating the model.

Additional options are available by editing the code.

Acknowledgements

Uses the xgboost algorithm from the xgboost package by Tianqi Chen.

Code

form.dropBox({label: "Outcome", 
            types:["Variable: Numeric, Date, Money, Categorical, OrderedCategorical"], 
            name: "formOutcomeVariable",
            prompt: "Independent target variable to be predicted"});
form.dropBox({label: "Predictor(s)",
            types:["Variable: Numeric, Date, Money, Categorical, OrderedCategorical"], 
            name: "formPredictorVariables", multi:true,
            prompt: "Dependent input variables"});

// ALGORITHM
var algorithm = form.comboBox({label: "Algorithm",
               alternatives: ["CART", "Deep Learning", "Gradient Boosting", "Linear Discriminant Analysis",
                              "Random Forest", "Regression", "Support Vector Machine"],
               name: "formAlgorithm", default_value: "Gradient Boosting",
               prompt: "Machine learning or regression algorithm for fitting the model"}).getValue();
var regressionType = "";
if (algorithm == "Regression")
    regressionType = form.comboBox({label: "Regression type", 
                                        alternatives: ["Linear", "Binary Logit", "Ordered Logit", "Multinomial Logit", "Poisson",
                                                                                                          "Quasi-Poisson", "NBD"], 
                                        name: "formRegressionType", default_value: "Linear",
                                        prompt: "Select type according to outcome variable type"}).getValue();
form.setHeading((regressionType == "" ? "" : (regressionType + " ")) + algorithm);

// DEFAULT CONTROLS
missing_data_options = ["Error if missing data", "Exclude cases with missing data", "Imputation (replace missing values with estimates)"];

// AMEND DEFAULT CONTROLS PER ALGORITHM
if (algorithm == "Support Vector Machine")
    output_options = ["Accuracy", "Prediction-Accuracy Table", "Detail"];
if (algorithm == "Gradient Boosting") 
    output_options = ["Accuracy", "Importance", "Prediction-Accuracy Table", "Detail"];
if (algorithm == "Random Forest")
    output_options = ["Importance", "Prediction-Accuracy Table", "Detail"];
if (algorithm == "Deep Learning")
    output_options = ["Accuracy", "Prediction-Accuracy Table", "Cross Validation", "Network Layers"];
if (algorithm == "Linear Discriminant Analysis")
    output_options = ["Means", "Detail", "Prediction-Accuracy Table", "Scatterplot", "Moonplot"];

if (algorithm == "CART") {
    output_options = ["Sankey", "Tree", "Text", "Prediction-Accuracy Table", "Cross Validation"];
    missing_data_options = ["Error if missing data", "Exclude cases with missing data",
                             "Use partial data", "Imputation (replace missing values with estimates)"]
}
if (algorithm == "Regression") {
    if (regressionType == "Multinomial Logit")
        output_options = ["Summary", "Detail", "ANOVA"];
    else
        output_options = ["Summary", "Detail", "ANOVA", "Relative Importance Analysis", "Effects Plot"]
    if (regressionType == "Linear")
        missing_data_options = ["Error if missing data", "Exclude cases with missing data", "Use partial data (pairwise correlations)", "Multiple imputation"];
    else
        missing_data_options = ["Error if missing data", "Exclude cases with missing data", "Multiple imputation"];
}

// COMMON CONTROLS FOR ALL ALGORITHMS
var output = form.comboBox({label: "Output", 
              alternatives: output_options, name: "formOutput", default_value: output_options[0]}).getValue();
var missing = form.comboBox({label: "Missing data", 
              alternatives: missing_data_options, name: "formMissing", default_value: "Exclude cases with missing data",
              prompt: "Options for handling cases with missing data"}).getValue();
form.checkBox({label: "Variable names", name: "formNames", default_value: false, prompt: "Display names instead of labels"});

// CONTROLS FOR SPECIFIC ALGORITHMS

if (algorithm == "Support Vector Machine")
    form.textBox({label: "Cost", name: "formCost", default_value: 1, type: "number",
                  prompt: "High cost produces a complex model with risk of overfitting, low cost produces a simpler mode with risk of underfitting"});

if (algorithm == "Gradient Boosting") {
    form.comboBox({label: "Booster", 
                  alternatives: ["gbtree", "gblinear"], name: "formBooster", default_value: "gbtree",
                  prompt: "Boost tree or linear underlying models"})
    form.checkBox({label: "Grid search", name: "formSearch", default_value: false,
                   prompt: "Search for optimal hyperparameters"});
}

if (algorithm == "Random Forest")
    if (output == "Importance")
        form.checkBox({label: "Sort by importance", name: "formImportance", default_value: true});

if (algorithm == "Deep Learning") {
    form.numericUpDown({name:"formEpochs", label:"Maximum epochs", default_value: 10, minimum: 1, maximum: 1000000,
                        prompt: "Number of rounds of training"});
    form.textBox({name: "formHiddenLayers", label: "Hidden layers", prompt: "Comma delimited list of the number of nodes in each hidden layer", required: true});
    form.checkBox({label: "Normalize predictors", name: "formNormalize", default_value: true,
                   prompt: "Normalize to zero mean and unit variance"});
}

if (algorithm == "Linear Discriminant Analysis") {
    if (output == "Scatterplot")
    {
        form.colorPicker({label: "Outcome color", name: "formOutColor", default_value:"#5B9BD5"});
        form.colorPicker({label: "Predictors color", name: "formPredColor", default_value:"#ED7D31"});
    }
    form.comboBox({label: "Prior", alternatives: ["Equal", "Observed",], name: "formPrior", default_value: "Observed",
                   prompt: "Probabilities of group membership"})
}

if (algorithm == "CART") {
    form.comboBox({label: "Pruning", alternatives: ["Minimum error", "Smallest tree", "None"], 
                   name: "formPruning", default_value: "Minimum error",
                   prompt: "Remove nodes after tree has been built"})
    form.checkBox({label: "Early stopping", name: "formStopping", default_value: false,
                   prompt: "Stop building tree when fit does not improve"});
    form.comboBox({label: "Predictor category labels", alternatives: ["Full labels", "Abbreviated labels", "Letters"],
                   name: "formPredictorCategoryLabels", default_value: "Abbreviated labels",
                   prompt: "Labelling of predictor categories in the tree"})
    form.comboBox({label: "Outcome category labels", alternatives: ["Full labels", "Abbreviated labels", "Letters"],
                   name: "formOutcomeCategoryLabels", default_value: "Full labels",
                   prompt: "Labelling of outcome categories in the tree"})
    form.checkBox({label: "Allow long-running calculations", name: "formLongRunningCalculations", default_value: false,
                   prompt: "Allow predictors with more than 30 categories"});
}

if (algorithm == "Regression") {
    if (missing == "Multiple imputation")
        form.dropBox({label: "Auxiliary variables",
            types:["Variable: Numeric, Date, Money, Categorical, OrderedCategorical"], 
            name: "formAuxiliaryVariables", required: false, multi:true,
            prompt: "Additional variables to use when imputing missing values"});
    form.comboBox({label: "Correction", alternatives: ["None", "False Discovery Rate", "Bonferroni"], name: "formCorrection",
                   default_value: "None", prompt: "Multiple comparisons correction applied when computing p-values of post-hoc comparisons"});
    var is_RIA = (output == "Relative Importance Analysis");
    if (regressionType == "Linear" && missing != "Use partial data (pairwise correlations)" && missing != "Multiple imputation")
        form.checkBox({label: "Robust standard errors", name: "formRobustSE", default_value: false,
                       prompt: "Standard errors are robust to violations of assumption of constant variance"});
    if (output == "Relative Importance Analysis")
        form.checkBox({label: "Absolute importance scores", name: "formAbsoluteImportance", default_value: false,
                       prompt: "Show absolute instead of signed importances"});
    if (regressionType != "Multinomial Logit" && (is_RIA || output == "Summary"))
        form.dropBox({label: "Crosstab interaction", name: "formInteraction", types:["Variable: Numeric, Date, Money, Categorical, OrderedCategorical"],
                      required: false, prompt: "Categorical variable to test for interaction with other variables"});
}

form.numericUpDown({name:"formSeed", label:"Random seed", default_value: 12321, minimum: 1, maximum: 1000000,
                    prompt: "Initializes randomization for imputation and certain algorithms"});
library(flipMultivariates)

model <- MachineLearning(formula = QFormula(formOutcomeVariable ~ formPredictorVariables),
                                    algorithm = formAlgorithm,
                                    weights = QPopulationWeight, subset = QFilter,
                                    missing = formMissing, output = formOutput, show.labels = !formNames,
                                    seed = get0("formSeed"),
                                    cost = get0("formCost"),
                                    booster = get0("formBooster"),
                                    grid.search = get0("formSearch"),
                                    sort.by.importance = get0("formImportance"),
                                    hidden.nodes = get0("formHiddenLayers"),
                                    max.epochs = get0("formEpochs"),
                                    normalize = get0("formNormalize"),
                                    outcome.color = get0("formOutColor"),
                                    predictors.color = get0("formPredColor"),
                                    prior = get0("formPrior"),
                                    prune = get0("formPruning"),
                                    early.stopping = get0("formStopping"),
                                    predictor.level.treatment = get0("formPredictorCategoryLabels"),
                                    outcome.level.treatment = get0("formOutcomeCategoryLabels"),
                                    long.running.calculations = get0("formLongRunningCalculations"),
                                    type = get0("formRegressionType"),
                                    auxiliary.data = get0("formAuxiliaryVariables"),
                                    correction = get0("formCorrection"),
                                    robust.se = get0("formRobustSE", ifnotfound = FALSE),
                                    importance.absolute = get0("formAbsoluteImportance"),
                                    interaction = get0("formInteraction"))