Batch Normalization详解

2023-11-25发布于深度学习 | 最后更新于2023-12-04 16:12:00

Pytorch Normalization

标准化的基本步骤

1. 计算统计量——均值与方差

假设数据为\(X=\{x_1, x_2, \dots,x_N\}\),则均值\(\mathrm{E}(X)\)与方差\(\mathrm{Var}(X)\)分别为:

$$ \mathrm{E}(X)=\frac{1}{N}\sum_{i=1}^N{x_i}\\ \mathrm{Var}(X)=\frac{1}{N}\sum_{i=1}^N{(x_i-\mathrm{E}(X))^2} $$

2. 进行标准化

使用下式对\(x_i\)进行标准化:

$$ \hat{x_i}=\frac{x_i-\mathrm{E}(X)}{\sqrt{\mathrm{Var}(X)+\epsilon}} $$

其中\(\epsilon\)为一个很小的正实数,用于避免计算中出现分母为零的情况

3. 放缩与平移

得到上面的\(\hat{x_i}\)后,还会进行放缩与平移:

$$ y=\gamma \cdot \hat{x_i}+\beta $$

其中的\(\gamma, \beta\)均为实数,并且是可以训练的。

Batch Normalization

在一个batch中,跨各个特征向量对每个channel各自进行标准化。以下是torch.nn模块中各种BN层的详解。

torch.nn.BatchNorm1d

顾名思义,对一维的的特征向量进行Batch Normalization。输入可以是2D或者3D的,会将输入数据的大小视为\((N,C)\)\((N,C,L)\)\(N\)为一个batch中的特征向量数量,\(C\)为通道数,\(L\)为序列长度。需要记住的是,BN永远是对各通道进行计算的,可以结合例子进行理解:

例子1

假设某batch的数据如下,size为\((3, 4)\),第1维(即每一行)表示单词的词向量,此batch中共有三个单词,通道数为4:

[[1, 3, 2, 4],
 [2, 2, 1, 1],
 [6, 2, 4, 1]]

分别计算四个通道的均值与方差:

$$ \begin{align*} \mathrm{E}_1&=\frac{1}{3}\cdot (1+2+6)&=3,\mathrm{Var}_1 &=\frac{1}{3}\cdot [(1-3)^2+(2-3)^2+(6-3)^2]&=\frac{14}{3}\\ \mathrm{E}_2&=\frac{1}{3}\cdot (3+2+2)&=\frac{7}{3},\mathrm{Var}_2 &=\frac{1}{3}\cdot [(\frac{2}{3})^2+(-\frac{1}{3})^2+(-\frac{1}{3})^2]&=\frac{2}{9}\\ \mathrm{E}_3&=\frac{1}{3}\cdot (2+1+4)&=\frac{7}{3},\mathrm{Var}_3 &=\frac{1}{3}\cdot [(-\frac{1}{3})^2+(-\frac{4}{3})^2+(\frac{5}{3})^2]&=\frac{14}{9}\\ \mathrm{E}_4&=\frac{1}{3}\cdot (4+1+1)&=2,\mathrm{Var}_4 &=\frac{1}{3}\cdot [(\frac{5}{3})^2+(-\frac{4}{3})^2+(-\frac{4}{3})^2]&=2 \end{align*} $$

根据标准化公式\(\hat{x_i}=\frac{x_i-\mathrm{E}(X)}{\sqrt{\mathrm{Var}(X)+\epsilon}}\)计算即可,此处省略\(\epsilon\)进行验证:

data_np = np.array([[1, 3, 2, 4],
                    [2, 2, 1, 1],
                    [6, 2, 4, 1]], dtype=np.float32)
data_torch = torch.from_numpy(data_np)
bn1 = torch.nn.BatchNorm1d(4, track_running_stats=False, affine=False)

avg = np.average(data_np, axis=0)
var = np.var(data_np, axis=0)

(data_np-avg)/np.sqrt(var)
array([[-0.9258,  1.4142, -0.2673,  1.4142],
       [-0.4629, -0.7071, -1.069 , -0.7071],
       [ 1.3887, -0.7071,  1.3363, -0.7071]])
bn1(data_torch)
tensor([[-0.9258,  1.4142, -0.2673,  1.4142],
        [-0.4629, -0.7071, -1.0690, -0.7071],
        [ 1.3887, -0.7071,  1.3363, -0.7071]], dtype=torch.float64)

在创建BatchNorm1d层时,指定了三个参数:

  • 第一个参数是num_features,即通道数\(C\)
  • track_running_stats决定了模型中此BatchNorm1d层会不会在不同batch之间跟踪均值和方差,若启用,则模型在训练时会有running_meanrunning_var用于保存全局的均值和方差。
  • affine决定了是否进行基本步骤中的第三步——放缩与平移

例子2

假设某batch的数据如下,size为\((3, 2, 2)\),第1维(即每一行)是表示一个句子的向量,此batch中:

  • 共有三个句子,例如第一个句子为[[1, 3], [2, 4]]
  • 通道数为2:对于一个句子来说,其中的一个词就是该句子的一个特征分量,例如第一个句子的第一个通道(词)为[1, 3]
  • 序列长度为2:词向量是2维的
[[[1, 3], [2, 4]],
 [[2, 2], [1, 1]],
 [[6, 2], [4, 1]]]

分别计算两个通道的均值和方差,第一个通道的数据为1, 3, 2, 2, 6, 2,第二个通道的数据为2, 4, 1, 1, 4, 1

