Video length is 5:49

Model Interpretability in MATLAB

From the series: Machine Learning in Finance

Interpretable machine learning (or in deep learning, “Explainable AI”) provides techniques and algorithms that overcome the black-box nature of AI models. By revealing how various features contribute (or do not contribute) to predictions, you can validate that the model is using the right evidence for its predictions and reveal model biases that were not apparent during training.

Get an overview of model interpretability and the use cases it addresses. For engineers and scientists who are interested in adopting machine learning but weary of black-box models, we explain how interpretability can satisfy regulations, build trust in machine learning, and validate that models are working. That’s particularly important in industries like finance and medical devices where regulations set strict guidelines. We provide an overview of interpretability methods for machine learning and how to apply them in MATLAB®. We demonstrate interpretability in the context of a medical application, classifying heart arrhythmia based on ECG signals.

Published: 19 Aug 2020

In recent years, we have seen AI and machine learning algorithms surpass or match human performance in many intelligence tasks, such as medical imaging diagnosis and operating motor vehicles. However, what is missing at the heart of these achievements is an intuitive understanding of how these algorithms work.

This video explains why interpretability is important, what methods exist for interpretability, and demonstrates how to use these techniques in MATLAB. Specifically, we will look at LIME, partial dependence plots, and permuted predictor importance algorithms. We will examine interpretability in the context of classifying electrocardiograms. The techniques described can be applied to any model. And a medical background is not required to follow along this video.

Why do we need interpretability? To start, machine learning models are not straightforward to understand and more accurate models are usually less interpretable. Further, interpretability methods are needed to help navigate regulatory hurdles in the medical, finance, and security industries.

Interpretable models are also needed to ensure that they are using the right evidence and reveal biases in the training data. A recent catastrophic use of AI was in credit card scoring where an algorithm reportedly gave higher credit limits to men over women. This could be due to biases in the training data, biases in the real-time data, or something else. Interpretive models help us prevent these issues.

For our example, you will apply interpretability to machine learning models trained to classify heartbeats as either abnormal or normal based on ECG data from two publicly available databases. The ECG represents the heart's response to electric stimulation from the sinus note and are typically decomposed into QRS ways. We'll use Matlab's Wavelet Toolbox to automatically extract the location of the QRS waves from the raw signal data. And from there, we extracted eight features from the R-peaks to be used for training.

Once we have the features, we can train models quickly using the Classification Learner. Here, we trained a decision tree as an example of an inherently interpretable model, alongside two complex ones. If accuracy were all that matter, it would simply pick the highest performing model. However, in situations such as predicting end of life care, interpretability is of great importance. And we will want to make sure that the model is making predictions using the right evidence and also understand the situations when the model may error.

Using Matlab's Permuted Predictor function, we see that for our best performing model, the random forest, the amplitude of the R-waves are included as important predictors. We can then use Matlab's partial dependents plots to quantify the effect of the R-amplitude on the model output. We see that as the amplitude approaches 0, this contributes to a 5% change in the probability of outputting an abnormal heartbeat classification.

However, this contradicts our domain knowledge. Experts say that R-amplitude levels should have little effect on the classification of a heartbeat. We would want to ensure that these biases in the data are not included in our model. So next, we retrain our models without the amplitudes as predictors. Once we have removed the bias, we can see how our new decision tree works on a global level. Instead of paying attention to R-amplitudes, the tree considers the RR0 and RR2 intervals to be the most important predictors.

For more complex models like our random forest, we, again, utilize partial dependency plots to see how our most important predictors affect the model. We see that extremely short RR1 intervals generally lead to a higher probability of an abnormal heartbeat classification. Intuitively, this makes sense.

We can also use partial dependency plots to compare different models. Looking at the same feature for the SVM shows that it has a similar trend to our random forest. However, the plot is far smoother, suggesting that the SVM is less sensitive to variance and input data, making it a more interpretable model.

Beyond understanding how these models work on a global scale, other situations may call for us to understand how they work for individual predictions. LIME is a technique that looks at the data points and model predictions around a point of interest. From there, it builds a simple linear model that acts as an approximation for our complex one. The coefficients of our approximate linear model are used as proxies for determining how much each feature contributes to predictions around our point of interest.

Let's look at an observation that our SVM misclassifies as normal. We see that our value for RR0 in this observation is 0.0528. And from our partial dependency plots earlier, we note that at values around 0.05, the probability of predicting an abnormal heartbeat goes down. We can also see that LIME places a high negative weight on RR0. The high value of RR0 and the negative weighting drive down the probability of predicting an abnormal heartbeat, explaining our misclassification.

However, there are some limitations. LIME acts as an approximation for our model and is by no means an exact representation of how our model works. To illustrate this, we can see that there are situations where the prediction of our complex model does not match up with the approximation. To avoid this, try running the LIME algorithm again with different parameters until the predictions agree, such as increasing the number of important predictors to plot.

We have demonstrated how we can use interpretability techniques in MATLAB and can now use interpretability to compare different models, reveal data biases, and understand why predictions go wrong. Even without a data science background, we can all be a part of the movement to make machine learning explainable. See the links below for more information about any of the techniques introduced in the video. Similar interpretability techniques also exist for neural networks, so please be sure to check out those resources as well.