Implementing automatic differentiation from scratch

So here you are, writing some PyTorch when you mindlessly call loss.backward() and, out of nowhere, you get the gradient of the loss with respect to all your parameters: just what you needed to improve your model. A bit fishy, isn’t it? What exactly is going on in here?

Well, long story short: you invoked autograd, PyTorch’s automatic differentiation package, and it took care of all the computations needed. In fact, it started taking care of it long before you realized! Today we will build a simple version of autograd to understand which kind of magic it is using. Let’s go!

Ingredients

Let us first clarify what we need the package to do, and which tools are at our disposal. We may be working with an arbitrarly number of variables, performing computations on them and obtaning a result variable. autograd must be able to find the partial derivatives of the result with respect to all other variables it depends on.

In my previous post I already hinted that the best way to go about this is to hardcode partial derivatives that only involve single operations and use the chain rule to build up the rest. Let us see a simple example given a toy computation:

$$ e = (a + b) \cdot c + d \ h_1 = (a+b) \cdot c \ h_2 = a + b $$

We can compute partials by going one level deeper each time.

  1. $\frac{\partial e}{\partial h_1}$ involves a single operation, so we can directly compute it and see that it equals $1$.
  2. By the chain rule we have $\frac{\partial e}{\partial h_2} = \frac{\partial e}{\partial h_1}\frac{\partial h_1}{\partial h_2}$: the first term we just computed, while the second one involves a single operation, so we just need to multiply. We get $\frac{\partial e}{\partial h_2} = 1 \cdot c = c$.
  3. Finally, and again by the chain rule, we have $\frac{\partial e}{\partial a} = \frac{\partial e}{\partial h_2}\frac{\partial h_2}{\partial a}$. Similarly to the previous point, we have already computed the previous term and we can obtain the second one from our hardcoded partials, giving us the final result $c$. And this holds no matter how deep we go!

Of course, this means that we will need to be careful about the order in which we compute these partials. Given a variable, its partial must only be calculated once all the partials of the variables that depend on it have been also obtained.

More importantly, though, it means that we must keep track of the dependencies between the variables while they are being used to create new results: we need to build a computation graph. Such a graph will tell us which variables a result immediately depends on and through which operation.

With this we now have an idea of which kind of features our package has to support, and even some hints about how to implement it. We need to:

  • store the variables’ values
  • maintain a computation graph
    • store which operation generates a result
    • store which variables a result depends on
  • compute and store partials of the variables with respect to a chosen result
    • hardcode partials for individual operations

The Value wrapper and the computation graph

Let us first focus on creating a structure that will let us store the computation graph in some way. Then we wil incrementally add more features until we have a fully functional package.

The way we will go about this is by wrapping all of our operations in a class called Value. Each new variable we create will have to be initialized through the Value constructor, and every time a result is obtained from one or more variables, a new Value object will be created to keep track of which operation was used and which variables it depends on.

Let’s look at how we can define the constructor for now:

class Value:
    def __init__(self, data, deps=(), op=''):
        self.data = data
        self._deps = set(deps)
        self._op = op

Every new Value gets initialized with some data. Optionally, we can specify its dependencies as a tuple of other Value instances it may depend on, and the operation used to get to obtain said Value, encoded as a character.

So now our variables are wrapped around this container… but, as you may have guessed, we cannot perform any operations on them! In Python, when we do 2 + 2 this is just syntactic sugar for 2.__add__(2). In other words, if we want an arbitrary class to support addition with +, it must have __add__ defined for it. If you hadn’t yet, now you must have come to the fatal realization: we will have to reimplement all basic operations. Luckily for us, a lot of them can be repurposed (subtraction is addition after negation, for example), plus we will only focus on addition and multiplication in this post; consider the rest homework! ;)

Back to the matter at hand, we must implement __add__ and __mul__ (the equivalent for the product) in our Value class. There’s no secret to it, we just return a new Value with the correct deps and op:

def __add__(self, other):
    result = Value(self.data + other.data, (self, other), '+')
    return result

