生成式模型的Beam Search採樣原理

本文瀏覽次數

Published

September 3, 2024

Abstract
Beam Search是生成式語言模型的常見採樣算法之一。本文結合一些簡單的例子介紹了Beam Search的原理。

生成式語言模型常見的採樣方式有貪心采樣、top k採樣、top p採樣和beam search采樣。 top k采樣和top p采樣都是對候選token作“截斷”,僅考慮概率最大的若干個token的方法。top k只從概率最高的k個候選token中進行隨機選擇。top p同樣優先考慮高概率的token,直到候選token的概率纍積達到p。貪心采樣可以視爲top k和top p的特例,即總是只考慮概率最高的那個token。

無論是貪心、top k還是top p,它們都只探索了一條生成序列,容易落入局部最優。Beam Search的特點是通過在搜索過程中同時考慮概率最高的若干條可能序列,這樣就擴大了搜索範圍。

以我個人的實際項目經驗來看,只要用上beam search,一般模型的評測得分都能小漲幾個點。在一些對輸出字符串有限制的任務中,我們還可以在beam search內引入約束,限制模型的生成。Beam search是一個基礎、方便、能夠即插即用的改善生成質量的生成策略。這篇文章就來簡單記錄下beam search的原理。

假設下面的model存儲了每個token被生成的概率,我們討論如何基於這個模型實現beam search采樣。

顯示初始化代碼
import math 
import numpy as np 
import numpy as np
from dataclasses import dataclass
from typing import List 
from numbers import Number 
import functools

import random
random.seed(0) 

model = {
    'a': 0.7, 
    '<eos>': 0.3, 
}

這是一個簡化了的語言模型——在實際應用中,語言模型的概率分佈會隨上下文變化,但這裡我們假設token產生的概率是與歷史token無關的。

Beam search的要點是,在每次生成下一個token時列舉出所有可能,然後只保留使整體序列概率最大的那n個。

為了理解beam search,我們先手推一遍基於model模型的beam search生成過程,然後再編寫程序與手推的結果進行比對。

beam search的兩個基本參數是beam_widthtokens_to_gen. beam_width表示在生成時要考慮多少個候選序列,而tokens_to_gen表示要生成多少個字符。

假設beam_width = 2,我們在空字符串的基礎上生成3個token。手推生成beam search的生成過程,我們會得到這樣的一個生成樹:

flowchart LR
    subgraph "第0步" 
    A
    end 
    subgraph "第1步" 
    A["空串, 1"] --> B("a, 1 * 0.7 = 0.7")
    A --> C("&lt;eos&gt;, 1 * 0.3 = 0.3")
    end 
    subgraph "第2步" 
    B --> D("aa, 0.7 * 0.7 = 0.49")
    B --> E("a&lt;eos&gt;, 0.3 * 0.7 = 0.21")
    C --> F("&lt;eos&gt;, 0.3")
    end 
    subgraph "第3步" 
    D --> G("aaa, 0.49 * 0.7 = 0.343")
    D --> H("aa&lt;eos&gt;, 0.49 * 0.3 = 0.147")
    F --> I("&lt;eos&gt;, 0.3")
    end

上圖以樹的形式展示了beam search的生成過程。每個節點都以“字符串,概率”的形式表示了當前生成的文本和對應的概率。每一步概率最大的兩個節點會被選中,在下一步生成後續token。在計算過程中需要注意,一旦<eos>token被生成,那麼它不會再產生後續token,其概率也不會再被改變。每生成一個token,當前生成結果的概率便等於上一個token序列的概率乘以生成token的概率。

在樹狀圖的最後一層,我們可以看到算法的最終生成結果。取出概率最大的兩個生成結果,分別得到aaa<eos>. 下面就讓我們編寫對應的beam search程序,並對照檢查結果是否正確。

1 Beam Search代碼實現

我們先構造一個BeamSearchCandidate類用於表示beam search的生成結果。這個類儲存了當前生成的token和對應的概率。

@functools.total_ordering
@dataclass 
class BeamSearchCandidate:
    tokens: List[str]
    logprob: Number 

    def __repr__(self):
        return f'{"".join(self.tokens)}:{math.exp(self.logprob):.2f}'

    def __ge__(self, other: 'BeamSearchCandidate'):
        return self.logprob > other.logprob

實踐中,我們經常存儲概率的對數,即代碼中的logprob,這樣能將概率的相乘轉化為對數概率的相加,減少計算代價,改善數值穩定性問題。

在理解算法運行過程的基礎上,程序的實現並不困難。同樣,以model這個簡單模型為例,beam search算法實現如下:

