Skip to content
 
📑标签
🏷yolo 🏷model 🏷python 🏷network

🗒初墨

🍊Hello,各位好,我是面包!

2017年,Google的一篇《Attention Is All You Need》论文犹如一颗深水炸弹,彻底改变了深度学习领域的格局。Transformer架构凭借其独特的自注意力机制,不仅横扫机器翻译任务,更衍生出BERT、GPT、ViT等划时代模型,成为ChatGPT等大模型的基石。它就像一位精通多国语言的画家,既能解析句子的深层语义,又能捕捉图像的全局特征。今天,我们将化身"架构考古学家",深入Transformer的每一块砖瓦,并用代码复现其核心组件。

一、Transformer整体架构:编码器与解码器的交响乐

如果把Transformer比作交响乐团,那么编码器(Encoder)和解码器(Decoder)就是指挥家手中的乐谱。

1.1 架构全景图

  • 编码器(左半部分):由N个相同层堆叠,每层包含多头自注意力机制和前馈神经网络,负责提取输入序列的深层特征。
  • 解码器(右半部分):同样由N个层堆叠,但在自注意力层增加了掩码机制,并引入编码器-解码器注意力,用于生成目标序列。

▲ Transformer整体架构(图源:Google论文)[1]

1.2 翻译任务的流水线演示

以英法翻译"I am a student → Je suis étudiant"为例:

  1. 输入嵌入:将法语词转换为向量,并添加位置编码(类似给每个词发"座位号")。
  2. 编码器处理:经过6层编码器,输出包含上下文信息的特征矩阵。
  3. 解码器生成:从"Begin"开始,逐步预测英语词,直至生成结束符。

二、核心组件拆解:Transformer的四大发明

2.1 输入表示:让模型理解词序的密码

传统RNN通过顺序处理捕捉位置信息,而Transformer另辟蹊径,使用位置编码(Positional Encoding):

python
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)  # 偶数维度用sin
        pe[:, 1::2] = torch.cos(position * div_term)  # 奇数维度用cos
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        x = x + self.pe[:x.size(1), :]  # 词嵌入 + 位置编码
        return x

▲ 位置编码代码实现[3]
原理:通过不同频率的正弦/余弦函数,为每个位置生成独一无二的"坐标"。高频信号区分邻近位置,低频信号捕获长远依赖。

2.2 自注意力机制:全局关系的读心术

自注意力的精髓在于让每个词都能与其他词"对话"。其计算分为四步:

步骤:

  1. 生成QKV矩阵:通过线性变换得到查询(Query)、键(Key)、值(Value)
  2. 计算注意力分数:Q与K的点积衡量词间相关性
  3. Softmax归一化:转化为概率分布
  4. 加权求和:用权重对V加权得到输出
python
# 自注意力代码示例  
input_sequence = torch.tensor([[[0.1, 0.2], [0.4, 0.5]]])  # 输入形状 (batch, seq_len, dim)
WQ = nn.Linear(2, 2); WK = nn.Linear(2,2); WV = nn.Linear(2,2)  

Q = WQ(input_sequence)  # Query矩阵
K = WK(input_sequence)  # Key矩阵  
V = WV(input_sequence)  # Value矩阵  

attn_scores = torch.matmul(Q, K.transpose(-1,-2)) / torch.sqrt(torch.tensor(2.0))  
attn_weights = F.softmax(attn_scores, dim=-1)  
output = torch.matmul(attn_weights, V)  # 加权和

▲ 自注意力核心代码[3]

数学公式:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q,K,V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V

其中dk\sqrt{d_k}用于防止点积过大导致梯度消失。

2.3 多头注意力:并行计算的艺术家

单一注意力如同只用一只眼睛看世界,而多头机制像多双眼睛从不同视角观察:

▲ 多头注意力结构(图源:[1])

