graph convolution nueral network

23 次查看(过去 30 天)
NGR MNFD
NGR MNFD 2024-11-14,14:40
评论: William Rose 2024-11-14,16:50
Hello and courtesy. I have implemented the following code in the MATLAB 2021 program, but I get an error to implement the stgcn layer and attention layers. I don't know what I should do to implement this layer in MATLAB? Also, in this code, I want to use time and place data at the same time. What changes should I make in the code to include the time of each sample? thanks
clc; clear all; close all;
names = {'lurch', 'antalgic', 'normal', 'steppage', 'trendelenburg', 'stiff_legged'};
num_names = length(names);
basePath = 'C:\Users\HAMSHAHRI\Documents\MATLAB\movementdisorderopen';
data_cells = cell(1, num_names);
% بارگذاری داده‌ها
for h = 1:10 % برای هر نمونه انسان
for n = 1:num_names % برای هر کلاس بیماری
for s = 1:6 % برای هر نمونه اسکلت
filename = fullfile(basePath, sprintf('human%d_%s%d_SkeletonData%d.csv', h, names{n}, s, s));
data = readmatrix(filename);
if isempty(data_cells{n})
data_cells{n} = data; % اگر آرایه خالی است، داده را قرار می‌دهیم
else
data_cells{n} = [data_cells{n}; data]; % در غیر این صورت داده‌ها را الحاق می‌کنیم
end
end
end
end
% جداسازی کلاسها
for n = 1:num_names
assignin('base', names{n}, data_cells{n});
end
% حالا یک حلقه برای تمام گروه‌ها می‌زنیم
for n = 1:num_names
% انتخاب نام گروه و متغیر مربوطه
group_name = names{n};
x = eval(group_name); % استفاده از eval برای فراخوانی متغیر با نام داینامیک
eval([group_name '_X = [];']);
eval([group_name '_Y = [];']);
eval([group_name '_Z = [];']);
[nRows, nCols] = size(x);
% پردازش داده‌ها برای X
for i = 3:4:nCols
if i <= nCols
x(:, i) = x(:, i) - x(:, 7);
eval([group_name '_X = [', group_name '_X; x(:,i).'';];']);
end
end
% پردازش داده‌ها برای Y
for j = 4:4:nCols
if j <= nCols
x(:, j) = x(:, j) - x(:, 8);
eval([group_name '_Y = [', group_name '_Y; x(:,j).'';];']);
end
end
% پردازش داده‌ها برای Z
for k = 5:4:nCols
if k <= nCols
x(:, k) = x(:, k) - x(:, 9);
eval([group_name '_Z = [', group_name '_Z; x(:,k).'';];']);
end
end
end
% creat new matrix after egocentric
ClassMatrices = cell(1, length(names));
for classIdx = 1:length(names)
currentClass = names{classIdx};
eval(['LURCH_X_T = ' currentClass '_X'';']); % taranahade
eval(['LURCH_Y_T = ' currentClass '_Y'';']); % taranahade
eval(['LURCH_Z_T = ' currentClass '_Z'';']); % taranahade
NewMatrix = [];numCols = size(LURCH_X_T, 2);
for i = 1:numCols
col_X = LURCH_X_T(:, i);
col_Y = LURCH_Y_T(:, i);
col_Z = LURCH_Z_T(:, i);
NewMatrix = [NewMatrix, col_X, col_Y, col_Z];
end
ClassMatrices{classIdx} = NewMatrix;
end
lurchData = ClassMatrices{1};antalgicData = ClassMatrices{2};normalData = ClassMatrices{3};
steppageData = ClassMatrices{4};trendelenburgData = ClassMatrices{5};stiff_leggedData = ClassMatrices{6};
% حذف ستون های صفر که مربوط به مفصل میانی شانه هستmid-spine
lurchData(:, [4, 5, 6]) = [];antalgicData(:, [4, 5, 6]) = [];normalData(:, [4, 5, 6]) = [];
steppageData(:, [4, 5, 6]) = [];trendelenburgData(:, [4, 5, 6]) = [];stiff_leggedData(:, [4, 5, 6]) = [];
%% velocity & acceleration for 6 classes
% سرعت: مشتق موقعیت نسبت به زمان و شتاب: مشتق سرعت نسبت به زمان
% lurch
V_X1=[];Velocity_lurch=[];A_X1=[];ACC_lurch=[];time = lurch(:, 1);
for i = 1:72
velocity_X = diff(lurchData(:, i)) ./ diff(time);
velocity_X(end+1) = NaN;V_X1=[V_X1;velocity_X'];Velocity_lurch=V_X1';
acceleration_X = diff(velocity_X) ./ diff(time);
acceleration_X(end+1) = NaN;A_X1=[A_X1;acceleration_X'];ACC_lurch=A_X1';
end
% antalgic
V_X2=[];Velocity_antalgic=[];A_X2=[];ACC_antalgic=[];time = antalgic(:, 1);
for i = 1:72
velocity_X = diff(antalgicData(:, i)) ./ diff(time);
velocity_X(end+1) = NaN;V_X2=[V_X2;velocity_X'];Velocity_antalgic=V_X2';
acceleration_X = diff(velocity_X) ./ diff(time);
acceleration_X(end+1) = NaN;A_X2=[A_X2;acceleration_X'];ACC_antalgic=A_X2';
end
% normal
V_X3=[];Velocity_normal=[];A_X3=[];ACC_normal=[];time = normal(:, 1);
for i = 1:72
velocity_X = diff(normalData(:, i)) ./ diff(time);
velocity_X(end+1) = NaN;V_X3=[V_X3;velocity_X'];Velocity_normal=V_X3';
acceleration_X = diff(velocity_X) ./ diff(time);
acceleration_X(end+1) = NaN;A_X3=[A_X3;acceleration_X'];ACC_normal=A_X3';
end
% steppage
V_X4=[];Velocity_steppage=[];A_X4=[];ACC_steppage=[];time = steppage(:, 1);
for i = 1:72
velocity_X = diff(steppageData(:, i)) ./ diff(time);
velocity_X(end+1) = NaN;V_X4=[V_X4;velocity_X'];Velocity_steppage=V_X4';
acceleration_X = diff(velocity_X) ./ diff(time);
acceleration_X(end+1) = NaN;A_X4=[A_X4;acceleration_X'];ACC_steppage=A_X4';
end
% trendelenburg
V_X5=[];Velocity_trendelenburg=[];A_X5=[];ACC_trendelenburg=[];time = trendelenburg(:, 1);
for i = 1:72
velocity_X = diff(trendelenburgData(:, i)) ./ diff(time);
velocity_X(end+1) = NaN;V_X5=[V_X5;velocity_X'];Velocity_trendelenburg=V_X5';
acceleration_X = diff(velocity_X) ./ diff(time);
acceleration_X(end+1) = NaN;A_X5=[A_X5;acceleration_X'];ACC_trendelenburg=A_X5';
end
% stiff_legged
V_X6=[];Velocity_stiff_legged=[];A_X6=[];ACC_stiff_legged=[];time = stiff_legged(:, 1);
for i = 1:72
velocity_X = diff(stiff_leggedData(:, i)) ./ diff(time);
velocity_X(end+1) = NaN;V_X6=[V_X6;velocity_X'];Velocity_stiff_legged=V_X6';
acceleration_X = diff(velocity_X) ./ diff(time);
acceleration_X(end+1) = NaN;A_X6=[A_X6;acceleration_X'];ACC_stiff_legged=A_X6';
end
%% Creat CNN
positionData = [lurchData; antalgicData; normalData; steppageData; trendelenburgData; stiff_leggedData];
velocityData = [Velocity_lurch; Velocity_antalgic; Velocity_normal; Velocity_steppage; Velocity_trendelenburg; Velocity_stiff_legged];
accelerationData = [ACC_lurch; ACC_antalgic; ACC_normal; ACC_steppage; ACC_trendelenburg; ACC_stiff_legged];
% آماده‌سازی ورودی‌ها برای شبکه CNN
X = cat(3, positionData, velocityData, accelerationData); % ترکیب داده‌ها به ابعاد صحیح
% برچسب‌ها (Y)
Y = categorical([
1*ones(size(lurchData, 1), 1);
2*ones(size(antalgicData, 1), 1);
3*ones(size(normalData, 1), 1);
4*ones(size(steppageData, 1), 1);
5*ones(size(trendelenburgData, 1), 1);
6*ones(size(stiff_leggedData, 1), 1)
]);
numSamples = size(X, 1); % تعداد نمونه‌ها در X
rng(0); % تنظیم رندوم برای تکرارپذیری
randomIndices = randperm(numSamples); % ایجاد اندیس‌های تصادفی
% تعداد نمونه‌های آموزش و تست
numTrain = floor(0.7 * numSamples); % 70% برای آموزش
numTest = numSamples - numTrain; % 30% برای تست
XTrain = X(randomIndices(1:numTrain), :, :, :); % انتخاب 70% داده‌ها برای آموزش
YTrain = Y(randomIndices(1:numTrain)); % برچسب‌های آموزشی
XTest = X(randomIndices(numTrain+1:end), :, :, :); % انتخاب 30% داده‌ها برای تست
YTest = Y(randomIndices(numTrain+1:end)); % برچسب‌های تست
% تغییر ابعاد XTrain و XTest به [تعداد نمونه‌ها, 72, 3, 3]
XTrain = cat(4, XTrain, zeros(size(XTrain, 1), size(XTrain, 2), size(XTrain, 3), 1));
XTest = cat(4, XTest, zeros(size(XTest, 1), size(XTest, 2), size(XTest, 3), 1));
% افزودن بعد چهارم به yTrain و yTest
YTrain = double(YTrain); % تبدیل YTrain به double
YTest = double(YTest); % تبدیل YTest به double
YTrain = cat(4, YTrain, zeros(size(YTrain, 1), 1, 1, 1));
YTest = cat(4, YTest, zeros(size(YTest, 1), 1, 1, 1));
imageInputLayer([72, 3, 3], 'Normalization', 'none', 'Name', 'Input');
% ساختار شبکه CNN با Mean Pooling (Average Pooling)
layers = [
imageInputLayer([72, 3, 3], 'Normalization', 'none', 'Name', 'Input') % ورودی
convolution2dLayer(3, 32, 'Padding', 'same', 'Name', 'Conv1') % لایه کانولوشن 1
batchNormalizationLayer('Name', 'BatchNorm1') % نرمال‌سازی دسته‌ای
reluLayer('Name', 'ReLU1') % فعال‌سازی ReLU
averagePooling2dLayer(2, 'Stride', 1, 'Name', 'AvgPool1') % لایه MeanPooling با stride=1
convolution2dLayer(3, 64, 'Padding', 'same', 'Name', 'Conv2') % لایه کانولوشن 2
batchNormalizationLayer('Name', 'BatchNorm2') % نرمال‌سازی دسته‌ای
reluLayer('Name', 'ReLU2') % فعال‌سازی ReLU
averagePooling2dLayer(2, 'Stride', 1, 'Name', 'AvgPool2') % لایه MeanPooling با stride=1
convolution2dLayer(3, 128, 'Padding', 'same', 'Name', 'Conv3') % لایه کانولوشن 3
batchNormalizationLayer('Name', 'BatchNorm3') % نرمال‌سازی دسته‌ای
reluLayer('Name', 'ReLU3') % فعال‌سازی ReLU
fullyConnectedLayer(6, 'Name', 'FC') % لایه Fully Connected
softmaxLayer('Name', 'Softmax') % لایه Softmax برای طبقه‌بندی
classificationLayer('Name', 'Output') % لایه نهایی
];
% گزینه‌های آموزش
options = trainingOptions('sgdm', ...
'MaxEpochs', 50, ...
'InitialLearnRate', 0.01, ...
'Verbose', false, ...
'Plots', 'training-progress');
% YTest = categorical(YTest); % تبدیل YTest به categorical
% YTrain = squeeze(YTrain); % حذف ابعاد اضافی
% YTrain = categorical(YTrain); % تبدیل به categorical
% آموزش شبکه
net = trainNetwork(XTrain, YTrain, layers, options);
% ارزیابی مدل
YPred = classify(net, XTest); % پیش‌بینی برچسب‌ها برای داده‌های تست
accuracy = sum(YPred == YTest) / numel(YTest); % محاسبه دقت مدل
disp(['Accuracy: ', num2str(accuracy)]);
%% Creat Graph
numClasses = 6; % تعداد کلاس‌ها
numNodes = 25; % تعداد گره‌ها (مفاصل)
numFeatures = 72; % تعداد ویژگی‌ها (ستون‌ها در داده‌های ورودی)
% اتصالات بین مفاصل (یال‌ها)
edges = [3 4; 3 21; 21 5; 5 6; 6 7; 7 8; 8 22; 7 23;
21 9; 9 10; 10 11; 11 12;11 25; 12 24; 21 2; 2 1;
1 13; 13 14; 14 15; 15 16; 1 17;17 18; 18 19; 19 20];
% ایجاد گراف
G = graph(edges(:,1), edges(:,2));
% % % ایجاد گراف جهت‌دار
% % G = digraph(edges(:,1), edges(:,2)); % گراف جهت‌دار
% % weights = rand(size(edges, 1), 1); % وزن‌های تصادفی برای یال‌ها
% % G = digraph(edges(:,1), edges(:,2), weights); % گراف جهت‌دار با وزن
position=[lurchData;antalgicData;normalData;steppageData;...
trendelenburgData;stiff_leggedData];
velocityData = [ Velocity_lurch;Velocity_antalgic;Velocity_normal;Velocity_steppage;...
Velocity_trendelenburg;Velocity_stiff_legged];
accelerationData = [ACC_lurch;ACC_antalgic;ACC_normal;ACC_steppage;...
ACC_trendelenburg;ACC_stiff_legged];
% ورودی‌ها برای شبکه‌ها
positionInput = imageInputLayer([72, 1, 1], 'Normalization', 'none', 'Name', 'PositionyInput'); % سرعت
velocityInput = imageInputLayer([72, 1, 1], 'Normalization', 'none', 'Name', 'VelocityInput'); % سرعت
accelerationInput = imageInputLayer([72, 1, 1], 'Normalization', 'none', 'Name', 'AccelerationInput'); % شتاب
% Batch Normalization
positionNorm = batchNormalizationLayer('Name', 'PositionBatchNorm');
velocityNorm = batchNormalizationLayer('Name', 'VelocityBatchNorm');
accelerationNorm = batchNormalizationLayer('Name', 'AccelerationBatchNorm');
% ST-GCN برای موقعیت
stgcnLayer1_position = stgcnLayer(6, 64, 'Name', 'ST-GCNLayer1-Position');
stgcnLayer2_position = stgcnLayer(64, 48, 'Name', 'ST-GCNLayer2-Position');
attentionLayer1_position = attentionLayer('Name', 'AttentionLayer1-Position');
% استفاده از self-attention برای موقعیت
positionWithAttention = custom_attention_layer1(positionInput, 3); % 3 یعنی d_k
stgcnLayer3_position = stgcnLayer(64, 16, 'Name', 'ST-GCNLayer3-Position');
attentionLayer2_position = attentionLayer('Name', 'AttentionLayer2-Position');
% ST-GCN برای سرعت
stgcnLayer1_velocity = stgcnLayer(6, 64, 'Name', 'ST-GCNLayer1-Velocity');
stgcnLayer2_velocity = stgcnLayer(64, 48, 'Name', 'ST-GCNLayer2-Velocity');
attentionLayer1_velocity = attentionLayer('Name', 'AttentionLayer1-Velocity');
velocityWithAttention = custom_attention_layer(velocityInput, 3); % 3 یعنی d_k
stgcnLayer3_velocity = stgcnLayer(64, 16, 'Name', 'ST-GCNLayer3-Velocity');
attentionLayer2_velocity = attentionLayer('Name', 'AttentionLayer2-Velocity');
velocityWithAttention = custom_attention_layer(velocityInput, 3); % 3 یعنی d_k
% ST-GCN برای شتاب
stgcnLayer1_acceleration = stgcnLayer(6, 64, 'Name', 'ST-GCNLayer1-ACC');
stgcnLayer2_acceleration = stgcnLayer(64, 48, 'Name', 'ST-GCNLayer2-ACC');
attentionLayer1_acceleration = attentionLayer('Name', 'AttentionLayer1-ACC');
accelerationWithAttention = custom_attention_layer(accelerationInput, 3); % 3 یعنی d_k
stgcnLayer3_acceleration = stgcnLayer(64, 16, 'Name', 'ST-GCNLayer3-ACC');
attentionLayer2_acceleration = attentionLayer('Name', 'AttentionLayer2-ACC');
accelerationWithAttention = custom_attention_layer(accelerationInput, 3); % 3 یعنی d_k
% Concatenate ویژگی‌های موقعیت و سرعت و شتاب
concatLayer = concatenationLayer(1, 2, 3, 'Name', 'Concatenate');
% ST-GCN لایه 4
stgcnLayer4 = stgcnLayer(48, 64, 'Name', 'ST-GCNLayer4');
attentionLayer = attentionLayer('Name', 'AttentionLayer');
% ST-GCN لایه 5
stgcnLayer5 = stgcnLayer(64, 128, 'Name', 'ST-GCNLayer5');
attentionLayer2 = attentionLayer('Name', 'AttentionLayer2');
% Global Average Pooling (GAP)
gapLayer = globalAveragePooling2dLayer('Name', 'GAP');
% Fully Connected Layer برای پیش‌بینی کلاس‌ها
fcLayer = fullyConnectedLayer(6, 'Name', 'FC');
% شبکه برای موقعیت
layers_position = [
positionInput
positionNorm
stgcnLayer1_position
stgcnLayer2_position
attentionLayer1_position
stgcnLayer3_position
attentionLayer2_position
];
% شبکه برای سرعت
layers_velocity = [
velocityInput
velocityNorm
stgcnLayer1_velocity
stgcnLayer2_velocity
attentionLayer1_velocity
stgcnLayer3_velocity
attentionLayer2_velocity
];
% شبکه برای شتاب
layers_acceleration = [
accelerationInput
accelerationNorm
stgcnLayer1_acceleration
stgcnLayer2_acceleration
attentionLayer1_acceleration
stgcnLayer3_acceleration
attentionLayer2_acceleration
];
% ترکیب خروجی‌ها
lgraph_position = layerGraph(layers_position);
lgraph_velocity = layerGraph(layers_velocity);
lgraph_acceleration = layerGraph(layers_acceleration);
% ترکیب خروجی‌ها برای هر ویژگی
lgraph = layerGraph([
lgraph_position
lgraph_velocity
lgraph_acceleration
concatLayer
stgcnLayer4
attentionLayer
stgcnLayer5
attentionLayer2
gapLayer
fcLayer
]);
% آموزش مدل
options = trainingOptions('sgdm', ...
'MaxEpochs', 50, ...
'InitialLearnRate', 0.01, ...
'Verbose', false, ...
'Plots', 'training-progress');
net = trainNetwork(X_train, Y_train, lgraph, options);
% ارزیابی مدل
YPred = classify(net, XTest); % پیش‌بینی برچسب‌ها برای داده‌های تست
accuracy = sum(YPred == YTest) / numel(YTest); % محاسبه دقت مدل
disp(['Accuracy: ', num2str(accuracy)]);
  1 个评论
William Rose
William Rose 2024-11-14,16:50
I understand that you are trying to train a network to recognize normal gait and movement disorders based on motion capture data. You want to use a spatio-temporal graph convolutional network.
You have posted hundreds of lines of code. Please post the simplest possible example that illustrates the problem you want to solve. Then you will be more likely to get help on this site.

请先登录,再进行评论。

回答(0 个)

产品


版本

R2021b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by