Main Content

使用深度学习进行调制分类

此示例说明如何使用卷积神经网络 (CNN) 进行调制分类。您将生成合成的、通道损伤波形。使用生成的波形作为训练数据,训练 CNN 进行调制分类。然后用软件无线电 (SDR) 硬件和空口信号测试 CNN。

使用 CNN 预测调制类型

本示例中经过训练的 CNN 可识别以下八种数字调制类型和三种模拟调制类型:

  • 二相相移键控 (BPSK)

  • 四相相移键控 (QPSK)

  • 八相相移键控 (8-PSK)

  • 十六相正交调幅 (16-QAM)

  • 六十四相正交调幅 (64-QAM)

  • 四相脉冲振幅调制 (PAM4)

  • 高斯频移键控 (GFSK)

  • 连续相位频移键控 (CPFSK)

  • 广播 FM (B-FM)

  • 双边带振幅调制 (DSB-AM)

  • 单边带振幅调制 (SSB-AM)

modulationTypes = categorical(sort(["BPSK", "QPSK", "8PSK", ...
  "16QAM", "64QAM", "PAM4", "GFSK", "CPFSK", ...
  "B-FM", "DSB-AM", "SSB-AM"]));

首先,加载经过训练的网络。有关网络训练的详细信息,请参阅“训练 CNN”一节。

load trainedModulationClassificationNetwork
trainedNet
trainedNet = 
  dlnetwork with properties:

         Layers: [19×1 nnet.cnn.layer.Layer]
    Connections: [18×2 table]
     Learnables: [22×3 table]
          State: [10×3 table]
     InputNames: {'Input Layer'}
    OutputNames: {'SoftMax'}
    Initialized: 1

  View summary with summary.

经过训练的 CNN 接受 1024 个通道损伤采样,并预测每个帧的调制类型。生成几个因莱斯多径衰落、中心频率和采样时间漂移以及 AWGN 而有所减损的 PAM4 帧。使用以下函数生成合成信号来测试 CNN。然后使用 CNN 预测帧的调制类型。

  • randi:生成随机位

  • pammod (Communications Toolbox):PAM4 调制位

  • rcosdesign (Signal Processing Toolbox):设计平方根升余弦脉冲整形滤波器

  • filter:脉冲确定符号的形状

  • comm.RicianChannel (Communications Toolbox):应用莱斯多径通道

  • comm.PhaseFrequencyOffset (Communications Toolbox):应用时钟偏移引起的相位和/或频率偏移

  • interp1:应用时钟偏移引起的计时漂移

  • awgn (Communications Toolbox):添加 AWGN

% Set the random number generator to a known state to be able to regenerate
% the same frames every time the simulation is run
rng(123456)
% Random bits
d = randi([0 3], 1024, 1);
% PAM4 modulation
syms = pammod(d,4);
% Square-root raised cosine filter
filterCoeffs = rcosdesign(0.35,4,8);
tx = filter(filterCoeffs,1,upsample(syms,8));

% Channel
SNR = 30;
maxOffset = 5;
fc = 902e6;
fs = 200e3;
multipathChannel = comm.RicianChannel(...
  'SampleRate', fs, ...
  'PathDelays', [0 1.8 3.4] / 200e3, ...
  'AveragePathGains', [0 -2 -10], ...
  'KFactor', 4, ...
  'MaximumDopplerShift', 4);

frequencyShifter = comm.PhaseFrequencyOffset(...
  'SampleRate', fs);

% Apply an independent multipath channel
reset(multipathChannel)
outMultipathChan = multipathChannel(tx);

% Determine clock offset factor
clockOffset = (rand() * 2*maxOffset) - maxOffset;
C = 1 + clockOffset / 1e6;

% Add frequency offset
frequencyShifter.FrequencyOffset = -(C-1)*fc;
outFreqShifter = frequencyShifter(outMultipathChan);

% Add sampling time drift
t = (0:length(tx)-1)' / fs;
newFs = fs * C;
tp = (0:length(tx)-1)' / newFs;
outTimeDrift = interp1(t, outFreqShifter, tp);

