本页面提供的是上一版软件的文档。当前版本中已删除对应的英文页面。

使用深度学习进行调制分类

此示例说明如何使用卷积神经网络 (CNN) 进行调制分类。您将生成合成的、通道减损波形。使用生成的波形作为训练数据,训练 CNN 进行调制分类。然后用软件定义的无线电 (SDR) 硬件和无线信号测试 CNN。

使用 CNN 预测调制类型

本示例中经过训练的 CNN 可识别以下八种数字调制类型和三种模拟调制类型:

  • 二相相移键控 (BPSK)

  • 四相相移键控 (QPSK)

  • 八相相移键控 (8-PSK)

  • 十六相正交幅值调制 (16-QAM)

  • 六十四相正交幅值调制 (64-QAM)

  • 四相脉冲幅值调制 (PAM4)

  • 高斯频移键控 (GFSK)

  • 连续相位频移键控 (CPFSK)

  • 广播 FM (B-FM)

  • 双边带幅值调制 (DSB-AM)

  • 单边带幅值调制 (SSB-AM)

modulationTypes = categorical(["BPSK", "QPSK", "8PSK", ...
  "16QAM", "64QAM", "PAM4", "GFSK", "CPFSK", ...
  "B-FM", "DSB-AM", "SSB-AM"]);

首先,加载经过训练的网络。有关网络训练的详细信息,请参阅“训练 CNN”一节。

load trainedModulationClassificationNetwork
trainedNet
trainedNet = 
  SeriesNetwork with properties:

         Layers: [28x1 nnet.cnn.layer.Layer]
     InputNames: {'Input Layer'}
    OutputNames: {'Output'}

经过训练的 CNN 接受 1024 个通道减损样本,并预测每个帧的调制类型。生成几个因莱斯多径衰落、中心频率和采样时间漂移以及 AWGN 而有所减损的 BPSK 帧。使用 randi 函数生成一些随机位,使用 pskmod 函数对这些位进行 BPSK 调制,使用 rcosdesign 函数设计平方根升余弦脉冲整形滤波器,并使用 filter 函数对符号进行脉冲整形。然后使用 CNN 预测帧的调制类型。

% Set the random number generator to a known state to be able to regenerate
% the same frames every time the simulation is run
rng(123456)
% Random bits
d = randi([0 1],1024,1);
% BPSK modulation
syms = pskmod(d,2);
% Square-root raised cosine filter
filterCoeffs = rcosdesign(0.35,4,8);
tx = filter(filterCoeffs,1,upsample(syms,8));
% Channel
channel = helperModClassTestChannel(...
  'SampleRate',200e3, ...
  'SNR',30, ...
  'PathDelays',[0 1.8 3.4] / 200e3, ...
  'AveragePathGains',[0 -2 -10], ...
  'KFactor',4, ...
  'MaximumDopplerShift',4, ...
  'MaximumClockOffset',5, ...
  'CenterFrequency',902e6);
rx = channel(tx);
% Plot transmitted and received signals
scope = dsp.TimeScope(2,200e3,'YLimits',[-1 1],'ShowGrid',true,...
  'LayoutDimensions',[2 1],'TimeSpan',45e-3);
scope(tx,rx)

% Frame generation for classification
unknownFrames = getNNFrames(rx,'Unknown');
% Classification
[prediction1,score1] = classify(trainedNet,unknownFrames);

返回分类器预测,这类似于硬判决。网络正确地将帧识别为 BPSK 帧。有关生成调制信号的详细信息,请参阅附录:调制器

prediction1
prediction1 = 7x1 categorical array
     BPSK 
     BPSK 
     BPSK 
     BPSK 
     BPSK 
     BPSK 
     BPSK 

分类器还返回一个包含每一帧分数的向量。分数对应于每个帧具有预测的调制类型的概率。绘制分数图。

plotScores(score1,modulationTypes)

接下来,使用 CNN 对 PAM4 帧进行分类。

% Random bits
d = randi([0 3], 1024, 1);
% PAM4 modulation
syms = pammod(d,4);
% Square-root raised cosine filter
filterCoeffs = rcosdesign(0.35, 4, 8);
tx = filter(filterCoeffs, 1, upsample(syms,8));
% Channel
rx = channel(tx);
% Plot transmitted and received signals
scope = dsp.TimeScope(2,200e3,'YLimits',[-2 2],'ShowGrid',true,...
  'LayoutDimensions',[2 1],'TimeSpan',45e-3);
scope(tx,rx)

% Frame generation for classification
unknownFrames = getNNFrames(rx,'Unknown');
% Classification
[estimate2,score2] = classify(trainedNet,unknownFrames);
estimate2
estimate2 = 7x1 categorical array
     PAM4 
     PAM4 
     PAM4 
     PAM4 
     PAM4 
     PAM4 
     PAM4 

plotScores(score2,modulationTypes)

我们首先需要用已知(即已加标签的)数据训练 CNN,然后才能使用 CNN 进行调制分类或执行任何其他任务。此示例的第一部分说明如何使用 Communications Toolbox 功能(如调制器、滤波器和通道减损)来生成合成的训练数据。第二部分着重于针对调制分类任务来定义、训练和测试 CNN。第三部分通过软件定义无线电 (SDR) 平台使用无线信号来测试网络性能。

生成用于训练的波形

为每种调制类型生成 10000 个帧,其中 80% 用于训练,10% 用于验证,10% 用于测试。我们在网络训练阶段使用训练和验证帧。使用测试帧获得最终分类准确度。每帧的长度为 1024 个样本,采样率为 200 kHz。对于数字调制类型,八个样本表示一个符号。网络根据单个帧而不是多个连续帧(如视频)作出每个决定。假设数字和模拟调制类型的中心频率分别为 902 MHz 和 100 MHz。

要快速运行此示例,请使用经过训练的网络并生成少量训练帧。要在您的计算机上训练网络,请选择“Train network now”选项(即,将 trainNow 设置为 true)。

trainNow = false;
if trainNow == true
  numFramesPerModType = 10000;
else
  numFramesPerModType = 500;
end
percentTrainingSamples = 80;
percentValidationSamples = 10;
percentTestSamples = 10;

sps = 8;                % Samples per symbol
spf = 1024;             % Samples per frame
symbolsPerFrame = spf / sps;
fs = 200e3;             % Sample rate
fc = [902e6 100e6];     % Center frequencies

