RoPE:旋转位置编码

2024/01/02 LLM 共 4609 字,约 14 分钟

RoPE 为苏剑林大佬之作,最早应用于他自研的 RoFormer (Rotary Transformer),属于相对位置编码。效果优于绝对位置编码和经典式相对位置编码。出自论文:《RoFormer: Enhanced Transformer with Rotary Position Embedding》

据我了解,最近发布的大语言模型:Meta 的 LLaMA、清华的 ChatGLM 都采用了 RoPE。这也足以证明了 RoPE 的优势。

苏神在旋转角度 \(\theta\) 的选择上沿用了 tansformer 的 \(\theta_i = 10000^{-2i/d)\)。因为苏神实验发现,在 RoPE 中采用这个 \(\theta\) 也可以带来一定的 远程衰减性(意思就是 token 之间的依赖关系会随着距离的变远而衰减,这也符合我们的直观理解)。当然别的 \(\theta\) 也可,只要满足远程衰减。

import torch

dim = 512
base = 10000

# dim // 2
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
print(inv_freq.shape)
print(inv_freq[:10])

"""
输出结果:
torch.Size([256])

tensor([1.0000, 0.9647, 0.9306, 0.8977, 0.8660, 0.8354, 0.8058, 0.7774, 0.7499,
        0.7234])
"""

从输出的结果可以看出,使用 \(\theta_i = 10000^{-2i/d)\) 呈现出了远程衰减的特点,离第 0 个位置越远,其值越小。

代码实现(PyTorch)

import torch
import math
import torch.nn as nn
import torch.nn.functional as F

"""
计算 cos m \theta_i 和 sin m \theta_i
"""
def sinusoidal_position_embeddings(batch_size = 16, num_head = 8, seq_len = 128, hidden_dim = 768 // 8):
    # (seq_len) -> (seq_len, 1) 即公式中的 m
    position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)

    # (hidden_dim // 2) 即公式中的 i
    ids = torch.arange(0, hidden_dim // 2, dtype=torch.float)

    # 旋转角度 theta = 10000^(-2i / d)
    # (hidden_dim // 2)
    theta = torch.pow(10000, -2 * ids / hidden_dim)

    # 公式中的 m * \theta_i
    # (seq_len, 1) * (hidden_dim // 2) -> (seq_len, hidden_dim // 2)
    embeddings = position * theta

    # 计算 sin(m * \theta_i) cos(m * \theta_i)
    # (seq_len, hidden_dim // 2, 2)
    embeddings = torch.stack([
        torch.sin(embeddings),
        torch.cos(embeddings)
    ], dim=-1)

    # (bs, head, seq_len, hidden_dim // 2, 2)
    # 在 bs、head 维度重复,其他维度都是 1(不重复)
    embeddings = embeddings.repeat((batch_size, num_head, *([1] * len(embeddings.shape))))
    
    # (bs, head, seq_len, hidden_dim)
    # reshape 后就是:偶数 sin, 奇数 cos 了
    embeddings = torch.reshape(embeddings, (batch_size, num_head, seq_len, hidden_dim))

    return embeddings
"""
计算 q_i * cos m\theta_i + q_i * sin m\theta_i
"""
def RoPE(q, k):
    # q, k: (bs, head, max_len, output_dim)
    batch_size, num_head, max_len, output_dim = q.shape

    # (bs, head, seq_len, hidden_dim)
    pos_emb = sinusoidal_position_embeddings(batch_size, num_head, max_len, output_dim)

    # cos_pos, sin_pos: (bs, head, max_len, output_dim // 2) -> (bs, head, max_len, output_dim)
    # 看 rope 公式可知,相邻 cos,sin 之间是相同的,所以复制一遍。
    # 如 (1,2,3) 变成 (1,1,2,2,3,3)
    cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1) # 将奇数列信息抽取出来也就是 cos 拿出来并复制
    sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1) # 将偶数列信息抽取出来也就是 sin 拿出来并复制

    # q, k: (bs, head, max_len, output_dim)
    q2 = torch.stack([
        -q[..., 1::2],
        q[...,::2]
    ], dim=-1)
    q2 = q2.reshape(q.shape) # reshape 后就是正负交替了
    # 更新 q_i * cos m\theta_i + q_i * sin m\theta_i(对应位置相乘)
    q = q * cos_pos + q2 * sin_pos

    k2 = torch.stack([
        -k[..., 1::2],
        k[...,::2]
    ], dim=-1)
    k2 = k2.reshape(k.shape) # reshape 后就是正负交替了
    # 更新 k_i * cos m\theta_i + k_i * sin m\theta_i(对应位置相乘)
    k = k * cos_pos + k2 * sin_pos

    return q, k

