How can I store the trained model and call and output the best action value during the test

3 次查看(过去 30 天)
How can I store the trained model and call and output the best action value during the test?
ResetFcn=@ myResetFunction;
StepFcn=@ myStepFunction;
%创建环境接口,并从环境中获取观察和动作信息
ObservationInfo = rlNumericSpec([7 1]);
ObservationInfo.Name = 'observation';
ActionInfo = rlNumericSpec([7 1],...
'LowerLimit',-1,'UpperLimit',1);
ActionInfo.Name = 'Price';
env = rlFunctionEnv(ObservationInfo,ActionInfo,StepFcn,ResetFcn);
Ts = 0.0005;
Tf = 0.02;
%%
%创建评论者网络
rng(0)
statePath = [
featureInputLayer(7,'Normalization','none','Name','observation')
fullyConnectedLayer(256,'Name','CriticStateFC1')
reluLayer('Name', 'CriticRelu1')
fullyConnectedLayer(128,'Name','CriticStateFC2')];
actionPath = [
featureInputLayer(7,'Normalization','none','Name','action')
fullyConnectedLayer(128,'Name','CriticActionFC1','BiasLearnRateFactor',0)];
commonPath = [
additionLayer(2,'Name','add')
reluLayer('Name','CriticCommonRelu')
fullyConnectedLayer(1,'Name','CriticOutput')];
criticNetwork = layerGraph();
criticNetwork = addLayers(criticNetwork,statePath);
criticNetwork = addLayers(criticNetwork,actionPath);
criticNetwork = addLayers(criticNetwork,commonPath);
criticNetwork = connectLayers(criticNetwork,'CriticStateFC2','add/in1');
criticNetwork = connectLayers(criticNetwork,'CriticActionFC1','add/in2');
% figure
% plot(criticNetwork)
%设置评论家网络的学习率以及梯度阈值(防止梯度爆炸)
criticOpts = rlRepresentationOptions('LearnRate',0.001,'GradientThreshold',1);
obsInfo = getObservationInfo(env);
actInfo = getActionInfo(env);
critic = rlQValueRepresentation(criticNetwork,obsInfo,actInfo,'Observation',{'observation'},'Action',{'action'},criticOpts);
%%
%创建动作者网络
actorNetwork = [
featureInputLayer(7,'Normalization','none','Name','observation')
fullyConnectedLayer(256,'Name','ActorFC1')
reluLayer('Name','ActorRelu1')
fullyConnectedLayer(128,'Name','ActorFC2')
reluLayer('Name','ActorRelu2')
fullyConnectedLayer(7,'Name','ActorFC3')
tanhLayer('Name','tanh1')];
actorOpts = rlRepresentationOptions('LearnRate',0.0001,'GradientThreshold',1);
actor = rlDeterministicActorRepresentation(actorNetwork,obsInfo,actInfo,'Observation',{'observation'},'Action',{'tanh1'},actorOpts);
%%
%创建DDPG智能体
agentOpts = rlDDPGAgentOptions(...
'SampleTime',0.02,...
'TargetSmoothFactor',0.001,...%目标网络软更新平滑因子
'ExperienceBufferLength',1000000,...
'DiscountFactor',0.99,...
'MiniBatchSize',64);
agentOpts.NoiseOptions.Variance = 0.002;%添加噪声,保证随机性
agentOpts.NoiseOptions.VarianceDecayRate = 1e-6;
agent = rlDDPGAgent(actor,critic,agentOpts);
%%
trainOpts = rlTrainingOptions;
trainOpts.MaxEpisodes = 3000;
trainOpts.MaxStepsPerEpisode = ceil(Tf/Ts);
trainOpts.StopTrainingCriteria = "AverageReward";
trainOpts.StopTrainingValue = 10000;
trainOpts.ScoreAveragingWindowLength = 2;
trainingOptions.UseParallel = true;
trainingOptions.ParallelizationOptions.Mode = 'async';
trainingOptions.ParallelizationOptions.DataToSendFromWorkers = 'experiences';
trainingOptions.ParallelizationOptions.StepsUntilDataIsSent = 32;
doTraining = true;
if doTraining
trainingStats = train(agent,env,trainOpts);
end

回答(1 个)

Ayush Aniket
Ayush Aniket 2023-9-27
As per my understanding, you want to know about the process of storing a trained RL agent and using the best action learned in training during testing.
You can use the `save` function to store the trained model as shown below:
save('trainedAgent.mat', 'agent');
You can read more about saving candidate agents by referring to the following link:
This saves the trained agent to a .mat file. To load the trained agent, you can use the `load` function as:
load('trainedAgent.mat', 'agent');
After loading the trained agent, you can use it to choose the best action given an observation which it already learned during the training period. The `getAction` function of the agent can be used for this:
getAction(agent,{rand(obsInfo.Dimension)});
In your case the observation input will differ as per the environment.
Hope it helps.

类别

Help CenterFile Exchange 中查找有关 Applications 的更多信息

标签

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by