创建通道减损

让每帧通过通道并具有

  • AWGN

  • 莱斯多径衰落

  • 时钟偏移,导致中心频率偏移和采样时间漂移

由于本示例中的网络基于单个帧作出决定,因此每个帧必须通过独立的通道。

AWGN

通道增加 SNR 为 30 dB 的 AWGN。由于帧经过归一化,因此噪声标准差可以计算为

SNR = 30;
std = sqrt(10.^(-SNR/10))
std = 0.0316

使用 comm.AWGNChannel 实现通道,

awgnChannel = comm.AWGNChannel(...
  'NoiseMethod', 'Signal to noise ratio (SNR)', ...
  'SignalPower', 1, ...
  'SNR', SNR)
awgnChannel = 
  comm.AWGNChannel with properties:

     NoiseMethod: 'Signal to noise ratio (SNR)'
             SNR: 30
     SignalPower: 1
    RandomStream: 'Global stream'

莱斯多径

通道使用 comm.RicianChannel System object 通过莱斯多径衰落通道传递信号。假设延迟分布为 [0 1.8 3.4] 个样本,对应的平均路径增益为 [0 -2 -10] dB。K 因子为 4,最大多普勒频移为 4 Hz,等效于 902 MHz 的步行速度。使用以下设置实现通道。

multipathChannel = comm.RicianChannel(...
  'SampleRate', fs, ...
  'PathDelays', [0 1.8 3.4]/fs, ...
  'AveragePathGains', [0 -2 -10], ...
  'KFactor', 4, ...
  'MaximumDopplerShift', 4)
multipathChannel = 
  comm.RicianChannel with properties:

                SampleRate: 200000
                PathDelays: [0 9.0000e-06 1.7000e-05]
          AveragePathGains: [0 -2 -10]
        NormalizePathGains: true
                   KFactor: 4
    DirectPathDopplerShift: 0
    DirectPathInitialPhase: 0
       MaximumDopplerShift: 4
           DopplerSpectrum: [1x1 struct]

  Show all properties

时钟偏移

时钟偏移是发送器和接收器的内部时钟源不准确造成的。时钟偏移导致中心频率(用于将信号下变频至基带)和数模转换器采样率不同于理想值。通道仿真器使用时钟偏移因子 C,表示为 C=1+Δclock106,其中 Δclock 是时钟偏移。对于每个帧,通道基于 [-maxΔclock maxΔclock] 范围内一组均匀分布的值生成一个随机 Δclock 值,其中 maxΔclock 是最大时钟偏移。时钟偏移以百万分率 (ppm) 为单位测量。对于本示例,假设最大时钟偏移为 5 ppm。

maxDeltaOff = 5;
deltaOff = (rand()*2*maxDeltaOff) - maxDeltaOff;
C = 1 + (deltaOff/1e6);

频率偏移

基于时钟偏移因子 C 和中心频率,对每帧进行频率偏移。使用 comm.PhaseFrequencyOffset 实现通道。

offset = -(C-1)*fc(1);
frequencyShifter = comm.PhaseFrequencyOffset(...
  'SampleRate', fs, ...
  'FrequencyOffset', offset)
frequencyShifter = 
  comm.PhaseFrequencyOffset with properties:

              PhaseOffset: 0
    FrequencyOffsetSource: 'Property'
          FrequencyOffset: -2.4386e+03
               SampleRate: 200000

采样率偏移

基于时钟偏移因子 C,对每帧进行采样率偏移。使用 interp1 函数实现通道,以 C×fs 的新速率对帧进行重新采样。

合并后的通道

使用 helperModClassTestChannel 对象对帧应用所有三种通道减损。

channel = helperModClassTestChannel(...
  'SampleRate', fs, ...
  'SNR', SNR, ...
  'PathDelays', [0 1.8 3.4] / fs, ...
  'AveragePathGains', [0 -2 -10], ...
  'KFactor', 4, ...
  'MaximumDopplerShift', 4, ...
  'MaximumClockOffset', 5, ...
  'CenterFrequency', 902e6)
channel = 
  helperModClassTestChannel with properties:

                    SNR: 30
        CenterFrequency: 902000000
             SampleRate: 200000
             PathDelays: [0 9.0000e-06 1.7000e-05]
       AveragePathGains: [0 -2 -10]
                KFactor: 4
    MaximumDopplerShift: 4
     MaximumClockOffset: 5

您可以使用 info 对象函数查看有关通道的基本信息。

chInfo = info(channel)
chInfo = struct with fields:
               ChannelDelay: 6
     MaximumFrequencyOffset: 4510
    MaximumSampleRateOffset: 1

波形生成

创建一个循环,它为每种调制类型生成通道减损的帧并将这些帧及其对应标签存储在 frameStore 中。从每帧的开头删除随机数量的样本,以去除瞬变并确保帧相对于符号边界具有随机起点。

% Set the random number generator to a known state to be able to regenerate
% the same frames every time the simulation is run
rng(1235)
tic

numModulationTypes = length(modulationTypes);

channelInfo = info(channel);
frameStore = helperModClassFrameStore(...
  numFramesPerModType*numModulationTypes,spf,modulationTypes);
transDelay = 50;
for modType = 1:numModulationTypes
  fprintf('%s - Generating %s frames\n', ...
    datestr(toc/86400,'HH:MM:SS'), modulationTypes(modType))
  numSymbols = (numFramesPerModType / sps);
  dataSrc = getSource(modulationTypes(modType), sps, 2*spf, fs);
  modulator = getModulator(modulationTypes(modType), sps, fs);
  if contains(char(modulationTypes(modType)), {'B-FM','DSB-AM','SSB-AM'})
    % Analog modulation types use a center frequency of 100 MHz
    channel.CenterFrequency = 100e6;
  else
    % Digital modulation types use a center frequency of 902 MHz
    channel.CenterFrequency = 902e6;
  end
  
  for p=1:numFramesPerModType
    % Generate random data
    x = dataSrc();
    
    % Modulate
    y = modulator(x);
    
    % Pass through independent channels
    rxSamples = channel(y);
    
    % Remove transients from the beginning, trim to size, and normalize
    frame = helperModClassFrameGenerator(rxSamples, spf, spf, transDelay, sps);
    
    % Add to frame store
    add(frameStore, frame, modulationTypes(modType));
  end
