How does feedforwardnet generate this graph?

14 次查看(过去 30 天)
When creating neural networks, how does the feedforwardnet function generate the below plot (bottom left) for the test data?
On the y-axis are the values it predicts during testing, but how does it map these to x values? It does not know the target value for the test data.

回答(1 个)

Binaya
Binaya 2023-11-2
Hi Emer,
Based on the provided description, it seems that you would like to understand why the plotregression consists of test plot as well which makes use of target value of the test sample as well.
Please go through the following explanation to your query:
  1. The “feedforwardnet is trained on a dataset which is split into the following 3 splits:
a. Train:
i. The dataset split on which the model is trained.
ii. Passed into the model using the “train” function.
iii. Targets of “Train” split is used to train the model
b. Validation:
i. The dataset split on which the model is validated to stop the training iterations.
ii. A performance metric measures the loss between predicted and target values of Validation set.
iii. This dataset split is given as an argument in “trainingOptions”.
c. Test:
i. After training is complete, this dataset split is used to measure the performance of the model.
ii. Used for final performance metric in a loss function or “plotregression” function.
3. The target values for all samples in the three subsets are known and used at different stages of training the feed-forward neural network.
4. However, when applying the trained feed-forward network to new samples, the target values are usually unknown, and hence, plots cannot be generated for these samples.
Please refer to following documentation of related functions for more details:
  1. Feed Forward Net (To generate a feed-forward network for a given network hyperparameters): https://www.mathworks.com/help/deeplearning/ref/feedforwardnet.html
  2. TrainingOptions (Selecting options for training the deep learning network): https://www.mathworks.com/help/deeplearning/ref/trainingoptions.html
  3. TrainNetwork (For training a deep learning network for a classification or regression task): https://www.mathworks.com/help/deeplearning/ref/trainnetwork.html
  4. Train (For training a shallow network): https://www.mathworks.com/help/deeplearning/ref/network.train.html
  5. Plotregression (For plotting linear regression between outputs and targets): https://www.mathworks.com/help/deeplearning/ref/plotregression.html
I hope this helps.
Regards 
Binaya 

产品


版本

R2023b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by