Training with trainNetwork failed. Invalid transform function defined on datastore.
1 次查看(过去 30 天)
显示 更早的评论
I'm trying to perform Transfer Learnign starting from the pretrained Yamnet network. This is the code to split and preprocess data:
%Create datastore, label data and split data for training, validation and testing
pathToAudio = "Results\AudioFiles3";
AudioDS = audioDatastore(pathToAudio,"IncludeSubfolders",true,"FileExtensions",".flac","LabelSource","foldernames");
[AudioTrain,AudioValidation,AudioTest] = splitEachLabel(AudioDS,0.7,0.2,0.1,"randomized");
% Use the transform function to preprocess the data using the function audioPreprocess, found at the end of this example. For each signal:
% Use yamnetPreprocess (Audio Toolbox) to generate mel spectrograms suitable for training using YAMNet. Each audio signal produces multiple spectrograms.
% Duplicate the class label for each of the spectrograms.
tdsTrain = transform(AudioTrain,@audioPreprocess,IncludeInfo=true);
tdsValidation = transform(AudioValidation,@audioPreprocess,IncludeInfo=true);
tdsTest = transform(AudioTest,@audioPreprocess,IncludeInfo=true);
This is the supporting function:
% The function audioPreprocess uses yamnetPreprocess (Audio Toolbox) to generate mel spectrograms
% from audioIn that you can feed to the YAMNet pretrained network. Each input signal generates multiple spectrograms,
% so the labels must be duplicated to create a one-to-one correspondence with the spectrograms.
function [data,info] = audioPreprocess(audioIn,info)
class = info.Label;
fs = info.SampleRate;
features = yamnetPreprocess(audioIn,fs);
numSpectrograms = size(features,4);
data = cell(numSpectrograms,2);
for index = 1:numSpectrograms
data{index,1} = features(:,:,:,index);
data{index,2} = class;
end
end
Afterwards, I open the Deep Network Designer and and select the "YamNet" network, I substitute the last fullyConnected layer (to give 2 classes as output) and the classification layer. After starting to train the network, I get the following error after a few iterations:
"Training with trainNetwork failed.
Invalid transform function defined on datastore."
Basically I am trying to reproduce what is done at this page (https://it.mathworks.com/help/deeplearning/ug/transfer-learning-with-audio-networks-in-deep-network-designer.html) with my how data.
Do anybody have a clue about where I could act to solve this?
Thanks in advance
回答(1 个)
Milan Bansal
2023-12-22
Hi Roberto Andreotti,
I understand that you are facing an error when trying to perform transfer learning from the pretrained YAMNet Network using your data.
The mentioned error is generally caused by a difference in the shape of data being returned by the transform function and the shape of the data being read in by the network. The output format of the transform function must match the expected input of the of the network. To ensure this, test the transform function on a small subset of the data before applying it to the entire datastore.
Ensure that the "yamnetPreprocess" function is outputting the mel spectrograms in the correct format expected by the YAMNet Network. Check the size of the "features" variable to see if it matches the input size of YAMNet.
Please refer to the documentation link to learn more about the "yamnetPreprocess" function.
Hope it helps!
0 个评论
另请参阅
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!