Deep Learning with MATLAB: Training a Neural Network from Scratch with MATLAB
From the series: Deep Learning with MATLAB
This demo uses MATLAB® to train a CNN from scratch for classifying images of four different animal types: cat, dog, deer, and frog. Images are used from the CIFAR-10 dataset (The CIFAR-10 dataset).
Recorded: 12 Apr 2017
Hi. My name is Gabriel Ha, and I'm here to show you how MATLAB makes it straightforward to create a deep neural network from scratch. Our demo has specific application-to-image processing and recognition, but we feel like images are pretty easy to relate to. And it's a fairly well-known application of neural networks. Most importantly, we want to make deep learning accessible to everyone, and you'll be able to get your hands on everything that we show you, and build on them, and start using your own networks on your own.
So for those of you who are super familiar with training networks, along with techniques to make them more accurate, MATLAB is going to be great for you because, as you would expect, we provide you with the intuitive syntax and functions which will allow you to easily implement your refinements. For those of you new to the deep learning field and want to get your feet wet with this technology, the extent of what you can immediately do might be limited to image recognition, but I'm confident that it will provide you with more than enough material to get started and have a lot of fun with neural networks.
So here's what we're going do. We want to train a network to recognize four different animals: cats, dogs, frogs, and deer. To do that, we're going to introduce images of each animal to our network, define the layers of our network, and then, using a single line of code, tell MATLAB to train and create our network from scratch. Then we'll test out our network by showing it new images that it hasn't seen before and check its accuracy.
To set things up, we're going to go into this directory and draw 5,000 images of each animal into separate folders. Now, if you're doing the math, that's 20,000 images total. And for those of you who are just interested in trying this out, you might be thinking, “Wait, so you expect me to watch this video and then go curate 20,000 images before I can even get started?” Well, you can if you want, or you can do it we did—take advantage of work that's already been done.
In this case, we got all of our images from the publicly available CIFAR-10 dataset, which literally just involves downloading and extracting one big ZIP file. So thankfully, setting up this demo is only dependent on your network speeds and processor power, respectively. That being said, training a network from scratch does require quite a bit of data, so always look for opportunities to build on previous work like this demo.
Let's take a look at the core code required to execute our training. You can see this part, which specifies the animal names, and then this part, pointing MATLAB to the folder containing that training data. And as far as setup is concerned, that's it.
So now we're going to tell MATLAB how we want the deep network to be trained. Every neural network has a series of layers, and the more layers it has, the deeper the network. Now each layer takes in data from the previous layer, transforms the data, and then passes it on. So the first layer takes in the raw input image, and by the time we get to the last layer, it's going to hopefully spit out the correct name of the animal in the original image.
So here are the layers that we've chosen to implement for this example. For those of you completely new to this field, you would not be expected to be able to come up with all these layers from scratch. On the other end, if you're a deep learning expert, we provide you with the tools to precisely implement your layers.
But in either case, if you want to build off this example, just replace the training data with your own, tweak the layers if you feel like you're up to the task, and with one line of code, MATLAB will give you a neural network trained on whatever you want, whether it's animals or faces of your friends, which is totally not a creepy thing that I did on company time.
So of course, it'll take some time to train. If you just have the CPU, it'll take a while, but if you have a wicked decked-out GPU like this machine, it takes about 45 seconds. Once it's done, we can move on to testing our network.
Let's start off super basic. We have a test set of images containing 1,000 of each animal—again, conveniently obtained from CIFAR-10. And as you can see, it's set up the exact same way directory-wise as a training set. But most importantly, the network was not trained on these images.
We'll display an image along with what the network thinks it is. That's a deer. Correct. That's a dog. Also correct. That's a frog. Network thought it was a cat.
I'm pretty sure you get what's going on by now, so let's speed up this process. We'll have MATLAB run this code, which tests all the images on our test set. And then it'll tell us percentage-wise how the network did overall. And the number is—drum roll, please—about 75% accuracy. And hey, for 45 seconds of training, that's not too bad.
As a caveat, you'll notice that the CIFAR-10 images are really small, and the first layer of our network requires images that are 32 by 32 by 3. While our code does resize your image, you'll have to determine whether that makes sense for your data. But if you have a whole bunch of images that you want to classify with the neural network, here's how to do it with MATLAB, and you can get started right away.
Click the links in the description below to get your hands on the code and check out documentation on using Neural Network Toolbox. Don't hesitate to leave us a question or comment. And as always, thanks for watching.
Hey, check out this app that Gabriel created. Point it at me.
Dude, this is creepy.
Download Code and Files
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .
You can also select a web site from the following list:
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
- América Latina (Español)
- Canada (English)
- United States (English)
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- United Kingdom (English)
- Australia (English)
- India (English)
- New Zealand (English)
- 日本Japanese (日本語)
- 한국Korean (한국어)