In my previous article, I mentioned that data augmentation helps deep learning models generalize well. That was on the data side of things. What about the model side of things? What can we do while training our models, that will help then generalize even better.
We do weight decay.
Before you read this article, make sure you understand my article on learning rates, because that article is the starting point for this one.
Parameters of a model
Trending AI Articles:
1. Machines Demonstrate Self-Awareness
2. Bursting the Jargon bubbles — Deep Learning
3. How Can We Improve the Quality of Our Data?
4. Machine Learning using Logistic Regression in Python with Code
We start by looking at the image above. We see that we have a bunch of data points and that we cannot fit them well with a straight line. Hence, we use a 2nd degree polynomial to do so. We also notice that if increase the degree of the polynomial beyond a certain point, then our model becomes too complex and it starts to overfit.
This means that in order to prevent overfitting, we shouldn’t allow our models to get too complex. Unfortunately, this has led to a misconception in deep learning that we shouldn’t use a lot of parameters (in order to keep our models from getting overly complex).
Origin of weight decay
First of all, real world data is not going to be as simple as the one shown above. Real world data is complex and in order to solve complex problems, we need complex solutions.
Having less parameters is only one way of preventing our model from getting overly complex. But it is actually a very limiting strategy. More parameters mean more interactions between various parts of our neural network. And more interactions mean more non linearities. These non linearities help us solve complex problems.
However, we don’t want these interactions to get out of hand. Hence, what if we penalize complexity. We will still use a lot of parameters, but we will prevent our model from getting too complex. This is how the idea of weight decay came up.
We’ve seen weight decay in my article on collaborative filtering. In fact, every learner in the fastai library has a parameter called weight decay.
This thing called weight decay
We’ve come to a conclusion that in order to prevent our models from overfitting, we need to penalize complexity.
One way to do so would be to add all our parameters (weights) to our loss function. Well, that won’t quite work because some parameters are positive and some are negative. So what if we add the squares of all the parameters to our loss function. We can do that, however it might result in our loss getting so huge that the best model would be to set all the parameters to 0.
To prevent that from happening, we multiply the sum of squares with another smaller number. This number is called weight decay or
Our loss function now looks as follows:
Loss = MSE(y_hat, y) + wd * sum(w^2)
Now what should the value of weight decay be?
wd = 0.1 works pretty well. However, the folks at fastai have been a little conservative in this respect. Hence the default value of weight decay in fastai is actually
The reason to choose this value is because if you have too much weight decay, then no matter how much you train, the model never quite fits well enough whereas if you have too little weight decay, you can still train well, you just have to stop a little bit early.
I’ve demonstrated this concept in this jupyter notebook.
It is a multi-class (and not a multi-label) classification problem where we try to predict the class of plant seedlings.
I’ve used 3 values for weight decay, the default
0.01 , the best value of
0.1 and a large value of
10 . In the first case our model takes more epochs to fit. In the second case it works best and in the final case it never quite fits well even after 10 epochs. (see difference b/w training and validation loss.)
In my next article, we will dive a little deeper into the math of weight decay and learn some more concepts like momentum and Adam’s optimizer with code samples. So stay tuned
If you liked this article, give it atleast 50 claps :p