end
00:00:00 - Generating BPSK frames
00:00:03 - Generating QPSK frames
00:00:06 - Generating 8PSK frames
00:00:13 - Generating 16QAM frames
00:00:16 - Generating 64QAM frames
00:00:19 - Generating PAM4 frames
00:00:21 - Generating GFSK frames
00:00:24 - Generating CPFSK frames
00:00:26 - Generating B-FM frames
00:00:40 - Generating DSB-AM frames
00:00:43 - Generating SSB-AM frames

接下来,将帧分为训练数据、验证数据和测试数据。默认情况下,frameStore 将 I/Q 基带样本按行放置在输出帧中。输出帧的大小为 [2xspf×1×N],其中第一行是同相采样,第二行是正交采样。

[mcfsTraining,mcfsValidation,mcfsTest] = splitData(frameStore,...
  [percentTrainingSamples,percentValidationSamples,percentTestSamples]);
[rxTraining,rxTrainingLabel] = get(mcfsTraining);
[rxValidation,rxValidationLabel] = get(mcfsValidation);
[rxTest,rxTestLabel] = get(mcfsTest);
% Plot the amplitude of the real and imaginary parts of the example frames
% against the sample number
plotTimeDomain(rxTest,rxTestLabel,modulationTypes,fs)

% Plot a spectrogram of the example frames
plotSpectrogram(rxTest,rxTestLabel,modulationTypes,fs,sps)

通过确保标签(调制类型)分布均匀,避免训练数据中的类不平衡。绘制标签分布图,以检查生成的标签是否分布均匀。

% Plot the label distributions
figure
subplot(3,1,1)
histogram(rxTrainingLabel)
title("Training Label Distribution")
subplot(3,1,2)
histogram(rxValidationLabel)
title("Validation Label Distribution")
subplot(3,1,3)
histogram(rxTestLabel)
title("Test Label Distribution")

训练 CNN

本示例使用的 CNN 由六个卷积层和一个全连接层组成。除最后一个卷积层外,每个卷积层后面都有一个批量归一化层、修正线性单元 (ReLU) 激活层和最大池化层。在最后一个卷积层中,最大池化层被一个平均池化层取代。输出层具有 softmax 激活。有关网络设计指导原则,请参阅Deep Learning Tips and Tricks

dropoutRate = 0.5;
numModTypes = numel(modulationTypes);
netWidth = 1;
filterSize = [1 sps];
poolSize = [1 2];
modClassNet = [
  imageInputLayer([2 spf 1], 'Normalization', 'none', 'Name', 'Input Layer')
  
  convolution2dLayer(filterSize, 16*netWidth, 'Padding', 'same', 'Name', 'CNN1')
  batchNormalizationLayer('Name', 'BN1')
  reluLayer('Name', 'ReLU1')
  maxPooling2dLayer(poolSize, 'Stride', [1 2], 'Name', 'MaxPool1')
  
  convolution2dLayer(filterSize, 24*netWidth, 'Padding', 'same', 'Name', 'CNN2')
  batchNormalizationLayer('Name', 'BN2')
  reluLayer('Name', 'ReLU2')
  maxPooling2dLayer(poolSize, 'Stride', [1 2], 'Name', 'MaxPool2')
  
  convolution2dLayer(filterSize, 32*netWidth, 'Padding', 'same', 'Name', 'CNN3')
  batchNormalizationLayer('Name', 'BN3')
  reluLayer('Name', 'ReLU3')
  maxPooling2dLayer(poolSize, 'Stride', [1 2], 'Name', 'MaxPool3')
  
  convolution2dLayer(filterSize, 48*netWidth, 'Padding', 'same', 'Name', 'CNN4')
  batchNormalizationLayer('Name', 'BN4')
  reluLayer('Name', 'ReLU4')
  maxPooling2dLayer(poolSize, 'Stride', [1 2], 'Name', 'MaxPool4')
  
  convolution2dLayer(filterSize, 64*netWidth, 'Padding', 'same', 'Name', 'CNN5')
  batchNormalizationLayer('Name', 'BN5')
  reluLayer('Name', 'ReLU5')
  maxPooling2dLayer(poolSize, 'Stride', [1 2], 'Name', 'MaxPool5')
  
  convolution2dLayer(filterSize, 96*netWidth, 'Padding', 'same', 'Name', 'CNN6')
  batchNormalizationLayer('Name', 'BN6')
  reluLayer('Name', 'ReLU6')
  
  averagePooling2dLayer([1 ceil(spf/32)], 'Name', 'AP1')
  
  fullyConnectedLayer(numModTypes, 'Name', 'FC1')
  softmaxLayer('Name', 'SoftMax')
  
  classificationLayer('Name', 'Output') ]
modClassNet = 
  28x1 Layer array with layers:

     1   'Input Layer'   Image Input             2x1024x1 images
     2   'CNN1'          Convolution             16 1x8 convolutions with stride [1  1] and padding 'same'
     3   'BN1'           Batch Normalization     Batch normalization
     4   'ReLU1'         ReLU                    ReLU
     5   'MaxPool1'      Max Pooling             1x2 max pooling with stride [1  2] and padding [0  0  0  0]
     6   'CNN2'          Convolution             24 1x8 convolutions with stride [1  1] and padding 'same'
     7   'BN2'           Batch Normalization     Batch normalization
     8   'ReLU2'         ReLU                    ReLU
     9   'MaxPool2'      Max Pooling             1x2 max pooling with stride [1  2] and padding [0  0  0  0]
    10   'CNN3'          Convolution             32 1x8 convolutions with stride [1  1] and padding 'same'
    11   'BN3'           Batch Normalization     Batch normalization
    12   'ReLU3'         ReLU                    ReLU
    13   'MaxPool3'      Max Pooling             1x2 max pooling with stride [1  2] and padding [0  0  0  0]
    14   'CNN4'          Convolution             48 1x8 convolutions with stride [1  1] and padding 'same'
    15   'BN4'           Batch Normalization     Batch normalization
    16   'ReLU4'         ReLU                    ReLU
    17   'MaxPool4'      Max Pooling             1x2 max pooling with stride [1  2] and padding [0  0  0  0]
    18   'CNN5'          Convolution             64 1x8 convolutions with stride [1  1] and padding 'same'
    19   'BN5'           Batch Normalization     Batch normalization
    20   'ReLU5'         ReLU                    ReLU
    21   'MaxPool5'      Max Pooling             1x2 max pooling with stride [1  2] and padding [0  0  0  0]
    22   'CNN6'          Convolution             96 1x8 convolutions with stride [1  1] and padding 'same'
    23   'BN6'           Batch Normalization     Batch normalization
    24   'ReLU6'         ReLU                    ReLU
    25   'AP1'           Average Pooling         1x32 average pooling with stride [1  1] and padding [0  0  0  0]
    26   'FC1'           Fully Connected         11 fully connected layer
    27   'SoftMax'       Softmax                 softmax
    28   'Output'        Classification Output   crossentropyex