def __mul__(self, other):
    result = Value(self.data * other.data, (self, other), '*')
    return result

This is all we need to maintain a computation graph of any result involving addition and multiplication. Convince yourself! You can define the function __repr__ inside Value to tell Python how it should behave when printing an instance of it. Then, play around trying different combinations and check if the graph is build correctly. Here’s an example of how your __repr__ function could look like:

def __repr__(self):
    deps_data = [x.data for x in self._deps]
    return f"Value(data={self.data}" + \
        f", deps_data={deps_data}, op={self._op})" if self._op else ")"

What do you get when you execute print(Value(3) + Value(4))? What about combining addition and multiplication? If all is correct, you should be able to trace any result you get all the way back to the first variable you used.

Backpropagation

We now have a way to build and maintain computation graphs. How can we use this to calculate the partial derivatives that we care about?

As we mentioned, a good approach is to hardcode partials involving one operation and obtaining the rest by applying the chain rule. Let us focus on variable $a$ from our previous example:

$$ e = (a + b) \cdot c + d \ h_1 = (a+b) \cdot c \ h_2 = a + b $$

We’re looking for $\frac{\partial e}{\partial a}$ and we conviniently thought of expressing this as $\frac{\partial e}{\partial h_2}\frac{\partial h_2}{\partial a}$. For now, we assume that we have already computed $\frac{\partial e}{\partial h_2}$, and it is stored somewhere in the instance of Value representing $h_2$. The other partial evaluates to $1$, and it’s the one we’ve been saying we want to hardcode somewhere.

Where can we fit these operations within the code and when should they be executed? Note that our computation graph has edges from results to variables they depend on. This means that we would not be able to use it to access $h_2$’s partial from the Value respresenting $a$ (as we can get to $a$ from $h_2$, but not viceversa). But there is a trick. When $h_2$ is created, inside $a$’s __add__ function, $a$ has access to it. Of course, we cannot apply the chain rule just yet, because we do not have $\frac{\partial e}{\partial h_2}$; but we can define what we would like to do, and execute it once $h_2$’s partial has been calculated. Moreover, we have been forced to replicate all operations inside Value, which makes it perfect for incorporating the hardcoded partial inside the funcitons representing these operations.

Going back to our example, we would like to modify __add__ so that it stores a function detailing how to update the gradient of $a$ once the gradient of $h_2$ is computed. This function should incorporate the derivative of addition and the partial $\frac{\partial e}{\partial h_2}$. Of course, we will need to modify the constructor to add member variables that can store this function and the eventual gradient result. The updated class looks like this:

class Value:
    def __init__(self, data, deps=(), op=''):
        self.data = data
        self.grad = 0
        self._deps = set(deps)
        self._op = op
        self._backward = lambda: None

    def __add__(self, other):
        result = Value(self.data + other.data, (self, other), '+')

        def backward():
            self.grad += 1 * result.grad
            other.grad += 1 * result.grad

        result._backward = backward
        return result

Look closely at what backward is doing; there’s a number of things to talk about here. First of all, we end up storing this backward function in the instance of the result (this way all independent variables are nicely discerned). As a result, backward has to take care of both participating variables in binary operations, which is why we set both self.grad and other.grad. For each of these it adds the same result: 1 represents the derivative of addition, while result.grad is $\frac{\partial e}{\partial h_2}$ in our example, and it’s applying the chain rule by multiplying the two, thus obtaining $\frac{\partial e}{\partial a}$ (and $\frac{\partial e}{\partial b}$ though other). Finally, note how it’s not just assigning these gradients; it’s adding them. The reason for this was clarified in this section of the previous post and has to do with a variable participating more than once in a computation.

What do you say, wanna have a go at modifying __mul__ yourself? Just remember that now the derivative of the product is not the same for self and other anymore! Here’s the solution in case you want to compare:

def __mul__(self, other):
    result = Value(self.data * other.data, (self, other), '*')

    def _backward():
        self.grad += other.data * result.grad
        other.grad += self.data * result.grad

    result._backward = _backward
    return result

