优化Transformer模型以处理可变长度输入序列

2024年11月28日 由 alex 发表 74 0

随着生成式人工智能(genAI)模型在受欢迎程度和规模上的不断增长,其训练和部署所需的计算需求及成本也在相应增加。优化这些模型对于提高它们的运行时性能和降低运营成本至关重要。现代genAI系统的核心是Transformer架构及其注意力机制,这一机制尤其需要大量的计算资源。


在本文中,我们将探索处理可变长度输入序列的挑战——这是包括文档、代码、时间序列等在内的真实世界数据的固有属性。


批处理可变长度输入的挑战

在典型的深度学习工作负载中,单个样本在被复制到GPU并输入到AI模型之前会被分组为批次。批处理提高了计算效率,并且通常有助于模型在训练过程中的收敛。通常,批处理涉及沿着一个新的维度(即批次维度)堆叠所有样本张量。然而,torch.stack要求所有张量具有相同的形状,这在处理可变长度序列时并不成立。


填充及其低效性

解决这一挑战的传统方法是将输入序列填充到固定长度,然后进行堆叠。这种解决方案需要在模型内部进行适当的掩码处理,以确保输出不受无关张量元素的影响。在注意力层中,填充掩码用于指示哪些标记是填充的,并且不应该被关注(例如,参见PyTorch的MultiheadAttention)。然而,填充会浪费大量的GPU资源,增加成本并减慢开发速度。对于大规模AI模型来说,这一点尤为明显。


不填充,而是连接

避免填充的一种方法是沿着现有维度连接序列,而不是沿着新维度堆叠它们。与torch.stack不同,torch.cat允许输入具有不同的形状。连接的输出是一个单一序列,其长度等于各个序列长度之和。为了使这种解决方案有效,我们的单一序列需要补充一个注意力掩码,以确保每个标记只关注同一原始序列中的其他标记,这一过程有时被称为文档掩码。用N表示所有单个序列长度之和,并采用“大O”表示法,这个掩码的大小需要是O(N²),同样,一个朴素的注意力层(在计算注意力得分后才应用掩码)的计算复杂度也是O(N²),这使得这种解决方案非常低效。


注意力层优化

这个问题的解决方案是以专门的注意力层的形式出现的。与标准的注意力层先执行完整的O(N²)注意力得分计算然后再屏蔽掉不相关的得分不同,这些优化的注意力核被设计为只计算有意义的得分。在本文中,我们将探索几种具有不同特点的解决方案,包括:

  • PyTorch的SDPA(缩放点积注意力)与NestedTensors,
  • FlashAttention2,以及
  • xFormers的内存高效注意力。


集成到现有的HuggingFace模型中

对于使用预训练模型的团队来说,过渡到这些优化可能看起来很有挑战性。我们将展示HuggingFace的API如何简化这一过程,使开发人员能够以最少的代码更改和精力集成这些技术。


玩具LLM模型

为了便于讨论,我们将定义一个简单的生成模型(部分灵感来自此处定义的GPT模型)。


Transformer块

我们首先构建一个基本的Transformer块,特别设计用于方便实验不同的注意力机制和优化方法。虽然我们的块执行与标准Transformer块相同的计算,但我们对通常选择的运算符进行了轻微的修改,以支持PyTorch NestedTensor输入的可能性。