使用 analyzeNetwork 函数显示网络架构的交互式可视化,检测网络的错误和问题,并获取有关网络层的详细信息。此网络有 98323 个可学习参数。

analyzeNetwork(modClassNet)

接下来配置 TrainingOptionsSGDM 以使用小批量大小为 256 的 SGDM 求解器。将最大轮数设置为 12,因为更多轮数不会提供进一步的训练优势。通过将执行环境设置为 'gpu',在 GPU 上训练网络。将初始学习率设置为 2x10-2。每 9 轮后将学习率降低十分之一。将 'Plots' 设置为 'training-progress' 以对训练进度绘图。在 NVIDIA Titan Xp GPU 上,网络需要大约 25 分钟来完成训练。

maxEpochs = 12;
miniBatchSize = 256;
validationFrequency = floor(numel(rxTrainingLabel)/miniBatchSize);
options = trainingOptions('sgdm', ...
  'InitialLearnRate',2e-2, ...
  'MaxEpochs',maxEpochs, ...
  'MiniBatchSize',miniBatchSize, ...
  'Shuffle','every-epoch', ...
  'Plots','training-progress', ...
  'Verbose',false, ...
  'ValidationData',{rxValidation,rxValidationLabel}, ...
  'ValidationFrequency',validationFrequency, ...
  'LearnRateSchedule', 'piecewise', ...
  'LearnRateDropPeriod', 9, ...
  'LearnRateDropFactor', 0.1, ...
  'ExecutionEnvironment', 'gpu');

或者训练网络,或者使用已经过训练的网络。默认情况下,此示例使用经过训练的网络。

if trainNow == true
  fprintf('%s - Training the network\n', datestr(toc/86400,'HH:MM:SS'))
  trainedNet = trainNetwork(rxTraining,rxTrainingLabel,modClassNet,options);
else
  load trainedModulationClassificationNetwork
end

如训练进度图所示,网络在大约 12 轮后收敛于几乎 90% 的准确度。

通过获得测试帧的分类准确度来评估经过训练的网络。结果表明,该网络对这组波形实现的准确度达到 90% 左右。

fprintf('%s - Classifying test frames\n', datestr(toc/86400,'HH:MM:SS'))
00:01:15 - Classifying test frames
rxTestPred = classify(trainedNet,rxTest);
testAccuracy = mean(rxTestPred == rxTestLabel);
disp("Test accuracy: " + testAccuracy*100 + "%")
Test accuracy: 90.5455%

绘制测试帧的混淆矩阵。如矩阵所示,网络混淆了 16-QAM 和 64-QAM 帧。此问题是预料之中的,因为每个帧只携带 128 个符号,而 16-QAM 是 64-QAM 的子集。该网络还混淆了 QPSK 和 8-PSK 帧,因为在通道衰落和频率偏移引发相位旋转后,这些调制类型的星座图看起来相似。

figure
cm = confusionchart(rxTestLabel, rxTestPred);
cm.Title = 'Confusion Matrix for Test Data';
cm.RowSummary = 'row-normalized';
cm.Parent.Position = [cm.Parent.Position(1:2) 740 424];

I/Q 作为页

默认情况下,frameStore 将 I/Q 基带样本按行放在二维数组中。由于卷积过滤器的大小为 [1xsps],卷积层独立处理同相和正交分量。仅在全连接层中,才会合并来自同相和正交分量的信息。

另一种方法是将 I/Q 样本表示为 [1xSPFx2] 大小的三维数组,该数组将同相和正交分量放在第三个维度(页)中。这种方法混合了 I 和 Q(甚至是卷积层)中的信息,使我们能够更好地利用相位信息。将帧存储的 'OutputFormat' 属性设置为 "IQAsPages",并将输入层的大小设置为 [1xSPFx2]。

% Put the data in [1xspfx2] format
mcfsTraining.OutputFormat = "IQAsPages";
[rxTraining,rxTrainingLabel] = get(mcfsTraining);
mcfsValidation.OutputFormat = "IQAsPages";
[rxValidation,rxValidationLabel] = get(mcfsValidation);
mcfsTest.OutputFormat = "IQAsPages";
[rxTest,rxTestLabel] = get(mcfsTest);

% Set the options
options = trainingOptions('sgdm', ...
  'InitialLearnRate',2e-2, ...
  'MaxEpochs',maxEpochs, ...
  'MiniBatchSize',miniBatchSize, ...
  'Shuffle','every-epoch', ...
  'Plots','training-progress', ...
  'Verbose',false, ...
  'ValidationData',{rxValidation,rxValidationLabel}, ...
  'ValidationFrequency',validationFrequency, ...
  'LearnRateSchedule', 'piecewise', ...
  'LearnRateDropPeriod', 9, ...
  'LearnRateDropFactor', 0.1, ...
  'ExecutionEnvironment', 'gpu');

% Set the input layer input size to [1xspfx2]
modClassNet(1) = ...
  imageInputLayer([1 spf 2], 'Normalization', 'none', 'Name', 'Input Layer');

下面对网络的分析表明,每个卷积过滤器的维度为 1×8×2,这使得卷积层能够在计算一个过滤器输出时同时使用 I 和 Q 数据。

analyzeNetwork(modClassNet)

% Train or load the pretrained modified network
if trainNow == true
  fprintf('%s - Training the network\n', datestr(toc/86400,'HH:MM:SS'))
  trainedNet = trainNetwork(rxTraining,rxTrainingLabel,modClassNet,options);
