A Smooth Introduction to Linear Regression and its Implementation in PyTorch (Part-II)

So in Part-I I gave a simple introduction on what linear regression is and how we can find the equation of the best fit line for our data. In this post, I will show you how to implement the task we worked on in Part-I in PyTorch. So, let’s get started!

The first step is to import the libraries we will be using:

Now, we need to set the hyperparameters, as follows:

The input size is set to 1 since our inputs to the model are scalars h (hour in the day). The output is also set to 1 since we will get only one value returned for r (number of pages being read). So, basically, we will leave our program to find the best values for B_0 and B_1 that we calculated in the previous part of this tutorial.

Let’s now enter the training data, which will represent the values of h and r that we already have.

To perform linear regression we have to define three things: model (linear regression), loss function, and the optimizer, after which we can take our data to training. The figure below depicts this process.

Let’s take that step-by-step in PyTorch. So, first we define our linear regression model:

Then, define the loss function (mean squared error):

And the optimizer (stochastic gradient descent):

We can now train our model with the number of epochs specified (i.e. 5).

Oh, to get an idea on what’s going on after each epoch, we can add the following statement in our training for-loop:

That’s it! You have just written a PyTorch program that will find the best fit line (i.e. linear regression) for our data which describes how many pages the person read at each hour of the day.

Let’s go ahead and plot our best fit line against the original data we provided the model with, as follows:

And,….. this is how our best fit line looks like!

If you would like the full code, you can kindly find it here.



Research Fellow @ Massachusetts General Hospital/Harvard Medical School | https://abder.mgh.harvard.edu

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store