Main Content

本页的翻译已过时。点击此处可查看最新英文版本。

使用长短期记忆网络对 ECG 信号进行分类

此示例说明如何使用深度学习和信号处理对来自 PhysioNet 2017 Challenge 的心电图 (ECG) 数据进行分类。具体而言,该示例使用长短期记忆网络和时频分析。

简介

ECG 记录一段时间内人体心脏的电活动。医生使用 ECG 检测患者的心跳是正常还是不规则。

心房纤维性颤动 (AFib) 是一种不规则的心跳,当心脏的上腔(即心房)与下腔(即心室)失去协调时就会发生心房纤维性颤动。

此示例使用的 ECG 数据来自 PhysioNet 2017 Challenge [1]、[2]、[3],可从 https://physionet.org/challenge/2017/ 获得。数据由一组以 300 Hz 采样的 ECG 信号组成,由一组专家分成四个不同类:正常 (N)、AFib (A)、其他心律 (O) 和含噪记录 (~)。此示例说明如何使用深度学习来自动化分类过程。该过程使用一个二类分类器,该二类分类器可以将正常 ECG 信号和显示 AFib 符号的信号区分开来。

本示例使用长短期记忆 (LSTM) 网络,这是一种循环神经网络 (RNN),非常适合研究序列和时序数据。LSTM 网络可以学习序列的时间步之间的长期相关性。LSTM 层 (lstmLayer (Deep Learning Toolbox)) 可以前向分析时间序列,而双向 LSTM 层 (bilstmLayer (Deep Learning Toolbox)) 可以前向和后向分析时间序列。此示例使用双向 LSTM 层。

为了加快训练过程,请在具有 GPU 的机器上运行此示例。如果您的机器同时有 GPU 和 Parallel Computing Toolbox™,则 MATLAB® 会自动使用 GPU 进行训练;否则,它使用 CPU。

加载并检查数据

运行 ReadPhysionetData 脚本,从 PhysioNet 网站下载数据,并生成包含适当格式的 ECG 信号的 MAT 文件 (PhysionetData.mat)。下载数据可能需要几分钟。使用一个条件语句,限定仅在当前文件夹中不存在 PhysionetData.mat 时才运行该脚本。

if ~isfile('PhysionetData.mat')
    ReadPhysionetData         
end
load PhysionetData

加载操作会向工作区添加两个变量:SignalsLabelsSignals 是保存 ECG 信号的元胞数组。Labels 是分类数组,它保存信号的对应真实值标签。

Signals(1:5)
ans=5×1 cell array
    {1×9000  double}
    {1×9000  double}
    {1×18000 double}
    {1×9000  double}
    {1×18000 double}

Labels(1:5)
ans = 5×1 categorical
     N 
     N 
     N 
     A 
     A 

使用 summary 函数查看数据中包含多少 AFib 信号和正常信号。

summary(Labels)
     A       738 
     N      5050 

生成信号长度的直方图。大多数信号的长度是 9000 个采样。

L = cellfun(@length,Signals);
h = histogram(L);
xticks(0:3000:18000);
xticklabels(0:3000:18000);
title('Signal Lengths')
xlabel('Length')
ylabel('Count')

可视化每个类中一个信号的一段。AFib 心跳间隔不规则,而正常心跳会周期性发生。AFib 心跳信号还经常缺失 P 波,P 波在正常心跳信号的 QRS 复合波之前出现。正常信号的绘图会显示 P 波和 QRS 复合波。

normal = Signals{1};
aFib = Signals{4};

subplot(2,1,1)
plot(normal)
title('Normal Rhythm')
xlim([4000,5200])
ylabel('Amplitude (mV)')
text(4330,150,'P','HorizontalAlignment','center')
text(4370,850,'QRS','HorizontalAlignment','center')

subplot(2,1,2)
plot(aFib)
title('Atrial Fibrillation')
xlim([4000,5200])
xlabel('Samples')
ylabel('Amplitude (mV)')

准备用于训练的数据

在训练期间,trainNetwork 函数将数据分成小批量。然后,该函数在同一个小批量中填充或截断信号,使它们都具有相同的长度。过多的填充或截断会对网络性能产生负面影响,因为网络可能会根据添加或删除的信息错误地解释信号。

