前言
介绍create_completion_mask函数。
Operating System: Ubuntu 22.04.4 LTS
函数原型
def create_completion_mask(completion_ids, eos_token_id):
"""
Creates a mask for completion tokens that excludes tokens after the EOS token.
Args:
completion_ids (torch.Tensor): Token IDs of the generated completions.
eos_token_id (int): The ID of the end-of-sequence token.
Returns:
torch.Tensor: A binary mask with 1s for valid tokens and 0s after the EOS token.
Explanation:
1. Identifies positions where EOS tokens occur in each sequence.
2. Finds the index of the first EOS token in each sequence.
3. Creates a mask where positions before and including the first EOS are 1, others are 0.
4. If no EOS token is found in a sequence, all positions are set to 1.
"""
is_eos = completion_ids == eos_token_id
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=completion_ids.device)
mask_exists = is_eos.any(dim=1)
eos_idx[mask_exists] = is_eos.int().argmax(dim=1)[mask_exists]
sequence_indices = torch.arange(is_eos.size(1), device=completion_ids.device).expand(is_eos.size(0), -1)
return (sequence_indices <= eos_idx.unsqueeze(1)).int()
函数功能概述
作用:create_completion_mask 用于为一批序列(通常是模型生成的 token 序列)创建一个掩码(mask),这个掩码会在第一个 EOS(end-of-sequence,序列结束)token 出现后把剩下的 token 标记为无效(0),而在第一个 EOS 之前(包括它自己)都标记为有效(1)。如果一条序列里没有 EOS,则全部为 1。
主要用途:
通常用于评估或损失计算时,只考虑 EOS 之前的 token,忽略无效生成。
代码详解
1. 判断 EOS 出现的位置
is_eos = completion_ids == eos_token_id
- 作用:判断每一个 token 是否等于 EOS 的 token id。
- 输入:
completion_ids: (batch_size, seq_len) 的张量,每一行是一条序列的 token id。eos_token_id: 一个整数,代表 EOS 的 token id。
- 输出:
is_eos: 同样 shape 的布尔型张量,True表示该位置是 EOS。
2. 初始化 EOS 索引
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=completion_ids.device)
- 作用:为每一条序列初始化一个默认的“EOS索引”,初值设为序列长度(即假设没有 EOS)。
- 输入:
is_eos.size(0): batch_sizeis_eos.size(1): seq_len
- 输出:
eos_idx: (batch_size,) 的长整型张量,初始值全为 seq_len。
3. 判断每个序列是否包含 EOS
mask_exists = is_eos.any(dim=1)
- 作用:判断每一条序列是否包含至少一个 EOS。
- 输出:
mask_exists: (batch_size,) 布尔型张量,为 True 表示有 EOS。
4. 找到每条序列第一个 EOS 的索引
eos_idx[mask_exists] = is_eos.int().argmax(dim=1)[mask_exists]
- **is_eos.int()**:把布尔值转成 int(True 变成 1,False 变成 0)。
- **argmax(dim=1)**:对每条序列找到第一个 1(即第一个 EOS)的位置索引。
- 只对存在 EOS 的序列进行索引替换。
5. 构造掩码
sequence_indices = torch.arange(is_eos.size(1), device=completion_ids.device).expand(is_eos.size(0), -1)
- 作用:生成一个 shape 为 (batch_size, seq_len) 的索引矩阵,每行是 [0, 1, …, seq_len-1]。
- expand:批量扩展到 batch_size。
6. 比较索引生成最终 mask
return (sequence_indices <= eos_idx.unsqueeze(1)).int()
- **eos_idx.unsqueeze(1)**:把 (batch_size,) 变成 (batch_size, 1),便于广播。
- **sequence_indices <= eos_idx.unsqueeze(1)**:每个 batch 内,索引 <= 第一个 EOS 的位置的地方为 True,否则为 False。
- **.int()**:转为 0/1 张量,就是最终的 mask。
涉及到的 PyTorch 函数
==:元素级比较torch.full(shape, fill_value, dtype, device):创建一个指定 shape 和初值的张量any(dim=1):判断某一维是否有 Trueint():类型转换argmax(dim=1):返回最大值的索引,这里用于第一个 True 的索引torch.arange(n, device=...):生成 [0,1,…,n-1] 的一维序列expand(batch_size, -1):批量扩展unsqueeze(dim):增加一个维度<=:元素级比较,自动广播
举例说明
假如:
completion_ids = torch.tensor([
[10, 9, 8, 7, 1, 3, 5], # EOS=1, 在第4个位置
[4, 6, 2, 8, 9, 11, 12], # 没有 EOS
])
eos_token_id = 1
掩码为:
[
[1, 1, 1, 1, 1, 0, 0], # 前5个位置包含第一个EOS
[1, 1, 1, 1, 1, 1, 1], # 没有EOS,全部为1
]
总结
- 该函数在处理批量生成任务时非常实用,能让你只关注“有意义”的输出部分。
- 涉及的每一步都是在用张量运算高效地实现“找到第一个 EOS,并在它之后全部 mask 掉”。
结语
第三百三十五篇博文写完,开心!!!!
今天,也是充满希望的一天。