Knowledge Distillation for Regression

33 次查看(过去 30 天)
Hi everyone,
Recently, mathworks has published an example of code to perform knowledge distillation for classification: Train Smaller Neural Network Using Knowledge Distillation - MATLAB & Simulink (mathworks.com)
My question is: if I have a regression problem, how can I train a smaller regression neural network using knowledge distillation from a teacher neural network also for a regression problem?
Best Regards,
Luís Moreira

回答(1 个)

Shubham
Shubham 2023-12-3
I understand that you want to train a small neural network from a teacher neural network using knowledge distillation for a regression problem.
The idea of knowledge distillation is to reduce the size of a neural network while maintaining its accuracy. Similar to the case of a classification problem, the student neural network would be trained using a distillation loss which would be a combination of loss between student’s and teacher’s predictions and the loss between student’s predictions and actual outputs.
The key difference between the classification and regression task is a loss function. In the case of regression mean squared error (MSE) values could be used. Also, in the case of a regression task the output is a continuous value instead of a discrete class and hence there would be a change in the output layer.
Refer to a simple code snippet below:
% Generate synthetic regression data
numObservations = 1000;
XTrain = linspace(-10, 10, numObservations)';
YTrain = sin(XTrain) + 0.1 * randn(numObservations, 1); % Target is a noisy sine wave
% Define and train the teacher network
teacherLayers = [
featureInputLayer(1)
fullyConnectedLayer(50)
reluLayer
fullyConnectedLayer(1)
regressionLayer
];
teacherOptions = trainingOptions('adam', ...
'MaxEpochs', 300, ...
'MiniBatchSize', 32, ...
'Plots', 'training-progress', ...
'Verbose', false);
teacherNet = trainNetwork(XTrain, YTrain, teacherLayers, teacherOptions);
% Define the student network with a simpler architecture
studentLayers = [
featureInputLayer(1)
fullyConnectedLayer(10)
reluLayer
fullyConnectedLayer(1)
regressionLayer
];
studentOptions = trainingOptions('adam', ...
'MaxEpochs', 300, ...
'MiniBatchSize', 32, ...
'Plots', 'training-progress', ...
'Verbose', false);
studentNet = trainNetwork(XTrain, YTrain, studentLayers, studentOptions);
% Train the student network using knowledge distillation
alpha = 0.5; % Weighting factor for combining losses
% Create combined targets based on teacher predictions and true targets
teacherPred = predict(teacherNet, XTrain);
combinedTargets = alpha * teacherPred + (1 - alpha) * YTrain;
% Training the student network with combined targets
studentNet = trainNetwork(XTrain, combinedTargets, studentLayers, studentOptions);
Following are the results produced for teacher network, student network and for the knowledge distillation from teacher to student.
The above code snippet is a quite straightforward way to illustrate how MSE could be used for a regression task in knowledge distillation. You could use knowledge distillation for regression in the similar way as used for classification.
I hope this helps!!

类别

Help CenterFile Exchange 中查找有关 Distillation Design 的更多信息

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by