Main Content

本页面提供的是上一版软件的文档。当前版本中已删除对应的英文页面。

使用深度学习进行调制分类

此示例说明如何使用卷积神经网络 (CNN) 进行调制分类。您将生成合成的、通道减损波形。使用生成的波形作为训练数据,训练 CNN 进行调制分类。然后用软件定义的无线电 (SDR) 硬件和无线信号测试 CNN。

使用 CNN 预测调制类型

本示例中经过训练的 CNN 可识别以下八种数字调制类型和三种模拟调制类型:

  • 二相相移键控 (BPSK)

  • 四相相移键控 (QPSK)

  • 八相相移键控 (8-PSK)

  • 十六相正交幅值调制 (16-QAM)

  • 六十四相正交幅值调制 (64-QAM)

  • 四相脉冲幅值调制 (PAM4)

  • 高斯频移键控 (GFSK)

  • 连续相位频移键控 (CPFSK)

  • 广播 FM (B-FM)

  • 双边带幅值调制 (DSB-AM)

  • 单边带幅值调制 (SSB-AM)

modulationTypes = categorical(["BPSK", "QPSK", "8PSK", ...
  "16QAM", "64QAM", "PAM4", "GFSK", "CPFSK", ...
  "B-FM", "DSB-AM", "SSB-AM"]);

首先,加载经过训练的网络。有关网络训练的详细信息,请参阅“训练 CNN”一节。

load trainedModulationClassificationNetwork
trainedNet
trainedNet = 
  SeriesNetwork with properties:

         Layers: [28x1 nnet.cnn.layer.Layer]
     InputNames: {'Input Layer'}
    OutputNames: {'Output'}

经过训练的 CNN 接受 1024 个通道减损采样,并预测每个帧的调制类型。生成几个因莱斯多径衰落、中心频率和采样时间漂移以及 AWGN 而有所减损的 PAM4 帧。使用以下函数生成合成信号来测试 CNN。然后使用 CNN 预测帧的调制类型。

  • randi:生成随机位

  • pammod (Communications Toolbox):PAM4 调制位

  • rcosdesign (Signal Processing Toolbox):设计平方根升余弦脉冲整形滤波器

  • filter:脉冲确定符号的形状

  • comm.RicianChannel (Communications Toolbox):应用莱斯多径通道

  • comm.PhaseFrequencyOffset (Communications Toolbox):应用时钟偏移引起的相位和/或频率偏移

  • interp1:应用时钟偏移引起的计时漂移

  • awgn (Communications Toolbox):添加 AWGN

% Set the random number generator to a known state to be able to regenerate
% the same frames every time the simulation is run
rng(123456)
% Random bits
d = randi([0 3], 1024, 1);
% PAM4 modulation
syms = pammod(d,4);
% Square-root raised cosine filter
filterCoeffs = rcosdesign(0.35,4,8);
tx = filter(filterCoeffs,1,upsample(syms,8));

% Channel
SNR = 30;
maxOffset = 5;
fc = 902e6;
fs = 200e3;
multipathChannel = comm.RicianChannel(...
  'SampleRate', fs, ...
  'PathDelays', [0 1.8 3.4] / 200e3, ...
  'AveragePathGains', [0 -2 -10], ...
  'KFactor', 4, ...
  'MaximumDopplerShift', 4);

frequencyShifter = comm.PhaseFrequencyOffset(...
  'SampleRate', fs);

% Apply an independent multipath channel
reset(multipathChannel)
outMultipathChan = multipathChannel(tx);

% Determine clock offset factor
clockOffset = (rand() * 2*maxOffset) - maxOffset;
C = 1 + clockOffset / 1e6;

% Add frequency offset
frequencyShifter.FrequencyOffset = -(C-1)*fc;
outFreqShifter = frequencyShifter(outMultipathChan);

% Add sampling time drift
t = (0:length(tx)-1)' / fs;
newFs = fs * C;
tp = (0:length(tx)-1)' / newFs;
outTimeDrift = interp1(t, outFreqShifter, tp);

% Add noise
rx = awgn(outTimeDrift,SNR,0);

