批归一化(Batch Normalization)原理分析和代码实现

本文瀏覽次數

发布日期

2024年3月18日

1 简介

批归一化(Batch Normalization,BN)是在2015年由Sergey Loffe和Christan Szegedy[1]提出的一种加速深度学习模型收敛的方法。

在模型的训练过程中,每个深度学习模型每一层模块的输出分布都在不断变化,后续的模块需要不断适应新的输入模式,这个问题在Batch Normalization[1]中被称为internal covariate shift。为克服这个问题,Batch Normalization提出在模型内部加入归一化层。归一化层的引入使得模型的训练更加稳定,允许使用更大的学习率,使得模型对参数的初始化没那么敏感。

本文主要谈一谈BN的运行原理和实现细节,最后分享一些BN的使用经验和容易踩到的坑。

2 实现方法

BN用如下公式作输入样本的归一化: \[ \hat x = \frac{x - E(x)}{\sqrt{Var(x) + \epsilon}}, \tag{1}\] 其中\(x\)为BN模块的输入。\(E(x)\)\(x\)的数学期望,\(Var(x)\)\(x\)的方差,\(\epsilon\)是防止除零异常的一个接近\(0\)的正数。

随后,归一化的样本经过一层线性层得到BN的输出: \[ y = \hat x\gamma + \beta. \]

2.1 Batch size

实际训练时,输入的tensor是N维的。以图像为例,一般图像特征\(x\)是4维,形状可能是\(b\times d\times h \times w\),其中\(b\)为batch size,\(d\)为特征维度,\(h\)为图像的长宽。在这个例子中,BN对所有\(d\)维的特征向量作归一化(共\(b\times h\times w\)个向量)。

为了获得尽量准确的统计,batch size最好取尽量大些。如果batch size太小,那么\(E(x)\)\(Var(x)\)的估计不准确,模型的最终性能便可能下降。

2.2 测试和训练阶段的行为不一致

测试推理阶段,模型往往一次只接受一个数据:\(\text{batch size}=1\)。BN不能像训练时那样在大batch size下估计\(E(x)\)\(Var(x)\). 为了应对这个问题,BN的对策是moving average。在训练阶段,BN使用Batch内统计的均值和方差作归一化,同时使用moving average方法维护均值和方差预备测试时使用。设BN的momentum参数等于0.1,\(m\)是moving average方法跟踪的一个统计量,那么其更新方法为: \[ \hat m_{t} = \hat m_{t-1} \cdot (1 - \text{momentum}) + m_t \cdot \text{momentum}. \]

测试时用事先统计的均值和方差的moving average,带入公式 1中作归一化。

3 代码实现

在实现BatchNorm之前,我们不妨先看看pytorch官方的BatchNorm2d模块,观察BatchNorm层要有哪些参数:

import torch 
import torch.nn as nn 
batch, ch, h, w = 2, 32, 128, 128
torch_batch_norm = nn.BatchNorm2d(num_features=ch)
for k, v in torch_batch_norm.named_parameters():
    print(f'parameter: {k}', v.shape)
for k, v in torch_batch_norm.named_buffers():
    print(f'buffer: {k}', v.shape)
parameter: weight torch.Size([32])
parameter: bias torch.Size([32])
buffer: running_mean torch.Size([32])
buffer: running_var torch.Size([32])
buffer: num_batches_tracked torch.Size([])

注意到BatchNorm的参数有两种。一种是parameterweightbias),一种是bufferrunning_meanrunning_varnum_batches_tracked)。对于parameter,torch默认其参数是需要梯度反传的;而buffer则用于存储一些不需要梯度反传的模型参数。与parameter一样,在保存模型时,buffer参数也会存储到state_dict中。

下面是本文提供的BN实现:

