自从“Attention is All You need”[1]这篇文章发布之后,注意力机制开始广为人知。虽然一开始注意力机制被应用于自然语言处理领域,但人们很快发现它也能够用于处理图像、点云等数据结构,并且取得非常好的效果。
本文介绍如何用pytorch实现文章[1]提出的multi-head attention(MHA)。MHA是scaled dot-product attention(SDPA)[1,2]的改进。既然MHA是SDPA的multi-head版,那么也许将SDPA称为单头注意力会比较形象。让我们先从单头注意力开始。
1 Scaled Dot-product Attention
SDPA的计算公式如下: \[ \text{Attention}(Q, K, V) = \text{Softmax}(\frac{QK^T}{\sqrt{d_k}})V \tag{1}\] 其中,\(Q\)、\(K\)和\(V\)分别为输入的query、key和value向量。\(Q\)的尺寸为\(L\times E_q\),\(K\)的尺寸为\(S\times E_q\),\(V\)的尺寸为\(S\times E_v\)。
公式 1中,\(QK^T\)的计算可以理解为求query向量和key向量间的“匹配程度”,使得SDPA能够根据匹配程度从value中取得相关的数据。Softmax函数在这里起到归一化的作用,使得SDPA的计算结果可以视为V的加权平均。
注意力机制是一个有趣的设计。RNN(Recurrent Neural Networks)结构面临著长程记忆困难的问题;类似的,CNN(Convolutional Neural Networks)中也存在远距离像素感知困难的问题。SDPA的特点是,为每一对query和key都作计算,而不考虑其时间或空间上的位置差距。
我个人的观察是,注意力机制就好像构造了一个“完全二部图”。这个二部图由\(A\)和\(B\)两个顶点集合组成,其中\(A\)集合代表query,\(B\)集合代表key和value。\(A\)和\(B\)之间的结点两两之间都有一条边,边的权重由query和key的“匹配程度”决定。类似SDPA,全连接层或卷积层也可以表示为二部图,只不过其边的权重是通过训练得到的,固定的值,而在注意力机制中,边的权重是动态的,根据query和key计算得到的。
1.1 Attention的缩放系数
从公式 1可以看到,在执行Softmax函数前,SDPA使用系数\(\frac{1}{\sqrt{d_k}}\)对输入作了缩放。为什么是\(\sqrt{d_k}\),而不是其它系数呢?本节主要摘录苏剑林的系列文章中的结论,对这个问题作简要的讨论。
Softmax是一个将向量映射为向量的函数。首先,我们注意到Softmax输入向量的方差对于其输出向量的分布有重要影响。
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
np.random.seed(0)
length = 32
samples = []
n_samples = 32
for i in range(n_samples):
m = i / n_samples * 10
x = np.random.randn(length) * m
y = np.array(torch.softmax(torch.tensor(x), dim=0))
samples.append((sorted(y), np.var(x)))
samples.sort(key=lambda x: x[1]) # 按方差排序
fig = plt.figure()
ax = plt.gca()
cax = ax.matshow(np.array([s for s, i in samples]).transpose())
ax.set_xticklabels([f'{i:.2f}' for s, i in samples])
fig.colorbar(cax);
由图 1可见,Softmax函数的输入向量方差越小,则输出越接近\(\vec 0\);反之,Softmax的输出越接近one-hot编码,即只有一个元素是1,其它元素都接近0. 然而,这两种极端情况都容易引发梯度消失问题。(这在我的《Softmax函数的性质》一文中已经给出了推导。)
为了避免梯度消失,我们既不希望Softmax的输出接近\(\vec 0\),也不希望其成为one-hot编码。下面的程序通过直观的方式演示了其中的原因:
# 向量的长度
length = 128
# 缩放系数,影响输入向量的方差大小
scales = [2**(i / 3) for i in range(10)]
xs = []
ys = []
for _ in range(1000):
# 随机选择一个缩放系数
m = np.random.choice(scales)
# 随机初始化Softmax的输入,并缩放
x = torch.tensor(np.random.randn(length) * m, requires_grad=True)
# 执行Softmax函数
y = torch.softmax(x, dim=0)
# 横轴为max(y)
xs.append(y.max().item())
label = torch.rand(y.shape)
loss = torch.abs(label - y).mean()
loss.backward()
# 纵轴为梯度绝对值的最大值
ys.append(x.grad.abs().max().item())
plt.scatter(xs, ys)
plt.xlabel('Softmax输出的最大值')
plt.ylabel('梯度绝对值的最大值')
plt.show()
从图 2可以看到,当Softmax的输入接近\(0\)向量或者one-hot向量的时候,容易发生梯度消失现象,此时梯度的绝对值接近\(0\).
综上所述,为尽量避免Softmax的梯度消失,控制输入向量的方差很重要。假设\(\vec q\)和\(\vec k\)都是\(0\)均值的,二者相互独立,方差分别为\(v_q\)和\(v_k\),不难证明它们的内积\(\vec q^T \vec k\)的数学期望为\(0\)。展开证明
\[ \begin{align*} \mathbb{E}(\vec q^T \vec k) &= \mathbb{E}\left(\sum_i q_i k_i\right) \\ &= \sum_i \mathbb{E}(q_i k_i) \\ &= \sum_i \mathbb{E}(q_i) \mathbb{E}(k_i) \\ &= \sum_i 0 \cdot 0 \\ &= 0. \end{align*} \]展开证明
\[ \begin{align*} \text{Var}(\vec q^T \vec k) &= \text{Var}\left(\sum_{i=1}^{d_k} (q_i k_i)\right) \\ \end{align*} \] 假设\(\vec q^T \vec k\)的任意两个维度都是独立同分布的,那么 \[ \begin{align*} \\ \text{Var}(\vec q^T \vec k) &= d_k Var(q_i k_i) \\ &= d_k (\mathbb E(q_i^2 k_i^2) - (\mathbb E(q_i k_i))^2)\\ &= d_k (\mathbb E(q_i^2)\mathbb E(k_i^2) - (\mathbb E(q_i) \mathbb E(k_i))^2)\\ &= d_k (\mathbb E(q_i^2)\mathbb E(k_i^2) )\\ &= d_k \left( Var(q_i) Var(k_i) \right) \\ &= d_k v_q v_k \end{align*} \] 于是\({\rm Var} (\vec q^T \vec k/\sqrt{d_k})=v_q v_k\)。因此,假设输入的query和value是标准正态分布(均值为\(0\),方差为\(1\)),那么经过\(1/\sqrt{d_k}\)的缩放处理,就能保持softmax的输入也是标准正态分布,避免方差过小或过大。
1.2 代码实现
在充分理解清楚SDPA的计算公式的前提下,用pytorch实现它就不困难。需要注意的是公式 1中没有体现attention mask,但我们要在代码实现中支持它。attention mask用于在训练中阻止模型到未来的token,这对训练聊天机器人、机器翻译模型是等生成式模型是必要的。
在实现中,mask矩阵的值会被加到注意力矩阵上。如果mask矩阵中有一个值非常小(接近\(-\infty\)),那么对应的注意力就会在softmax中被抑制。
下列代码实现了scaled dot-product attention:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
def my_scaled_dot_product_attention(q, k, v, mask=None):
embed_dim = q.shape[-1]
attn_weights = torch.einsum('nle,nse->nls', q, k)
attn_weights /= np.sqrt(embed_dim)
if mask is not None: attn_weights += mask
attn_weights = torch.softmax(attn_weights, dim=-1)
ret = torch.einsum('nls,nse->nle', attn_weights, v)
return ret, attn_weights让我们看看实现的my_scaled_dot_product_attention函数能不能正常运行:
batch_size = 1
length = 2
embed_dim = 4
q = torch.rand((batch_size, length, embed_dim))
k = torch.rand((batch_size, length, embed_dim))
v = torch.rand((batch_size, length, embed_dim))
mask = torch.rand((length, length))
ret, attn_weights = my_scaled_dot_product_attention(q, k, v, mask)
print('attention操作返回结果:', ret)
print('attention矩阵为:', attn_weights)attention操作返回结果: tensor([[[0.5092, 0.4412, 0.1042, 0.1914],
[0.6837, 0.5136, 0.1024, 0.2412]]])
attention矩阵为: tensor([[[0.5479, 0.4521],
[0.3397, 0.6603]]])
一个检查实现正确性的小技巧是,观察返回attention矩阵的各列之和是否为1. 如果不为1,说明实现有问题。
print(attn_weights.sum(-1))
assert torch.all(attn_weights.sum(-1) - 1 < 1e-6)tensor([[1., 1.]])
在准备好scaled dot product attention的基础上,我们便可以著手实现multi-head attention了。
2 Multi-head Attention
根据原论文描述,MHA的实现为: \[\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O \] 其中\(\text{head}_i = \text{Attention}(QW^Q_i , KW^K_i ,VW^V_i)\). 据此可知,MHA操作为:
- 将query、key和value拆分成若干个head,对这些head分别作线性变换;
- 对每个head分别执行scaled dot-product attention;
- 将所有head的attention的结果合并,经过一层线性变换后返回。
通过在多个head上执行attention,MHA允许模型在不同的子空间收集不同位置的信息,增加了attention机制的表达能力。
代码实现如下:
class MyMultiheadAttention(nn.Module):
def __init__(
self,
embed_dim,
num_heads,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.in_proj_weight = nn.Parameter(
torch.rand((3 * embed_dim, embed_dim))
)
self.in_proj_bias = nn.Parameter(
torch.rand((3 * embed_dim))
)
self.out_proj = nn.Linear(embed_dim, embed_dim)
def forward(self, query, key, value, attn_mask=None):
L, N, q_dim = query.shape
S, N, k_dim = key.shape
S, N, v_dim = value.shape
assert self.embed_dim % self.num_heads == 0
weight_query, weight_key, weight_value = self.in_proj_weight.chunk(3)
bias_query, bias_key, bias_value = self.in_proj_bias.chunk(3)
query = F.linear(query, weight_query, bias_query)
key = F.linear(key, weight_key, bias_key)
value = F.linear(value, weight_value, bias_value)
# L, N, E -> N*n_heads, L, E/n_heads
query = query.view(L, N * self.num_heads, -1).permute(1, 0, 2)
key = key.view(S, N * self.num_heads, -1).permute(1, 0, 2)
value = value.view(S, N * self.num_heads, -1).permute(1, 0, 2)
out, attn_weights = my_scaled_dot_product_attention(query, key, value, mask=attn_mask)
out = out.permute(1, 0, 2).reshape(L, N, -1)
attn_weights = attn_weights.reshape(N, -1, L, S).mean(1)
out = self.out_proj(out)
return out, attn_weights首先,我们先比较一下本文实现的multi-head attention和torch官方的实现的参数设置。我们将会看到所有参数名称和参数尺寸都是一致的。这表明我们的multi-head attention与torch的官方实现兼容。
batch_size = 2
embed_dim = 16
num_heads = 4
torch_mha = nn.MultiheadAttention(embed_dim, num_heads)
for n, p in torch_mha.named_parameters():
print(n, p.shape)
print('')
my_mha = MyMultiheadAttention(embed_dim, num_heads)
for n, p in my_mha.named_parameters():
print(n, p.shape)in_proj_weight torch.Size([48, 16])
in_proj_bias torch.Size([48])
out_proj.weight torch.Size([16, 16])
out_proj.bias torch.Size([16])
in_proj_weight torch.Size([48, 16])
in_proj_bias torch.Size([48])
out_proj.weight torch.Size([16, 16])
out_proj.bias torch.Size([16])
接下来,我们检查我们实现的MHA和torch版本MHA的输出是否完全一致。在进行比较之前,我们先进行参数的复制,保持两个模块的参数完全相同。
my_mha_parameters = dict(my_mha.named_parameters())
for n, p in torch_mha.named_parameters():
my_mha_parameters[n].data.copy_(p)
assert torch.all(p == my_mha_parameters[n])下列代码对函数的输出进行检查:
query = torch.rand((length, batch_size, embed_dim))
key = torch.rand((length, batch_size, embed_dim))
value = torch.rand((length, batch_size, embed_dim))
# prepare mask
attn_bias = torch.zeros(length, length, dtype=query.dtype)
temp_mask = torch.ones(length, length, dtype=torch.bool).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
ret, attn_weights = my_mha(query, key, value, attn_mask=attn_bias)
ret2, attn_weights2 = torch_mha(query, key, value, need_weights=True, attn_mask=attn_bias)
error = torch.mean(torch.abs(ret - ret2)).item()
error2 = torch.mean(torch.abs(attn_weights - attn_weights2)).item()
print(f'Difference: {error:.4f}, {error2:.4f}')
assert error < 1e-6 and error2 < 1e-6Difference: 0.0000, 0.0000
检查结果表明我们实现的multi-head attention输出与torch官方实现的完全一致,验证了代码的正确性。
3 练习题
例子 1 改变MHA的heads数会如何改变MHA的参数量?
参考答案
MHA的参数量与head的数量无关,这从代码中也能很容易看出来。
实验也验证了这一点。
embed_dim = 16
num_heads = 4
mha1 = nn.MultiheadAttention(embed_dim, num_heads)
mha2 = nn.MultiheadAttention(embed_dim, num_heads * 2)
def count_parameters(m):
c = 0
for p in m.parameters(): c += p.numel()
return c
print(count_parameters(mha1), count_parameters(mha2))1088 1088
例子 2 在自注意力机制中,MHA的时间复杂度是多少?和head数有关吗?
参考答案
首先,由公式 1可以看出,单头注意力机制的时间复杂度是\(O(n^2d)\),其中\(n\)为序列长度,\(d\)为特征维度。
接下来考虑MHA的时间复杂度。MHA先对输入向量作线性变换得到Q、K、V,这部分的时间复杂度为\(O(nd^2)\);然后这些特征被拆分为\(n\)个head,分别做attention,这个部分的时间复杂度为\(O(n^2\frac{d}{h} * h)\),其中\(h\)为head数。最后MHA再做一次线性变化,时间复杂度为\(O(nd^2)\).
因此,总的来看,MHA的时间复杂度为\(O(n^2d + nd^2)\),与head数无关。
从stack overflow的这个网页可以看到一些有趣的讨论:https://stackoverflow.com/questions/65703260/computational-complexity-of-self-attention-in-the-transformer-model4 总结
推荐阅读:
- torch官方的scaled dot-product attention文档中简要介绍了一种scaled dot-product attention的python实现方式。
- 论文MQA[3]中的Background一节对各类Attention有详细介绍,推荐一读。