Main Content

predict

Predict labels using classification tree model

    Description

    example

    label = predict(tree,X) returns a vector of predicted class labels for the predictor data in the table or matrix X, based on the trained classification tree tree.

    example

    label = predict(tree,X,Subtrees=subtrees) also prunes tree to the level specified by subtrees, before predicting labels.

    example

    [label,score,node,cnum] = predict(___) also returns the following, using any of the input argument combinations in the previous syntaxes:

    • A matrix of classification scores (score) indicating the likelihood that a label comes from a particular class. For classification trees, scores are posterior probabilities. For each observation in X, the predicted class label corresponds to the minimum expected misclassification cost among all classes.

    • A vector of predicted node numbers for the classification (node).

    • A vector of predicted class numbers for the classification (cnum).

    Examples

    collapse all

    Examine predictions for a few rows in a data set left out of training.

    Load Fisher's iris data set.

    load fisheriris

    Partition the data into training (50%) and validation (50%) sets.

    n = size(meas,1);
    rng(1) % For reproducibility
    idxTrn = false(n,1);
    idxTrn(randsample(n,round(0.5*n))) = true;
    idxVal = idxTrn == false;                 

    Grow a classification tree using the training set.

    Mdl = fitctree(meas(idxTrn,:),species(idxTrn));

    Predict labels for the validation data, and display several predicted labels. Count the number of misclassified observations.

    label = predict(Mdl,meas(idxVal,:));
    label(randsample(numel(label),5))
    ans = 5x1 cell
        {'setosa'    }
        {'setosa'    }
        {'setosa'    }
        {'virginica' }
        {'versicolor'}
    
    
    numMisclass = sum(~strcmp(label,species(idxVal)))
    numMisclass = 3
    

    The software misclassifies three out-of-sample observations.

    Load Fisher's iris data set.

    load fisheriris

    Partition the data into training (50%) and validation (50%) sets.

    n = size(meas,1);
    rng(1) % For reproducibility
    idxTrn = false(n,1);
    idxTrn(randsample(n,round(0.5*n))) = true;
    idxVal = idxTrn == false;

    Grow a classification tree using the training set, and then view it.

    Mdl = fitctree(meas(idxTrn,:),species(idxTrn));
    view(Mdl,"Mode","graph")

    The resulting tree has four levels.

    Estimate posterior probabilities for the test set using subtrees pruned to levels 1 and 3. Display several posterior probabilities.

    [~,Posterior] = predict(Mdl,meas(idxVal,:), ...
        Subtrees=[1 3]);
    Mdl.ClassNames
    ans = 3x1 cell
        {'setosa'    }
        {'versicolor'}
        {'virginica' }
    
    
    Posterior(randsample(size(Posterior,1),5),:,:)
    ans = 
    ans(:,:,1) =
    
        1.0000         0         0
        1.0000         0         0
        1.0000         0         0
             0         0    1.0000
             0    0.8571    0.1429
    
    
    ans(:,:,2) =
    
        0.3733    0.3200    0.3067
        0.3733    0.3200    0.3067
        0.3733    0.3200    0.3067
        0.3733    0.3200    0.3067
        0.3733    0.3200    0.3067
    
    

    The elements of Posterior are class posterior probabilities:

    • Rows correspond to observations in the validation set.

    • Columns correspond to the classes as listed in Mdl.ClassNames.

    • Pages correspond to the subtrees.

    The subtree pruned to level 1 is more sure of its predictions than the subtree pruned to level 3 (that is, the root node).

    Input Arguments

    collapse all

    Trained classification tree, specified as a ClassificationTree model object trained with fitctree, or a CompactClassificationTree model object created with compact.

    Predictor data to be classified, specified as a numeric matrix or a table.

    Each row of X corresponds to one observation, and each column corresponds to one variable.

    For a numeric matrix:

    • The variables that make up the columns of X must have the same order as the predictor variables used to train tree.

    • If you train tree using a table (for example, Tbl), then X can be a numeric matrix if Tbl contains all numeric predictor variables. To treat numeric predictors in Tbl as categorical during training, identify categorical predictors using the CategoricalPredictors name-value argument of fitctree. If Tbl contains heterogeneous predictor variables (for example, numeric and categorical data types) and X is a numeric matrix, then predict issues an error.

    For a table:

    • predict does not support multicolumn variables or cell arrays other than cell arrays of character vectors.

    • If you train tree using a table (for example, Tbl), then all predictor variables in X must have the same variable names and data types as those used to train tree (stored in tree.PredictorNames). However, the column order of X does not need to correspond to the column order of Tbl. Tbl and X can contain additional variables (response variables, observation weights, and so on), but predict ignores them.

    • If you train tree using a numeric matrix, then the predictor names in tree.PredictorNames and corresponding predictor variable names in X must be the same. To specify predictor names during training, use the PredictorNames name-value argument of fitctree. All predictor variables in X must be numeric vectors. X can contain additional variables (response variables, observation weights, and so on), but predict ignores them.

    Data Types: table | double | single

    Pruning level, specified as a vector of nonnegative integers in ascending order or "all".

    If you specify a vector, then all elements must be at least 0 and at most max(tree.PruneList). 0 indicates the full, unpruned tree, and max(tree.PruneList) indicates the completely pruned tree (that is, just the root node).

    If you specify "all", then predict operates on all subtrees (that is, the entire pruning sequence). This specification is equivalent to using 0:max(tree.PruneList).

    predict prunes tree to each level specified by subtrees, and then estimates the corresponding output arguments. The size of subtrees determines the size of some output arguments.

    For the function to invoke subtrees, the properties PruneList and PruneAlpha of tree must be nonempty. In other words, grow tree by setting Prune="on" when you use fitctree, or by pruning tree using prune.

    Data Types: single | double | char | string

    Output Arguments

    collapse all

    Predicted class labels, returned as a categorical or character array, logical or numeric vector, or cell array of character vectors. Each entry of label corresponds to the class with the minimal expected cost for the corresponding row of X.

    Suppose subtrees is a numeric vector containing T elements, and X has N rows.

    • If the response data type is char and T = 1, then label is a character matrix containing N rows. Each row contains the predicted label produced by subtrees.

    • If the response data type is char and T > 1, then label is an N-by-T cell array. Column j of label contains the vector of predicted labels produced by subtree subtrees(j).

    • Otherwise, label is an N-by-T array that has the same data type as the response. Column j of label contains the vector of predicted labels produced by subtree subtrees(j). (The software treats string arrays as cell arrays of character vectors.)

    Posterior probabilities, returned as a numeric matrix of size N-by-K, where N is the number of observations (rows) in X, and K is the number of classes (in tree.ClassNames). score(i,j) is the posterior probability that row i in X is of class j in tree.ClassNames.

    If subtrees has T elements, and X has N rows, then score is an N-by-K-by-T array, and node and cnum are N-by-T matrices.

    Node numbers for the predicted classes, returned as a numeric vector. Each entry corresponds to the predicted node in tree for the corresponding row of X.

    Class numbers corresponding to the predicted labels, returned as a numeric vector. Each entry of cnum corresponds to the predicted class number for the corresponding row of X.

    More About

    collapse all

    Predicted Class Label

    predict classifies by minimizing the expected misclassification cost:

    y^=argminy=1,...,Kj=1KP^(j|x)C(y|j),

    where:

    • y^ is the predicted classification.

    • K is the number of classes.

    • P^(j|x) is the posterior probability of class j for observation x.

    • C(y|j) is the cost of classifying an observation as y when its true class is j.

    Score (tree)

    For trees, the score of a classification of a leaf node is the posterior probability of the classification at that node. The posterior probability of the classification at a node is the number of training sequences that lead to that node with the classification, divided by the number of training sequences that lead to that node.

    For an example, see Posterior Probability Definition for Classification Tree.

    True Misclassification Cost

    The true misclassification cost is the cost of classifying an observation into an incorrect class.

    You can set the true misclassification cost per class by using the Cost name-value argument when you create the classifier. Cost(i,j) is the cost of classifying an observation into class j when its true class is i. By default, Cost(i,j)=1 if i~=j, and Cost(i,j)=0 if i=j. In other words, the cost is 0 for correct classification and 1 for incorrect classification.

    Expected Cost

    The expected misclassification cost per observation is an averaged cost of classifying the observation into each class.

    Suppose you have Nobs observations that you want to classify with a trained classifier, and you have K classes. You place the observations into a matrix X with one observation per row.

    The expected cost matrix CE has size Nobs-by-K. Each row of CE contains the expected (average) cost of classifying the observation into each of the K classes. CE(n,k) is

    i=1KP^(i|X(n))C(k|i),

    where:

    • K is the number of classes.

    • P^(i|X(n)) is the posterior probability of class i for observation X(n).

    • C(k|i) is the true misclassification cost of classifying an observation as k when its true class is i.

    Predictive Measure of Association

    The predictive measure of association is a value that indicates the similarity between decision rules that split observations. Among all possible decision splits that are compared to the optimal split (found by growing the tree), the best surrogate decision split yields the maximum predictive measure of association. The second-best surrogate split has the second-largest predictive measure of association.

    Suppose xj and xk are predictor variables j and k, respectively, and jk. At node t, the predictive measure of association between the optimal split xj < u and a surrogate split xk < v is

    λjk=min(PL,PR)(1PLjLkPRjRk)min(PL,PR).

    • PL is the proportion of observations in node t, such that xj < u. The subscript L stands for the left child of node t.

    • PR is the proportion of observations in node t, such that xju. The subscript R stands for the right child of node t.

    • PLjLk is the proportion of observations at node t, such that xj < u and xk < v.

    • PRjRk is the proportion of observations at node t, such that xju and xkv.

    • Observations with missing values for xj or xk do not contribute to the proportion calculations.

    λjk is a value in (–∞,1]. If λjk > 0, then xk < v is a worthwhile surrogate split for xj < u.

    Algorithms

    predict generates predictions by following the branches of tree until it reaches a leaf node or a missing value. If predict reaches a leaf node, it returns the classification of that node.

    If predict reaches a node with a missing value for a predictor, its behavior depends on the setting of the Surrogate name-value argument when fitctree constructs tree.

    • Surrogate = "off" (default) — predict returns the label with the largest number of training samples that reach the node.

    • Surrogate = "on"predict uses the best surrogate split at the node. If all surrogate split variables with positive predictive measure of association are missing, predict returns the label with the largest number of training samples that reach the node. For a definition, see Predictive Measure of Association.

    Alternative Functionality

    Simulink Block

    To integrate the prediction of a classification tree model into Simulink®, you can use the ClassificationTree Predict block in the Statistics and Machine Learning Toolbox™ library or a MATLAB® Function block with the predict function. For examples, see Predict Class Labels Using ClassificationTree Predict Block and Predict Class Labels Using MATLAB Function Block.

    When deciding which approach to use, consider the following:

    • If you use the Statistics and Machine Learning Toolbox library block, you can use the Fixed-Point Tool (Fixed-Point Designer) to convert a floating-point model to fixed point.

    • Support for variable-size arrays must be enabled for a MATLAB Function block with the predict function.

    • If you use a MATLAB Function block, you can use MATLAB functions for preprocessing or post-processing before or after predictions in the same MATLAB Function block.

    Extended Capabilities

    Version History

    Introduced in R2011a