批歸一化(Batch Normalization)原理分析和代碼實現

本文瀏覽次數

发布日期

2024年3月18日

1 簡介

批歸一化(Batch Normalization,BN)是在2015年由Sergey Loffe和Christan Szegedy[1]提出的一種加速深度學習模型收斂的方法。

在模型的訓練過程中,每個深度學習模型每一層模塊的輸出分佈都在不斷變化,後續的模塊需要不斷適應新的輸入模式,這個問題在Batch Normalization[1]中被稱為internal covariate shift。為克服這個問題,Batch Normalization提出在模型內部加入歸一化層。歸一化層的引入使得模型的訓練更加穩定,允許使用更大的學習率,使得模型對參數的初始化沒那麼敏感。

本文主要談一談BN的運行原理和實現細節,最後分享一些BN的使用經驗和容易踩到的坑。

2 實現方法

BN用如下公式作輸入樣本的歸一化: \[ \hat x = \frac{x - E(x)}{\sqrt{Var(x) + \epsilon}}, \tag{1}\] 其中\(x\)為BN模塊的輸入。\(E(x)\)\(x\)的數學期望,\(Var(x)\)\(x\)的方差,\(\epsilon\)是防止除零異常的一個接近\(0\)的正數。

隨後,歸一化的樣本經過一層線性層得到BN的輸出: \[ y = \hat x\gamma + \beta. \]

2.1 Batch size

實際訓練時,輸入的tensor是N維的。以圖像為例,一般圖像特徵\(x\)是4維,形狀可能是\(b\times d\times h \times w\),其中\(b\)為batch size,\(d\)為特徵維度,\(h\)為圖像的長寬。在這個例子中,BN對所有\(d\)維的特徵向量作歸一化(共\(b\times h\times w\)個向量)。

為了獲得盡量準確的統計,batch size最好取盡量大些。如果batch size太小,那麼\(E(x)\)\(Var(x)\)的估計不準確,模型的最終性能便可能下降。

2.2 測試和訓練階段的行為不一致

測試推理階段,模型往往一次只接受一個數據:\(\text{batch size}=1\)。BN不能像訓練時那樣在大batch size下估計\(E(x)\)\(Var(x)\). 為了應對這個問題,BN的對策是moving average。在訓練階段,BN使用Batch內統計的均值和方差作歸一化,同時使用moving average方法維護均值和方差預備測試時使用。設BN的momentum參數等於0.1,\(m\)是moving average方法跟蹤的一個統計量,那麼其更新方法為: \[ \hat m_{t} = \hat m_{t-1} \cdot (1 - \text{momentum}) + m_t \cdot \text{momentum}. \]

測試時用事先統計的均值和方差的moving average,帶入公式 1中作歸一化。

3 代碼實現

在實現BatchNorm之前,我們不妨先看看pytorch官方的BatchNorm2d模塊,觀察BatchNorm層要有哪些參數:

import torch 
import torch.nn as nn 
batch, ch, h, w = 2, 32, 128, 128
torch_batch_norm = nn.BatchNorm2d(num_features=ch)
for k, v in torch_batch_norm.named_parameters():
    print(f'parameter: {k}', v.shape)
for k, v in torch_batch_norm.named_buffers():
    print(f'buffer: {k}', v.shape)
parameter: weight torch.Size([32])
parameter: bias torch.Size([32])
buffer: running_mean torch.Size([32])
buffer: running_var torch.Size([32])
buffer: num_batches_tracked torch.Size([])

注意到BatchNorm的參數有兩種。一種是parameterweightbias),一種是bufferrunning_meanrunning_varnum_batches_tracked)。對於parameter,torch默認其參數是需要梯度反傳的;而buffer則用於存儲一些不需要梯度反傳的模型參數。與parameter一樣,在保存模型時,buffer參數也會存儲到state_dict中。

下面是本文提供的BN實現:

