I present a derivation of efficient backpropagation equations for batch-normalization layers.
Table of Contents
Introduction
A batch normalization layer is given a batch of examples, each of which is a -dimensional vector. We can represent the inputs as a matrix where each row is a single example. Each example is normalized by
where are the mean and variance, respectively, of each input dimension across the batch. is some small constant that prevents division by 0. The mean and variance are computed by
An affine transform is then applied to the normalized rows to produce the final output
where are learnable scale parameters for each input dimension. For notational simplicity, we can express the entire layer as
Notation: denotes the Hadamard (element-wise) product. In the case of , where is a row vector and is a matrix, each row of is multiplied element-wise by .
Gradient Notes: Several times throughout this post, I mention my “gradient notes” which refers to this document.
Backpropagation Basics
Let be the training loss. We are given , the gradient signal with respect to . Our goal is to calculate three gradients:
- , to perform a gradient descent update on
- , to perform a gradient descent update on
- , to pass on the gradient signal to lower layers
Both and are straightforward. Let be the -th row of . We refer to our gradient notes to get
Deriving requires backpropagation through , which yields
Next we have to backpropagate through . Because both and are functions of , finding the gradient of with respect to is tricky. There are two approaches to break this down:
- Take the gradient of with respect to each row (example) in . This approach is complicated by the fact that the values of each row in influence the values of all rows in (i.e. ). By properly considering how changes in influence and , this is doable, as explained here.
- Take the gradient of with respect to each column (input dimension) in . I find this more intuitive because batch normalization operates independently for each column - , , , and are all calculated per column. This method is explained below.
Column-wise Gradient
Since we are taking the gradient of with respect to each column in , we can start by considering the case where is just a single column vector. Thus, each example is a single number, and and are scalar real numbers. This makes the math much easier. Later on, we generalize to -dimensional input examples.
Lemma
Let be a real-valued function of vector . Suppose is known. If where and , then
Proof
First we compute the gradient of for a single element in .
We apply the chain rule to obtain the gradient of for a single element in .
Now we can write the gradient for all elements in , where is the identity matrix.
This result is a generalization of the “product rule” in the completely scalar case. For a function where , we have
Getting a single expression for
We want a single expression for , which we will derive in two steps.
- Rewrite in the form for some choice of and . This enables us to use the lemma above to obtain .
- Rewrite in the form for some choice of . This enables us to use our gradient notes to obtain .
We choose . Then can be expressed in terms of as follows:
where . Now we apply our lemma above.
can be written as a matrix multiplication with , where is a matrix of all ones.
Using our gradient rules, we get
Simplifying the expression
First, we calculate
We plug this into our equation for and rewrite and in terms of and :
The last step above is because is the 0-vector:
Note that when the inputs are scalars, where is a scalar and is a column vector. Thus,
where is a -dimensional column vector of ones. The last line uses the fact that when the input examples are scalars, the derivatives simplify to
Finally, we generalize to the case when the input examples are -dimensional vectors:
References
- Batch Normalization
- the original paper by Sergey Ioffe and Christian Szegedy
- Efficient Batch Normalization
- row-wise derivation of
- Deriving the Gradient for the Backward Pass of Batch Normalization
- another take on row-wise derivation of
- Understanding the backward pass through Batch Normalization Layer
- (slow) step-by-step backpropagation through the batch normalization layer
- Batch Normalization - What the Hey?
- explains some intuition behind batch normalization
- clarifies the difference between using batch statistics during training and sample statistics during inference