How can I store the trained model and call and output the best action value during the test
4 次查看(过去 30 天)
显示 更早的评论
How can I store the trained model and call and output the best action value during the test?
ResetFcn=@ myResetFunction;
StepFcn=@ myStepFunction;
%创建环境接口,并从环境中获取观察和动作信息
ObservationInfo = rlNumericSpec([7 1]);
ObservationInfo.Name = 'observation';
ActionInfo = rlNumericSpec([7 1],...
'LowerLimit',-1,'UpperLimit',1);
ActionInfo.Name = 'Price';
env = rlFunctionEnv(ObservationInfo,ActionInfo,StepFcn,ResetFcn);
Ts = 0.0005;
Tf = 0.02;
%%
%创建评论者网络
rng(0)
statePath = [
featureInputLayer(7,'Normalization','none','Name','observation')
fullyConnectedLayer(256,'Name','CriticStateFC1')
reluLayer('Name', 'CriticRelu1')
fullyConnectedLayer(128,'Name','CriticStateFC2')];
actionPath = [
featureInputLayer(7,'Normalization','none','Name','action')
fullyConnectedLayer(128,'Name','CriticActionFC1','BiasLearnRateFactor',0)];
commonPath = [
additionLayer(2,'Name','add')
reluLayer('Name','CriticCommonRelu')
fullyConnectedLayer(1,'Name','CriticOutput')];
criticNetwork = layerGraph();
criticNetwork = addLayers(criticNetwork,statePath);
criticNetwork = addLayers(criticNetwork,actionPath);
criticNetwork = addLayers(criticNetwork,commonPath);
criticNetwork = connectLayers(criticNetwork,'CriticStateFC2','add/in1');
criticNetwork = connectLayers(criticNetwork,'CriticActionFC1','add/in2');
% figure
% plot(criticNetwork)
%设置评论家网络的学习率以及梯度阈值(防止梯度爆炸)
criticOpts = rlRepresentationOptions('LearnRate',0.001,'GradientThreshold',1);
obsInfo = getObservationInfo(env);
actInfo = getActionInfo(env);
critic = rlQValueRepresentation(criticNetwork,obsInfo,actInfo,'Observation',{'observation'},'Action',{'action'},criticOpts);
%%
%创建动作者网络
actorNetwork = [
featureInputLayer(7,'Normalization','none','Name','observation')
fullyConnectedLayer(256,'Name','ActorFC1')
reluLayer('Name','ActorRelu1')
fullyConnectedLayer(128,'Name','ActorFC2')
reluLayer('Name','ActorRelu2')
fullyConnectedLayer(7,'Name','ActorFC3')
tanhLayer('Name','tanh1')];
actorOpts = rlRepresentationOptions('LearnRate',0.0001,'GradientThreshold',1);
actor = rlDeterministicActorRepresentation(actorNetwork,obsInfo,actInfo,'Observation',{'observation'},'Action',{'tanh1'},actorOpts);
%%
%创建DDPG智能体
agentOpts = rlDDPGAgentOptions(...
'SampleTime',0.02,...
'TargetSmoothFactor',0.001,...%目标网络软更新平滑因子
'ExperienceBufferLength',1000000,...
'DiscountFactor',0.99,...
'MiniBatchSize',64);
agentOpts.NoiseOptions.Variance = 0.002;%添加噪声,保证随机性
agentOpts.NoiseOptions.VarianceDecayRate = 1e-6;
agent = rlDDPGAgent(actor,critic,agentOpts);
%%
trainOpts = rlTrainingOptions;
trainOpts.MaxEpisodes = 3000;
trainOpts.MaxStepsPerEpisode = ceil(Tf/Ts);
trainOpts.StopTrainingCriteria = "AverageReward";
trainOpts.StopTrainingValue = 10000;
trainOpts.ScoreAveragingWindowLength = 2;
trainingOptions.UseParallel = true;
trainingOptions.ParallelizationOptions.Mode = 'async';
trainingOptions.ParallelizationOptions.DataToSendFromWorkers = 'experiences';
trainingOptions.ParallelizationOptions.StepsUntilDataIsSent = 32;
doTraining = true;
if doTraining
trainingStats = train(agent,env,trainOpts);
end
0 个评论
回答(1 个)
Ayush Aniket
2023-9-27
As per my understanding, you want to know about the process of storing a trained RL agent and using the best action learned in training during testing.
You can use the `save` function to store the trained model as shown below:
save('trainedAgent.mat', 'agent');
You can read more about saving candidate agents by referring to the following link:
This saves the trained agent to a .mat file. To load the trained agent, you can use the `load` function as:
load('trainedAgent.mat', 'agent');
After loading the trained agent, you can use it to choose the best action given an observation which it already learned during the training period. The `getAction` function of the agent can be used for this:
getAction(agent,{rand(obsInfo.Dimension)});
In your case the observation input will differ as per the environment.
Hope it helps.
0 个评论
另请参阅
类别
在 Help Center 和 File Exchange 中查找有关 Applications 的更多信息
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!