import torch
import torch.nn as nn 
class MyBatchNorm(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super().__init__()
        # weight初始化为1,bias初始化为0.
        self.weight = nn.Parameter(torch.ones(num_features))
        self.bias = nn.Parameter(torch.zeros(num_features))

        self.momentum = 0.1 
        self.eps = eps 

        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))
        self.register_buffer('num_batches_tracked', torch.tensor(0.))

    def forward(self, tensor): 
        bs, ch, h, w = tensor.shape 
        # 变换一下tensor的尺寸,方便处理
        tensor_flatten = tensor.permute(0, 2, 3, 1).flatten(0, 2)  # bs * h * w, ch 
        # 求均值和方差
        mean = torch.mean(tensor_flatten, 0)
        # 注意方差有biased和unbiased两种。
        var = torch.var(tensor_flatten, 0, unbiased=False)
        var_unbiased = torch.var(tensor_flatten, 0, unbiased=True)

        if self.training:
            # 训练时,我们要执行moving average,统计
            # running_mean和running_var,注意此时应
            # 使用unbiased版本的方差。
            self.running_mean.mul_(1 - self.momentum).add_(self.momentum * mean)
            self.running_var.mul_(1 - self.momentum).add_(self.momentum * var_unbiased)
            self.num_batches_tracked.add_(1)

        # 训练时用batch内的统计量,测试时用moving average
        # 保存的统计量。
        if self.training:
            tensor_flatten = (tensor_flatten - mean) / torch.sqrt(var + self.eps)
        else: 
            tensor_flatten = (tensor_flatten - self.running_mean) / torch.sqrt(self.running_var + self.eps)

        # 归一化完成后,做线性变换。
        ret = tensor_flatten * self.weight + self.bias 
        ret = ret.view(bs, h, w, ch).permute(0, 3, 1, 2)
        return ret 

接下来验证看看MyBatchNorm的行为和torch.nn.BatchNorm2d是否完全一致。

我们先检查训练模式下两者的行为:

my_batch_norm = MyBatchNorm(num_features=ch)
# 因为BN涉及running_mean和running_var的更新,所以我们要多跑几轮来检查moving average的正确性。
for _ in range(10):
    a = torch.rand(batch, ch, h, w)
    ret1 = torch_batch_norm(a)
    ret2 = my_batch_norm(a)
    diff = torch.mean(torch.abs(ret1 - ret2)).item()
    
    running_mean1 = torch_batch_norm.running_mean 
    running_mean2 = my_batch_norm.running_mean 
    diff_mean = torch.mean(torch.abs(running_mean1 - running_mean2)).item()

    running_var1 = torch_batch_norm.running_var 
    running_var2 = my_batch_norm.running_var
    diff_var = torch.mean(torch.abs(running_var1 - running_var2)).item()
    print('{:.6f};{:.6f};{:.6f}'.format(diff, diff_mean, diff_var))
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000

0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000

可以看到,输出的所有误差项都为0!这表明MyBatchNorm的实现和torch的BN相吻合。

不要忘记BN在训练时的行为和测试时的行为不同。我们需要再检查一遍测试阶段下MyBatchNorm的行为。

# .eval()开启测试模型
torch_batch_norm.eval()  
my_batch_norm.eval()
for _ in range(10):
    a = torch.rand(batch, ch, h, w)
    ret1 = torch_batch_norm(a)
    ret2 = my_batch_norm(a)
    diff = torch.mean(torch.abs(ret1 - ret2)).item()
    
    running_mean1 = torch_batch_norm.running_mean 
    running_mean2 = my_batch_norm.running_mean 
    diff_mean = torch.mean(torch.abs(running_mean1 - running_mean2)).item()

    running_var1 = torch_batch_norm.running_var 
    running_var2 = my_batch_norm.running_var
    diff_var = torch.mean(torch.abs(running_var1 - running_var2)).item()
    print('{:.6f};{:.6f};{:.6f}'.format(diff, diff_mean, diff_var))
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000

0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000

