使用JAX和Haiku从头开始​​实现Transformer编码器

2023年11月08日 由 alex 发表 310 0

介绍


在2017年的开创性论文《Attention is all you need》中引入的Transformer架构可以说是近期深度学习历史上最有影响力的突破之一,它使得大型语言模型得以兴起,并在计算机视觉等领域发挥作用。


相比于之前依赖于循环结构(如长短时记忆网络(LSTM)或门控循环单元(GRU))的最先进架构,Transformer引入了自注意力的概念,并采用了编码器/解码器架构。


在本文中,我们将从头开始逐步实现Transformer的第一部分,即编码器。我们将使用JAX作为我们的主要框架,同时结合DeepMind的深度学习库Haiku。


主要参数


在我们开始之前,我们需要定义一些关键参数,这些参数在编码器块中起着至关重要的作用:


1. 序列长度(seq_len):序列中标记或词语的数目。


2. 嵌入维度(embed_dim):嵌入表示的维度,换句话说,描述单个标记或词语所使用的数值的数量。


3. 批大小(batch_size):输入批次的大小,即同时处理的序列数。


我们的编码器模型的输入序列通常具有形状(batch_size,seq_len)。在本文中,我们将使用batch_size=32和seq_len=10,这意味着我们的编码器将同时处理32个由10个词语组成的序列。


在处理的每个步骤中,关注数据的形状将使我们能够更好地可视化和理解数据在编码器块中的流动。这是我们编码器的一个高级概述,我们将从底部开始,即从嵌入层和位置编码开始:


4


嵌入层和位置编码


如前所述,我们的模型将批量化的标记序列作为输入。生成这些标记可以简单地收集数据集中独特单词的集合,并为每个单词分配一个索引。然后,我们会提取32个包含10个单词的序列,并用词汇表中每个单词的索引替换每个单词。这个过程会生成一个形状为(batch_size,seq_len)的数组,符合我们的期望。


我们现在准备开始编码器的工作。第一步是为我们的序列创建“位置嵌入”。位置嵌入是词嵌入和位置编码的总和。


词嵌入


词嵌入允许我们对词汇表中的单词的意义和语义关系进行编码。在本文中,嵌入维度被固定为64。这意味着每个单词由一个64维向量表示,因此具有相似含义的单词具有相似的坐标。此外,我们可以操作这些向量来提取单词之间的关系,如下图所示。


5


使用Haiku,生成可学习的嵌入向量就像调用一样简单。


hk.Embed(vocab_size, embed_dim)


这些嵌入将在模型训练过程中与其他可学习参数一同被更新。


位置编码


与循环神经网络不同,Transformer不能通过共享的隐藏状态推测出令牌的位置,因为它们缺乏循环或卷积结构。因此引入了位置编码,这些向量传达了输入序列中令牌的位置。


基本上,每个令牌被分配一个位置向量,由交替的正弦和余弦值组成。这些向量与单词嵌入的维度匹配,因此两者可以相加。


特别是,原始的Transformer论文使用以下函数:


6


6-1


以下数字使我们进一步理解位置编码的功能。让我们看一下最上方图的第一行,我们可以看到零和一的交替序列。事实上,行代表序列中令牌的位置(pos变量),而列代表嵌入维度(i变量)。


因此,当pos=0时,前面的方程在偶数嵌入维度中返回sin(0)=0,在奇数维度中返回cos(0)=1。


此外,我们看到相邻的行共享相似的值,而第一行和最后一行则有很大的不同。这种性质有助于模型评估序列中单词之间的距离以及它们的顺序。


最后,第三个图表示位置编码和嵌入的总和,这是嵌入块的输出。


7


使用俳句,我们定义嵌入层如下。与其他深度学习框架类似,Haiku 允许我们定义自定义模块(此处hk.Module)来存储可学习的参数并定义模型组件的行为。

每个 Haiku 模块都需要有一个__init__and__call__函数。hk.Embed在这里,调用函数只是使用函数和位置编码来计算嵌入,然后再对它们求和。


简而言之,它允许我们为单个样本vmap定义一个函数并将其矢量化,以便它可以应用于批量数据。该参数用于指定我们要迭代输入的第一个轴,即嵌入维度。另一方面,是 Python if/else 语句的 XLA 兼容版本。in_axesdim lax.cond