% Frame generation for classification
unknownFrames = helperModClassGetNNFrames(rx);

% Classification
[prediction1,score1] = classify(trainedNet,unknownFrames);

返回分类器预测,这类似于硬判决。网络将帧正确识别为 PAM4 帧。有关生成调制信号的详细信息,请参阅 helperModClassGetModulator 函数。

prediction1
prediction1 = 7x1 categorical
     PAM4 
     PAM4 
     PAM4 
     PAM4 
     PAM4 
     PAM4 
     PAM4 

分类器还返回一个包含每一帧分数的向量。分数对应于每个帧具有预测的调制类型的概率。绘制分数图。

helperModClassPlotScores(score1,modulationTypes)

我们首先需要用已知(即已加标签的)数据训练 CNN,然后才能使用 CNN 进行调制分类或执行任何其他任务。此示例的第一部分说明如何使用 Communications Toolbox 功能(如调制器、滤波器和通道减损)来生成合成的训练数据。第二部分着重于针对调制分类任务来定义、训练和测试 CNN。第三部分通过软件定义无线电 (SDR) 平台使用无线信号来测试网络性能。

生成用于训练的波形

为每种调制类型生成 10000 个帧,其中 80% 用于训练,10% 用于验证,10% 用于测试。我们在网络训练阶段使用训练和验证帧。使用测试帧获得最终分类准确度。每帧的长度为 1024 个样本,采样率为 200 kHz。对于数字调制类型,八个采样表示一个符号。网络根据单个帧而不是多个连续帧(如视频)作出每个决定。假设数字和模拟调制类型的中心频率分别为 902 MHz 和 100 MHz。

要快速运行此示例,请使用经过训练的网络并生成少量训练帧。要在您的计算机上训练网络,请选择“Train network now”选项(即,将 trainNow 设置为 true)。

trainNow = false;
if trainNow == true
  numFramesPerModType = 10000;
else
  numFramesPerModType = 500;
end
percentTrainingSamples = 80;
percentValidationSamples = 10;
percentTestSamples = 10;

sps = 8;                % Samples per symbol
spf = 1024;             % Samples per frame
symbolsPerFrame = spf / sps;
fs = 200e3;             % Sample rate
fc = [902e6 100e6];     % Center frequencies

创建通道减损

让每帧通过通道并具有

  • AWGN

  • 莱斯多径衰落

  • 时钟偏移,导致中心频率偏移和采样时间漂移

由于本示例中的网络基于单个帧作出决定,因此每个帧必须通过独立的通道。

AWGN

通道增加 SNR 为 30 dB 的 AWGN。使用 awgn (Communications Toolbox) 函数实现通道。

莱斯多径

通道使用 comm.RicianChannel (Communications Toolbox) System object 通过莱斯多径衰落通道传递信号。假设延迟分布为 [0 1.8 3.4] 个样本,对应的平均路径增益为 [0 -2 -10] dB。K 因子为 4,最大多普勒频移为 4 Hz,等效于 902 MHz 的步行速度。使用以下设置实现通道。

时钟偏移

时钟偏移是发送器和接收器的内部时钟源不准确造成的。时钟偏移导致中心频率(用于将信号下变频至基带)和数模转换器采样率不同于理想值。通道仿真器使用时钟偏移因子 C,表示为 C=1+Δclock106,其中 Δclock 是时钟偏移。对于每个帧,通道基于 [-maxΔclock maxΔclock] 范围内一组均匀分布的值生成一个随机 Δclock 值,其中 maxΔclock 是最大时钟偏移。时钟偏移以百万分率 (ppm) 为单位测量。对于本示例,假设最大时钟偏移为 5 ppm。

maxDeltaOff = 5;
deltaOff = (rand()*2*maxDeltaOff) - maxDeltaOff;
C = 1 + (deltaOff/1e6);

频率偏移

基于时钟偏移因子 C 和中心频率,对每帧进行频率偏移。使用 comm.PhaseFrequencyOffset (Communications Toolbox) 实现通道。

采样率偏移

