So, why automatic differentiation?

Deep what?

This post is meant as an introduction, and will brush past some basic topics of neural networks as an exception. If you’re already familiar with the basics of deep learning and the math involved, feel free to skip this post and jump straight to here.

Deep learning is concerned with deep artificial neural networks. In essence, a neural network is determined by its parameters (values stored within the network that are used for computation) and architecture (how these parameters interact with each other).

The workflow is usually like this: we have a problem we want to solve with deep learning, we decide on an architecture that we think will be suited for the task, and then we must find values for the parameters that will make the network generally produce the correct ouput. We call the process of finding these values “training”, and it always involves minimizing a loss function.

As an example, our network may be trying to classify images of cats and dogs by producing an output between $0$ and $1$ and applying a threshold at $0.5$. Let’s say we put an image of a cat through our network, and the range we have assigned to this category is $[0.5, 1]$. If the network outputs $0.2$, we would want to identify the parameters that increase the output the most, and modify them accordingly so the output falls into the range for cats. Or, from a loss function perspective, we have a measure of how far we are from the desired output (1 for cats, 0 for dogs) and want to nudge the parameters towards values that will minimize it.

The loss function perspective is particularly interesting because it allows us to treat training as an optimization problem, where we’re simply trying to find the minimum of a function. To do this, we could just randomly sample parameter values and pick the instance that gave us the smallest loss. Or we could be smart about it, find how much each parameter impacts the loss and in which direction, and use that information to steadily and iteratively move towards a minimum. This paradigm is called gradient descent, and it’s what virtually all neural networks use to minimize their loss functions.

Math refresher: the gradient

This is nice and all, but… How exactly are we meant to find the parameter information? Well, the term “gradient descent” should have already given you a hint (and the title of this section is not trying to hide anything): we will use the gradient of the loss function.

The gradient of a function $f$ at a point $p$ gives us the direction of steepest ascent. To get some intuition about why, let’s take a look at how it’s defined:

$$ \nabla f(p)^T = \left[ \frac{\partial f}{\partial x_1}(p), \frac{\partial f}{\partial x_2}(p), \ldots, \frac{\partial f}{\partial x_n}(p) \right] $$

As we can see, in each coordinate we have the partial derivative with respect to the corresponding variable. Let us focus on the first term. $\frac{\partial f}{\partial x_1}(p)$ tells us how much nudging $x_1$ affects the ouput of $f$ (by how big its magnitude is), and in which way these are correlated (since a negative sign indicates that increasing $x_1$ would decrease the value of $f$). This is all summarized by the fact that a derivative (and a partial derivative, for that matter) can be seen as the best linear map approximation near a point. Considering this for all coordinates, it gives us a full guide on how to locally modify the $x_i$ to move towards coordinates that will maximize $f$.

In our case, $f$ is our loss function, each $x_i$ corresponds to a parameter in the network and we just want to flip the sign and go in the direction of the negative gradient instead. Given our current understanding, training is something like:

  1. Start with some initial parameters
  2. Calculate the loss function on given data
  3. Obtain partial derivatives for the loss with respect to all the parameters
  4. Nudge all parameters in the direction of the negative gradient

Although there is much to say about all of these, today we are interested in point $3$. And for that we’ll need a way to get partial derivatives, fast.

Math refresher: the chain rule

Let us consider a toy computation:

$$ (a + b) \cdot c = d $$

How could we get, say, $\frac{\partial d}{\partial a}$? Of course, we could analitically solve it by just looking at it… But that’s easy to do now, when there’s only three variables $d$ depends on. What happens when we have neural networks with billions of parameters? What kind of program would be able to reduce such massive expressions to the analytic partial derivative?

The trick here is, again, to go slow and steady. We can hardcode how to get partial derivatives involving a single operation, and use the chain rule to build up the rest. But let’s not get ahead of ourselves; what exactly is the chain rule?

Given two real functions $f$ and $g$, we may be interested in how z =$(f \circ g)(x) = f ( g (x))$ behaves as we nudge $x$ or, in other words, in $\frac{dz}{dx}$. Since $f$ and $g$ are functions in their own right, it is possible to obtain information from them about how their ouput reacts to small oscillations in their input. In particular, for $z = f(y)$ we can find $\frac{dz}{dy}$ and from $y = g(x)$ we can find $\frac{dy}{dx}$. The chain rule tells us in which way these individual derivatives interact so as to produce the derivative of the overall, composed function $(f \circ g)$. And it turns out that you just need to multiply them to obtain the result:

$$ \frac{dz}{dx} = \frac{dz}{dy}\frac{dy}{dx} $$

This is one of those cases where common sense is right. Intuitively, we want to multiply rates of change: “If a car travels twice as fast as a bicycle and the bicycle is four times as fast as a walking man, then the car travels 2 × 4 = 8 times as fast as the man.”.

In our specific example this means that it is interesting to represent our partial derivative as:

$$ \frac{\partial d}{\partial a} = \frac{\partial d}{\partial h} \frac{\partial h}{\partial a} $$

where we have introduced the intermediate variable $h = a + b$. Note that both partials on the right hand side involve a single operation; multiplication in the first case and addition in the second. Hence, our machine would be able to compute:

$$ \frac{\partial d}{\partial h} = \frac{\partial}{\partial h} \big(h \cdot c\big) = c $$

and

$$ \frac{\partial h}{\partial a} = \frac{\partial}{\partial a} \big( a + b \big) = 1 $$

The last step of multiplying the individual partials (applying the chain rule) gives us our result $\frac{\partial d}{\partial a} = c$. Note that, in practice, our computer will not be storing the symbolic expression “$c$”, but rather the actual value of the partial derivative, since we are evaluating it at a point of interest (the current parameter values).

Multiple chains

There is one final thing to consider: what if the same variable is present more than once? A single chain of products would not be able to account for all the impact that this variable has on the output. For example, let us modify our toy computation:

$$ d = (a + b) \cdot (a \cdot c) $$

Sure, $\frac{\partial d}{\partial h} \frac{\partial h}{\partial a}$ gives us some information about the derivative of $d$ with respect to $a$, but what about the multiplying $c$? Shouldn’t that have an impact on $d$ as well? The answer is that, of course, it does. In fact, the only thing we need to do is add up all the chains where an intermediate result depends on $a$. Given $i = a \cdot c$, we obtain:

$$ \frac{\partial d}{\partial a} = \frac{\partial d}{\partial h} \frac{\partial h}{\partial a} + \frac{\partial d}{\partial i} \frac{\partial i}{\partial a} $$

In general, we can express this by considering a vector function that wraps together all intermediate results that depend on the variable we are differentiationg with respect to. Given $f : \mathbb{R}^n \to \mathbb{R}$ and $g: \mathbb{R} \to \mathbb{R}^n$ such that $z = f(\bm{y})$ and $\bm{y} = g(x)$, we have the additive relationship:

$$ \frac{\partial z}{\partial x} = \sum_i \frac{\partial z}{\partial y_i} \frac{\partial y_i}{\partial x} $$

Summary

Hopefully the question in the title of this blog post is now answered. An automatic differentiation package calculates the gradient of the loss function with respect to the parameters of a neural network, which can be used to minimize the objective and effectively train. Such a package will rely on the chain rule to obtain partial derivatives fast, and reuse precomputed results.

We now have all the conceptual pieces we need to start building such a package. Let’s complete the puzzle!

seemore ml