# general imports
import time, functools
# torch imports
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
# Define Transformer settings
BATCH_SIZE = 32
NUM_HEADS = 16
HEAD_DIM = 64
DIM = NUM_HEADS * HEAD_DIM
DEPTH = 24
NUM_TOKENS = 1024
MAX_SEQ_LEN = 1024
PAD_ID = 0
DEVICE = 'cuda'
class MyAttentionBlock(nn.Module):
    def __init__(
            self,
            attn_fn,
            dim,
            num_heads,
            format=None,
            **kwargs
    ):
        super().__init__()
        self.attn_fn = attn_fn
        self.num_heads = num_heads
        self.dim = dim
        self.head_dim = dim // num_heads
        self.norm1 = nn.LayerNorm(dim, bias=False)
        self.norm2 = nn.LayerNorm(dim, bias=False)
        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)
        # mlp layers
        self.fc1 = nn.Linear(dim, dim * 4)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(dim * 4, dim)
        self.permute = functools.partial(torch.transpose, dim0=1, dim1=2)
        if format == 'bshd':
            self.permute = nn.Identity()
    def mlp(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x
    def reshape_and_permute(self,x, batch_size):
        x = x.view(batch_size, -1, self.num_heads, self.head_dim)
        return self.permute(x)
    def forward(self, x_in, attn_mask=None):
        batch_size = x_in.size(0)
        x = self.norm1(x_in)
        qkv = self.qkv(x)
        # rather than first reformatting and then splitting the input
        # state, we first split and then reformat q, k, v in order to
        # support PyTorch Nested Tensors
        q, k, v = qkv.chunk(3, -1)
        q = self.reshape_and_permute(q, batch_size)
        k = self.reshape_and_permute(k, batch_size)
        v = self.reshape_and_permute(v, batch_size)
        
        # call the attn_fn with the input attn_mask
        x = self.attn_fn(q, k, v, attn_mask=attn_mask)
        # reformat output
        x = self.permute(x).reshape(batch_size, -1, self.dim)
        x = self.proj(x)
        x = x + x_in
        x = x + self.mlp(self.norm2(x))
        return x


Transformer 解码器模型

基于我们可编程的Transformer模块,我们构建了一个典型的Transformer解码器模型。


class MyDecoder(nn.Module):
    def __init__(
            self,
            block_fn,
            num_tokens,
            dim,
            num_heads,
            num_layers,
            max_seq_len,
            pad_idx=None
    ):
        super().__init__()
        self.num_heads = num_heads
        self.pad_idx = pad_idx
        self.embedding = nn.Embedding(num_tokens, dim, padding_idx=pad_idx)
        self.positional_embedding = nn.Embedding(max_seq_len, dim)
        self.blocks = nn.ModuleList([
            block_fn(
                dim=dim,
                num_heads=num_heads
            )
            for _ in range(num_layers)])
        self.output = nn.Linear(dim, num_tokens)
    def embed_tokens(self, input_ids, position_ids=None):
        x = self.embedding(input_ids)
        if position_ids is None:
            position_ids = torch.arange(input_ids.shape[1],
                                        device=x.device)
        x = x + self.positional_embedding(position_ids)
        return x
    def forward(self, input_ids, position_ids=None, attn_mask=None):
        # Embed tokens and add positional encoding
        x = self.embed_tokens(input_ids, position_ids)
        if self.pad_idx is not None:
            assert attn_mask is None
            # create a padding mask - we assume boolean masking
            attn_mask = (input_ids != self.pad_idx)
            attn_mask = attn_mask.view(BATCH_SIZE, 1, 1, -1) \
                .expand(-1, self.num_heads, -1, -1)
        for b in self.blocks:
            x = b(x, attn_mask)
        logits = self.output(x)
        return logits


可变长度序列输入

接下来,我们创建一个包含可变长度序列的数据集,其中每个序列由随机生成的标记组成。为了简化,我们(任意地)为序列长度选择一个固定的分布。在现实场景中,序列长度的分布通常反映了数据的性质,如文档的长度或音频段的长度。请注意,长度的分布直接影响由填充引起的计算效率低下问题。


# Use random datadata
class FakeDataset(Dataset):
    def __len__(self):
        return 1000000
    def __getitem__(self, index):
        length = torch.randint(1, MAX_SEQ_LEN, (1,))
        sequence = torch.randint(1, NUM_TOKENS, (length + 1,))
        inputs = sequence[:-1]
        targets = sequence[1:]
        return inputs, targets
def pad_sequence(sequence, length, pad_val):
    return torch.nn.functional.pad(
        sequence,
        (0, length - sequence.shape[0]),
        value=pad_val
    )
def collate_with_padding(batch):
    padded_inputs = []
    padded_targets = []
    for b in batch:
        padded_inputs.append(pad_sequence(b[0], MAX_SEQ_LEN, PAD_ID))
        padded_targets.append(pad_sequence(b[1], MAX_SEQ_LEN, PAD_ID))
    padded_inputs = torch.stack(padded_inputs, dim=0)
    padded_targets = torch.stack(padded_targets, dim=0)
    return {
        'inputs': padded_inputs,
        'targets': padded_targets
    }
def data_to_device(data, device):
    if isinstance(data, dict):
        return {
            key: data_to_device(val,device)
            for key, val in data.items()
        }
    elif isinstance(data, (list, tuple)):
        return type(data)(
            data_to_device(val, device) for val in data
        )
    elif isinstance(data, torch.Tensor):
        return data.to(device=device, non_blocking=True)
    else:
        return data.to(device=device)


训练/评估循环

最后,我们实现了一个主函数,用于对可变长度的输入序列进行训练/评估。


def main(main(
    block_fn, 
    data_collate_fn=collate_with_padding,
    pad_idx=None,
    train=True,
    compile=False
):
    torch.random.manual_seed(0)
    device = torch.device(DEVICE)
    torch.set_float32_matmul_precision("high")
    # Create dataset and dataloader
    data_set = FakeDataset()
    data_loader = DataLoader(
        data_set,
        batch_size=BATCH_SIZE,
        collate_fn=data_collate_fn,
        num_workers=12,
        pin_memory=True,
        drop_last=True
    )
    model = MyDecoder(
        block_fn=block_fn,
        num_tokens=NUM_TOKENS,
        dim=DIM,
        num_heads=NUM_HEADS,
        num_layers=DEPTH,
        max_seq_len=MAX_SEQ_LEN,
        pad_idx=pad_idx
    ).to(device)
    if compile:
        model = torch.compile(model)
    # Define loss and optimizer
    criterion = torch.nn.CrossEntropyLoss(ignore_index=PAD_ID)
    optimizer = torch.optim.SGD(model.parameters())
    def train_step(model, inputs, targets, 
                   position_ids=None, attn_mask=None):
        with torch.amp.autocast(DEVICE, dtype=torch.bfloat16):
            outputs = model(inputs, position_ids, attn_mask)
            outputs = outputs.view(-1, NUM_TOKENS)
            targets = targets.flatten()
            loss = criterion(outputs, targets)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
    @torch.no_grad()
    def eval_step(model, inputs, targets, 
                  position_ids=None, attn_mask=None):
        with torch.amp.autocast(DEVICE, dtype=torch.bfloat16):
            outputs = model(inputs, position_ids, attn_mask)
            if outputs.is_nested:
                outputs = outputs.data._values
                targets = targets.data._values
            else:
                outputs = outputs.view(-1, NUM_TOKENS)
                targets = targets.flatten()
            loss = criterion(outputs, targets)
        return loss
    if train:
        model.train()
        step_fn = train_step
    else:
        model.eval()
        step_fn = eval_step
    t0 = time.perf_counter()
    summ = 0
    count = 0
    for step, data in enumerate(data_loader):
        # Copy data to GPU
        data = data_to_device(data, device=device)
        step_fn(model, data['inputs'], data['targets'],
                       position_ids=data.get('indices'),
                       attn_mask=data.get('attn_mask'))
        # Capture step time
        batch_time = time.perf_counter() - t0
        if step > 20:  # Skip first steps
            summ += batch_time
            count += 1
        t0 = time.perf_counter()
        if step >= 100:
            break
    print(f'average step time: {summ / count}')


PyTorch中带填充的SDPA

在我们的基线实验中,我们将Transformer模块配置为使用PyTorch的SDPA(可能是指某种特定的加速或优化机制,但SDPA并非PyTorch的官方术语,这里根据上下文推测其含义)机制。在我们的实验中,我们分别在有和没有使用torch.compile的情况下进行了训练和评估。这些实验是在配备CUDA 12.4和PyTorch 2.5.1的NVIDIA H100上运行的。


from torch.nn.functional import scaled_dot_product_attention as sdpa
block_fn = functools.partial(MyAttentionBlock, attn_fn=sdpa)
causal_block_fn = functools.partial(
    MyAttentionBlock,
    attn_fn=functools.partial(sdpa, is_causal=True)
)
for mode in ['eval', 'train']:
    for compile in [False, True]:
        block_func = causal_block_fn\
            if mode == 'train' else block_fn
        print(f'{mode} with {collate}, '
              f'{"compiled" if compile else "uncompiled"}')
        main(block_fn=block_func,
             pad_idx=PAD_ID,
             train=mode=='train',
             compile=compile)


性能结果:

  • 评估:不使用torch.compile时为132毫秒(ms),使用torch.compile时为130毫秒
  • 训练:不使用torch.compile时为342毫秒,使用torch.compile时为299毫秒


针对可变长度输入的优化

在本节中,我们将探讨几种处理Transformer模型中可变长度输入序列的优化技术。


填充优化

我们的第一个优化与注意力核无关,而是与我们的填充机制有关。我们不是将每个批次中的序列填充到固定长度,而是将它们填充到该批次中最长序列的长度。以下代码块包含我们修订后的整理函数和更新的实验。


def collate_pad_to_longest(batch):
    padded_inputs = []
    padded_targets = []
    max_length = max([b[0].shape[0] for b in batch])
    for b in batch:
        padded_inputs.append(pad_sequence(b[0], max_length, PAD_ID))
        padded_targets.append(pad_sequence(b[1], max_length, PAD_ID))
    padded_inputs = torch.stack(padded_inputs, dim=0)
    padded_targets = torch.stack(padded_targets, dim=0)
    return {
        'inputs': padded_inputs,
        'targets': padded_targets
    }
for mode in ['eval', 'train']:
    for compile in [False, True]:
        block_func = causal_block_fn\
            if mode == 'train' else block_fn
        print(f'{mode} with {collate}, '
              f'{"compiled" if compile else "uncompiled"}')
        main(block_fn=block_func,
             data_collate_fn=collate_pad_to_longest,
             pad_idx=PAD_ID,
             train=mode=='train',
             compile=compile)


将每个批次中的序列填充到最长序列的长度会导致性能略有提升:

  • 评估:不使用torch.compile时为129毫秒,使用torch.compile时为116毫秒
  • 训练:不使用torch.compile时为337毫秒,使用torch.compile时为294毫秒


PyTorch NestedTensors与SDPA结合使用

接下来,我们利用SDPA在评估模式下对PyTorch NestedTensors的内置支持。目前,PyTorch NestedTensors是一个原型功能,它允许将不同长度的张量组合在一起。这些有时被称为锯齿状或不规则张量。在下面的代码块中,我们定义了一个整理函数,用于将我们的序列组合成NestedTensors。我们还定义了一个索引条目,以便我们能够正确地计算位置嵌入。


PyTorch NestedTensors仅受有限数量的PyTorch操作支持。绕过这些限制可能需要一些创造力。例如,只有在NestedTensors具有完全相同的“锯齿状”形状时,才支持它们之间的加法。在下面的代码中,我们使用了一种解决方法来确保索引条目与模型输入具有相同的形状。


def nested_tensor_collate(batch):
    inputs = torch.nested.as_nested_tensor([b[0] for b in batch],
                                           layout=torch.jagged)
    targets = torch.nested.as_nested_tensor([b[1] for b in batch],
                                            layout=torch.jagged)
    indices = torch.concat([torch.arange(b[0].shape[0]) for b in batch])
    # workaround for creating a NestedTensor with identical "jagged" shape
    xx = torch.empty_like(inputs)
    xx.data._values[:] = indices
    return {
        'inputs': inputs,
        'targets': targets,
        'indices': xx
    }
for compile in [False, True]:
    print(f'eval with nested tensors, '
          f'{"compiled" if compile else "uncompiled"}')
    main(
        block_fn=block_fn,
        data_collate_fn=nested_tensor_collate,
        train=False,
        compile=compile
    )


尽管在使用torch.compile时,NestedTensor优化得到的单步时间为131毫秒,与我们的基线结果相似,但在编译模式下,单步时间降低到了42毫秒,实现了令人印象深刻的约3倍性能提升。


FlashAttention2

在这篇文章中,我们展示了如何使用来自flash-attn(2.7.0版本)的flash_attn_varlen_func,这是一个专为可变大小输入设计的API。为了使用这个函数,我们将批次中的所有序列连接成一个单一的序列。我们还创建了一个cu_seqlens张量,它指向连接后的张量中每个单独序列开始的索引位置。下面的代码块包含了我们的整理函数,以及随后的评估和训练实验。请注意,flash_attn_varlen_func(在撰写本文时)不支持torch.compile。


def collate_concat(batch):
    inputs = torch.concat([b[0] for b in batch]).unsqueeze(0)
    targets = torch.concat([b[1] for b in batch]).unsqueeze(0)
    indices = torch.concat([torch.arange(b[0].shape[0]) for b in batch])
    seqlens = torch.tensor([b[0].shape[0] for b in batch])
    seqlens = torch.cumsum(seqlens, dim=0, dtype=torch.int32)
    cu_seqlens = torch.nn.functional.pad(seqlens, (1, 0))
    return {
        'inputs': inputs,
        'targets': targets,
        'indices': indices,
        'attn_mask': cu_seqlens
    }
from flash_attn import flash_attn_varlen_func
fa_varlen = lambda q, k, v, attn_mask: flash_attn_varlen_func(
    q.squeeze(0),
    k.squeeze(0),
    v.squeeze(0),
    cu_seqlens_q=attn_mask,
    cu_seqlens_k=attn_mask,
    max_seqlen_q=MAX_SEQ_LEN,
    max_seqlen_k=MAX_SEQ_LEN
).unsqueeze(0)
fa_varlen_causal = lambda q, k, v, attn_mask: flash_attn_varlen_func(
    q.squeeze(0),
    k.squeeze(0),
    v.squeeze(0),
    cu_seqlens_q=attn_mask,
    cu_seqlens_k=attn_mask,
    max_seqlen_q=MAX_SEQ_LEN,
    max_seqlen_k=MAX_SEQ_LEN,
    causal=True
).unsqueeze(0)
block_fn = functools.partial(MyAttentionBlock,
                             attn_fn=fa_varlen,
                             format='bshd')
causal_block_fn = functools.partial(MyAttentionBlock,
                                    attn_fn=fa_varlen_causal,
                                    format='bshd')
print('flash-attn eval')
main(
    block_fn=block_fn,
    data_collate_fn=collate_concat,
    train=False
)
print('flash-attn train')
main(
    block_fn=causal_block_fn,
    data_collate_fn=collate_concat,
    train=True,
)


这种优化的效果非常显著,评估时间缩短至51毫秒,训练时间缩短至160毫秒,与我们的基线实验相比,性能分别提升了2.6倍和2.1倍。


XFormers 内存高效注意力机制

在这里,我们展示了BlockDiagonalMask的使用,它是专门为任意长度的输入序列设计的。所需的整理函数出现在下面的代码块中,随后是评估和训练实验。请注意,在训练模式下,torch.compile失败了。


from xformers.ops import fmha
from xformers.ops import memory_efficient_attention as mea
def collate_xformer(batch):
    inputs = torch.concat([b[0] for b in batch]).unsqueeze(0)
    targets = torch.concat([b[1] for b in batch]).unsqueeze(0)
    indices = torch.concat([torch.arange(b[0].shape[0]) for b in batch])
    seqlens = [b[0].shape[0] for b in batch]
    batch_sizes = [1 for b in batch]
    block_diag = fmha.BlockDiagonalMask.from_seqlens(seqlens, device='cpu')
    block_diag._batch_sizes = batch_sizes
    return {
        'inputs': inputs,
        'targets': targets,
        'indices': indices,
        'attn_mask': block_diag
    }
mea_eval = lambda q, k, v, attn_mask: mea(
    q,k,v, attn_bias=attn_mask)
mea_train = lambda q, k, v, attn_mask: mea(
    q,k,v, attn_bias=attn_mask.make_causal())
block_fn = functools.partial(MyAttentionBlock,
                             attn_fn=mea_eval,
                             format='bshd')
causal_block_fn = functools.partial(MyAttentionBlock,
                             attn_fn=mea_train,
                             format='bshd')
print(f'xFormer Attention ')
for compile in [False, True]:
    print(f'eval with xFormer Attention, '
          f'{"compiled" if compile else "uncompiled"}')
    main(block_fn=block_fn,
         train=False,
         data_collate_fn=collate_xformer,
         compile=compile)
print(f'train with xFormer Attention')
main(block_fn=causal_block_fn,
     train=True,
     data_collate_fn=collate_xformer)


未使用torch.compile时,评估和训练的所得单步时间分别为50毫秒和159毫秒。使用torch.compile进行评估时,单步时间为42毫秒。


结果

下表总结了我们优化方法的结果。


8


对于我们的小型模型来说,表现最佳的是xFormer的内存高效注意力机制,它在评估时提供了约3倍的性能提升,在训练时提供了约2倍的性能提升。我们提醒不要从这些结果中得出任何结论,因为不同注意力函数对性能的影响会根据具体的模型和使用场景而有显著差异。


为可变长度输入优化HuggingFace模型

当从头开始创建模型时,上述工具和技术很容易实现。然而,如今机器学习开发者采用现有(预训练)模型并根据其使用场景进行微调的情况并不罕见。虽然我们描述的优化可以在不改变模型权重集和不改变模型行为的情况下集成,但如何做到这一点的最佳方法并不完全清楚。在理想情况下,我们的机器学习框架将允许我们编程使用针对可变长度输入优化的注意力机制。在本节中,我们将演示如何为可变长度输入优化HuggingFace模型。


HuggingFace玩具模型 - GPT2LMHeadModel

为了便于讨论,我们创建了一个玩具示例,其中我们在可变长度序列上训练HuggingFace的GPT2LMHead模型。这需要根据HuggingFace的输入规范调整我们的随机数据集和数据填充整理函数。


from transformers import GPT2Config, GPT2LMHeadModel
# Use random data
class HuggingFaceFakeDataset(Dataset):
    def __len__(self):
        return 1000000
    def __getitem__(self, index):
        length = torch.randint(1, MAX_SEQ_LEN, (1,))
        input_ids = torch.randint(1, NUM_TOKENS, (length,))
        labels = input_ids.clone()
        labels[0] = PAD_ID # ignore first token
        return {
            'input_ids': input_ids,
            'labels': labels
        }
        return input_ids, labels
def hf_collate_with_padding(batch):
    padded_inputs = []
    padded_labels = []
    for b in batch:
        input_ids = b['input_ids']
        labels = b['labels']
        padded_inputs.append(pad_sequence(input_ids, MAX_SEQ_LEN, PAD_ID))
        padded_labels.append(pad_sequence(labels, MAX_SEQ_LEN, PAD_ID))
    padded_inputs = torch.stack(padded_inputs, dim=0)
    padded_labels = torch.stack(padded_labels, dim=0)
    return {
        'input_ids': padded_inputs,
        'labels': padded_labels,
        'attention_mask': (padded_inputs != PAD_ID)
    }


训练函数

我们的训练函数会根据所请求的GPT2配置实例化一个GPT2LMHeadModel,并继续在我们的可变长度序列上对其进行训练。


def hf_main(hf_main(
    config,
    collate_fn=hf_collate_with_padding,
    compile=False
):
    torch.random.manual_seed(0)
    device = torch.device(DEVICE)
    torch.set_float32_matmul_precision("high")
    # Create dataset and dataloader
    data_set = HuggingFaceFakeDataset()
    data_loader = DataLoader(
        data_set,
        batch_size=BATCH_SIZE,
        collate_fn=collate_fn,
        num_workers=12 if DEVICE == "CUDA" else 0,
        pin_memory=True,
        drop_last=True
    )
    model = GPT2LMHeadModel(config).to(device)
    if compile:
        model = torch.compile(model)
    # Define loss and optimizer
    criterion = torch.nn.CrossEntropyLoss(ignore_index=PAD_ID)
    optimizer = torch.optim.SGD(model.parameters())
    model.train()
    t0 = time.perf_counter()
    summ = 0
    count = 0
    for step, data in enumerate(data_loader):
        # Copy data to GPU
        data = data_to_device(data, device=device)
        input_ids = data['input_ids']
        labels = data['labels']
        position_ids = data.get('position_ids')
        attn_mask = data.get('attention_mask')
        with torch.amp.autocast(DEVICE, dtype=torch.bfloat16):
            outputs = model(input_ids=input_ids,
                            position_ids=position_ids,
                            attention_mask=attn_mask)
            logits = outputs.logits[..., :-1, :].contiguous()
            labels = labels[..., 1:].contiguous()
            loss = criterion(logits.view(-1, NUM_TOKENS), labels.flatten())
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        # Capture step time
        batch_time = time.perf_counter() - t0
        if step > 20:  # Skip first steps
            summ += batch_time
            count += 1
        t0 = time.perf_counter()
        if step >= 100:
            break
    print(f'average step time: {summ / count}')


使用填充的SDPA

在下面的回调中,我们调用训练函数,并使用默认的序列填充整理器。


config = GPT2Config(
        n_layer=DEPTH,
        n_embd=DIM,
        n_head=NUM_HEADS,
        vocab_size=NUM_TOKENS,
    )
for compile in [False, True]:
    print(f"HF GPT2 train with SDPA, compile={compile}")
    hf_main(config=config, compile=compile)


未使用torch.compile时,所得的单步时间为815毫秒,而使用torch.compile后则为440毫秒。


FlashAttention2

现在,我们通过将attn_implementation参数设置为“flash_attention_2”来利用HuggingFace对FlashAttention2的内置支持。在后台,HuggingFace会先对填充的数据输入进行去填充处理,然后将它们传递给我们之前看到的优化过的flash_attn_varlen_func函数。


flash_config = GPT2Config(
        n_layer=DEPTH,
        n_embd=DIM,
        n_head=NUM_HEADS,
        vocab_size=NUM_TOKENS,
        attn_implementation='flash_attention_2''flash_attention_2'
    )
print(f"HF GPT2 train with flash")
hf_main(config=flash_config)


所得的时间步长为620毫秒,仅通过简单切换就实现了(在未编译模式下)30%的性能提升。


使用未填充输入的FlashAttention2

当然,在整理函数中填充序列然后又将其去填充,这种做法似乎并不合理。在HuggingFace最近的一次更新中,增加了对将连接(未填充)序列传递给选定模型的支持。不幸的是,(在撰写本文时)我们的GPT2模型并未被纳入支持范围。然而,为GPT2模型添加支持只需在modeling_gpt2.py文件中增加五行小的代码更改,以便将序列的position_ids传播到flash-attention内核中。完整的补丁代码出现在下面的代码块中:


@@ -370,0 +371 @@
+        position_ids = None
@@ -444,0 +446 @@
+            position_ids=position_ids
@@ -611,0 +614 @@
+        position_ids=None
@@ -621,0 +625 @@
+            position_ids=position_ids
@@ -1140,0 +1145 @@
+                    position_ids=position_ids


我们定义了一个整理函数来连接我们的序列,并在未填充的序列上训练我们的Hugging Face模型。


def collate_flatten(batch):
    input_ids = torch.concat([b['input_ids'] for b in batch]).unsqueeze(0)concat([b['input_ids'] for b in batch]).unsqueeze(0)
    labels = torch.concat([b['labels'] for b in batch]).unsqueeze(0)
    position_ids = [torch.arange(b['input_ids'].shape[0]) for b in batch]
    position_ids = torch.concat(position_ids)
    return {
        'input_ids': input_ids,
        'labels': labels,
        'position_ids': position_ids
    }
print(f"HF GPT2 train with flash, no padding")
hf_main(config=flash_config, collate_fn=collate_flatten)


所得的单步时间为323毫秒,比在运行填充输入的flash-attention上快了90%。


结果

我们HuggingFace实验的结果总结如下。


9


通过少量的努力,我们与未编译的基线实验相比,运行时性能提高了2.5倍,与编译版本相比提高了36%。


在本节中,我们展示了HuggingFace API如何让我们利用FlashAttention2中的优化内核,从而显著提高现有模型在可变长度序列上的训练性能。


总结

随着人工智能模型在普及度和复杂性上的不断增长,优化其性能对于减少运行时间和成本变得至关重要。这对于像注意力层这样计算密集型的组件来说尤其如此。在这篇文章中,我们继续探索了注意力层的优化,并展示了提高Transformer模型性能的新工具和技术。

文章来源:https://medium.com/towards-data-science/optimizing-transformer-models-for-variable-length-input-sequences-19fb88fddf71
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
热门职位
Maluuba
20000~40000/月
Cisco
25000~30000/月 深圳市
PilotAILabs
30000~60000/年 深圳市
写评论取消
回复取消