主要内容

dlfeval

评估用于自定义训练循环的深度学习模型

说明

dlfeval 函数用于评估启用了自动微分的深度学习模型并计算相关函数。要计算梯度,请使用 dlgradient 函数。

提示

对于大多数深度学习任务,您可以使用预训练神经网络,并使其适应您自己的数据。有关说明如何使用迁移学习来重新训练卷积神经网络以对一组新图像进行分类的示例,请参阅重新训练神经网络以对新图像进行分类。您也可以使用 trainnettrainingOptions 函数从头创建和训练神经网络。

如果 trainingOptions 函数没有提供您的任务所需的训练选项,则您可以使用自动微分创建自定义训练循环。要了解详细信息,请参阅使用自定义训练循环训练网络

如果 trainnet 函数没有提供您的任务所需的损失函数,您可以将 trainnet 的自定义损失函数指定为函数句柄。对于需要比预测值和目标值更多输入的损失函数(例如,需要访问神经网络或额外输入的损失函数),请使用自定义训练循环来训练模型。要了解详细信息,请参阅使用自定义训练循环训练网络

如果 Deep Learning Toolbox™ 没有提供您的任务所需的层,则您可以创建一个自定义层。要了解详细信息,请参阅定义自定义深度学习层。对于无法指定为由层组成的网络的模型,可以将模型定义为函数。要了解详细信息,请参阅Train Network Using Model Function

有关对哪项任务使用哪种训练方法的详细信息,请参阅Train Deep Learning Model in MATLAB

[y1,...,yk] = dlfeval(fun,x1,...,xn) 基于输入参量 x1,...,xn 计算深度学习数组函数 fun。传递给 dlfeval 的函数可以包含对 dlgradient 的调用,这些调用使用自动微分根据输入 x1,...,xn 计算梯度。

示例

示例

全部折叠

罗森布罗克函数是用于优化的标准测试函数。rosenbrock.m 辅助函数计算函数值并使用自动微分来计算其梯度。

type rosenbrock.m
function [y,dydx] = rosenbrock(x)

y = 100*(x(2) - x(1).^2).^2 + (1 - x(1)).^2;
dydx = dlgradient(y,x);

end

要计算点 [–1,2] 处的罗森布罗克函数值及其梯度,请创建该点的 dlarray,然后在函数句柄 @rosenbrock 上调用 dlfeval

x0 = dlarray([-1,2]);
[fval,gradval] = dlfeval(@rosenbrock,x0)
fval = 
  1×1 dlarray

   104

gradval = 
  1×2 dlarray

   396   200

或者,将罗森布罗克函数定义为具有两个输入(x1 和 x2)的函数。

type rosenbrock2.m
function [y,dydx1,dydx2] = rosenbrock2(x1,x2)

y = 100*(x2 - x1.^2).^2 + (1 - x1).^2;
[dydx1,dydx2] = dlgradient(y,x1,x2);

end

调用 dlfeval 来对表示输入 –12 的两个 dlarray 参量计算 rosenbrock2

x1 = dlarray(-1);
x2 = dlarray(2);
[fval,dydx1,dydx2] = dlfeval(@rosenbrock2,x1,x2)
fval = 
  1×1 dlarray

   104

dydx1 = 
  1×1 dlarray

   396

dydx2 = 
  1×1 dlarray

   200

绘制单位正方形中多个点的罗森布罗克函数梯度图。首先,初始化表示计算点和函数输出的数组。

[X1 X2] = meshgrid(linspace(0,1,10));
X1 = dlarray(X1(:));
X2 = dlarray(X2(:));
Y = dlarray(zeros(size(X1)));
DYDX1 = Y;
DYDX2 = Y;

在循环中计算函数值。使用 quiver 绘制结果。

for i = 1:length(X1)
    [Y(i),DYDX1(i),DYDX2(i)] = dlfeval(@rosenbrock2,X1(i),X2(i));
end
quiver(extractdata(X1),extractdata(X2),extractdata(DYDX1),extractdata(DYDX2))
xlabel('x1')
ylabel('x2')

