Optimising Nearest Neighbor Program

3 次查看(过去 30 天)
Hi Guys,
I am trying to optmimise this code so that it runs in under 10 seconds for N=20k.
It currently takes around 40 seconds to run.
I think that I need to vecotirse some or all of the loops so that the calculations are done at the same time, but I cannot figure out how to do it.
Any help would be much appreciated as we are learning from home with next to no support from the University.
Here is the code.
s = zeros(N,N);
for n = 1:3
for a = 1:N
count = 0;
for b=1:N
off=0;
if (pos(n,a)<=0.25 && pos(n,b) >= 0.75)
off=1;
elseif pos(n,a)>=0.75 && pos(n,b)<=0.25
off=-1;
end
s(a,b)=s(a,b)+(pos(n,a)-pos(n,b)+off)^2.0;
if n == 3
s(a,b)=sqrt(s(a,b));
end
end
end
end
match = zeros(1,N);
for a=1:N
mindist=1e10;
for c=1:N
if (a~=c)
mindist=min(s(a,c),mindist);
if (mindist==s(a,c))
match(a)=c;
end
end
end
end
  4 个评论
Samuel Leeney
Samuel Leeney 2020-10-30
Thanks a lot I have had a look through the link and it does help me understand how the nearest neighbor works - the purpose of my question though is to learn about optimisation so I am really just trying to find a way to improve the efficiency of the code above (or any future code I write) rather than write a nearest neighbor program from scratch.

请先登录,再进行评论。

采纳的回答

Bruno Luong
Bruno Luong 2020-10-31
I put my comments as answer here, so you can accept if it helps
N=20000;
nd=3;
pos=rand(nd,N);
pos_r = reshape(pos,[nd 1 N]);
s = zeros(N);
for b=1:N
posb = pos_r(:,:,b);
off = (pos<=0.25 & posb>=0.75) - ...
(pos>=0.75 & posb<=0.25);
sb = sum((pos-posb+off).^2,1);
s(:,b) = sb(:);
end
s = sqrt(s);
s(1:N+1:end) = Inf;
[~,match] = min(s,[],2);
match = match.'; % row vector
Still if one dosn't have to deal with the odd OFFSET, the delaunay approach is much faster.

更多回答(3 个)

KSSV
KSSV 2020-10-30
编辑:KSSV 2020-10-30
N = 1000;
seed = 1234;
%seed=input('Random number seed = ');
rng(seed)
pos=rand(3,N);
a = 1:N ;
b = 1:N ;
[a,b] = meshgrid(a,b) ;
S = zeros(N,N) ;
for n = 1:3
off = zeros(N,N) ;
P = pos(n,:) ;
off(P(a)<=0.25 && P(b) >= 0.75) = +1 ;
off(P(a)>=0.75 && P(b)<=0.25) = -1 ;
S = S+(P(a)-P(b)+off).^2;
end
S = sqrt(S);
  9 个评论
KSSV
KSSV 2020-10-30
How did you calculate the time? Check the below:
N = 1000;
seed = 1234;
%seed=input('Random number seed = ');
rng(seed)
pos=rand(3,N);
t1 = tic ;
s = zeros(N,N);
for n = 1:3
for a = 1:N
count = 0;
for b=1:N
off=0;
if (pos(n,a)<=0.25 && pos(n,b) >= 0.75)
off=1;
elseif pos(n,a)>=0.75 && pos(n,b)<=0.25
off=-1;
end
s(a,b)=s(a,b)+(pos(n,a)-pos(n,b)+off)^2.0;
if n == 3
s(a,b)=sqrt(s(a,b));
end
end
end
end
t1 = toc(t1) ;
t2 = tic ;
a = 1:N ;
b = 1:N ;
[a,b] = meshgrid(a,b) ;
S = zeros(N,N) ;
for n = 1:3
off = zeros(N,N) ;
P = pos(n,:) ;
off(P(a)<=0.25 && P(b) >= 0.75) = +1 ;
off(P(a)>=0.75 && P(b)<=0.25) = -1 ;
S = S+(P(a)-P(b)+off).^2;
end
S = sqrt(S);
t2 = toc(t2) ;
Now check t1, t2 for different size inputs. My bet is always the second code will be faster. I am comparing only the first part off the code.
Samuel Leeney
Samuel Leeney 2020-10-30
Thanks,
I calculated the tim using the toc function too.
I have also used the 'Rune and Time' option and the vast majority of time is spent computing the code in the
for n = 1:3
loop.
The second code is not faster. You can try it yourself.

请先登录,再进行评论。


