# Simple Neural Net Backward Pass

Deriving the math of the backward pass for a simple neural net.
neural-nets
Published

November 13, 2022

# Motivation

Connect the math to the code for what Jeremy Howard did in his 03_backprop.ipynb notebook for the 2022 part 2 course. The simple network for this post is that we have a single neuron followed by a Rectified Linear Unit (ReLU) and then the Mean Squared Error (MSE) loss function.

I know we want to get to many neurons (stacked in multiple layers) but even this simple setting will take us quite far.

# Notation

• $$N$$ : number of training examples
• $$d$$ : number of features
• $$\mathbf{x}^{(i)}$$ is the $$i$$-th training example and can be represented as a column vector $\mathbf{x}^{(i)} = \begin{bmatrix} x^{(i)}_1 \\ x^{(i)}_2 \\ \vdots \\ x^{(i)}_d \end{bmatrix}$
• $$\mathbf{w}$$ is the vector of weights for a single neuron $\mathbf{w} = \begin{bmatrix} w_1 \\ w_2 \\ \vdots \\ w_d \end{bmatrix}$
• $$b$$ is the bias term for a single neuron
• $$\mathbf{X}$$ is an $$N \times d$$ matrix with the training examples stacked in rows $\mathbf{X} = \begin{bmatrix} x^{(1)}_1 & x^{(1)}_2 & \cdots & x^{(1)}_d\\ x^{(2)}_1 & x^{(2)}_2 & \cdots & x^{(2)}_d\\ \vdots & \vdots & \ddots & \vdots\\ x^{(N)}_1 & x^{(N)}_2 & \cdots & x^{(N)}_d \end{bmatrix}$
• $$\{y^{(i)}\}_{i=1}^{N}$$ are the targets for each of the $$N$$ training examples
• $$z^{(i)}$$ is the output of our single neuron when the $$i$$-th training example is passed through it
• Specifically, $$z^{(i)} = \mathbf{w}^T\mathbf{x}^{(i)} + b = b + \sum_{j=1}^{d}w_{j}x^{(i)}_{j} = b + w_{1}x^{(i)}_{1} + \ldots + w_{j}x^{(i)}_{j} + \ldots + w_{d}x^{(i)}_{d}$$. Additionally we see that,
• $$\frac{\partial z^{(i)}}{\partial w_{j}} = x^{(i)}_{j}$$
• $$\frac{\partial z^{(i)}}{\partial b} = 1$$
• $$\frac{\partial z^{(i)}}{\partial x^{(i)}_{j}} = w_{j}$$
• $$a^{(i)} = \phi(z^{(i)})$$ is the activation when an input $$z^{(i)}$$ is passed through an activation function $$\phi$$
• $$\{a^{(i)}\}_{i=1}^{N}$$ are the activations for each of the $$N$$ training examples when passed through a single neuron followed by the application of the activation function
• For the purposes of this page the activation function is considered to be a ReLU so $a^{(i)} = \phi(z^{(i)}) = \max\{0, z^{(i)}\} = \begin{cases} z^{(i)} & z^{(i)} \gt 0 \\ 0 & z^{(i)} \leq 0 \end{cases}$

Erik Learned-Miller’s Vector, Matrix, and Tensor Derivatives was very helpful for this section particularly the following excerpt from page $$7$$: Another useful reference is Terence Parr and Jeremy Howard’s Matrix Calculus You Need For Deep Learning.

After reading this post a good next stop would be Justin Johnson’s Backpropagation for a Linear Layer.

## Gradient of loss with respect to the Activations $$a^{(i)}$$

$$J\left(\{y^{(i)}\}_{i=1}^{N},\{a^{(i)}\}_{i=1}^{N}\right) = \frac{1}{N}\sum_{i=1}^{N}(y^{(i)}-a^{(i)})^{2}$$ is the loss (mean squared error) across the $$N$$ training examples

The backward function of the Mse class computes an estimate of how the loss function changes as the input activations change.

class Mse():
def __call__(self, inp, targ):
self.inp,self.targ = inp,targ
self.out = mse(inp, targ)
return self.out

def backward(self):
N = self.targ.shape
A = self.inp
Y = self.targ
dJ_dA = (2./N) * (A.squeeze() - Y).unsqueeze(-1)
self.inp.g = dJ_dA

The change in the loss as the $$i$$-th activation changes is given by

$\frac{\partial J}{\partial a^{(i)}} = \frac{2}{N}\sum_{i=1}^{N} \frac{\partial (y^{(i)}-a^{(i)})^{2}}{\partial a^{(i)}} = \frac{2}{N}(y^{(i)}-a^{(i)})\frac{\partial (y^{(i)}-a^{(i)}) }{\partial a^{(i)}} = \frac{2}{N}(a^{(i)}-y^{(i)})$

