Vectorize nested loops for performance
3 次查看(过去 30 天)
显示 更早的评论
I have to solve a dynamic programming problem with a finite horizon and I am trying to vectorize as much as possible for speed. I attach here a MWE so that everybody can run it
I have run the code with the Matlab profiler and indetified two bottlenecks, marked with the comment line % THIS IS SLOW ACCORDING TO PROFILER
The first bottleneck is the call to the function ReturnFn that builds the n_a*n_a matrix RetMat. I have done this in a vectorized way (inside the function RetMat, I do not use loops)
The second bottleneck is the maximization of RetMat along the first dimension.
I'd be very grateful for any comment/suggestion!
P.S. I have an if condition inside the loops to check for bugs, I removed it now that I am confident about the code but the speed improvement is marginal.
clear,clc,close all
% STATE VARIABLES
% V(a,h,z,j)
% a: asset holdings
% h: human capital (female)
% z: labor productivity shocks (eta_m, eta_f are shocks, theta is permanent)
% j: age (from 1 to N_j)
% CHOICE VARIABLES
% d: Female labor supply (only extensive margin: either 0 or 1)
% a': Next-period assets
% h' is implied by (d,a) and consumption is implied by the budget
% constraint
% DYNAMIC PROGRAMMING PROBLEM
% V(a,h,z,j) = max_{d,a'} F(d,a',a,h,z,j)+beta*s_j*E[V(a',h',z',j+1)|z]
% subject to
% h'=G(d,h), law of motion for human capital
verbose = 1;
%% Define grids and grid sizes
N_j = 80;
n_a = 51;
n_h = 11;
n_z = 50;
n_d = 2;
a_grid = linspace(0,450,n_a)';
h_grid = linspace(0,0.72,n_h)';
z_grid = linspace(0.9,1.1,n_z)';
d_grid = [0,1]';
z_grid = repmat(z_grid,[1,3]);
pi_z = rand(n_z,n_z);
pi_z = pi_z./sum(pi_z,2);
aprime_val = a_grid; %(a',1)
a_val = a_grid'; %(1,a)
%% Set parameters that do not depend on age
beta = 0.98;
r = 0.04;
w_m = 1;
w_f = 0.75;
crra = 2;
nu = 0.12;
Jr = 45;
xi_1 = 0.05312;
xi_2 = -0.00188;
del_h = 0.074;
h_l = 0;
p.eff_j = ones(N_j,1);
p.pchild_j = ones(N_j,1);
p.pen_j = ones(N_j,1);
p.nchild_j = ones(N_j,1);
p.s_j = ones(N_j,1);
p.age_j = (1:1:N_j)';
% Initialize output arrays
V = zeros(n_a,n_h,n_z,N_j);
Policy = zeros(2,n_a,n_h,n_z,N_j);
tic
%% Solve problem in the last period
% Set age-dependent parameters
eff_j = p.eff_j(N_j);
pchild_j = p.pchild_j(N_j);
pen_j = p.pen_j(N_j);
nchild_j = p.nchild_j(N_j);
age_j = p.age_j(N_j);
s_j = p.s_j(N_j);
% V(a,h,z,N_j) = max_{d,a'}
V_d = zeros(n_a,n_h,n_z,n_d);
Pol_aprime_d = zeros(n_a,n_h,n_z,n_d);
for d_c = 1:n_d
d_val = d_grid(d_c);
for z_c = 1:n_z
eta_m_val = z_grid(z_c,1);
eta_f_val = z_grid(z_c,2);
theta_val = z_grid(z_c,3);
for h_c = 1:n_h
h_val = h_grid(h_c);
% RetMat is (a',a)
RetMat = ReturnFn(d_val,aprime_val,a_val,h_val,eta_m_val,eta_f_val,theta_val,...
w_m,eff_j,w_f,pchild_j,pen_j,r,nchild_j,crra,nu,age_j,Jr);
[max_val,max_ind] = max(RetMat,[],1);
V_d(:,h_c,z_c,d_c) = max_val;
Pol_aprime_d(:,h_c,z_c,d_c) = max_ind;
end %end h
end %end z
end %end d
[V(:,:,:,N_j),d_max] = max(V_d,[],4);
Policy(1,:,:,:,N_j) = d_max; % Optimal d
for z_c=1:n_z
for h_c=1:n_h
for a_c = 1:n_a
d_star = d_max(a_c,h_c,z_c);
Policy(2,a_c,h_c,z_c,N_j) = Pol_aprime_d(a_c,h_c,z_c,d_star); % Optimal a'
end
end
end
%% Backward iteration over age
for j = N_j-1:-1:1
if verbose==1; fprintf('Age %d out of %d \n',j,N_j); end
V_next = V(:,:,:,j+1); %V(a',h',z')
% Set age-dependent parameters
eff_j = p.eff_j(j);
pchild_j = p.pchild_j(j);
pen_j = p.pen_j(j);
nchild_j = p.nchild_j(j);
age_j = p.age_j(j);
s_j = p.s_j(j);
for z_c = 1:n_z
eta_m_val = z_grid(z_c,1);
eta_f_val = z_grid(z_c,2);
theta_val = z_grid(z_c,3);
% Compute EV(a',h'), given z
% EV = zeros(n_a,n_h);
z_prob = pi_z(z_c,:)';
% for zprime_c = 1:n_z
% EV = EV+V_next(:,:,zprime_c)*z_prob(zprime_c);
% end %end z'
EV = V_next.*shiftdim(z_prob,-2); %V(a',h',z')*Prob(1,1,z')
EV = sum(EV,3); %EV(a',h')
for d_c=1:n_d
d_val = d_grid(d_c);
for h_c = 1:n_h
h_val = h_grid(h_c);
hprime_val = f_HC_accum(d_val,h_val,age_j,xi_1,xi_2,del_h,h_l);
% Ret_mat is (a',a)
% THIS IS SLOW ACCORDING TO PROFILER
Ret_mat = ReturnFn(d_val,aprime_val,a_val,h_val,eta_m_val,eta_f_val,theta_val,...
w_m,eff_j,w_f,pchild_j,pen_j,r,nchild_j,crra,nu,age_j,Jr);
%[ind_l,weight_l] = interp_toolkit(hprime_val,h_grid);
[ind_l,weight_l] = find_loc(h_grid,hprime_val);
if ind_l>length(h_grid)-1
error('ind_l out of bounds')
end
EV_interp = EV(:,ind_l)*weight_l+EV(:,ind_l+1)*(1-weight_l);
RHS_mat = Ret_mat+beta*s_j*EV_interp;
% THIS IS SLOW ACCORDING TO PROFILER
[max_val,max_ind] = max(RHS_mat,[],1);
% max_val and max_ind are (1,a)
V_d(:,h_c,z_c,d_c) = max_val; %best V given d
Pol_aprime_d(:,h_c,z_c,d_c) = max_ind; %best a' given d
end %end h
end %end d
end %end z
[V(:,:,:,j),d_max] = max(V_d,[],4);
Policy(1,:,:,:,j) = d_max; % Optimal d
for z_c=1:n_z
for h_c=1:n_h
for a_c = 1:n_a
d_star = d_max(a_c,h_c,z_c);
Policy(2,a_c,h_c,z_c,j) = Pol_aprime_d(a_c,h_c,z_c,d_star); % Optimal a'
end
end
end
end %end j
toc
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function F = ReturnFn(l_f,aprime,a,h_f,eta_m,eta_f,theta,w_m,eff_j,w_f,...
pchild_j,pen_j,r,nchild_j,crra,nu,agej,Jr)
% Calculate earnings (incl. child care costs) of men and women
y_m = w_m*eff_j*theta*eta_m;
y_f = w_f*l_f*(exp(h_f)*theta*eta_f - pchild_j);
% l_f can be either 0 or 1
% calculate available resources
cash = (1+r)*a + pen_j*(agej>=Jr) + (y_m + y_f)*(agej<Jr);
cons = cash-aprime;
%pos = cons>0;
F = (cons/(sqrt(2+nchild_j))).^(1-crra)/(1-crra) - nu*l_f;
F(cons<=0) = -inf;
end %end function "f_ReturnFn"
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [h_f_prime] = f_HC_accum(l_f,h_f,age_j,xi_1,xi_2,del_h,h_l)
% l_f: d variable that affects h'
% h_f: current-period value of h
h_f_prime = h_f + (xi_1 + xi_2*age_j)*l_f - del_h*(1-l_f);
h_f_prime = max(h_f_prime, h_l);
%h_f_prime = h_f;
%h_f_prime = max(h_f_prime, h_l);
end %end function f_HC_accum
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [jl,omega] = find_loc(x_grid,xi)
%-------------------------------------------------------------------------%
% DESCRIPTION
% Find jl s.t. x_grid(jl)<=xi<x_grid(jl+1)
% for jl=1,..,N-1
% omega is the weight on x_grid(jl) so that
% omega*x_grid(jl)+(1-omega)*x_grid(jl+1)=xi
% INPUTS
% x_grid must be a strictly increasing column vector (nx,1)
% xi must be a scalar
% OUTPUTS
% jl: Left point (scalar)
% omega: weight on the left point (scalar)
% NOTES
% See find_loc_vec.m for a vectorized version.
%-------------------------------------------------------------------------%
nx = size(x_grid,1);
jl = max(min(locate(x_grid,xi),nx-1),1);
%Weight on x_grid(j)
omega = (x_grid(jl+1)-xi)/(x_grid(jl+1)-x_grid(jl));
omega = max(min(omega,1),0);
end %end function "find_loc"
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function jl = locate(xx,x)
%function jl = locate(xx,x)
%
% x is between xx(jl) and xx(jl+1)
%
% jl = 0 and jl = n means x is out of range
%
% xx is assumed to be monotone increasing
n = length(xx);
if x<xx(1)
jl = 0;
elseif x>xx(n)
jl = n;
else
jl = 1;
ju = n;
while (ju-jl>1)
jm = floor((ju+jl)/2);
if x>=xx(jm)
jl = jm;
else
ju=jm;
end
end
end
end %end function locate
5 个评论
dpb
2024-6-18
I don't know that I'm terribly surprised; for loop code optimization has advanced by leaps and bounds since early days of MATLAB so while vectorized code can often be quicker if can get to a linear addressing mode in internal code, if the vecorized code is complex there may be little optimization that can be done and the combination of memory and code complexity can be counterproductive to the "deadahead" solution (as you've just demonstrated).
You do need to run the process monitor and ensure you really are not page faulting with the larger sizes; it may be you have a lot of installed memory but MATLAB isn't using it all...
采纳的回答
更多回答(0 个)
另请参阅
类别
在 Help Center 和 File Exchange 中查找有关 Oceanography and Hydrology 的更多信息
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!