Bruno Luong
Bruno Luong 2020-10-30
编辑:Bruno Luong 2020-10-30
I remove your hanling of offset (not sure what is the purpose), and this is a much faster method using delaunay triangulation:
clear
N=20000;
nd=3;
pos=rand(nd,N);
tic
s = zeros(N,N);
for n = 1:3
for a = 1:N
count = 0;
for b=1:N
s(a,b)=s(a,b)+(pos(n,a)-pos(n,b))^2.0;
if n == 3
s(a,b)=sqrt(s(a,b));
end
end
end
end
match = zeros(1,N);
for a=1:N
mindist=1e10;
for c=1:N
if (a~=c)
mindist=min(s(a,c),mindist);
if (mindist==s(a,c))
match(a)=c;
end
end
end
end
toc % Elapsed time is 56.414344 seconds.
Find nearest neighbour of the same set of point in 2D or 3D
% INPUT: pos is array of size (nd x N), coordinates of N points in R^nd
tic
T = delaunay(pos.');
p = nchoosek(1:size(T,2),2);
P = T(:,p);
P = reshape(P,[],2);
P = unique(sort(P,2),'rows');
P1 = P(:,1);
P2 = P(:,2);
d2 = sum((pos(:,P2)-pos(:,P1)).^2,1);
A = [P1(:), d2(:), P2(:);
P2(:), d2(:), P1(:)];
A = sortrows(A,[1 2]);
b = [true; diff(A(:,1),1)>0];
A = A(b,:);
nn = A(:,3).'; % index of nearest neighbour
d = sqrt(A(:,2)).'; % corresponding distance
toc % Elapsed time is 0.349317 seconds.
isequal(match,nn) % 1
Now it doesn't help you for your practice, but I still pot it here for future readers who seek for fast method.
  4 个评论
Samuel Leeney
Samuel Leeney 2020-10-30
The code you've run is faster than mine but gives the wrong answer.
Bruno Luong
Bruno Luong 2020-10-30
编辑:Bruno Luong 2020-10-30
I run it 10 times and isequal(match,nn) return TRUE. So the answer match 10 times with random points.
> for k=1:10; benchnntest; end
Elapsed time is 49.127755 seconds.
Elapsed time is 0.516644 seconds.
ans =
logical
1
Elapsed time is 51.189459 seconds.
Elapsed time is 0.459055 seconds.
ans =
logical
1
Elapsed time is 50.431960 seconds.
Elapsed time is 0.465887 seconds.
ans =
logical
1
Elapsed time is 50.426246 seconds.
Elapsed time is 0.454885 seconds.
ans =
logical
1
Elapsed time is 50.651649 seconds.
Elapsed time is 0.567084 seconds.
ans =
logical
1
Elapsed time is 50.889422 seconds.
Elapsed time is 0.461514 seconds.
ans =
logical
1
Elapsed time is 50.678441 seconds.
Elapsed time is 0.491820 seconds.
ans =
logical
1
Elapsed time is 50.476219 seconds.
Elapsed time is 0.451430 seconds.
ans =
logical
1
Elapsed time is 52.659327 seconds.
Elapsed time is 0.443114 seconds.
ans =
logical
1
Elapsed time is 52.004992 seconds.
Elapsed time is 0.459873 seconds.
ans =
logical
1
>>

请先登录,再进行评论。


Image Analyst
Image Analyst 2020-10-30
MATLAB is column major order, which means that the left most indexes go faster because they are adjacent in memory. MATLAB goes down rows first, then moves over to the next column and goes down its rows. So this slow code
for row = 1 : rows
for col = 1 : columns
s(row, col) = whatever; % Col iterates fastest
end
end
will (or may be) be slower than this fast code
for col = 1 : columns
for row = 1 : rows
s(row, col) = whatever; % row iterates fastest
end
end
Note that, in your code, n is your left most index of your arrays, yet you had the n loop as the outer loop, which is the slowest possible to do it. If possible, see if you can move n to an inner loop. I've had luck in the past getting nested loops to speed up doing that.
  8 个评论
Samuel Leeney
Samuel Leeney 2020-10-30
Bruno, this works peftectly, thank you!
Image Analyst
Image Analyst 2020-10-30
Samuel, going by your description, I'd try something like this to find the closest point.
N = 1000;
xyz = rand(N, 3); % Get N randomly located points in 3-D.
for k = 1 : N
% Get the squared distance of point k to every other point in the array.
distancesSquared = ((xyz(k, 1) - xyz(:, 1)) .^2 + ...
(xyz(k, 2) - xyz(:, 2)) .^2 + ...
(xyz(k, 3) - xyz(:, 3)) .^2);
% We don't want to consider the distance of the point to itself, so set any zeros to infinity.
distancesSquared(distancesSquared==0) = inf;
% Find the min value and the index of that min for the other points.
[minDist2, index] = min(distancesSquared);
% Print it out.
fprintf('Point %d at (%.2f, %.2f, %.2f) is closest to point %d at (%.2f, %.2f, %.2f).\n',...
k, xyz(k, 1), xyz(k, 2), xyz(k, 3), index, xyz(index, 1), xyz(index, 2), xyz(index, 3));
end
It prints out stuff like:
Point 1 at (0.70, 0.37, 0.03) is closest to point 312 at (0.70, 0.37, 0.04).
Point 2 at (0.09, 0.71, 0.91) is closest to point 918 at (0.10, 0.72, 0.94).
Point 3 at (0.53, 0.95, 0.47) is closest to point 50 at (0.54, 0.93, 0.43).
etc.
Point 998 at (0.03, 0.67, 0.34) is closest to point 972 at (0.05, 0.59, 0.39).
Point 999 at (0.99, 0.87, 0.16) is closest to point 592 at (0.94, 0.93, 0.16).
Point 1000 at (0.54, 0.34, 0.42) is closest to point 540 at (0.55, 0.33, 0.42).
It doesn't do the stuff about the edges of the square though.

请先登录,再进行评论。

类别

Help CenterFile Exchange 中查找有关 Logical 的更多信息

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by