where the last step follows because $$\frac{\partial (y^{(i)}-a^{(i)}) }{\partial a^{(i)}} = 0-1 =-1$$.

The change in the loss as a function of the change in activations from our training examples is captured by the $$N \times 1$$ matrix:

$\frac{\partial J}{\partial \mathbf{a}} = \begin{bmatrix} \frac{\partial J}{\partial a^{(1)}} \\ \frac{\partial J}{\partial a^{(2)}} \\ \vdots \\ \frac{\partial J}{\partial a^{(N)}} \end{bmatrix} = \begin{bmatrix} \frac{2}{N}\left(a^{(1)}-y^{(1)}\right) \\ \frac{2}{N}\left(a^{(2)}-y^{(2)}\right) \\ \vdots \\ \frac{2}{N}\left(a^{(N)}-y^{(N)}\right) \end{bmatrix}$

From the implementation perspective, the activations $$\{a^{(i)}\}_{i=1}^{N}$$ and the targets $$\{y^{(i)}\}_{i=1}^{N}$$ are passed in and stored during the forward pass (specifically in the dunder __call__ method). In the backward pass these are retrieved and $$\frac{\partial J}{\partial \mathbf{a}}$$ is computed and stored for access by the backward function of the prior layer. Hopefully the backward method for the Mse class makes sense now.

## Gradient of loss with respect to the Linear Output $$z^{(i)}$$

class Relu():
def __call__(self, inp):
self.inp = inp
self.out = inp.clamp_min(0.)
return self.out

def backward(self):
dJ_dA = self.out.g
dA_dZ = (self.inp>0).float()

# Note this is an elementwise multiplication
dJ_dZ = dJ_dA * dA_dZ

self.inp.g = dJ_dZ

How does the loss change as the output of the linear unit changes?

$\frac{\partial J}{\partial z^{(i)}} = \frac{\partial J}{\partial a^{(i)}} \frac{\partial a^{(i)}}{\partial z^{(i)}} = \frac{2}{N}\left(a^{(i)}-y^{(i)}\right)\frac{\partial a^{(i)}}{\partial z^{(i)}}$

For the ReLU activation function we have that, $\frac{\partial a^{(i)}}{\partial z^{(i)}} = \begin{cases} 1 & z^{(i)} \gt 0 \\ 0 & z^{(i)} \leq 0 \end{cases}$

and hence

$\frac{\partial \mathbf{a}}{\partial \mathbf{z}} = \begin{bmatrix} \frac{\partial a^{(1)}}{\partial z^{(1)}} \\ \frac{\partial a^{(2)}}{\partial z^{(2)}} \\ \vdots \\ \frac{\partial a^{(N)}}{\partial z^{(N)}} \end{bmatrix}$

Thus the change in the loss as a function of the change in the output from the linear unit on our training examples is given by the $$N \times 1$$ matrix:

$\frac{\partial J}{\partial \mathbf{z}} = \begin{bmatrix} \frac{\partial J}{\partial z^{(1)}} \\ \frac{\partial J}{\partial z^{(2)}} \\ \vdots \\ \frac{\partial J}{\partial z^{(N)}} \end{bmatrix} = \begin{bmatrix} \frac{\partial J}{\partial a^{(1)}}\frac{\partial a^{(1)}}{\partial z^{(1)}} \\ \frac{\partial J}{\partial a^{(2)}}\frac{\partial a^{(2)}}{\partial z^{(2)}} \\ \vdots \\ \frac{\partial J}{\partial a^{(N)}}\frac{\partial a^{(N)}}{\partial z^{(N)}} \end{bmatrix}$

So $$\frac{\partial J}{\partial \mathbf{z}}$$ ends up being an elementwise product between the corresponding entries of $$\frac{\partial J}{\partial \mathbf{a}}$$ and $$\frac{\partial \mathbf{a}}{\partial \mathbf{z}}$$.

From the implementation perspective, in the backward pass $$\frac{\partial \mathbf{a}}{\partial \mathbf{z}}$$ will be computed locally and multiplied elementwise with $$\frac{\partial J}{\partial \mathbf{a}}$$ (this will have been computed in the backward pass in the Mse class and will be available to access when the backward function of the Relu function is called).

## Gradient of loss with respect to $$w_{j}, b$$ and $$X$$

The next three subsections will explain the backward function of the Lin class.