至此我们已经完成了全部检查,实验结果表明MyBatchNorm的实现是正确的。

4 正确地使用BN

BN作为一种实用、应用广泛的归一化模块,是计算机视觉领域的一座里程碑。尽管BN的应用确实解决了一些实际问题,但它也存在一些“坑”,是在使用BN时应当注意的。

4.1 BN在训练阶段与测试阶段的行为差异

在训练阶段,BN使用batch内统计的均值和方差作归一化,并记录它们的moving average;而在测试阶段,BN不再统计新数据的均值和方差,也不再更新moving average。这种不一致性(inconsistency)在后续工作[2]中被认为是一种影响性能的潜在因素。

4.2 如何正确地冻结BN模块

设想我们有一个模型经过了充分的预训练,现在我们希望在一个小数据集上微调它。一般步骤包括(以pytorch为例):

  1. 阻止梯度反传。这可以通过使用torch.no_grad()或将该各参数的requires_grad属性设置为False做到;
  2. 调用module.eval(),关闭train模式;

针对第2点,一般人们有两种意见。一种看法认为不开BN的eval模式更好,这有助于让模型学习如何对新数据做归一化。而我倾向于采取的做法则是开启eval。在我的经验中,如果BN处于训练状态,而模型的其它层则冻结著,那么模型可能因为不适应BN在新数据上归一化参数的改变而引发训练不稳定。

总而言之,BN在训练、迁移学习、测试时的行为不一致有时确实是一个麻烦的问题。如果遇到了这个问题,我建议考虑一下是否要开启BN的eval模式,或者试试后来的Group Normalization[2]

4.3 分布式训练

在训练参数量较大的模型时,可以用分布式训练,利用多个进程和多个计算设备执行计算。这种情况下,每张卡只需负责比较小的batch。注意原始的BN在batch size较小时,所产生的均值/方差的统计量不准确。因此,在分布式训练时,我们最好将原BatchNorm模块替换为torch.SyncBatchNorm。后者能同步所有计算设备,在更大的batch size上统计均值和方差。

4.4 不要递归地使用BN

最后介绍一个我踩过的,印象深刻的坑。 假如有这样一段代码:

x2 = batch_norm(x1)
x3 = batch_norm(x2)

x2 = conv(x1)  # conv中包含BN模块
x3 = conv(x2)

前面的一段代码的问题或许容易识别,后者的问题则稍隐蔽些。你能预测到会发生什么吗?在这样的代码中,同一个BN模块在训练时会分别获取x2x1的均值和方差,然后通过moving average将它们计入running_meanrunning_var。然而,由于x1x2服从不同的分布,因此running_meanrunning_var的统计将失去意义。 问题的表现是,在训练阶段,我们会观察到损失正常下降。测试时,我们开启eval模式,模型的表现不如预期;可是如果你关闭eval模式,也许会发现模型又能正常工作。

类似的问题也存在于特征金字塔(FPN)的实现中。如果你希望在类特征金字塔的结构中实现不同层级共享参数的话,注意卷积的参数也许能共享,但BN的参数不要共享。

5 总结

本文介绍了BN的工作原理,给出了一种基于pytorch的BN模块实现,并提供了详细的代码检查。最后,本文讨论了应用BN过程中容易遇到的几种问题。

在接触深度学习的过程中,Batch Normalization是一个让我反复(大概得有两三次吧)踩坑的模块,每次踩坑都得琢磨好久才能发现问题所在。现在我已经习惯性的选择Group Normalization[2],抛弃BN了。尽管如此,BN仍是一个经典的工作,它背后的思想很值得学习研究。

参考

[1]
IOFFE S, SZEGEDY C. Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift[J]. : 9.
[2]
WU Y, HE K. Group Normalization[J]. arXiv:1803.08494 [cs], 2018.
By @執迷 in
Tags : #Batch Normalization, #Layer Normalization, #歸一化方法, #自然語言處理, #計算機視覺,