为避免过度填充或截断,请对 ECG 信号应用 segmentSignals 函数,使它们的长度都为 9000 个样本。该函数会忽略少于 9000 个样本的信号。如果信号的样本超过 9000 个,segmentSignals 会将其分成尽可能多的包含 9000 个样本的信号段,并忽略剩余样本。例如,具有 18500 个样本的信号将变为两个包含 9000 个样本的信号,剩余的 500 个样本被忽略。

[Signals,Labels] = segmentSignals(Signals,Labels);

查看 Signals 数组的前五个元素,以验证每个条目的长度现在为 9000 个样本。

Signals(1:5)
ans=5×1 cell array
    {1×9000 double}
    {1×9000 double}
    {1×9000 double}
    {1×9000 double}
    {1×9000 double}

使用原始信号数据训练分类器

要设计分类器,请使用上一节中生成的原始信号。将信号分成一个训练集(用于训练分类器)和一个测试集(用于基于新数据测试分类器的准确度)。

使用 summary 函数显示 AFib 信号与正常信号的比率为 718:4937,约为 1:7。

summary(Labels)
     A       718 
     N      4937 

由于约 7/8 的信号是正常信号,因此分类器会发现通过简单地将所有信号分类为正常信号就可达到高准确度。为了避免这种偏置,需要通过复制数据集中的 AFib 信号来增加 AFib 数据,以便正常信号和 AFib 信号的数量相同。这种复制通常称为过采样,是深度学习中使用的一种数据增强形式。

根据信号所属的类划分信号。

afibX = Signals(Labels=='A');
afibY = Labels(Labels=='A');

normalX = Signals(Labels=='N');
normalY = Labels(Labels=='N');

接下来,使用 dividerand 将每个类的目标随机分为训练集和测试集。

[trainIndA,~,testIndA] = dividerand(718,0.9,0.0,0.1);
[trainIndN,~,testIndN] = dividerand(4937,0.9,0.0,0.1);

XTrainA = afibX(trainIndA);
YTrainA = afibY(trainIndA);

XTrainN = normalX(trainIndN);
YTrainN = normalY(trainIndN);

XTestA = afibX(testIndA);
YTestA = afibY(testIndA);

XTestN = normalX(testIndN);
YTestN = normalY(testIndN);

现在有 646 个 AFib 信号和 4443 个正常信号用于训练。要在每个类中获得相同数量的信号,请使用前 4438 个正常信号,然后使用 repmat 对前 634 个 AFib 信号重复七次。

对于测试集,现在有 72 个 AFib 信号和 494 个正常信号。使用前 490 个正常信号,然后使用 repmat 对前 70 个 AFib 信号重复七次。默认情况下,神经网络会在训练前随机对数据进行乱序处理,以确保相邻信号不都有相同的标签。

XTrain = [repmat(XTrainA(1:634),7,1); XTrainN(1:4438)];
YTrain = [repmat(YTrainA(1:634),7,1); YTrainN(1:4438)];

XTest = [repmat(XTestA(1:70),7,1); XTestN(1:490)];
YTest = [repmat(YTestA(1:70),7,1); YTestN(1:490);];

现在,正常信号和 AFib 信号在训练集和测试集中的分布均衡。

summary(YTrain)
     A      4438 
     N      4438 
summary(YTest)
     A      490 
     N      490 

定义 LSTM 网络架构

LSTM 网络可以学习序列数据的时间步之间的长期相关性。此示例使用双向 LSTM 层 bilstmLayer,因为它前向和后向检测序列。

由于输入信号各有一个维度,将输入大小指定是大小为 1 的序列。指定输出大小为 100 的一个双向 LSTM 层,并输出序列的最后一个元素。此命令指示双向 LSTM 层将输入时序映射到 100 个特征,然后为全连接层准备输出。最后,通过包含大小为 2 的全连接层,后跟 softmax 层和分类层,来指定两个类。

layers = [ ...
    sequenceInputLayer(1)
    bilstmLayer(100,'OutputMode','last')
    fullyConnectedLayer(2)
    softmaxLayer
    classificationLayer
    ]
layers = 
  5x1 Layer array with layers:

     1   ''   Sequence Input          Sequence input with 1 dimensions
     2   ''   BiLSTM                  BiLSTM with 100 hidden units
     3   ''   Fully Connected         2 fully connected layer
     4   ''   Softmax                 softmax
     5   ''   Classification Output   crossentropyex

