使用深度学习进行调制分类
此示例说明如何使用卷积神经网络 (CNN) 进行调制分类。您将生成合成的、通道损伤波形。使用生成的波形作为训练数据,训练 CNN 进行调制分类。然后用软件无线电 (SDR) 硬件和空口信号测试 CNN。
使用 CNN 预测调制类型
本示例中经过训练的 CNN 可识别以下八种数字调制类型和三种模拟调制类型:
二相相移键控 (BPSK)
四相相移键控 (QPSK)
八相相移键控 (8-PSK)
十六相正交调幅 (16-QAM)
六十四相正交调幅 (64-QAM)
四相脉冲振幅调制 (PAM4)
高斯频移键控 (GFSK)
连续相位频移键控 (CPFSK)
广播 FM (B-FM)
双边带振幅调制 (DSB-AM)
单边带振幅调制 (SSB-AM)
modulationTypes = categorical(sort(["BPSK", "QPSK", "8PSK", ... "16QAM", "64QAM", "PAM4", "GFSK", "CPFSK", ... "B-FM", "DSB-AM", "SSB-AM"]));
首先,加载经过训练的网络。有关网络训练的详细信息,请参阅“训练 CNN”一节。
load trainedModulationClassificationNetwork
trainedNet
trainedNet = dlnetwork with properties: Layers: [19×1 nnet.cnn.layer.Layer] Connections: [18×2 table] Learnables: [22×3 table] State: [10×3 table] InputNames: {'Input Layer'} OutputNames: {'SoftMax'} Initialized: 1 View summary with summary.
经过训练的 CNN 接受 1024 个通道损伤采样,并预测每个帧的调制类型。生成几个因莱斯多径衰落、中心频率和采样时间漂移以及 AWGN 而有所减损的 PAM4 帧。使用以下函数生成合成信号来测试 CNN。然后使用 CNN 预测帧的调制类型。
randi
:生成随机位pammod
(Communications Toolbox):PAM4 调制位rcosdesign
(Signal Processing Toolbox):设计平方根升余弦脉冲整形滤波器filter
:脉冲确定符号的形状comm.RicianChannel
(Communications Toolbox):应用莱斯多径通道comm.PhaseFrequencyOffset
(Communications Toolbox):应用时钟偏移引起的相位和/或频率偏移interp1
:应用时钟偏移引起的计时漂移awgn
(Communications Toolbox):添加 AWGN
% Set the random number generator to a known state to be able to regenerate % the same frames every time the simulation is run rng(123456) % Random bits d = randi([0 3], 1024, 1); % PAM4 modulation syms = pammod(d,4); % Square-root raised cosine filter filterCoeffs = rcosdesign(0.35,4,8); tx = filter(filterCoeffs,1,upsample(syms,8)); % Channel SNR = 30; maxOffset = 5; fc = 902e6; fs = 200e3; multipathChannel = comm.RicianChannel(... 'SampleRate', fs, ... 'PathDelays', [0 1.8 3.4] / 200e3, ... 'AveragePathGains', [0 -2 -10], ... 'KFactor', 4, ... 'MaximumDopplerShift', 4); frequencyShifter = comm.PhaseFrequencyOffset(... 'SampleRate', fs); % Apply an independent multipath channel reset(multipathChannel) outMultipathChan = multipathChannel(tx); % Determine clock offset factor clockOffset = (rand() * 2*maxOffset) - maxOffset; C = 1 + clockOffset / 1e6; % Add frequency offset frequencyShifter.FrequencyOffset = -(C-1)*fc; outFreqShifter = frequencyShifter(outMultipathChan); % Add sampling time drift t = (0:length(tx)-1)' / fs; newFs = fs * C; tp = (0:length(tx)-1)' / newFs; outTimeDrift = interp1(t, outFreqShifter, tp); % Add noise rx = awgn(outTimeDrift,SNR,0); % Frame generation for classification unknownFrames = helperModClassGetNNFrames(rx); % Classification scores1 = predict(trainedNet,unknownFrames); prediction1 = scores2label(scores1,modulationTypes);
返回分类器预测,这类似于硬判决。网络将帧正确识别为 PAM4 帧。有关生成调制信号的详细信息,请参阅 helperModClassGetModulator 函数。
prediction1
prediction1 = 7×1 categorical
PAM4
PAM4
PAM4
PAM4
PAM4
PAM4
PAM4
分类器还返回一个包含每一帧分数的向量。分数对应于每个帧具有预测的调制类型的概率。绘制分数图。
helperModClassPlotScores(scores1,modulationTypes)
我们首先需要用已知(即已加标签的)数据训练 CNN,然后才能使用 CNN 进行调制分类或执行任何其他任务。此示例的第一部分说明如何使用 Communications Toolbox™ 功能(如调制器、滤波器和通道损伤)来生成合成的训练数据。第二部分着重于针对调制分类任务来定义、训练和测试 CNN。第三部分通过软件无线电 (SDR) 平台使用空口信号来测试网络性能。
生成用于训练的波形
为每种调制类型生成 10000 个帧,其中 80% 用于训练,10% 用于验证,10% 用于测试。我们在网络训练阶段使用训练和验证帧。使用测试帧获得最终分类准确度。每帧的长度为 1024 个样本,采样率为 200 kHz。对于数字调制类型,八个采样表示一个符号。网络根据单个帧而不是多个连续帧(如视频)作出每个决定。假设数字和模拟调制类型的中心频率分别为 902 MHz 和 100 MHz。
要快速运行此示例,请使用经过训练的网络并生成少量训练帧。要在您的计算机上训练网络,请选择“Train network now”选项(即,将 trainNow 设置为 true)。
trainNow = false; if trainNow == true numFramesPerModType = 10000; else numFramesPerModType = 200; end percentTrainingSamples = 80; percentValidationSamples = 10; percentTestSamples = 10; sps = 8; % Samples per symbol spf = 1024; % Samples per frame fs = 200e3; % Sample rate fc = [902e6 100e6]; % Center frequencies
创建通道损伤
让每帧通过通道并具有
AWGN
莱斯多径衰落
时钟偏移,导致中心频率偏移和采样时间漂移
由于本示例中的网络基于单个帧作出决定,因此每个帧必须通过独立的通道。
AWGN
通道增加 SNR 为 30 dB 的 AWGN。使用 awgn
(Communications Toolbox) 函数实现通道。
莱斯多径
通道使用 comm.RicianChannel
(Communications Toolbox) System object™ 通过莱斯多径衰落通道传递信号。假设延迟分布为 [0 1.8 3.4] 个样本,对应的平均路径增益为 [0 -2 -10] dB。K 因子为 4,最大多普勒频移为 4 Hz,等效于 902 MHz 的步行速度。使用以下设置实现通道。
时钟偏移
时钟偏移是发射机和接收机的内部时钟源不准确造成的。时钟偏移导致中心频率(用于将信号下变频至基带)和数模转换器采样率不同于理想值。通道仿真器使用时钟偏移因子 ,表示为 ,其中 是时钟偏移。对于每个帧,通道基于 [ ] 范围内一组均匀分布的值生成一个随机 值,其中 是最大时钟偏移。时钟偏移以百万分率 (ppm) 为单位测量。对于本示例,假设最大时钟偏移为 5 ppm。
maxDeltaOff = 5; deltaOff = (rand()*2*maxDeltaOff) - maxDeltaOff; C = 1 + (deltaOff/1e6);
频率偏移
基于时钟偏移因子 和中心频率,对每帧进行频率偏移。使用 comm.PhaseFrequencyOffset
(Communications Toolbox) 实现通道。
采样率偏移
基于时钟偏移因子 ,对每帧进行采样率偏移。使用 interp1
函数实现通道,以 的新速率对帧进行重新采样。
合并后的通道
使用 helperModClassTestChannel 对象对帧应用所有三种通道损伤。
channel = helperModClassTestChannel(... 'SampleRate', fs, ... 'SNR', SNR, ... 'PathDelays', [0 1.8 3.4] / fs, ... 'AveragePathGains', [0 -2 -10], ... 'KFactor', 4, ... 'MaximumDopplerShift', 4, ... 'MaximumClockOffset', 5, ... 'CenterFrequency', 902e6)
channel = helperModClassTestChannel with properties: SNR: 30 CenterFrequency: 902000000 SampleRate: 200000 PathDelays: [0 9.0000e-06 1.7000e-05] AveragePathGains: [0 -2 -10] KFactor: 4 MaximumDopplerShift: 4 MaximumClockOffset: 5
您可以使用 info 对象函数查看有关通道的基本信息。
chInfo = info(channel)
chInfo = struct with fields:
ChannelDelay: 6
MaximumFrequencyOffset: 4510
MaximumSampleRateOffset: 1
波形生成
创建一个循环,它为每种调制类型生成通道损伤的帧并将这些帧及其对应标签存储在 MAT 文件中。通过将数据保存到文件中,您无需每次运行此示例时都生成数据。您还可以更高效地共享数据。
从每帧的开头删除随机数量的样本,以去除瞬变并确保帧相对于符号边界具有随机起点。
% Set the random number generator to a known state to be able to regenerate % the same frames every time the simulation is run rng(12) tic numModulationTypes = length(modulationTypes); channelInfo = info(channel); transDelay = 50; pool = getPoolSafe(); if ~isa(pool,"parallel.ClusterPool") dataDirectory = fullfile(tempdir,"ModClassDataFiles"); else dataDirectory = uigetdir("","Select network location to save data files"); end disp("Data file directory is " + dataDirectory)
Data file directory is C:\TEMP\ModClassDataFiles
fileNameRoot = "frame"; % Check if data files exist dataFilesExist = false; if exist(dataDirectory,'dir') files = dir(fullfile(dataDirectory,sprintf("%s*",fileNameRoot))); if length(files) == numModulationTypes*numFramesPerModType dataFilesExist = true; end end if ~dataFilesExist disp("Generating data and saving in data files...") [success,msg,msgID] = mkdir(dataDirectory); if ~success error(msgID,msg) end for modType = 1:numModulationTypes elapsedTime = seconds(toc); elapsedTime.Format = 'hh:mm:ss'; fprintf('%s - Generating %s frames\n', ... elapsedTime, modulationTypes(modType)) label = modulationTypes(modType); numSymbols = (numFramesPerModType / sps); dataSrc = helperModClassGetSource(modulationTypes(modType), sps, 2*spf, fs); modulator = helperModClassGetModulator(modulationTypes(modType), sps, fs); if contains(char(modulationTypes(modType)), {'B-FM','DSB-AM','SSB-AM'}) % Analog modulation types use a center frequency of 100 MHz channel.CenterFrequency = 100e6; else % Digital modulation types use a center frequency of 902 MHz channel.CenterFrequency = 902e6; end for p=1:numFramesPerModType % Generate random data x = dataSrc(); % Modulate y = modulator(x); % Pass through independent channels rxSamples = channel(y); % Remove transients from the beginning, trim to size, and normalize frame = helperModClassFrameGenerator(rxSamples, spf, spf, transDelay, sps); % Save data file fileName = fullfile(dataDirectory,... sprintf("%s%s%03d",fileNameRoot,modulationTypes(modType),p)); save(fileName,"frame","label") end end else disp("Data files exist. Skip data generation.") end
Generating data and saving in data files...
00:00:00 - Generating 16QAM frames 00:00:01 - Generating 64QAM frames 00:00:02 - Generating 8PSK frames 00:00:04 - Generating B-FM frames 00:00:05 - Generating BPSK frames 00:00:07 - Generating CPFSK frames 00:00:08 - Generating DSB-AM frames 00:00:10 - Generating GFSK frames 00:00:11 - Generating PAM4 frames 00:00:12 - Generating QPSK frames 00:00:14 - Generating SSB-AM frames
% Plot the amplitude of the real and imaginary parts of the example frames % against the sample number helperModClassPlotTimeDomain(dataDirectory,modulationTypes,fs)
% Plot the spectrogram of the example frames
helperModClassPlotSpectrogram(dataDirectory,modulationTypes,fs,sps)
创建数据存储
使用 signalDatastore
对象来管理包含生成的复杂波形的文件。如果每个文件可以单独放入内存中,但整个集合不一定能放入内存,则数据存储特别有用。
frameDS = signalDatastore(dataDirectory,'SignalVariableNames',["frame","label"]);
拆分为训练、验证和测试
接下来,将帧分为训练数据、验证数据和测试数据。有关详细信息,请参阅 helperModClassSplitData。
splitPercentages = [percentTrainingSamples,percentValidationSamples,percentTestSamples]; [trainDS,validDS,testDS] = helperModClassSplitData(frameDS,splitPercentages);
将数据导入内存
神经网络训练是迭代进行的。在每次迭代中,数据存储从文件中读取数据,变换数据,然后更新网络系数。如果数据可放入计算机的内存中,则将数据从文件导入内存可以消除重复的文件读取和变换过程,从而加快训练速度。这样,只需执行一次从文件读取并变换数据的操作。
将文件中的所有数据导入内存。这些文件有两个变量:frame
和 label
,对数据存储的每个 read
调用都返回一个元胞数组,其中第一个元素是 frame
,第二个元素是 label
。使用 transform
函数 helperModClassReadFrame 和 helperModClassReadLabel 读取帧和标签。如果您拥有 Parallel Computing Toolbox™ 许可证,请使用 "UseParallel"
选项设置为 true
的 readall
来启用变换函数的并行处理。由于 readall
函数默认情况下在第一个维度上串联 read
函数的输出,因此以元胞数组的形式返回帧并在第四个维度上手动进行串联。
% Read the training and validation frames into the memory pctExists = parallelComputingLicenseExists(); trainFrames = transform(trainDS, @helperModClassReadFrame); rxTrainFrames = readall(trainFrames,"UseParallel",pctExists); validFrames = transform(validDS, @helperModClassReadFrame); rxValidFrames = readall(validFrames,"UseParallel",pctExists); % Read the training and validation labels into the memory trainLabels = transform(trainDS, @helperModClassReadLabel); rxTrainLabels = readall(trainLabels,"UseParallel",pctExists); validLabels = transform(validDS, @helperModClassReadLabel); rxValidLabels = readall(validLabels,"UseParallel",pctExists);
训练 CNN
本示例使用的 CNN 由五个卷积层和一个全连接层组成。除最后一个卷积层外,每个卷积层后面都有一个批量归一化层、修正线性单元 (ReLU) 激活层和最大池化层。在最后一个卷积层中,最大池化层被一个全局平均池化层取代。输出层具有 softmax 激活。有关网络设计指导原则,请参阅Deep Learning Tips and Tricks。
modClassNet = helperModClassCNN(modulationTypes,sps,spf);
接下来配置 TrainingOptionsSGDM
以使用小批量大小为 1024 的 SGDM 求解器。将最大轮数设置为 20,因为更多轮数不会提供进一步的训练优势。默认情况下,'ExecutionEnvironment'
属性设置为 'auto'
,其中 trainNetwork
函数使用 GPU(如果可用)或使用 CPU(如果 GPU 不可用)。要使用 GPU,您必须拥有 Parallel Computing Toolbox 许可证。将初始学习率设置为 。每 6 轮后将学习率降低五分之四。将 'Plots'
设置为 'training-progress'
以对训练进度绘图。在 NVIDIA® GeForce RTX 3080 GPU 上,网络需要大约 3 分钟来完成训练。
maxEpochs = 20; miniBatchSize = 1024; trainingPlots = "none"; metrics = []; verbose = true; validationFrequency = floor(numel(rxTrainLabels)/miniBatchSize); options = trainingOptions('sgdm', ... InitialLearnRate = 3e-1, ... MaxEpochs = maxEpochs, ... MiniBatchSize = miniBatchSize, ... Shuffle = 'every-epoch', ... Plots = trainingPlots, ... Verbose = verbose, ... ValidationData = {rxValidFrames,rxValidLabels}, ... ValidationFrequency = validationFrequency, ... ValidationPatience = 5, ... Metrics = metrics, ... LearnRateSchedule = 'piecewise', ... LearnRateDropPeriod = 6, ... LearnRateDropFactor = 0.75, ... OutputNetwork='best-validation-loss');
或者训练网络,或者使用已经过训练的网络。默认情况下,此示例使用经过训练的网络。
if trainNow == true elapsedTime = seconds(toc); elapsedTime.Format = 'hh:mm:ss'; fprintf('%s - Training the network\n', elapsedTime) trainedNet = trainnet(rxTrainFrames,rxTrainLabels,modClassNet,"crossentropy",options); else load trainedModulationClassificationNetwork end
下图显示一个运行示例,其中 trainingPlots
设置为“Training progress”,metric
设置为“Accuracy”,而 verbose
设置为 false
。网络在大约 20 轮后收敛于大约 97% 的准确度。
通过获得测试帧的分类准确度来评估经过训练的网络。结果表明,该网络对这组波形实现的准确度达到 96% 左右。
elapsedTime = seconds(toc); elapsedTime.Format = 'hh:mm:ss'; fprintf('%s - Classifying test frames\n', elapsedTime)
00:00:32 - Classifying test frames
% Read the test frames into the memory testFrames = transform(testDS, @helperModClassReadFrame); rxTestFrames = readall(testFrames,"UseParallel",pctExists); % Read the test labels into the memory testLabels = transform(testDS, @helperModClassReadLabel); rxTestLabels = readall(testLabels,"UseParallel",pctExists); scores = predict(trainedNet,cat(3,rxTestFrames{:})); rxTestPred = scores2label(scores,modulationTypes); testAccuracy = mean(rxTestPred == rxTestLabels); disp("Test accuracy: " + testAccuracy*100 + "%")
Test accuracy: 97.7273%
绘制测试帧的混淆矩阵。如矩阵所示,网络混淆了 16-QAM 和 64-QAM 帧。此问题是预料之中的,因为每个帧只携带 128 个符号,而 16-QAM 是 64-QAM 的子集。该网络还会混淆 DSB-AM 和 SSB-AM 帧,因为 SSB-AM 信号恰好包含 DSB-AM 信号频谱的一半。
figure cm = confusionchart(rxTestLabels, rxTestPred); cm.Title = 'Confusion Matrix for Test Data'; cm.RowSummary = 'row-normalized'; cm.Parent.Position = [cm.Parent.Position(1:2) 950 550];
使用 SDR 进行测试
使用 helperModClassSDRTest 函数,通过空口信号测试经过训练的网络的性能。要执行此测试,您必须有专用的 SDR 用于发送和接收。您可以使用两个 ADALM-PLUTO 无线电,或使用一个 ADALM-PLUTO 无线电和一个 USRP® 无线电分别进行发送和接收。您必须Install Support Package for Analog Devices ADALM-PLUTO Radio (Communications Toolbox)。如果您使用的是 USRP® 无线电,您还必须Install Communications Toolbox Support Package for USRP Radio (Communications Toolbox)。helperModClassSDRTest
函数使用生成训练信号所用的同一调制函数,然后使用 ADALM-PLUTO 无线电进行传输。使用针对信号接收配置的 SDR(ADALM-PLUTO 或 USRP® 无线电)捕获通道损伤的信号,而不对通道进行仿真。使用经过训练的网络和先前使用的同一 perdict
函数来预测调制类型。运行下一个代码段会生成一个混淆矩阵,并输出测试准确度。
radioPlatform = "ADALM-PLUTO"; switch radioPlatform case "ADALM-PLUTO" if helperIsPlutoSDRInstalled() == true radios = findPlutoRadio(); if length(radios) >= 2 helperModClassSDRTest(radios); else disp('Selected radios not found. Skipping over-the-air test.') end end case {"USRP B2xx","USRP X3xx","USRP N2xx"} if (helperIsUSRPInstalled() == true) && (helperIsPlutoSDRInstalled() == true) txRadio = findPlutoRadio(); rxRadio = findsdru(); switch radioPlatform case "USRP B2xx" idx = contains({rxRadio.Platform}, {'B200','B210'}); case "USRP X3xx" idx = contains({rxRadio.Platform}, {'X300','X310'}); case "USRP N2xx" idx = contains({rxRadio.Platform}, 'N200/N210/USRP2'); end rxRadio = rxRadio(idx); if (length(txRadio) >= 1) && (length(rxRadio) >= 1) helperModClassSDRTest(rxRadio); else disp('Selected radios not found. Skipping over-the-air test.') end end end
Selected radios not found. Skipping over-the-air test.
当使用两个相隔约 2 英尺的固定 ADALM-PLUTO 无线电时,网络可实现 99% 的总体准确度,混淆矩阵如下图所示。结果可因试验设置而异。
进一步探查
要提高准确度,可以优化超参数,例如滤波器数量、滤波器大小;或者优化网络结构,例如添加更多层,使用不同激活层等。
Communication Toolbox 提供了更多调制类型和通道损伤。有关详细信息,请参阅Modulation (Communications Toolbox)和Propagation and Channel Models (Communications Toolbox)部分。您还可以使用 LTE Toolbox、WLAN Toolbox 和 5G Toolbox 添加标准特定信号。您还可以使用 Phased Array System Toolbox 添加雷达信号。
辅助文件
helperModClassGetModulator 函数提供用于生成调制信号的 MATLAB® 函数。您还可以探查以下函数和 System object 以获得详细信息:
局部函数
function pool = getPoolSafe() if exist("gcp","file") && license('test','distrib_computing_toolbox') pool = gcp; if isempty(pool) pool = parpool; end else pool = []; end end
参考资料
O'Shea, T. J., J. Corgan, and T. C. Clancy."Convolutional Radio Modulation Recognition Networks."Preprint, submitted June 10, 2016. https://arxiv.org/abs/1602.04105
O'Shea, T. J., T. Roy, and T. C. Clancy."Over-the-Air Deep Learning Based Radio Signal Classification."IEEE Journal of Selected Topics in Signal Processing.Vol. 12, Number 1, 2018, pp. 168–179.
Liu, X., D. Yang, and A. E. Gamal."Deep Neural Network Architectures for Modulation Classification."Preprint, submitted January 5, 2018. https://arxiv.org/abs/1712.00443v3
另请参阅
trainnet
| trainingOptions
| dlnetwork