$$ \begin{align*} \mathrm{E}_1&=\frac{8}{3},&\mathrm{Var}_1&=\frac{23}{9}\\ \mathrm{E}_2&=\frac{13}{6},&\mathrm{Var}_2&=\frac{65}{36} \end{align*} $$

验证如下(主要看结果,numpy的过程用了些小技巧)

data_np = np.array([[[1, 3], [2, 4]],
                    [[2, 2], [1, 1]],
                    [[6, 2], [4, 1]]], dtype=np.float32)
data_torch = torch.from_numpy(data_np)
bn1 = torch.nn.BatchNorm1d(2, track_running_stats=False, affine=False)
avg = np.array([data_np[:,0,:].mean(), data_np[:,1,:].mean()])
avg = np.tile(avg.repeat(2), 3).reshape((3, 2, 2))
var = np.array([data_np[:,0,:].var(), data_np[:,1,:].var()])
var = np.tile(var.repeat(2), 3).reshape((3, 2, 2))

(data_np-avg)/np.sqrt(var)
array([[[-1.0426,  0.2085],
        [-0.124 ,  1.3644]],

       [[-0.417 , -0.417 ],
        [-0.8682, -0.8682]],

       [[ 2.0851, -0.417 ],
        [ 1.3644, -0.8682]]], dtype=float32)
bn1(data_torch)
tensor([[[-1.0426,  0.2085],
         [-0.1240,  1.3644]],

        [[-0.4170, -0.4170],
         [-0.8682, -0.8682]],

        [[ 2.0851, -0.4170],
         [ 1.3644, -0.8682]]])

torch.nn.BatchNorm2d

输入为4维tensor,大小为\((N,C,H,W)\)\(N\)为一个batch中特征向量的数量,\(C\)为一个特征向量的分量数,特征向量一个分量的大小则为\((H,W)\)。可以从图片的角度来理解:N为一个batch中图片的张数,C为图片的R、G、B三个通道,H、W则为图片的高和宽。

测试数据如下,注释中说明了理解方式

[
    [
        [[3, 4],
         [1, 2]],   # 第一张图片的一通道
        [[3, 3],
         [1, 2]],   # 第一张图片的二通道
        [[4, 1],
         [4, 1]],   # 第一张图片的三通道
    ],  #batch中的第一张图片
    [
        [[2, 2],
         [1, 4]],
        [[3, 3],
         [1, 4]],
        [[2, 1],
         [2, 3]],
    ],  #batch中的第二张图片
    [
        [[4, 2],
         [4, 1]],
        [[1, 1],
         [2, 1]],
        [[3, 2],
         [2, 4]],
    ],  #batch中的第三张图片
]
此时,第一个通道的数据就是3, 4, 1, 2, 2, 2, 1, 4, 4, 2, 4, 1;以此类推

# 限于篇幅,略去数据初始化部分,np.array与tensor分别存于data_np与data_torch中
avg = []
var = []
for i in range(data_np.shape[1]):
    avg.append(data_np[:,i,:,:].mean())
    var.append(data_np[:,i,:,:].var())
avg = np.array(avg)
avg = np.tile(avg, 3).repeat(4).reshape((3, 3, 2, 2))
var = np.array(var)
var = np.tile(var, 3).repeat(4).reshape((3, 3, 2, 2))
bn2 = torch.nn.BatchNorm2d(3, track_running_stats=False, affine=False)

(data_np-avg)/np.sqrt(var)
array([[[[ 0.4201,  1.2603],
         [-1.2603, -0.4201]],

        [[ 0.8835,  0.8835],
         [-1.0442, -0.0803]],

        [[ 1.4201, -1.2706],
         [ 1.4201, -1.2706]]],

            ...

       [[[ 1.2603, -0.4201],
         [ 1.2603, -1.2603]],

        [[-1.0442, -1.0442],
         [-0.0803, -1.0442]],

        [[ 0.5232, -0.3737],
         [-0.3737,  1.4201]]]], dtype=float32)
bn2(data_torch)
tensor([[[[ 0.4201,  1.2602],
          [-1.2602, -0.4201]],

         [[ 0.8835,  0.8835],
          [-1.0442, -0.0803]],

         [[ 1.4201, -1.2706],
          [ 1.4201, -1.2706]]],

            ...

        [[[ 1.2602, -0.4201],
          [ 1.2602, -1.2602]],

         [[-1.0442, -1.0442],
          [-0.0803, -1.0442]],

         [[ 0.5232, -0.3737],
          [-0.3737,  1.4201]]]])

torch.nn.BatchNorm3d

输入的是5维tensor,大小会被视为\((N,C,D,H,W)\),即一个channel中就包含了多张单色图。限于篇幅,下面只写出用numpy手工计算BN的过程

var = []
avg = []
for i in range(data_np.shape[1]):
    var.append(data_np[:,i,:,:,:].var())
    avg.append(data_np[:,i,:,:,:].mean())

var = np.array(var)
avg = np.array(avg)
avg = np.tile(avg, 2).repeat(32).reshape((2, 3, 2, 4, 4))
var = np.tile(var, 2).repeat(32).reshape((2, 3, 2, 4, 4))
res_np = (data_np-avg)/np.sqrt(var)

总结

至此,已经可以很清晰的看出Batch Normalization的通用算法了,大致分为下面几步:

  1. 计算各通道的均值和方差,有几个通道就有几个均值、几个方差
  2. 输入数据中的某一个数进行BN后的结果,其实就是先减去这个数所在通道的均值,再除以所在通道方差的平方根

BN算法伪代码