What are recurrent neural networks and how can you use them

What are recurrent neural networks and how can you use them
Photo by Jake Hills on Unsplash

Recurrent neural networks (RNNs) are helpful for specific time-series problems.

A time series problem is one where the x-axis is time, and the y-axis is something you want to measure.

Example of a time series graph

A typical time-series problem might be 'guess the next value'. Another might be 'classify this time series'.

'Are you saying I can use this to predict the future?'

Yes.

'Can I use this to predict the stock market and make millions?'

Maybe. But probably not :)

A simple problem that multi-layer perceptions can't solve

Consider the following time-series classification problem:

  • We have a set of time series of arbitrary lengths.
  • Each time series element is either a 1 or a 0.
  • If a time series has two 1's next to each other, it's considered to be in the 'positive' class.
  • Otherwise, it's in the negative class.

There are two reasons why this model cannot be solved using multi-layer perceptions. Firstly, such models only work with input vectors of fixed length. We want the model to take in vectors of arbitrary length. We could pad the inputs with 0's and pass the model the same size vector every time... but how would we know how long to make it?

The second problem is even more serious. Second, multi-layer perception models don't perceive the temporal relationship between the input variables. It cannot know which input vector element came 'before' another.

Now, it might be possible to build an MLP model big enough so that it overfits the training data and scores perfectly, but such a model is of limited practical use.

We'd rather have a framework specifically built for temporal data.

Enter recurrent neural networks.

Some advantages of RNNs

The first advantage of RNNs over other types of networks is that they can take in inputs of any length. You don't need to do anything different to a vector of lengths 10 and 100. Put them both in, and the model will be fine.

That's a nice property to have.

The second main advantage is that RNNs encode the temporal relationship between the variables in the time series. That means it can 'learn' the relationship between variables.

So, how do RNNs work?

I thought you'd never ask.

Here's a simple neural network:

A simple neural network

The input is 0.6. The first weight is 0.2, and the first bias is 0.7. Then there's a ReLU activation (if you don't know what that is, don't worry too much for now - it's the bit in the middle). The second weight and bias (0.1, 0.2) turn the hidden neuron into an output of 0.28.

All these numbers are arbitrary – I just made them up for this example.  

So far, so good.

The clever bit

The real power of RNNs lies in the fact that the output of one neuron can be passed to the next.

Let's see this in another diagram.

Look closely at the diagram above. The first two elements of the time series are the 'inputs' on the left (0.5 and 0.6). The output of the first neuron (the top line) is ignored, but the model can 'pass' information about the previous neuron down to the next one. This is then added to the value of the second neuron to get a new output.

This 'unravelling' of time series makes it easy to understand how a network can pass information about past events into the future.

Now, some important points to make.

Firstly, the weights in the unravelled neural network are all the same. It isn't that each node has its own weight. For example, all the weights working on the input variables are the same. That's what allows the network to work on networks of arbitrary length. If all the weights were different, the model would need to know how long the input vector was, and we'd lose one of the valuable properties of the network.

In real life, of course, you're probably not going to have just one RNN working on a problem but many, each one searching for different features. This will likely be followed by a fully connected layer (or several such layers).  

Problems with RNNs

Alas, RNNs have a fatal flaw. This is known as either the vanishing gradient problem or the exploding gradient problem (depending on which flavour of problem you have.

Now, vanishing gradients are a problem with many types of deep learning algorithms. However, the problem is particularly pronounced in RNNs.

The problem is the weight of the neuron that passes the neuron from the first element of the time series to the second.

Suppose this neuron had a weight of 2. That means, we multiply the result of the first number (in our case 0.82) by 2 and pass it in. That's not so much of a problem by itself. But now suppose you multiply the next one by 2, and the next one by 2, and so on. If your input vector had 100 elements, you'd do this calculation 100 times. That 2^100 – a large number.

Now, imagine that you want to change that weight. You nudge it up a little bit. But because of the architecture of the network, your output is raised to the power of 100. The gradient has exploded. That makes it very difficult to take small steps to find the optimal weight.

Long-short-term-memory networks save the day

Long-short-term memory (LSTM) networks are recurrent neural networks that cleverly evade the exploding and vanishing gradient problems.

However, this work is a bit more complicated and will require its own post to explain.