library(tidyverse)
<- 20
n <- matrix(rnorm(n * 2), ncol = 2)
X <- rnorm(n)
y
ggplot(NULL, aes(x = X[, 1], y = y)) +
geom_point() +
geom_smooth(method = "lm", se = FALSE) +
labs(title = "One dimension of the simulated data", x = expression(X[1])) +
theme_classic()
Motivation
Suppose you have some loss function \(\mathcal{L}(\beta) : \mathbb{R}^n \to \mathbb{R}\) you want to minimize with respect to some model parameters \(\beta\). You understand how gradient descent works and you have a correct implementation of \(\mathcal{L}\) but aren’t sure if you took the gradient correctly or implemented it correctly in code.
Solution
We can compare our implemention of the gradient of \(\mathcal{L}\) to a finite difference approximation of the gradient. Recall that the gradient of \(\mathcal{L}\), \(\nabla_\mathcal{L}\), in a direction \(d \in \mathbb{R}^n\) at a point \(x \in \mathbb{R}^n\) is defined as
\[d^T \nabla_\mathcal{L}(x) = \lim_{\epsilon \to 0} \frac{\mathcal{L}(x + \epsilon \cdot d) - \mathcal{L}(x - \epsilon \cdot d)}{2 \epsilon}\]
If we take \(\epsilon\) to be fixed and small, we can use this formula to approximate the gradient in any direction. By approximating the gradient in each unit direction, we construct an approximation of the gradient of \(\mathcal{L}\) at a particular point \(x\).
Example: Checking the gradient of linear regression
Suppose that we have \(n = 20\) data points in \(\mathbb{R}^2\) with responses \(y \in \mathbb{R}\). Linear regression assumes the responses \(y\) are related linearly to the data matrix \(X\) via the equation
\[y = X \beta + \epsilon\]
We want to find an estimate \(\hat \beta\) that minimizes the sum of squared error of the predicted values \(\hat y = X \hat \beta\)
\[\mathcal{L}(\beta) = \frac{1}{2n} \sum_i (y_i - \hat y_i)^2 = \frac{1}{2n} \sum_i (y_i - x_i \beta)^2 = \frac{1}{2n} (y - X \beta)^T (y - X \beta)\]
In the final step above we recognize that the sum of squared residuals can be written as a dot product. Next we’d like to the gradient of this dot product. There’s a beautiful explanation of how to take the gradient of a quadratic form here. The gradient (in matrix notation) is
\[\nabla_\mathcal{L}(\beta) = -\frac{1}{n} (y - X \beta)^T X\]
We can now implement an analytical version of \(\nabla_\mathcal{L}(\beta)\) and compare it to a finite difference approximation. First we simulate and visualize some data:
Next we implement our loss and gradient functions. We assume the loss
function is implemented correctly but want to check the analytical_grad
implementation.
<- function(beta) {
loss <- y - X %*% beta
resid sum(resid^2) / (2 * n)
}
<- function(beta) {
analytical_grad <- -t(y - X %*% beta) %*% X / n
grad as.vector(grad)
}
To perform this check, we need get approximate the gradient in a direction \(d\):
#' @param f function that takes a single vector argument x
#' @param x point at which to evaluate derivative of f (vector)
#' @param d direction in which to take derivative of f (vector)
#' @param eps epsilon to use in the gradient approximation
<- function(f, x, d, eps = 1e-8) {
numerical_directional_grad f(x + eps * d) - f(x - eps * d)) / (2 * eps)
( }
And then to approximate the entire gradient, we need to combine directional derivatives in each of the unit directions:
<- function(x) {
zeros_like rep(0, length(x))
}
<- function(f, x, eps = 1e-8) {
numerical_grad <- zeros_like(x)
grad for (dim in seq_along(x)) {
<- zeros_like(x)
unit <- 1
unit[dim] <- numerical_directional_grad(f, x, unit, eps)
grad[dim]
}
grad
}
<- function(want, got) {
relative_error - got) / want # assumes want is not zero
(want }
Now we can check the relative error between our analytical implementation of the gradient and the numerical approximation.
<- c(2, 3) # point in parameter space to check gradient at
b
<- numerical_grad(loss, b)
num_grad <- analytical_grad(b)
ana_grad
num_grad
[1] 2.112107 3.286946
ana_grad
[1] 2.112107 3.286946
relative_error(num_grad, ana_grad)
[1] -2.810374e-08 1.553777e-08
The relative error is small, and we can feel confident that our implementation of the gradient is correct.
This post is based off of Tim Vieira’s fantastic post on how to use numerical gradient checks in practice, but with R
code. See also the numDeriv package.