Machine Learning - Compare Models

From Q
Jump to navigation Jump to search

Compare the performance of multiple Machine Learning and Regression models by producing a table of metrics from each model

Compare the performance of multiple Machine Learning and Regression models by producing a table of metrics from each model. The metrics are computed based on each model's training data. Optionally a filter specifying evaluation data (usually a testing sample independent of the training sample) may also be provided. The models may either be existing already, or created for the comparison.

Usage

To compare machine learning models:

1. In Displayr, select Anything > Advanced Analysis > Machine Learning > Compare Models. In Q, select Create > Classifier > Compare Models.
2. Under Inputs > Existing or new models select whether you would like to compare existing models or create multiple new models for comparison.
3. If your selection is to work with Existing models then under Inputs > EXISTING MODELS > Input models select the models you want to compare and change any other settings.

Machinelearning comparemodels inputs.PNG

4. If your selection is to work with New models then under Inputs > COMMON INPUTS select your Outcome and Predictor(s) variables.
5. For New models you should also select which MODEL X > Algorithm to use for each of the models to compare.

Machinelearning comparemodels inputs2.PNG

Example


Options

Existing or new models Choose to use existing machine learning models or create new models to compare.

Existing models

Input models At least 2 existing machine learning models.

Ensemble Whether to create an ensemble model by combining the predictions of the underlying models.


New models

Outcome The variable to be predicted by the predictor variables.

Predictors The variable(s) to predict the Outcome.

Missing data See Missing Data Options.

Variable names Displays Variable Names in the output.

Random seed Seed used to initialize the (pseudo) random number generator for the model fitting algorithm.

Ensemble Whether to create an ensemble model by combining the predictions of the underlying models.

Evaluation filter Select a filter to apply to the models.

Models For each model, select a machine learning algorithm and the desired settings for each model. See for more details.

For model-specific options see Classification And Regression Trees (CART), Linear Discriminant Analysis, Random Forest, Support Vector Machine, Deep Learning, Gradient Boosting or Regression.


Code

var allow_control_groups = Q.fileFormatVersion() > 10.9; // Group controls for Displayr and later versions of Q

var existing = form.comboBox({label: "Existing or new models", 
              alternatives: ["Existing models", "New models"], name: "formExisting", default_value: "Existing models",
              prompts: "Whether the underlying models are existing objects or new models will be built"}).getValue();

if (existing == "New models") {  // Common controls to create new models
    
    if (allow_control_groups)
        form.group("Common inputs");

    form.dropBox({label: "Outcome", 
                types:["Variable: Numeric, Date, Money, Categorical, OrderedCategorical"], 
                name: "formOutcomeVariable",
                prompt: "Independent target variable to be predicted"});
    form.dropBox({label: "Predictor(s)",
                types:["Variable: Numeric, Date, Money, Categorical, OrderedCategorical"], 
                name: "formPredictorVariables", multi:true,
                prompt: "Dependent input variables"});

    var missing = form.comboBox({label: "Missing data", 
                  alternatives: ["Error if missing data", "Exclude cases with missing data", "Imputation (replace missing values with estimates)"], name: "formMissing", default_value: "Exclude cases with missing data",
                  prompt: "Options for handling cases with missing data"}).getValue();
    form.checkBox({label: "Variable names", name: "formNames", default_value: false, prompt: "Display names instead of labels"});

    form.numericUpDown({name:"formSeed", label:"Random seed", default_value: 12321, minimum: 1, maximum: Number.MAX_SAFE_INTEGER,
                        prompt: "Initializes randomization for imputation and certain algorithms"});

} else {    // Select existing models
    
    if (allow_control_groups)
        form.group("Existing models");
    var modelsInput = form.dropBox({label: "Input models", types:["RItem:MachineLearning,Regression"], name: "formModels",
                               multi: true, required: true, min_inputs: 2,
                               prompt: "Select at least 2 Machine Learning or Regression models"});
}

// Ensemble or comparison
var ensemble = form.checkBox({label: "Ensemble", name: "formEnsemble", default_value: false,
                              prompt: "Whether to create an ensemble of the models"});
