lbfgsupdate
Syntax
Description
Update the network learnable parameters in a custom training loop using the limited-memory BFGS (L-BFGS) algorithm.
The L-BFGS algorithm [1] is a quasi-Newton method that approximates the Broyden-Fletcher-Goldfarb-Shanno (BFGS) algorithm. Use the L-BFGS algorithm for small networks and data sets that you can process in a single batch.
Note
This function applies the L-BFGS optimization algorithm to update network parameters in
custom training loops. To train a neural network using the trainnet
function
using the L-BFGS solver, use the trainingOptions
function and set the solver to
"lbfgs"
.
[
updates the learnable parameters of the network netUpdated
,solverStateUpdated
] = lbfgsupdate(net
,lossFcn
,solverState
)net
using the L-BFGS
algorithm with the specified loss function and solver state. Use this syntax in a training
loop to iteratively update a network defined as a dlnetwork
object.
[
updates the learnable parameters in parametersUpdated
,solverStateUpdated
] = lbfgsupdate(parameters
,lossFcn
,solverState
)parameters
using the L-BFGS algorithm
with the specified loss function and solver state. Use this syntax in a training loop to
iteratively update the learnable parameters of a network defined as a function.
___ = lbfgsupdate(___,
specifies additional options using one or more name-value arguments.Name=Value
)
Examples
Update Learnable Parameters in Neural Network
Read the transmission casing data from the CSV file "transmissionCasingData.csv"
.
filename = "transmissionCasingData.csv"; tbl = readtable(filename,TextType="String");
Convert the labels for prediction to categorical using the convertvars
function.
labelName = "GearToothCondition"; tbl = convertvars(tbl,labelName,"categorical");
To train a network using categorical features, convert the categorical predictors to categorical using the convertvars
function by specifying a string array containing the names of all the categorical input variables.
categoricalPredictorNames = ["SensorCondition" "ShaftCondition"]; tbl = convertvars(tbl,categoricalPredictorNames,"categorical");
Loop over the categorical input variables. For each variable, convert the categorical values to one-hot encoded vectors using the onehotencode
function.
for i = 1:numel(categoricalPredictorNames) name = categoricalPredictorNames(i); tbl.(name) = onehotencode(tbl.(name),2); end
View the first few rows of the table.
head(tbl)
SigMean SigMedian SigRMS SigVar SigPeak SigPeak2Peak SigSkewness SigKurtosis SigCrestFactor SigMAD SigRangeCumSum SigCorrDimension SigApproxEntropy SigLyapExponent PeakFreq HighFreqPower EnvPower PeakSpecKurtosis SensorCondition ShaftCondition GearToothCondition ________ _________ ______ _______ _______ ____________ ___________ ___________ ______________ _______ ______________ ________________ ________________ _______________ ________ _____________ ________ ________________ _______________ ______________ __________________ -0.94876 -0.9722 1.3726 0.98387 0.81571 3.6314 -0.041525 2.2666 2.0514 0.8081 28562 1.1429 0.031581 79.931 0 6.75e-06 3.23e-07 162.13 0 1 1 0 No Tooth Fault -0.97537 -0.98958 1.3937 0.99105 0.81571 3.6314 -0.023777 2.2598 2.0203 0.81017 29418 1.1362 0.037835 70.325 0 5.08e-08 9.16e-08 226.12 0 1 1 0 No Tooth Fault 1.0502 1.0267 1.4449 0.98491 2.8157 3.6314 -0.04162 2.2658 1.9487 0.80853 31710 1.1479 0.031565 125.19 0 6.74e-06 2.85e-07 162.13 0 1 0 1 No Tooth Fault 1.0227 1.0045 1.4288 0.99553 2.8157 3.6314 -0.016356 2.2483 1.9707 0.81324 30984 1.1472 0.032088 112.5 0 4.99e-06 2.4e-07 162.13 0 1 0 1 No Tooth Fault 1.0123 1.0024 1.4202 0.99233 2.8157 3.6314 -0.014701 2.2542 1.9826 0.81156 30661 1.1469 0.03287 108.86 0 3.62e-06 2.28e-07 230.39 0 1 0 1 No Tooth Fault 1.0275 1.0102 1.4338 1.0001 2.8157 3.6314 -0.02659 2.2439 1.9638 0.81589 31102 1.0985 0.033427 64.576 0 2.55e-06 1.65e-07 230.39 0 1 0 1 No Tooth Fault 1.0464 1.0275 1.4477 1.0011 2.8157 3.6314 -0.042849 2.2455 1.9449 0.81595 31665 1.1417 0.034159 98.838 0 1.73e-06 1.55e-07 230.39 0 1 0 1 No Tooth Fault 1.0459 1.0257 1.4402 0.98047 2.8157 3.6314 -0.035405 2.2757 1.955 0.80583 31554 1.1345 0.0353 44.223 0 1.11e-06 1.39e-07 230.39 0 1 0 1 No Tooth Fault
Extract the training data.
predictorNames = ["SigMean" "SigMedian" "SigRMS" "SigVar" "SigPeak" "SigPeak2Peak" ... "SigSkewness" "SigKurtosis" "SigCrestFactor" "SigMAD" "SigRangeCumSum" ... "SigCorrDimension" "SigApproxEntropy" "SigLyapExponent" "PeakFreq" ... "HighFreqPower" "EnvPower" "PeakSpecKurtosis" "SensorCondition" "ShaftCondition"]; XTrain = table2array(tbl(:,predictorNames)); numInputFeatures = size(XTrain,2);
Extract the targets and convert them to one-hot encoded vectors.
TTrain = tbl.(labelName); TTrain = onehotencode(TTrain,2); numClasses = size(TTrain,2);
Convert the predictors and targets to dlarray
objects with format "BC"
(batch, channel).
XTrain = dlarray(XTrain,"BC"); TTrain = dlarray(TTrain,"BC");
Define the network architecture.
numHiddenUnits = 32; layers = [ featureInputLayer(numInputFeatures) fullyConnectedLayer(16) layerNormalizationLayer reluLayer fullyConnectedLayer(numClasses) softmaxLayer]; net = dlnetwork(layers);
Define the modelLoss
function, listed in the Model Loss Function section of the example. This function takes as input a neural network, input data, and targets. The function returns the loss and the gradients of the loss with respect to the network learnable parameters.
The lbfgsupdate
function requires a loss function with the syntax [loss,gradients] = f(net)
. Create a variable that parameterizes the evaluated modelLoss
function to take a single input argument.
lossFcn = @(net) dlfeval(@modelLoss,net,XTrain,TTrain);
Initialize an L-BFGS solver state object with a maximum history size of 3 and an initial inverse Hessian approximation factor of 1.1.
solverState = lbfgsState( ... HistorySize=3, ... InitialInverseHessianFactor=1.1);
Train the network a maximum of 200 iterations. Stop training early when the norm of the gradients or steps are smaller than 0.00001. Print the training loss every 10 iterations.
maxIterations = 200; gradientTolerance = 1e-5; stepTolerance = 1e-5; iteration = 0; while iteration < maxIterations iteration = iteration + 1; [net, solverState] = lbfgsupdate(net,lossFcn,solverState); if iteration==1 || mod(iteration,10)==0 fprintf("Iteration %d: Loss: %d\n",iteration,solverState.Loss); end if solverState.GradientsNorm < gradientTolerance || ... solverState.StepNorm < stepTolerance || ... solverState.LineSearchStatus == "failed" break end end
Iteration 1: Loss: 9.343236e-01 Iteration 10: Loss: 4.721475e-01 Iteration 20: Loss: 4.678575e-01 Iteration 30: Loss: 4.666964e-01 Iteration 40: Loss: 4.665921e-01 Iteration 50: Loss: 4.663871e-01 Iteration 60: Loss: 4.662519e-01 Iteration 70: Loss: 4.660451e-01 Iteration 80: Loss: 4.645303e-01 Iteration 90: Loss: 4.591753e-01 Iteration 100: Loss: 4.562556e-01 Iteration 110: Loss: 4.531167e-01 Iteration 120: Loss: 4.489444e-01 Iteration 130: Loss: 4.392228e-01 Iteration 140: Loss: 4.347853e-01 Iteration 150: Loss: 4.341757e-01 Iteration 160: Loss: 4.325102e-01 Iteration 170: Loss: 4.321948e-01 Iteration 180: Loss: 4.318990e-01 Iteration 190: Loss: 4.313784e-01 Iteration 200: Loss: 4.311314e-01
Model Loss Function
The modelLoss
function takes as input a neural network net
, input data X
, and targets T
. The function returns the loss and the gradients of the loss with respect to the network learnable parameters.
function [loss, gradients] = modelLoss(net, X, T) Y = forward(net,X); loss = crossentropy(Y,T); gradients = dlgradient(loss,net.Learnables); end
Input Arguments
net
— Neural network
dlnetwork
object
Neural network, specified as a dlnetwork
object.
The function updates the Learnables
property of the
dlnetwork
object. net.Learnables
is a table with
three variables:
Layer
— Layer name, specified as a string scalar.Parameter
— Parameter name, specified as a string scalar.Value
— Parameter value, specified as a cell array containing adlarray
object.
parameters
— Learnable parameters
dlarray
object | numeric array | cell array | structure | table
Learnable parameters, specified as a dlarray
object,
a numeric array, a cell array, a structure, or a table.
If you specify parameters
as a table, it must contain these variables:
Layer
— Layer name, specified as a string scalar.Parameter
— Parameter name, specified as a string scalar.Value
— Parameter value, specified as a cell array containing adlarray
object.
You can specify parameters
as a container of learnable parameters
for your network using a cell array, structure, or table, or using nested cell arrays or
structures. The learnable parameters inside the cell array, structure, or table must be
dlarray
objects or numeric values with the data type
double
or single
.
If parameters
is a numeric array, then lossFcn
must
not use the dlgradient
function.
lossFcn
— Loss function
function handle | AcceleratedFunction
object
Loss function, specified as a function handle or an AcceleratedFunction
object with the syntax [loss,gradients] =
f(net)
, where loss
and gradients
correspond to the loss and gradients of the loss with respect to the learnable
parameters, respectively.
To parametrize a model loss function that has a call to the dlgradient
function, specify the loss function as @(net)
dlfeval(@modelLoss,net,arg1,...,argN)
, where modelLoss
is
a function with the syntax [loss,gradients] =
modelLoss(net,arg1,...,argN)
that returns the loss and gradients of the loss
with respect to the learnable parameters in net
given arguments
arg1,...,argN
.
If parameters
is
a numeric array, then the loss function must not use the dlgradient
or
dlfeval
functions.
If the loss function has more than two outputs, also specify the NumLossFunctionOutputs
argument.
Data Types: function_handle
solverState
— Solver state
lbfgsState
object | []
Solver state, specified as an lbfgsState
object or []
.
Name-Value Arguments
Specify optional pairs of arguments as
Name1=Value1,...,NameN=ValueN
, where Name
is
the argument name and Value
is the corresponding value.
Name-value arguments must appear after other arguments, but the order of the
pairs does not matter.
Example: lbfgsupdate(net,lossFcn,solverState,LineSearchMethod="strong-wolfe")
updates the learnable parameters in net
and searches for a learning rate
that satisfies the strong Wolfe conditions.
LineSearchMethod
— Method to find suitable learning rate
"weak-wolfe"
(default) | "strong-wolfe"
| "backtracking"
Method to find suitable learning rate, specified as one of these values:
"weak-wolfe"
— Search for a learning rate that satisfies the weak Wolfe conditions. This method maintains a positive definite approximation of the inverse Hessian matrix."strong-wolfe"
— Search for a learning rate that satisfies the strong Wolfe conditions. This method maintains a positive definite approximation of the inverse Hessian matrix."backtracking"
— Search for a learning rate that satisfies sufficient decrease conditions. This method does not maintain a positive definite approximation of the inverse Hessian matrix.
MaxNumLineSearchIterations
— Maximum number of line search iterations
20
(default) | positive integer
Maximum number of line search iterations to determine the learning rate, specified as a positive integer.
Data Types: single
| double
| int8
| int16
| int32
| int64
| uint8
| uint16
| uint32
| uint64
NumLossFunctionOutputs
— Number of loss function outputs
2
(default) | integer greater than or equal to two
Number of loss function outputs, specified as an integer greater than or equal to
two. Set this option when lossFcn
has
more than two output arguments.
Data Types: single
| double
| int8
| int16
| int32
| int64
| uint8
| uint16
| uint32
| uint64
Output Arguments
netUpdated
— Updated network
dlnetwork
object
Updated network, returned as a dlnetwork
object.
The function updates the Learnables
property of the
dlnetwork
object.
parametersUpdated
— Updated learnable parameters
dlarray
| numeric array | cell array | structure | table
Updated learnable parameters, returned as an object with the same type as
parameters
.
solverStateUpdated
— Updated solver state
lbfgsState
Updated solver state, returned as an lbfgsState
state object.
Algorithms
Limited-Memory BFGS
The L-BFGS algorithm [1] is a quasi-Newton method that approximates the Broyden-Fletcher-Goldfarb-Shanno (BFGS) algorithm. Use the L-BFGS algorithm for small networks and data sets that you can process in a single batch.
The algorithm updates learnable parameters W at iteration k+1 using the update step given by
where Wk denotes the weights at iteration k, is the learning rate at iteration k, Bk is an approximation of the Hessian matrix at iteration k, and denotes the gradients of the loss with respect to the learnable parameters at iteration k.
The L-BFGS algorithm computes the matrix-vector product directly. The algorithm does not require computing the inverse of Bk.
To save memory, the L-BFGS algorithm does not store and invert the dense Hessian matrix B. Instead, the algorithm uses the approximation , where m is the history size, the inverse Hessian factor is a scalar, and I is the identity matrix. The algorithm then stores the scalar inverse Hessian factor only. The algorithm updates the inverse Hessian factor at each step.
To compute the matrix-vector product directly, the L-BFGS algorithm uses this recursive algorithm:
Set , where m is the history size.
For :
Let , where and are the step and gradient differences for iteration , respectively.
Set , where is derived from , , and the gradients of the loss with respect to the loss function. For more information, see [1].
Return .
References
[1] Liu, Dong C., and Jorge Nocedal. "On the limited memory BFGS method for large scale optimization." Mathematical programming 45, no. 1 (August 1989): 503-528. https://doi.org/10.1007/BF01589116.
Extended Capabilities
GPU Arrays
Accelerate code by running on a graphics processing unit (GPU) using Parallel Computing Toolbox™.
The lbfgsupdate
function
supports GPU array input with these usage notes and limitations:
When
lossFcn
outputs data of typegpuArray
ordlarray
with underlying data of typegpuArray
, this function runs on the GPU.
For more information, see Run MATLAB Functions on a GPU (Parallel Computing Toolbox).
Version History
Introduced in R2023a
See Also
adamupdate
| rmspropupdate
| sgdmupdate
| dlupdate
| lbfgsState
| dlnetwork
| dlarray
| dlgradient
| dljacobian
| dldivergence
| dllaplacian
MATLAB 命令
您点击的链接对应于以下 MATLAB 命令:
请在 MATLAB 命令行窗口中直接输入以执行命令。Web 浏览器不支持 MATLAB 命令。
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .
You can also select a web site from the following list:
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)