Figure contains an axes object. The axes object with xlabel x1, ylabel x2 contains an object of type quiver.

使用 dlgradientdlfeval 计算涉及复数的函数的值和梯度。您可以计算复数梯度,也可以将梯度仅限为实数。

定义在此示例末尾列出的 complexFun 函数。此函数实现以下复数公式:

f(x)=(2+3i)x

定义在此示例末尾列出的 gradFun 函数。此函数调用 complexFun 并使用 dlgradient 来计算结果关于输入的梯度。对于自动微分,要微分的值(即根据输入计算出的函数值)必须是实数标量,因此该函数在计算梯度之前会取结果实部之和。该函数返回函数值的实部和梯度(可能是复数)。

在复平面上定义 -2 到 2 以及 -2i 到 2i 之间的采样点,并转换为 dlarray

functionRes = linspace(-2,2,100);
x = functionRes + 1i*functionRes.';
x = dlarray(x);

计算每个样本点处的函数值和梯度。

[y, grad] = dlfeval(@gradFun,x);
y = extractdata(y);

定义要显示其梯度的样本点。

gradientRes = linspace(-2,2,11);
xGrad = gradientRes + 1i*gradientRes.';

提取这些样本点处的梯度值。

[~,gradPlot] = dlfeval(@gradFun,dlarray(xGrad));
gradPlot = extractdata(gradPlot);

绘制结果。使用 imagesc 在复平面上显示函数的值。使用 quiver 来显示梯度的方向和幅值。

imagesc([-2,2],[-2,2],y);
axis xy
colorbar
hold on
quiver(real(xGrad),imag(xGrad),real(gradPlot),imag(gradPlot),"k");
xlabel("Real")
ylabel("Imaginary")
title("Real Value and Gradient","Re$(f(x)) = $ Re$((2+3i)x)$","interpreter","latex")

该函数的梯度在整个复平面上是相同的。提取自动微分计算出的梯度值。

grad(1,1)
ans = 
  1×1 dlarray

   2.0000 - 3.0000i

经检查,函数的复数导数的值为

df(x)dx=2+3i

然而,Re(f(x)) 函数不是解析函数,因此没有定义复数导数。对于 MATLAB 中的自动微分,要微分的值必须始终是实数,因此该函数永远不能是复数解析函数。在这种情况下,计算导数,使得返回的梯度指向上升坡度最大的方向,如图所示。这是通过将函数 Re(f(x)):C R 解释为函数 Re(f(xR+ixI)):R × R R 来完成的。

function y = complexFun(x)
    y = (2+3i)*x;    
end

function [y,grad] = gradFun(x)
    y = complexFun(x);
    y = real(y);

    grad = dlgradient(sum(y,"all"),x);
end

输入参数

全部折叠

要计算的函数,指定为函数句柄。如果 fun 包含 dlgradient 调用,则 dlfeval 使用自动微分来计算梯度。在此梯度计算中,dlgradient 调用的每个参量都必须是 dlarray 或包含 dlarray 的元胞数组、结构体或表。dlfeval 的输入参量数必须与 fun 的输入参量数相同。

示例: @rosenbrock

数据类型: function_handle

函数参量,指定为任何 MATLAB 数据类型或 dlnetwork 对象。不支持量化的 dlnetwork 对象。

作为 dlgradient 调用中的微分变量的输入参量 xj 必须是跟踪的 dlarray 或包含跟踪的 dlarray 的元胞数组、结构体或表。超参数或常量数据数组等额外变量不必是 dlarray

要为深度学习计算梯度,您可以提供 dlnetwork 对象作为函数参量,并在 fun 内计算网络的前向传导。

示例: dlarray([1 2;3 4])

数据类型: single | double | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64 | logical | char | string | struct | table | cell | function_handle | categorical | datetime | duration | calendarDuration | fi
复数支持:

输出参量

全部折叠

函数输出,以任何数据类型的形式返回。如果输出来自 dlgradient 调用,则输出为 dlarray

提示

算法

全部折叠

扩展功能

全部展开

版本历史记录

在 R2019b 中推出