else
  load trainedModulationClassificationNetwork2
end

如训练进度图所示,网络在大约 12 轮后收敛于超过 95% 的准确度。将 I/Q 分量表示为页而不是行可以将网络的准确度提高大约 5%。

通过获得测试帧的分类准确度来评估经过训练的网络。结果表明,该网络对这组波形实现的准确度达到 95% 左右。

fprintf('%s - Classifying test frames\n', datestr(toc/86400,'HH:MM:SS'))
00:01:24 - Classifying test frames
rxTestPred = classify(trainedNet,rxTest);
testAccuracy = mean(rxTestPred == rxTestLabel);
disp("Test accuracy: " + testAccuracy*100 + "%")
Test accuracy: 95.2727%

绘制测试帧的混淆矩阵。如矩阵所示,将 I/Q 分量表示为页而不是行,可极大地提高网络准确区分 16-QAM 和 64-QAM 帧以及 QPSK 和 8-PSK 帧的能力。

figure
cm = confusionchart(rxTestLabel, rxTestPred);
cm.Title = 'Confusion Matrix for Test Data';
cm.RowSummary = 'row-normalized';
cm.Parent.Position = [cm.Parent.Position(1:2) 740 424];

使用 SDR 进行测试

使用 sdrTest 函数,通过无线信号测试经过训练的网络的性能。要执行此测试,您必须有专用的 SDR 用于发送和接收。您可以使用两个 ADALM-PLUTO 无线电,或使用一个 ADALM-PLUTO 无线电和一个 USRP® 无线电分别进行发送和接收。您必须安装 Communications Toolbox Support Package for ADALM-PLUTO Radio。如果使用 USRP® 无线电,还必须安装 Communications Toolbox Support Package for USRP® Radio。sdrTest 函数使用生成训练信号所用的同一调制函数,然后使用 ADALM-PLUTO 无线电进行传输。使用针对信号接收配置的 SDR(ADALM-PLUTO 或 USRP® 无线电)捕获通道减损的信号,而不对通道进行仿真。使用经过训练的网络和先前使用的同一 classify 函数来预测调制类型。运行下一个代码段会生成一个混淆矩阵,并输出测试准确度。

radioPlatform = "ADALM-PLUTO";

switch radioPlatform
  case "ADALM-PLUTO"
    if isPlutoSDRInstalled() == true
      radios = findPlutoRadio();
      if length(radios) >= 2
        sdrTest(radios);
      else
        disp('Selected radios not found. Skipping over-the-air test.')
      end
    end
  case {"USRP B2xx","USRP X3xx","USRP N2xx"}
    if (isUSRPInstalled() == true) && (isPlutoSDRInstalled() == true)
      txRadio = findPlutoRadio();
      rxRadio = findsdru();
      switch radioPlatform
        case "USRP B2xx"
          idx = contains({rxRadio.Platform}, {'B200','B210'});
        case "USRP X3xx"
          idx = contains({rxRadio.Platform}, {'X300','X310'});
        case "USRP N2xx"
          idx = contains({rxRadio.Platform}, 'N200/N210/USRP2');
      end
      rxRadio = rxRadio(idx);
      if (length(txRadio) >= 1) && (length(rxRadio) >= 1)
        sdrTest(rxRadio);
      else
        disp('Selected radios not found. Skipping over-the-air test.')
      end
    end
end

当使用两个相隔约 2 英尺的固定 ADALM-PLUTO 无线电时,网络可实现 99% 的总体准确度,混淆矩阵如下图所示。结果可因试验设置而异。

进一步探查

要提高准确度,可以优化网络参数,例如过滤器数量、过滤器大小;或者优化网络结构,例如添加更多层,使用不同激活层等。

Communication Toolbox 提供了更多调制类型和通道减损。有关详细信息,请参阅Modulation (Communications Toolbox)和Channel Models (Communications Toolbox)部分。您还可以使用 LTE ToolboxWLAN Toolbox5G Toolbox 添加标准特定信号。您还可以使用 Phased Array System Toolbox 添加雷达信号。

附录:调制器部分提供用于生成调制信号的 MATLAB 函数。您还可以探查以下函数和 System object 以获得详细信息:

参考资料

  1. O'Shea, T. J., J. Corgan, and T. C. Clancy."Convolutional Radio Modulation Recognition Networks."Preprint, submitted June 10, 2016. https://arxiv.org/abs/1602.04105

  2. O'Shea, T. J., T. Roy, and T. C. Clancy."Over-the-Air Deep Learning Based Radio Signal Classification."IEEE Journal of Selected Topics in Signal Processing.Vol. 12, Number 1, 2018, pp. 168–179.

  3. Liu, X., D. Yang, and A. E. Gamal."Deep Neural Network Architectures for Modulation Classification."Preprint, submitted January 5, 2018. https://arxiv.org/abs/1712.00443v3

附录:辅助函数

function testAccuracy = sdrTest(radios)
%sdrTest Test CNN performance with over-the-air signals
%   A = sdrTest sends test frames from an ADALM-PLUTO radio, receives using
%   an ADALM-PLUTO or USRP radio, performs classification with the trained 
%   network and returns the overall classification accuracy. Transmitting 
%   radio uses transmit-repeat capability to send the same waveform repeatedly 
%   without loading the main loop.

modulationTypes = categorical(["BPSK", "QPSK", "8PSK", ...
  "16QAM", "64QAM", "PAM4", "GFSK", "CPFSK", "B-FM"]);
load trainedModulationClassificationNetwork2 trainedNet
numFramesPerModType = 100;

sps = 8;                % Samples per symbol
spf = 1024;             % Samples per frame
fs = 200e3;             % Sample rate

txRadio = sdrtx('Pluto');
txRadio.RadioID = 'usb:0';
txRadio.CenterFrequency = 902e6;
txRadio.BasebandSampleRate = fs;

if isfield(radios, 'Platform')
  radioPlatform = "USRP";
  % Configure USRP radio as the receiver. 
  rxRadio = comm.SDRuReceiver("Platform",radios(1).Platform);
  switch radios(1).Platform
    case {"B200","B210"}
      masterClockRate = 5e6;
      rxRadio.SerialNum = radios(1).SerialNum;
    case {"N200/N210/USRP2"}
      masterClockRate = 100e6;
      rxRadio.IPAddress = radios(1).IPAddress;
    case {"X300","X310"}
      masterClockRate = 120e6;
      rxRadio.IPAddress = radios(1).IPAddress;
  end
  rxRadio.MasterClockRate = masterClockRate;
  rxRadio.DecimationFactor = masterClockRate/fs;
  radioInfo = info(rxRadio);
  maximumGain = radioInfo.MaximumGain;
  minimumGain = radioInfo.MinimumGain;
else
  radioPlatform = "PlutoSDR";
  rxRadio = sdrrx('Pluto');
  rxRadio.RadioID = 'usb:1';
  rxRadio.BasebandSampleRate = fs;
  rxRadio.ShowAdvancedProperties = true;
  rxRadio.EnableQuadratureCorrection = false;
  rxRadio.GainSource = "Manual";
  maximumGain = 73;
  minimumGain = -10;
end
rxRadio.SamplesPerFrame = spf;
% Use burst mode with numFramesInBurst set to 1, so that each capture 
% (call to the receiver) will return an independent fresh frame even 
% though the radio overruns.
rxRadio.EnableBurstMode = true;
rxRadio.NumFramesInBurst = 1;
rxRadio.OutputDataType = 'single';

% Display Tx and Rx radios
txRadio
rxRadio

% Set random number generator to a known state to be able to regenerate
% the same frames every time the simulation is run
rng(1235)
tic

numModulationTypes = length(modulationTypes);
txModType = repmat(modulationTypes(1),numModulationTypes*numFramesPerModType,1);
estimatedModType = repmat(modulationTypes(1),numModulationTypes*numFramesPerModType,1);
frameCnt = 1;
for modType = 1:numModulationTypes
  fprintf('%s - Testing %s frames\n', ...
    datestr(toc/86400,'HH:MM:SS'), modulationTypes(modType))
  dataSrc = getSource(modulationTypes(modType), sps, 2*spf, fs);
  modulator = getModulator(modulationTypes(modType), sps, fs);
  if contains(char(modulationTypes(modType)), {'B-FM'})...
           && (radioPlatform == "PlutoSDR")
    % Analog modulation types use a center frequency of 100 MHz
    txRadio.CenterFrequency = 100e6;
    rxRadio.CenterFrequency = 100e6;
  else
    % Digital modulation types use a center frequency of 902 MHz
    txRadio.CenterFrequency = 902e6;
    rxRadio.CenterFrequency = 902e6;
  end
  
  disp('Starting transmitter')
  x = dataSrc();
  y = modulator(x);
  % Remove filter transients
  y = y(4*sps+1:end,1);
  maxVal = max(max(abs(real(y))), max(abs(imag(y))));
  y = y *0.8/maxVal;
  % Download waveform signal to radio and repeatedly transmit it over the air
  transmitRepeat(txRadio, complex(y));
  
  disp('Adjusting receiver gain')
  rxRadio.Gain = maximumGain;
  gainAdjusted = false;
  while ~gainAdjusted
    for p=1:20
      rx = rxRadio();
    end
    maxAmplitude = max([abs(real(rx)); abs(imag(rx))]);
    if (maxAmplitude < 0.8) || (rxRadio.Gain <= minimumGain)
      gainAdjusted = true;
    else
      rxRadio.Gain = rxRadio.Gain - 3;
    end
  end
  
  disp('Starting receiver and test')
  for p=1:numFramesPerModType
    rx = rxRadio();
    
    frameEnergy = sum(abs(rx).^2);
    rx = rx / sqrt(frameEnergy);
    reshapedRx(1,:,1,1) = real(rx);
    reshapedRx(1,:,2,1) = imag(rx);
    
    % Classify
    txModType(frameCnt) = modulationTypes(modType);
    estimatedModType(frameCnt) = classify(trainedNet, reshapedRx);
    
    frameCnt = frameCnt + 1;
    
    % Pause for some time to get an independent channel. The pause duration 
    % together with the processing time of a single loop must be greater 
    % than the channel coherence time. Assume channel coherence time is less 
    % than 0.1 seconds.
    pause(0.1)
  end
  disp('Releasing Tx radio')
  release(txRadio);
  testAccuracy = mean(txModType(1:frameCnt-1) == estimatedModType(1:frameCnt-1));
  disp("Test accuracy: " + testAccuracy*100 + "%")
end
disp('Releasing Rx radio')
release(rxRadio);
testAccuracy = mean(txModType == estimatedModType);
disp("Final test accuracy: " + testAccuracy*100 + "%")

figure
cm = confusionchart(txModType, estimatedModType);
cm.Title = 'Confusion Matrix for Test Data';
cm.RowSummary = 'row-normalized';
cm.Parent.Position = [cm.Parent.Position(1:2) 740 424];
end

function modulator = getModulator(modType, sps, fs)
%getModulator Modulation function selector
%   MOD = getModulator(TYPE,SPS,FS) returns the modulator function handle
%   MOD based on TYPE. SPS is the number of samples per symbol and FS is
%   the sample rate.

switch modType
  case "BPSK"
    modulator = @(x)bpskModulator(x,sps);
  case "QPSK"
    modulator = @(x)qpskModulator(x,sps);
  case "8PSK"
    modulator = @(x)psk8Modulator(x,sps);
  case "16QAM"
    modulator = @(x)qam16Modulator(x,sps);
  case "64QAM"
    modulator = @(x)qam64Modulator(x,sps);
  case "GFSK"
    modulator = @(x)gfskModulator(x,sps);
  case "CPFSK"
    modulator = @(x)cpfskModulator(x,sps);
  case "PAM4"
    modulator = @(x)pam4Modulator(x,sps);
  case "B-FM"
    modulator = @(x)bfmModulator(x, fs);
  case "DSB-AM"
    modulator = @(x)dsbamModulator(x, fs);
  case "SSB-AM"
    modulator = @(x)ssbamModulator(x, fs);
end
end

function src = getSource(modType, sps, spf, fs)
%getSource Source selector for modulation types
%    SRC = getSource(TYPE,SPS,SPF,FS) returns the data source
%    for the modulation type TYPE, with the number of samples
%    per symbol SPS, the number of samples per frame SPF, and
%    the sampling frequency FS.

