How to multiply N matrices without a FOR loop? (Slices of 3D array)
26 次查看(过去 30 天)
显示 更早的评论
I have a 3D matrix 2x2xN which, for my purposes, are essentially N 2x2 matrices and I want to do matrix multiplication with all of them so that I would get the following result:
N = 14;
M = rand(2,2,N);
Z = M(:,:,1)*M(:,:,2)* ... *M(:,:,N);
size(Z) == [2 2]
I can do it with a for loop, but I am looking for a single line approach, something like:
prod(M,3);
but probably with mtimes that would do matrix multiplication along the 3rd dimension (not the element-wise product).
I also converted matrix M into a Nx1 cell array of 2x2 matrices, but this approach did not work either to do the multiplication.
8 个评论
Jan
2017-12-7
编辑:Jan
2017-12-7
Stephen's comment is very good.
For the estimation of the effects of optimizing the code, the usual sizes of the inputs matter: Is it really a [2 x 2 x N] array and what sizes of N do you have? For larger rows and columns, the main is done by mtimes, while the loop does not matter much. mtimes calls optimized BLAS or ATLAS functions, such that there is no room for further improvements. But I do not know, if these library function handle tiny 2x2 matrices with unrolled loops. So perhaps a C-Mex function could be more efficient.
回答(5 个)
Jan
2017-12-7
编辑:Jan
2017-12-7
If you really have 2x2 sub matrices to accumulate, try a C-Mex function:
#include "mex.h"
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
const mwSize *size;
mwSize N;
double *p, *q, q11, q12, q21, q22, t11, t21;
p = mxGetPr(prhs[0]);
size = mxGetDimensions(prhs[0]);
if (size[0] != 2 || size[1] != 2) {
mexErrMsgIdAndTxt("JSimon:CumMProd2x2:BadInput1",
"1st input must be a [2 x 2 x N] array.");
}
N = size[2];
q11 = p[0];
q21 = p[1];
q12 = p[2];
q22 = p[3];
while (--N) { // Unrolled 2x2 matrix multiplication
p += 4;
t11 = q11 * p[0] + q12 * p[1];
t21 = q21 * p[0] + q22 * p[1];
q12 = q11 * p[2] + q12 * p[3];
q22 = q21 * p[2] + q22 * p[3];
q11 = t11;
q21 = t21;
}
plhs[0] = mxCreateDoubleMatrix(2, 2, mxREAL);
q = mxGetPr(plhs[0]);
q[0] = q11;
q[1] = q21;
q[2] = q12;
q[3] = q22;
return;
}
[EDITED] This is tested now. The speed is very interesting:
function speed
x = rand(2, 2, 1000);
tic; for k = 1:1000, y = CumMProd2x2(x); end; toc
tic; for k = 1:1000, y = CumMProd2x2_AB(x); end; toc
tic
for k = 1:1000 % Jos (10584)
iif = @(varargin) varargin{2*find([varargin{1:2:end}], 1, 'first')}() ;
mprodf = @(F,M,n) iif (n < 2, M(:,:,1), true, @() F(F,M,n-1) * M(:,:,n)) ;
out = mprodf(mprodf, x, size(x, 3));
end
toc
end
function out = CumMProd2x2_AB(M) % Andrei Bobrov
s = size(M, 3);
out = M(:,:,1);
for ii = 2:s
out = out * M(:,:,ii);
end
end
R2016b/64/Win7:
Elapsed time is 0.011403 seconds. C-mex
Elapsed time is 3.884977 seconds. Loop
Elapsed time is 96.038754 seconds. Recursive anonymous function
I was surprised, that Andrei's loop is such slow, although it is clearly the nicest and cleaned solution. Let's try to unroll the loops like in the C-Code:
function out = CumMProd2x2_unroll(M)
q11 = M(1);
q21 = M(2);
q12 = M(3);
q22 = M(4);
c = 1;
for ii = 2:size(M, 3)
c = c + 4;
t11 = q11 * M(c) + q12 * M(c+1);
t21 = q21 * M(c) + q22 * M(c+1);
q12 = q11 * M(c+2) + q12 * M(c+3);
q22 = q21 * M(c+2) + q22 * M(c+3);
q11 = t11;
q21 = t21;
end
out = [q11, q12; q21, q22];
end
This 64 times faster than the direct approach "out * M(:,:,ii)":
Elapsed time is 0.061287 seconds. Unrolled
Obviously Matlab calls very smart highly optimized libraries for the matrix multiplication, which treat the tiny input with the same hammer method as a 1000x1000 matrix.
But this unrolled version is such ugly, that I would hesitate to use it in productive code. For x = rand(2, 2, 100000) I get the timings for 1000 iterations:
Elapsed time is 1.377695 seconds. C-mex
Elapsed time is 2.872356 seconds. M with unrolled mtimes
Only a factor 2! Another example, that loops are not such bad in Matlab compared to C.
2 个评论
Jos (10584)
2017-12-7
haha, I really liked my anonymous function approach, and did expect it to perform poorly, but that poor ... haha
Andrei Bobrov
2017-12-6
s = size(M)
out = M(:,:,1);
for ii = 2:s(3)
out = out*M(:,:,ii);
end
5 个评论
Jan
2017-12-7
+1: This is the nicest solution. That the multiplication of 2x2 matrices is much faster with hard coded algorithm is not a problem of this solution.
Although the C-Mex approach is faster, it would be very hard to generalize it for inputs beside 2x2xN arrays.
Matt J
2017-12-7
Although the C-Mex approach is faster, it would be very hard to generalize it for inputs beside 2x2xN arrays.
Just wanted to note that, while my solution based on MTIMESX is not as fast as Jan's for the 2x2xN case, it is applicable to arbitrary MxMxN arrays,
Matt J
2017-12-6
The following is not a one-line solution (for that just stick it in a function file) and requires MTIMESX from the File Exchange. However, I do see a few factors speed-up over a conventional for-loop,
out=M;
while size(out,3)>1
n=size(out,3);
if mod(n,2)
n=n-1;
A=out(:,:,1:2:n);
B=out(:,:,2:2:n);
out=cat(3,mtimesx(A,B),out(:,:,n+1));
else
A=out(:,:,1:2:n);
B=out(:,:,2:2:n);
out=mtimesx(A,B);
end
end
5 个评论
James Tursa
2017-12-7
编辑:James Tursa
2017-12-7
Side Note: MTIMESX by default calls BLAS library routines for matrix multiply so that it matches MATLAB for-loop m-code result, whereas MTIMESX with the 'SPEED' option will use hand-coded inline matrix multiply code for up to 5x5 size slices which may not match MATLAB for-loop m-code result exactly.
Sometime back I had a beta version of MTIMESX that implemented the matrix equivalent versions of 'prod' and 'cumprod'. Maybe it is time I dust that off and finish the implementation/testing so I can publish it.
Matt J
2017-12-7
That is strange, since I still see significant speed-up even with
mtimesx MATLAB
Steven Lord
2020-9-17
If you're using release R2020b or later, take a look at the pagemtimes function introduced in that release.
0 个评论
Jos (10584)
2017-12-6
Here is one using recursion without a for-loop; not faster though, and somewhat mysterious, but just nice :) ...
M = randi(5,[2 2 4]) ; % data
iif = @(varargin) varargin{2*find([varargin{1:2:end}], 1, 'first')}() ;
mprodf = @(F,M,n) iif (n < 2, M(:,:,1), true, @() F(F,M,n-1) * M(:,:,n)) ;
out = mprodf(mprodf,M,size(M,3)) % voila, it works!
3 个评论
Jos (10584)
2017-12-7
It is the inline version of this recursive m-file:
function X = mprod(M,n)
% X = mprod(M) returns M(:,:,1) * M(:,:,2) * ... * M(:,:,end)
% where M is a 3D array
if nargin==1
X = mprod(M,size(M,3)) ;
elseif n < 2
X = M(:,:,1) ;
else
X = mprod(M,n-1) * M(:,:,n) ;
end
另请参阅
类别
在 Help Center 和 File Exchange 中查找有关 Matrices and Arrays 的更多信息
产品
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!