基于时钟偏移因子 C,对每帧进行采样率偏移。使用 interp1 函数实现通道,以 C×fs 的新速率对帧进行重新采样。

合并后的通道

使用 helperModClassTestChannel 对象对帧应用所有三种通道减损。

channel = helperModClassTestChannel(...
  'SampleRate', fs, ...
  'SNR', SNR, ...
  'PathDelays', [0 1.8 3.4] / fs, ...
  'AveragePathGains', [0 -2 -10], ...
  'KFactor', 4, ...
  'MaximumDopplerShift', 4, ...
  'MaximumClockOffset', 5, ...
  'CenterFrequency', 902e6)
channel = 
  helperModClassTestChannel with properties:

                    SNR: 30
        CenterFrequency: 902000000
             SampleRate: 200000
             PathDelays: [0 9.0000e-06 1.7000e-05]
       AveragePathGains: [0 -2 -10]
                KFactor: 4
    MaximumDopplerShift: 4
     MaximumClockOffset: 5

您可以使用 info 对象函数查看有关通道的基本信息。

chInfo = info(channel)
chInfo = struct with fields:
               ChannelDelay: 6
     MaximumFrequencyOffset: 4510
    MaximumSampleRateOffset: 1

波形生成

创建一个循环,它为每种调制类型生成通道减损的帧并将这些帧及其对应标签存储在 MAT 文件中。通过将数据保存到文件中,您无需每次运行此示例时都生成数据。您还可以更高效地共享数据。

从每帧的开头删除随机数量的样本,以去除瞬变并确保帧相对于符号边界具有随机起点。

% Set the random number generator to a known state to be able to regenerate
% the same frames every time the simulation is run
rng(1235)
tic

numModulationTypes = length(modulationTypes);

channelInfo = info(channel);
transDelay = 50;
dataDirectory = fullfile(tempdir,"ModClassDataFiles");
disp("Data file directory is " + dataDirectory)
Data file directory is /tmp/Bdoc20b_1465442_200336/ModClassDataFiles
fileNameRoot = "frame";

% Check if data files exist
dataFilesExist = false;
if exist(dataDirectory,'dir')
  files = dir(fullfile(dataDirectory,sprintf("%s*",fileNameRoot)));
  if length(files) == numModulationTypes*numFramesPerModType
    dataFilesExist = true;
  end
end

if ~dataFilesExist
  disp("Generating data and saving in data files...")
  [success,msg,msgID] = mkdir(dataDirectory);
  if ~success
    error(msgID,msg)
  end
  for modType = 1:numModulationTypes
    fprintf('%s - Generating %s frames\n', ...
      datestr(toc/86400,'HH:MM:SS'), modulationTypes(modType))
    
    label = modulationTypes(modType);
    numSymbols = (numFramesPerModType / sps);
    dataSrc = helperModClassGetSource(modulationTypes(modType), sps, 2*spf, fs);
    modulator = helperModClassGetModulator(modulationTypes(modType), sps, fs);
    if contains(char(modulationTypes(modType)), {'B-FM','DSB-AM','SSB-AM'})
      % Analog modulation types use a center frequency of 100 MHz
      channel.CenterFrequency = 100e6;
    else
      % Digital modulation types use a center frequency of 902 MHz
      channel.CenterFrequency = 902e6;
    end
    
    for p=1:numFramesPerModType
      % Generate random data
      x = dataSrc();
      
      % Modulate
      y = modulator(x);
      
      % Pass through independent channels
      rxSamples = channel(y);
      
      % Remove transients from the beginning, trim to size, and normalize
      frame = helperModClassFrameGenerator(rxSamples, spf, spf, transDelay, sps);
      
      % Save data file
      fileName = fullfile(dataDirectory,...
        sprintf("%s%s%03d",fileNameRoot,modulationTypes(modType),p));
      save(fileName,"frame","label")
    end
  end
else
  disp("Data files exist. Skip data generation.")
