前言
Transformer 模型相較於傳統的 RNN 模型更適合并行訓練,因此得以在 NLP 領域得到廣汎應用。但是它也存在一些問題:
- 時間複雜度和空間複雜度:\(O(n^2)\) 空間和時間複雜度是導致 Transformer 模型消耗顯存和算力的痛點。
- 缺少對上下文的壓縮:有一個觀點是”壓縮即智能”。但是我們在 Transformer 上沒有看到任何”記憶壓縮”的影子。它只是將所有 token 的特徵排列在那裏,可供後續隨時取用。有沒有一種辦法,能夠實現上下文的壓縮,用有限的空間管理我們感興趣的知識呢?
出於對以上問題的興趣,我著手調研(或者也可以説是考古吧😀)和實驗一些綫性 Transformer 的工作。
本文是我閱讀 Transformers are RNNs [1] 這篇文章的筆記。這篇文章有一些有意思的貢獻:
- 從 softmax attention 這個特例推廣到一般的衡量 query 和 key 相似度的方法;
- 提出線性 Transformer,它既能像傳統 Transformer 一樣並行訓練,也能像 RNN 模型那樣,在推理階段維持常數的空間複雜度和時間複雜度。
Transformer 的結構
Transformer 模型 \(T\) 可以看做是一個類型為 \(\mathbb R^{N\times F}\to \mathbb R^{N\times F}\) 的函數,這裡 \(N\) 是序列長度,\(F\) 是特徵向量的維度。
一個 Transformer 有 \(L\) 層,記第 \(l\) 層為 \(T_l\),其結構為
\[ T_l(x) = f_l(A_l(x) + x), \]
其中 \(A_l\) 是 Attention 模塊,\(f_l\) 是 Feed-forward 模塊。
在 Attention 模塊 \(A_l\) 中,\(x\) 先經過線性轉換 \(\mathbf W_Q\in \mathbb R^{F\times D},\mathbf W_K\in \mathbb R^{F\times D},\mathbf W_V\in \mathbb R^{F\times M}\),得到 \(\mathbf Q, \mathbf K, \mathbf V\)(即 query、key 和 value 向量):
\[ \begin{aligned} \mathbf Q &= \mathbf x\mathbf W_Q,\\ \mathbf K &= \mathbf x \mathbf W_K,\\ \mathbf V &= \mathbf x\mathbf W_V. \end{aligned} \]
接著,Attention 模塊的輸出為
\[ \mathbf A_l(x) = \mathbf V' = \text{softmax}\left(\frac{\mathbf Q\mathbf K^T}{\sqrt{D}}\right)V. \tag{1}\]
可以看到 \(V'\) 實際上是 \(V\) 的加權和,其中每個位置特徵的權重由 \(Q\) 和 \(K\) 計算得到。
Equation 1 使用 softmax 函數計算權重,這實際是下面的這個函數的一個特例:
\[ \mathbf V'_i = \frac{\sum_{j=1}^N \text{sim}(\mathbf Q_i, \mathbf K_j)V_j}{\sum_{j=1}^N \text{sim}(\mathbf Q_i, \mathbf K_j)}. \tag{2}\]
softmax 函數等同於將上式中的 \(\text{sim}\) 函數定義為 \(\exp(\frac{\mathbf Q_i^T\mathbf K_j}{\sqrt D})\)。
線性注意力
假設對於某 \(\text{sim}\) 函數,存在 \(\phi\) 函數,使得
\[ \text{sim}(\mathbf Q_i, \mathbf K_j) = \phi(\mathbf Q_i)^T \phi(\mathbf K_j) \]
這樣我們就將 Attention 模塊線性化了。利用矩陣乘法的結合律,上式可以改寫為:
\[ \mathbf V'_i = \frac{ \phi(\mathbf Q_i)^T \sum_{j=1}^N \phi(\mathbf K_j) \mathbf V_j^T }{ \phi(\mathbf Q_i)^T \sum_{j=1}^N \phi(\mathbf K_j) } \tag{3}\]
由於 query 和 key 兩兩之間都需要計算相似度,softmax-attention [2] 的時間複雜度是 \(O(N^2)\);其空間複雜度同理也是 \(O(N^2)\),因為需要存儲注意力矩陣用於後續的反向傳播。而線性注意力的時間複雜度是 \(O(N)\),空間複雜度也是 \(O(N)\)。由 Equation 3 可見,\(\sum_{j=1}^N \phi(\mathbf K_j) \mathbf V_j^T\) 和 \(\sum_{j=1}^N \phi(\mathbf K_j)\) 均可以在計算時可以被緩存和複用。
為了使 \(\text{sim}(\cdot)\) 是非負的,文章取 \(\phi(x)\) 為
\[ \phi(x) = \text{elu}(x) + 1 \]
因果掩膜
在訓練自回歸模型時,每個位置的 token 都不能受到未來 token 的影響。位置 \(i\) 上的 token 受到位置 \(j\) 上的 token 的影響,當且僅當 \(j\leq i\)。這應用於線性 Transformer,需要將 Equation 2 改為
\[ \mathbf V'_i=\frac{\sum_{j=1}^i \text{sim}(\mathbf Q_i, \mathbf K_j)\mathbf V_j}{\sum_{j=1}^i \text{sim}(\mathbf Q_i,\mathbf K_j)}. \tag{4}\]
類似前面的推理,上式可改寫為
\[ \mathbf V'_i = \frac{ \phi(\mathbf Q_i)^T \sum_{j=1}^i \phi(\mathbf K_j) \mathbf V_j^T }{ \phi(\mathbf Q_i)^T \sum_{j=1}^i \phi(\mathbf K_j) }, \tag{5}\]
引入變量 \(S_i\) 和 \(Z_i\) 如下:
\[ \begin{aligned} \mathbf S_i &= \sum_{j=1}^i \phi(\mathbf K_j) \mathbf V_j^T,\\ \mathbf Z_i &= \sum_{j=1}^i \phi(\mathbf K_j). \end{aligned} \]
Equation 5 可以改寫為
\[ \mathbf V'_i = \frac{ \phi(\mathbf Q_i)^T \mathbf S_i }{ \phi(\mathbf Q_i)^T \mathbf Z_i }. \tag{6}\]
注意 \(\mathbf S_i, \mathbf Z_i\) 可以由 \(\mathbf S_{i-1}, \mathbf Z_{i-1}\) 計算得到,時間複雜度為常數。
梯度計算
如果根據 Equation 6 實現簡單的深度學習模型,那麼為了梯度計算,我們需要緩存所有的 \(S_i, Z_i\),這將帶來很大的空間複雜度。文章提出將對 Equation 5 分子的求導實現為一種基於累加和(cumulative sum)的方法,實現前向傳播和反向傳播時的線性時間複雜度和常數空間複雜度。
為了簡便,本節的推導假設 \(\mathbf Q, \mathbf K\) 中已經包含了 \(\phi\) 函數。
假設 \(\overline{\mathbf V}_i\) 是 Equation 5 的分子,即
\[ \overline{\mathbf V}_i = \mathbf Q_i^T \sum_{j=1}^i \mathbf K_j \mathbf V_j^T. \tag{7}\]
設 \(L\) 為損失,已知 \(\overline{\mathbf V}_i\) 和 \(L\),則 \(L\) 對 \(\mathbf Q, \mathbf K, \mathbf V\) 的梯度分別為
\[ \nabla_{\mathbf Q_i}L = (\nabla_{\overline{\mathbf V}_i} L)(\sum_{j=1}^i \mathbf K_j \mathbf V_j^T)^T, \tag{8}\]
\[ \nabla_{\mathbf K_i}L = \left(\sum_{j=i}^N\mathbf Q_j (\nabla_{\overline{\mathbf V}_j}L)^T\right)\mathbf V_i, \tag{9}\]
\[ \nabla_{\mathbf V_i}L = \left(\sum_{j=i}^N \mathbf Q_j (\nabla_{\overline{V}_j}L)^T\right)^T \mathbf K_i, \tag{10}\]
其中 \(\mathbf Q\in\mathbb R^{N\times D},\mathbf K\in\mathbb R^{N\times D}, \mathbf V\in \mathbb R^{N\times M}\)。
文章只考慮了分子的梯度計算。分母、整個分式的梯度計算交給 torch 自動處理。
以下是詳細推導:
首先我們考慮矩陣中每個元素的計算,將 Equation 7 中的矩陣、向量記號去除,得到
\[ \overline{V}_{ie} = \sum_{d=1}^D Q_{id} \sum_{j=1}^i K_{jd} V_{je} = \sum_{d=1}^D \sum_{j=1}^i Q_{id} K_{jd} V_{je}. \]
於是對於任意的 \(Q_{lt}\),我們可以推導出梯度為
\[ \frac{\partial L}{\partial Q_{lt}} = \sum_{e=1}^M \frac{\partial L}{\partial \overline{V}_{le}}\frac{\partial \overline{V}_{le}}{\partial Q_{lt}} = \sum_{e=1}^M\frac{\partial L}{\partial \overline{V}_{le}}(\sum_{j=1}^l K_{jt} V_{je}). \tag{11}\]
將其整理成矩陣形式,得到 Equation 8。
在 Equation 11 中,我們利用了 \(\overline{\mathbf V}_l\) 只受 \(l\) 位置的 query(即 \(\mathbf Q_l\))影響的性質。query 只影響當下,而每個 key 和 value 都會對未來的計算產生影響。對於 key,其梯度的計算方式為:
\[ \begin{aligned} \frac{\partial L}{\partial K_{lt}} &= \sum_{e=1}^M\sum_{i=l}^N \frac{\partial L}{\partial \overline{V}_{ie}} \frac{\partial \overline{V}_{ie}}{\partial K_{lt}} \\ &= \sum_{e=1}^M\sum_{i=l}^N \frac{\partial L}{\partial \overline{V}_{ie}} \frac{\partial(\sum_{d=1}^D\sum_{j=1}^i Q_{id}K_{jd}V_{je})}{\partial K_{lt}} \\ &= \sum_{e=1}^M\sum_{i=l}^N \frac{\partial L}{\partial \overline{V}_{ie}} Q_{it} V_{le} \end{aligned} \]
將其整理為矩陣形式,得到 Equation 9。
類似的,對於 value,其梯度的計算方式為:
\[ \begin{aligned} \frac{\partial L}{\partial V_{lt}} &= \sum_{e=1}^M \sum_{i=l}^N \frac{\partial L}{\partial \overline{V}_{ie}} \frac{\overline{V}_{ie}}{\partial V_{lt}} \\ &= \sum_{e=1}^M \sum_{i=l}^N \frac{\partial L}{\partial \overline{V}_{ie}} \frac{\partial (\sum_{d=1}^D \sum_{j=1}^i Q_{id} K_{jd} V_{je})}{\partial V_{lt}}\\ &= \sum_{i=l}^N \frac{\partial L}{\partial \overline{V}_{it}} \sum_{d=1}^D Q_{id}K_{ld} \end{aligned} \]
將其整理為矩陣形式,得到 Equation 10。
訓練和推理
訓練時,完整的訓練序列是已知的,這允許 Transformer 模型實現並行的訓練;而受限於計算方式,傳統的 RNN 模型一般難並行訓練。Transformer 模型的每一步推理的時間複雜度是不同的,隨著上下文長度的增加而增加;而 RNN 模型的時間複雜度是固定的。
文章提出的線性 Transformer 結合了兩者的優點。
Transformer 模型是 RNN 模型的特例
從本文的討論,我們可以明顯的看出,帶因果掩膜的 Transformer 模型可以視作是 RNN 模型的特例,即 Transformer 模型可以看做是一個能維護一個內部狀態(\(\mathbf S_i\) 和 \(\mathbf Z_i\)),在每次獲得新輸入時更新內部狀態的模型。
推薦閲讀
- 《線性Attention的探索:Attention必須有個Softmax嗎?》。蘇劍林的文章,非常好的總結分享了各種各樣的綫性 Attention,甚至提出了自己的新設計。