var heading_text;
var plural_text;
if (ensemble.getValue()) {
    heading_text = 'Ensemble of Machine Learning Models';
    plural_text = 'Ensembles of Machine Learning Models'
    form.checkBox({label: "Optimal ensemble", name: "formOptimalEnsemble", default_value: false,
                              prompt: "Find the ensemble with the best performance."});
    var output = form.comboBox({label: "Output", 
              alternatives: ["Comparison", "Ensemble"], name: "formOutput", default_value: "Comparison",
              prompts: "A table comparing the models, or a prediction-accuracy table for the ensemble."});
} else {
    heading_text = 'Compare Machine Learning Models';
    plural_text = 'Comparisons of Machine Learning Models'
}

if (!!form.setObjectInspectorTitle)
    form.setObjectInspectorTitle(heading_text, plural_text);
else 
    form.setHeading(heading_text);

if (existing == "New models") {    // Evaluation filter and recursive model-specific controls
    
    if (allow_control_groups)
        form.group("Evaluation filter");

    form.dropBox({label: "Evaluation filter", 
                types:["v:!hidden:filter"], 
                name: "formEvaluationFilter",
                required: false,
                prompt: "Used to calculate out-of-sample performance metrics"});

    var model = 0;
    var algorithm = "first";
    while (algorithm != " ") {

        ++model;
        if (allow_control_groups)
            form.group("Model " + model);

        algorithm = form.comboBox({label: "Algorithm",
                   alternatives: [" ", "CART", "Deep Learning", "Gradient Boosting", "Linear Discriminant Analysis",
                                  "Random Forest", "Regression", "Support Vector Machine"],
                   name: "formAlgorithm" + model, required: false, default_value: " ",
                   prompt: "Machine learning or regression algorithm for fitting the model"}).getValue();

        // CONTROLS FOR SPECIFIC ALGORITHMS
        if (algorithm == "Support Vector Machine")
            form.textBox({label: "Cost", name: "formCost" + model, default_value: 1, type: "number",
                          prompt: "High cost produces a complex model with risk of overfitting, low cost produces a simpler mode with risk of underfitting"});

        if (algorithm == "Gradient Boosting") {
            form.comboBox({label: "Booster", 
                          alternatives: ["gbtree", "gblinear"], name: "formBooster" + model, default_value: "gbtree",
                          prompt: "Boost tree or linear underlying models"})
            form.checkBox({label: "Grid search", name: "formSearch" + model, default_value: false,
                           prompt: "Search for optimal hyperparameters"});
        }

        if (algorithm == "Random Forest")
            if (output == "Importance")
                form.checkBox({label: "Sort by importance", name: "formImportance" + model, default_value: true});

        if (algorithm == "Deep Learning") {
            form.numericUpDown({name:"formEpochs" + model, label:"Maximum epochs", default_value: 10, minimum: 1, maximum: Number.MAX_SAFE_INTEGER,
                                prompt: "Number of rounds of training"});
            form.textBox({name: "formHiddenLayers" + model, label: "Hidden layers", prompt: "Comma delimited list of the number of nodes in each hidden layer", required: true});
            form.checkBox({label: "Normalize predictors", name: "formNormalize" + model, default_value: true,
                           prompt: "Normalize to zero mean and unit variance"});
        }

        if (algorithm == "Linear Discriminant Analysis") {
            if (output == "Scatterplot")
            {
                form.colorPicker({label: "Outcome color", name: "formOutColor" + model, default_value:"#5B9BD5"});
                form.colorPicker({label: "Predictors color", name: "formPredColor" + model, default_value:"#ED7D31"});
            }
            form.comboBox({label: "Prior", alternatives: ["Equal", "Observed",], name: "formPrior" + model, default_value: "Observed",
                           prompt: "Probabilities of group membership"})
        }

        if (algorithm == "CART") {
            form.comboBox({label: "Pruning", alternatives: ["Minimum error", "Smallest tree", "None"], 
                           name: "formPruning" + model, default_value: "Minimum error",
                           prompt: "Remove nodes after tree has been built"})
            form.checkBox({label: "Early stopping", name: "formStopping" + model, default_value: false,
                           prompt: "Stop building tree when fit does not improve"});
            form.checkBox({label: "Allow long-running calculations", name: "formLongRunningCalculations" + model, default_value: false,
                           prompt: "Allow predictors with more than 30 categories"});
        }

        if (algorithm == "Regression") {
            var regressionType = form.comboBox({label: "Regression type", 
                                            alternatives: ["Linear", "Binary Logit", "Ordered Logit", "Multinomial Logit", "Poisson",
                                                                                                              "Quasi-Poisson", "NBD"], 
                                            name: "formRegressionType" + model, default_value: "Linear",
                                            prompt: "Select type according to outcome variable type"}).getValue();
            if (missing == "Multiple imputation")
                form.dropBox({label: "Auxiliary variables",
                    types:["Variable: Numeric, Date, Money, Categorical, OrderedCategorical"], 
                    name: "formAuxiliaryVariables" + model, required: false, multi:true,
                    prompt: "Additional variables to use when imputing missing values"});
            form.comboBox({label: "Correction", alternatives: ["None", "False Discovery Rate", "Bonferroni"], name: "formCorrection" + model,
                           default_value: "None", prompt: "Multiple comparisons correction applied when computing p-values of post-hoc comparisons"});
            if (regressionType == "Linear" && missing != "Use partial data (pairwise correlations)" && missing != "Multiple imputation")
                form.checkBox({label: "Robust standard errors", name: "formRobustSE" + model, default_value: false,
                               prompt: "Standard errors are robust to violations of assumption of constant variance"});
        }
    }
}
library(flipMultivariates)

