A gentle introduction to neural networks

A gentle introduction to neural networks
<a href="https://www.freepik.com/free-photo/network-connections_2052331.htm#query=neural%20network&position=1&from_view=keyword">Image by kjpargeter</a> on Freepik

I will continue my posts on machine-learning techniques by going 'under the hood' and describing how neural networks work.

The target audience is people who want a high(ish)-level understanding of how neural networks operate. Not to the level of mathematics or code. But enough to have sensible conversations about them.

Firstly, why call it a 'neural' network?

Roughly speaking, neural networks are based on theoretical models of how the human brain processes and stores information.

Now, I say 'roughly speaking' because a brain is a complex object (possibly the most complex) and partly because neural networks are quite different to brains.

But it's the name we've got, and we're stuck with it now.

A 'real' neuron has three main parts to it:

  1. Dendrites
  2. Neurons
  3. Axons

These are descriptions of the physical parts of a neuron. However, what's important to us is their function. Here they are:

  1. The Input signal(s)
  2. The calculating machine (the neuron)
  3. The output signal
A neuron

In human brains, the input signals are electrical impulses. But in computers, they're numbers. For example, input 1 could be the number 1, and input 2 could be the number 0.5.

The lines tell us how the input affects the neuron. For example, the line joining input 1 to the neuron could have the rule 'multiply the input by 3'. The line joining input 2 to the neuron could be the rule 'multiply by -1' and so on.

All of these 'rules' are 'multiply the input by a number.' The number we multiply by is called a 'weight'.

At the neuron, we collect these calculations together and add them. So the neuron is the sum of the inputs multiplied by their weights. That gets sent to the output.

That's it!

Activation function

Ok, I lied; that's not quite it.

Neurons don't normally output the raw values described above. Instead, they input those values into an activation function and output that.

What kind of activation functions are there?

There are many, but a sigmoid is a popular one (that isn't too difficult to understand). So is a threshold function (either 1 or 0, depending on the input).

This is the first of many nuances that are common in machine learning. While these asides might be meaningful to an engineer building a machine learning model, they aren't helpful to us trying to understand how neural networks work.

So feel free to imagine the neuron as just adding up the numbers.

A simple model

Suppose we're trying to build a model that predicts the price of a house.

The input layer of this model consists of some facts about that house:

  1. The area of the house
  2. Number of bedrooms
  3. Distance to the city
  4. Age

We could imagine a simple model that looks like this:

A simple model for predicting a house price

Remember that each line represents a weight (a number we multiply the input by). So the prediction would be:

\[ \text{predicted price} = w_1 * \text{area} + w_2 * \text{bedrooms}\ +w_3 * \text{distance} + w_4 * \text{age}\ \]

Astute readers have noticed this is just linear regression!

So what's the point of neural networks?

Neural networks can have hidden layers that uncover non-linear relationships. Such a network might look like this:

Neural network with a hidden layer

In the diagram, I've missed some lines from the input layer to some elements of the hidden layer to represent that the weights might be weak.

Each neuron fires works by the same logic as the single neuron. There are just lots of them.

I think of this as a battleship powered by thousands of tiny propellers.  

So how do neural networks learn?

Here's where we talk about data.

For our network to learn, it has to learn from something. That something is called training data.

We pass instances of our training data into a neural network, and our model transforms that into a prediction.

Then, we compare the prediction of our model to the truth. For example, if we're trying to predict a house price like the model above, then we compare the number our model predicted with the actual price of the house (we would call this a 'regression' type problem when we're trying to guess a number).

Now we perform what's called backward propagation. We can think of going to each weight in the model and asking it the following question: if I could only tweak your value and I want to improve our guess, would I make you slightly larger, slightly smaller, or roughly the same?

Then we tweak that weight in whichever direction it suggested.

Repeat this for weights in the model and then move on to another instance from the training data.

Disclaimer: I could point out all kinds of nuances here (e.g. training in batches rather than one data point at a time). But none of that helps us understand how the model learns.

Differentiation

In the previous section, I spoke about 'asking the model in which direction it should go'. Practically speaking, how do we do that?

As it happens, the algorithm is easy to understand. I taught it to an intelligent A-Level student who understood it in less than an hour.

We differentiate the cost function (the function that increases the further away from the final answer we are). Then we move the weight in the direction down the slope.

Gradient descent

The technical name given to this is gradient descent.

Over time, the cost function decreases, the same as the model getting better at predicting the correct house price (in our example).

How do we check that our model works?

When we have a model, it's time to check it. How do we do that?

Simple. We apply the model to new data that it's never seen before. This new data is called testing data.

Then we observe how well the model performs on the testing data. Simple.

Now, we can use all kinds of fancy statistical tests, scores and measurements to determine this, but ultimately, we want to know how well the model does.

That's it!

Conclusion

I have a theory that there's no such thing as a 'complicated idea'.

That's because whenever I move into a new field of study (and I've moved into several), I always find, at their heart, simple ideas.

Neural networks are a simple idea. It might take a few passes through the theory to slot everything together. But, ultimately, there's very little to them.