end
Generating data and saving in data files...
00:00:00 - Generating BPSK frames
00:00:05 - Generating QPSK frames
00:00:09 - Generating 8PSK frames
00:00:18 - Generating 16QAM frames
00:00:22 - Generating 64QAM frames
00:00:28 - Generating PAM4 frames
00:00:34 - Generating GFSK frames
00:00:38 - Generating CPFSK frames
00:00:44 - Generating B-FM frames
00:01:04 - Generating DSB-AM frames
00:01:08 - Generating SSB-AM frames
% Plot the amplitude of the real and imaginary parts of the example frames
% against the sample number
helperModClassPlotTimeDomain(dataDirectory,modulationTypes,fs)

% Plot the spectrogram of the example frames
helperModClassPlotSpectrogram(dataDirectory,modulationTypes,fs,sps)

创建数据存储

使用 signalDatastore 对象来管理包含生成的复杂波形的文件。如果每个文件可以单独放入内存中,但整个集合不一定能放入内存,则数据存储特别有用。

frameDS = signalDatastore(dataDirectory,'SignalVariableNames',["frame","label"]);

将复信号变换为实数数组

此示例中的深度学习网络需要实数输入,而接收的信号包含复数基带采样。将复信号变换为实数值四维数组。输出帧的大小为 1×spf×2×N,其中第一页(第三个维度)是同相采样,第二页是正交采样。当卷积滤波器的大小为 1×spf 时,这种方法可确保混合 I 和 Q(甚至是卷积层)中的信息,使我们能够更好地利用相位信息。有关详细信息,请参阅 helperModClassIQAsPages

frameDSTrans = transform(frameDS,@helperModClassIQAsPages);

拆分为训练、验证和测试

接下来,将帧分为训练数据、验证数据和测试数据。有关详细信息,请参阅 helperModClassSplitData

splitPercentages = [percentTrainingSamples,percentValidationSamples,percentTestSamples];
[trainDSTrans,validDSTrans,testDSTrans] = helperModClassSplitData(frameDSTrans,splitPercentages);
Starting parallel pool (parpool) using the 'local' profile ...
Connected to the parallel pool (number of workers: 4).
Evaluating tall expression using the Parallel Pool 'local':
- Pass 1 of 2: 0% complete
Evaluation 0% complete
- Pass 1 of 2: Completed in 18 sec
- Pass 2 of 2: Completed in 20 sec
Evaluation completed in 41 sec

将数据导入内存

神经网络训练是迭代进行的。在每次迭代中,数据存储从文件中读取数据,变换数据,然后更新网络系数。如果数据可放入计算机的内存中,则将数据从文件导入内存可以消除重复的文件读取和变换过程,从而加快训练速度。这样,只需执行一次从文件读取并变换数据的操作。使用磁盘上的数据文件训练此网络需要大约 110 分钟,而使用内存中的数据只需要大约 50 分钟。

将文件中的所有数据导入内存。这些文件有两个变量:framelabel,对数据存储的每个 read 调用都返回一个元胞数组,其中第一个元素是 frame,第二个元素是 label。使用 transform 函数 helperModClassReadFramehelperModClassReadLabel 读取帧和标签。如果您有 Parallel Computing Toolbox 许可证,请使用 tall 数组来实现变换函数的并行处理。由于 gather 函数默认情况下在第一个维度上串联 read 函数的输出,因此以元胞数组的形式返回帧并在第四个维度上手动进行串联。

% Gather the training and validation frames into the memory
trainFramesTall = tall(transform(trainDSTrans, @helperModClassReadFrame));
rxTrainFrames = gather(trainFramesTall);
Evaluating tall expression using the Parallel Pool 'local':
- Pass 1 of 1: Completed in 7.7 sec
Evaluation completed in 7.7 sec
rxTrainFrames = cat(4, rxTrainFrames{:});
validFramesTall = tall(transform(validDSTrans, @helperModClassReadFrame));
rxValidFrames = gather(validFramesTall);
Evaluating tall expression using the Parallel Pool 'local':
- Pass 1 of 1: Completed in 1.8 sec
Evaluation completed in 1.8 sec
rxValidFrames = cat(4, rxValidFrames{:});