接下来指定分类器的训练选项。将 'MaxEpochs' 设置为 10,以允许基于训练数据对网络进行 10 轮训练。'MiniBatchSize' 为 150 指示网络一次分析 150 个训练信号。'InitialLearnRate' 为 0.01 有助于加快训练过程。指定 'SequenceLength' 为 1000 以将信号分解成更小的片段,这样机器就不会因为一次处理太多数据而耗尽内存。将 'GradientThreshold' 设置为 1 以防止梯度过大,从而稳定训练过程。将 'Plots' 指定为 'training-progress',以生成显示训练随迭代次数的增加而变化的进度图。将 'Verbose' 设置为 false 以隐藏对应于图中所示数据的表输出。如果您要查看此表,请将 'Verbose' 设置为 true

此示例使用自适应矩估计 (ADAM) 求解器。与默认的具有动量的随机梯度下降 (SGDM) 求解器相比,ADAM 在使用 LSTM 之类的 RNN 时性能更好。

options = trainingOptions('adam', ...
    'MaxEpochs',10, ...
    'MiniBatchSize', 150, ...
    'InitialLearnRate', 0.01, ...
    'SequenceLength', 1000, ...
    'GradientThreshold', 1, ...
    'ExecutionEnvironment',"auto",...
    'plots','training-progress', ...
    'Verbose',false);

训练 LSTM 网络

通过使用 trainNetwork 用指定的训练选项和层架构训练 LSTM 网络。由于训练集很大,训练过程可能需要几分钟。

net = trainNetwork(XTrain,YTrain,layers,options);

训练进度图的顶部子图表示训练准确度,即基于每个小批量的分类准确度。当训练在成功进行时,此值通常会逐渐增大,直到 100%。底部子图显示训练损失,即基于每个小批量的交叉熵损失。当训练在成功进行时,该值通常会逐渐降低,直到为零。

如果训练未收敛,绘图可能会在各值之间振荡,而不会呈现向上或向下趋势。这种振荡意味着训练准确度没有提高,训练损失没有减少。这种情况可能发生在训练开始时,或者在训练准确度有初步提高后,绘图可能趋于平稳。在许多情况下,更改训练选项可以帮助网络实现收敛。减少 MiniBatchSize 或减少 InitialLearnRate 可能会导致更长的训练时间,但这可能有助于网络更好地学习。

分类器的训练准确度在约 50% 和约 60% 之间震荡,在 10 轮结束时,训练已进行了几分钟。

可视化训练和测试准确度

计算训练准确度,该准确度表示分类器对于所训练信号的准确度。首先,对训练数据进行分类。

trainPred = classify(net,XTrain,'SequenceLength',1000);

在分类问题中,混淆矩阵用于可视化分类器对于一组已知真实数值的数据上的性能。目标类是信号的真实值标签,输出类是网络分配给信号的标签。坐标区标签表示类标签 AFib (A) 和 Normal (N)。

使用 confusionchart 命令计算用于测试数据预测的总体分类准确度。将 'RowSummary' 指定为 'row-normalized' 以在行汇总中显示真正率和假正率。此外,将 'ColumnSummary' 指定为 'column-normalized' 以在列汇总中显示正预测值和假发现率。

LSTMAccuracy = sum(trainPred == YTrain)/numel(YTrain)*100
LSTMAccuracy = 61.7283
figure
confusionchart(YTrain,trainPred,'ColumnSummary','column-normalized',...
              'RowSummary','row-normalized','Title','Confusion Chart for LSTM');

现在用相同的网络对测试数据进行分类。

testPred = classify(net,XTest,'SequenceLength',1000);

计算测试准确度,并使用混淆矩阵将分类性能可视化。

LSTMAccuracy = sum(testPred == YTest)/numel(YTest)*100
LSTMAccuracy = 66.2245
figure
confusionchart(YTest,testPred,'ColumnSummary','column-normalized',...
              'RowSummary','row-normalized','Title','Confusion Chart for LSTM');

通过特征提取提高性能

从数据中提取特征有助于提高分类器的训练和测试准确度。为了决定提取哪些特征,本示例采用的方法是先计算时频图像(如频谱图),然后使用它们来训练卷积神经网络 (CNN) [4]、[5]。

可视化每个信号类型的频谱图。

fs = 300;

figure
subplot(2,1,1);
pspectrum(normal,fs,'spectrogram','TimeResolution',0.5)
title('Normal Signal')

subplot(2,1,2);
pspectrum(aFib,fs,'spectrogram','TimeResolution',0.5)
title('AFib Signal')

因为本示例使用 LSTM 而不是 CNN,必须转换该方法以应用于一维信号。时频 (TF) 矩从频谱图中提取信息。每个矩都可以用作一维特征以输入到 LSTM。

