RoPE 为苏剑林大佬之作,最早应用于他自研的 RoFormer (Rotary Transformer),属于相对位置编码。效果优于绝对位置编码和经典式相对位置编码。出自论文:《RoFormer: Enhanced Transformer with Rotary Position Embedding》。
据我了解,最近发布的大语言模型:Meta 的 LLaMA、清华的 ChatGLM 都采用了 RoPE。这也足以证明了 RoPE 的优势。
苏神在旋转角度 \(\theta\) 的选择上沿用了 tansformer 的 \(\theta_i = 10000^{-2i/d)\)。因为苏神实验发现,在 RoPE 中采用这个 \(\theta\) 也可以带来一定的 远程衰减性(意思就是 token 之间的依赖关系会随着距离的变远而衰减,这也符合我们的直观理解)。当然别的 \(\theta\) 也可,只要满足远程衰减。
import torch
dim = 512
base = 10000
# dim // 2
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
print(inv_freq.shape)
print(inv_freq[:10])
"""
输出结果:
torch.Size([256])
tensor([1.0000, 0.9647, 0.9306, 0.8977, 0.8660, 0.8354, 0.8058, 0.7774, 0.7499,
0.7234])
"""
从输出的结果可以看出,使用 \(\theta_i = 10000^{-2i/d)\) 呈现出了远程衰减的特点,离第 0 个位置越远,其值越小。
代码实现(PyTorch)
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
"""
计算 cos m \theta_i 和 sin m \theta_i
"""
def sinusoidal_position_embeddings(batch_size = 16, num_head = 8, seq_len = 128, hidden_dim = 768 // 8):
# (seq_len) -> (seq_len, 1) 即公式中的 m
position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
# (hidden_dim // 2) 即公式中的 i
ids = torch.arange(0, hidden_dim // 2, dtype=torch.float)
# 旋转角度 theta = 10000^(-2i / d)
# (hidden_dim // 2)
theta = torch.pow(10000, -2 * ids / hidden_dim)
# 公式中的 m * \theta_i
# (seq_len, 1) * (hidden_dim // 2) -> (seq_len, hidden_dim // 2)
embeddings = position * theta
# 计算 sin(m * \theta_i) cos(m * \theta_i)
# (seq_len, hidden_dim // 2, 2)
embeddings = torch.stack([
torch.sin(embeddings),
torch.cos(embeddings)
], dim=-1)
# (bs, head, seq_len, hidden_dim // 2, 2)
# 在 bs、head 维度重复,其他维度都是 1(不重复)
embeddings = embeddings.repeat((batch_size, num_head, *([1] * len(embeddings.shape))))
# (bs, head, seq_len, hidden_dim)
# reshape 后就是:偶数 sin, 奇数 cos 了
embeddings = torch.reshape(embeddings, (batch_size, num_head, seq_len, hidden_dim))
return embeddings
"""
计算 q_i * cos m\theta_i + q_i * sin m\theta_i
"""
def RoPE(q, k):
# q, k: (bs, head, max_len, output_dim)
batch_size, num_head, max_len, output_dim = q.shape
# (bs, head, seq_len, hidden_dim)
pos_emb = sinusoidal_position_embeddings(batch_size, num_head, max_len, output_dim)
# cos_pos, sin_pos: (bs, head, max_len, output_dim // 2) -> (bs, head, max_len, output_dim)
# 看 rope 公式可知,相邻 cos,sin 之间是相同的,所以复制一遍。
# 如 (1,2,3) 变成 (1,1,2,2,3,3)
cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1) # 将奇数列信息抽取出来也就是 cos 拿出来并复制
sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1) # 将偶数列信息抽取出来也就是 sin 拿出来并复制
# q, k: (bs, head, max_len, output_dim)
q2 = torch.stack([
-q[..., 1::2],
q[...,::2]
], dim=-1)
q2 = q2.reshape(q.shape) # reshape 后就是正负交替了
# 更新 q_i * cos m\theta_i + q_i * sin m\theta_i(对应位置相乘)
q = q * cos_pos + q2 * sin_pos
k2 = torch.stack([
-k[..., 1::2],
k[...,::2]
], dim=-1)
k2 = k2.reshape(k.shape) # reshape 后就是正负交替了
# 更新 k_i * cos m\theta_i + k_i * sin m\theta_i(对应位置相乘)
k = k * cos_pos + k2 * sin_pos
return q, k
计算 Self-Attention:
def attention(q, k, v, mask=None, dropout=None, use_rope=True):
# q.shape: (bs, head, seq_len, dk)
# k.shape: (bs, head, seq_len, dk)
# v.shape: (bs, head, seq_len, dk)
if use_rope:
q, k = RoPE(q, k)
d_k = k.size()[-1]
# q * k^T
# (bs, head, seq_len, seq_len)
att_logits = torch.matmul(q, k.transpose(-2, -1))
att_logits /= math.sqrt(d_k)
if mask is not None:
# mask 掉为 0 的部分,设为一个较大的负数值(-10000.0)来屏蔽无效位置
att_logits = att_logits.masked_fill(mask == 0, -1e9)
# softmax
# (bs, head, seq_len, seq_len)
att_scores = F.softmax(att_logits, dim=-1)
if dropout is not None:
att_scores = dropout(att_scores)
# (bs, head, seq_len, seq_len) * (bs, head, seq_len, dk) = (bs, head, seq_len, dk)
return torch.matmul(att_scores, v), att_scores
测试:
# (bs, head, seq_len, dk)
q = torch.randn((8, 12, 10, 32))
k = torch.randn((8, 12, 10, 32))
v = torch.randn((8, 12, 10, 32))
batch_size, num_head, max_len, output_dim = q.shape
embeddings = sinusoidal_position_embeddings(batch_size, num_head, max_len, output_dim)
print(embeddings.shape)
print(embeddings[1, 1, 1, 1:17:2]) # 奇数 -> cos_pos
print(embeddings[1, 1, 1, 0:16:2]) # 偶数 -> sin_pos
print(embeddings[1, 1, 0, 1:17:2]) # 奇数 -> cos_pos
print(embeddings[1, 1, 0, 0:16:2]) # 偶数 -> sin_pos
print(embeddings[1, 1, 2, 1:17:2]) # 奇数 -> cos_pos
print(embeddings[1, 1, 2, 0:16:2]) # 偶数 -> sin_pos
print("-----" * 10)
res, att_scores = attention(q, k, v, mask=None, dropout=None, use_rope=True)
# (bs, head, seq_len, dk), (bs, head, seq_len, seq_len)
print(res.shape, att_scores.shape)
输出结果为:
torch.Size([8, 12, 10, 32])
tensor([0.5403, 0.8460, 0.9504, 0.9842, 0.9950, 0.9984, 0.9995, 0.9998])
tensor([0.8415, 0.5332, 0.3110, 0.1769, 0.0998, 0.0562, 0.0316, 0.0178])
tensor([1., 1., 1., 1., 1., 1., 1., 1.])
tensor([0., 0., 0., 0., 0., 0., 0., 0.])
tensor([-0.4161, 0.4315, 0.8066, 0.9374, 0.9801, 0.9937, 0.9980, 0.9994])
tensor([0.9093, 0.9021, 0.5911, 0.3482, 0.1987, 0.1122, 0.0632, 0.0356])
--------------------------------------------------
torch.Size([8, 12, 10, 32]) torch.Size([8, 12, 10, 10])
可以到 Colab ipynb 文件 中运行
参考
论文:
科学空间:
CSDN:
[Rotary Position Embedding (RoPE, 旋转式位置编码) 原理讲解 + torch 代码实现](https://blog.csdn.net/weixin_43646592/article/details/130924280)
文档信息
- 本文作者:Bookstall
- 本文链接:https://bookstall.github.io/2024/01/02/rope/
- 版权声明:自由转载-非商用-非衍生-保持署名(创意共享3.0许可证)