Let’s write a function for matrix multiplication in Python.
We start by finding the shapes of the 2 matrices and checking if they can be multiplied after all.(Number of columns of matrix_1 should be equal to the number of rows of matrix 2)
Then we write 3 loops to multiply the matrices element wise. The shape of the final matrix will be (number of rows matrix_1) by (number of columns of matrix_2).
Now let’s create a basic neural net where we will use this function.
In this article we will be using the MNIST dataset for demonstration purposes. It contains 50,000 samples of handwritten digits. These digits are originally 28*28 matrices (or 784 values in a linear vector after unpacking).
Hence our neural net takes 784 values as input and gives the 10 classes as output.
Let’s now grab 5 elements from the MNIST validation set and run them through this model.
We see that for a mere 5 elements, it took us
650 milliseconds to perform matrix multiplication. This is relatively slow. Let’s try to speed it up.
Why is speed important?
Matrix multiplication forms the basis of neural networks. Most operations while training a neural network require some form of matrix multiplication. Hence doing it well and doing it fast is really important.
We will speed up our matrix multiplication by eliminating loops, and replacing them with PyTorch functionalities. This will give us C speed (underneath PyTorch) instead of Python speed. Let’s see how that works.
Eliminating the innermost loop
We start by eliminating the innermost loop. The idea behind eliminating this loop is that instead of doing operations on one element at a time, we can do them on one row (or column) at a time. Take a look at the image below.
We have 2 tensors and we want to add their elements together. We can write a loop to do so or we can make use of PyTorch’s elementwise operations (a + b directly) to do the same.
Using the same idea we will eliminate the innermost loop so that instead of doing
we directly do
Our function now looks as follows,
and takes about
1.55 milliseconds to run which is massive improvement!
If you are not familiar with the indexing syntax,
a[i,:] means select the
ith row and all columns while
b[:,j] means select all rows and the
We can write a little test to confirm that our updated function gives the same output as our original function.
And it does.
Eliminating the second loop
We can now move on to eliminating the second loop. And this is the most exciting part because this time, we will go from this
To do so, we need to know something known as broadcasting.
Suppose you want to subtract the mean from every data point in your dataset. Once again you can write a loop to do so or you can make use of broadcasting.
In broadcasting, we take the smaller tensor, and broadcast it across the larger tensor so that they have comparable shapes. Once they have comparable shapes we can perform elementwise operations on them. Let’s see another example.
Do you see what happened there? Tensor c got broadcasted so that it had the same number of rows as m. We can find out what a tensor will look like after being broadcasted with the
And here’s the best part. PyTorch does not actually duplicate values. It just pretends to do so. Let’s take a look at the storage and shapes of
The tensor t is still stored as only [10,20,30] but it knows that it’s shape is supposed to be 3*3. This makes broadcasting memory efficient.
Using broadcasting, we will broadcast the first row of matrix_1 and operate it with the whole of matrix_2. Our function now looks as follows:
and takes only
402 micro seconds to run!
This is the best we can do in a flexible way. If you want to do even better you can use Einstein summation to do so.
But the fastest way would be to use PyTorch’s
The reason it’s so fast is because it uses assembly language code underneath as well.
That would be it for this article.
If you liked this article, give it atleast 50 claps :p
If you want to learn more about deep learning you can check out my deep learning series below.
Deep learning from the foundations: fastai.