Deep Learning with MATLAB: Transfer Learning in 10 Lines of MATLAB Code
From the series: Deep Learning with MATLAB
Use MATLAB® for transfer learning, and see how it is a practical way to apply deep learning to your problems.
This demo uses transfer learning to retrain AlexNet, a pretrained deep convolutional neural network (CNN or ConvNet), to recognize snack foods such as hot dogs, cupcakes, and apple pie.
Recorded: 8 Feb 2017
Hi. My name is Joe Hicklin. I'm a senior developer at the MathWorks. I'm going to show you how to do transfer learning. Transfer learning can be a very practical way to apply deep learning to your problems.
With transfer learning, you take a preexisting neural net, modify it slightly, and then retrain it on your images. This can produce excellent results and is far, far easier than designing a network from scratch and training it yourself.
In my work, I need to be able to distinguish hamburgers from hot dogs and cupcakes and apple pie and ice cream. As far as I know, there's no network that'll do that for me. So I'm going to start with a preexisting network, Alex net. Alex net's been trained to classify 1,000 different kinds of images, and it's been trained on over a million images already.
So here I am. I'm going to start out loading Alex net, and I'm going to get the layers out of it so I can see the parts. If you look down here, you can see that Alex net has 25 layers. Most of the layers are doing useful image processing things that'll work for my system as well as for Alex net's. I'm going to leave those alone.
But the 23rd layer has 1,000 neurons in it, because Alex net classifies 1,000 different images. I'm only going to do five different kinds of images, so I'm going to replace that with a network that only has five images. Finally, I'm going to replace the output layer as well. The last layer of Alex net has learned Alex net's classifications, those 1,000 different classes. I don't want that. I'm going to replace it with an empty layer that's going to learn mine.
So now I've got my network set up. It's time to deal with the data. You don't need a million images like Alex net was trained on, but you do need 1,000 of them to get good results. I've made a folder with five subfolders in it, one for each of my classes. So there's one called Apple Pie, one called Cupcakes, and so on. And inside each of these folders are 1,000 images of the appropriate topic.
I've sized these images to be the size Alex net expects, 227 by 227, and you'll have to do that, too. If you arrange your data like this, you can use MATLAB's image data store object, because it understands that structure, and it will load all the images and label them appropriately for you. So that's what I'm doing here.
As soon as I've got my images, I need to separate them into two sets. Most of them I'll use for training, but I'll save a few of them out to test for accuracy later. So let's do that. Now I'm all set to train my network. I've got to set up a few network parameters here. I've chosen parameters that are going to work well.
You can change these if you like and see what happens. And then I'm ready to train the network. That started. That's going to take five or six minutes to do its job. I have a fairly powerful GPU in my computer, so it's pretty quick. Your mileage may vary. All right, the network's done training. The first thing we're going to do now is see how accurate it is.
We're going to ask the network to classify the test images, the images we left out of our training set. And then we're just going to see what fraction of those it gets right. We were 84% accurate. Pretty good for five minutes of work. Let's try it now with the webcam on some real food. I just happened to have some food on my desk. There's hamburgers, apple pie, hot dogs, ice cream.
So overall, it works pretty good, and it's fairly robust for a lot of these. Different angles and stuff. So there we go. That worked better than I expected, really. I simplified this demo as much as I could, but in the download, we'll include a second file that'll have a lot more comments, and it'll have some more code to handle some situations that might arise.
I've showed you how to do classification with transfer learning, but if you need real numbers out, you can also do regression with transfer learning. Well, I hope I've shown you enough to get you interested in transfer learning, so grab some snacks and give it a go.
Download Code and Files
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .
You can also select a web site from the following list:
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
- América Latina (Español)
- Canada (English)
- United States (English)
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- United Kingdom (English)
- Australia (English)
- India (English)
- New Zealand (English)
- 日本Japanese (日本語)
- 한국Korean (한국어)