TensorRT-LLM-06-Layers模块-深度剖析
一、模块概览
1.1 模块定位
Layers 模块是 TensorRT-LLM 的基础组件层,提供了构建大语言模型所需的各种神经网络层,包括Attention、MLP、Normalization、Embedding等。每个层都经过TensorRT优化,支持混合精度、量化和并行计算。
核心职责:
- 基础层实现:Attention、MLP、LayerNorm、Embedding
- TensorRT优化:自定义kernel、层融合、内存优化
- 并行支持:Tensor Parallel、Pipeline Parallel
- 量化支持:FP8、INT8、INT4等
- 特殊架构:MoE、Mamba、RoPE、ALiBi等
1.2 模块架构
tensorrt_llm/layers/
├── attention.py # Attention层实现
│ ├── Attention # 通用Attention
│ ├── BertAttention # BERT风格Attention
│ ├── DeepseekV2Attention # DeepSeek V2 MLA
│ └── CogVLMAttention # CogVLM多模态Attention
│
├── mlp.py # MLP层实现
│ ├── MLP # 标准MLP
│ ├── GatedMLP # 门控MLP(SwiGLU等)
│ └── FusedGatedMLP # 融合门控MLP
│
├── normalization.py # 标准化层
│ ├── LayerNorm # 标准LayerNorm
│ ├── RmsNorm # RMS LayerNorm
│ └── GroupNorm # Group Normalization
│
├── linear.py # 线性层
│ ├── Linear # 标准线性层
│ ├── ColumnLinear # 列并行线性层
│ └── RowLinear # 行并行线性层
│
├── embedding.py # 嵌入层
│ ├── Embedding # 标准嵌入层
│ └── PromptTuningEmbedding # Prompt Tuning嵌入
│
├── moe.py # MoE层
│ ├── MOE # Mixture of Experts
│ ├── MoeConfig # MoE配置
│ └── SharedMoE # 共享MoE
│
├── ssm.py # 状态空间模型
│ ├── Mamba # Mamba架构
│ └── Mamba2 # Mamba2架构
│
└── lora.py # LoRA层
├── Lora # LoRA实现
└── LoraParams # LoRA参数
1.3 层分类
| 类别 | 层名称 | 功能 | 使用场景 |
|---|---|---|---|
| 注意力层 | Attention | 多头自注意力 | 所有Transformer模型 |
| BertAttention | BERT风格注意力 | BERT系列模型 | |
| DeepseekV2Attention | Multi-Head Latent Attention | DeepSeek V2 | |
| MLP层 | MLP | 标准前馈网络 | 传统Transformer |
| GatedMLP | 门控MLP | Llama、Mistral等 | |
| FusedGatedMLP | 融合门控MLP | 性能优化版本 | |
| 标准化层 | LayerNorm | 层标准化 | GPT系列 |
| RmsNorm | RMS标准化 | Llama系列 | |
| 线性层 | Linear | 标准线性变换 | 通用 |
| ColumnLinear | 列并行线性层 | Tensor Parallel | |
| RowLinear | 行并行线性层 | Tensor Parallel | |
| 特殊层 | MOE | 专家混合 | Mixtral等 |
| Mamba | 状态空间模型 | Mamba架构 | |
| Lora | 低秩适配 | 参数高效微调 |
二、核心API详细剖析
2.1 Attention层
2.1.1 类定义
class Attention(Module):
"""
多头自注意力层
支持特性:
- Multi-Head Attention (MHA)
- Grouped-Query Attention (GQA)
- Multi-Query Attention (MQA)
- RoPE、ALiBi等位置编码
- Flash Attention优化
- KV Cache
- Tensor Parallel并行
"""
def __init__(
self,
*,
local_layer_idx: int, # 层索引
hidden_size: int, # 隐藏层维度
num_attention_heads: int, # 注意力头数
num_kv_heads: int = None, # KV头数(GQA/MQA)
max_position_embeddings: int = 1024, # 最大位置编码
attention_head_size: int = None, # 注意力头维度
attention_mask_type: AttentionMaskType = AttentionMaskType.padding,
position_embedding_type: PositionEmbeddingType = PositionEmbeddingType.learned_absolute,
rotary_embedding_base: float = 10000.0, # RoPE基数
bias: bool = True, # 是否使用bias
dtype = None, # 数据类型
tp_group = None, # TP通信组
tp_size: int = 1, # TP并行度
quant_mode: QuantMode = QuantMode(0), # 量化模式
**kwargs
):
2.1.2 核心参数详解
注意力配置
| 参数 | 类型 | 说明 | 示例 |
|---|---|---|---|
| num_attention_heads | int | Query头数 | 32(Llama-7B) |
| num_kv_heads | int | Key/Value头数 | 8(GQA),1(MQA),32(MHA) |
| attention_head_size | int | 每个头的维度 | 128(4096/32) |
| max_position_embeddings | int | 最大序列长度 | 2048, 4096, 8192 |
位置编码类型
| 类型 | 说明 | 使用模型 |
|---|---|---|
| learned_absolute | 学习的绝对位置编码 | GPT-2 |
| rope_gpt_neox | RoPE(GPT-NeoX风格) | Llama, Mistral |
| alibi | ALiBi线性偏置 | PaLM, BLOOM |
| relative_position | 相对位置编码 | T5 |
2.1.3 forward()方法实现
def forward(
self,
hidden_states: Tensor, # [batch, seq_len, hidden_size]
attention_mask: Tensor = None, # 注意力掩码
use_cache: bool = False, # 是否使用KV Cache
kv_cache_params: KeyValueCacheParams = None, # KV Cache参数
attention_params: AttentionParams = None, # 注意力参数
lora_layer_params: LoraRuntimeParams = None, # LoRA参数
**kwargs
) -> Tuple[Tensor, Optional[Tensor]]:
"""
Attention前向传播
Returns:
context: [batch, seq_len, hidden_size]
present_key_value: KV Cache(可选)
"""
# 1. QKV投影
batch_size, seq_len, hidden_size = hidden_states.size()
# 1.1 QKV线性变换
if self.cross_attention:
# Cross Attention:Q来自decoder,KV来自encoder
q = self.q(hidden_states) # [batch, seq_len, num_q_heads * head_size]
k = self.k(encoder_hidden_states)
v = self.v(encoder_hidden_states)
else:
# Self Attention:QKV都来自同一输入
if hasattr(self, 'qkv'):
# 融合QKV投影(性能更好)
qkv = self.qkv(hidden_states) # [batch, seq_len, (num_q_heads + 2*num_kv_heads) * head_size]
# 分离Q, K, V
q_size = self.num_attention_heads * self.attention_head_size
kv_size = self.num_attention_kv_heads * self.attention_head_size
q = qkv[:, :, :q_size]
k = qkv[:, :, q_size:q_size + kv_size]
v = qkv[:, :, q_size + kv_size:]
else:
# 分离QKV投影
q = self.q(hidden_states)
k = self.k(hidden_states)
v = self.v(hidden_states)
# 1.2 Reshape为多头格式
q = q.view(batch_size, seq_len, self.num_attention_heads, self.attention_head_size)
k = k.view(batch_size, seq_len, self.num_attention_kv_heads, self.attention_head_size)
v = v.view(batch_size, seq_len, self.num_attention_kv_heads, self.attention_head_size)
# 2. 位置编码
if self.position_embedding_type == PositionEmbeddingType.rope_gpt_neox:
# RoPE位置编码
q, k = self._apply_rotary_pos_emb(q, k, position_ids)
elif self.position_embedding_type == PositionEmbeddingType.alibi:
# ALiBi位置编码(在注意力计算中添加)
pass
# 3. KV Cache处理
if use_cache:
if kv_cache_params is not None:
# 从cache中获取past_key_value
past_key_value = kv_cache_params.get_cache_kv(self.local_layer_idx)
if past_key_value is not None:
past_k, past_v = past_key_value
k = torch.cat([past_k, k], dim=-2) # 在seq_len维度拼接
v = torch.cat([past_v, v], dim=-2)
# 更新cache
kv_cache_params.update_cache_kv(self.local_layer_idx, (k, v))
# 4. Grouped-Query Attention处理
if self.num_attention_kv_heads < self.num_attention_heads:
# GQA/MQA:扩展KV以匹配Q的头数
num_heads_per_kv = self.num_attention_heads // self.num_attention_kv_heads
k = k.repeat_interleave(num_heads_per_kv, dim=2)
v = v.repeat_interleave(num_heads_per_kv, dim=2)
# 5. Attention计算
if self.use_flash_attention:
# Flash Attention(内存高效)
context = self._flash_attention(q, k, v, attention_mask)
else:
# 标准Attention
context = self._standard_attention(q, k, v, attention_mask)
# 6. 输出投影
context = context.view(batch_size, seq_len, self.attention_hidden_size)
output = self.dense(context) # [batch, seq_len, hidden_size]
# 7. 返回结果
present_key_value = (k, v) if use_cache else None
return output, present_key_value
def _standard_attention(self, q, k, v, attention_mask):
"""
标准注意力计算
"""
# 1. 计算注意力分数
# q: [batch, seq_len, num_heads, head_size]
# k: [batch, seq_len, num_heads, head_size]
scores = torch.matmul(q, k.transpose(-2, -1)) # [batch, num_heads, seq_len, seq_len]
scores = scores / math.sqrt(self.attention_head_size) # 缩放
# 2. 添加位置编码偏置
if self.position_embedding_type == PositionEmbeddingType.alibi:
alibi_biases = self._compute_alibi_biases(seq_len)
scores = scores + alibi_biases
# 3. 应用注意力掩码
if attention_mask is not None:
scores = scores + attention_mask # 掩码位置设为-inf
# 4. Softmax
attn_weights = torch.softmax(scores, dim=-1)
# 5. 应用Dropout(训练时)
if self.training:
attn_weights = self.dropout(attn_weights)
# 6. 加权求和
context = torch.matmul(attn_weights, v) # [batch, num_heads, seq_len, head_size]
return context
def _apply_rotary_pos_emb(self, q, k, position_ids):
"""
应用RoPE位置编码
"""
# 1. 计算旋转角度
inv_freq = 1.0 / (self.rotary_embedding_base **
(torch.arange(0, self.attention_head_size, 2).float() / self.attention_head_size))
# 2. 构建旋转矩阵
freqs = position_ids.unsqueeze(-1) * inv_freq
cos = freqs.cos()
sin = freqs.sin()
# 3. 应用旋转
def rotate_half(x):
x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
return torch.cat([-x2, x1], dim=-1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
2.1.4 时序图
sequenceDiagram
autonumber
participant Input as hidden_states
participant QKV as QKV Projection
participant PosEnc as Position Encoding
participant Cache as KV Cache
participant AttnCalc as Attention Calculation
participant Output as Output Projection
Input->>QKV: 线性变换
activate QKV
alt 融合QKV
QKV->>QKV: qkv = self.qkv(hidden_states)
QKV->>QKV: 分离Q, K, V
else 分离QKV
QKV->>QKV: q = self.q(hidden_states)
QKV->>QKV: k = self.k(hidden_states)
QKV->>QKV: v = self.v(hidden_states)
end
QKV->>QKV: reshape为多头格式
QKV-->>PosEnc: Q, K, V
deactivate QKV
PosEnc->>PosEnc: 应用位置编码
activate PosEnc
alt RoPE
PosEnc->>PosEnc: _apply_rotary_pos_emb(q, k)
else ALiBi
PosEnc->>PosEnc: 计算ALiBi偏置(延后到attention)
end
PosEnc-->>Cache: Q, K, V
deactivate PosEnc
alt 使用KV Cache
Cache->>Cache: 获取past_key_value
Cache->>Cache: 拼接历史KV
Cache->>Cache: 更新cache
end
Cache-->>AttnCalc: Q, K, V
AttnCalc->>AttnCalc: GQA处理(扩展KV)
activate AttnCalc
alt Flash Attention
AttnCalc->>AttnCalc: _flash_attention(q, k, v)
else 标准Attention
AttnCalc->>AttnCalc: scores = Q @ K^T / sqrt(d)
AttnCalc->>AttnCalc: 应用attention_mask
AttnCalc->>AttnCalc: softmax(scores)
AttnCalc->>AttnCalc: context = attn_weights @ V
end
AttnCalc-->>Output: context
deactivate AttnCalc
Output->>Output: reshape展平
activate Output
Output->>Output: self.dense(context)
Output-->>Input: output, present_kv
deactivate Output
2.2 GatedMLP层
2.2.1 类定义
class GatedMLP(Module):
"""
门控MLP层(如SwiGLU)
架构:
Input → Gate(W1) * Up(W2) → Activation → Down(W3) → Output
其中:
- Gate和Up投影并行计算
- 激活函数通常为SiLU/Swish
- 支持SwiGLU、GeGELU等变体
"""
def __init__(
self,
hidden_size: int, # 输入/输出维度
ffn_hidden_size: int, # 中间层维度(通常是hidden_size的2.67倍)
hidden_act: str, # 激活函数("silu", "gelu", "swiglu"等)
bias: bool = True, # 是否使用bias
dtype = None, # 数据类型
tp_group = None, # TP通信组
tp_size: int = 1, # TP并行度
**kwargs
):
super().__init__()
# Gate和Up投影融合(性能优化)
self.gate_up_proj = ColumnLinear(
in_features=hidden_size,
out_features=2 * ffn_hidden_size, # 同时计算Gate和Up
bias=bias,
dtype=dtype,
tp_group=tp_group,
tp_size=tp_size,
gather_output=False, # 不收集结果(保持分片状态)
)
# Down投影
self.down_proj = RowLinear(
in_features=ffn_hidden_size,
out_features=hidden_size,
bias=bias,
dtype=dtype,
tp_group=tp_group,
tp_size=tp_size,
)
self.hidden_act = hidden_act
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
2.2.2 forward()方法实现
def forward(self, hidden_states: Tensor) -> Tensor:
"""
门控MLP前向传播
Args:
hidden_states: [batch, seq_len, hidden_size]
Returns:
output: [batch, seq_len, hidden_size]
"""
# 1. Gate和Up投影(融合计算)
gate_up = self.gate_up_proj(hidden_states) # [batch, seq_len, 2 * ffn_hidden_size]
# 2. 分离Gate和Up
gate, up = gate_up.chunk(2, dim=-1) # 各自为[batch, seq_len, ffn_hidden_size]
# 3. 计算门控激活
if self.hidden_act == "silu":
# SiLU/Swish激活:x * sigmoid(x)
intermediate = gate * torch.sigmoid(gate) * up
elif self.hidden_act == "gelu":
# GELU激活
intermediate = torch.nn.functional.gelu(gate) * up
elif self.hidden_act == "swiglu":
# SwiGLU:SiLU门控的GLU变体
intermediate = torch.nn.functional.silu(gate) * up
elif self.hidden_act == "gegelu":
# GeGELU:GELU门控的GLU变体
intermediate = torch.nn.functional.gelu(gate) * up
else:
raise ValueError(f"Unsupported activation: {self.hidden_act}")
# 4. Down投影
output = self.down_proj(intermediate) # [batch, seq_len, hidden_size]
return output
2.2.3 Tensor Parallel实现
# ColumnLinear(列并行):权重按列切分
class ColumnLinear(Module):
"""
列并行线性层
权重分布:
- GPU 0: W[:, 0:d/tp_size]
- GPU 1: W[:, d/tp_size:2*d/tp_size]
- ...
输出需要AllGather收集(如果gather_output=True)
"""
def forward(self, input):
# 本地计算
output = torch.matmul(input, self.weight) # 只计算部分输出
if self.gather_output:
# AllGather收集完整输出
output = self._all_gather(output)
return output
# RowLinear(行并行):权重按行切分
class RowLinear(Module):
"""
行并行线性层
权重分布:
- GPU 0: W[0:d/tp_size, :]
- GPU 1: W[d/tp_size:2*d/tp_size, :]
- ...
输入需要分片,输出需要AllReduce合并
"""
def forward(self, input):
# 输入已经是分片状态(来自ColumnLinear)
output = torch.matmul(input, self.weight) # 本地计算
# AllReduce合并所有GPU的结果
output = self._all_reduce(output)
return output
2.2.4 架构图
graph TB
subgraph "GatedMLP架构(SwiGLU示例)"
A[Input<br/>[batch, seq, hidden]]
subgraph "TP并行计算"
B1[GPU 0<br/>Gate+Up投影<br/>W[:, 0:d/4]]
B2[GPU 1<br/>Gate+Up投影<br/>W[:, d/4:d/2]]
B3[GPU 2<br/>Gate+Up投影<br/>W[:, d/2:3d/4]]
B4[GPU 3<br/>Gate+Up投影<br/>W[:, 3d/4:d]]
end
C1[分离Gate和Up]
C2[SiLU(Gate) * Up]
subgraph "Down投影(行并行)"
D1[GPU 0<br/>Down投影<br/>W[0:d/4, :]]
D2[GPU 1<br/>Down投影<br/>W[d/4:d/2, :]]
D3[GPU 2<br/>Down投影<br/>W[d/2:3d/4, :]]
D4[GPU 3<br/>Down投影<br/>W[3d/4:d, :]]
end
E[AllReduce]
F[Output<br/>[batch, seq, hidden]]
A --> B1 & B2 & B3 & B4
B1 --> C1
B2 --> C1
B3 --> C1
B4 --> C1
C1 --> C2
C2 --> D1 & D2 & D3 & D4
D1 & D2 & D3 & D4 --> E
E --> F
end
style A fill:#e1f5ff
style F fill:#e1f5ff
style B1 fill:#fff3e0
style B2 fill:#fff3e0
style B3 fill:#fff3e0
style B4 fill:#fff3e0
style D1 fill:#f3e5f5
style D2 fill:#f3e5f5
style D3 fill:#f3e5f5
style D4 fill:#f3e5f5
2.3 RmsNorm层
2.3.1 类定义和原理
class RmsNorm(Module):
"""
Root Mean Square Layer Normalization
相比LayerNorm的优势:
- 计算更简单(无需计算均值)
- 数值更稳定
- 训练速度更快
公式:
RMSNorm(x) = x / RMS(x) * γ
其中 RMS(x) = sqrt(mean(x²) + ε)
"""
def __init__(
self,
normalized_shape: int, # 标准化维度(通常是hidden_size)
eps: float = 1e-6, # 数值稳定性参数
dtype = None, # 数据类型
**kwargs
):
super().__init__()
# 可学习的缩放参数
self.weight = Parameter(
shape=(normalized_shape,),
dtype=dtype,
)
self.normalized_shape = normalized_shape
self.eps = eps
self.dtype = dtype
2.3.2 forward()方法实现
def forward(self, x: Tensor) -> Tensor:
"""
RMSNorm前向传播
Args:
x: [batch, seq_len, hidden_size]
Returns:
output: [batch, seq_len, hidden_size]
"""
# 1. 计算RMS(Root Mean Square)
# x²的均值
variance = x.pow(2).mean(-1, keepdim=True) # [batch, seq_len, 1]
# RMS = sqrt(mean(x²) + ε)
rms = torch.sqrt(variance + self.eps)
# 2. 标准化
x_normalized = x / rms # [batch, seq_len, hidden_size]
# 3. 缩放
output = self.weight * x_normalized
return output
2.3.3 与LayerNorm对比
# LayerNorm vs RmsNorm 对比
def layer_norm(x, weight, bias, eps=1e-5):
"""标准LayerNorm"""
# 1. 计算均值和方差
mean = x.mean(-1, keepdim=True)
variance = x.var(-1, keepdim=True, unbiased=False)
# 2. 标准化
x_normalized = (x - mean) / torch.sqrt(variance + eps)
# 3. 仿射变换
return weight * x_normalized + bias
def rms_norm(x, weight, eps=1e-6):
"""RMSNorm(更简单)"""
# 1. 计算RMS
rms = torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
# 2. 标准化和缩放
return weight * (x / rms)
# 性能对比:
# - LayerNorm:需要计算mean和var,两次pass
# - RmsNorm:只需计算mean(x²),一次pass
# - RmsNorm比LayerNorm快约10-15%
三、关键功能深度剖析
3.1 Flash Attention实现
3.1.1 原理
标准Attention问题:
1. 内存占用:O(N²)存储attention matrix
2. 内存访问:多次读写HBM(High Bandwidth Memory)
Flash Attention解决方案:
1. 分块计算:将Q、K、V分成块
2. 在线softmax:避免存储完整attention matrix
3. 重计算:在backward时重新计算而非存储
4. 利用SRAM:减少HBM访问
性能提升:
- 内存使用:从O(N²)降至O(N)
- 速度提升:2-4倍(长序列更明显)
3.1.2 伪代码实现
def flash_attention(Q, K, V, block_size=256):
"""
Flash Attention伪代码
"""
N, d = Q.shape # 序列长度,头维度
Br = Bc = block_size # 块大小
# 初始化输出
O = torch.zeros_like(Q)
l = torch.zeros(N) # 分母(用于数值稳定的softmax)
m = torch.full((N,), -float('inf')) # 最大值(用于数值稳定)
# 遍历KV块
for j in range(0, N, Bc):
# 加载KV块到SRAM
Kj = K[j:j+Bc] # [Bc, d]
Vj = V[j:j+Bc] # [Bc, d]
# 遍历Q块
for i in range(0, N, Br):
# 加载Q块到SRAM
Qi = Q[i:i+Br] # [Br, d]
# 计算attention分数
Sij = Qi @ Kj.T # [Br, Bc]
# 在线softmax更新
m_new = torch.maximum(m[i:i+Br], Sij.max(dim=1).values)
l_new = torch.exp(m[i:i+Br] - m_new) * l[i:i+Br] + \
torch.exp(Sij - m_new.unsqueeze(1)).sum(dim=1)
# 更新输出
O[i:i+Br] = torch.exp(m[i:i+Br] - m_new).unsqueeze(1) * O[i:i+Br] + \
torch.exp(Sij - m_new.unsqueeze(1)) @ Vj
# 更新统计量
m[i:i+Br] = m_new
l[i:i+Br] = l_new
# 最终归一化
O = O / l.unsqueeze(1)
return O
3.2 MoE (Mixture of Experts)
3.2.1 架构原理
传统MLP:Input → FC1 → Activation → FC2 → Output
所有参数都被激活
MoE MLP:
Input → Router(选择top-k experts)
├─> Expert 0: FC1_0 → Act → FC2_0
├─> Expert 1: FC1_1 → Act → FC2_1
├─> ...
└─> Expert N: FC1_N → Act → FC2_N
Output = Σ(weight_i * Expert_i(Input)) # 只激活top-k个
优势:
- 参数量大但计算量小
- 专家专门化(不同专家学习不同模式)
- 可扩展性好
3.2.2 实现
class MOE(Module):
"""
Mixture of Experts实现
"""
def __init__(
self,
hidden_size: int,
ffn_hidden_size: int,
num_experts: int = 8, # 专家数量
top_k: int = 2, # 每个token选择的专家数
gate_type: str = "top", # 门控类型
**kwargs
):
super().__init__()
# Router(门控网络)
self.gate = Linear(
in_features=hidden_size,
out_features=num_experts,
bias=False,
)
# Experts
self.experts = ModuleList([
GatedMLP(hidden_size, ffn_hidden_size, **kwargs)
for _ in range(num_experts)
])
self.num_experts = num_experts
self.top_k = top_k
def forward(self, hidden_states: Tensor) -> Tensor:
batch_size, seq_len, hidden_size = hidden_states.shape
# 1. Router计算
router_logits = self.gate(hidden_states) # [batch, seq, num_experts]
router_probs = torch.softmax(router_logits, dim=-1)
# 2. 选择top-k experts
top_k_probs, top_k_indices = torch.topk(
router_probs, self.top_k, dim=-1
) # [batch, seq, top_k]
# 3. 重新归一化权重
top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
# 4. 专家计算
output = torch.zeros_like(hidden_states)
for expert_idx in range(self.num_experts):
# 4.1 找到选择当前expert的tokens
expert_mask = (top_k_indices == expert_idx).any(dim=-1) # [batch, seq]
if not expert_mask.any():
continue
# 4.2 提取需要处理的tokens
expert_input = hidden_states[expert_mask] # [num_tokens, hidden]
# 4.3 专家计算
expert_output = self.experts[expert_idx](expert_input)
# 4.4 计算权重并累加到输出
for k in range(self.top_k):
k_mask = (top_k_indices[:, :, k] == expert_idx) & expert_mask
if k_mask.any():
weight = top_k_probs[:, :, k][k_mask].unsqueeze(-1)
output[k_mask] += weight * expert_output[k_mask.sum(dim=0).cumsum(0) - 1]
return output
3.3 LoRA (Low-Rank Adaptation)
3.3.1 原理
传统微调:更新所有参数 W → W + ΔW
问题:ΔW参数量巨大(与原模型相同)
LoRA:
W' = W + ΔW = W + BA
其中:
- W: 原始权重 [d, k](冻结)
- B: 下投影矩阵 [d, r](可训练)
- A: 上投影矩阵 [r, k](可训练)
- r << min(d, k):低秩瓶颈
参数量:r(d + k) << dk
例如:Llama-7B的QKV投影
- 原始:4096 × 4096 = 16.8M参数
- LoRA (r=16):16 × (4096 + 4096) = 131K参数
- 压缩比:128:1
3.3.2 实现
class Lora(Module):
"""
LoRA低秩适配层
"""
def __init__(
self,
in_hidden_size: int, # 输入维度
out_hidden_sizes: List[int], # 输出维度列表(支持多个投影)
max_low_rank: int, # 最大低秩维度
**kwargs
):
super().__init__()
self.in_hidden_size = in_hidden_size
self.out_hidden_sizes = out_hidden_sizes
self.max_low_rank = max_low_rank
# A矩阵(上投影):[max_rank, in_hidden_size]
self.lora_a_weights = Parameter(
shape=(max_low_rank, in_hidden_size),
dtype=kwargs.get('dtype'),
)
# B矩阵(下投影):[sum(out_hidden_sizes), max_rank]
total_out_size = sum(out_hidden_sizes)
self.lora_b_weights = Parameter(
shape=(total_out_size, max_low_rank),
dtype=kwargs.get('dtype'),
)
def forward(
self,
input: Tensor, # [batch, seq, in_hidden_size]
lora_runtime_params: LoraRuntimeParams = None
) -> Tensor:
"""
LoRA前向传播:x → xA^T → (xA^T)B^T = x(AB)^T
"""
if lora_runtime_params is None:
# 不使用LoRA,返回零
batch_size, seq_len = input.shape[:2]
total_out_size = sum(self.out_hidden_sizes)
return torch.zeros(
batch_size, seq_len, total_out_size,
dtype=input.dtype, device=input.device
)
# 1. 获取LoRA参数
lora_ranks = lora_runtime_params.lora_ranks # 每个LoRA的实际rank
lora_weights_pointers = lora_runtime_params.lora_weights_pointers
# 2. 第一步:input @ A^T
# input: [batch, seq, in_hidden] @ A^T: [in_hidden, rank]
# = [batch, seq, rank]
first_gemm_output = torch.matmul(
input,
self.lora_a_weights[:lora_ranks[0], :].T
)
# 3. 第二步:(input @ A^T) @ B^T
# [batch, seq, rank] @ B^T: [rank, out_hidden]
# = [batch, seq, out_hidden]
lora_output = torch.matmul(
first_gemm_output,
self.lora_b_weights[:, :lora_ranks[0]].T
)
return lora_output
四、数据结构UML图
4.1 Layers核心类图
classDiagram
class Module {
<<abstract>>
+forward(*args, **kwargs) Tensor
+named_parameters() Iterator
+parameters() Iterator
+training: bool
}
class Attention {
+local_layer_idx: int
+hidden_size: int
+num_attention_heads: int
+num_attention_kv_heads: int
+attention_head_size: int
+position_embedding_type: PositionEmbeddingType
+qkv: ColumnLinear
+dense: RowLinear
+forward(hidden_states, ...) Tuple[Tensor, Tensor]
-_apply_rotary_pos_emb(q, k, pos_ids)
-_standard_attention(q, k, v, mask)
-_flash_attention(q, k, v, mask)
}
class GatedMLP {
+hidden_size: int
+ffn_hidden_size: int
+hidden_act: str
+gate_up_proj: ColumnLinear
+down_proj: RowLinear
+forward(hidden_states) Tensor
}
class RmsNorm {
+normalized_shape: int
+eps: float
+weight: Parameter
+forward(x) Tensor
}
class LayerNorm {
+normalized_shape: int
+eps: float
+weight: Parameter
+bias: Parameter
+forward(x) Tensor
}
class ColumnLinear {
+in_features: int
+out_features: int
+tp_size: int
+gather_output: bool
+weight: Parameter
+forward(input) Tensor
-_all_gather(tensor) Tensor
}
class RowLinear {
+in_features: int
+out_features: int
+tp_size: int
+weight: Parameter
+forward(input) Tensor
-_all_reduce(tensor) Tensor
}
class MOE {
+num_experts: int
+top_k: int
+gate: Linear
+experts: ModuleList[GatedMLP]
+forward(hidden_states) Tensor
}
class Lora {
+in_hidden_size: int
+out_hidden_sizes: List[int]
+max_low_rank: int
+lora_a_weights: Parameter
+lora_b_weights: Parameter
+forward(input, lora_params) Tensor
}
Module <|-- Attention
Module <|-- GatedMLP
Module <|-- RmsNorm
Module <|-- LayerNorm
Module <|-- ColumnLinear
Module <|-- RowLinear
Module <|-- MOE
Module <|-- Lora
Attention --> ColumnLinear : uses (qkv)
Attention --> RowLinear : uses (dense)
GatedMLP --> ColumnLinear : uses (gate_up_proj)
GatedMLP --> RowLinear : uses (down_proj)
MOE --> GatedMLP : contains
4.2 Attention计算流程状态图
stateDiagram-v2
[*] --> 输入预处理: hidden_states
输入预处理 --> QKV投影: reshape, 类型检查
QKV投影 --> QKV分离: 融合QKV或分离QKV
QKV分离 --> 多头Reshape: 分离Q,K,V
多头Reshape --> 位置编码: 转换为多头格式
位置编码 --> RoPE: PositionEmbeddingType.rope_gpt_neox
位置编码 --> ALiBi: PositionEmbeddingType.alibi
位置编码 --> 相对位置编码: PositionEmbeddingType.relative_position
位置编码 --> KV_Cache: 无位置编码
RoPE --> KV_Cache: 应用旋转编码
ALiBi --> KV_Cache: 准备偏置
相对位置编码 --> KV_Cache: 应用相对编码
KV_Cache --> GQA处理: 使用cache
KV_Cache --> GQA处理: 不使用cache
GQA处理 --> 标准Attention: num_kv_heads < num_attention_heads
GQA处理 --> 注意力计算: num_kv_heads == num_attention_heads
标准Attention --> 注意力计算: 扩展KV heads
注意力计算 --> Flash_Attention: use_flash_attention=True
注意力计算 --> 标准_Attention: use_flash_attention=False
Flash_Attention --> 输出投影: 内存优化计算
标准_Attention --> 输出投影: QK^T → Softmax → 加权V
输出投影 --> 返回结果: dense projection
返回结果 --> [*]: output, present_kv
五、使用示例
5.1 构建自定义Attention层
from tensorrt_llm.layers import Attention, AttentionMaskType, PositionEmbeddingType
# 1. 标准Multi-Head Attention
attention = Attention(
local_layer_idx=0,
hidden_size=4096,
num_attention_heads=32,
num_kv_heads=32, # MHA
max_position_embeddings=2048,
attention_mask_type=AttentionMaskType.causal,
position_embedding_type=PositionEmbeddingType.rope_gpt_neox,
dtype="float16",
)
# 2. Grouped-Query Attention
gqa_attention = Attention(
local_layer_idx=0,
hidden_size=4096,
num_attention_heads=32,
num_kv_heads=8, # GQA: 4个Q heads共享1个KV head
max_position_embeddings=4096,
position_embedding_type=PositionEmbeddingType.rope_gpt_neox,
rotary_embedding_base=10000.0,
)
# 3. Multi-Query Attention
mqa_attention = Attention(
local_layer_idx=0,
hidden_size=4096,
num_attention_heads=32,
num_kv_heads=1, # MQA: 所有Q heads共享1个KV head
max_position_embeddings=8192,
)
5.2 构建Tensor Parallel MLP
from tensorrt_llm.layers import GatedMLP
from tensorrt_llm.mapping import Mapping
# 1. 单GPU MLP
mlp = GatedMLP(
hidden_size=4096,
ffn_hidden_size=11008, # Llama-7B: 2.67倍
hidden_act="silu",
bias=False,
)
# 2. 4-GPU Tensor Parallel MLP
mapping = Mapping(world_size=4, tp_size=4)
tp_mlp = GatedMLP(
hidden_size=4096,
ffn_hidden_size=11008,
hidden_act="silu",
tp_group=mapping.tp_group,
tp_size=mapping.tp_size,
)
# TP权重分布:
# GPU 0: gate_up_proj.weight[:, 0:2752] + down_proj.weight[0:2752, :]
# GPU 1: gate_up_proj.weight[:, 2752:5504] + down_proj.weight[2752:5504, :]
# GPU 2: gate_up_proj.weight[:, 5504:8256] + down_proj.weight[5504:8256, :]
# GPU 3: gate_up_proj.weight[:, 8256:11008] + down_proj.weight[8256:11008, :]
5.3 构建MoE层
from tensorrt_llm.layers import MOE
# Mixtral 8x7B风格的MoE
moe = MOE(
hidden_size=4096,
ffn_hidden_size=14336,
num_experts=8, # 8个专家
top_k=2, # 每个token选择2个专家
hidden_act="silu",
tp_group=mapping.tp_group,
tp_size=mapping.tp_size,
)
# Expert Parallel + Tensor Parallel组合:
# - 8个experts分布到4个GPU:每个GPU 2个experts
# - 每个expert内部使用Tensor Parallel
5.4 LoRA微调
from tensorrt_llm.layers import Lora, LoraRuntimeParams
# 1. 创建LoRA层
lora = Lora(
in_hidden_size=4096,
out_hidden_sizes=[4096, 4096, 4096], # Q, K, V投影
max_low_rank=64,
)
# 2. 创建LoRA运行时参数
lora_params = LoraRuntimeParams(
lora_ranks=[16, 16, 16], # Q, K, V的实际rank
lora_weights_pointers=[...], # 权重指针
)
# 3. 前向传播
hidden_states = torch.randn(2, 512, 4096)
lora_output = lora(hidden_states, lora_params)
# 4. 与原始层结合
original_output = linear(hidden_states)
final_output = original_output + lora_output # 残差连接
六、性能优化建议
6.1 Attention优化
# 1. 启用Flash Attention
attention = Attention(
# ... 其他参数
use_flash_attention=True, # 内存优化
use_paged_context_fmha=True, # Paged KV Cache
)
# 2. 选择合适的位置编码
# RoPE:外推性好,支持长序列
# ALiBi:计算简单,内存友好
position_embedding_type=PositionEmbeddingType.rope_gpt_neox
# 3. 使用GQA减少KV Cache
num_attention_heads=32,
num_kv_heads=8, # 减少75% KV Cache
6.2 MLP优化
# 1. 融合Gate和Up投影
gate_up_proj = ColumnLinear(
in_features=hidden_size,
out_features=2 * ffn_hidden_size, # 融合计算
gather_output=False, # 避免不必要的通信
)
# 2. 选择高效的激活函数
# SiLU/Swish:在现代GPU上很快
# GELU需要更多计算
hidden_act="silu"
# 3. 合理设置TP
# ffn_hidden_size必须能被tp_size整除
assert ffn_hidden_size % tp_size == 0
6.3 MoE优化
# 1. 平衡专家负载
# 监控专家使用分布,避免负载不均
# 2. Expert Parallel + Tensor Parallel
# 大模型使用EP+TP组合:
# - 16个experts,4个GPU:每GPU 4个experts
# - 每个expert内部使用TP=2
# 3. 路由策略优化
# Top-1 vs Top-2权衡:
# - Top-1:计算量小,可能精度损失大
# - Top-2:计算量适中,精度损失小
七、常见问题
Q1:GQA如何影响模型精度?
- Llama2-70B:32 heads → 8 KV heads,精度损失<1%
- 建议:Q heads数量应该是KV heads的倍数
Q2:Flash Attention何时有效?
- 序列长度>512时明显有效
- 长序列(2K+)可以有2-4倍加速
- 短序列可能反而慢(kernel启动开销)
Q3:MoE如何选择专家数量?
- 专家数量=2的幂次(2, 4, 8, 16…)
- Top-K通常选择2
- 专家数量vs模型质量:更多专家→更好性能,但收益递减
Q4:LoRA的rank如何选择?
- 小模型:r=8-16
- 大模型:r=32-64
- 更大的r→更好微调效果,但参数量增加
Q5:如何在多GPU上部署MoE?
# Expert Parallel示例(8 experts, 4 GPUs)
expert_parallel_group = [0, 1, 2, 3] # 4个GPU组成EP组
tp_group = [0, 1] # 每2个GPU组成TP组
# 专家分布:
# GPU 0: Expert 0, 1 (TP rank 0)
# GPU 1: Expert 0, 1 (TP rank 1)
# GPU 2: Expert 2, 3 (TP rank 0)
# GPU 3: Expert 2, 3 (TP rank 1)