计算 Self-Attention:

def attention(q, k, v, mask=None, dropout=None, use_rope=True):
    # q.shape: (bs, head, seq_len, dk)
    # k.shape: (bs, head, seq_len, dk)
    # v.shape: (bs, head, seq_len, dk)
    if use_rope:
        q, k = RoPE(q, k)
    
    d_k = k.size()[-1]

    # q * k^T
    # (bs, head, seq_len, seq_len)
    att_logits = torch.matmul(q, k.transpose(-2, -1))
    att_logits /= math.sqrt(d_k)

    if mask is not None:
        # mask 掉为 0 的部分,设为一个较大的负数值(-10000.0)来屏蔽无效位置
        att_logits = att_logits.masked_fill(mask == 0, -1e9)
    
    # softmax
    # (bs, head, seq_len, seq_len)
    att_scores = F.softmax(att_logits, dim=-1)

    if dropout is not None:
        att_scores = dropout(att_scores)

    # (bs, head, seq_len, seq_len) * (bs, head, seq_len, dk) = (bs, head, seq_len, dk)
    return torch.matmul(att_scores, v), att_scores

测试:

# (bs, head, seq_len, dk)
q = torch.randn((8, 12, 10, 32))
k = torch.randn((8, 12, 10, 32))
v = torch.randn((8, 12, 10, 32))

batch_size, num_head, max_len, output_dim = q.shape

embeddings = sinusoidal_position_embeddings(batch_size, num_head, max_len, output_dim)
print(embeddings.shape)
print(embeddings[1, 1, 1, 1:17:2]) # 奇数 -> cos_pos
print(embeddings[1, 1, 1, 0:16:2]) # 偶数 -> sin_pos

print(embeddings[1, 1, 0, 1:17:2]) # 奇数 -> cos_pos
print(embeddings[1, 1, 0, 0:16:2]) # 偶数 -> sin_pos

print(embeddings[1, 1, 2, 1:17:2]) # 奇数 -> cos_pos
print(embeddings[1, 1, 2, 0:16:2]) # 偶数 -> sin_pos
print("-----" * 10)

res, att_scores = attention(q, k, v, mask=None, dropout=None, use_rope=True)

# (bs, head, seq_len, dk),  (bs, head, seq_len, seq_len)
print(res.shape, att_scores.shape)

输出结果为:

torch.Size([8, 12, 10, 32])
tensor([0.5403, 0.8460, 0.9504, 0.9842, 0.9950, 0.9984, 0.9995, 0.9998])
tensor([0.8415, 0.5332, 0.3110, 0.1769, 0.0998, 0.0562, 0.0316, 0.0178])
tensor([1., 1., 1., 1., 1., 1., 1., 1.])
tensor([0., 0., 0., 0., 0., 0., 0., 0.])
tensor([-0.4161,  0.4315,  0.8066,  0.9374,  0.9801,  0.9937,  0.9980,  0.9994])
tensor([0.9093, 0.9021, 0.5911, 0.3482, 0.1987, 0.1122, 0.0632, 0.0356])
--------------------------------------------------
torch.Size([8, 12, 10, 32]) torch.Size([8, 12, 10, 10])

可以到 Colab ipynb 文件 中运行

参考

文档信息

-->

Search

    Table of Contents