Computing mean by group (need help with speed!)
显示 更早的评论
Hi everyone,
I need to demean the columns of matrix Xw (200000 x 24) by group id. To do this, I need to compute means of the columns of Xw for each group identified by the vector id (200000 x 1). I have written the following code, which is the fastest I could do in light of this very similar post:
%DEMEANING
[cid,indx_i,indx_j]=unique(id);
for i=1:size(Xw,2);
bb = accumarray(indx_j, Xw(:,i), [], @mean);
Xw(:,i) = Xw(:,i) - bb(indx_j);
end
This is faster than the code suggested at the post linked above. Importantly, the matrix Xw is rather sparse, and only contains zeros and ones (it is a matrix of dummy variables).
My question: Is there any way to speed up this process further? It is quite time consuming as it stands, and given that this itself is inside of a loop, it is slowing everything else down. Please help!! Creative solutions welcome.
Thanks!
1 个评论
Matthew Eicholtz
2015-9-8
"It is quite time consuming as it stands"...how much time are we talking? I ran your version on my machine and it took around 300 ms. I understand this becomes a problem if it is inside another loop, but I'm just curious what your benchmark is.
采纳的回答
更多回答(5 个)
Kelly Kearney
2015-9-8
编辑:Kelly Kearney
2015-9-8
This example uses my aggregate function. It's basically a wrapper around accumarray, but allows you to apply the results to multiple columns without recalling accumarray each time.
% Some fake data
n = 200000;
Xw = rand(n,20);
id = ceil(rand(n,1) * 10);
% Group and subtract mean
[idg, Xwg] = aggregate(id, Xw, @(x) bsxfun(@minus, x, mean(x,1)));
% Put back in original order
[~, srt] = aggregate(id, (1:n)');
[~, isrt] = sort(cat(1, srt{:}));
Xwg = cat(1, Xwg{:});
Xwg = Xwg(isrt,:);
Perhaps I should add the index retrieval and resort part to the function itself... I'll add that to my todo list. When I timed this, the reordering part was the most time-consuming part, so that may be able to be sped up a bit more.
EDIT:
Returning the indices turned out to be a quick and easy change to the function. Adding that eliminates the second call to aggregate in my example. I've uploaded the change to GitHub, so you'll want to clone the code from there; MatlabCentral may not grab the updates until the end of the day. New example:
[idg, Xwg, idx] = aggregate(id, Xw, @(x) bsxfun(@minus, x, mean(x,1)));
[~, isrt] = sort(cat(1, idx{:}));
Xwg = cat(1, Xwg{:});
Xwg = Xwg(isrt,:);
This version is takes about 60% the time of your original code.
Guillaume
2015-9-8
I have no idea if it would be faster but my suggestion would be to split your Xw into a submatrix for each ID. You can then calculate the mean and subtract it for all the columns all at once:
Xwid = arrayfun(@(cid) Xw(id == cid, :), unique(id), 'UniformOutput', false); %split Xw into submatrices
Xwid = cellfun(@(xw) xw - mean(xw), Xwsplit, 'UniformOutput', false); %remove row mean in each submatrix
Andrei Bobrov
2015-9-8
编辑:Andrei Bobrov
2015-9-8
[~,~,c]=unique(id);
[ii,jj] = ndgrid(c,1:size(Xw,2));
bb = accumarray([ii(:),jj(:)], Xw(:), [], @mean);
out = Xw - bb(c,:);
5 个评论
John
2015-9-8
Andrei Bobrov
2015-9-8
Xw(:) is vector.
I can't test right now, but sometimes it is more efficient..
- to accumulate data using the native/default sum,
- then to accumulate ones to get counts,
- and to divide the first by the second to get the means,
than to ask ACCUMARRAY to use an accumulation function passed by its handle.
Matthew Eicholtz
2015-9-8
This is the fastest version for me. Although there is no need to call unique. Try this instead:
[ii,jj] = ndgrid(id,1:size(Xw,2));
bb = accumarray([ii(:),jj(:)], Xw(:), [], @mean);
out = Xw - bb(id,:);
John
2015-9-9
Vahidreza Jahanmard
2022-5-14
May the following command work better for your issue
[Groups,~] = findgroups(id);
bb = splitapply(@mean, Xw, Groups);
if you also want to deal with nan values:
bb = splitapply(@(x)mean(x,'omitnan'), Xw, Groups);
Steven Lord
2022-5-14
0 个投票
Since this question was asked (in release R2018b) we introduced the grouptransform function. I believe using the 'meancenter' method or specifying the method as a function handle involving detrend may make this a one-line operation.
类别
在 帮助中心 和 File Exchange 中查找有关 Function Creation 的更多信息
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!