class EmbeddingLayer(hk.Module):
    def __init__(
        self,
        embed_dim: int,
        vocab_size: int,
        seq_len: int,
        name: str | None = None,
    ):
        super().__init__(name=name)
        self.embed_dim = embed_dim
        self.vocab_size = vocab_size
        self.seq_len = seq_len
        self.embedding_layer = hk.Embed(self.vocab_size, self.embed_dim)
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        embeddings = embedding_layer(x)
        positional_encodings = self._batched_positional_encoding(
            jnp.arange(self.seq_len),
            jnp.arange(self.embed_dim),
        )
        return embeddings + positional_encodings
    @partial(jit, static_argnums=(0))
    @partial(vmap, in_axes=(None, None, 0), out_axes=(1))  # iterate over the embedding dimensions
    def _batched_positional_encoding(self, pos: jnp.ndarray, dim: jnp.ndarray):
           def _even_encoding():
            return jnp.sin(pos / (jnp.power(10_000, 2 * dim / self.embed_dim)))
        def _odd_encoding():
            return jnp.cos(pos / (jnp.power(10_000, 2 * dim / self.embed_dim)))
        is_even = dim % 2 == 0
        return lax.cond(is_even, _even_encoding, _odd_encoding)


自注意力和多头注意力


注意力旨在计算序列中每个单词与输入单词的相关性以确定其重要性。例如,在以下句子中:


The black cat jumped on the sofa, lied down and fell asleep, as it was tired”。


对于模型来说,单词“it”可能相对含糊不清,因为严格来说,它可以指向“cat”和“sofa”两个对象。一个经过良好训练的注意力模型能够理解“it”指的是“cat”,并根据此为句子的其他部分分配注意力值。


从本质上讲,注意力值可以看作是描述在输入单词上下文中某个单词重要性的权重。


8


在Transformer论文中,使用了缩放点积注意力机制来计算注意力。其公式总结如下:


9


在这里,Q、K和V代表查询(Queries)、键(Keys)和值(Values)。这些矩阵是通过将学习得到的权重向量WQ、WK和WV与位置嵌入相乘得到的。


这些名称主要是抽象的,用于帮助理解信息在注意力块中如何被处理和加权。


以下是一个直观的解释:


1. 查询(Queries):可以理解为关于序列中所有位置的“一系列问题”。例如,询问单词的上下文并尝试识别序列中最相关的部分。


2. 键(Keys):可以看作是与查询(Queries)进行交互的信息,查询与键之间的兼容性决定查询应该对相应的值(Values)给予多少注意力。


3. 值(Values):匹配键和查询使我们能够决定哪些键是相关的,值是与键配对的实际内容。


在下面的图中,查询是YouTube搜索,键是视频描述和元数据,而值是与之关联的视频。


10


在我们的例子下,查询、键和值都来自同一源(因为它们都是从输入序列中派生出来的),因此被称为自注意力(self-attention)。


注意力得分的计算通常会并行执行多次,每次使用一小部分的嵌入。这个机制被称为“多头注意力”,使得每个注意力头可以并行学习数据的几个不同表示,从而得到更强大的模型。


一个单独的注意力头通常处理形状为(batch_size,seq_len,d_k)的数组,其中 d_k 可以设为头的数量与嵌入维度之间的比率(d_k = n_heads / embed_dim)。通过这种方式,方便地将每个头的输出串联起来,得到形状为(batch_size,seq_len,embed_dim)的数组,作为输入。


注意力矩阵的计算可以分解为几个步骤:


1. 首先,我们定义可学习的权重向量 WQ、WK 和 WV。这些向量的形状为(n_heads,embed_dim,d_k)。


2. 同时,我们将位置嵌入与权重向量相乘。我们得到形状为(batch_size,seq_len,d_k)的 Q、K 和 V 矩阵。


3. 然后,我们对 Q 和 K(转置)的点积进行缩放。这个缩放过程涉及将点积结果除以 d_k 的平方根,并对矩阵的行应用 softmax 函数。因此,输入标记(即一行)的注意力得分之和为一,这有助于防止值变得过大并减慢计算速度。输出的形状为(batch_size,seq_len,seq_len)。


4. 最后,我们将上述操作的结果与 V 相乘,使输出的形状为(batch_size,seq_len,d_k)。


11


5. 每个注意力头的输出可以连接在一起,形成一个形状为(batch_size, seq_len, embed_dim)的矩阵。Transformer论文还在多头注意力模块的末尾添加了一个线性层,用于汇总和组合来自所有注意力头的学习表示。


12


在Haiku中,多头注意力模块可以实现如下。__call__函数遵循上述图的相同逻辑,而类方法利用JAX工具(如vmap,用于在不同的注意力头和矩阵上向量化操作)和tree_map(将矩阵点积映射到权重向量上)的优势。


class Multihead_Attention(hk.Module):
    def __init__(
        self,
        embed_dim: int,
        batch_size: int,
        seq_len: int,
        n_heads: int,
        d_k: int,
        name: str | None = None,
    ):
        super().__init__(name)
        self.embed_dim = embed_dim
        self.batch_size = batch_size
        self.seq_len = seq_len
        self.n_heads = n_heads
        self.d_k = d_k
    def __call__(self, x: jnp.ndarray):
        WQ, WK, WV = self._init_attention_weights()
        Q, K, V = self._get_multihead_Q_K_V_matrices(WQ, WK, WV, x)
        attention_matrices = self._multihead_attention(Q, K, V)
        # concatenate the matrices obtained by the different attention heads
        attention_matrix = attention_matrices.transpose(1, 2, 0, 3).reshape(
            self.batch_size, self.seq_len, -1
        )
        # scale and combine the attention vectors using a linear layer
        return hk.Linear(self.embed_dim)(attention_matrix)
    def _init_attention_weights(self):
        init = hk.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform")
        init_parameters = {
            "shape": (self.n_heads, self.embed_dim, self.d_k),
            "dtype": jnp.float32,
            "init": init,
        }
        variable_names = ["WQ", "WK", "WV"]
        
        return jax.tree_map(
            lambda name: hk.get_parameter(name, **init_parameters), variable_names
        )
    @staticmethod
    @jit
    @partial(vmap, in_axes=(0, 0, 0, None))
    def _get_multihead_Q_K_V_matrices(WQ, WK, WV, positional_embeddings):
        return jax.tree_map(
            lambda x: jnp.matmul(positional_embeddings, x), [WQ, WK, WV]
        )
    @partial(jit, static_argnums=(0))
    @partial(vmap, in_axes=(None, 0, 0, 0))  # iterate over the heads
    def _multihead_attention(self, Q, K, V):
        # transpose K to (N_HEADS, BATCH_SIZE, D_K, SEQ_LEN)
        attention_score = jnp.matmul(Q, K.transpose(0, 2, 1))
        # apply row-wise softmax
        scaled_attention = jax.nn.softmax(attention_score / jnp.sqrt(self.d_k), axis=-1)
        
        return jnp.matmul(scaled_attention, V)


剩余链接和层标准化


正如你可能在Transformer图中注意到的那样,多头注意力块和前馈网络之后跟着剩余连接和层归一化。


剩余连接


剩余连接是解决梯度消失问题的标准解决方案,当梯度变得太小以至于无法有效地更新模型参数时,就会出现梯度消失问题。


由于这个问题在特别深的架构中会自然产生,所以剩余连接被用于各种复杂的模型,如计算机视觉中的ResNet(Kaiming等,2015),强化学习中的AlphaZero(Silver等,2017),当然还有Transformer。


在实践中,剩余连接简单地将特定层的输出传递给后续层,跳过一个或多个层。例如,围绕多头注意力的剩余连接等效于将多头注意力的输出与位置嵌入相加。


这使得在反向传播过程中梯度能够更有效地在架构中流动,并且通常可以实现更快的收敛和更稳定的训练。


13


层归一化


层归一化有助于确保通过模型传播的值不会“爆炸”(趋向无穷大),这在注意力块中很容易发生,在每个前向传递中都要乘以多个矩阵。


与批归一化不同,批归一化假设均匀分布并沿批维度进行归一化,而层归一化在特征维度上进行操作。这种方法适用于句子批次,因为每个批次可能由于不同的含义和词汇而具有独特的分布。


通过在嵌入或注意力值等特征上进行归一化,层归一化将数据标准化到一致的尺度,而不会混淆不同的句子特征,保持每个句子的独特分布。


14


层归一化的实现非常直观,我们初始化可学习的参数alpha和beta,并沿着所需的特征轴进行归一化。


lass LayerNorm(hk.Module):
    def __init__(self, epsilon: float = 1e-6, feature_axis: int = -1, name: str | None = None,):
        super().__init__(name=name)
        self.epsilon = epsilon
        self.feature_axis = feature_axis
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        gamma = hk.get_parameter(
            "gamma", shape=(x.shape[-1],), init=hk.initializers.Constant(1.0)
        )
        beta = hk.get_parameter(
            "beta", shape=(x.shape[-1],), init=hk.initializers.Constant(0.0)
        )
        means = jnp.mean(x, axis=self.feature_axis)
        variances = jnp.var(x, axis=self.feature_axis)
        normalized = (x - jnp.expand_dims(means, -1)) / jnp.sqrt(
            jnp.expand_dims(variances, -1) + self.epsilon
        )
        return gamma * normalized + beta