import torch
import torch.nn as nn 
class MyBatchNorm(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super().__init__()
        # weight初始化為1,bias初始化為0.
        self.weight = nn.Parameter(torch.ones(num_features))
        self.bias = nn.Parameter(torch.zeros(num_features))

        self.momentum = 0.1 
        self.eps = eps 

        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))
        self.register_buffer('num_batches_tracked', torch.tensor(0.))

    def forward(self, tensor): 
        bs, ch, h, w = tensor.shape 
        # 變換一下tensor的尺寸,方便處理
        tensor_flatten = tensor.permute(0, 2, 3, 1).flatten(0, 2)  # bs * h * w, ch 
        # 求均值和方差
        mean = torch.mean(tensor_flatten, 0)
        # 注意方差有biased和unbiased兩種。
        var = torch.var(tensor_flatten, 0, unbiased=False)
        var_unbiased = torch.var(tensor_flatten, 0, unbiased=True)

        if self.training:
            # 訓練時,我們要執行moving average,統計
            # running_mean和running_var,注意此時應
            # 使用unbiased版本的方差。
            self.running_mean.mul_(1 - self.momentum).add_(self.momentum * mean)
            self.running_var.mul_(1 - self.momentum).add_(self.momentum * var_unbiased)
            self.num_batches_tracked.add_(1)

        # 訓練時用batch內的統計量,測試時用moving average
        # 保存的統計量。
        if self.training:
            tensor_flatten = (tensor_flatten - mean) / torch.sqrt(var + self.eps)
        else: 
            tensor_flatten = (tensor_flatten - self.running_mean) / torch.sqrt(self.running_var + self.eps)

        # 歸一化完成後,做線性變換。
        ret = tensor_flatten * self.weight + self.bias 
        ret = ret.view(bs, h, w, ch).permute(0, 3, 1, 2)
        return ret 

接下來驗證看看MyBatchNorm的行為和torch.nn.BatchNorm2d是否完全一致。

我們先檢查訓練模式下兩者的行為:

my_batch_norm = MyBatchNorm(num_features=ch)
# 因為BN涉及running_mean和running_var的更新,所以我們要多跑幾輪來檢查moving average的正確性。
for _ in range(10):
    a = torch.rand(batch, ch, h, w)
    ret1 = torch_batch_norm(a)
    ret2 = my_batch_norm(a)
    diff = torch.mean(torch.abs(ret1 - ret2)).item()
    
    running_mean1 = torch_batch_norm.running_mean 
    running_mean2 = my_batch_norm.running_mean 
    diff_mean = torch.mean(torch.abs(running_mean1 - running_mean2)).item()

    running_var1 = torch_batch_norm.running_var 
    running_var2 = my_batch_norm.running_var
    diff_var = torch.mean(torch.abs(running_var1 - running_var2)).item()
    print('{:.6f};{:.6f};{:.6f}'.format(diff, diff_mean, diff_var))
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000

0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000

可以看到,輸出的所有誤差項都為0!這表明MyBatchNorm的實現和torch的BN相吻合。

不要忘記BN在訓練時的行為和測試時的行為不同。我們需要再檢查一遍測試階段下MyBatchNorm的行為。

# .eval()開啟測試模型
torch_batch_norm.eval()  
my_batch_norm.eval()
for _ in range(10):
    a = torch.rand(batch, ch, h, w)
    ret1 = torch_batch_norm(a)
    ret2 = my_batch_norm(a)
    diff = torch.mean(torch.abs(ret1 - ret2)).item()
    
    running_mean1 = torch_batch_norm.running_mean 
    running_mean2 = my_batch_norm.running_mean 
    diff_mean = torch.mean(torch.abs(running_mean1 - running_mean2)).item()

    running_var1 = torch_batch_norm.running_var 
    running_var2 = my_batch_norm.running_var
    diff_var = torch.mean(torch.abs(running_var1 - running_var2)).item()
    print('{:.6f};{:.6f};{:.6f}'.format(diff, diff_mean, diff_var))
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000

0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000

至此我們已經完成了全部檢查,實驗結果表明MyBatchNorm的實現是正確的。