% Add noise
rx = awgn(outTimeDrift,SNR,0);

% Frame generation for classification
unknownFrames = helperModClassGetNNFrames(rx);

% Classification
scores1 = predict(trainedNet,unknownFrames);
prediction1 = scores2label(scores1,modulationTypes);

返回分类器预测,这类似于硬判决。网络将帧正确识别为 PAM4 帧。有关生成调制信号的详细信息,请参阅 helperModClassGetModulator 函数。

prediction1
prediction1 = 7×1 categorical
     PAM4 
     PAM4 
     PAM4 
     PAM4 
     PAM4 
     PAM4 
     PAM4 

分类器还返回一个包含每一帧分数的向量。分数对应于每个帧具有预测的调制类型的概率。绘制分数图。

helperModClassPlotScores(scores1,modulationTypes)

我们首先需要用已知(即已加标签的)数据训练 CNN,然后才能使用 CNN 进行调制分类或执行任何其他任务。此示例的第一部分说明如何使用 Communications Toolbox™ 功能(如调制器、滤波器和通道损伤)来生成合成的训练数据。第二部分着重于针对调制分类任务来定义、训练和测试 CNN。第三部分通过软件无线电 (SDR) 平台使用空口信号来测试网络性能。

生成用于训练的波形

为每种调制类型生成 10000 个帧,其中 80% 用于训练,10% 用于验证,10% 用于测试。我们在网络训练阶段使用训练和验证帧。使用测试帧获得最终分类准确度。每帧的长度为 1024 个样本,采样率为 200 kHz。对于数字调制类型,八个采样表示一个符号。网络根据单个帧而不是多个连续帧(如视频)作出每个决定。假设数字和模拟调制类型的中心频率分别为 902 MHz 和 100 MHz。

要快速运行此示例,请使用经过训练的网络并生成少量训练帧。要在您的计算机上训练网络,请选择“Train network now”选项(即,将 trainNow 设置为 true)。

trainNow = false;
if trainNow == true
  numFramesPerModType = 10000;
else
  numFramesPerModType = 200;
end
percentTrainingSamples = 80;
percentValidationSamples = 10;
percentTestSamples = 10;

sps = 8;                % Samples per symbol
spf = 1024;             % Samples per frame
fs = 200e3;             % Sample rate
fc = [902e6 100e6];     % Center frequencies

创建通道损伤

让每帧通过通道并具有

  • AWGN

  • 莱斯多径衰落

  • 时钟偏移,导致中心频率偏移和采样时间漂移

由于本示例中的网络基于单个帧作出决定,因此每个帧必须通过独立的通道。

AWGN

通道增加 SNR 为 30 dB 的 AWGN。使用 awgn (Communications Toolbox) 函数实现通道。

莱斯多径

通道使用 comm.RicianChannel (Communications Toolbox) System object™ 通过莱斯多径衰落通道传递信号。假设延迟分布为 [0 1.8 3.4] 个样本,对应的平均路径增益为 [0 -2 -10] dB。K 因子为 4,最大多普勒频移为 4 Hz,等效于 902 MHz 的步行速度。使用以下设置实现通道。

时钟偏移

时钟偏移是发射机和接收机的内部时钟源不准确造成的。时钟偏移导致中心频率(用于将信号下变频至基带)和数模转换器采样率不同于理想值。通道仿真器使用时钟偏移因子 C,表示为 C=1+Δclock106,其中 Δclock 是时钟偏移。对于每个帧,通道基于 [-maxΔclock maxΔclock] 范围内一组均匀分布的值生成一个随机 Δclock 值,其中 maxΔclock 是最大时钟偏移。时钟偏移以百万分率 (ppm) 为单位测量。对于本示例,假设最大时钟偏移为 5 ppm。

maxDeltaOff = 5;
deltaOff = (rand()*2*maxDeltaOff) - maxDeltaOff;
C = 1 + (deltaOff/1e6);

