1 前言
The Bitter Lesson提出了这样的观点:能充分适应算力增长的方法最终能脱颖而出。Transformer之所以能在很多领域取代RNN,就是因为前者更适合并行训练。可是在大模型广泛应用的今天,这些基于Transformer结构的模型仍然还有一些令人在意的缺点和局限性:
- 每一个token的计算量随著上下文的增长而增长
- 随著上下文越来越长,模型就越难从上下文中找到“重点”
- 如何管理模型的记忆?制作一个聊天机器人,难道要将所有的聊天记录都保存下来吗?或者需要定时对上下文做一个文本总结吗?
这些问题,总而言之是如何管理记忆的问题。在这方面Transformer的策略是简单粗暴的:它把所有的历史token都保留下来。而理想中的AI或许应该具有这样的特性:他能够维护一个大小有限的记忆,能够选择保留重要的信息,也具备“遗忘”的能力。这就让人联想到RNN模型。
当然,也许我们不能指望重新设计一个RNN模型,让他和最领先的大模型一较高下。但像这样的思考仍然是有益而且有趣的:如果我们要设计一个具有维护记忆力的能力的模型,他应该是什么样子?
本文是我阅读Transformers are RNNs[1]这篇文章的笔记。这篇文章有一些有意思的贡献:
- 从softmax attention这个特例推广到一般的衡量query和key相似度的方法;
- 提出线性Transformer,它既能像传统Transformer一样并行训练,也能像RNN模型那样,在推理阶段维持常数的空间复杂度和时间复杂度。
2 Transformer的结构
Transformer模型\(T\)可以看做是一个类型为\(\mathbb R^{N\times F}\to \mathbb R^{N\times F}\)的函数,这里\(N\)是序列长度,\(F\)是特征向量的维度。
一个Transformer有L层,记第\(l\)层为\(T_l\),其结构为 \[ T_l(x) = f_l(A_l(x) + x), \] 其中\(A_l\)是Attention模块,\(f_l\)是Feed-forward模块。
在Attention模块\(A_l\)中,\(x\)先经过线性转换\(\mathbf W_Q\in \mathbb R^{F\times D},\mathbf W_K\in \mathbb R^{F\times D},\mathbf W_V\in \mathbb R^{F\times M}\),得到\(\mathbf Q, \mathbf K, \mathbf V\)(即query、key和value向量): \[ \begin{aligned} \mathbf Q &= \mathbf x\mathbf W_Q,\\ \mathbf K &= \mathbf x \mathbf W_K,\\ \mathbf V &= \mathbf x\mathbf W_V. \end{aligned} \] 接著,Attention模块的输出为 \[ \mathbf A_l(x) = \mathbf V' = \text{softmax}\left(\frac{\mathbf Q\mathbf K^T}{\sqrt{D}}\right)V. \tag{1}\] 可以看到\(V'\)实际上是\(V\)的加权和,其中每个位置特征的权重由\(Q\)和\(K\)计算得到。
公式 1使用softmax函数计算权重,这实际是下面的这个函数的一个特例: \[ \mathbf V'_i = \frac{\sum_{j=1}^N \text{sim}(\mathbf Q_i, \mathbf K_j)V_j}{\sum_{j=1}^N \text{sim}(\mathbf Q_i, \mathbf K_j)}. \tag{2}\] softmax函数等同于将上式中的\(\text{sim}\)函数定义为\(\exp(\frac{\mathbf Q_i^T\mathbf K_j}{\sqrt D})\)
3 线性注意力
假设对于某\(\text{sim}\)函数,存在\(\phi\)函数,使得 \[ \text{sim}(\mathbf Q_i, \mathbf K_j) = \phi(\mathbf Q_i)^T \phi(\mathbf K_j) \] 这样我们就将Attention模块线性化了。 利用矩阵乘法的结合律,上式可以改写为: \[ \mathbf V'_i = \frac{ \phi(\mathbf Q_i)^T \sum_{j=1}^N \phi(\mathbf K_j) \mathbf V_j^T }{ \phi(\mathbf Q_i)^T \sum_{j=1}^N \phi(\mathbf K_j) } \tag{3}\] 由于query和key两两之间都需要计算相似度,softmax-attention[2]的时间复杂度是\(O(N^2)\);其空间复杂度同理也是\(O(N^2)\),因为需要存储注意力矩阵用于后续的反向传播。而线性注意力的时间复杂度是\(O(N)\),空间复杂度也是\(O(N)\)。由公式 3可见,\(\sum_{j=1}^N \phi(\mathbf K_j) \mathbf V_j^T\)和\(\sum_{j=1}^N \phi(\mathbf K_j)\)均可以在计算时可以被缓存和复用。
为了使\(\text{sim}(\cdot)\)是非负的,文章取\(\phi(x)\)为 \[ \phi(x) = \text{elu}(x) + 1 \]
4 因果掩膜
在训练自回归模型时,每个位置的token都不能受到未来token的影响。位置\(i\)上的token受到位置\(j\)上的token的影响,当且仅当\(j\leq i\). 这应用于线性Transformer,需要将公式 2改为 \[ \mathbf V'_i=\frac{\sum_{j=1}^i \text{sim}(\mathbf Q_i, \mathbf K_j)\mathbf V_j}{\sum_{j=1}^i \text{sim}(\mathbf Q_i,\mathbf K_j)}. \tag{4}\] 类似前面的推理,上式可改写为 \[ \mathbf V'_i = \frac{ \phi(\mathbf Q_i)^T \sum_{j=1}^i \phi(\mathbf K_j) \mathbf V_j^T }{ \phi(\mathbf Q_i)^T \sum_{j=1}^i \phi(\mathbf K_j) }, \tag{5}\] 引入变量\(S_i\)和\(Z_i\)如下: \[ \begin{aligned} \mathbf S_i &= \sum_{j=1}^i \phi(\mathbf K_j) \mathbf V_j^T,\\ \mathbf Z_i &= \sum_{j=1}^i \phi(\mathbf K_j). \end{aligned} \] 公式 5可以改写为 \[ \mathbf V'_i = \frac{ \phi(\mathbf Q_i)^T\mathbf S_i }{ \phi(\mathbf Q_i)^T \mathbf Z_i }. \tag{6}\] 注意\(\mathbf S_i,\mathbf Z_i\)可以由\(\mathbf S_{i-1},\mathbf Z_{i-1}\)计算得到,时间复杂度为常数。
5 梯度计算
如果根据公式 6实现简单的深度学习模型,那么为了梯度计算,我们需要缓存所有的\(S_i, Z_i\),这将带来很大的空间复杂度。文章提出将对公式 5分子的求导实现为一种基于累加和(cumulative sum)的方法,实现前向传播和反向传播时的线性时间复杂度和常数空间复杂度。
为了简便,本节的推导假设\(\mathbf Q, \mathbf K\)中已经包含了\(\phi\)函数。
假设\(\overline{\mathbf V}_i\)是公式 5的分子,即 \[ \overline{\mathbf V}_i = \mathbf Q_i^T \sum_{j=1}^i \mathbf K_j \mathbf V_j^T. \tag{7}\] 设\(L\)为损失,已知\(\overline{\mathbf V}_i\)和\(L\),则\(L\)对\(\mathbf Q, \mathbf K, \mathbf V\)的梯度分别为 \[ \nabla_{\mathbf Q_i}L = (\nabla_{\overline{\mathbf V}_i} L)(\sum_{j=1}^i \mathbf K_j \mathbf V_j^T)^T, \tag{8}\] \[ \nabla_{\mathbf K_i}L = \left(\sum_{j=i}^N\mathbf Q_j (\nabla_{\overline{\mathbf V}_j}L)^T\right)\mathbf V_i, \tag{9}\] \[ \nabla_{\mathbf V_i}L = \left(\sum_{j=i}^N \mathbf Q_j (\nabla_{\overline{V}_j}L)^T\right)^T \mathbf K_i, \tag{10}\] 其中\(\mathbf Q\in\mathbb R^{N\times D},\mathbf K\in\mathbb R^{N\times D}, \mathbf V\in \mathbb R^{N\times M}\).
文章只考虑了分子的梯度计算。分母、整个分式的梯度计算交给torch
自动处理。
以下是详细推导:
首先我们考虑矩阵中每个元素的计算,将公式 7中的矩阵、向量记号去除,得到 \[ \overline{V}_{ie} = \sum_{d=1}^D Q_{id} \sum_{j=1}^i K_{jd} V_{je} = \sum_{d=1}^D \sum_{j=1}^i Q_{id} K_{jd} V_{je}。 \]
于是对于任意的\(Q_{lt}\),我们可以推导出梯度为 \[ \frac{\partial L}{\partial Q_{lt}} = \sum_{e=1}^M \frac{\partial L}{\partial \overline{V}_{le}}\frac{\partial \overline{V}_{le}}{\partial Q_{lt}} = \sum_{e=1}^M\frac{\partial L}{\partial \overline{V}{le}}(\sum_{j=1}^lK_{jt}V_{je}). \tag{11}\]
将其整理成矩阵形式,得到公式 8
在公式 11中,我们利用了\(\overline{\mathbf V}_l\)只受\(l\)位置的query(即\(\mathbf Q_l\))影响的性质。query只影响当下,而每个key和value都会对未来的计算产生影响。对于key,其梯度的计算方式为: \[ \begin{aligned} \frac{\partial L}{\partial K_{lt}} &= \sum_{e=1}^M\sum_{i=l}^N \frac{\partial L}{\partial \overline{V}_{ie}} \frac{\partial \overline{V}_{ie}}{\partial K_{lt}} \\ &= \sum_{e=1}^M\sum_{i=l}^N \frac{\partial L}{\partial \overline{V}_{ie}} \frac{\partial(\sum_{d=1}^D\sum_{j=1}^i Q_{id}K_{jd}V_{je})}{\partial K_{lt}} \\ &= \sum_{e=1}^M\sum_{i=l}^N \frac{\partial L}{\partial \overline{V}_{ie}} Q_{it} V_{le} \end{aligned} \] 将其整理为矩阵形式,得到公式 9
类似的,对于value,其梯度的计算方式为: \[ \begin{aligned} \frac{\partial L}{\partial V_{lt}} &= \sum_{e=1}^M \sum_{i=l}^N \frac{\partial L}{\partial \overline{V}_{ie}} \frac{\overline{V}_{ie}}{\partial V_{lt}} \\ &= \sum_{e=1}^M \sum_{i=l}^N \frac{\partial L}{\partial \overline{V}_{ie}} \frac{\partial (\sum_{d=1}^D \sum_{j=1}^i Q_{id} K_{jd} V_{je})}{\partial V_{lt}}\\ &= \sum_{i=l}^N \frac{\partial L}{\partial \overline{V}_{it}} \sum_{d=1}^D Q_{id}K_{ld} \end{aligned} \] 将其整理为矩阵形式,得到公式 10
6 训练和推理
训练时,完整的训练序列是已知的,这允许Transformer模型实现并行的训练;而受限于计算方式,传统的RNN模型一般难并行训练。Transformer模型的每一步推理的时间复杂度是不同的,随著上下文长度的增加而增加;而RNN模型的时间复杂度是固定的。
文章提出的线性Transformer结合了两者的优点。
7 Transformer模型是RNN模型的特例
从本文的讨论,我们可以明显的看出,带因果掩膜的Transformer模型可以视作是RNN模型的特例,即Transformer模型可以看做是一个能维护一个内部状态(\(\mathbf S_i\)和\(\mathbf Z_i\)),在每次获得新输入时更新内部状态的模型。