% Gather the training and validation labels into the memory
trainLabelsTall = tall(transform(trainDSTrans, @helperModClassReadLabel));
rxTrainLabels = gather(trainLabelsTall);
Evaluating tall expression using the Parallel Pool 'local':
- Pass 1 of 2: Completed in 5.5 sec
- Pass 2 of 2: Completed in 7.8 sec
Evaluation completed in 14 sec
validLabelsTall = tall(transform(validDSTrans, @helperModClassReadLabel));
rxValidLabels = gather(validLabelsTall);
Evaluating tall expression using the Parallel Pool 'local':
- Pass 1 of 2: Completed in 1.6 sec
- Pass 2 of 2: Completed in 2 sec
Evaluation completed in 4 sec

训练 CNN

本示例使用的 CNN 由六个卷积层和一个全连接层组成。除最后一个卷积层外,每个卷积层后面都有一个批量归一化层、修正线性单元 (ReLU) 激活层和最大池化层。在最后一个卷积层中,最大池化层被一个平均池化层取代。输出层具有 softmax 激活。有关网络设计指导原则,请参阅Deep Learning Tips and Tricks

modClassNet = helperModClassCNN(modulationTypes,sps,spf);

接下来配置 TrainingOptionsSGDM 以使用小批量大小为 256 的 SGDM 求解器。将最大轮数设置为 12,因为更多轮数不会提供进一步的训练优势。默认情况下,'ExecutionEnvironment' 属性设置为 'auto',其中 trainNetwork 函数使用 GPU(如果可用)或使用 CPU(如果 GPU 不可用)。要使用 GPU,您必须拥有 Parallel Computing Toolbox 许可证。将初始学习率设置为 2x10-2。每 9 轮后将学习率降低十分之一。将 'Plots' 设置为 'training-progress' 以对训练进度绘图。在 NVIDIA Titan Xp GPU 上,网络需要大约 25 分钟来完成训练。

maxEpochs = 12;
miniBatchSize = 256;
options = helperModClassTrainingOptions(maxEpochs,miniBatchSize,...
  numel(rxTrainLabels),rxValidFrames,rxValidLabels);

或者训练网络,或者使用已经过训练的网络。默认情况下,此示例使用经过训练的网络。

if trainNow == true
  fprintf('%s - Training the network\n', datestr(toc/86400,'HH:MM:SS'))
  trainedNet = trainNetwork(rxTrainFrames,rxTrainLabels,modClassNet,options);
else
  load trainedModulationClassificationNetwork
end

如训练进度图所示,网络在大约 12 轮后收敛于超过 95% 的准确度。

通过获得测试帧的分类准确度来评估经过训练的网络。结果表明,该网络对这组波形实现的准确度达到 94% 左右。

fprintf('%s - Classifying test frames\n', datestr(toc/86400,'HH:MM:SS'))
00:03:52 - Classifying test frames
% Gather the test frames into the memory
testFramesTall = tall(transform(testDSTrans, @helperModClassReadFrame));
rxTestFrames = gather(testFramesTall);
Evaluating tall expression using the Parallel Pool 'local':
- Pass 1 of 1: Completed in 1.6 sec
Evaluation completed in 1.6 sec
rxTestFrames = cat(4, rxTestFrames{:});

% Gather the test labels into the memory
testLabelsTall = tall(transform(testDSTrans, @helperModClassReadLabel));
rxTestLabels = gather(testLabelsTall);
Evaluating tall expression using the Parallel Pool 'local':
- Pass 1 of 2: Completed in 1.6 sec
- Pass 2 of 2: Completed in 2 sec
Evaluation completed in 4.4 sec
rxTestPred = classify(trainedNet,rxTestFrames);
testAccuracy = mean(rxTestPred == rxTestLabels);
disp("Test accuracy: " + testAccuracy*100 + "%")
Test accuracy: 95.4545%

绘制测试帧的混淆矩阵。如矩阵所示,网络混淆了 16-QAM 和 64-QAM 帧。此问题是预料之中的,因为每个帧只携带 128 个符号,而 16-QAM 是 64-QAM 的子集。该网络还混淆了 QPSK 和 8-PSK 帧,因为在通道衰落和频率偏移引发相位旋转后,这些调制类型的星座图看起来相似。

