Published on

Use Einstein Notation to derive Backpropagation


Recently, I hit into a great book published Huggingface that talked about scaling LLM. One of the figures triggered my thinking - how to derive the gradient flowing back in the compute graph.

Compute Graph BP

I have done that previous with matrix calculus, but it can be tedious. Is there a simple way to do the derivation? Einstein notation comes to rescue. For those who are not familiar with Einstein notation, you can check this website for more details.

For the MLP layer shown in the figure, the forward pass is pretty straightforward:

z=Wxy=σ(z)L=ytruey22WRM×NxRN×1zRM×1yRM×1z = Wx \\[1ex] y = \sigma(z) \\[1ex] L = | y_{\text{true}} - y |^2_2 \\[1ex] W \in \mathbb{R}^{M \times N} \\[1ex] x \in \mathbb{R}^{N \times 1} \\[1ex] z \in \mathbb{R}^{M \times 1} \\[1ex] y \in \mathbb{R}^{M \times 1} \\[1ex]

Here, the loss LL is a scalar. We can express it using Einstein notation by first defining the error term:

ei=ytrue,iyie_i = y_{\text{true}, i} - y_i

so that

L=eieiL = e_i e_i

Differentiating L with respect to y_i yields

Lyi=2ei(1)\frac{\partial L}{\partial y_i} = 2 e_i (-1)

where the negative sign appears because the error is defined as ei=ytrue,iyie_i = y_{\text{true}, i} - y_i

Next, applying the chain rule, the gradient with respect to zz is

Lzi=Lyiσ(zi)zi\frac{\partial L}{\partial z_i} = \frac{\partial L}{\partial y_i} \cdot \frac{\partial \sigma(z_i)}{\partial z_i}

Now, let’s derive the gradient with respect to the weight matrix WW. Since

z=Wxzi=Wijxj,z = Wx \quad \Longrightarrow \quad z_i = W_{ij} x_j,

the derivative with respect to an element WijW_{ij} is

grad(Wij)=LWij=LziziWij=Lzixj\text{grad}(W_{ij}) = \frac{\partial L}{\partial W_{ij}} = \frac{\partial L}{\partial z_i} \cdot \frac{\partial z_i}{\partial W_{ij}} = \frac{\partial L}{\partial z_i} x_j

In matrix form, this can be neatly written as

grad(W)=LzxT\text{grad}(W) = \frac{\partial L}{\partial z} x^T

Using Einstein notation, we focus on individual scalar components rather than entire vectors or matrices, which simplifies the derivation. You can apply the same process to the input x to compute its gradient.