00335 收集的 pytorch 函数——create_completion_mask


前言

介绍create_completion_mask函数。

Operating System: Ubuntu 22.04.4 LTS

函数原型

def create_completion_mask(completion_ids, eos_token_id):
    """
    Creates a mask for completion tokens that excludes tokens after the EOS token.

    Args:
        completion_ids (torch.Tensor): Token IDs of the generated completions.
        eos_token_id (int): The ID of the end-of-sequence token.

    Returns:
        torch.Tensor: A binary mask with 1s for valid tokens and 0s after the EOS token.

    Explanation:
        1. Identifies positions where EOS tokens occur in each sequence.
        2. Finds the index of the first EOS token in each sequence.
        3. Creates a mask where positions before and including the first EOS are 1, others are 0.
        4. If no EOS token is found in a sequence, all positions are set to 1.
    """
    is_eos = completion_ids == eos_token_id
    eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=completion_ids.device)
    mask_exists = is_eos.any(dim=1)
    eos_idx[mask_exists] = is_eos.int().argmax(dim=1)[mask_exists]
    sequence_indices = torch.arange(is_eos.size(1), device=completion_ids.device).expand(is_eos.size(0), -1)
    return (sequence_indices <= eos_idx.unsqueeze(1)).int()

函数功能概述

作用:
create_completion_mask 用于为一批序列(通常是模型生成的 token 序列)创建一个掩码(mask),这个掩码会在第一个 EOS(end-of-sequence,序列结束)token 出现后把剩下的 token 标记为无效(0),而在第一个 EOS 之前(包括它自己)都标记为有效(1)。如果一条序列里没有 EOS,则全部为 1。

主要用途:
通常用于评估或损失计算时,只考虑 EOS 之前的 token,忽略无效生成。


代码详解

1. 判断 EOS 出现的位置

is_eos = completion_ids == eos_token_id
  • 作用:判断每一个 token 是否等于 EOS 的 token id。
  • 输入
    • completion_ids: (batch_size, seq_len) 的张量,每一行是一条序列的 token id。
    • eos_token_id: 一个整数,代表 EOS 的 token id。
  • 输出
    • is_eos: 同样 shape 的布尔型张量,True 表示该位置是 EOS。

2. 初始化 EOS 索引

eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=completion_ids.device)
  • 作用:为每一条序列初始化一个默认的“EOS索引”,初值设为序列长度(即假设没有 EOS)。
  • 输入
    • is_eos.size(0): batch_size
    • is_eos.size(1): seq_len
  • 输出
    • eos_idx: (batch_size,) 的长整型张量,初始值全为 seq_len。

3. 判断每个序列是否包含 EOS

mask_exists = is_eos.any(dim=1)
  • 作用:判断每一条序列是否包含至少一个 EOS。
  • 输出
    • mask_exists: (batch_size,) 布尔型张量,为 True 表示有 EOS。

4. 找到每条序列第一个 EOS 的索引

eos_idx[mask_exists] = is_eos.int().argmax(dim=1)[mask_exists]
  • **is_eos.int()**:把布尔值转成 int(True 变成 1,False 变成 0)。
  • **argmax(dim=1)**:对每条序列找到第一个 1(即第一个 EOS)的位置索引。
  • 只对存在 EOS 的序列进行索引替换。

5. 构造掩码

sequence_indices = torch.arange(is_eos.size(1), device=completion_ids.device).expand(is_eos.size(0), -1)
  • 作用:生成一个 shape 为 (batch_size, seq_len) 的索引矩阵,每行是 [0, 1, …, seq_len-1]。
  • expand:批量扩展到 batch_size。

6. 比较索引生成最终 mask

return (sequence_indices <= eos_idx.unsqueeze(1)).int()
  • **eos_idx.unsqueeze(1)**:把 (batch_size,) 变成 (batch_size, 1),便于广播。
  • **sequence_indices <= eos_idx.unsqueeze(1)**:每个 batch 内,索引 <= 第一个 EOS 的位置的地方为 True,否则为 False。
  • **.int()**:转为 0/1 张量,就是最终的 mask。

涉及到的 PyTorch 函数

  • ==:元素级比较
  • torch.full(shape, fill_value, dtype, device):创建一个指定 shape 和初值的张量
  • any(dim=1):判断某一维是否有 True
  • int():类型转换
  • argmax(dim=1):返回最大值的索引,这里用于第一个 True 的索引
  • torch.arange(n, device=...):生成 [0,1,…,n-1] 的一维序列
  • expand(batch_size, -1):批量扩展
  • unsqueeze(dim):增加一个维度
  • <=:元素级比较,自动广播

举例说明

假如:

completion_ids = torch.tensor([
    [10, 9, 8, 7, 1, 3, 5],  # EOS=1, 在第4个位置
    [4, 6, 2, 8, 9, 11, 12], # 没有 EOS
])
eos_token_id = 1

掩码为:

[
  [1, 1, 1, 1, 1, 0, 0],  # 前5个位置包含第一个EOS
  [1, 1, 1, 1, 1, 1, 1],  # 没有EOS,全部为1
]

总结

  • 该函数在处理批量生成任务时非常实用,能让你只关注“有意义”的输出部分。
  • 涉及的每一步都是在用张量运算高效地实现“找到第一个 EOS,并在它之后全部 mask 掉”。

结语

第三百三十五篇博文写完,开心!!!!

今天,也是充满希望的一天。


文章作者: LuYF-Lemon-love
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 LuYF-Lemon-love !
  目录