Gradients not recorded for a dlnetwork VAE

2 次查看(过去 30 天)
I have created a VAE using dlnetwork - an encoder, decoder and a classifier. The loss function will not record gradients for the update via dlgradients. Can you have a look and discover what is the cause.
% === Setup ===
numEpochs = 5;
miniBatchSize = 32;
learnRate = 1e-3;
latentDim = 32;
hiddenUnits = 64;
inputSize = size(XTrain{1}, 1); % Number of features
sequenceLength = size(XTrain{1}, 2); % Time steps
% === Encoder ===
encoderLayers = [
sequenceInputLayer(inputSize, 'Name', 'input')
lstmLayer(64, 'OutputMode', 'last', 'Name', 'lstm_enc')
fullyConnectedLayer(latentDim, 'Name', 'fc_mu')
fullyConnectedLayer(latentDim, 'Name', 'fc_logvar')
];
encoderNet = dlnetwork(layerGraph(encoderLayers));
% === Decoder ===
decoderLayers = [
sequenceInputLayer(latentDim, 'Name', 'latent_input')
fullyConnectedLayer(64, 'Name', 'fc_latent')
lstmLayer(64, 'OutputMode', 'sequence', 'Name', 'lstm_dec')
fullyConnectedLayer(inputSize, 'Name', 'fc_recon')
];
decoderNet = dlnetwork(layerGraph(decoderLayers));
% === Classifier ===
classifierLayers = [
featureInputLayer(latentDim, 'Name', 'class_input')
fullyConnectedLayer(hiddenUnits, 'Name', 'fc_class_hidden')
reluLayer('Name', 'relu_class')
fullyConnectedLayer(numClasses, 'Name', 'fc_class')
softmaxLayer('Name', 'softmax_out')
];
classifierNet = dlnetwork(layerGraph(classifierLayers));
function loss = computeTotalLoss(dlX, dlY, encoderNet, decoderNet, classifierNet, sequenceLength)
% All operations inside this function are traced
[mu, logvar] = encodeLatents(dlX, encoderNet);
z = sampleLatents(mu, logvar); % Sampling with reparameterization
loss = computeLoss(dlX, dlY, z, mu, logvar, decoderNet, classifierNet, sequenceLength);
end
function [mu, logvar] = encodeLatents(dlX, encoderNet)
mu = forward(encoderNet, dlX, 'Outputs', 'fc_mu');
logvar = forward(encoderNet, dlX, 'Outputs', 'fc_logvar');
end
function z = sampleLatents(mu, logvar)
eps = dlarray(randn(size(mu), 'like', mu)); % Traced random noise
z = mu + exp(0.5 * logvar) .* eps;
end
function loss = computeLoss(dlX, dlY, z, mu, logvar, decoderNet, classifierNet, sequenceLength)
zRepeated = repmat(z, 1, 1, sequenceLength);
dlZSeq = dlarray(zRepeated, 'CBT');
reconOut = forward(decoderNet, dlZSeq, 'Outputs', 'fc_recon');
classOut = forward(classifierNet, dlarray(z, 'CB'), 'Outputs', 'softmax_out');
reconLoss = mse(reconOut, dlX);
classLoss = crossentropy(classOut, dlY);
klLoss = -0.5 * sum(1 + logvar - mu.^2 - exp(logvar), 'all');
loss = reconLoss + classLoss + klLoss;
end
for epoch = 1:numEpochs
idx = randperm(numel(XTrain));
totalEpochLoss = 0;
numBatches = 0;
for i = 1:miniBatchSize:numel(XTrain)
batchIdx = idx(i:min(i+miniBatchSize-1, numel(XTrain)));
XBatch = XTrain(batchIdx);
YBatch = YTrain(batchIdx, :);
% === Format Inputs ===
XCat = cat(3, XBatch{:});
XCat = permute(XCat, [1, 3, 2]);
dlX = dlarray(XCat, 'CBT');
dlY = dlarray(YBatch', 'CB');
totalLoss = dlfeval(@computeTotalLoss, dlX, dlY, encoderNet, decoderNet, classifierNet, sequenceLength);
gradients = dlgradient(totalLoss, encoderNet.Learnables);
Error using dlarray/dlgradient (line 105)
Value to differentiate is not traced. It must be a traced real dlarray scalar. Use dlgradient inside a function called by dlfeval to trace the variables.

采纳的回答

Torsten
Torsten 2025-8-22
移动:Torsten 2025-8-22
Did you read the documentation on how "dlgradient" is to be applied ?
You will have to call "dlgradient" inside the function "computeTotalLoss", not outside from it.
  2 个评论
SIMON
SIMON 2025-8-23
Ok thanks this is news to me thankyou. I will do that. Thanks so much.
Catalytic
Catalytic 2025-8-23
@SIMON, you should click accept on Torsten's answer if it solves your problem.

请先登录,再进行评论。

更多回答(1 个)

Matt J
Matt J 2025-8-22
编辑:Matt J 2025-8-22
As the error message says, the call to dlgradient must occur within the function (in this case computeTotalLoss) called by dlfeval, but that is not what you've done.
More confusingly, it appears that computeTotalLoss and computeLoss call each other in a recursive loop. That will create problems as well, I imagine. You are not meant to be calling dlfeval inside the loss function. It is meant to be called externally, in your training loop.
  2 个评论
SIMON
SIMON 2025-8-23
The two functions are not recursive computeLoss is called by computeTotalLoss which is inside dfeval. However the gradients are calculated outside of dlfeval - so I may try putting it inside this thanks.
Torsten
Torsten 2025-8-23
You have to put "dlgradient" inside "computeTotalLoss". Then you can call "dlfeval" from somewhere outside "computeTotalLoss" and "computeLoss".

请先登录,再进行评论。

类别

Help CenterFile Exchange 中查找有关 Custom Training Loops 的更多信息

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by