Shapley Output Functions
A Shapley output function is a function that is called at the end of every iteration of
the shapley
or
fit
function. The
output function can stop Shapley computations, create plots, save information to your
workspace, or perform calculations using query point information.
To use the OutputFcn
name-value argument in the call to
shapley
or fit
, write a custom output function
with this signature:
stop = outputfcn(x,results,state)
The shapley
or fit
function passes the variables
x
, results
, and state
to your
output function. Your output function returns stop
, which you set to
true
to stop the iterations, or false
to allow the
iterations to continue.
x
contains the Shapley values for the query point at the current iteration.results
is a structure with these fields:Iteration
— Current iteration numberQueryPointIndex
— Index of the query point evaluated at the current iterationTimePerQuery
— Time spent computing the Shapley values for the query point at the current iterationMethod
— Method used to compute the Shapley values for the query point at the current iteration
state
has these possible values:"init"
—shapley
orfit
is about to start iterating."iter"
—shapley
orfit
just finished an iteration."done"
—shapley
orfit
just finished its final iteration.
Stop Shapley Value Computations Early
Train a classification model. Compute the Shapley values for multiple query points. Specify to stop the Shapley computations if they take too much time, and plot the partial results.
Load the CreditRating_Historical
data set. The data set contains customer IDs and their financial ratios, industry labels, and credit ratings.
tbl = readtable("CreditRating_Historical.dat");
Display the first three rows of the table.
head(tbl,3)
ID WC_TA RE_TA EBIT_TA MVE_BVTD S_TA Industry Rating _____ _____ _____ _______ ________ _____ ________ ______ 62394 0.013 0.104 0.036 0.447 0.142 3 {'BB'} 48608 0.232 0.335 0.062 1.969 0.281 8 {'A' } 42444 0.311 0.367 0.074 1.935 0.366 1 {'A' }
Train a blackbox model of credit ratings by using the fitcecoc
function. Use the variables from the second through seventh columns in tbl
as the predictor variables. A recommended practice is to specify the class names to set the order of the classes.
blackbox = fitcecoc(tbl,"Rating", ... PredictorNames=tbl.Properties.VariableNames(2:7), ... CategoricalPredictors="Industry", ... ClassNames={'AAA','AA','A','BBB','BB','B','CCC'});
Create a shapley
object that explains the predictions for multiple query points. For faster computation, subsample 10% of the observations from tbl
with stratification and use the samples to compute the Shapley values. Specify the sampled observations as the query points.
Use the output function earlystop
(shown at the end of this example) to stop the Shapley value computations early if the cumulative computation time exceeds 60 seconds. If shapley
stops early, the output function creates two new variables in the workspace: totalTime
and numQueryPoints
.
rng("default") % For reproducibility c = cvpartition(tbl.Rating,"Holdout",0.10); sampleTbl = tbl(test(c),:); explainer = shapley(blackbox,sampleTbl, ... queryPoints=sampleTbl,NumObservationsToSample="all", ... OutputFcn=@earlystop);
Warning: Iterations terminated prematurely by user.
Display the total Shapley value computation time.
totalTime
totalTime = 60.2992
Note that the time only slightly exceeds 60 seconds.
Compare the total number of observations in sampleTbl
to the number of query points whose Shapley values were computed by shapley
.
numObservations = size(sampleTbl,1)
numObservations = 393
numQueryPoints
numQueryPoints = 102
Display a swarm chart of the partial results.
swarmchart(explainer)
Output Function
The output function earlystop
uses the query point computation times in results
(results.TimePerQuery
) to determine whether to stop Shapley computations early. If the cumulative computation time exceeds 60 seconds, the function stops early. This code creates the earlystop
output function.
function stop = earlystop(~,results,state) persistent totalTime stop = false; switch state case "init" totalTime = 0; case "iter" totalTime = totalTime + results.TimePerQuery; if totalTime > 60 assignin("base","totalTime",totalTime) assignin("base","numQueryPoints",results.Iteration) stop = true; end otherwise end end
Find Method Used for Individual Shapley Value Computations
Train an ensemble model that uses tree weak learners with surrogate splits. Compute the Shapley values for multiple query points using predictor data that contains missing values. In this case, the Shapley value computation algorithm might not be the same for all query points. Use an output function to determine the method used to compute the Shapley values for each query point.
Load the fisheriris
data set, which contains measurements for 150 irises, and create a table. SepalLength
, SepalWidth
, PetalLength
, and PetalWidth
are the predictor variables, and Species
is the response variable.
fisheriris = readtable("fisheriris.csv");
Partition the data into two sets. Use 50% of the observations for training and 50% of the observations for computing Shapley values.
rng("default") c = cvpartition(fisheriris.Species,"Holdout",0.5); trainTbl = fisheriris(training(c),:); queryTbl = fisheriris(test(c),:);
For this example, add a missing value to the second observation in queryTbl
.
queryTbl{2,4} = NaN; queryTbl(2,:)
ans=1×5 table
SepalLength SepalWidth PetalLength PetalWidth Species
___________ __________ ___________ __________ __________
4.9 3 1.4 NaN {'setosa'}
Train a classification ensemble by using the fitcensemble
function. Specify to use tree stumps with surrogate splits as the weak learners.
tree = templateTree(Surrogate="on",MaxNumSplits=1); mdl = fitcensemble(trainTbl,"Species",Learners=tree);
Create a shapley object that explains the predictions for the query points in queryTbl
. Use the queryTbl
predictor data to compute the Shapley values.
Use the output function methodinfo
(shown at the end of this example) to find the Shapley value computation algorithm used for each query point. The function also returns the index of the query point evaluated at each iteration.
explainer = shapley(mdl,queryTbl,QueryPoints=queryTbl, ...
OutputFcn=@methodinfo)
Warning: Computations might be slow when the tree-based model uses surrogate splits for prediction. In this case, the software uses a mix of 'interventional-kernel' and 'interventional-tree'.
explainer = shapley explainer with the following mean absolute Shapley values: Predictor setosa versicolor virginica _____________ __________ __________ __________ "SepalLength" 0.056765 0.23593 0.17916 "SepalWidth" 4.8278e-16 3.7358e-16 2.9962e-16 "PetalLength" 4.4249 1.6484 3.1843 "PetalWidth" 0.1696 0.52159 0.69119 Properties, Methods
The warning message indicates that shapley
might use a mix of the Tree SHAP algorithm with an interventional value function and the Kernel SHAP algorithm with an interventional value function. The Method
property of the explainer
object reflects this information with the value "interventional-mix"
.
explainer.Method
ans = "interventional-mix"
Create a table containing the method information for each query point.
methodInfoTbl = table(queryPointIndex',methodType', ... VariableNames=["QueryPointIndex","Method"])
methodInfoTbl=75×2 table
QueryPointIndex Method
_______________ _______________________
1 "interventional-kernel"
2 "interventional-kernel"
3 "interventional-kernel"
4 "interventional-kernel"
5 "interventional-kernel"
6 "interventional-kernel"
7 "interventional-kernel"
8 "interventional-kernel"
9 "interventional-kernel"
10 "interventional-kernel"
11 "interventional-kernel"
12 "interventional-kernel"
13 "interventional-kernel"
14 "interventional-kernel"
15 "interventional-kernel"
16 "interventional-kernel"
⋮
unique(methodInfoTbl.Method)
ans = "interventional-kernel"
In this example, every query point uses the "interventional-kernel"
method.
As a convenience, the output function methodinfo
additionally returns the Shapley values for each query point. This information is also available in the Shapley
property of explainer
.
Find the Shapley values for the second query point. Recall from the table methodInfoTbl
that the function evaluated the second query point during the second iteration.
rowNames = explainer.Shapley{:,1}; varNames = ... explainer.Shapley.Properties.VariableNames(2:end); queryPointInfo = array2table(shapleyValues(:,:,2), ... RowNames=rowNames,VariableNames=varNames)
queryPointInfo=4×3 table
setosa versicolor virginica
___________ ___________ __________
SepalLength 0.037345 -0.15521 0.11787
SepalWidth -9.7992e-16 -4.8954e-16 6.7612e-16
PetalLength 6.6859 -2.0038 -4.6821
PetalWidth 0.067022 0.20267 -0.2697
For an example that shows how to find the Shapley values for a specific query point without using an output function, see Investigate One Query Point After Fitting Multiple Query Points.
Output Function
The output function methodinfo
records the query point index (results.QueryPointIndex
), Shapley values (x
), and method (results.Method
) at each iteration. The function returns the information to the MATLAB® workspace as the variables queryPointIndex
, shapleyValues
, and methodType
, respectively. This code creates the methodinfo
output function.
function stop = methodinfo(x,results,state) persistent queryPointIndex persistent shapleyValues persistent methodType stop = false; switch state case "init" queryPointIndex = []; shapleyValues = zeros(4,3,1); % Initialize shapleyValues based on predictors and classes methodType = ""; case "iter" queryPointIndex(results.Iteration) = results.QueryPointIndex; shapleyValues(:,:,results.Iteration) = x; methodType(results.Iteration) = results.Method; case "done" assignin("base","queryPointIndex",queryPointIndex) assignin("base","shapleyValues",shapleyValues) assignin("base","methodType",methodType) otherwise end end
See Also
shapley
| fit
| plot
| boxchart
| swarmchart