swarmchart
Description
swarmchart(
creates a swarm chart,
or scatter plot with jittered (offset) points, for each predictor in
explainer
)explainer.BlackboxModel.PredictorNames
, where
explainer
is a shapley
object. For
each predictor, the function displays the Shapley values for the query points in
explainer.QueryPoints
. The corresponding swarm chart shows the
distribution of the Shapley values.
If explainer.BlackboxModel
is a classification model, the function
displays swarm charts for class explainer.BlackboxModel.ClassNames(1)
by
default.
swarmchart(
specifies additional options using one or more name-value arguments. For example, specify
explainer
,Name=Value
)NumImportantPredictors=5
to create swarm charts for the five predictors
with the greatest mean absolute Shapley values
(explainer.MeanAbsoluteShapley
).
swarmchart(
displays the
swarm charts in the target axes ax
,___)ax
. Specify ax
as
the first argument in any of the previous syntaxes.
returns an
array of s
= swarmchart(___)Scatter
objects. Use s
to query or modify the
properties (Scatter Properties) of an object after you
create it.
Examples
Shapley Value Swarm Charts for Classification Model
Train a classification model and create a shapley
object. Then visualize the Shapley values for multiple query points by using the swarmchart
object function.
Load the CreditRating_Historical
data set. The data set contains customer IDs and their financial ratios, industry labels, and credit ratings.
tbl = readtable("CreditRating_Historical.dat");
Display the first three rows of the table.
head(tbl,3)
ID WC_TA RE_TA EBIT_TA MVE_BVTD S_TA Industry Rating _____ _____ _____ _______ ________ _____ ________ ______ 62394 0.013 0.104 0.036 0.447 0.142 3 {'BB'} 48608 0.232 0.335 0.062 1.969 0.281 8 {'A' } 42444 0.311 0.367 0.074 1.935 0.366 1 {'A' }
Train a blackbox model of credit ratings by using the fitcecoc
function. Use the variables from the second through seventh columns in tbl
as the predictor variables. A recommended practice is to specify the class names to set the order of the classes.
blackbox = fitcecoc(tbl,"Rating", ... PredictorNames=tbl.Properties.VariableNames(2:7), ... CategoricalPredictors="Industry", ... ClassNames={'AAA','AA','A','BBB','BB','B','CCC'});
Create a shapley
object that explains the predictions for multiple query points. For faster computation, subsample 10% of the observations from tbl
with stratification and use the samples to compute the Shapley values. Specify the sampled observations as the query points.
rng("default") % For reproducibility c = cvpartition(tbl.Rating,"Holdout",0.10); sampleTbl = tbl(test(c),:); explainer = shapley(blackbox,sampleTbl, ... queryPoints=sampleTbl);
Visualize the Shapley values by using the swarmchart
object function.
swarmchart(explainer)
By default, the function shows the Shapley values for the first class, AAA
. For each predictor, the function displays the Shapley values for the query points. The corresponding swarm chart shows the distribution of the Shapley values. The function determines the order of the predictors by using the mean absolute Shapley values.
For class AAA
, the Shapley values for the RE_TA
predictor seem to follow the trend of the predictor values. That is, query points with lower RE_TA
values seem to have lower RE_TA
Shapley values. Similarly, query points with higher RE_TA
values seem to have higher RE_TA
Shapley values. You can use data tips to see the query point predictor values.
Shapley Value Swarm Charts for Regression Model
Train a regression model and create a shapley
object. Use the object function fit
to compute the Shapley values for the specified query points. Then plot the Shapley values for multiple query points by using the swarmchart
object function.
Load the carbig
data set, which contains measurements of cars made in the 1970s and early 1980s.
load carbig
Create a table containing the predictor variables Acceleration
, Cylinders
, and so on, as well as the response variable MPG
.
tbl = table(Acceleration,Cylinders,Displacement, ...
Horsepower,Model_Year,Weight,MPG);
Removing missing values in a training set helps to reduce memory consumption and speed up training for the fitrkernel
function. Remove missing values in tbl
.
tbl = rmmissing(tbl);
Train a blackbox model of MPG
by using the fitrkernel
function. Specify the Cylinders
and Model_Year
variables as categorical predictors. Standardize the remaining predictors.
rng("default") % For reproducibility mdl = fitrkernel(tbl,"MPG",CategoricalPredictors=[2 5], ... Standardize=true);
Create a shapley
object. Because mdl
does not contain training data, specify the data set tbl
.
explainer = shapley(mdl,tbl)
explainer = BlackboxModel: [1×1 RegressionKernel] QueryPoints: [] BlackboxFitted: [] ShapleyValues: [] X: [392×7 table] CategoricalPredictors: [2 5] Method: "interventional-kernel" Intercept: 22.7326 NumSubsets: 64
explainer
stores the training data tbl
in the X
property.
Compute the Shapley values for all observations in tbl
. Speed up computations by using the UseParallel
name-value argument, if you have a Parallel Computing Toolbox™ license.
explainer = fit(explainer,tbl,UseParallel=true);
Starting parallel pool (parpool) using the 'Processes' profile ... 10-Jan-2024 14:09:35: Job Queued. Waiting for parallel pool job with ID 5 to start ... Connected to parallel pool with 6 workers.
For a regression model, shapley
computes Shapley values using the predicted response, and stores them in the ShapleyValues
property. Because explainer
contains Shapley values for multiple query points, display the mean absolute Shapley values instead.
explainer.MeanAbsoluteShapley
ans=6×2 table
Predictor ShapleyValue
______________ ____________
"Acceleration" 0.52233
"Cylinders" 1.0412
"Displacement" 0.80485
"Horsepower" 0.7589
"Model_Year" 0.82285
"Weight" 0.98453
For each predictor, the mean absolute Shapley value is the absolute value of the Shapley values, averaged across all query points. The Cylinders
predictor has the greatest mean absolute Shapley value, and the Acceleration
predictor has the smallest mean absolute Shapley value.
Visualize the Shapley values by using the swarmchart
object function. Specify to use the "copper"
colormap.
swarmchart(explainer,ColorMap="copper")
For each predictor, the function displays the Shapley values for the query points. The corresponding swarm chart shows the distribution of the Shapley values. The function determines the order of the predictors by using the mean absolute Shapley values.
Query points with low Weight
values seem to have large positive Shapley values. That is, for these query points, the Weight
predictor contributes to an increase in the MPG
predicted value from the average. Similarly, query points with high Weight
values seem to have large negative Shapley values. That is, for these query points, the Weight
predictor contributes to a decrease in the MPG
predicted value from the average. These results match the idea that car weights are inversely correlated with MPG values.
Input Arguments
explainer
— Object explaining blackbox model
shapley
object
Object explaining the blackbox model, specified as a shapley
object. explainer
must contain Shapley values; that is, explainer.ShapleyValues
must be nonempty.
ax
— Axes for plot
Axes
object
Axes for the plot, specified as an Axes
object. If you do not
specify ax
, then swarmchart
creates the plot
using the current axes. For more information on creating an Axes
object, see axes
.
Name-Value Arguments
Specify optional pairs of arguments as
Name1=Value1,...,NameN=ValueN
, where Name
is
the argument name and Value
is the corresponding value.
Name-value arguments must appear after other arguments, but the order of the
pairs does not matter.
Example: swarmchart(explainer,NumImportantPredictors=5,ColorMap="copper")
creates a swarm chart for each of the five predictors with the greatest mean absolute
Shapley values and uses the copper colormap to indicate the range of predictor
values.
NumImportantPredictors
— Number of important predictors
min(M,10)
where M
is the number of predictors (default) | positive integer
Number of important predictors to plot, specified as a positive integer. The
swarmchart
function plots the Shapley values of the specified
number of predictors with the greatest mean absolute Shapley values.
Example: NumImportantPredictors=5
specifies to plot the five most important predictors. The swarmchart
function determines the order of importance by using the mean absolute Shapley values.
Data Types: single
| double
ClassName
— Class label to plot
explainer.BlackboxModel.ClassNames(1)
(default) | numeric scalar | logical scalar | character vector | string scalar | categorical scalar
Class label to plot, specified as a numeric scalar, logical scalar, character vector, string
scalar, or categorical scalar. The value and data type of the
ClassName
value must match one of the class names in the
ClassNames
property of the machine learning model in
explainer
(explainer.BlackboxModel.ClassNames
). Note that the software
accepts character vectors, string scalars, and categorical scalars
interchangeably.
This argument is valid only when the machine learning model (BlackboxModel
) in explainer
is a classification model.
Example: ClassName="AAA"
Data Types: single
| double
| logical
| char
| string
| categorical
YJitter
— Type of jitter along y-dimension
"density"
(default) | "rand"
Type of jitter (spacing of points) along the y-dimension, specified as one of the following values:
"density"
— Jitter the points using the kernel density estimate of the Shapley values."rand"
— Jitter the points randomly with a uniform distribution.
Example: YJitter="rand"
Data Types: char
| string
ColorMap
— Colormap for swarm charts
"default"
(default) | colormap name | "bluered"
Colormap for the swarm charts, specified as a predefined colormap name or
"bluered"
. A value of "default"
sets the
colormap to the default colormap for the target axes ax
. A value
of "bluered"
sets the colormap to a color scale that ranges from
blue to red. For more information on the available colormaps, see map
.
For more information on how swarmchart
maps predictor values
to the colormap, see Color Assignment for Predictor Values.
Example: ColorMap="parula"
Example: ColorMap="bluered"
Data Types: char
| string
More About
Shapley Values
In game theory, the Shapley value of a player is the average marginal contribution of the player in a cooperative game. In the context of machine learning prediction, the Shapley value of a feature for a query point explains the contribution of the feature to a prediction (response for regression or score of each class for classification) at the specified query point.
The Shapley value of a feature for a query point is the contribution of the feature to the deviation from the average prediction. For a query point, the sum of the Shapley values for all features corresponds to the total deviation of the prediction from the average. That is, the sum of the average prediction and the Shapley values for all features corresponds to the prediction for the query point.
For more details, see Shapley Values for Machine Learning Model.
Tips
Use
swarmchart
whenexplainer
contains Shapley values for many query points.
Algorithms
Color Assignment for Predictor Values
swarmchart
maps predictor values to the colormap specified by the
ColorMap
name-value argument as follows:
For each numeric predictor, the function determines the nonoutlier minimum and maximum values. The function maps the outliers and extrema (minimum and maximum values) to the appropriate colormap endpoints, and maps the remaining values to the interior of the colormap range using normalization.
For each nonnumeric predictor, the function uniformly maps categories to colors in the colormap. The color order of the categories is arbitrary.
Version History
Introduced in R2024a
MATLAB 命令
您点击的链接对应于以下 MATLAB 命令:
请在 MATLAB 命令行窗口中直接输入以执行命令。Web 浏览器不支持 MATLAB 命令。
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .
You can also select a web site from the following list:
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)