Training a model using different shapes

7 views (last 30 days)
Ahmed
Ahmed on 27 Jan 2023
Answered: Conor Daly on 28 Mar 2023
I have training data of around 1000 shapes (different sizes and dimensions). This data is in a cell array, where each cell is a shape and within the cell there is an array of size n by 2. The n is the number of data points that draw the shape and 2 columns are for the x and y coordinates of these points. For the training data, these points are ordered so that if a straight line connects the points as they are ordered in the array it will draw out the desired shape accurately.
I would like to train a model to learn from those 1000 shapes so that if given a new shape and the points are not in order, the model is able to re order the points and draw the shape based on what it has learned from its training of all other shapes.
I am very new to the concept of training models, what I have used in matlab so far is giving the neural networks a set of inputs and an outputs and it learns what it can, but here I have different cases that should be learned from and I’m not sure added all those points to one long array of coordinates is the right thing to do because it defeats the purpose of the distinct shapes and the order of the points, any advice is appreciated
% Step 1: Prepare the data
% Load the x and y coordinates of your shapes
load('shapes.mat');
% Concatenate the x and y coordinates of each shape
data = [];
for i = 1:numel(shapes)
data = [data; shapes{i}];
end
% Step 2: Define the CNN architecture
layers = [
sequenceInputLayer([size(data, 1) 2])
lstmLayer(64,'OutputMode','sequence')
dropoutLayer(0.1)
lstmLayer(64,'OutputMode','sequence')
dropoutLayer(0.1)
fullyConnectedLayer(size(data, 1)*2)
regressionLayer];
% Step 3: Train the model on all shapes
% Split the data into training and test sets
[XTrain,XTest,YTrain,YTest] = split_data(data, 0.8);
options = trainingOptions('adam', ...
'InitialLearnRate', 0.01, ...
'MaxEpochs',4, ...
'Shuffle','every-epoch', ...
'Verbose',false, ...
'Plots','training-progress');
net = trainNetwork(XTrain,YTrain,layers,options);
% Step 4: Use the trained model to make predictions on new shapes
predictedCoordinates = predict(net,XTest);
  7 Comments
Ahmed
Ahmed on 27 Jan 2023
i want to first train the model on these shapes before testing on randomly ordered shapes
Ahmed
Ahmed on 27 Jan 2023
@KSSV is there any resource you could direct me to where I can learn how to train a model to order points based on different examples please

Sign in to comment.

Answers (1)

Conor Daly
Conor Daly on 28 Mar 2023
To train a model that can unscramble the order of the data, the model needs to be trained specifically for this. One way of doing this is to create a set of predictors which are scrambled, and use the unscrambled data as targets.
Here's an example to get you started. The model doesn't train very well, but it's just an example.
% Load the data.
load('shapes_2.mat');
% Transpose each shape to 2x(numPoints).
shapes = cellfun(@transpose, shapes, UniformOutput=false);
% Standardize data.
M = mean( cat(2, shapes{:}), 2 );
S = std( cat(2, shapes{:}), [], 2 );
shapes = cellfun(@(x)(x-M)./S, shapes, UniformOutput=false);
% Create training predictors/targets by scrambling the order of the
% predictors.
X = shapes;
T = shapes;
for n = 1:numel(X)
idx = randperm(size(X{n},2));
X{n} = X{n}(:, idx);
end
% Split into train/test sets.
XTrain = X(1:150);
TTrain = T(1:150);
XTest = X(151:end);
TTest = T(151:end);
% Define network architecture.
layers = [
sequenceInputLayer(2)
bilstmLayer(64)
dropoutLayer(0.1)
bilstmLayer(64)
dropoutLayer(0.1)
fullyConnectedLayer(2)
regressionLayer ];
% Train the network.
options = trainingOptions("adam", ...
MiniBatchSize=50, ...
MaxEpochs=300, ...
Shuffle="every-epoch", ...
ValidationData={XTest,TTest}, ...
Verbose=false, ...
OutputNetwork="best-validation-loss", ...
Plots="training-progress" );
net = trainNetwork(XTrain, TTrain, layers, options);
% Test the trained network.
YTest = predict(net, XTest);
meanAbsError = mean( cellfun(@(y,t)mean(abs(y - t),'all'), YTest, TTest ));

Categories

Find more on Deep Learning Toolbox in Help Center and File Exchange

Products


Release

R2021b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!