An issue with matlabFunction

5 次查看(过去 30 天)
Hello,
I have a question and I think the best is to explain via an example. So, consider the following
mu=@(x,par)par(1).*x;sigma=@(x,par)par(2);
syms x0 x dt
par = sym('par', [1, 2]);
LL = -1 / 2 * (log(2 * pi * dt * sigma(x0, par)^2) + ((x - x0 - mu(x0, par) * dt) / (sigma(x0, par) * sqrt(dt)))^2);
h_LL = hessian(LL,par);
h_LL=matlabFunction(h_LL)
h_LL = function_handle with value:
@(dt,par1,par2,x,x0)reshape([-dt.*1.0./par2.^2.*x0.^2,1.0./par2.^3.*x0.*(-x+x0+dt.*par1.*x0).*2.0,1.0./par2.^3.*x0.*(-x+x0+dt.*par1.*x0).*2.0,1.0./par2.^2-(1.0./par2.^4.*(-x+x0+dt.*par1.*x0).^2.*3.0)./dt],[2,2])
Now, assume that V is a 1-by-n given vector and x0=V(1:n-1), x=V(2:n), dt and par's are given scalars. for instance, assume that dt=1, x0=1:100, x=2:101, par1=par2=1. What I want to do is to calculate the following sum over the elements of V:
h_LL(1,1,1,1,2)+h_LL(1,1,1,2,3)+h_LL(1,1,1,3,4)+ ... + h_LL(1,1,1,100,101)
Of course I can calculate this sum using a for-loop. But, my actual 'n' and the number of my actual parameters are very big making this scheme impractical. The idea that comes to my mind is to find a way to remove the 'reshape' in the above function handle h_LL and supply the entire vevtors x and x0 rather than doing so element by element.
Unfortunately, I haver no control over matlabFunction and cannot tell it please give me a matrix rather than applying reshape to a scalar.
Any idea?
Thanks in advance,
Babak
  2 个评论
Torsten
Torsten 2023-12-11
编辑:Torsten 2023-12-11
I don't understand how you want to define the hessian with respect to "par" for vectors x and x0.
Do you want to create 100 x 100 Hessians of size 2x2, one for each pair (x(i),x0(j)) ?
Mohammad Shojaei Arani
Hi Torsten,
I find hessian WRT par only.

请先登录,再进行评论。

采纳的回答