4 正確地使用BN

BN作為一種實用、應用廣泛的歸一化模塊,是計算機視覺領域的一座里程碑。儘管BN的應用確實解決了一些實際問題,但它也存在一些“坑”,是在使用BN時應當注意的。

4.1 BN在訓練階段與測試階段的行為差異

在訓練階段,BN使用batch內統計的均值和方差作歸一化,並記錄它們的moving average;而在測試階段,BN不再統計新數據的均值和方差,也不再更新moving average。這種不一致性(inconsistency)在後續工作[2]中被認為是一種影響性能的潛在因素。

4.2 如何正確地凍結BN模塊

設想我們有一個模型經過了充分的預訓練,現在我們希望在一個小數據集上微調它。一般步驟包括(以pytorch為例):

  1. 阻止梯度反傳。這可以通過使用torch.no_grad()或將該各參數的requires_grad屬性設置為False做到;
  2. 調用module.eval(),關閉train模式;

針對第2點,一般人們有兩種意見。一種看法認為不開BN的eval模式更好,這有助於讓模型學習如何對新數據做歸一化。而我傾向於採取的做法則是開啟eval。在我的經驗中,如果BN處於訓練狀態,而模型的其它層則凍結著,那麼模型可能因為不適應BN在新數據上歸一化參數的改變而引發訓練不穩定。

總而言之,BN在訓練、遷移學習、測試時的行為不一致有時確實是一個麻煩的問題。如果遇到了這個問題,我建議考慮一下是否要開啟BN的eval模式,或者試試後來的Group Normalization[2]

4.3 分佈式訓練

在訓練參數量較大的模型時,可以用分佈式訓練,利用多個進程和多個計算設備執行計算。這種情況下,每張卡只需負責比較小的batch。注意原始的BN在batch size較小時,所產生的均值/方差的統計量不準確。因此,在分佈式訓練時,我們最好將原BatchNorm模塊替換為torch.SyncBatchNorm。後者能同步所有計算設備,在更大的batch size上統計均值和方差。

4.4 不要遞歸地使用BN

最後介紹一個我踩過的,印象深刻的坑。 假如有這樣一段代碼:

x2 = batch_norm(x1)
x3 = batch_norm(x2)

x2 = conv(x1)  # conv中包含BN模塊
x3 = conv(x2)

前面的一段代碼的問題或許容易識別,後者的問題則稍隱蔽些。你能預測到會發生什麼嗎?在這樣的代碼中,同一個BN模塊在訓練時會分別獲取x2x1的均值和方差,然後通過moving average將它們計入running_meanrunning_var。然而,由於x1x2服從不同的分佈,因此running_meanrunning_var的統計將失去意義。 問題的表現是,在訓練階段,我們會觀察到損失正常下降。測試時,我們開啟eval模式,模型的表現不如預期;可是如果你關閉eval模式,也許會發現模型又能正常工作。

類似的問題也存在於特征金字塔(FPN)的實現中。如果你希望在類特征金字塔的結構中實現不同層級共享參數的話,注意卷積的參數也許能共享,但BN的參數不要共享。

5 總結

本文介紹了BN的工作原理,給出了一種基於pytorch的BN模塊實現,並提供了詳細的代碼檢查。最後,本文討論了應用BN過程中容易遇到的幾種問題。

在接觸深度學習的過程中,Batch Normalization是一個讓我反復(大概得有兩三次吧)踩坑的模塊,每次踩坑都得琢磨好久才能發現問題所在。現在我已經習慣性的選擇Group Normalization[2],拋棄BN了。儘管如此,BN仍是一個經典的工作,它背後的思想很值得學習研究。

参考

[1]
IOFFE S, SZEGEDY C. Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift[J]. : 9.
[2]
WU Y, HE K. Group Normalization[J]. arXiv:1803.08494 [cs], 2018.
By @執迷 in
Tags : #Batch Normalization, #Layer Normalization, #歸一化方法, #自然語言處理, #計算機視覺,