Main Content

audioPretrainedNetwork

Pretrained audio neural networks

Since R2024a

    Description

    net = audioPretrainedNetwork(name) returns the specified pretrained audio neural network.

    example

    net = audioPretrainedNetwork(name,Name=Value) specifies options using one or more name-value arguments.

    example

    [net,classNames] = audioPretrainedNetwork("yamnet",___) also returns the class names for the pretrained YAMNet network.

    This function requires Deep Learning Toolbox™.

    example

    Examples

    collapse all

    You may need to download the pretrained network if it is not currently installed.

    Call audioPretrainedNetwork with the desired network name in the Command Window. If the required model is not installed, then the function throws an error and provides a link to download. Click the link, and unzip the file to a location on the MATLAB path.

    Alternatively, execute these commands to download and unzip the model to your temporary directory.

    modelName = "yamnet";
    downloadFolder = fullfile(tempdir,"pretrainedNetDownload");
    downloadURL = sprintf("https://ssd.mathworks.com/supportfiles/audio/%s.zip",modelName);
    loc = websave(downloadFolder,downloadURL);
    modelsLocation = tempdir;
    unzip(loc,modelsLocation)
    addpath(fullfile(modelsLocation,modelName))

    Load a pretrained network using audioPretrainedNetwork. In this example, you load the YAMNet network. See the properties of the dlnetwork object.

    net = audioPretrainedNetwork("yamnet")
    net = 
      dlnetwork with properties:
    
             Layers: [85×1 nnet.cnn.layer.Layer]
        Connections: [84×2 table]
         Learnables: [110×3 table]
              State: [54×3 table]
         InputNames: {'input_1'}
        OutputNames: {'softmax'}
        Initialized: 1
    
      View summary with summary.
    
    

    View the first few layers of the network.

    head(net.Layers)
      8×1 Layer array with layers:
    
         1   'input_1'            Image Input               96×64×1 images
         2   'conv2d'             2-D Convolution           32 3×3×1 convolutions with stride [2  2] and padding 'same'
         3   'b'                  Batch Normalization       Batch normalization with 32 channels
         4   'activation'         ReLU                      ReLU
         5   'depthwise_conv2d'   2-D Grouped Convolution   32 groups of 1 3×3×1 convolutions with stride [1  1] and padding 'same'
         6   'L11'                Batch Normalization       Batch normalization with 32 channels
         7   'activation_1'       ReLU                      ReLU
         8   'conv2d_1'           2-D Convolution           64 1×1×32 convolutions with stride [1  1] and padding 'same'
    

    Read in an audio signal to classify it.

    [audioIn,fs] = audioread("TrainWhistle-16-44p1-mono-9secs.wav");

    Plot and listen to the audio signal.

    t = (0:numel(audioIn)-1)/fs;
    plot(t,audioIn)
    xlabel("Time (s)")
    ylabel("Ampltiude")
    axis tight

    Figure contains an axes object. The axes object with xlabel Time (s), ylabel Ampltiude contains an object of type line.

    sound(audioIn,fs)

    YAMNet requires you to preprocess the audio signal to match the input format used to train the network. The preprocesssing steps include resampling the audio signal and computing an array of mel spectrograms. To learn more about mel spectrograms, see melSpectrogram. Use yamnetPreprocess to preprocess the signal and extract the mel spectrograms to be passed to YAMNet. Visualize one of these spectrograms chosen at random.

    spectrograms = yamnetPreprocess(audioIn,fs);
    
    arbitrarySpect = spectrograms(:,:,1,randi(size(spectrograms,4)));
    surf(arbitrarySpect,EdgeColor="none")
    view([90 -90])
    xlabel("Mel Band")
    ylabel("Frame")
    title("Mel Spectrogram for YAMNet")
    axis tight

    Figure contains an axes object. The axes object with title Mel Spectrogram for YAMNet, xlabel Mel Band, ylabel Frame contains an object of type surface.

    Create a YAMNet neural network using the audioPretrainedNetwork function. Call predict with the network on the preprocessed mel spectrogram images. Convert the network output to class labels using scores2label.

    [net,classNames] = audioPretrainedNetwork("yamnet");
    scores = predict(net,spectrograms);
    classes = scores2label(scores,classNames);

    The classification step returns a label for each of the spectrogram images in the input. Classify the sound as the most frequently occurring label in the output.

    mySound = mode(classes)
    mySound = categorical
         Whistle 
    
    

    Download and unzip the air compressor data set [1]. This data set consists of recordings from air compressors in a healthy state or one of 7 faulty states.

    url = "https://www.mathworks.com/supportfiles/audio/AirCompressorDataset/AirCompressorDataset.zip";
    downloadFolder = fullfile(tempdir,"aircompressordataset");
    datasetLocation = tempdir;
    
    if ~exist(fullfile(tempdir,"AirCompressorDataSet"),"dir")
        loc = websave(downloadFolder,url);
        unzip(loc,fullfile(tempdir,"AirCompressorDataSet"))
    end

    Create an audioDatastore object to manage the data and split it into train and validation sets.

    ads = audioDatastore(downloadFolder,IncludeSubfolders=true,LabelSource="foldernames");
    
    [adsTrain,adsValidation] = splitEachLabel(ads,0.8,0.2);

    Read an audio file from the datastore and save the sample rate for later use. Reset the datastore to return the read pointer to the beginning of the data set. Listen to the audio signal and plot the signal in the time domain.

    [x,fileInfo] = read(adsTrain);
    fs = fileInfo.SampleRate;
    
    reset(adsTrain)
    
    sound(x,fs)
    
    figure
    t = (0:size(x,1)-1)/fs;
    plot(t,x)
    xlabel("Time (s)")
    title("State = " + string(fileInfo.Label))
    axis tight

    Figure contains an axes object. The axes object with title State = Bearing, xlabel Time (s) contains an object of type line.

    Extract Mel spectrograms from the train set using yamnetPreprocess. There are multiple spectrograms for each audio signal. Replicate the labels so that they are in one-to-one correspondence with the spectrograms.

    emptyLabelVector = adsTrain.Labels;
    emptyLabelVector(:) = [];
    
    trainFeatures = [];
    trainLabels = emptyLabelVector;
    while hasdata(adsTrain)
        [audioIn,fileInfo] = read(adsTrain);
        features = yamnetPreprocess(audioIn,fileInfo.SampleRate);
        numSpectrums = size(features,4);
        trainFeatures = cat(4,trainFeatures,features);
        trainLabels = cat(2,trainLabels,repmat(fileInfo.Label,1,numSpectrums));
    end

    Extract features from the validation set and replicate the labels.

    validationFeatures = [];
    validationLabels = emptyLabelVector;
    while hasdata(adsValidation)
        [audioIn,fileInfo] = read(adsValidation);
        features = yamnetPreprocess(audioIn,fileInfo.SampleRate);
        numSpectrums = size(features,4);
        validationFeatures = cat(4,validationFeatures,features);
        validationLabels = cat(2,validationLabels,repmat(fileInfo.Label,1,numSpectrums));
    end

    The air compressor data set has only 8 classes. Call audioPretrainedNetwork with NumClasses set to 8 to load a pretrained YAMNet network with the desired number of output classes for transfer learning.

    classNames = unique(adsTrain.Labels);
    numClasses = numel(classNames);
    
    net = audioPretrainedNetwork("yamnet",NumClasses=numClasses);

    To define training options, use trainingOptions.

    miniBatchSize = 128;
    validationFrequency = floor(numel(trainLabels)/miniBatchSize);
    options = trainingOptions('adam', ...
        InitialLearnRate=3e-4, ...
        MaxEpochs=2, ...
        MiniBatchSize=miniBatchSize, ...
        Shuffle="every-epoch", ...
        Plots="training-progress", ...
        Metrics="accuracy", ...
        Verbose=false, ...
        ValidationData={single(validationFeatures),validationLabels'}, ...
        ValidationFrequency=validationFrequency);

    To train the network, use trainnet.

    airCompressorNet = trainnet(trainFeatures,trainLabels',net,"crossentropy",options);

    Save the trained network to airCompressorNet.mat. You can now use this pre-trained network by loading the airCompressorNet.mat file.

    save airCompressorNet.mat airCompressorNet 

    References

    [1] Verma, Nishchal K., et al. “Intelligent Condition Based Monitoring Using Acoustic Signals for Air Compressors.” IEEE Transactions on Reliability, vol. 65, no. 1, Mar. 2016, pp. 291–309. DOI.org (Crossref), doi:10.1109/TR.2015.2459684.

    Input Arguments

    collapse all

    Name of pretrained network, specified as "yamnet", "vggish", "openl3", "crepe", or "vadnet".

    audioPretrainedNetwork Model Name ArgumentNeural Network NameInput ShapePreprocessing and Postprocessing Functions
    "yamnet"YAMNet96-by-64-by-1-by-TyamnetPreprocess
    "vggish"VGGish96-by-64-by-1-by-TvggishPreprocess
    "openl3"OpenL3N-by-M-by-1-by-T, where N and M depend on SpectrumTypeopenl3Preprocess
    "crepe"CREPE1024-by-1-by-1-by-TcrepePreprocess, crepePostprocess
    "vadnet"Voice activity detection (VAD) network40-by-TvadnetPreprocess, vadnetPostprocess

    For the network input shapes, T depends on the length of the audio signal.

    Data Types: char | string

    Name-Value Arguments

    Specify optional pairs of arguments as Name1=Value1,...,NameN=ValueN, where Name is the argument name and Value is the corresponding value. Name-value arguments must appear after other arguments, but the order of the pairs does not matter.

    Example: net = audioPretrainedNetwork(name,Weights="none")

    Neural network weights, specified as one of these values:

    • "pretrained" — Return the neural network with its pretrained weights.

    • "none" — Return the uninitialized neural network architecture only.

    • "env" — Return an OpenL3 network that was trained on environmental sound data. This value applies only when name is "openl3".

    • "music" — Return an OpenL3 network that was trained on music data. This value applies only when name is "openl3".

    Data Types: char | string

    Number of classes for classification tasks, specified as a positive integer or []. This argument applies only when you set name to "yamnet" for the YAMNet network.

    If NumClasses is an integer, then the audioPretrainedNetwork function adapts the pretrained YAMNet network for classification tasks with the specified number of classes by replacing the learnable layer in the classification head of the network.

    If you specify the NumClasses option, then NumResponses must be [] and the function must not output the classNames argument.

    Data Types: single | double | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64

    Number of responses for regression tasks, specified as a positive integer or []. This argument applies only when you set name to "yamnet" for the YAMNet network.

    If NumResponses is an integer, then the audioPretrainedNetwork function adapts the pretrained YAMNet network for regression tasks with the specified number of responses by replacing the classification head of the network with a head for regression tasks.

    If you specify the NumResponses option, then NumClasses must be [] and the function must not output the classNames argument.

    Data Types: single | double | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64

    Input spectrum type for the OpenL3 network, specified as one of these values:

    • "mel128" — The network accepts mel spectrograms with 128 mel bands.

    • "mel256" — The network accepts mel spectrograms with 256 mel bands.

    • "linear" — The network accepts positive one-sided spectrograms with an FFT length of 512.

    This argument applies only when name is "openl3".

    Data Types: char | string

    Length of output embedding for the OpenL3 network, specified as 512 or 6144.

    This argument applies only when name is "openl3".

    Data Types: single | double | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64

    Model capacity of the CREPE network, specified as "tiny", "small", "medium", "large", or "full". The higher the model capacity, the greater the number of learnables in the model.

    This argument applies only when name is "crepe".

    Data Types: char | string

    Output Arguments

    collapse all

    Neural network, returned as a dlnetwork (Deep Learning Toolbox) object

    Class names, returned as a string array.

    The function returns class names only when name is "yamnet" and both the NumClasses and NumResponses options are [].

    References

    [1] Gemmeke, Jort F., Daniel P. W. Ellis, Dylan Freedman, Aren Jansen, Wade Lawrence, R. Channing Moore, Manoj Plakal, and Marvin Ritter. 2017. “Audio Set: An Ontology and Human-Labeled Dataset for Audio Events.” In 2017 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 776–80. New Orleans, LA: IEEE. https://doi.org/10.1109/ICASSP.2017.7952261.

    [2] Hershey, Shawn, et al. “CNN Architectures for Large-Scale Audio Classification.” 2017 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), IEEE, 2017, pp. 131–35. DOI.org (Crossref), doi:10.1109/ICASSP.2017.7952132.

    [3] Cramer, Jason, et al. "Look, Listen, and Learn More: Design Choices for Deep Audio Embeddings." In ICASSP 2019 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), IEEE, 2019, pp. 3852-56. DOI.org (Crossref), doi:/10.1109/ICASSP.2019.8682475.

    [4] Kim, Jong Wook, Justin Salamon, Peter Li, and Juan Pablo Bello. “Crepe: A Convolutional Representation for Pitch Estimation.” In 2018 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 161–65. Calgary, AB: IEEE, 2018. https://doi.org/10.1109/ICASSP.2018.8461329.

    [5] Ravanelli, Mirco, et al. SpeechBrain: A General-Purpose Speech Toolkit. arXiv, 8 June 2021. arXiv.org, http://arxiv.org/abs/2106.04624

    Version History

    Introduced in R2024a

    See Also

    (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox)