Linear Transformer論文精讀

本文瀏覽次數

发布于

2025年5月15日

1 前言

The Bitter Lesson提出了這樣的觀點:能充分適應算力增長的方法最終能脫穎而出。Transformer之所以能在很多領域取代RNN,就是因為前者更適合並行訓練。可是在大模型廣泛應用的今天,這些基於Transformer結構的模型仍然還有一些令人在意的缺點和局限性:

  • 每一個token的計算量隨著上下文的增長而增長
  • 隨著上下文越來越長,模型就越難從上下文中找到“重點”
  • 如何管理模型的記憶?製作一個聊天機器人,難道要將所有的聊天記錄都保存下來嗎?或者需要定時對上下文做一個文本總結嗎?

這些問題,總而言之是如何管理記憶的問題。在這方面Transformer的策略是簡單粗暴的:它把所有的歷史token都保留下來。而理想中的AI或許應該具有這樣的特性:他能夠維護一個大小有限的記憶,能夠選擇保留重要的信息,也具備“遺忘”的能力。這就讓人聯想到RNN模型。

當然,也許我們不能指望重新設計一個RNN模型,讓他和最領先的大模型一較高下。但像這樣的思考仍然是有益而且有趣的:如果我們要設計一個具有維護記憶力的能力的模型,他應該是什麼樣子?

本文是我閱讀Transformers are RNNs[1]這篇文章的筆記。這篇文章有一些有意思的貢獻:

  1. 從softmax attention這個特例推廣到一般的衡量query和key相似度的方法;
  2. 提出線性Transformer,它既能像傳統Transformer一樣並行訓練,也能像RNN模型那樣,在推理階段維持常數的空間複雜度和時間複雜度。

2 Transformer的结构

Transformer模型\(T\)可以看做是一個類型為\(\mathbb R^{N\times F}\to \mathbb R^{N\times F}\)的函數,這裡\(N\)是序列長度,\(F\)是特征向量的維度。

一個Transformer有L層,記第\(l\)層為\(T_l\),其結構為 \[ T_l(x) = f_l(A_l(x) + x), \] 其中\(A_l\)是Attention模塊,\(f_l\)是Feed-forward模塊。

在Attention模塊\(A_l\)中,\(x\)先經過線性轉換\(\mathbf W_Q\in \mathbb R^{F\times D},\mathbf W_K\in \mathbb R^{F\times D},\mathbf W_V\in \mathbb R^{F\times M}\),得到\(\mathbf Q, \mathbf K, \mathbf V\)(即query、key和value向量): \[ \begin{aligned} \mathbf Q &= \mathbf x\mathbf W_Q,\\ \mathbf K &= \mathbf x \mathbf W_K,\\ \mathbf V &= \mathbf x\mathbf W_V. \end{aligned} \] 接著,Attention模塊的輸出為 \[ \mathbf A_l(x) = \mathbf V' = \text{softmax}\left(\frac{\mathbf Q\mathbf K^T}{\sqrt{D}}\right)V. \tag{1}\] 可以看到\(V'\)實際上是\(V\)的加權和,其中每個位置特征的權重由\(Q\)\(K\)計算得到。

