機械学習の入力エラーについて

2 次查看(过去 30 天)
Yuuki
Yuuki 2020-11-23
评论: Yuuki 2020-11-30
LSTMの学習方法について質問です.
最下部に示したコードを実行したとき,「予測子はシーケンスの N 行 1 列の cell 配列でなければなりません。」が表示されうまく学習できません.
入力データは1タイムステップに t-2, t-1, t のデータが含まれており,それに対応する出力データは t+1 のデータとなっています.
ここで学習に用いるデータを
net = trainNetwork(XTrain_C, YTrain_C, layers, options);
のように,Cのみを用いるようにすると上手く実行できるのですが,元コードのようにA,C,Dの3つの時系列データを学習させたモデルを作成しようとするとエラー文が表示されてしまいます.
下記のコードをどう修正すれば実行可能になりますでしょうか?
clear all
close all
%% Make dataset
A = zeros(1,100);
B = zeros(1,100);
C = zeros(1,100);
D = zeros(1,100);
% A
for i = 1:100
if i <= 40
A(:,i) = i / 40;
elseif (41 <= i) && (i <= 45)
A(:,i) = 1 - ((i - 40) / 5);
elseif 46 <= i
A(:,i) = 0;
end
end
% B
for i = 1:100
if i <= 60
B(:,i) = i / 60;
elseif (61 <= i) && (i <= 65)
B(:,i) = 1 - ((i - 60) / 5);
elseif 66 <= i
B(:,i) = 0;
end
end
% C
for i = 1:100
if i <= 80
C(:,i) = i / 80;
elseif (81 <= i) && (i <= 85)
C(:,i) = 1 - ((i - 80) / 5);
elseif 86 <= i
C(:,i) = 0;
end
end
% D
for i = 1:100
if i <= 40
D(:,i) = i / 20;
elseif (21 <= i) && (i <= 25)
D(:,i) = 1 - ((i - 20) / 5);
elseif 26 <= i
D(:,i) = 0;
end
end
%% Plot
plot(1:100, A(1,:),'LineWidth',2);hold on
plot(1:100, B(1,:),'LineWidth',2);hold on
plot(1:100, C(1,:),'LineWidth',2);hold off
xlim([1 100])
ylim([-0.1 1.1])
legend('A','B','C','Location','northwest')
grid on
%% Preparing for ML
% A
for i = 1:97
XTrain_A{1,i} = A(:,i:i+2).';
YTrain_A{1,i} = A(:,i+3);
end
% C
for i = 1:97
XTrain_C{1,i} = C(:,i:i+2).';
YTrain_C{1,i} = C(:,i+3);
end
% D
for i = 1:97
XTrain_D{1,i} = D(:,i:i+2).';
YTrain_D{1,i} = D(:,i+3);
end
% Input
XTrain{1,1} = XTrain_D;
XTrain{2,1} = XTrain_A;
XTrain{3,1} = XTrain_C;
YTrain{1,1} = YTrain_D;
YTrain{2,1} = YTrain_A;
YTrain{3,1} = YTrain_C;
%% TrainNetwork
numFeatures = 3;
numResponses = 1;
numHiddenUnits = 300;
layers = [ ...
sequenceInputLayer(numFeatures)
flattenLayer('Name','flatten')
lstmLayer(numHiddenUnits,'OutputMode','sequence')
fullyConnectedLayer(20)
fullyConnectedLayer(numResponses)
regressionLayer];
options = trainingOptions('adam', ...
'MaxEpochs',200, ...
'GradientThreshold',1, ...
'InitialLearnRate',0.0001, ...
'LearnRateSchedule','piecewise', ...
'LearnRateDropPeriod',50, ...
'LearnRateDropFactor',0.2, ...
'Verbose',0, ...
'Plots','training-progress');
net = trainNetwork(XTrain, YTrain, layers, options);
%% Test
Result = zeros(1,100);
Result(:,1:3) = B(1,1:3);
for i = 1:97
[net,Result(1,i+1)] = predictAndUpdateState(net, Result(:,i:i+2).');
end
%% Plot result
plot(1:100, B(1,:),'k','LineWidth',2);hold on
plot(1:100, Result(1,:),'r','LineWidth',2);hold off
xlim([1 100])
ylim([-0.1 1.1])
legend('B','Predection','Location','northwest')
grid on
  7 个评论
Naoya
Naoya 2020-11-30
はい、その理解となります。 t-2, t-1 を含めずにまずはお試し頂ければと思います。
Yuuki
Yuuki 2020-11-30
Naoya様
ご返信ありがとうございます.
何度か5×Sで試したもののあまり精度が出ず,他論文でt-5~tのデータを入力としt+1を出力する例を見たため,同様の方法で精度が上がらないかと思い上記の質問を設けた次第です.
もう少し他の方法で改善を試みようと思います.

请先登录,再进行评论。

回答(0 个)

类别

Help CenterFile Exchange 中查找有关 Deep Learning Toolbox 的更多信息

产品


版本

R2019b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!