LSTM
This code implements forward propagation and backward propagation of Long-Short Term Memory recurrent neural network. Please note this code is a part of a library so please see below for how to use. This file (LSTM.m) extends BaseLayer class (classdef LSTM < BaseLayer) but it can be deleted.
<initialization>
obj.vis = vis; % number of dimension of input vector
obj.hid = hid; % number of dimension of output vector
obj.T = T; % Length of time step for BPTT
obj.batchSize = batchSize; % size of mini-batch
obj.prms = cell(obj.prmNum, 1); % container for parameters
obj.states = cell(obj.stateNum,1); % container for state of gates, memory cell, hidden units
obj.delta = zeros(vis, batchSize, T); % delta to be passed to a lower layer
initPrms(obj);
initStates(obj);
obj.gprms = obj.prms;
<forward propagation>
LSTM.affineTrans(input); % input should be a matrix size of vis * batchSize * T
output = LSTM.nonlinearTrans();
<backward propagation>
dgate = LSTM.bpropGate(delta); % this delta is from a upper layer
newDelta = LSTM.bpropDelta(dgate);
引用格式
Yuto Ozaki (2024). LSTM (https://www.mathworks.com/matlabcentral/fileexchange/56993-lstm), MATLAB Central File Exchange. 检索来源 .
MATLAB 版本兼容性
平台兼容性
Windows macOS Linux类别
- AI, Data Science, and Statistics > Deep Learning Toolbox > Sequence and Numeric Feature Data Workflows >
标签
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!版本 | 已发布 | 发行说明 | |
---|---|---|---|
1.0.0.0 | Description is modified Description is modified |