switch modType
  case {"BPSK","GFSK","CPFSK"}
    M = 2;
    src = @()randi([0 M-1],spf/sps,1);
  case {"QPSK","PAM4"}
    M = 4;
    src = @()randi([0 M-1],spf/sps,1);
  case "8PSK"
    M = 8;
    src = @()randi([0 M-1],spf/sps,1);
  case "16QAM"
    M = 16;
    src = @()randi([0 M-1],spf/sps,1);
  case "64QAM"
    M = 64;
    src = @()randi([0 M-1],spf/sps,1);
  case {"B-FM","DSB-AM","SSB-AM"}
    src = @()getAudio(spf,fs);
end
end

function x = getAudio(spf,fs)
%getAudio Audio source for analog modulation types
%    A = getAudio(SPF,FS) returns the audio source A, with the
%    number of samples per frame SPF, and the sample rate FS.

persistent audioSrc audioRC

if isempty(audioSrc)
  audioSrc = dsp.AudioFileReader('audio_mix_441.wav',...
    'SamplesPerFrame',spf,'PlayCount',inf);
  audioRC = dsp.SampleRateConverter('Bandwidth',30e3,...
    'InputSampleRate',audioSrc.SampleRate,...
    'OutputSampleRate',fs);
  [~,decimFactor] = getRateChangeFactors(audioRC);
  audioSrc.SamplesPerFrame = ceil(spf / fs * audioSrc.SampleRate / decimFactor) * decimFactor;
end

x = audioRC(audioSrc());
x = x(1:spf,1);
end

function frames = getNNFrames(rx,modType)
%getNNFrames Generate formatted frames for neural networks
%   F = getNNFrames(X,MODTYPE) formats the input X, into frames
%   that can be used with the neural network designed in this
%   example, and returns the frames in the output F.

frames = helperModClassFrameGenerator(rx,1024,1024,32,8);
frameStore = helperModClassFrameStore(10,1024,categorical({modType}));
add(frameStore,frames,modType);
frames = get(frameStore);
end

function plotScores(score,labels)
%plotScores Plot classification scores of frames
%   plotScores(SCR,LABELS) plots the classification scores SCR as a stacked
%   bar for each frame. SCR is a matrix in which each row is the score for a
%   frame.

co = [0.08 0.9 0.49;
  0.52 0.95 0.70;
  0.36 0.53 0.96;
  0.09 0.54 0.67;
  0.48 0.99 0.26;
  0.95 0.31 0.17;
  0.52 0.85 0.95;
  0.08 0.72 0.88;
  0.12 0.45 0.69;
  0.22 0.11 0.49;
  0.65 0.54 0.71];
figure; ax = axes('ColorOrder',co,'NextPlot','replacechildren');
bar(ax,[score; nan(2,11)],'stacked'); legend(categories(labels),'Location','best');
xlabel('Frame Number'); ylabel('Score'); title('Classification Scores')
end

function plotTimeDomain(rxTest,rxTestLabel,modulationTypes,fs)
%plotTimeDomain Time domain plots of frames

numRows = ceil(length(modulationTypes) / 4);
spf = size(rxTest,2);
t = 1000*(0:spf-1)/fs;
if size(rxTest,1) == 2
  IQAsRows = true;
else
  IQAsRows = false;
end
for modType=1:length(modulationTypes)
  subplot(numRows, 4, modType);
  idxOut = find(rxTestLabel == modulationTypes(modType), 1);
  if IQAsRows
    rxI = rxTest(1,:,1,idxOut);
    rxQ = rxTest(2,:,1,idxOut);
  else
    rxI = rxTest(1,:,1,idxOut);
    rxQ = rxTest(1,:,2,idxOut);
  end
  plot(t,squeeze(rxI), '-'); grid on; axis equal; axis square
  hold on
  plot(t,squeeze(rxQ), '-'); grid on; axis equal; axis square
  hold off
  title(string(modulationTypes(modType)));
  xlabel('Time (ms)'); ylabel('Amplitude')
end
end

function plotSpectrogram(rxTest,rxTestLabel,modulationTypes,fs,sps)
%plotSpectrogram Spectrogram of frames

if size(rxTest,1) == 2
  IQAsRows = true;
else
  IQAsRows = false;
end
numRows = ceil(length(modulationTypes) / 4);
for modType=1:length(modulationTypes)
  subplot(numRows, 4, modType);
  idxOut = find(rxTestLabel == modulationTypes(modType), 1);
  if IQAsRows
    rxI = rxTest(1,:,1,idxOut);
    rxQ = rxTest(2,:,1,idxOut);
  else
    rxI = rxTest(1,:,1,idxOut);
    rxQ = rxTest(1,:,2,idxOut);
  end
  rx = squeeze(rxI) + 1i*squeeze(rxQ);
  spectrogram(rx,kaiser(sps),0,1024,fs,'centered');
  title(string(modulationTypes(modType)));
end
h = gcf; delete(findall(h.Children, 'Type', 'ColorBar'))
end

function flag = isPlutoSDRInstalled
%isPlutoSDRInstalled Check if ADALM-PLUTO Radio HSP is installed

spkg = matlabshared.supportpkg.getInstalled;
flag = ~isempty(spkg) && any(contains({spkg.Name},'ADALM-PLUTO','IgnoreCase',true));
end

function flag = isUSRPInstalled
%isUSRPInstalled Check if USRP Radio HSP is installed

spkg = matlabshared.supportpkg.getInstalled;
flag = ~isempty(spkg) && any(contains({spkg.Name},'USRP','IgnoreCase',true));
end

附录:调制器

function y = bpskModulator(x,sps)
%bpskModulator BPSK modulator with pulse shaping
%   Y = bpskModulator(X,SPS) BPSK modulates the input X, and returns the
%   root-raised cosine pulse shaped signal Y. X must be a column vector
%   of values in the set [0 1]. The root-raised cosine filter has a
%   roll-off factor of 0.35 and spans four symbols. The output signal
%   Y has unit power.

persistent filterCoeffs
if isempty(filterCoeffs)
  filterCoeffs = rcosdesign(0.35, 4, sps);
end
% Modulate
syms = pskmod(x,2);
% Pulse shape
y = filter(filterCoeffs, 1, upsample(syms,sps));
end