figure
cm = confusionchart(rxTestLabels, rxTestPred);
cm.Title = 'Confusion Matrix for Test Data';
cm.RowSummary = 'row-normalized';
cm.Parent.Position = [cm.Parent.Position(1:2) 740 424];

使用 SDR 进行测试

使用 helperModClassSDRTest 函数,通过无线信号测试经过训练的网络的性能。要执行此测试,您必须有专用的 SDR 用于发送和接收。您可以使用两个 ADALM-PLUTO 无线电,或使用一个 ADALM-PLUTO 无线电和一个 USRP® 无线电分别进行发送和接收。您必须安装 Communications Toolbox Support Package for ADALM-PLUTO Radio。如果使用 USRP® 无线电,还必须安装 Communications Toolbox Support Package for USRP® RadiohelperModClassSDRTest 函数使用生成训练信号所用的同一调制函数,然后使用 ADALM-PLUTO 无线电进行传输。使用针对信号接收配置的 SDR(ADALM-PLUTO 或 USRP® 无线电)捕获通道减损的信号,而不对通道进行仿真。使用经过训练的网络和先前使用的同一 classify 函数来预测调制类型。运行下一个代码段会生成一个混淆矩阵,并输出测试准确度。

radioPlatform = "ADALM-PLUTO";

switch radioPlatform
  case "ADALM-PLUTO"
    if helperIsPlutoSDRInstalled() == true
      radios = findPlutoRadio();
      if length(radios) >= 2
        helperModClassSDRTest(radios);
      else
        disp('Selected radios not found. Skipping over-the-air test.')
      end
    end
  case {"USRP B2xx","USRP X3xx","USRP N2xx"}
    if (helperIsUSRPInstalled() == true) && (helperIsPlutoSDRInstalled() == true)
      txRadio = findPlutoRadio();
      rxRadio = findsdru();
      switch radioPlatform
        case "USRP B2xx"
          idx = contains({rxRadio.Platform}, {'B200','B210'});
        case "USRP X3xx"
          idx = contains({rxRadio.Platform}, {'X300','X310'});
        case "USRP N2xx"
          idx = contains({rxRadio.Platform}, 'N200/N210/USRP2');
      end
      rxRadio = rxRadio(idx);
      if (length(txRadio) >= 1) && (length(rxRadio) >= 1)
        helperModClassSDRTest(rxRadio);
      else
        disp('Selected radios not found. Skipping over-the-air test.')
      end
    end
end

当使用两个相隔约 2 英尺的固定 ADALM-PLUTO 无线电时,网络可实现 99% 的总体准确度,混淆矩阵如下图所示。结果可因试验设置而异。

进一步探查

要提高准确度,可以优化超参数,例如滤波器数量、滤波器大小;或者优化网络结构,例如添加更多层,使用不同激活层等。

Communication Toolbox 提供了更多调制类型和通道减损。有关详细信息,请参阅Modulation (Communications Toolbox)Propagation and Channel Models (Communications Toolbox)部分。您还可以使用 LTE ToolboxWLAN Toolbox5G Toolbox 添加标准特定信号。您还可以使用 Phased Array System Toolbox 添加雷达信号。

helperModClassGetModulator 函数提供用于生成调制信号的 MATLAB 函数。您还可以探查以下函数和 System object 以获得详细信息:

参考资料

  1. O'Shea, T. J., J. Corgan, and T. C. Clancy."Convolutional Radio Modulation Recognition Networks."Preprint, submitted June 10, 2016. https://arxiv.org/abs/1602.04105

  2. O'Shea, T. J., T. Roy, and T. C. Clancy."Over-the-Air Deep Learning Based Radio Signal Classification."IEEE Journal of Selected Topics in Signal Processing.Vol. 12, Number 1, 2018, pp. 168–179.

  3. Liu, X., D. Yang, and A. E. Gamal."Deep Neural Network Architectures for Modulation Classification."Preprint, submitted January 5, 2018. https://arxiv.org/abs/1712.00443v3

另请参阅

|

相关主题