频率偏移

基于时钟偏移因子 C 和中心频率,对每帧进行频率偏移。使用 comm.PhaseFrequencyOffset (Communications Toolbox) 实现通道。

采样率偏移

基于时钟偏移因子 C,对每帧进行采样率偏移。使用 interp1 函数实现通道,以 C×fs 的新速率对帧进行重新采样。

合并后的通道

使用 helperModClassTestChannel 对象对帧应用所有三种通道损伤。

channel = helperModClassTestChannel(...
  'SampleRate', fs, ...
  'SNR', SNR, ...
  'PathDelays', [0 1.8 3.4] / fs, ...
  'AveragePathGains', [0 -2 -10], ...
  'KFactor', 4, ...
  'MaximumDopplerShift', 4, ...
  'MaximumClockOffset', 5, ...
  'CenterFrequency', 902e6)
channel = 
  helperModClassTestChannel with properties:

                    SNR: 30
        CenterFrequency: 902000000
             SampleRate: 200000
             PathDelays: [0 9.0000e-06 1.7000e-05]
       AveragePathGains: [0 -2 -10]
                KFactor: 4
    MaximumDopplerShift: 4
     MaximumClockOffset: 5

您可以使用 info 对象函数查看有关通道的基本信息。

chInfo = info(channel)
chInfo = struct with fields:
               ChannelDelay: 6
     MaximumFrequencyOffset: 4510
    MaximumSampleRateOffset: 1

波形生成

创建一个循环,它为每种调制类型生成通道损伤的帧并将这些帧及其对应标签存储在 MAT 文件中。通过将数据保存到文件中,您无需每次运行此示例时都生成数据。您还可以更高效地共享数据。

从每帧的开头删除随机数量的样本,以去除瞬变并确保帧相对于符号边界具有随机起点。

% Set the random number generator to a known state to be able to regenerate
% the same frames every time the simulation is run
rng(12)
tic
numModulationTypes = length(modulationTypes);
channelInfo = info(channel);
transDelay = 50;
pool = getPoolSafe();
if ~isa(pool,"parallel.ClusterPool")
  dataDirectory = fullfile(tempdir,"ModClassDataFiles");
else
  dataDirectory = uigetdir("","Select network location to save data files");
end
disp("Data file directory is " + dataDirectory)
Data file directory is C:\TEMP\ModClassDataFiles
fileNameRoot = "frame";

% Check if data files exist
dataFilesExist = false;
if exist(dataDirectory,'dir')
  files = dir(fullfile(dataDirectory,sprintf("%s*",fileNameRoot)));
  if length(files) == numModulationTypes*numFramesPerModType
    dataFilesExist = true;
  end
end

if ~dataFilesExist
  disp("Generating data and saving in data files...")
  [success,msg,msgID] = mkdir(dataDirectory);
  if ~success
    error(msgID,msg)
  end
  for modType = 1:numModulationTypes
    elapsedTime = seconds(toc);
    elapsedTime.Format = 'hh:mm:ss';
    fprintf('%s - Generating %s frames\n', ...
      elapsedTime, modulationTypes(modType))
    
    label = modulationTypes(modType);
    numSymbols = (numFramesPerModType / sps);
    dataSrc = helperModClassGetSource(modulationTypes(modType), sps, 2*spf, fs);
    modulator = helperModClassGetModulator(modulationTypes(modType), sps, fs);
    if contains(char(modulationTypes(modType)), {'B-FM','DSB-AM','SSB-AM'})
      % Analog modulation types use a center frequency of 100 MHz
      channel.CenterFrequency = 100e6;
    else
      % Digital modulation types use a center frequency of 902 MHz
      channel.CenterFrequency = 902e6;
    end
    
    for p=1:numFramesPerModType
      % Generate random data
      x = dataSrc();
      
      % Modulate
      y = modulator(x);
      
      % Pass through independent channels
      rxSamples = channel(y);
      
      % Remove transients from the beginning, trim to size, and normalize
      frame = helperModClassFrameGenerator(rxSamples, spf, spf, transDelay, sps);
      
      % Save data file
      fileName = fullfile(dataDirectory,...
        sprintf("%s%s%03d",fileNameRoot,modulationTypes(modType),p));
      save(fileName,"frame","label")
    end
  end
