How to get shapley value for Neural Network trained on matlab? it keeps error...
4 次查看(过去 30 天)
显示 更早的评论
Hi there,
I wanted to get shapley value of my pre-trained ANN.
it is regression model.
it's input's shape is 7*5120 double
and output is 1*5120 double.
I'm confused with idea of shapley.. sorry
3 个评论
Angelo Yeo
2024-8-26
Can you be more specific about your model and the error message? It's the best if you can share your model (code and data) and the reproduction steps for the error.
回答(1 个)
Angelo Yeo
2024-8-26
I do not have the model and dataset, so I used a random samples. The key is to use yticklabels. Would this work for you?
clc;
clear;
%% Shapley 값 계산
% demo neural network
x = randn(7, 150);
t = randn(1, 150);
net = fitnet(10);
net = configure(net,x,t);
% view(net)
f = @(x) net(x')'; % 인공신경망 모델 함수를 정의
x_veri_shapley = x(:,101:end)'; % 각 행이 하나의 샘플이 되도록 전치
x_train_shapley = x(:, 1:100)'; % 각 행이 하나의 샘플이 되도록 전치
% 샘플링 예시
num_samples = size(x_veri_shapley,1); % 샘플링할 데이터 수
idx = randperm(size(x_veri_shapley, 1), num_samples);
x_veri_shapley_sampled = x_veri_shapley(idx, :);
% % 병렬 처리 활성화
explainer = shapley(f, x_train_shapley, 'QueryPoints', x_veri_shapley_sampled,'UseParallel', false);
%%
% plot(explainer)
%%
% MeanAbsoluteShapley table을 복사
shapley_table = explainer.MeanAbsoluteShapley;
% 변수 이름 변경
desired_variable_names = ["PGA", "Dur_{sig}", "Sa_{max}", "Tm", "CAV_{max}", "Arias_{max}", "f_{1}"];
shapley_table.Predictor = desired_variable_names(:); % 새 변수 이름으로 교체
% Shapley 값과 변수 이름을 Shapley 값의 내림차순으로 정렬
[sorted_values, sort_index] = sort(shapley_table.ShapleyValue, 'ascend');
sorted_names = shapley_table.Predictor(sort_index);
% 막대 그래프 그리기 (큰 값부터 작은 값 순서로)
% figure;
% barh(sorted_values);
% set(gca, 'YTickLabel', sorted_names);
% xlabel('Shapley 절댓값의 평균');
% ylabel('예측 변수');
% title('Shapley 중요도 플롯');
%%
close all;
figure(10);
plot(explainer,QueryPointIndices=30);
hAxes = gca;
hAxes.TickLabelInterpreter = "tex";
yticklabels(hAxes, sorted_names) % use "yticklabels" to change the YTickLabels
figure(11);
plot(explainer);
hAxes = gca;
hAxes.TickLabelInterpreter = "tex";
yticklabels(hAxes, sorted_names) % use "yticklabels" to change the YTickLabels
figure(12);
swarmchart(explainer);
hAxes = gca;
hAxes.TickLabelInterpreter = "tex";
yticklabels(hAxes, sorted_names) % use "yticklabels" to change the YTickLabels
0 个评论
另请参阅
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!