With this we have established how the partials should be calculated. Now we just need to actually calculate them! Let us reuse our example, but this time we give it some values:

a = Value(0)
b = Value(1)
c = Value(2)
d = Value(3)
h_2 = a + b
h_1 = h_2 * c
e = h_1 + d

To obtain $\frac{\partial e}{\partial a}$ we will have to work our way though all the variables that depend on $a$ first. Obviously, $\frac{\partial e}{\partial e} = 1$ so we can already set e.grad = 1. Then it’s just a matter of calling _backward in order so that dependent variables never go before their dependencies. Overall, our operations would look like this:

e.grad = 1
e._backward()
h_1._backward()
h_2._backward()

Now check if you got something for a.grad! It may be interesting to you to modify the __repr__ function so it now incorporates information about grad.

What we did just now is a manual re-enactment of the back-propagation algorithm (backprop for short). It gets its name due to it “propagating” the gradients backwards in the computation graph (basically applying the chain rule a bunch of times). Backprop always needs to first have a forward propagation pass where the graph is built.

Getting the order right

We’re almost there! The only thing we need to do is to formally define the order in which we will call _backward. Given a variable, we want all its dependent results to be called before it. This constraint, and keeping in mind that the computation graph is a directed acyclic graph, generates a partial ordering on the variables. We can then use any kind of topological sorting to obtain the ordering we want.

I know that was a lot of heavy words and Wikipedia links, but I do not want to spend too much time on the technicalities of topological ordering; this post is already long enough (thanks for reading, by the way! Glad you made it this far!). In any case, the intuition is really not that hard to understand. The code below builds a topological ordering given a result res:

topo = []
visited = set()
def build_topo(v):
    if v not in visited:
        visited.add(v)
        for dep in v._deps:
            build_topo(dep)
        topo.append(v)
build_topo(res)

Take a look at what is happening here. topo is where we are building the ordering; as we will see this will come out reversed. We also keep track of a set of visited variables so we add them only once in topo. Our recursive function build_topo takes a variable v and (if it has not been visited yet), calls itself on all the dependencies of v; only after that we append v to topo. This means that the calls of build_topo on the dependecies will append these before v is appended, and this invariant will be true for all variables.

As mentioned, the result of this is that the last variable generated by out forward pass will be at the end of topo, and we will want to call _backward on the variables in the reverse order of topo. So, what about we put all of this together in a function? This way, whenever we generate a result and we want to get all the partials of that result with respect to all the variables it depends on, we can just call this one function.

def backward(self):
    topo = []
    visited = set()
    def build_topo(v):
        if v not in visited:
            visited.add(v)
            for dep in v._deps:
                build_topo(dep)
            topo.append(v)
    build_topo(self)

    self.grad = 1
    for value in reversed(topo):
        value._backward()

We literally just took the code to build a (reversed) topological ordering, set the gradient of our result to 1, and called _backward in the correct order. How beautiful!

Final thoughts

And that’s really all there is to it! Of course, we just covered the backbone of what a package like autograd looks like, but filling in the gaps shouldn’t be too much of an issue. In fact, why don’t you try? You can find the code we covered in this post on my GitHub; it’s a good exercise to try and implement some of the missing operations yourself. Some can be recycled, such as negation and subtraction, but some will require a bit more thinking, such as raising to the power.

You will see I have added a number of quality of life improvements. For example, there is now an extra line in each operation that will convert other to a Value if it isn’t. Furthermore, we have implemented the reverse operators for addition and multiplication. These are what we default to when the second element in a binary operation is of type Value but the first one isn’t. All in all, this will make it so that combining a Value with a number that is not wrapped in our class will not throw an error. Why don’t you try and implement another basic operation like subtraction or division?

I hope you learned something with this post and that you’re as excited about the future as I am. In the following months I plan to go through some major milestones in deep computer vision; don’t miss it!

seemore ml