else
  disp("Data files exist. Skip data generation.")
end
Generating data and saving in data files...
00:00:00 - Generating 16QAM frames
00:00:01 - Generating 64QAM frames
00:00:02 - Generating 8PSK frames
00:00:04 - Generating B-FM frames
00:00:05 - Generating BPSK frames
00:00:07 - Generating CPFSK frames
00:00:08 - Generating DSB-AM frames
00:00:10 - Generating GFSK frames
00:00:11 - Generating PAM4 frames
00:00:12 - Generating QPSK frames
00:00:14 - Generating SSB-AM frames
% Plot the amplitude of the real and imaginary parts of the example frames
% against the sample number
helperModClassPlotTimeDomain(dataDirectory,modulationTypes,fs)

% Plot the spectrogram of the example frames
helperModClassPlotSpectrogram(dataDirectory,modulationTypes,fs,sps)

创建数据存储

使用 signalDatastore 对象来管理包含生成的复杂波形的文件。如果每个文件可以单独放入内存中,但整个集合不一定能放入内存,则数据存储特别有用。

frameDS = signalDatastore(dataDirectory,'SignalVariableNames',["frame","label"]);

拆分为训练、验证和测试

接下来,将帧分为训练数据、验证数据和测试数据。有关详细信息,请参阅 helperModClassSplitData

splitPercentages = [percentTrainingSamples,percentValidationSamples,percentTestSamples];
[trainDS,validDS,testDS] = helperModClassSplitData(frameDS,splitPercentages);

将数据导入内存

神经网络训练是迭代进行的。在每次迭代中,数据存储从文件中读取数据,变换数据,然后更新网络系数。如果数据可放入计算机的内存中,则将数据从文件导入内存可以消除重复的文件读取和变换过程,从而加快训练速度。这样,只需执行一次从文件读取并变换数据的操作。

将文件中的所有数据导入内存。这些文件有两个变量:framelabel,对数据存储的每个 read 调用都返回一个元胞数组,其中第一个元素是 frame,第二个元素是 label。使用 transform 函数 helperModClassReadFramehelperModClassReadLabel 读取帧和标签。如果您拥有 Parallel Computing Toolbox™ 许可证,请使用 "UseParallel" 选项设置为 truereadall 来启用变换函数的并行处理。由于 readall 函数默认情况下在第一个维度上串联 read 函数的输出,因此以元胞数组的形式返回帧并在第四个维度上手动进行串联。

% Read the training and validation frames into the memory
pctExists = parallelComputingLicenseExists();
trainFrames = transform(trainDS, @helperModClassReadFrame);
rxTrainFrames = readall(trainFrames,"UseParallel",pctExists);
validFrames = transform(validDS, @helperModClassReadFrame);
rxValidFrames = readall(validFrames,"UseParallel",pctExists);

% Read the training and validation labels into the memory
trainLabels = transform(trainDS, @helperModClassReadLabel);
rxTrainLabels = readall(trainLabels,"UseParallel",pctExists);
validLabels = transform(validDS, @helperModClassReadLabel);
rxValidLabels = readall(validLabels,"UseParallel",pctExists);

训练 CNN

本示例使用的 CNN 由五个卷积层和一个全连接层组成。除最后一个卷积层外,每个卷积层后面都有一个批量归一化层、修正线性单元 (ReLU) 激活层和最大池化层。在最后一个卷积层中,最大池化层被一个全局平均池化层取代。输出层具有 softmax 激活。有关网络设计指导原则,请参阅Deep Learning Tips and Tricks

modClassNet = helperModClassCNN(modulationTypes,sps,spf);