class Lin():
def __init__(self, w, b): self.w,self.b = w,b

def __call__(self, inp):
self.inp = inp
self.out = lin(inp, self.w, self.b)
return self.out

def backward(self):
# See Gradient of loss with respect to w_j
dJ_dZ = self.out.g
X = self.inp
dJ_dW = X.t() @ dJ_dZ
self.w.g = dJ_dW

# See Gradient of loss with respect to the bias b
dJ_db = dJ_dZ.sum(0)
self.b.g = dJ_db

# See Gradient of loss with respect to X
dJ_dX = dJ_dZ @ self.w.t()
self.inp.g = dJ_dX

### Gradient of loss with respect to $$w_{j}$$

How does the loss react when we wiggle $$w_{j}$$? $\frac{\partial J}{\partial w_{j}} = \frac{\partial J}{\partial \mathbf{z}} \frac{\partial \mathbf{z}}{\partial w_{j}}= \sum_{i=1}^{N}\frac{\partial J}{\partial z^{(i)}} \frac{\partial z^{(i)}}{\partial w_{j}} = \frac{\partial J}{\partial z^{(1)}}x^{(1)}_{j} + \frac{\partial J}{\partial z^{(2)}}x^{(2)}_{j} + \ldots + \frac{\partial J}{\partial z^{(N)}}x^{(N)}_{j}$

Thus,

$\frac{\partial J}{\partial \mathbf{w}} = \begin{bmatrix} \frac{\partial J}{\partial w_{1}} \\ \frac{\partial J}{\partial w_{2}} \\ \vdots \\ \frac{\partial J}{\partial w_{d}} \end{bmatrix} = \begin{bmatrix} x^{(1)}_{1}\frac{\partial J}{\partial z^{(1)}} + x^{(2)}_{1}\frac{\partial J}{\partial z^{(2)}} + \ldots + x^{(N)}_{1}\frac{\partial J}{\partial z^{(N)}} \\ x^{(1)}_{2}\frac{\partial J}{\partial z^{(1)}} + x^{(2)}_{2}\frac{\partial J}{\partial z^{(2)}} + \ldots + x^{(N)}_{2}\frac{\partial J}{\partial z^{(N)}} \\ \vdots \\ x^{(1)}_{d}\frac{\partial J}{\partial z^{(1)}} + x^{(2)}_{d}\frac{\partial J}{\partial z^{(2)}} + \ldots + x^{(N)}_{d}\frac{\partial J}{\partial z^{(N)}} \end{bmatrix}$

This is a matrix multiplication in disguise (each row is a dot product) and can be written more compactly as:

$\frac{\partial J}{\partial \mathbf{w}} = \begin{bmatrix} x^{(1)}_1 & x^{(2)}_1 & \cdots & x^{(N)}_1\\ x^{(1)}_2 & x^{(2)}_2 & \cdots & x^{(N)}_2\\ \vdots & \vdots & \ddots & \vdots\\ x^{(1)}_d & x^{(2)}_d & \cdots & x^{(N)}_d \end{bmatrix} \begin{bmatrix} \frac{\partial J}{\partial z^{(1)}} \\ \frac{\partial J}{\partial z^{(2)}} \\ \vdots \\ \frac{\partial J}{\partial z^{(N)}} \end{bmatrix} =\mathbf{X}^{T}\frac{\partial J}{\partial \mathbf{z}}$

From the implementation perspective $$\mathbf{X}^{T}$$ is computed locally in the backward function of the Lin class while $$\frac{\partial J}{\partial \mathbf{z}}$$ is ready and waiting to be accessed. Recall the latter was computed in the backward function of the Relu class.

In the lin_grad function, in the earlier portion of the notebook, we see Jeremy computing the gradient with respect to the weights as w.g = (inp.unsqueeze(-1) * out.g.unsqueeze(1)).sum(0).

def lin_grad(inp, out, w, b):
# grad of matmul with respect to input
inp.g = out.g @ w.t()

# What's going on here?
w.g = (inp.unsqueeze(-1) * out.g.unsqueeze(1)).sum(0)

b.g = out.g.sum(0)

This uses broadcasting to achieve the same computation as the matrix multiplication between $$X^{T}$$ and $$\frac{\partial J}{\partial \mathbf{z}}$$.

