Linear regression is one of the fundamental algorithms existing in machine learning. Understanding its internal workflow helps in grasping the main concepts of other algorithms in data science. Linear regression has a wide range of applications where it is used to predict a continuous variable.
Before diving into the inner workings of linear regression let us first understand a regression problem.
Regression is a machine learning problem aiming to predict a value of a continuous variable given a feature vector which is usually denoted as x = <x₁, x₂, x₃, …, xₙ> where xᵢ represents a value of the i-th feature in data. In order for a model to be able to make predictions, it has to be trained on a dataset containing mappings from feature vectors x to corresponding values of a target variable y. The learning process depends on the type of algorithm used for a certain task.
In the case of linear regression, the model learns such a vector of weights w = <x₁, x₂, x₃, …, xₙ> and a bias parameter b which try to approximate a target value y as <w, x> + b = x₁ * w₁ + x₂ * w₂ + x₃ * w₃ + … + xₙ * wₙ + b in the best possible for every dataset observation (x, y).
When building a linear regression model the ultimate goal is to find a vector of weights w and a bias term b that will more closely bring predicted value ŷ to the real target value y for all the inputs:
To make things easier, in the example we are going to look at, a dataset with single feature x is going to be used. Therefore, x and w are one-dimensional vectors. For simplicity, let us get rid of inner product notation and rewrite the equation above in the following way:
In order to train an algorithm, a loss function has to be chosen. The loss function measures how good or bad the algorithm made predictions for a set of objects at a single training iteration. Based on its value, the algorithm adjusts the parameters of the model in the hope that in the future the model will produce fewer errors.
One of the most popular loss functions is Mean Squared Error (or simply MSE) which measures the average square deviation between predicted and true values.
Gradient descent is an iterative algorithm of updating the weights’ vector for minimizing a given loss function by searching for a local minimum. Gradient descent uses the following formula on each iteration:
- <w> is a vector of model weights on the current iteration. Computed weights are assigned to <w>’. During the first iteration of the algorithm, weights are usually initialized randomly but there exist other strategies as well.
- alpha is usually a small positive value, also known as a learning rate, — hyperparameter which controls the speed rate of finding a local minimum.
- upside-down triangle denotes a gradient — vector of partial derivatives of a loss function. In the current example, the vector of weights consists of 2 components. So, to compute a gradient of <w> 2 partial derivatives need to be computed (f represents a loss function):
The update formulas can be rewritten in the following way:
Right now the objective is to find partial derivatives of f. Assuming that MSE is chosen as a loss function, let us compute it for a single observation (n = 1 in the MSE formula), so f = (y — ŷ)² = (y — wx — b)².
The process of adjustment of model’s weights based on a single object is called stochastic gradient descent.
In the section above, model parameters were updated by calculating MSE for a single object (n = 1). In fact, it is possible to perform a gradient descent for several objects in a single iteration. This way of updating weights is called batch gradient descent.
Formulas for updating weights in such a case can be obtained in a very similar manner, compared to stochastic gradient descent in the previous section. The only difference is that here the number of objects n has to be taken into consideration. Ultimately, the sum of the terms of all objects in a batch is computed and then divided by n — the batch size.
When dealing with a dataset consisting only of a single feature, the regression results can be easily visualized on a 2D-plot. The horizontal axis represents values of the feature while the vertical axis contains target values.
The quality of a linear regression model can be visually evaluated by how closely it fits dataset points: the closer the average distance between every dataset point to the line, the better the algorithm is.
If a dataset contains more features, then visualization can be done by using dimensionality reduction techniques like PCA or t-SNE applied to features to represent them in lower dimensionality. After that, new features are plotted on 2D or 3D-plots, as usual.
Linear regression has a set of advantages:
- Training speed. Due to the simplicity of the algorithm, linear regression can be rapidly trained, compared to more complex machine learning algorithms. Moreover, it can be implemented through the LSM method which is also relatively fast and easy to understand.
- Interpretability. A linear regression equation built for several features can be easily interpreted in terms of feature importance. The higher the value of the coefficient of a feature, the more effect it has on the final prediction.
On the other hand, it comes with several disadvantages:
- Data assumptions. Before fitting a linear regression model it is important to check the type of dependency between output and input features. If is linear, then there should not be any issue with fitting it. Otherwise, the model is not normally able to fit the data well since the equation has only linear terms in it. In fact, it is possible to add higher degrees into the equation to turn the algorithm into polynomial regression, for instance. However, in reality, without a lot of domain knowledge it is often difficult to correctly foresee the type of dependency. This is one of the reasons why linear regression might not adapt to given data.
- Multicollinearity problem. Multicollinearity occurs when two or more predictors are highly correlated to each other. Imagine a situation when a change in one variable influences another variable. However, a trained model has no information about it. When these changes are large, it is difficult for the model to be stable during theh inference phase on unseen data. Therefore, this causes a problem of overfitting. Furthermore, the final regression coefficients might also be unstable for interpretation because of this.
- Data normalisation. In order to use linear regression as a feature importance tool, the data has to be normalized or standardized. This will make sure that all of the final regression coefficients are on the same scale and can be correctly interpreted.
We have looked through linear regression — a simple but very popular algorithm in machine learning. Its core principles are used in more complex algorithms.
Though linear regression is rarely used in modern production systems its simplicity allows to use often as a standard baseline in regression problems which is then compared to more sophisticated solutions.
The source code used in the article can be found here:
All images unless otherwise noted are by the author.