本人求职中!大模型 Post-Training 方向 · 长三角、福建地区优先 · 如果您有兴趣,欢迎联系 zhimi64@foxmail.com 👀 了解更多

KTO算法的学习和推导

本文瀏覽次數

之前我在一个实际项目里面负责一个LLM的偏好的对齐。神奇的是这个项目只有单个样本的拒绝/采样信号,没有pair-wise的数据。我还是第一次遇到这种情况。

很自然的,我想到要试试KTO的效果。KTO是一种只需要单个样本偏好数据的偏好对齐算法。

这篇文章记录了我阅读KTO这篇论文的笔记,以及对KTO背后思想的一些理解。

HALO

KTO提出了human-aware losses(HALO)的分类学概念。损失函数可以分为符合HALO和不符合HALO规范的。虽然没有办法证明HALO类的损失函数一定优于非HALO的,但是论文表示,从经验上看,HALO损失函数的效果一般更好。

HALO的定义: 假设\(\theta\)表示模型\(\pi_\theta:\mathcal X\rightarrow\mathcal P(\mathcal Y)\)的可训练参数。\(\pi_\text{ref}\)是基准模型,\(l:\mathcal Y\rightarrow \mathbb R^+\)是一个归一化系数,\(r_\theta(x, y)=l(y) \log [\pi_\theta(y | x)/\pi_\text{ref}(y|x)]\)是隐含的奖励函数。数据\((x, y)\)的“人类价值”应该可以形式化为: \[ v(r_\theta(x, y) - \mathbb E_Q[r_\theta(x, y')]), \] 其中\(Q(Y'|x)\)\(\mathcal Y\)上的一个基准点数据的分布,\(v:\mathbb R\rightarrow \mathbb R\)是非降的,在\((0, \infty)\)内为凹的函数。

函数\(f\)如果满足以下性质,就可以视为基于v的“human-aware的损失函数”: \[ f(\pi_\theta, \pi_\text{ref})=\mathbb E_{x, y\sim\mathcal D}[a_{x, y} v(r_\theta(x, y) - \mathbb E_Q[r_\theta(x, y')])] + C_{\mathbfcal D} \] 其中 \(a_{x, y}\in \{-1, +1\}\). \(\mathcal D\)是数据集,\(C_{\mathcal D}\in\mathbb R\)是由数据集决定的常数。

如何判断一个对齐方法是否属于HALO

CSFT不是HALO

CSFT会在prompt中注入control token,例如<good><bad>,与相应的回复数据拼接在一起,构造得到训练数据。推理的时候,一律采用<good>这个control token来获得回复。

所以CSFT非常类似普通的SFT。

\[ L_\text{CSFT} = -\log \pi_\theta(y|x, c), \] 其中\(c=<\text{good}>\)或者\(<\text{bad}>\)

假设\(L_\text{CSFT}\)符合HALO的定义,那么存在 \(r_\theta(x, y)=l(y)\log \frac{\pi_\theta(y | x)}{\pi_\text{ref}(y|x)}\)\(R(x) = E_Q[r_\theta(x, y')]\),使得 \[ L_\text{CSFT} = a v(r_\theta(x, y) - R(x)) + C \] 代入HALO对\(r_\theta\)的定义,上式展开得到 \[ \begin{aligned} L_\text{CSFT} = a v \left(l(y)\log \pi_\theta(y|x) - l(y)\log \pi_\text{ref}(y|x)- R(x) \right) + C\\ \end{aligned} \] 要看CSFT能否构成HALO,就是看能否找到合适的\(a, v, l(y), R(x), C\)使得下面的等式成立: \[ -\log \pi_\theta(y|x, c) = a v \left(l(y)\log \pi_\theta(y|x) - l(y)\log \pi_\text{ref}(y|x)- R(x) \right) + C \]

我们比较等式两边的依赖:

  • 左边:只依赖\(\pi_\theta\)
  • 右边:
    • \(\log \pi_\theta(y|x)\)
    • \(\log \pi_\text{ref}(y|x)\)
    • \(R(x)\)

为了让等式对所有的y成立,\(R(x)\)就必须包含\(\pi_\text{ref}\)。除非\(\pi_\text{ref}\)是不依赖\(y\)的均匀分布,否则是无法成功构造的。

因此论文得出结论:CSFT不是一种HALO。

SLiC不是HALO

SLiC(Sequence Likelihood Calibration)的损失函数如下: \[ \begin{aligned} &L_\text{cal} (\pi_\theta)= \mathbb E_{x, y_w, y_l\sim D}\left[\max\left(0, \delta - \log\frac{\pi_\theta(y_w|x)}{\pi_\theta(y_l|x)}\right)\right]\\ &L_\text{reg} (\pi_\theta, \pi_\text{ref}) = \mathbb E_{x\sim D, y\sim\pi_\text{ref}(x)} [-\log \pi_\theta(y|x)]\\ &L_\text{SLiC}(\pi_\theta, \pi_\text{ref}) = L_\text{cal} (\pi_\theta) + \lambda_\text{reg}L_\text{reg}(\pi_\theta, \pi_\text{ref}) \end{aligned} \] 分析SLiC的损失函数,可以看到 \[ \begin{aligned} &max(0, \delta - \log \frac{\pi_\theta(y_w|x)}{\pi_\theta(y_l|x)})\\ =&max(0, \delta - (\log \pi_\theta(y_w|x) - \log\pi_\theta(y_l|x))) \end{aligned} \] 这部分实际上希望模型输出\(y_w\)的对数似然大于输出\(y_l\)的对数似然,但差距不要差太大。

\(L_\text{reg}\)做的事情就是让\(\pi_\text{ref}\)负责\(y\)的采样,用\(\text{ref}\)模型产生的数据约束当前训练中的模型,约束模型不要偏离基础模型太远。

和CSFT的推导过程类似,由于\(\pi_\text{ref}\)只用来采样数据,但不直接构成损失函数的一项,SLiC也不是HALO。

DPO是HALO

DPO的损失函数是 \[ L_\text{DPO}(\pi_\theta, \pi_\text{ref}) = \mathbb E_{x, y_w, y_l}\left[ -\log \sigma \left( \beta \log \frac{\pi_\theta(y_w|x)}{\pi_\text{ref}(y_w|x)} - \beta \log \frac{\pi_\theta(y_l|x)}{\pi_\text{ref}(y_l|x)} \right) \right], \]\(l(y)=\beta, r_\theta = \beta \log(\frac{\pi_\theta(y|x)}{\pi_\text{ref}(y|x)}), v(\cdot) = \log \sigma(\cdot)\)\(Q\)是一个将全部概率质量集中在\((x, y_l)\)上的分布。\(a_{x, y}=-1\). 通过这样的构造,我们可以验证DPO符合HALO的定义。

PPO-Clip是HALO

\[ L_\text{PPO~(offline)} = - \mathbb E _{x, y, t\sim D}\left[ \min(q_\theta A(x:y_{<t}, y_t), \text{clip}(q_\theta, 1 - \epsilon, 1 + \epsilon)A(x:y_{<t}, y_t)) \right] \] 其中\(q_\theta = \frac{\pi_\theta(y_t|x:y_{<t})}{\pi_\text{ref}(y_t|x:y_{<t})}\)是token级别的概率比值。

\(A(x:y_{<t}, y_t)\)是token级别的优势函数,可以表示为\(Q^\pi(x:y_{<t}, y_t) - V^\pi(x:y_{<t})\),即动作-价值函数和价值函数的差值。因为\(V^\pi(x:y_{<t}) = \mathbb E_{y\sim \pi}Q^\pi(x:y_{<t}, y)\),因此基准点分布(reference distribution)就是policy本身。

那么根据HALO的定义,\(r_\theta\)可以构造为\(q_\theta Q^\pi(x:y_{<t}, t)\)。这里需要不失一般性地,假设\(Q^\pi\)是非负的,因为\(Q^\pi\)总是可以加上一个正数而不改变优势函数。这意味着\(\exists u \geq 1, q_\theta Q^\pi(x:y_{<t}, y)=\log u = \log \hat \pi_\theta(x:y_{<t}, y) / \hat \pi_\text{ref}(x:y_{<t}, y)\),其中\(\hat \pi_\theta, \hat \pi_\text{ref}\)是隐含的策略分布和参考分布。 \(q_\theta A = r_\theta - z_0, v(q_\theta A) = min(q_\theta A, A(1 + \text{sign}(q_\theta A)\epsilon)), a_{x, y}=-1\)就能完成构造。

重新梳理一遍:

\[ \begin{aligned} & v(r_\theta(x, y)) - \mathbb E_Q[r_\theta(x, y')] & \\ = & v(q_\theta Q^\pi - \mathbb E_Q[q_\theta Q^\pi]) & (\text{令} r_\theta = q_\theta Q^\pi)\\ = & v(q_\theta Q^\pi - \mathbb E_{y\sim\pi_\text{ref}}[q_\theta Q^\pi]) & (\text{基准分布就是策略本身})\\ = & v(q_\theta Q^\pi - q_\theta V) & (\text{因为}\mathbb E_{\pi_\text{ref}}[q_\theta Q^\pi] = \mathbb E_{\pi_\theta} [Q^{\pi_\theta}] = V^{\pi_\theta})\\ = & v(q_\theta A) & \text{因为}(A = Q^{\pi_\theta} - V^{\pi_\theta}) \\ = & min(q_\theta A, \text{clip}(q_\theta, 1-\epsilon, 1+\epsilon) A) & \text{代入}v\text{的构造} \end{aligned}\\ \]

Note

根据论文的证明,\(r_\theta\)的构造实际对应着隐含的策略\(\hat \pi_\theta\)\(\hat \pi_\text{ref}\)。实际上这两个策略并不是真实存在的,或者可以直接由真实的策略构造出来的。因此这里的证明与其说是说明“对于PPO所优化的策略”而言,PPO是一种HALO,倒不如说“存在一种隐含的策略”,对于它来说PPO可以解释为HALO。

KTO

前文中已经罗列了一些偏好对齐方法,并且判断他们是否属于HALO。接著论文作者提出了自己的设计:Kahneman-Tversky Optimization(KTO)。Kahneman和Tversky认为人类价值函数可以用如下的算式建模: \[ v(z; \lambda ,\alpha, z_0) = \left\{ \begin{aligned} & (z - z_0)^\alpha & \text{if} z \geq z_0 \\ & -\lambda (z_0 - z) ^\alpha & \text{if} z < z_0 \end{aligned} \right. \] 但是这样的设计在训练模型的时候会遇到数值稳定性问题,因此作者将指数函数替换为logistic函数。

为了模拟人类的“损失厌恶”倾向,KTO引入了\(\beta \in \mathbb R^+\)参数。\(\beta\)越大,价值函数越容易饱和,对应著人类在正收益时的损失厌恶,在面临损失时又倾向于接受风险。

原模型中的\(\lambda\)参数被分化为\({\lambda_D, \lambda_U}\)两个参数,用于分别控制desirale和undesirable两种数据的权重。

KTO函数设计为: \[ L_\text{KTO} (\pi_\theta, \pi_\text{ref}) = \mathbb E_{x, y\sim D}[\lambda_y - v(x, y)], \] 其中 \[ \begin{aligned} r_\theta(x, y) & = \log \frac{\pi_\theta(y|x)}{\pi_\text{ref}(y|x)} \\ z_0 &= \text{KL}(\pi_\theta(y'|x)\Vert \pi_\text{ref}(y'|x))\\ v(x, y) &= \left \{ \begin{aligned} \lambda_D \sigma(\beta(r_\theta(x, y) - z_0))~~~~& \text{if} ~y\sim y_\text{desirable} | x \\ \lambda_U \sigma(\beta(z_0 - r_\theta(x, y)))~~~~& \text{if} ~y\sim y_\text{undesirable} | x \end{aligned} \right. \end{aligned} \]

其中\(z_0\)不参与梯度反向传播,以保持训练过程稳定。

\(z_0\)的计算本质是估计一个KL散度。理论上需要用蒙特卡洛方法从\(\pi_\theta\)中采样\(y\),再平均。采样过程是很慢的,代价很大。所以作者这里用了一个取巧的方法。对一个同一个batch内的一组数据\(\{(x_1, y_1), (x_2, y_2), \dots, (x_m, y_m)\}\),KTO首先让\(x\)\(y\)错位组合变成\(\{(x_1, y_2), (x_2, y_3), \dots, (x_m, y_1)\}\). KTO使用以下的公式估计\(z_0\): \[ \begin{aligned} &\hat z_0 = \max\left( 0, \frac{1}{m}\sum_{i\leq i\lt m} \log \frac{\pi_\theta(y_i|x_i)}{\pi_\text{ref}(y_j|x_i)} \right),\\ &\text{where} ~j = (i + 1) \mod m \end{aligned} \]

也就是用训练集中固有的\(y\)来避免重新采样。但是\(y_j\)\(x_i\)完全没有关系怎么办?作者认为这是biased,但是方便啊。估计需要使用\(\max(0, \cdot)\)保证估计得到的KL散度是非负的。

\(\hat z_0\)对于每个batch计算一次,在batch内是共享的。

一些思考

KTO其实依赖一个比较强的假设——损失厌恶。但是这在我的实际项目里是不成立的。在我的项目中,被拒绝的数据不一定是坏数据,被采纳的数据不一定是好数据。噪声太大了。我不得不将undesired_weight调整到一个比较低的水平,才能跑起来,否则训练就不稳定。

在“偏好”这个信号有很强的噪声时,该如何利用偏好数据蕴含的信息进行训练呢?这个问题没法很好地用KTO解决,但很有实际研究价值。

By @執迷 in
Tags :