Dyuman Joshi
Dyuman Joshi 2023-12-11
编辑:Dyuman Joshi 2023-12-11
mu=@(x,par)par(1).*x;
sigma=@(x,par)par(2);
syms x0 x dt
par = sym('par', [1, 2]);
LL = -1 / 2 * (log(2 * pi * dt * sigma(x0, par)^2) + ((x - x0 - mu(x0, par) * dt) / (sigma(x0, par) * sqrt(dt)))^2);
temp = hessian(LL, par)
temp = 
%Reshape the hessian as a column vector
h_LL = reshape(temp, [], 1)
h_LL = 
%Convert to a function handle
fun=matlabFunction(h_LL);
%Sample inputs
DT=1; X0=1:100; X=2:101; PAR1=1; PAR2=1;
%Provide x0 and x as row vectors
out = fun(DT, PAR1, PAR2, X, X0)
out = 4×100
-1 -4 -9 -16 -25 -36 -49 -64 -81 -100 -121 -144 -169 -196 -225 -256 -289 -324 -361 -400 -441 -484 -529 -576 -625 -676 -729 -784 -841 -900 0 4 12 24 40 60 84 112 144 180 220 264 312 364 420 480 544 612 684 760 840 924 1012 1104 1200 1300 1404 1512 1624 1740 0 4 12 24 40 60 84 112 144 180 220 264 312 364 420 480 544 612 684 760 840 924 1012 1104 1200 1300 1404 1512 1624 1740 1 -2 -11 -26 -47 -74 -107 -146 -191 -242 -299 -362 -431 -506 -587 -674 -767 -866 -971 -1082 -1199 -1322 -1451 -1586 -1727 -1874 -2027 -2186 -2351 -2522
%Convert the output to a 3D array,
%where each 2D matrix is a corresponding output
out = reshape(out, 2, 2, [])
out =
out(:,:,1) = -1 0 0 1 out(:,:,2) = -4 4 4 -2 out(:,:,3) = -9 12 12 -11 out(:,:,4) = -16 24 24 -26 out(:,:,5) = -25 40 40 -47 out(:,:,6) = -36 60 60 -74 out(:,:,7) = -49 84 84 -107 out(:,:,8) = -64 112 112 -146 out(:,:,9) = -81 144 144 -191 out(:,:,10) = -100 180 180 -242 out(:,:,11) = -121 220 220 -299 out(:,:,12) = -144 264 264 -362 out(:,:,13) = -169 312 312 -431 out(:,:,14) = -196 364 364 -506 out(:,:,15) = -225 420 420 -587 out(:,:,16) = -256 480 480 -674 out(:,:,17) = -289 544 544 -767 out(:,:,18) = -324 612 612 -866 out(:,:,19) = -361 684 684 -971 out(:,:,20) = -400 760 760 -1082 out(:,:,21) = -441 840 840 -1199 out(:,:,22) = -484 924 924 -1322 out(:,:,23) = -529 1012 1012 -1451 out(:,:,24) = -576 1104 1104 -1586 out(:,:,25) = -625 1200 1200 -1727 out(:,:,26) = -676 1300 1300 -1874 out(:,:,27) = -729 1404 1404 -2027 out(:,:,28) = -784 1512 1512 -2186 out(:,:,29) = -841 1624 1624 -2351 out(:,:,30) = -900 1740 1740 -2522 out(:,:,31) = -961 1860 1860 -2699 out(:,:,32) = -1024 1984 1984 -2882 out(:,:,33) = -1089 2112 2112 -3071 out(:,:,34) = -1156 2244 2244 -3266 out(:,:,35) = -1225 2380 2380 -3467 out(:,:,36) = -1296 2520 2520 -3674 out(:,:,37) = -1369 2664 2664 -3887 out(:,:,38) = -1444 2812 2812 -4106 out(:,:,39) = -1521 2964 2964 -4331 out(:,:,40) = -1600 3120 3120 -4562 out(:,:,41) = -1681 3280 3280 -4799 out(:,:,42) = -1764 3444 3444 -5042 out(:,:,43) = -1849 3612 3612 -5291 out(:,:,44) = -1936 3784 3784 -5546 out(:,:,45) = -2025 3960 3960 -5807 out(:,:,46) = -2116 4140 4140 -6074 out(:,:,47) = -2209 4324 4324 -6347 out(:,:,48) = -2304 4512 4512 -6626 out(:,:,49) = -2401 4704 4704 -6911 out(:,:,50) = -2500 4900 4900 -7202 out(:,:,51) = -2601 5100 5100 -7499 out(:,:,52) = -2704 5304 5304 -7802 out(:,:,53) = -2809 5512 5512 -8111 out(:,:,54) = -2916 5724 5724 -8426 out(:,:,55) = -3025 5940 5940 -8747 out(:,:,56) = -3136 6160 6160 -9074 out(:,:,57) = -3249 6384 6384 -9407 out(:,:,58) = -3364 6612 6612 -9746 out(:,:,59) = -3481 6844 6844 -10091 out(:,:,60) = -3600 7080 7080 -10442 out(:,:,61) = -3721 7320 7320 -10799 out(:,:,62) = -3844 7564 7564 -11162 out(:,:,63) = -3969 7812 7812 -11531 out(:,:,64) = -4096 8064 8064 -11906 out(:,:,65) = -4225 8320 8320 -12287 out(:,:,66) = -4356 8580 8580 -12674 out(:,:,67) = -4489 8844 8844 -13067 out(:,:,68) = -4624 9112 9112 -13466 out(:,:,69) = -4761 9384 9384 -13871 out(:,:,70) = -4900 9660 9660 -14282 out(:,:,71) = -5041 9940 9940 -14699 out(:,:,72) = -5184 10224 10224 -15122 out(:,:,73) = -5329 10512 10512 -15551 out(:,:,74) = -5476 10804 10804 -15986 out(:,:,75) = -5625 11100 11100 -16427 out(:,:,76) = -5776 11400 11400 -16874 out(:,:,77) = -5929 11704 11704 -17327 out(:,:,78) = -6084 12012 12012 -17786 out(:,:,79) = -6241 12324 12324 -18251 out(:,:,80) = -6400 12640 12640 -18722 out(:,:,81) = -6561 12960 12960 -19199 out(:,:,82) = -6724 13284 13284 -19682 out(:,:,83) = -6889 13612 13612 -20171 out(:,:,84) = -7056 13944 13944 -20666 out(:,:,85) = -7225 14280 14280 -21167 out(:,:,86) = -7396 14620 14620 -21674 out(:,:,87) = -7569 14964 14964 -22187 out(:,:,88) = -7744 15312 15312 -22706 out(:,:,89) = -7921 15664 15664 -23231 out(:,:,90) = -8100 16020 16020 -23762 out(:,:,91) = -8281 16380 16380 -24299 out(:,:,92) = -8464 16744 16744 -24842 out(:,:,93) = -8649 17112 17112 -25391 out(:,:,94) = -8836 17484 17484 -25946 out(:,:,95) = -9025 17860 17860 -26507 out(:,:,96) = -9216 18240 18240 -27074 out(:,:,97) = -9409 18624 18624 -27647 out(:,:,98) = -9604 19012 19012 -28226 out(:,:,99) = -9801 19404 19404 -28811 out(:,:,100) = -10000 19800 19800 -29402
%For comparison
subs(temp, [dt x0 x par], [DT X0(1) X(1) PAR1 PAR2])
ans = 
subs(temp, [dt x0 x par], [DT X0(2) X(2) PAR1 PAR2])
ans = 
  2 个评论
Mohammad Shojaei Arani
Hi Dyuman,
Thanks a lot!
You elegantly solved my problem. I appreciate that!
Dyuman Joshi
Dyuman Joshi 2023-12-12
You're welcome! Glad to have helped!

请先登录,再进行评论。

更多回答(1 个)

Walter Roberson
Walter Roberson 2023-12-11
matlabFunction can produce vectorized code only if all of the following are true:
  • the expression is scalar for scalar inputs
  • the expression does not use piecewise()
  • the expression does not involve an integral, unless the integral is one that has a closed-form solution. This applies regardless of whether the symbolic parameters appear in the expression to be integrated, or in the bounds of the integral (but for different reasons for those two situations)
  • the expression does not use symmatrix (matlabFunction refuses to process those)
I might be forgetting some other cases.
The idea that comes to my mind is to find a way to remove the 'reshape' in the above function handle h_LL
You would have to do that by post processing, such as by having matlabFunction use the 'File' option and edit the file afterwards. Unless you wanted to risk putting a reshape.m on your MATLAB path, but that would be distinctly risky.
So what is the official solution for this kind of problem?
  1. You can loop; or
  2. You can use arrayfun() to apply the function to "corresponding" matrix inputs, which will automatically wrap the results; you would 'uniform', 0 to get a cell array of results. You could reshape() or permute() that cell to be 1 x 1 x n and then use cell2mat() -- or you could do cat(3, ThatCell{:}) to concatenate along the third dimension.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by