位置通用前馈网络


位置通用前馈网络是编码器的最后一个组件。这个全连接网络将注意力模块的归一化输出作为输入,并用于引入非线性以及增加模型学习复杂函数的能力。


它由两个经过gelu激活分割的稠密层组成。


class FeedForwardNet(hk.Module):
    def __init__(
        self,
        embed_dim: int,
        hidden_dim: int,
        name: str | None = None,
    ):
        super().__init__(name=name)
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
    def __call__(self, x):
        x = hk.Linear(self.hidden_dim)(x)
        x = jax.nn.gelu(x)
        x = hk.Linear(self.embed_dim)(x)
        return x


在这个块之后,我们还有另一个残差连接和层归一化来完成编码器。


总结


到此为止!现在你应该对Transformer编码器的主要概念已经很熟悉了。下面是完整的编码器类,注意在Haiku中,我们为每个层分配一个名称,这样可以将可学习参数分开并且易于访问。__call__函数提供了我们编码器中不同步骤的良好概述。


class Encoder(hk.Module):
    def __init__(
        self,
        vocab_size: int,
        embed_dim: int,
        seq_len: int,
        batch_size: int,
        n_heads: int,
        d_k: int,
        hidden_dim: int,
        name: str,
    ) -> None:
        super().__init__(name=name)
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.seq_len = seq_len
        self.batch_size = batch_size
        self.n_heads = n_heads
        self.d_k = d_k  # common value: int(self.embed_dim / self.n_heads)
        self.hidden_dim = hidden_dim
        self.embedding_layer = EmbeddingLayer(
            embed_dim, vocab_size, seq_len, name="Embedding_layer"
        )
        self.multihead_attention = Multihead_Attention(
            embed_dim, batch_size, seq_len, n_heads, d_k, name="Multihead_Attention"
        )
        self.post_attention_layer_norm = LayerNorm(
            feature_axis=-1, name="Post_Attention_Layer_Norm"
        )
        self.post_feedforward_layer_norm = LayerNorm(
            feature_axis=-1, name="Post_FeedForward_Layer_Norm"
        )
        self.feedforward_net = FeedForwardNet(
            embed_dim, hidden_dim, name="FeedForward_Net"
        )
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        embeddings = self.embedding_layer(x)
        attention_matrices = self.multihead_attention(embeddings)
        normalized_residual_attention = self.post_attention_layer_norm(
            attention_matrices + embeddings
        )
        feedforward_outputs = self.feedforward_net(normalized_residual_attention)
        return self.post_feedforward_layer_norm(
            feedforward_outputs + normalized_residual_attention
        )


要在实际数据上使用该模块,我们需要对封装编码器类的函数应用hk.transform。事实上,你可能记得JAX采用了函数式编程范式,因此Haiku遵循相同的原则。


我们定义一个包含编码器类实例的函数,并返回前向传递的输出。应用hk.transform会返回一个转换后的对象,可以访问两个函数:init和apply。


前者使我们能够使用随机密钥和一些虚拟数据(请注意,这里我们传递了一个形状为batch_size、seq_len的零数组)初始化模块,而后者允许我们处理真实数据。


@hk.transform
def encoder(x):
    model = Encoder(
        vocab_size=VOCAB_SIZE,
        embed_dim=EMBED_DIM,
        seq_len=SEQ_LEN,
        batch_size=BATCH_SIZE,
        n_heads=N_HEADS,
        d_k=D_K,
        hidden_dim=HIDDEN_DIM,
        name="Encoder_module"
    )
    return model(x)
key = jax.random.PRNGKey(0)
params = encoder.init(key, jnp.zeros((BATCH_SIZE, SEQ_LEN), dtype=jnp.int32))
outputs = encoder.apply(params, key, encoded_sequences)


# Note: the two following syntaxes are equivalent
# 1: Using transform as a class decorator
@hk.transform
def encoder(x):
  ...
  return model(x) 
 
encoder.init(...)
encoder.apply(...)
# 2: Applying transfom separately
def encoder(x):
  ...
  return model(x)
encoder_fn = hk.transform(encoder)
encoder_fn.init(...)
encoder_fn.apply(...)



文章来源:https://medium.com/towards-data-science/implementing-a-transformer-encoder-from-scratch-with-jax-and-haiku-791d31b4f0dd
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
写评论取消
回复取消