function y = qpskModulator(x,sps)
%qpskModulator QPSK modulator with pulse shaping
%   Y = qpskModulator(X,SPS) QPSK modulates the input X, and returns the
%   root-raised cosine pulse shaped signal Y. X must be a column vector
%   of values in the set [0 3]. The root-raised cosine filter has a
%   roll-off factor of 0.35 and spans four symbols. The output signal
%   Y has unit power.

persistent filterCoeffs
if isempty(filterCoeffs)
  filterCoeffs = rcosdesign(0.35, 4, sps);
end
% Modulate
syms = pskmod(x,4,pi/4);
% Pulse shape
y = filter(filterCoeffs, 1, upsample(syms,sps));
end

function y = psk8Modulator(x,sps)
%psk8Modulator 8-PSK modulator with pulse shaping
%   Y = psk8Modulator(X,SPS) 8-PSK modulates the input X, and returns the
%   root-raised cosine pulse shaped signal Y. X must be a column vector
%   of values in the set [0 7]. The root-raised cosine filter has a
%   roll-off factor of 0.35 and spans four symbols. The output signal
%   Y has unit power.

persistent filterCoeffs
if isempty(filterCoeffs)
  filterCoeffs = rcosdesign(0.35, 4, sps);
end
% Modulate
syms = pskmod(x,8);
% Pulse shape
y = filter(filterCoeffs, 1, upsample(syms,sps));
end

function y = qam16Modulator(x,sps)
%qam16Modulator 16-QAM modulator with pulse shaping
%   Y = qam16Modulator(X,SPS) 16-QAM modulates the input X, and returns the
%   root-raised cosine pulse shaped signal Y. X must be a column vector
%   of values in the set [0 15]. The root-raised cosine filter has a
%   roll-off factor of 0.35 and spans four symbols. The output signal
%   Y has unit power.

persistent filterCoeffs
if isempty(filterCoeffs)
  filterCoeffs = rcosdesign(0.35, 4, sps);
end
% Modulate and pulse shape
syms = qammod(x,16,'UnitAveragePower',true);
% Pulse shape
y = filter(filterCoeffs, 1, upsample(syms,sps));
end

function y = qam64Modulator(x,sps)
%qam64Modulator 64-QAM modulator with pulse shaping
%   Y = qam64Modulator(X,SPS) 64-QAM modulates the input X, and returns the
%   root-raised cosine pulse shaped signal Y. X must be a column vector
%   of values in the set [0 63]. The root-raised cosine filter has a
%   roll-off factor of 0.35 and spans four symbols. The output signal
%   Y has unit power.

persistent filterCoeffs
if isempty(filterCoeffs)
  filterCoeffs = rcosdesign(0.35, 4, sps);
end
% Modulate
syms = qammod(x,64,'UnitAveragePower',true);
% Pulse shape
y = filter(filterCoeffs, 1, upsample(syms,sps));
end

function y = pam4Modulator(x,sps)
%pam4Modulator PAM4 modulator with pulse shaping
%   Y = pam4Modulator(X,SPS) PAM4 modulates the input X, and returns the
%   root-raised cosine pulse shaped signal Y. X must be a column vector
%   of values in the set [0 3]. The root-raised cosine filter has a
%   roll-off factor of 0.35 and spans four symbols. The output signal
%   Y has unit power.

persistent filterCoeffs amp
if isempty(filterCoeffs)
  filterCoeffs = rcosdesign(0.35, 4, sps);
  amp = 1 / sqrt(mean(abs(pammod(0:3, 4)).^2));
end
% Modulate
syms = amp * pammod(x,4);
% Pulse shape
y = filter(filterCoeffs, 1, upsample(syms,sps));
end

function y = gfskModulator(x,sps)
%gfskModulator GFSK modulator
%   Y = gfskModulator(X,SPS) GFSK modulates the input X and returns the
%   signal Y. X must be a column vector of values in the set [0 1]. The
%   BT product is 0.35 and the modulation index is 1. The output signal
%   Y has unit power.

persistent mod meanM
if isempty(mod)
  M = 2;
  mod = comm.CPMModulator(...
    'ModulationOrder', M, ...
    'FrequencyPulse', 'Gaussian', ...
    'BandwidthTimeProduct', 0.35, ...
    'ModulationIndex', 1, ...
    'SamplesPerSymbol', sps);
  meanM = mean(0:M-1);
end
% Modulate
y = mod(2*(x-meanM));
end

function y = cpfskModulator(x,sps)
%cpfskModulator CPFSK modulator
%   Y = cpfskModulator(X,SPS) CPFSK modulates the input X and returns
%   the signal Y. X must be a column vector of values in the set [0 1].
%   the modulation index is 0.5. The output signal Y has unit power.

persistent mod meanM
if isempty(mod)
  M = 2;
  mod = comm.CPFSKModulator(...
    'ModulationOrder', M, ...
    'ModulationIndex', 0.5, ...
    'SamplesPerSymbol', sps);
  meanM = mean(0:M-1);
end
% Modulate
y = mod(2*(x-meanM));
end

function y = bfmModulator(x,fs)
%bfmModulator Broadcast FM modulator
%   Y = bfmModulator(X,FS) broadcast FM modulates the input X and returns
%   the signal Y at the sample rate FS. X must be a column vector of
%   audio samples at the sample rate FS. The frequency deviation is 75 kHz
%   and the pre-emphasis filter time constant is 75 microseconds.

persistent mod
if isempty(mod)
  mod = comm.FMBroadcastModulator(...
    'AudioSampleRate', fs, ...
    'SampleRate', fs);
end
y = mod(x);
end

function y = dsbamModulator(x,fs)
%dsbamModulator Double sideband AM modulator
%   Y = dsbamModulator(X,FS) double sideband AM modulates the input X and
%   returns the signal Y at the sample rate FS. X must be a column vector of
%   audio samples at the sample rate FS. The IF frequency is 50 kHz.

y = ammod(x,50e3,fs);
end

function y = ssbamModulator(x,fs)
%ssbamModulator Single sideband AM modulator
%   Y = ssbamModulator(X,FS) single sideband AM modulates the input X and
%   returns the signal Y at the sample rate FS. X must be a column vector of
%   audio samples at the sample rate FS. The IF frequency is 50 kHz.

y = ssbmod(x,50e3,fs);
end

另请参阅

|

相关主题