公式 1使用softmax函數計算權重,這實際是下面的這個函數的一個特例: \[ \mathbf V'_i = \frac{\sum_{j=1}^N \text{sim}(\mathbf Q_i, \mathbf K_j)V_j}{\sum_{j=1}^N \text{sim}(\mathbf Q_i, \mathbf K_j)}. \tag{2}\] softmax函數等同於將上式中的\(\text{sim}\)函數定義為\(\exp(\frac{\mathbf Q_i^T\mathbf K_j}{\sqrt D})\)

3 線性注意力

假設對於某\(\text{sim}\)函數,存在\(\phi\)函數,使得 \[ \text{sim}(\mathbf Q_i, \mathbf K_j) = \phi(\mathbf Q_i)^T \phi(\mathbf K_j) \] 這樣我們就將Attention模塊線性化了。 利用矩陣乘法的結合律,上式可以改寫為: \[ \mathbf V'_i = \frac{ \phi(\mathbf Q_i)^T \sum_{j=1}^N \phi(\mathbf K_j) \mathbf V_j^T }{ \phi(\mathbf Q_i)^T \sum_{j=1}^N \phi(\mathbf K_j) } \tag{3}\] 由於query和key兩兩之間都需要計算相似度,softmax-attention[2]的時間複雜度是\(O(N^2)\);其空間複雜度同理也是\(O(N^2)\),因為需要存儲注意力矩陣用於後續的反向傳播。而線性注意力的時間複雜度是\(O(N)\),空間複雜度也是\(O(N)\)。由公式 3可見,\(\sum_{j=1}^N \phi(\mathbf K_j) \mathbf V_j^T\)\(\sum_{j=1}^N \phi(\mathbf K_j)\)均可以在計算時可以被緩存和複用。

為了使\(\text{sim}(\cdot)\)是非負的,文章取\(\phi(x)\)\[ \phi(x) = \text{elu}(x) + 1 \]

4 因果掩膜

在訓練自回歸模型時,每個位置的token都不能受到未來token的影響。位置\(i\)上的token受到位置\(j\)上的token的影響,當且僅當\(j\leq i\). 這應用於線性Transformer,需要將公式 2改為 \[ \mathbf V'_i=\frac{\sum_{j=1}^i \text{sim}(\mathbf Q_i, \mathbf K_j)\mathbf V_j}{\sum_{j=1}^i \text{sim}(\mathbf Q_i,\mathbf K_j)}. \tag{4}\] 類似前面的推理,上式可改寫為 \[ \mathbf V'_i = \frac{ \phi(\mathbf Q_i)^T \sum_{j=1}^i \phi(\mathbf K_j) \mathbf V_j^T }{ \phi(\mathbf Q_i)^T \sum_{j=1}^i \phi(\mathbf K_j) }, \tag{5}\] 引入變量\(S_i\)\(Z_i\)如下: \[ \begin{aligned} \mathbf S_i &= \sum_{j=1}^i \phi(\mathbf K_j) \mathbf V_j^T,\\ \mathbf Z_i &= \sum_{j=1}^i \phi(\mathbf K_j). \end{aligned} \] 公式 5可以改寫為 \[ \mathbf V'_i = \frac{ \phi(\mathbf Q_i)^T\mathbf S_i }{ \phi(\mathbf Q_i)^T \mathbf Z_i }. \tag{6}\] 注意\(\mathbf S_i,\mathbf Z_i\)可以由\(\mathbf S_{i-1},\mathbf Z_{i-1}\)計算得到,時間複雜度為常數。

5 梯度計算

如果根據公式 6實現簡單的深度學習模型,那麼為了梯度計算,我們需要緩存所有的\(S_i, Z_i\),這將帶來很大的空間複雜度。文章提出將對公式 5分子的求導實現為一種基於累加和(cumulative sum)的方法,實現前向傳播和反向傳播時的線性時間複雜度和常數空間複雜度。

为了简便,本节的推导假设\(\mathbf Q, \mathbf K\)中已经包含了\(\phi\)函数。

假设\(\overline{\mathbf V}_i\)公式 5的分子,即 \[ \overline{\mathbf V}_i = \mathbf Q_i^T \sum_{j=1}^i \mathbf K_j \mathbf V_j^T. \tag{7}\]\(L\)为损失,已知\(\overline{\mathbf V}_i\)\(L\),則\(L\)\(\mathbf Q, \mathbf K, \mathbf V\)的梯度分别为 \[ \nabla_{\mathbf Q_i}L = (\nabla_{\overline{\mathbf V}_i} L)(\sum_{j=1}^i \mathbf K_j \mathbf V_j^T)^T, \tag{8}\] \[ \nabla_{\mathbf K_i}L = \left(\sum_{j=i}^N\mathbf Q_j (\nabla_{\overline{\mathbf V}_j}L)^T\right)\mathbf V_i, \tag{9}\] \[ \nabla_{\mathbf V_i}L = \left(\sum_{j=i}^N \mathbf Q_j (\nabla_{\overline{V}_j}L)^T\right)^T \mathbf K_i, \tag{10}\] 其中\(\mathbf Q\in\mathbb R^{N\times D},\mathbf K\in\mathbb R^{N\times D}, \mathbf V\in \mathbb R^{N\times M}\).

注记

文章只考虑了分子的梯度计算。分母、整个分式的梯度计算交给torch自动处理。

以下是详细推导:

首先我们考虑矩阵中每个元素的计算,将公式 7中的矩阵、向量记号去除,得到 \[ \overline{V}_{ie} = \sum_{d=1}^D Q_{id} \sum_{j=1}^i K_{jd} V_{je} = \sum_{d=1}^D \sum_{j=1}^i Q_{id} K_{jd} V_{je}。 \]

于是对于任意的\(Q_{lt}\),我们可以推导出梯度为 \[ \frac{\partial L}{\partial Q_{lt}} = \sum_{e=1}^M \frac{\partial L}{\partial \overline{V}_{le}}\frac{\partial \overline{V}_{le}}{\partial Q_{lt}} = \sum_{e=1}^M\frac{\partial L}{\partial \overline{V}{le}}(\sum_{j=1}^lK_{jt}V_{je}). \tag{11}\]

將其整理成矩陣形式,得到公式 8

公式 11中,我們利用了\(\overline{\mathbf V}_l\)只受\(l\)位置的query(即\(\mathbf Q_l\))影響的性質。query只影響當下,而每個key和value都會對未來的計算產生影響。對於key,其梯度的計算方式為: \[ \begin{aligned} \frac{\partial L}{\partial K_{lt}} &= \sum_{e=1}^M\sum_{i=l}^N \frac{\partial L}{\partial \overline{V}_{ie}} \frac{\partial \overline{V}_{ie}}{\partial K_{lt}} \\ &= \sum_{e=1}^M\sum_{i=l}^N \frac{\partial L}{\partial \overline{V}_{ie}} \frac{\partial(\sum_{d=1}^D\sum_{j=1}^i Q_{id}K_{jd}V_{je})}{\partial K_{lt}} \\ &= \sum_{e=1}^M\sum_{i=l}^N \frac{\partial L}{\partial \overline{V}_{ie}} Q_{it} V_{le} \end{aligned} \] 將其整理為矩陣形式,得到公式 9

類似的,對於value,其梯度的計算方式為: \[ \begin{aligned} \frac{\partial L}{\partial V_{lt}} &= \sum_{e=1}^M \sum_{i=l}^N \frac{\partial L}{\partial \overline{V}_{ie}} \frac{\overline{V}_{ie}}{\partial V_{lt}} \\ &= \sum_{e=1}^M \sum_{i=l}^N \frac{\partial L}{\partial \overline{V}_{ie}} \frac{\partial (\sum_{d=1}^D \sum_{j=1}^i Q_{id} K_{jd} V_{je})}{\partial V_{lt}}\\ &= \sum_{i=l}^N \frac{\partial L}{\partial \overline{V}_{it}} \sum_{d=1}^D Q_{id}K_{ld} \end{aligned} \] 將其整理為矩陣形式,得到公式 10

6 訓練和推理

訓練時,完整的訓練序列是已知的,這允許Transformer模型實現並行的訓練;而受限於計算方式,傳統的RNN模型一般難並行訓練。Transformer模型的每一步推理的時間複雜度是不同的,隨著上下文長度的增加而增加;而RNN模型的時間複雜度是固定的。

文章提出的線性Transformer結合了兩者的優點。

7 Transformer模型是RNN模型的特例

從本文的討論,我們可以明顯的看出,帶因果掩膜的Transformer模型可以視作是RNN模型的特例,即Transformer模型可以看做是一個能維護一個內部狀態(\(\mathbf S_i\)\(\mathbf Z_i\)),在每次獲得新輸入時更新內部狀態的模型。

参考文献

[1]
KATHAROPOULOS A, VYAS A, PAPPAS N, 等. Transformers are RNNs: fast autoregressive transformers with linear attention[C]//Proceedings of the 37th International Conference on Machine Learning. JMLR.org.
[2]
VASWANI A, SHAZEER N, PARMAR N, 等. Attention Is All You Need[J]. arXiv:1706.03762 [cs], 2017.
By @執迷 in
Tags : #Transformer, #Linear Transformer, #線性Transformer, #RNN,