How to speed up vectorized operations for dynamic programming

3 次查看(过去 30 天)
I would like to speed up the following code which solves a discrete dynamic programming problem using the method of successive approximations, as described e.g. in Bertsekas.
The algorithm is made of two steps. In step 1, I precompute the payoff array R(a',a,z) where a' is the action and (a,z) are the states. In step 2 I compute the value function using the method of successive approximations: I guess V0, then I compute an updated V1 and finally I check if ||V1-V0|| is less than a tolerance level. If it is, I stop, otherwise I set V0=V1 and go on.
I profiled the code (see below a MWE) and the two most time-consuming lines are the following ones:
(1) RHS = Ret+beta*permute(EV,[1,3,2]);
(2) [max_val,max_ind] = max(RHS,[],1);
Line (1) takes up 63% of the total running time, line (2) takes 32%. As you can see I have already vectorized all loops.
I would be very grateful for any suggestion. I post below a MWE. (Note that I set n_a, the num of grid points, to a low value on purpose, to allow interested users to run quickly the example. In my actual code, n_a=10000 or more).
%% Solve income fluctuation problem CPU
clear;clc;close all
%% Economic parameters
sigma = 2;
r = 0.03;
beta = 0.96;
PZ = [0.60 0.40;
0.05 0.95];
z_grid = [0.5 1.0]';
n_z = length(z_grid);
b = 0; %lower bound for asset holdings a
grid_max = 4;
n_a = 500; % IN PRACTICE THIS IS EQUAL TO 5000-10000
R = 1+r;
a_grid = linspace(-b,grid_max,n_a)';
if sigma==1
fun_u = @(c) log(c);
else
fun_u = @(c) c.^(1-sigma)/(1-sigma);
end
%% Computational parameters
verbose = 0;
tiny = 1e-8; %very small positive number
tol = 1e-6; %tolerance for VFI and TI
max_iter = 500; %maximum num. of iterations for both VFI and TI
%% Start timing
tic
%% STEP 1- Precompute current payoff array R(a',a,z)
a_tomorrow = a_grid; %(a',1,1)
a_today = a_grid'; %(1,a,1)
z_today = shiftdim(z_grid,-2); %(1,1,z)
cons = (1+r)*a_today+z_today-a_tomorrow;
Ret = fun_u(cons); %size: [n_a,n_a,n_z]
Ret(cons<=0) = -inf;
%% STEP 2 - Value function iteration
iter = 1;
err = tol+1;
V0 = zeros(n_a,n_z);
while err>tol && iter<=max_iter
EV = V0*PZ'; %(a',z)
RHS = Ret+beta*permute(EV,[1,3,2]);
[max_val,max_ind] = max(RHS,[],1);
V1 = squeeze(max_val);
pol_ind_ap = squeeze(max_ind);
err = max(abs(V0(:)-V1(:)));
if verbose==1
fprintf('iter = %d, err = %f \n',iter,err)
end
iter = iter+1;
V0 = V1;
end
if err>tol
error('VFI did not converge!')
else
fprintf('VFI converged after = %d iterations \n',iter)
end
pol_ap = a_grid(pol_ind_ap);
pol_c = (1+r)*a_grid+z_grid'-pol_ap;
%% End timing
toc
%% Figures
figure
plot(a_grid,pol_c(:,1),'linewidth',2)
hold on
plot(a_grid,pol_c(:,2),'linewidth',2)
legend('Low shock','High shock','Location','NorthWest')
xlabel('asset level')
ylabel('consumption')
title('Consumption Policy Function')
figure
plot(a_grid,a_grid,'--','linewidth',2)
hold on
plot(a_grid,pol_ap(:,1),'linewidth',2)
hold on
plot(a_grid,pol_ap(:,2),'linewidth',2)
legend('45 line','Low shock','High shock','Location','NorthWest')
xlabel('Current period assets')
ylabel('Next-period assets')
title('Assets Policy Function')
  2 个评论
Torsten
Torsten 2024-9-14
编辑:Torsten 2024-9-14
You want to speed up a runtime of 0.17 s ?
Apart from this: I don't think there is much to optimize in the two commands you listed.
Alessandro
Alessandro 2024-9-14
The running time of the code depends on the number of grid points for assets, n_a. In this example I set n_a=500 but I need something like n_a=10000. I wrote this clearly as a comment to the line where I set n_a.
Moreover, the problem shown here is part of a larger project and it has to run hundreds of times, so EVEN if it was 0.17s it would still be worth to speed up

请先登录,再进行评论。

采纳的回答

Matt J
Matt J 2024-9-14
编辑:Matt J 2024-9-14
This might be a little faster.
betaPZtransp=beta*PZ';
tic
while err>tol && iter<=max_iter
RHS = Ret + reshape(V0*betaPZtransp,n_a,1,n_z);
V1 = max(RHS,[],1);
err = norm( V0(:)-V1(:) ,inf);
if verbose
fprintf('iter = %d, err = %f \n',iter,err)
end
iter = iter+1;
V0 = reshape(V1,n_a,n_z);
end
toc
[V1,pol_ind_ap]=max(RHS,[],1);
pol_ind_ap = reshape(pol_ind_ap, n_a,n_z);
  2 个评论
Matt J
Matt J 2024-9-14
编辑:Matt J 2024-9-14
You should also use gpuArrays if you have appropriate GPU hardware and toolboxes.
Alessandro
Alessandro 2024-9-14
编辑:Alessandro 2024-9-14
Thanks! Your version is indeed a bit faster. After doing some tests, what really helped is to move the computation of the argmax (i.e. pol_ind_ap) out of the while loop and eliminating the call to squeeze. Instead, replacing permute with reshape does not seem to affect the timing (maybe because both permute and reshape are built-in functions, while squeeze is an M-file.
I will try the gpu later.

请先登录,再进行评论。

更多回答(0 个)

类别

Help CenterFile Exchange 中查找有关 Creating and Concatenating Matrices 的更多信息

产品


版本

R2024a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by