def simple_top_k(arr, k):
    '''一種簡單的top k算法。'''
    arr.sort()
    return arr[-k:]

import tqdm 


def beam_search(
    model, 
    input_sequence:str,
    beam_width:int,
    tokens_to_gen:int,
    top_k=simple_top_k, # ←後續我們會討論top k算法的優化
) -> List[BeamSearchCandidate]:
    candidates = [BeamSearchCandidate([input_sequence], 0.)]
    for i in tqdm.tqdm(range(tokens_to_gen)):
        next_candidates = []
        for cdd in candidates:
            if len(cdd.tokens) and cdd.tokens[-1] == '<eos>':
                next_candidates.append(cdd)
            else:
                for c, lp in zip(model.tokens, model.logprob(cdd.tokens)):
                    next_candidates.append(
                        BeamSearchCandidate(
                            cdd.tokens + [c],  # 計算每個候選語句的概率
                            cdd.logprob + lp
                        )
                    )
        # 每一步都只保留概率最高的那些候選語句
        candidates = top_k(next_candidates, k=beam_width)
    return candidates

從beam search的運算結果可以看到,其生成結果與我們手動計算的一致。

ret = beam_search(model, '', beam_width=2, tokens_to_gen=3)
print(ret)
100%|██████████| 3/3 [00:00<00:00, 2092.62it/s]
[<eos>:0.30, aaa:0.34]

接下來本文記錄筆者實踐中遇到的一些問題和解決思路。

2 Top K算法優化

Beam search算法中,選擇概率最大的若干個候選token序列是一個比較耗時的操作。以上beam search函數使用了排序算法來找到前\(k\)個概率最大的候選文本,這一操作的時間複雜度為\(O(n\log n)\)\(n\)為總的候選token數量。因為大模型的詞表比較大,所以改進beam search所用的top k算法能帶來不少加速。

第一個思路是對選擇排序進行改造。在排序的過程中,使其在完成k個數字的排序時停止。這樣就獲得了一種簡單的top k選擇算法,此類方法的時間複雜度是\(O(nk)\)\(k\)beam width

另一種思路是基於優先隊列(堆)實現top k算法。我們可以維護一個大小為\(k\)的堆,在遍歷過所有\(n\)個候選token後,堆中留下的便是概率最大的\(k\)個預測序列。維護堆的複雜度為\(O(\log k)\),共維護\(n\)次,因此總的時間複雜度為\(O(n\log k)\).

最後,一種非常高效的,但知道的人比較少的top k算法是快速選擇算法(quick select),其時間複雜度為\(O(n)\). 顯然這是top k算法所能達到的最優時間複雜度——因為不管哪種算法也要把原始輸入序列逐個過一遍。
與前面幾種top k算法相比,快速選擇算法的缺點在於其無法保證返回的k個結果的有序性。即quick select方法以犧牲結果的有序性為代價改進了其時間複雜度。

3 數值精度優化

  1. 大模型的詞表一般很大,這就導致所有token的生成概率可能會整體偏低;
  2. 模型內部實現上並不是直接返回\([0, 1)\)之間的概率數值,而是返回範圍在\((-\infty, \infty)\)內的logits;
  3. 為了加速推理和訓練,人們經常使用float16等低精度浮點型。

以上因素的共同作用下,beam search很容易遇到數值溢出的問題。在實踐中,可以利用將logits同時加上或者減去任意數,不影響概率大小的特性來調控logits數值的範圍,盡量避免溢出。這一原理我在之前的《Softmax原理》一文中也有講解。

4 長度歸一化

模型有時候可能會傾向於產生較長或較短的文本,這時候對概率作長度歸一化是一種可以考慮的手段。

在未引入長度歸一化時,beam search選取候選token的條件為 \[ \arg\max \sum_{t=1}^{T_y}\log P(y_t|X, y_1, y_2, \dots, y_{t-1}), \] 其中\(T_y\)為序列長度。

在引入長度歸一化後,條件變為: \[ \arg\max \frac{1}{T_y^\alpha} \sum_{t=1}^{T_y}\log P(y_t|X, y_1, y_2, \dots, y_{t-1}). \] 新的條件引入了一個額外的參數\(\alpha\in[0, 1]\)。當\(\alpha\)為0時,beam search的行為相當於不作任何歸一化;\(\alpha=1\)則對應著“標準的”歸一化操作。如果想要取得一個折中的效果,可以設\(\alpha\)\((0, 1)\)內的一個數字。

5 推薦閱讀

By @執迷 in
Tags : #LLM, #大模型, #Beam Search, #生成式模型,