主要内容

卷积 LSTM 网络的代码生成

此示例说明如何为包含卷积和双向长短期记忆 (BiLSTM) 层的深度学习网络生成 MEX 函数。生成的函数不使用任何第三方库。生成的 MEX 函数以视频帧序列形式从指定的视频文件中读取数据,并输出对视频中的活动进行分类的标签。有关此网络训练的详细信息,请参阅示例使用深度学习对视频进行分类 (Deep Learning Toolbox)。有关支持的编译器的详细信息,请参阅 使用 MATLAB Coder 进行深度学习的前提条件

此示例在 Mac®、Linux® 和 Windows® 平台上受支持。但 MATLAB® Online™ 不支持此示例。

准备输入视频

使用 readvideo 辅助函数读取视频文件 pushup.mp4。要观看视频,请循环播放视频文件的各个帧,并使用 imshow 函数。

filename = "pushup.mp4";
video = readVideo(filename);
numFrames = size(video,4);
figure
for i = 1:numFrames
    frame = video(:,:,:,i);
    imshow(frame/255);
    drawnow
end

通过使用 centerCrop 辅助函数,将输入视频帧居中裁剪到经过训练的网络的输入大小。

inputSize = [224 224 3];
video = centerCrop(video,inputSize);

video_classify 入口函数

video_classify.m 入口函数获取图像序列,并将其传递给经过训练的网络进行预测。此函数使用示例使用深度学习对视频进行分类 (Deep Learning Toolbox)中的卷积 LSTM 网络。该函数将文件 net.mat 中的网络对象加载到持久变量中,然后使用 classify (Deep Learning Toolbox) 函数来执行预测。该函数在后续调用中将重用该持久性对象。

type('video_classify.m')
function out = video_classify(in) %#codegen
%   Copyright 2021-2024 The MathWorks, Inc.

% A persistent object dlnet is used to load the dlnetwork object. At the
% first call to this function, the persistent object is constructed and
% setup. When the function is called subsequent times, the same object is
% reused to call predict on inputs, thus avoiding reconstructing and
% reloading the network object. A categorial arrary labels is also loaded

persistent dlnet;
persistent labels;

if isempty(dlnet)
    dlnet = coder.loadDeepLearningNetwork('dlnet.mat');
    labels = coder.load('labels.mat');
end

% The dlnetwork object require dlarrays as inputs, convert input to a
% dlarray
dlIn = dlarray(in, 'SSCT');

% pass input to network and perform prediction
dlOut = predict(dlnet, dlIn); 
scores = extractdata(dlOut);

classNames = labels.classNames;

% Convert prediction scores to labels
out = scores2label(scores,classNames,1);

下载预训练的网络

运行 downloadVideoClassificationNetwork 辅助函数以下载视频分类网络,并将网络保存在 MAT 文件 net.mat 中。

downloadVideoClassificationNetwork();

生成 MEX 函数

要生成 MEX 函数,请创建一个名为 cfgcoder.MexCodeConfig 对象。将 cfgTargetLang 属性设置为 C++。要生成不使用任何第三方库的代码,请通过将 targetlib 设置为 none 来使用 coder.DeepLearningConfig 函数。将其赋给 cfg 对象的 DeepLearningConfig 属性。

cfg = coder.config('mex');
cfg.TargetLang = 'C++';
cfg.DeepLearningConfig = coder.DeepLearningConfig('none');

使用 coder.typeof 函数指定入口函数的输入参量的类型和大小。在此示例中,输入为单精度类型,大小为 224×224×3,且序列长度可变。

Input = coder.typeof(single(0),[224 224 3 Inf],[false false false true]);

通过运行 codegen 命令生成 MEX 函数。

codegen -config cfg video_classify -args {Input} -report
Code generation successful: View report

运行生成的 MEX 函数

使用居中裁剪的视频输入运行生成的 MEX 函数。

output = video_classify_mex(single(video))
output = categorical
     pushup 

将预测叠加到输入视频上。

video = readVideo(filename);
numFrames = size(video,4);
figure
for i = 1:numFrames
    frame = video(:,:,:,i);
    frame = insertText(frame, [1 1], char(output), 'TextColor', [255 255 255],'FontSize',30, 'BoxColor', [0 0 0]);
    imshow(frame/255);
    drawnow
end

辅助函数

readVideo 辅助函数在 MATLAB 中或 Jetson™ 设备中读取视频文件,并将其以四维数组形式返回。

function video = readVideo(filename, frameSize)

if coder.target('MATLAB')
    vr = VideoReader(filename);
else
    hwobj = jetson();
    vr = VideoReader(hwobj, filename, 'Width', frameSize(1), 'Height', frameSize(2));
end
H = vr.Height;
W = vr.Width;
C = 3;

% Preallocate video array
numFrames = floor(vr.Duration * vr.FrameRate);
video = zeros(H,W,C,numFrames);

% Read frames
i = 0;
while hasFrame(vr)
    i = i + 1;
    video(:,:,:,i) = readFrame(vr);
end

% Remove unallocated frames
if size(video,4) > i
    video(:,:,:,i+1:end) = [];
end

end

centerCrop 辅助函数根据视频的方向将其裁剪到正方形,并将其调整为指定的输入大小。

function videoResized = centerCrop(video,inputSize)
%   Copyright 2020-2021 The MathWorks, Inc.

sz = size(video);
videoTmp = video;

if sz(1) < sz(2)
    % Video is landscape
    idx = floor((sz(2) - sz(1))/2);
    videoTmp(:,1:(idx-1),:,:) = [];
    videoTmp(:,(sz(1)+1):end,:,:) = [];
    
elseif sz(2) < sz(1)
    % Video is portrait
    idx = floor((sz(1) - sz(2))/2);
    videoTmp(1:(idx-1),:,:,:) = [];
    videoTmp((sz(2)+1):end,:,:,:) = [];
end

videoResized = imresize(videoTmp,inputSize(1:2));
videoResized = reshape(videoResized, inputSize(1), inputSize(2), inputSize(3), []);
end

另请参阅

| |

主题