Training a model using different shapes

7 次查看(过去 30 天)
I have training data of around 1000 shapes (different sizes and dimensions). This data is in a cell array, where each cell is a shape and within the cell there is an array of size n by 2. The n is the number of data points that draw the shape and 2 columns are for the x and y coordinates of these points. For the training data, these points are ordered so that if a straight line connects the points as they are ordered in the array it will draw out the desired shape accurately.
I would like to train a model to learn from those 1000 shapes so that if given a new shape and the points are not in order, the model is able to re order the points and draw the shape based on what it has learned from its training of all other shapes.
I am very new to the concept of training models, what I have used in matlab so far is giving the neural networks a set of inputs and an outputs and it learns what it can, but here I have different cases that should be learned from and I’m not sure added all those points to one long array of coordinates is the right thing to do because it defeats the purpose of the distinct shapes and the order of the points, any advice is appreciated
% Step 1: Prepare the data
% Load the x and y coordinates of your shapes
load('shapes.mat');
% Concatenate the x and y coordinates of each shape
data = [];
for i = 1:numel(shapes)
data = [data; shapes{i}];
end
% Step 2: Define the CNN architecture
layers = [
sequenceInputLayer([size(data, 1) 2])
lstmLayer(64,'OutputMode','sequence')
dropoutLayer(0.1)
lstmLayer(64,'OutputMode','sequence')
dropoutLayer(0.1)
fullyConnectedLayer(size(data, 1)*2)
regressionLayer];
% Step 3: Train the model on all shapes
% Split the data into training and test sets
[XTrain,XTest,YTrain,YTest] = split_data(data, 0.8);
options = trainingOptions('adam', ...
'InitialLearnRate', 0.01, ...
'MaxEpochs',4, ...
'Shuffle','every-epoch', ...
'Verbose',false, ...
'Plots','training-progress');
net = trainNetwork(XTrain,YTrain,layers,options);
% Step 4: Use the trained model to make predictions on new shapes
predictedCoordinates = predict(net,XTest);
  7 个评论
Ahmed
Ahmed 2023-1-27
i want to first train the model on these shapes before testing on randomly ordered shapes
Ahmed
Ahmed 2023-1-27
@KSSV is there any resource you could direct me to where I can learn how to train a model to order points based on different examples please

请先登录,再进行评论。

回答(1 个)

Conor Daly
Conor Daly 2023-3-28
To train a model that can unscramble the order of the data, the model needs to be trained specifically for this. One way of doing this is to create a set of predictors which are scrambled, and use the unscrambled data as targets.
Here's an example to get you started. The model doesn't train very well, but it's just an example.
% Load the data.
load('shapes_2.mat');
% Transpose each shape to 2x(numPoints).
shapes = cellfun(@transpose, shapes, UniformOutput=false);
% Standardize data.
M = mean( cat(2, shapes{:}), 2 );
S = std( cat(2, shapes{:}), [], 2 );
shapes = cellfun(@(x)(x-M)./S, shapes, UniformOutput=false);
% Create training predictors/targets by scrambling the order of the
% predictors.
X = shapes;
T = shapes;
for n = 1:numel(X)
idx = randperm(size(X{n},2));
X{n} = X{n}(:, idx);
end
% Split into train/test sets.
XTrain = X(1:150);
TTrain = T(1:150);
XTest = X(151:end);
TTest = T(151:end);
% Define network architecture.
layers = [
sequenceInputLayer(2)
bilstmLayer(64)
dropoutLayer(0.1)
bilstmLayer(64)
dropoutLayer(0.1)
fullyConnectedLayer(2)
regressionLayer ];
% Train the network.
options = trainingOptions("adam", ...
MiniBatchSize=50, ...
MaxEpochs=300, ...
Shuffle="every-epoch", ...
ValidationData={XTest,TTest}, ...
Verbose=false, ...
OutputNetwork="best-validation-loss", ...
Plots="training-progress" );
net = trainNetwork(XTrain, TTrain, layers, options);
% Test the trained network.
YTest = predict(net, XTest);
meanAbsError = mean( cellfun(@(y,t)mean(abs(y - t),'all'), YTest, TTest ));

类别

Help CenterFile Exchange 中查找有关 Deep Learning Toolbox 的更多信息

产品


版本

R2021b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by