comparison <- if (formExisting == "Existing models") {
    MachineLearningEnsemble(models = formModels,
                            compare.only = !formEnsemble,
                            optimal.ensemble = get0("formOptimalEnsemble", ifnotfound = FALSE),
                            evaluation.subset = QFilter,
                            evaluation.weights = QPopulationWeight,
                            output = get0("formOutput", ifnotfound = "Comparison"))   
} else { # new models
    models <- 0
    while (get0(paste0("formAlgorithm", models + 1)) != " ")
        models <- models + 1
    if (models == 0)
        stop("At least one model must be specified.")

    models.args <- list()
    for (i in seq(models)) {
        models.args[[i]] <- list(algorithm = get0(paste0("formAlgorithm", i)),
                                 cost = get0(paste0("formCost", i)),
                                 booster = get0(paste0("formBooster", i)),
                                 grid.search = get0(paste0("formSearch", i)),
                                 sort.by.importance = get0(paste0("formImportance", i)),
                                 hidden.nodes = get0(paste0("formHiddenLayers", i)),
                                 max.epochs = get0(paste0("formEpochs", i)),
                                 normalize = get0(paste0("formNormalize", i)),
                                 outcome.color = get0(paste0("formOutColor", i)),
                                 predictors.color = get0(paste0("formPredColor", i)),
                                 prior = get0(paste0("formPrior", i)),
                                 prune = get0(paste0("formPruning", i)),
                                 early.stopping = get0(paste0("formStopping", i)),
                                 long.running.calculations = get0(paste0("formLongRunningCalculations" ,i)),
                                 type = get0(paste0("formRegressionType", i)),
                                 auxiliary.data = get0(paste0("formAuxiliaryVariables", i)),
                                 correction = get0(paste0("formCorrection", i)),
                                 robust.se = get0(paste0("formRobustSE", i), ifnotfound = FALSE),
                                 importance.absolute = get0(paste0("formAbsoluteImportance", i)), # not used
                                 interaction = get0(paste0("formInteraction", i))) # not used
    }

    MachineLearningMulti(QFormula(formOutcomeVariable ~ formPredictorVariables),
                         weights = QPopulationWeight,
                         subset = QFilter,
                         evaluation.subset = get0("formEvaluationFilter"),
                         missing = formMissing,
                         show.labels = !formNames,
                         seed = get0("formSeed"),
                         models.args = models.args,
                         compare.only = !get0("formEnsemble"),
                         optimal.ensemble = get0("formOptimalEnsemble", ifnotfound = FALSE),
                         output = get0("formOutput", ifnotfound = "Comparison"))
}