How can I transfer the model parameters of a well-trained NN to another one?
3 次查看(过去 30 天)
显示 更早的评论
I have two NNs, i.e., net_1 and net_2, where net_1 is not trained and net_2 has been well trained. Now I want to transfer the knowledge of net_2 to net_1, such that net_1 can be used well as net_2. So I have got the following code. However, after setting the weights and bias of net_1 to those of net_2, I find that the net_1 behaves very very bad, e.g., net_2(-2) = 3.999, net_1(-2)=32.249. Here, net_1 is expected to output a value that is very similar with net_2. May anone please tell me that is there anything wrong with my code? Thanks.
(Please note that I do not want to use the operation net_1 = net_2 to achieve this purpose.)
clear all
%%
% Task: To fit a non-linear function f(x) = x.^2
%%
D=1e4; % no. of training sample
layers_neurons=[64];
%% Net 1: no training network
net_1 = feedforwardnet(layers_neurons);
[data1,target2] = gen_data_sample(10);
net_1 = configure(net_1, data1, target2);
%% Net 2: well training network
[data2,target2] = gen_data_sample(D);
net_2 = feedforwardnet(layers_neurons); % doc feedforwardnet for more details
net_2 = configure(net_2, data2, target2);
net_2 = train(net_2,data2, target2); % , 'useGPU', 'yes', 'useparallel', 'yes'
%% Transfer the knowledge of Net 2 to Net 1
net_1.IW = net_2.IW;
net_1.LW = net_2.LW;
net_1.b = net_2.b;
%% Test and Compare Net 1 and Net 2
net_1(-2)
net_2(-2)
%%
function [input,output] = gen_data_sample(D)
%%
input = -20+(20-(-20))*rand(1, D);
output = input.^2;
end
0 个评论
采纳的回答
Divya Gaddipati
2019-12-5
Before you assign weights of “net_2” to “net_1”, initialize net_1 to net_2 using the init function
net_1 = init(net_2);
This would resolve your issue.
Additionally, you can also remove the configuring part of net_1 (i.e., line 10 in your code), which might not be required if you are using init.
For more information on configure and init, refer to the below link:https://www.mathworks.com/help/deeplearning/ug/create-configure-and-initialize-multilayer-neural-networks.html#bss330n-3
Hope this helps!
0 个评论
更多回答(0 个)
另请参阅
类别
在 Help Center 和 File Exchange 中查找有关 Sequence and Numeric Feature Data Workflows 的更多信息
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!