ニューラルネットワー​クの学習をdoubl​e型で行うことはでき​ますか?

2 次查看(过去 30 天)
Fumiya Watanabe
Fumiya Watanabe 2018-6-26
ニューラルネットワークの学習をdouble型で行うことはできますか?
現在、ある実数値ベクトルを入力とする回帰問題をNeural Network Toolboxを用いて実現しようとしています。 このベクトル入力を画像入力として扱うことで実現を考えています。しかしながら、trainNetworkを実行するとsingle型として扱われてしまう問題が生じており、解決法がわからず困っております。
例えば、次の自作の回帰層を考えます。
classdef testLayer < nnet.layer.RegressionLayer
methods
function layer = testLayer()
end
function loss = forwardLoss(layer, Y, T)
loss = gpuArray(0);
end
function dLdX = backwardLoss(layer, Y, T)
dLdX = gpuArray(zeros(size(Y)));
end
end
end
この自作回帰層を用いて、次のように学習を実行します。
%%学習データ
x_in = rand(10, 1, 1, 6);
y_tr = rand(6, 5);
%%層構造とオプションの定義
layers = [
imageInputLayer([10 1 1], 'Normalization', 'none', 'Name', 'Input')
fullyConnectedLayer(2, 'Name', 'Layer1')
reluLayer('Name', 'ReLU1')
fullyConnectedLayer(5, 'Name', 'Output')
testLayer
];
layers(end).Name = 'Regression';
options = trainingOptions(...
'sgdm',...
'InitialLearnRate', 0.001, ...
'MiniBatchSize', 3, ...
'MaxEpochs', 1);
%%学習開始
net = trainNetwork(x_in, y_tr, layers, options);
すると、次のエラーが発生します。
エラー: trainNetwork (line 154)
Incorrect type of dLdX for 'backwardLoss' in the output layer. Expected gpuArray of underlying type 'single', but instead has
underlying type 'double'.
上記の自作回帰層で、gpuArrayの内部をsingleにキャストすることで実行することが可能となるのですが、実際に使っている自作回帰層ではdouble型でないと計算できない関数を利用しているため、
function loss = forwardLoss(layer, Y, T)
loss = gpuArray(single(myfun(double(Y), double(T))));
end
のようなキャストをしていく必要が生じてしまいます。これを避けるために学習をdouble型で実行したいのですが、解決法はありますでしょうか。

采纳的回答

Naoya
Naoya 2018-6-29
Neural Network Toolbox で提供される 畳み込みニューラルネットワークですが、trainNetwork 側で与えるデータ型は single, double 両方を受け付けます。
しかしながら、基本的にGPU上では単精度演算として扱われますので、GPU へ渡すゲートウェイとなるデータ型は single型となってしまいます。
  3 个评论
Naoya
Naoya 2018-7-3
ご連絡ありがとうございます。 cpuモードの場合でも backwardLoss 関数のゲートウェイは single型にする必要があります。
Fumiya Watanabe
Fumiya Watanabe 2018-7-5
ご回答ありがとうございます。
入力としてはdouble型を受け付けるが、計算内部はGPU・CPUどちらの場合でもsingle型で実行される形になっており、自作の層を扱う場合はsingle型でほかの層とのやり取りが必要であると理解いたしました。 ありがとうございました。

请先登录,再进行评论。

更多回答(0 个)

类别

Help CenterFile Exchange 中查找有关 深層学習データの前処理 的更多信息

产品


版本

R2018a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!