接下来配置 TrainingOptionsSGDM 以使用小批量大小为 1024 的 SGDM 求解器。将最大轮数设置为 20,因为更多轮数不会提供进一步的训练优势。默认情况下,'ExecutionEnvironment' 属性设置为 'auto',其中 trainNetwork 函数使用 GPU(如果可用)或使用 CPU(如果 GPU 不可用)。要使用 GPU,您必须拥有 Parallel Computing Toolbox 许可证。将初始学习率设置为 3x10-1。每 6 轮后将学习率降低五分之四。将 'Plots' 设置为 'training-progress' 以对训练进度绘图。在 NVIDIA® GeForce RTX 3080 GPU 上,网络需要大约 3 分钟来完成训练。

maxEpochs = 20;
miniBatchSize = 1024;
trainingPlots = "none";
metrics = [];
verbose = true;
validationFrequency = floor(numel(rxTrainLabels)/miniBatchSize);
options = trainingOptions('sgdm', ...
  InitialLearnRate = 3e-1, ...
  MaxEpochs = maxEpochs, ...
  MiniBatchSize = miniBatchSize, ...
  Shuffle = 'every-epoch', ...
  Plots = trainingPlots, ...
  Verbose = verbose, ...
  ValidationData = {rxValidFrames,rxValidLabels}, ...
  ValidationFrequency = validationFrequency, ...
  ValidationPatience = 5, ...
  Metrics = metrics, ...
  LearnRateSchedule = 'piecewise', ...
  LearnRateDropPeriod = 6, ...
  LearnRateDropFactor = 0.75, ...
  OutputNetwork='best-validation-loss');

或者训练网络,或者使用已经过训练的网络。默认情况下,此示例使用经过训练的网络。

if trainNow == true
  elapsedTime = seconds(toc);
  elapsedTime.Format = 'hh:mm:ss';
  fprintf('%s - Training the network\n', elapsedTime)
  trainedNet = trainnet(rxTrainFrames,rxTrainLabels,modClassNet,"crossentropy",options);
else
  load trainedModulationClassificationNetwork
end

下图显示一个运行示例,其中 trainingPlots 设置为“Training progress”,metric 设置为“Accuracy”,而 verbose 设置为 false。网络在大约 20 轮后收敛于大约 97% 的准确度。

通过获得测试帧的分类准确度来评估经过训练的网络。结果表明,该网络对这组波形实现的准确度达到 96% 左右。

elapsedTime = seconds(toc);
elapsedTime.Format = 'hh:mm:ss';
fprintf('%s - Classifying test frames\n', elapsedTime)
00:00:32 - Classifying test frames
% Read the test frames into the memory
testFrames = transform(testDS, @helperModClassReadFrame);
rxTestFrames = readall(testFrames,"UseParallel",pctExists);

% Read the test labels into the memory
testLabels = transform(testDS, @helperModClassReadLabel);
rxTestLabels = readall(testLabels,"UseParallel",pctExists);

scores = predict(trainedNet,cat(3,rxTestFrames{:}));
rxTestPred = scores2label(scores,modulationTypes);
testAccuracy = mean(rxTestPred == rxTestLabels);
disp("Test accuracy: " + testAccuracy*100 + "%")
Test accuracy: 97.7273%

绘制测试帧的混淆矩阵。如矩阵所示,网络混淆了 16-QAM 和 64-QAM 帧。此问题是预料之中的,因为每个帧只携带 128 个符号,而 16-QAM 是 64-QAM 的子集。该网络还会混淆 DSB-AM 和 SSB-AM 帧,因为 SSB-AM 信号恰好包含 DSB-AM 信号频谱的一半。

figure
cm = confusionchart(rxTestLabels, rxTestPred);
cm.Title = 'Confusion Matrix for Test Data';
cm.RowSummary = 'row-normalized';
cm.Parent.Position = [cm.Parent.Position(1:2) 950 550];

使用 SDR 进行测试

