Transformer 架构详解
大约 13 分钟约 3891 字
Transformer 架构详解
简介
Transformer 是现代大语言模型(LLM)的基础架构,通过自注意力机制(Self-Attention)实现并行化的序列建模。理解 Multi-Head Attention、位置编码和 Transformer 的变体(GPT、BERT、T5),有助于深入理解大模型的工作原理。
特点
自注意力机制
Self-Attention 原理
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# Self-Attention 的核心:Q(查询)、K(键)、V(值)
# Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) * V
class SelfAttention(nn.Module):
def __init__(self, embed_dim, dropout=0.1):
super().__init__()
self.embed_dim = embed_dim
self.scale = math.sqrt(embed_dim)
# 线性投影
self.W_q = nn.Linear(embed_dim, embed_dim)
self.W_k = nn.Linear(embed_dim, embed_dim)
self.W_v = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.shape
# 计算查询、键、值
Q = self.W_q(x) # (batch, seq_len, embed_dim)
K = self.W_k(x) # (batch, seq_len, embed_dim)
V = self.W_v(x) # (batch, seq_len, embed_dim)
# 计算注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
# scores: (batch, seq_len, seq_len)
# 因果掩码(GPT 类模型使用)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax 归一化
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
# 加权求和
output = torch.matmul(attn_weights, V)
# output: (batch, seq_len, embed_dim)
return output, attn_weights
# Multi-Head Attention
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0.1):
super().__init__()
assert embed_dim % num_heads == 0, "embed_dim 必须能被 num_heads 整除"
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.scale = math.sqrt(self.head_dim)
# 合并的 QKV 投影(更高效)
self.W_qkv = nn.Linear(embed_dim, 3 * embed_dim)
self.W_out = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.shape
# 计算 QKV
qkv = self.W_qkv(x)
qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, batch, heads, seq, head_dim)
Q, K, V = qkv[0], qkv[1], qkv[2]
# 注意力计算
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
# 加权求和
context = torch.matmul(attn_weights, V)
# (batch, heads, seq, head_dim) → (batch, seq, embed_dim)
context = context.transpose(1, 2).contiguous().reshape(batch_size, seq_len, self.embed_dim)
output = self.W_out(context)
return output
# 位置前馈网络
class FeedForward(nn.Module):
def __init__(self, embed_dim, ff_dim, dropout=0.1):
super().__init__()
self.net = nn.Sequential(
nn.Linear(embed_dim, ff_dim),
nn.GELU(), # GELU 激活(比 ReLU 更平滑)
nn.Dropout(dropout),
nn.Linear(ff_dim, embed_dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)位置编码
正弦编码与 RoPE
# 1. 正弦位置编码(原始 Transformer)
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, embed_dim, max_len=5000, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
pe = torch.zeros(max_len, embed_dim)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, embed_dim, 2).float() *
(-math.log(10000.0) / embed_dim))
pe[:, 0::2] = torch.sin(position * div_term) # 偶数维度
pe[:, 1::2] = torch.cos(position * div_term) # 奇数维度
pe = pe.unsqueeze(0) # (1, max_len, embed_dim)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
# 2. 旋转位置编码(RoPE,LLaMA 使用)
class RotaryPositionalEmbedding(nn.Module):
def __init__(self, head_dim, max_len=8192):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim))
self.register_buffer('inv_freq', inv_freq)
self.max_len = max_len
self._set_cos_sin_cache(max_len)
def _set_cos_sin_cache(self, seq_len):
t = torch.arange(seq_len).float()
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer('cos_cached', emb.cos().unsqueeze(0).unsqueeze(0))
self.register_buffer('sin_cached', emb.sin().unsqueeze(0).unsqueeze(0)
def forward(self, x, seq_len=None):
if seq_len is None:
seq_len = x.shape[2]
return (
x * self.cos_cached[:, :, :seq_len, :],
x * self.sin_cached[:, :, :seq_len, :]
)
@staticmethod
def rotate_half(x):
x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary(self, q, k):
q_cos, q_sin = self.forward(q)
k_cos, k_sin = self.forward(k)
q_rotated = q * q_cos + self.rotate_half(q) * q_sin
k_rotated = k * k_cos + self.rotate_half(k) * k_sin
return q_rotated, k_rotatedTransformer Block
GPT 风格的 Decoder Block
# GPT 风格的 Transformer Block(Decoder-Only)
class GPTBlock(nn.Module):
def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
super().__init__()
# Pre-LayerNorm(比 Post-LayerNorm 训练更稳定)
self.ln1 = nn.LayerNorm(embed_dim)
self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
self.ln2 = nn.LayerNorm(embed_dim)
self.ffn = FeedForward(embed_dim, ff_dim, dropout)
def forward(self, x, mask=None):
# 自注意力 + 残差连接
x = x + self.attn(self.ln1(x), mask)
# 前馈网络 + 残差连接
x = x + self.ffn(self.ln2(x))
return x
# GPT 模型
class GPTModel(nn.Module):
def __init__(self, vocab_size, embed_dim=768, num_heads=12,
num_layers=12, ff_dim=3072, max_len=2048, dropout=0.1):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, embed_dim)
self.position_embedding = SinusoidalPositionalEncoding(embed_dim, max_len, dropout)
self.blocks = nn.ModuleList([
GPTBlock(embed_dim, num_heads, ff_dim, dropout)
for _ in range(num_layers)
])
self.ln_f = nn.LayerNorm(embed_dim)
self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
# 权重共享
self.lm_head.weight = self.token_embedding.weight
# 参数初始化
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, input_ids):
batch_size, seq_len = input_ids.shape
# 因果掩码
mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device))
mask = mask.unsqueeze(0).unsqueeze(0) # (1, 1, seq, seq)
# Token + 位置编码
x = self.token_embedding(input_ids)
x = self.position_embedding(x)
# Transformer Blocks
for block in self.blocks:
x = block(x, mask)
# 输出层
x = self.ln_f(x)
logits = self.lm_head(x)
return logits
# BERT 风格的 Encoder Block(双向注意力)
class BERTBlock(nn.Module):
def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
super().__init__()
self.ln1 = nn.LayerNorm(embed_dim)
self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
self.ln2 = nn.LayerNorm(embed_dim)
self.ffn = FeedForward(embed_dim, ff_dim, dropout)
def forward(self, x, attention_mask=None):
# 双向注意力(不使用因果掩码)
x = x + self.attn(self.ln1(x), attention_mask)
x = x + self.ffn(self.ln2(x))
return xKV Cache 优化
推理加速
# KV Cache — 缓存已计算的 Key 和 Value,避免重复计算
class CachedMultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.W_q = nn.Linear(embed_dim, embed_dim, bias=False)
self.W_k = nn.Linear(embed_dim, embed_dim, bias=False)
self.W_v = nn.Linear(embed_dim, embed_dim, bias=False)
self.W_out = nn.Linear(embed_dim, embed_dim, bias=False)
def forward(self, x, kv_cache=None, start_pos=0):
batch_size, seq_len, _ = x.shape
Q = self.W_q(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
K = self.W_k(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
V = self.W_v(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# KV Cache
if kv_cache is not None:
cache_k, cache_v = kv_cache
K = torch.cat([cache_k, K], dim=2)
V = torch.cat([cache_v, V], dim=2)
new_kv_cache = (K, V)
# 注意力计算
scale = math.sqrt(self.head_dim)
scores = torch.matmul(Q, K.transpose(-2, -1)) / scale
# 因果掩码(只需要掩码新 token)
if seq_len > 1:
mask = torch.tril(torch.ones(seq_len, K.size(2), device=x.device))
mask = mask[:, -seq_len:]
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
context = torch.matmul(attn_weights, V)
context = context.transpose(1, 2).reshape(batch_size, seq_len, self.embed_dim)
return self.W_out(context), new_kv_cacheFlash Attention
内存高效的注意力
# Flash Attention 思想:分块计算注意力,减少 HBM 访问
# 标准 Attention 的内存复杂度:O(N^2)
# Flash Attention 的内存复杂度:O(N)
# Flash Attention 伪代码
class FlashAttentionSimulation:
"""
Flash Attention 核心思想:
1. 将 Q, K, V 分成小块(block)
2. 对每个块计算注意力
3. 使用在线 Softmax 技巧(数值稳定)
4. 避免在 HBM 中存储完整的 N×N 注意力矩阵
"""
@staticmethod
def flash_attention_forward(Q, K, V, block_size=64):
batch, heads, seq_len, head_dim = Q.shape
scale = math.sqrt(head_dim)
# 输出初始化
O = torch.zeros_like(Q)
l = torch.zeros(batch, heads, seq_len, 1, device=Q.device) # 累积分母
m = torch.full((batch, heads, seq_len, 1), float('-inf'), device=Q.device) # 累积最大值
# 分块计算
for i in range(0, seq_len, block_size):
Qi = Q[:, :, i:i+block_size] # (batch, heads, block, head_dim)
for j in range(0, seq_len, block_size):
Kj = K[:, :, j:j+block_size]
Vj = V[:, :, j:j+block_size]
# 计算块内注意力分数
Sij = torch.matmul(Qi, Kj.transpose(-2, -1)) / scale
# 在线 Softmax 更新
m_new = torch.max(m[:, :, i:i+block_size], Sij.max(dim=-1, keepdim=True).values)
Pij = torch.exp(Sij - m_new)
l_new = torch.exp(m[:, :, i:i+block_size] - m_new) * l[:, :, i:i+block_size] + Pij.sum(dim=-1, keepdim=True)
# 更新输出
O[:, :, i:i+block_size] = (
O[:, :, i:i+block_size] * torch.exp(m[:, :, i:i+block_size] - m_new) * l[:, :, i:i+block_size]
+ torch.matmul(Pij, Vj)
) / l_new
m[:, :, i:i+block_size] = m_new
l[:, :, i:i+block_size] = l_new
return O
# 实际使用 Flash Attention 2
# pip install flash-attn
# from flash_attn import flash_attn_func
# output = flash_attn_func(q, k, v, causal=True)Flash Attention 原理
Flash Attention 是现代大模型训练中最重要的优化之一。它通过分块计算(tiling)减少对 HBM(高带宽显存)的访问次数,在不牺牲精度的情况下大幅降低显存使用和提升计算速度。
# Flash Attention 原理示意
# 标准 Attention 的问题:
# 1. 需要实例化完整的 N x N 注意力矩阵,显存占用 O(N^2)
# 2. 多次读写 HBM,内存带宽成为瓶颈
# Flash Attention 的解决方案:
# 1. 分块计算:将 Q、K、V 分成小块,在 SRAM 中完成注意力计算
# 2. 在线 Softmax:逐块累积 softmax 结果,不需要完整矩阵
# 3. 重计算策略:前向不保存注意力矩阵,反向时重计算
# 使用 Flash Attention 2(推荐)
import torch
# pip install flash-attn
# from flash_attn import flash_attn_func
# PyTorch 2.0+ 内置支持
# scale_dot_product_attention 会自动使用 Flash Attention
# torch.nn.functional.scaled_dot_product_attention(q, k, v)
# 显存对比(以 seq_len=4096, hidden=4096 为例):
# 标准 Attention: 约 64GB 显存(N^2 矩阵)
# Flash Attention: 约 0.5GB 显存(线性增长)
# 速度提升:2-4x
# 在 HuggingFace Transformers 中启用
# model = AutoModel.from_pretrained("...", torch_dtype=torch.float16, attn_implementation="flash_attention_2")注意力优化变体
# 1. Multi-Query Attention (MQA) — GPT-4, PaLM
# 所有注意力头共享同一组 K 和 V,只有 Q 是多头
class MultiQueryAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.W_q = nn.Linear(embed_dim, embed_dim) # Q: 多头
self.W_k = nn.Linear(embed_dim, self.head_dim) # K: 单头
self.W_v = nn.Linear(embed_dim, self.head_dim) # V: 单头
def forward(self, x):
B, L, _ = x.shape
Q = self.W_q(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
K = self.W_k(x).unsqueeze(1) # (B, 1, L, head_dim) 广播到所有头
V = self.W_v(x).unsqueeze(1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
attn = F.softmax(scores, dim=-1)
return torch.matmul(attn, V).transpose(1, 2).contiguous().view(B, L, -1)
# 2. Grouped-Query Attention (GQA) — LLaMA 2, Mistral
# KV 头数介于 1 和 num_heads 之间,是 MHA 和 MQA 的折中
class GroupedQueryAttention(nn.Module):
def __init__(self, embed_dim, num_heads, num_kv_heads):
super().__init__()
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = embed_dim // num_heads
self.kv_groups = num_heads // num_kv_heads
self.W_q = nn.Linear(embed_dim, embed_dim)
self.W_k = nn.Linear(embed_dim, num_kv_heads * self.head_dim)
self.W_v = nn.Linear(embed_dim, num_kv_heads * self.head_dim)
def forward(self, x):
B, L, _ = x.shape
Q = self.W_q(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
K = self.W_k(x).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
V = self.W_v(x).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
# 扩展 KV 头以匹配 Q 头数
K = K.repeat_interleave(self.kv_groups, dim=1)
V = V.repeat_interleave(self.kv_groups, dim=1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
attn = F.softmax(scores, dim=-1)
return torch.matmul(attn, V).transpose(1, 2).contiguous().view(B, L, -1)
# 3. Sliding Window Attention — Mistral, Longformer
# 每个位置只关注固定窗口内的 token,复杂度从 O(N^2) 降到 O(N*W)
# Mistral 使用 window_size=4096,信息通过多层传播间接覆盖长距离现代激活函数
# Transformer 中的激活函数演进
# 1. ReLU — 原始 Transformer 使用
# nn.ReLU()
# 2. GELU — BERT、GPT-2 使用(平滑的 ReLU 变体)
# nn.GELU()
# 3. SwiGLU — LLaMA、Mistral 等现代模型使用
class SwiGLU(nn.Module):
"""Swish-Gated Linear Unit — 现代 LLM 的标准激活函数"""
def __init__(self, embed_dim, hidden_dim):
super().__init__()
self.w1 = nn.Linear(embed_dim, hidden_dim, bias=False) # gate
self.w2 = nn.Linear(embed_dim, hidden_dim, bias=False) # up
self.w3 = nn.Linear(hidden_dim, embed_dim, bias=False) # down
def forward(self, x):
# SwiGLU(x) = (x * W1 * sigmoid(x * W1)) * (x * W2) * W3
return self.w3(F.silu(self.w1(x)) * self.w2(x))
# 对比:
# ReLU FFN: output = W2(ReLU(W1(x))) # 2 个权重矩阵
# GLU FFN: output = W2((W1(x) * sigmoid(W1(x)))) # 需要 gate 机制
# SwiGLU: output = W3(silu(W1(x)) * W2(x)) # 3 个权重矩阵,效果最好长上下文处理技术
# 扩展 Transformer 上下文长度的关键技术
# 1. RoPE 外推(Position Interpolation)
# 通过缩放位置索引来扩展上下文窗口
def apply_rotary_emb_with_scaling(x, cos, sin, scale_factor=1.0):
"""带缩放的旋转位置编码"""
# 原始位置:0, 1, 2, ..., max_len
# 缩放后: 0, 1/s, 2/s, ..., max_len/s
# 使得更长的位置可以映射到已训练的位置范围内
scaled_cos = torch.cos(torch.arange(x.shape[1]) / scale_factor * theta)
scaled_sin = torch.sin(torch.arange(x.shape[1]) / scale_factor * theta)
# ... 应用旋转
# 2. YaRN (Yet another RoPE extensioN)
# 结合位置插值和注意力温度调整
# 在 LLaMA 等模型中成功将上下文从 4K 扩展到 128K
# 3. ALiBi (Attention with Linear Biases)
# 通过线性偏置替代位置编码,天然支持长度外推
class ALiBiAttention(nn.Module):
"""ALiBi 注意力 — 无位置编码,通过线性偏置注入位置信息"""
def __init__(self, num_heads):
super().__init__()
# 每个头的斜率不同
slopes = 2 ** (-8 * torch.arange(1, num_heads + 1) / num_heads)
self.register_buffer('slopes', slopes)
def forward(self, q, k, v):
batch, heads, seq_len, dim = q.shape
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(dim)
# 添加线性偏置(距离越远,惩罚越大)
positions = torch.arange(seq_len)
distance = positions.unsqueeze(0) - positions.unsqueeze(1) # (L, L)
bias = self.slopes.unsqueeze(1).unsqueeze(1) * distance.unsqueeze(0) # (H, L, L)
scores = scores - bias
attn = F.softmax(scores, dim=-1)
return torch.matmul(attn, v)Transformer 模型规模对比
模型规模演进:
模型名称 参数量 层数 隐藏维度 注意力头数 上下文长度
-----------------------------------------------------------------------
BERT-base 110M 12 768 12 512
BERT-large 340M 24 1024 16 512
GPT-2 1.5B 48 1600 25 1024
GPT-3 175B 96 12288 96 2048
LLaMA-7B 7B 32 4096 32 2048
LLaMA-2-70B 70B 80 8192 64 4096
LLaMA-3-8B 8B 32 4096 32 8192
LLaMA-3-70B 70B 80 8192 64 8192
Mistral-7B 7B 32 4096 32(GQA) 32768
Mixtral-8x7B 47B(13B活跃) 32 4096 32(GQA) 32768
GPT-4 ~1.8T(推测) ? ? ? 128K
关键趋势:
1. 参数量从百万级增长到万亿级
2. 上下文长度从 512 扩展到 128K+
3. GQA/MQA 替代 MHA 减少 KV Cache 开销
4. SwiGLU 替代 GELU 提升效果
5. RoPE 替代正弦位置编码支持长度外推
6. MoE(混合专家)模型在保持推理成本的同时增加总参数量优点
缺点
总结
Transformer 通过自注意力机制实现序列建模,Multi-Head Attention 在多个子空间并行计算注意力。位置编码(正弦/RoPE)注入位置信息。GPT 使用 Decoder-Only 架构(因果注意力),BERT 使用 Encoder-Only 架构(双向注意力)。KV Cache 缓存已计算的 Key/Value 加速推理。Flash Attention 通过分块计算减少显存访问,是现代大模型训练的标准优化。
关键知识点
- 先分清模型能力边界、数据边界和工程边界。
- 任何 AI 主题都不只看效果,还要看延迟、成本、可解释性和安全性。
- 评估方式和失败样例往往比“换哪个模型”更重要。
- 这类主题通常同时涉及表示学习、上下文建模和推理成本。
项目落地视角
- 给数据来源、Prompt 模板、Embedding 版本、评估集和实验结果做版本管理。
- 上线前准备兜底策略,例如拒答、回退、人工审核或缓存降级。
- 观察错误类型时,区分数据问题、召回问题、提示词问题和模型问题。
- 把提示词拆成角色、任务、约束、输出格式和失败兜底几部分。
常见误区
- 只关注 Demo 效果,不考虑线上稳定性和可复现性。
- 没有评估集就频繁调参,最后无法解释为什么变好或变差。
- 忽略权限、审计、隐私和模型输出的安全边界。
- 把所有问题都归因于提示词,而忽略数据和模型能力边界。
进阶路线
- 继续补齐训练、推理、评估、MLOps 和治理链路。
- 把主题放回真实业务流程,思考谁提供数据、谁消费结果、谁负责兜底。
- 把 PoC 逐步升级到可观测、可回滚、可演进的生产方案。
- 继续深入长上下文、结构化输出、提示模板管理和模型评估。
适用场景
- 当你准备把《Transformer 架构详解》真正落到项目里时,最适合先在一个独立模块或最小样例里验证关键路径。
- 适合企业知识问答、内容生成、分类抽取和智能助手等场景。
- 当需求同时关注效果、时延、成本和安全边界时,这类主题最有价值。
落地建议
- 先定义评估集、成功标准和失败样例,再开始调模型或调提示。
- 把数据来源、分块方式、Embedding 版本和 Prompt 模板纳入版本管理。
- 上线前准备兜底策略,例如拒答、回退、人工审核或检索降级。
排错清单
- 先判断问题出在数据、检索、Prompt、模型还是后处理。
- 检查上下文是否过长、分块是否过碎或召回是否偏题。
- 对错误回答做分类,区分幻觉、事实过时、指令误解和格式错误。
复盘问题
- 如果把《Transformer 架构详解》放进你的当前项目,最先要验证的输入、输出和失败路径分别是什么?
- 《Transformer 架构详解》最容易在什么规模、什么边界条件下暴露问题?你会用什么指标或日志去确认?
- 相比默认实现或替代方案,采用《Transformer 架构详解》最大的收益和代价分别是什么?
