1 簡介
批歸一化(Batch Normalization,BN)是在2015年由Sergey Loffe和Christan Szegedy[1]提出的一種加速深度學習模型收斂的方法。
在模型的訓練過程中,每個深度學習模型每一層模塊的輸出分佈都在不斷變化,後續的模塊需要不斷適應新的輸入模式,這個問題在Batch Normalization[1]中被稱為internal covariate shift。為克服這個問題,Batch Normalization提出在模型內部加入歸一化層。歸一化層的引入使得模型的訓練更加穩定,允許使用更大的學習率,使得模型對參數的初始化沒那麼敏感。
本文主要談一談BN的運行原理和實現細節,最後分享一些BN的使用經驗和容易踩到的坑。
2 實現方法
BN用如下公式作輸入樣本的歸一化: \[ \hat x = \frac{x - E(x)}{\sqrt{Var(x) + \epsilon}}, \tag{1}\] 其中\(x\)為BN模塊的輸入。\(E(x)\)為\(x\)的數學期望,\(Var(x)\)為\(x\)的方差,\(\epsilon\)是防止除零異常的一個接近\(0\)的正數。
隨後,歸一化的樣本經過一層線性層得到BN的輸出: \[ y = \hat x\gamma + \beta. \]
2.1 Batch size
實際訓練時,輸入的tensor是N維的。以圖像為例,一般圖像特徵\(x\)是4維,形狀可能是\(b\times d\times h \times w\),其中\(b\)為batch size,\(d\)為特徵維度,\(h\)和\(w\)為圖像的長寬。在這個例子中,BN對所有\(d\)維的特徵向量作歸一化(共\(b\times h\times w\)個向量)。
為了獲得盡量準確的統計,batch size最好取盡量大些。如果batch size太小,那麼\(E(x)\)和\(Var(x)\)的估計不準確,模型的最終性能便可能下降。
2.2 測試和訓練階段的行為不一致
測試推理階段,模型往往一次只接受一個數據:\(\text{batch size}=1\)。BN不能像訓練時那樣在大batch size下估計\(E(x)\)和\(Var(x)\). 為了應對這個問題,BN的對策是moving average。在訓練階段,BN使用Batch內統計的均值和方差作歸一化,同時使用moving average方法維護均值和方差預備測試時使用。設BN的momentum
參數等於0.1,\(m\)是moving average方法跟蹤的一個統計量,那麼其更新方法為: \[
\hat m_{t} = \hat m_{t-1} \cdot (1 - \text{momentum}) + m_t \cdot \text{momentum}.
\]
測試時用事先統計的均值和方差的moving average,帶入公式 1中作歸一化。
3 代碼實現
在實現BatchNorm之前,我們不妨先看看pytorch官方的BatchNorm2d
模塊,觀察BatchNorm層要有哪些參數:
import torch
import torch.nn as nn
= 2, 32, 128, 128
batch, ch, h, w = nn.BatchNorm2d(num_features=ch)
torch_batch_norm for k, v in torch_batch_norm.named_parameters():
print(f'parameter: {k}', v.shape)
for k, v in torch_batch_norm.named_buffers():
print(f'buffer: {k}', v.shape)
parameter: weight torch.Size([32])
parameter: bias torch.Size([32])
buffer: running_mean torch.Size([32])
buffer: running_var torch.Size([32])
buffer: num_batches_tracked torch.Size([])
注意到BatchNorm的參數有兩種。一種是parameter(weight
、bias
),一種是buffer(running_mean
、running_var
、num_batches_tracked
)。對於parameter,torch默認其參數是需要梯度反傳的;而buffer則用於存儲一些不需要梯度反傳的模型參數。與parameter
一樣,在保存模型時,buffer
參數也會存儲到state_dict
中。
下面是本文提供的BN實現:
import torch
import torch.nn as nn
class MyBatchNorm(nn.Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1):
super().__init__()
# weight初始化為1,bias初始化為0.
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
self.momentum = 0.1
self.eps = eps
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.register_buffer('num_batches_tracked', torch.tensor(0.))
def forward(self, tensor):
= tensor.shape
bs, ch, h, w # 變換一下tensor的尺寸,方便處理
= tensor.permute(0, 2, 3, 1).flatten(0, 2) # bs * h * w, ch
tensor_flatten # 求均值和方差
= torch.mean(tensor_flatten, 0)
mean # 注意方差有biased和unbiased兩種。
= torch.var(tensor_flatten, 0, unbiased=False)
var = torch.var(tensor_flatten, 0, unbiased=True)
var_unbiased
if self.training:
# 訓練時,我們要執行moving average,統計
# running_mean和running_var,注意此時應
# 使用unbiased版本的方差。
self.running_mean.mul_(1 - self.momentum).add_(self.momentum * mean)
self.running_var.mul_(1 - self.momentum).add_(self.momentum * var_unbiased)
self.num_batches_tracked.add_(1)
# 訓練時用batch內的統計量,測試時用moving average
# 保存的統計量。
if self.training:
= (tensor_flatten - mean) / torch.sqrt(var + self.eps)
tensor_flatten else:
= (tensor_flatten - self.running_mean) / torch.sqrt(self.running_var + self.eps)
tensor_flatten
# 歸一化完成後,做線性變換。
= tensor_flatten * self.weight + self.bias
ret = ret.view(bs, h, w, ch).permute(0, 3, 1, 2)
ret return ret
接下來驗證看看MyBatchNorm
的行為和torch.nn.BatchNorm2d
是否完全一致。
我們先檢查訓練模式下兩者的行為:
= MyBatchNorm(num_features=ch)
my_batch_norm # 因為BN涉及running_mean和running_var的更新,所以我們要多跑幾輪來檢查moving average的正確性。
for _ in range(10):
= torch.rand(batch, ch, h, w)
a = torch_batch_norm(a)
ret1 = my_batch_norm(a)
ret2 = torch.mean(torch.abs(ret1 - ret2)).item()
diff
= torch_batch_norm.running_mean
running_mean1 = my_batch_norm.running_mean
running_mean2 = torch.mean(torch.abs(running_mean1 - running_mean2)).item()
diff_mean
= torch_batch_norm.running_var
running_var1 = my_batch_norm.running_var
running_var2 = torch.mean(torch.abs(running_var1 - running_var2)).item()
diff_var print('{:.6f};{:.6f};{:.6f}'.format(diff, diff_mean, diff_var))
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
可以看到,輸出的所有誤差項都為0!這表明MyBatchNorm
的實現和torch的BN相吻合。
不要忘記BN在訓練時的行為和測試時的行為不同。我們需要再檢查一遍測試階段下MyBatchNorm
的行為。
# .eval()開啟測試模型
eval()
torch_batch_norm.eval()
my_batch_norm.for _ in range(10):
= torch.rand(batch, ch, h, w)
a = torch_batch_norm(a)
ret1 = my_batch_norm(a)
ret2 = torch.mean(torch.abs(ret1 - ret2)).item()
diff
= torch_batch_norm.running_mean
running_mean1 = my_batch_norm.running_mean
running_mean2 = torch.mean(torch.abs(running_mean1 - running_mean2)).item()
diff_mean
= torch_batch_norm.running_var
running_var1 = my_batch_norm.running_var
running_var2 = torch.mean(torch.abs(running_var1 - running_var2)).item()
diff_var print('{:.6f};{:.6f};{:.6f}'.format(diff, diff_mean, diff_var))
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
至此我們已經完成了全部檢查,實驗結果表明MyBatchNorm
的實現是正確的。
4 正確地使用BN
BN作為一種實用、應用廣泛的歸一化模塊,是計算機視覺領域的一座里程碑。儘管BN的應用確實解決了一些實際問題,但它也存在一些“坑”,是在使用BN時應當注意的。
4.1 BN在訓練階段與測試階段的行為差異
在訓練階段,BN使用batch內統計的均值和方差作歸一化,並記錄它們的moving average;而在測試階段,BN不再統計新數據的均值和方差,也不再更新moving average。這種不一致性(inconsistency)在後續工作[2]中被認為是一種影響性能的潛在因素。
4.2 如何正確地凍結BN模塊
設想我們有一個模型經過了充分的預訓練,現在我們希望在一個小數據集上微調它。一般步驟包括(以pytorch為例):
- 阻止梯度反傳。這可以通過使用
torch.no_grad()
或將該各參數的requires_grad
屬性設置為False
做到; - 調用
module.eval()
,關閉train
模式;
針對第2點,一般人們有兩種意見。一種看法認為不開BN的eval
模式更好,這有助於讓模型學習如何對新數據做歸一化。而我傾向於採取的做法則是開啟eval
。在我的經驗中,如果BN處於訓練狀態,而模型的其它層則凍結著,那麼模型可能因為不適應BN在新數據上歸一化參數的改變而引發訓練不穩定。
總而言之,BN在訓練、遷移學習、測試時的行為不一致有時確實是一個麻煩的問題。如果遇到了這個問題,我建議考慮一下是否要開啟BN的eval
模式,或者試試後來的Group Normalization[2]。
4.3 分佈式訓練
在訓練參數量較大的模型時,可以用分佈式訓練,利用多個進程和多個計算設備執行計算。這種情況下,每張卡只需負責比較小的batch。注意原始的BN在batch size較小時,所產生的均值/方差的統計量不準確。因此,在分佈式訓練時,我們最好將原BatchNorm模塊替換為torch.SyncBatchNorm
。後者能同步所有計算設備,在更大的batch size上統計均值和方差。
4.4 不要遞歸地使用BN
最後介紹一個我踩過的,印象深刻的坑。 假如有這樣一段代碼:
= batch_norm(x1)
x2 = batch_norm(x2) x3
或
= conv(x1) # conv中包含BN模塊
x2 = conv(x2) x3
前面的一段代碼的問題或許容易識別,後者的問題則稍隱蔽些。你能預測到會發生什麼嗎?在這樣的代碼中,同一個BN模塊在訓練時會分別獲取x2
和x1
的均值和方差,然後通過moving average將它們計入running_mean
和running_var
。然而,由於x1
和x2
服從不同的分佈,因此running_mean
和running_var
的統計將失去意義。 問題的表現是,在訓練階段,我們會觀察到損失正常下降。測試時,我們開啟eval
模式,模型的表現不如預期;可是如果你關閉eval
模式,也許會發現模型又能正常工作。
類似的問題也存在於特征金字塔(FPN)的實現中。如果你希望在類特征金字塔的結構中實現不同層級共享參數的話,注意卷積的參數也許能共享,但BN的參數不要共享。
5 總結
本文介紹了BN的工作原理,給出了一種基於pytorch的BN模塊實現,並提供了詳細的代碼檢查。最後,本文討論了應用BN過程中容易遇到的幾種問題。
在接觸深度學習的過程中,Batch Normalization是一個讓我反復(大概得有兩三次吧)踩坑的模塊,每次踩坑都得琢磨好久才能發現問題所在。現在我已經習慣性的選擇Group Normalization[2],拋棄BN了。儘管如此,BN仍是一個經典的工作,它背後的思想很值得學習研究。