1 背景
pyTorch和numpy等库都提供了einsum
函数函数。einsum,全称Einstein summation convention(爱因斯坦求和约定),是一个方便而强大的工具。
我们熟悉的矩阵乘法、转置、求矩阵的迹等操作都可视为einsum的特例。除此之外,它还可以表示更复杂的运算,例如将矩阵乘法、转置等操作复合在一起表示为一个操作。
先从矩阵乘法说起。我们可以这样用einsum表示矩阵乘法,
import torch
= torch.rand(2, 2)
a = torch.rand(2, 2)
b = torch.einsum('ij,jk->ik', a, b) c
在ij,jk->ik
这个式子中,j
在->
左侧重复使用,且没有出现在->
右侧,因此我们需要对j
求和。由此可知einsum('ij,jk->ik', a, b)
的实际含义为\(c_{ik}=\sum_j a_{ij} b_{jk}\),等同于矩阵乘法。
下面我们简单验证一下计算的正确性。
= 1e-5
eps assert (c - a @ b).abs().mean() < eps
print('检查通过')
检查通过
上面的代碼中,a @ b
是pytorch中計算a
和b
矩陣乘法時的一般寫法。可以看到我們用einsum得到的計算結果與標準寫法的計算結果是一致的。
代码torch.einsum('ij,jk->ik', a, b)
使用了->
符号分隔输入和输出。这被称为显式的einsum。你也可以不使用->
,只给出输入的索引,这被称为隐式的einsum。例如torch.einsum('ij,jk', a, b)
. 注意因为隐式的einsum不指定输出的索引,因此输出的索引将会按字母表顺序排列。
为了简便起见。这篇文章里我只讨论显式的einsum。基于以上介绍的基础规则,我想在这篇文章中分享一些有趣的发现和思考:
- 顺序无关性:einsum的输入参数的顺序是可交换的;
- 手写einsum函数:如何基于python实现一个基础的einsum函数;
- einsum的梯度:einsum函数的梯度仍然可以用einsum表示;
2 顺序无关性
einsum函数的要点是:1. 根据索引从输入参数中取数,计算累积的乘积;2. 对重名的,在输出中没有出现的索引求和。
可以看到einsum的规则和输入参数的顺序没有关系。因此,einsum函数的输入参数是可以交换顺序的。
例如,虽然矩阵乘法没有交换律,但是基于einsum计算矩阵乘法时,你可以交换参数的顺序。
for _ in range(100):
= torch.rand(2, 2)
a = torch.rand(2, 2)
b = torch.einsum('ij,jk->ik', a, b) # 正序
c = torch.einsum('jk,ij->ik', b, a) # 反序
d assert (c - d).abs().mean() < eps
print('检查通过')
检查通过
3 手写einsum
einsum函数的计算是這樣一個過程:一个遍历所有索引,計算輸入元素的乘積,然後求和的過程。如果我们有一个函数能够遍历所有的索引,那会很方便。
我们先来实现这样的一个函数iter_elements
:
import torch
def iter_elements(sizes):
if len(sizes) == 0:
yield []
return
for i in range(sizes[0]):
for j in iter_elements(sizes[1:]):
yield (i, *j)
for idx in iter_elements([2, 3]):
print(idx)
(0, 0)
(0, 1)
(0, 2)
(1, 0)
(1, 1)
(1, 2)
上面的代码实现了一个遍历所有可能索引的函数iter_elements
. 对于尺寸为\(2\times 3\)的矩阵,这个函数输出了所有可能的六个索引。这是一个非常方便的函数。接下来我们就基于它在einsum函数中遍历每一个元素。
def einsum_forward(equation, *operands):
= equation.split('->')
input_dims, output_dim = input_dims.split(',')
input_dims
# 收集所有的索引名
= list(set(''.join(input_dims)))
index_names # 收集每个索引名对应的维度大小
= [-1] * len(index_names)
sizes for shape, tensor in zip(input_dims, operands):
for index_name, size in zip(shape, tensor.shape):
assert sizes[index_names.index(index_name)] == -1 or \
== size
sizes[index_names.index(index_name)] = size
sizes[index_names.index(index_name)]
# 计算输出矩阵的尺寸
= [sizes[index_names.index(name)] for name in output_dim]
output_size # 将输出矩阵用0初始化
= operands[0].new_zeros(output_size)
output # 遍历所有的索引
for idx in iter_elements(sizes):
# 映射到输出矩阵的索引
= tuple(idx[index_names.index(name)] for name in output_dim)
idx_output = 1
prod for input_dim, tensor in zip(input_dims, operands):
# 对于每一个输入tensor,取得对应的索引
= tuple(idx[index_names.index(name)] for name in input_dim)
idx_input # 计算累积的乘积
= prod * tensor[idx_input]
prod # 求和
+= prod
output[idx_output] return output
接下來我們編寫一些測試用例,檢查einsum_forward
和torch.einsum
的計算結果是否一致。
# 矩阵乘法
= torch.rand((5, 4))
a = torch.rand((4, 5))
b = einsum_forward('ij,jk->ik', a, b)
c = torch.einsum('ij,jk->ik', a, b)
c2 = a @ b
c3 assert (c - c2).abs().mean() < eps
assert (c - c3).abs().mean() < eps
# 矩阵转置
= torch.rand((5, 4))
a = einsum_forward('ij->ji', a)
b = torch.einsum('ij->ji', a)
b2 = a.T
b3 assert (b - b2).abs().mean() < eps
assert (b - b3).abs().mean() < eps
# 矩阵的迹
= torch.rand((5, 5))
a = einsum_forward('ii->', a)
b = torch.einsum('ii->', a)
b2 = torch.trace(a)
b3 assert (b - b2).abs().mean() < eps
assert (b - b3).abs().mean() < eps
print('全部檢查通過。')
全部檢查通過。
4 einsum函数的梯度
舉一個簡單的例子。假設\(c_{ik} = \sum_{j} a_{ij} b_{jk}\). 對於任意的\(i,j\),導數\(\frac{\partial L}{\partial c_{ik}}\)都已知,問如何計算\(\frac{\partial L}{\partial a_{ij}}\)和\(\frac{\partial L}{\partial b_{jk}}\)?
顯然 \[ \frac{\partial L}{\partial a_{ij}} = \sum_k \frac{\partial L}{\partial c_{ik}}b_{jk}, \] \[ \frac{\partial L}{\partial b_{jk}} = \sum_i \frac{\partial L}{\partial c_{ik}} a_{ij}. \] 可以看到,在這個例子中,不論是前向的計算還是梯度的計算都可以表示為一系列乘積的和。
如果你仔細地觀察和理解了einsum函數的計算方式,不難得出更一般的結論,一個由einsum函數定義的表達式,其對任意一個輸入參數的梯度同樣可以表示為einsum函數的形式。
現在假設a, b, c
是某個einsum操作的輸入,經過操作我們得到輸出d
,如下所示:
= torch.rand((2, 2), requires_grad=True)
a = torch.rand((2, 2), requires_grad=True)
b = torch.rand((2, 2), requires_grad=True)
c = torch.einsum('ij,jk,kl->il', a, b, c) d
假設\(L\)是\(d\)經過某個函數計算得到的實數,\(\frac{\partial L}{\partial d}\)已知,記為grad_d
:
= torch.rand(d.shape) # 隨機初始化grad_d,假設它就是梯度 grad_d
pytorch能很方便地幫我們分別算出a, b, c
的梯度。
d.backward(grad_d)print(a.grad.shape, b.grad.shape, c.grad.shape)
torch.Size([2, 2]) torch.Size([2, 2]) torch.Size([2, 2])
但是如前所述,einsum函數的梯度仍然可以組織成einsum的形式。讓我們來驗證這一點。
首先是a
的梯度\(\frac{\partial L}{\partial a}\)
= torch.einsum('il,jk,kl->ij', grad_d, b, c)
grad_a assert torch.mean(torch.abs(a.grad - grad_a)) < eps
接著是b
的梯度:
= torch.einsum('il,ij,kl->jk', grad_d, a, c)
grad_b assert torch.mean(torch.abs(b.grad - grad_b)) < eps
c
的梯度:
= torch.einsum('il,ij,jk->kl', grad_d, a, b)
grad_c assert torch.mean(torch.abs(c.grad - grad_c)) < eps
至此,我們驗證了a, b, c
三個參數的梯度計算方式,它們都可以用einsum函數表示,而且計算結果和pytorch的計算結果一致!
5 總結
最近在實現自己的深度學習模型時,我設計了一個簡單的模塊,其中一部分用到了einsum。為了加速這個模塊在梯度反傳階段的計算,我作了詳細的推導,發現einsum有一些有趣的性質,於是就記錄下來,形成了這篇文章。
仔細分析可以發現,其實einsum本質上是計算元素間乘積的和的過程。儘管einsum的使用方式多種多樣,但只要把握這個本質,其性質也就不難理解了。