Create plot with gradient descend vectors for function of 2 variables
29 次查看(过去 30 天)
显示 更早的评论
Hello,
I would like to extend this code to a function of two variables z = x.^2+y.^2.
I have the following code for a function y = f(x). Gradient is a vector that shows the direction of the fastet rise of the value, while gradient descend vector shows the direction of the steepest fall.
% Define the function
f = @(x) 2*x.^2 - 4*x - 2;
grad_f = @(x) 4*x - 4;
% Plot the function
x = linspace(-5, 5, 100);
y = f(x);
figure;
plot(x, y, 'LineWidth', 1.5);
hold on;
% Set the starting point for gradient descent
start_x = -2;
start_y = f(start_x);
% Plot the starting point
scatter(start_x, start_y, 100, 'r', 'filled');
% Points to show gradient vectors
points_to_show = -2;
% Gradient descent parameters
alpha = 0.5;
deltax = 0.5;
for i = 1:numel(points_to_show)
x = points_to_show(i);
% Plot the point
scatter(x, f(x), 100, 'g', 'filled');
% Calculate the gradient at the current point using the gradient function
gradient_at_x = grad_f(x);
% Plot the gradient vector at the specified point
quiver(x, f(x), -alpha*(x+deltax), -alpha*gradient_at_x*(x+deltax), 'Color', 'r', 'LineWidth', 1.5);
end
hold off;
% Add labels and title
xlabel('X-axis');
ylabel('Y-axis');
title('Gradient Descent Vectors');
% Adjust axis limits if needed
grid on;
This is what I have done so far, but I'm not really able to figure out the way to plot the vector. Can you give advice or hints? What I noticed by searching the web is that gradient descent vectors are actually plotted on contours and not on the surface itself.
% Define the function
z = @(x, y) x.^2 + y.^2;
% Point of interest
point_x = -1;
point_y = -1;
point_z = z(point_x, point_y);
deltax = 0.5;
deltay = 0.5;
% Calculate the gradient analytically
gradient_x = 2*point_x;
gradient_y = 2*point_y;
gradient_z = 0;
% Plot the surface
[x_vals, y_vals] = meshgrid(linspace(-2, 2, 50), linspace(-2, 2, 50));
z_vals = z(x_vals, y_vals);
surf(x_vals, y_vals, z_vals);
hold on;
% Plot the point of interest
scatter3(point_x, point_y, point_z, 100, 'r', 'filled');
% Plot the gradient vector at the specified point
quiver3(point_x, point_y, point_z, gradient_x, gradient_y, gradient_z, 'Color', 'r', 'LineWidth', 2);
hold off;
% Add labels and title
xlabel('X-axis');
ylabel('Y-axis');
zlabel('Z-axis');
Thank you.
0 个评论
回答(1 个)
Austin M. Weber
2024-2-1
My apologies for the double-response. I accidentally answered your question as a comment rather than as an answer.
If I am understanding correctly, I think I have found a simple way of doing what you want.
First, let me copy-paste your code without the quiver3 function:
% Define the function
z = @(x, y) x.^2 + y.^2;
% Point of interest
point_x = -1;
point_y = -1;
point_z = z(point_x, point_y);
deltax = 0.5;
deltay = 0.5;
% Calculate the gradient analytically
gradient_x = 2*point_x;
gradient_y = 2*point_y;
gradient_z = 0;
% Plot the surface
[x_vals, y_vals] = meshgrid(linspace(-2, 2, 50), linspace(-2, 2, 50));
z_vals = z(x_vals, y_vals);
surfc(x_vals, y_vals, z_vals,'EdgeColor','none','FaceColor','interp','FaceAlpha',0.7);
hold on;
% Plot the point of interest
scatter3(point_x, point_y, point_z, 30, 'r', 'filled');
% Add labels and title
xlabel('X-axis');
ylabel('Y-axis');
zlabel('Z-axis');
view(-22,6)
I swapped the surf function for surfc which plots a contour map underneath the surface plot. I also got rid of the edge lines and interpolated the colors to make the visalization less busy. You can revert to your original surf plot if you prefer. Moreover, I used the view function to change the azimuth angle to get a different perspective of the axes.
Now, to add the vector arrow, I am going to calculate an infinitesimally small change for each x, y, and z coordinate of your point of interest. The change in slope at this point is what I am going to use to define the vector arrow.
% Infinitesimally small change
infchange = 0.0000000001;
point_xinf = point_x + infchange;
point_yinf = point_y + infchange;
point_zinf = z(point_xinf, point_yinf);
% Calculate the difference relative to the original point
dx = point_xinf - point_x;
dy = point_yinf - point_y;
dz = point_zinf - point_z;
% Add vector arrow to map
quiver3(point_x,point_y,point_z,...
dx*20e8,dy*20e8,dz*20e8,'Color','r',...
'LineWidth',1,...
'ShowArrowHead','on',...
'MaxHeadSize',0.6)
0 个评论
另请参阅
类别
在 Help Center 和 File Exchange 中查找有关 Surface and Mesh Plots 的更多信息
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!