How can I define a custom loss function using trainnet?
13 次查看(过去 30 天)
显示 更早的评论
Hello,
I am trying to define a custom loss function using trainnet. The documentation says:
If the trainnet function does not provide the loss function that you need for your task, then you can specify a custom loss function to the trainnet as a function handle. The function must have the syntax loss = f(Y,T), where Y and T are the predictions and targets, respectively.
However, I am not sure how the predictions and targets are defined here. I am currently using trainnet as follows:
trainedNet = trainnet(dsTrain,layers,"mse",options);
dsTrain is a datastore containing the input and target images for the regression problem. But I would like change the loss to a custom function involving ssim. I would like something similar to the following, although, I know this isn't quite right:
trainedNet = trainnet(dsTrain,layers,@(Y,targets) 1-ssim(Y,targets),options);
I get the following errror message:
Error using trainnet
Value to differentiate is non-scalar. It must be a traced real dlarray scalar.
Thanks!
0 个评论
回答(1 个)
另请参阅
类别
在 Help Center 和 File Exchange 中查找有关 Deep Learning Toolbox 的更多信息
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!