python
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_k = d_model // num_heads
        self.num_heads = num_heads
        self.WQ = nn.Linear(d_model, d_model)
        self.WK = nn.Linear(d_model, d_model)
        self.WV = nn.Linear(d_model, d_model)
        self.WO = nn.Linear(d_model, d_model)
        
    def forward(self, x):
        batch_size = x.size(0)
        # 分头处理
        Q = self.WQ(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)
        K = self.WK(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)
        V = self.WV(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)
        # 各头独立计算注意力
        attn = torch.matmul(Q, K.transpose(-1,-2)) / math.sqrt(self.d_k)
        attn = F.softmax(attn, dim=-1)
        output = torch.matmul(attn, V).transpose(1,2).contiguous()
        # 合并多头输出
        output = output.view(batch_size, -1, self.d_model)
        return self.WO(output)

▲ 多头注意力实现[4]

优势:

  • 并行处理:各注意力头独立计算
  • 多视角学习:捕获语法、语义等不同特征

三、编码器与解码器:Transformer的双子星

3.1 编码器层:特征提取的流水线

每个编码器层包含:

  1. 多头自注意力:提取全局依赖
  2. 残差连接 + LayerNorm:防止梯度消失
  3. 前馈网络:非线性变换增强表达能力
python
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff=2048):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
    def forward(self, x):
        # 残差连接
        attn_output = self.self_attn(x)
        x = self.norm1(x + attn_output)
        ffn_output = self.ffn(x)
        x = self.norm2(x + ffn_output)
        return x

▲ 编码器层实现[3]

3.2 解码器层:掩码与交叉注意力的奥秘

解码器的特殊设计:

  • 掩码多头注意力:防止当前位置看到未来信息(类似遮挡后半句)
  • 编码器-解码器注意力:让解码器聚焦编码器输出

▲ 解码器的掩码机制(图源:[1])

python
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff=2048):
        super().__init__()
        self.masked_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = nn.Sequential(...)  # 同编码器
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        
    def forward(self, x, encoder_output, tgt_mask):
        # 掩码自注意力
        attn = self.masked_attn(x, x, x, tgt_mask)
        x = self.norm1(x + attn)
        # 编码器-解码器注意力
        cross_attn = self.cross_attn(x, encoder_output, encoder_output)
        x = self.norm2(x + cross_attn)
        # 前馈
        x = self.norm3(x + self.ffn(x))
        return x

▲ 解码器层核心代码[3]


四、Transformer的现代变体:进化之路

4.1 视觉Transformer(ViT)[6]

  • 图像分块:将图片划分为16x16的patch
  • 位置编码:学习图像的空间关系
  • 分类Token:类似NLP中的[CLS] token用于分类

4.2 Swin Transformer[6]

  • 窗口注意力:在局部窗口计算注意力,降低计算量
  • 移位窗口:通过窗口滑动实现跨窗口信息交互

4.3 最新研究进展[5]

  • 层共享表征:实验表明中间层共享相似表示空间
  • 动态路径:部分层可跳过或并行执行,提升效率

五、手撕Transformer:从零搭建迷你版

python
import torch
import torch.nn as nn

class MiniTransformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, num_layers=6, num_heads=8):
        super().__init__()
        # 嵌入层
        self.src_embed = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embed = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model)
        
        # 编码器与解码器堆叠
        self.encoder = nn.ModuleList([EncoderLayer(d_model, num_heads) for _ in range(num_layers)])
        self.decoder = nn.ModuleList([DecoderLayer(d_model, num_heads) for _ in range(num_layers)])
        
        # 输出层
        self.fc_out = nn.Linear(d_model, tgt_vocab_size)
        
    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        # 嵌入 + 位置编码
        src = self.pos_encoder(self.src_embed(src))
        tgt = self.pos_encoder(self.tgt_embed(tgt))
        
        # 编码器处理
        for layer in self.encoder:
            src = layer(src)
            
        # 解码器处理
        for layer in self.decoder:
            tgt = layer(tgt, src, tgt_mask)
            
        # 输出预测
        return self.fc_out(tgt)

▲ 迷你版Transformer实现(整合自[3][4])


结语:Transformer的无限可能

从机器翻译到图像生成,从蛋白质结构预测到自动驾驶,Transformer正在重塑AI的边界。其成功启示我们:对全局关系的建模能力,远比局部归纳偏好更重要。正如论文作者所言:"Attention is All You Need",但这远不是故事的终点——未来,我们或许会看到更多突破性架构,但Transformer的智慧火花将永远闪耀在AI的星空中。