QuickDraw: Final Project implementing LSTM
(Mind you, this is not a tutorial, I will be posting one soon though!)
Before the NUS internship(click here to read about it!), I was just tinkering with Data Analysis of very common datasets like Boston Housing, Crime Rates, Iris Dataset etc and applying basic regression or classification algorithms. I never really applied advanced Artificial Neural Network architectures like LSTM (Long Short-Term Memory) and CNN on my own project. All this changed during my internship.
My group mates and I sat for an hour or two and brainstormed for ideas ranging from very basic to very advanced topics that none of us had personally implemented. Our project had to have a novel business application along with some concepts taught in the class. Keeping all this in mind, we thought to use Google's Dataset available publicly which contains Numpy Arrays of lots of different classes of objects(hand-drawn). It's called Quick Draw, so we went with the same name.
Now moving onto some project details. Things might get complicated but I'll try my best to explain every step in layman terms hoping to reach out to a larger audience.
What our project is about:
Our project QuickDraw essentially predicts the image that is being drawn/doodled in real time. It starts to predict what you are trying to draw after 100 strokes(points) have been completed(mapped) and then gives a real-time output which changes as you draw the image. This project has a lot of scope in Digital Art, Helping handicapped people who aspire to draw(or convey messages with really easy strokes) and also teach a computer some complicated and heavily stroke dependent languages like Arabic.
Find the code here!
Our Project Workflow:
Download the data for 5 classes of objects with each class containing 10,000 labeled images.
Since the data was consistent and reliable, we didn't have much of a cleaning job to do.
The pictures were in colour, so we converted them to grayscale images, and also resized them to 28x28. Our final image was a one channel 28x28 array of numbers: each number describing intensities of that particular pixel. 0 for black and 255 for white.
We were a team of 5, and each one of us decided to take up one algorithm against this dataset. We ran the test with 5 algorithms namely SVM (Support Vector Machines), K Nearest Neighbours, Feed Forward Neural Network, Convolutional Neural Network and Long Short-Term Memory.
The highest accuracy was obtained using CNN(97%) and LSTM(92%) so we decided to implement these in our final project.
We made afront end using OpenCV which takes user input by tracking the mouse co-ordinates and maps it. This provides as input to our LSTM model.
We then did some trial runs and made a presentation for our final Project Review.
Now let's understand what exactly are LSTMs. For this we have to understand recurrent neural networks.
Humans don’t start their thinking from scratch every second. As you read this, you understand each word based on your understanding of previous words. You don’t throw everything away and start thinking from scratch again. Your thoughts have persistence. Traditional neural networks can’t do this, and it seems like a major shortcoming. Recurrent neural networks address this issue. They are networks with loops in them, allowing information to persist.
“LSTM” is a very special kind of recurrent neural network which works, for many tasks, much much better than the standard version. Almost all exciting results based on recurrent neural networks are achieved with them.
Sometimes, we only need to look at recent information to perform the present task. For example, consider a language model trying to predict the next word based on the previous ones. If we are trying to predict the last word in “the clouds are in the sky,” we don’t need any further context – it’s pretty obvious the next word is going to be sky. In such cases, where the gap between the relevant information and the place that it’s needed is small, RNNs can learn to use the past information.
LSTMs are explicitly designed to avoid the long-term dependency problem. Remembering information for long periods of time is practically their default behaviour, not something they struggle to learn!
Therefore, we used LSTM so that our network remembers the previous strokes as well as the new strokes that were drawn. Without remembering the previous strokes, our model would have struggled to identify the doodle being drawn. It "connected the dots" to find the bigger picture and did it with ease! LSTMs are very useful in speech recognition, language modelling, translation, image captioning all of which require great amounts of memory retention and forwarding.
Shoutout to Colah's blog from which I've used some excerpts and diagrams.
Checkout Understanding LSTM Networks by Colah. It dives deeper into the mathematical representation of these networks and lots of calculus, so go ahead and read more if you like!
That's it for this blog!
Click the heart button there and share it! :)