使用 helperModClassSDRTest 函数,通过空口信号测试经过训练的网络的性能。要执行此测试,您必须有专用的 SDR 用于发送和接收。您可以使用两个 ADALM-PLUTO 无线电,或使用一个 ADALM-PLUTO 无线电和一个 USRP® 无线电分别进行发送和接收。您必须Install Support Package for Analog Devices ADALM-PLUTO Radio (Communications Toolbox)。如果您使用的是 USRP® 无线电,您还必须Install Communications Toolbox Support Package for USRP Radio (Communications Toolbox)helperModClassSDRTest 函数使用生成训练信号所用的同一调制函数,然后使用 ADALM-PLUTO 无线电进行传输。使用针对信号接收配置的 SDR(ADALM-PLUTO 或 USRP® 无线电)捕获通道损伤的信号,而不对通道进行仿真。使用经过训练的网络和先前使用的同一 perdict 函数来预测调制类型。运行下一个代码段会生成一个混淆矩阵,并输出测试准确度。

radioPlatform = "ADALM-PLUTO";

switch radioPlatform
  case "ADALM-PLUTO"
    if helperIsPlutoSDRInstalled() == true
      radios = findPlutoRadio();
      if length(radios) >= 2
        helperModClassSDRTest(radios);
      else
        disp('Selected radios not found. Skipping over-the-air test.')
      end
    end
  case {"USRP B2xx","USRP X3xx","USRP N2xx"}
    if (helperIsUSRPInstalled() == true) && (helperIsPlutoSDRInstalled() == true)
      txRadio = findPlutoRadio();
      rxRadio = findsdru();
      switch radioPlatform
        case "USRP B2xx"
          idx = contains({rxRadio.Platform}, {'B200','B210'});
        case "USRP X3xx"
          idx = contains({rxRadio.Platform}, {'X300','X310'});
        case "USRP N2xx"
          idx = contains({rxRadio.Platform}, 'N200/N210/USRP2');
      end
      rxRadio = rxRadio(idx);
      if (length(txRadio) >= 1) && (length(rxRadio) >= 1)
        helperModClassSDRTest(rxRadio);
      else
        disp('Selected radios not found. Skipping over-the-air test.')
      end
    end
end
Selected radios not found. Skipping over-the-air test.

当使用两个相隔约 2 英尺的固定 ADALM-PLUTO 无线电时,网络可实现 99% 的总体准确度,混淆矩阵如下图所示。结果可因试验设置而异。

进一步探查

要提高准确度,可以优化超参数,例如滤波器数量、滤波器大小;或者优化网络结构,例如添加更多层,使用不同激活层等。

Communication Toolbox 提供了更多调制类型和通道损伤。有关详细信息,请参阅Modulation (Communications Toolbox)Propagation and Channel Models (Communications Toolbox)部分。您还可以使用 LTE ToolboxWLAN Toolbox5G Toolbox 添加标准特定信号。您还可以使用 Phased Array System Toolbox 添加雷达信号。

辅助文件

helperModClassGetModulator 函数提供用于生成调制信号的 MATLAB® 函数。您还可以探查以下函数和 System object 以获得详细信息:

局部函数

function pool = getPoolSafe()
if exist("gcp","file") && license('test','distrib_computing_toolbox')
  pool = gcp;
  if isempty(pool)
    pool = parpool;
  end
else
  pool = [];
end
end

参考资料

  1. O'Shea, T. J., J. Corgan, and T. C. Clancy."Convolutional Radio Modulation Recognition Networks."Preprint, submitted June 10, 2016. https://arxiv.org/abs/1602.04105

  2. O'Shea, T. J., T. Roy, and T. C. Clancy."Over-the-Air Deep Learning Based Radio Signal Classification."IEEE Journal of Selected Topics in Signal Processing.Vol. 12, Number 1, 2018, pp. 168–179.

  3. Liu, X., D. Yang, and A. E. Gamal."Deep Neural Network Architectures for Modulation Classification."Preprint, submitted January 5, 2018. https://arxiv.org/abs/1712.00443v3

另请参阅

| |

相关主题