The inp.unsqueeze(-1) takes inp (our $$X$$) and converts it into a tensor with shape $$N \times d \times 1$$. This is then multiplied with out.g.unsqueeze(1) (out.g is $$\frac{\partial J}{\partial \mathbf{z}}$$) with shape $$N \times 1 \times 1$$. Broadcasting will make $$d$$ “copies” of $$\frac{\partial J}{\partial \mathbf{z}}$$ each of which will get multiplied (in the dot product sense) by a column of $$X$$ (each column of $$X$$ is just the feature values across the training examples). Finally, the sum(0) at the end sums across the first axis to get our $$d \times 1$$ shaped output.

### Gradient of loss with respect to the bias $$b$$

$\frac{\partial J}{\partial b} = \frac{\partial J}{\partial \mathbf{z}} \frac{\partial \mathbf{z}}{\partial b}= \sum_{i=1}^{N}\frac{\partial J}{\partial z^{(i)}} \frac{\partial z^{(i)}}{\partial b}= \sum_{i=1}^{N}\frac{\partial J}{\partial z^{(i)}} 1 = \frac{\partial J}{\partial z^{(1)}} + \frac{\partial J}{\partial z^{(2)}} + \ldots + \frac{\partial J}{\partial z^{(N)}}$

From the implementation perspective we need to access $$\frac{\partial J}{\partial \mathbf{z}}$$ and sum across the first axis. The local derivative computation of $$\frac{\partial \mathbf{z}}{\partial b}$$ is particulary simple (since each $$\frac{\partial z^{(i)}}{\partial b}$$ is just $$1$$).

### Gradient of loss with respect to $$x^{(i)}_{j}$$

Let’s understand how the loss will change as we twiddle the $$j$$-th feature of the $$i$$-th training example.

$\frac{\partial J}{\partial x^{(i)}_{j}} = \frac{\partial J}{\partial \mathbf{z}} \frac{\partial \mathbf{z}}{\partial x^{(i)}_{j}}= \sum_{k=1}^{N}\frac{\partial J}{\partial z^{(k)}} \frac{\partial z^{(k)}}{\partial x^{(i)}_{j}} = \frac{\partial J}{\partial z^{(i)}}\frac{\partial z^{(i)}}{\partial x^{(i)}_{j}} + \sum_{k: k \neq i}^{N} \frac{\partial J}{\partial z^{(k)}}\frac{\partial z^{(k)}}{\partial x^{(i)}_{j}}$

Since $$\frac{\partial z^{(k)}}{\partial x^{(i)}_{j}} = 0$$ for any $$k \neq i$$ we get

$\frac{\partial J}{\partial x^{(i)}_{j}} = \frac{\partial J}{\partial z^{(i)}}w_{j}$

Thus the $$N \times d$$ matrix of these gradients are,

$\frac{\partial J}{\partial \mathbf{X}} = \begin{bmatrix} \frac{\partial J}{\partial x^{(1)}_{1}} & \frac{\partial J}{\partial x^{(1)}_{2}} & \cdots & \frac{\partial J}{\partial x^{(1)}_{d}} \\ \frac{\partial J}{\partial x^{(2)}_{1}} & \frac{\partial J}{\partial x^{(2)}_{2}} & \cdots & \frac{\partial J}{\partial x^{(2)}_{d}} \\ \vdots & \vdots & \ddots & \vdots \\ \frac{\partial J}{\partial x^{(N)}_{1}} & \frac{\partial J}{\partial x^{(N)}_{2}} & \cdots & \frac{\partial J}{\partial x^{(N)}_{d}} \end{bmatrix}= \begin{bmatrix} \frac{\partial J}{\partial z^{(1)}}w_{1} & \frac{\partial J}{\partial z^{(1)}}w_{2} & \cdots & \frac{\partial J}{\partial z^{(1)}}w_{d} \\ \frac{\partial J}{\partial z^{(2)}}w_{1} & \frac{\partial J}{\partial z^{(2)}}w_{2} & \cdots & \frac{\partial J}{\partial z^{(2)}}w_{d} \\ \vdots & \vdots & \ddots & \vdots \\ \frac{\partial J}{\partial z^{(N)}}w_{1} & \frac{\partial J}{\partial z^{(N)}}w_{2} & \cdots & \frac{\partial J}{\partial z^{(N)}}w_{d} \end{bmatrix}$

More compactly this can be represented as an outer product,

$\frac{\partial J}{\partial \mathbf{X}} = \begin{bmatrix} \frac{\partial J}{\partial z^{(1)}} \\ \frac{\partial J}{\partial z^{(2)}} \\ \vdots \\ \frac{\partial J}{\partial z^{(N)}} \end{bmatrix} \begin{bmatrix} w_{1} & w_{2} & \cdots & w_{d} \end{bmatrix} =\frac{\partial J}{\partial \mathbf{z}}\mathbf{w}^{T}$