探查时域中的两个 TF 矩:

  • 瞬时频率 (instfreq)

  • 谱熵 (pentropy)

instfreq 函数估计信号的时变频率,作为功率谱图的第一个矩。该函数使用时间窗上的短时傅里叶变换计算频谱图。在本示例中,该函数使用 255 个时间窗。该函数的时间输出对应于时间窗的中心。

可视化每个信号类型的瞬时频率。

[instFreqA,tA] = instfreq(aFib,fs);
[instFreqN,tN] = instfreq(normal,fs);

figure
subplot(2,1,1);
plot(tN,instFreqN)
title('Normal Signal')
xlabel('Time (s)')
ylabel('Instantaneous Frequency')

subplot(2,1,2);
plot(tA,instFreqA)
title('AFib Signal')
xlabel('Time (s)')
ylabel('Instantaneous Frequency')

使用 cellfun instfreq 函数应用于训练集和测试集中的每个单元。

instfreqTrain = cellfun(@(x)instfreq(x,fs)',XTrain,'UniformOutput',false);
instfreqTest = cellfun(@(x)instfreq(x,fs)',XTest,'UniformOutput',false);

谱熵测量信号的频谱的尖度或平坦度。具有尖峰频谱的信号(如正弦波之和)具有低谱熵。具有平坦频谱的信号(如白噪声)具有高谱熵。pentropy 函数基于功率谱估计谱熵。与瞬时频率估计情况一样,pentropy 使用 255 个时间窗来计算频谱图。函数的时间输出对应于时间窗的中心。

可视化每个信号类型的谱熵。

[pentropyA,tA2] = pentropy(aFib,fs);
[pentropyN,tN2] = pentropy(normal,fs);

figure

subplot(2,1,1)
plot(tN2,pentropyN)
title('Normal Signal')
ylabel('Spectral Entropy')

subplot(2,1,2)
plot(tA2,pentropyA)
title('AFib Signal')
xlabel('Time (s)')
ylabel('Spectral Entropy')

使用 cellfunpentropy 函数应用于训练中和测试集中的每个单元。

pentropyTrain = cellfun(@(x)pentropy(x,fs)',XTrain,'UniformOutput',false);
pentropyTest = cellfun(@(x)pentropy(x,fs)',XTest,'UniformOutput',false);

串联这些特征,使新的训练集和测试集中的每个单元都有两个维度(即两个特征)。

XTrain2 = cellfun(@(x,y)[x;y],instfreqTrain,pentropyTrain,'UniformOutput',false);
XTest2 = cellfun(@(x,y)[x;y],instfreqTest,pentropyTest,'UniformOutput',false);

可视化新输入的格式。每个单元不再包含一个长度为 9000 个样本的信号;现在它包含两个长度为 255 个样本的特征。

XTrain2(1:5)
ans=5×1 cell array
    {2×255 double}
    {2×255 double}
    {2×255 double}
    {2×255 double}
    {2×255 double}

标准化数据

瞬时频率和谱熵的均值相差几乎一个数量级。而且,瞬时频率均值可能太高,以致 LSTM 无法高效学习。当网络适合于均值和极差较大的数据时,大的输入可能会减慢网络的学习和收敛速度 [6]。

mean(instFreqN)
ans = 5.5615
mean(pentropyN)
ans = 0.6326

使用训练集均值和标准差来标准化训练集和测试集。标准化,或 z 分数,是一种在训练过程中提高网络性能的常用方法。

XV = [XTrain2{:}];
mu = mean(XV,2);
sg = std(XV,[],2);

XTrainSD = XTrain2;
XTrainSD = cellfun(@(x)(x-mu)./sg,XTrainSD,'UniformOutput',false);

XTestSD = XTest2;
XTestSD = cellfun(@(x)(x-mu)./sg,XTestSD,'UniformOutput',false);

显示标准化瞬时频率和谱熵的均值。

instFreqNSD = XTrainSD{1}(1,:);
pentropyNSD = XTrainSD{1}(2,:);

mean(instFreqNSD)
ans = -0.3211
mean(pentropyNSD)
ans = -0.2416

修改 LSTM 网络架构

现在每个信号都有两个维度,就有必要通过将输入序列大小指定为 2 来修改网络架构。指定输出大小为 100 的一个双向 LSTM 层,并输出序列的最后一个元素。通过使用一个大小为 2 的全连接层,后跟 softmax 层和分类层,来指定两个类。

layers = [ ...
    sequenceInputLayer(2)
    bilstmLayer(100,'OutputMode','last')
    fullyConnectedLayer(2)
    softmaxLayer
    classificationLayer
    ]
layers = 
  5x1 Layer array with layers:

     1   ''   Sequence Input          Sequence input with 2 dimensions
     2   ''   BiLSTM                  BiLSTM with 100 hidden units
     3   ''   Fully Connected         2 fully connected layer
     4   ''   Softmax                 softmax
     5   ''   Classification Output   crossentropyex

指定训练选项。将最大轮数设置为 30,以允许基于训练数据对网络进行 30 轮训练。

options = trainingOptions('adam', ...
    'MaxEpochs',30, ...
    'MiniBatchSize', 150, ...
    'InitialLearnRate', 0.01, ...
    'GradientThreshold', 1, ...
    'ExecutionEnvironment',"auto",...
    'plots','training-progress', ...
    'Verbose',false);

用时频特征训练 LSTM 网络

通过使用 trainNetwork 用指定的训练选项和层架构训练 LSTM 网络。

net2 = trainNetwork(XTrainSD,YTrain,layers,options);

训练准确度有很大提高。交叉熵损失趋向于 0。而且,训练所需的时间减少,因为 TF 矩比原始序列短。

可视化训练和测试准确度

使用更新后的 LSTM 网络对训练数据进行分类。将分类性能可视化为混淆矩阵。

trainPred2 = classify(net2,XTrainSD);
LSTMAccuracy = sum(trainPred2 == YTrain)/numel(YTrain)*100
LSTMAccuracy = 83.5962
figure
confusionchart(YTrain,trainPred2,'ColumnSummary','column-normalized',...
              'RowSummary','row-normalized','Title','Confusion Chart for LSTM');

使用更新后的网络对测试数据进行分类。绘制混淆矩阵以检查测试准确度。

testPred2 = classify(net2,XTestSD);

LSTMAccuracy = sum(testPred2 == YTest)/numel(YTest)*100
LSTMAccuracy = 80.1020
figure
confusionchart(YTest,testPred2,'ColumnSummary','column-normalized',...
              'RowSummary','row-normalized','Title','Confusion Chart for LSTM');

结论

此示例说明如何使用 LSTM 网络构建分类器来检测 ECG 信号中的心房颤动。该过程使用过采样来避免在主要由健康被测者组成的人群中检测异常情况时出现的分类偏置。使用原始信号数据训练 LSTM 网络会导致分类准确度差。对每个信号使用两个时频矩特征来训练网络可显著提高分类性能,同时减少训练时间。

参考资料

[1] AF Classification from a Short Single Lead ECG Recording: the PhysioNet/Computing in Cardiology Challenge, 2017. https://physionet.org/challenge/2017/

[2] Clifford, Gari, Chengyu Liu, Benjamin Moody, Li-wei H. Lehman, Ikaro Silva, Qiao Li, Alistair Johnson, and Roger G. Mark."AF Classification from a Short Single Lead ECG Recording:The PhysioNet Computing in Cardiology Challenge 2017."Computing in Cardiology (Rennes:IEEE).Vol. 44, 2017, pp. 1–4.

[3] Goldberger, A. L., L. A. N. Amaral, L. Glass, J. M. Hausdorff, P. Ch.Ivanov, R. G. Mark, J. E. Mietus, G. B. Moody, C.-K. Peng, and H. E. Stanley."PhysioBank, PhysioToolkit, and PhysioNet:Components of a New Research Resource for Complex Physiologic Signals".Circulation.Vol. 101, No. 23, 13 June 2000, pp. e215–e220. http://circ.ahajournals.org/content/101/23/e215.full

[4] Pons, Jordi, Thomas Lidy, and Xavier Serra."Experimenting with Musically Motivated Convolutional Neural Networks".14th International Workshop on Content-Based Multimedia Indexing (CBMI).June 2016.

[5] Wang, D."Deep learning reinvents the hearing aid," IEEE Spectrum, Vol. 54, No. 3, March 2017, pp. 32–37. doi:10.1109/MSPEC.2017.7864754.

[6] Brownlee, Jason.How to Scale Data for Long Short-Term Memory Networks in Python.7 July 2017. https://machinelearningmastery.com/how-to-scale-data-for-long-short-term-memory-networks-in-python/.

另请参阅

函数

相关主题