vLLM-03-Attention模块-数据结构
关键数据结构概览
Attention 模块的数据结构设计围绕高效的注意力计算和内存管理展开,包括核心类定义、元数据管理和缓存结构三个层次。
classDiagram
class AttentionType {
+String DECODER
+String ENCODER
+String ENCODER_ONLY
+String ENCODER_DECODER
}
class AttentionBackend {
<<abstract>>
+bool accept_output_buffer
+bool supports_quant_query_input
+get_name()* String
+get_impl_cls()* Type
+get_metadata_cls()* Type
+get_builder_cls()* Type
+get_kv_cache_shape()* Tuple
+make_metadata() AttentionMetadata
}
class AttentionMetadata {
<<abstract>>
}
class Attention {
-num_heads: int
-head_size: int
-scale: float
-num_kv_heads: int
-sliding_window: Optional[int]
-kv_cache_dtype: String
-block_size: int
-calculate_kv_scales: bool
-impl: AttentionImpl
-kv_cache: List[torch.Tensor]
+forward(query, key, value) torch.Tensor
+load_weights(weights) None
}
class PagedAttentionMetadata {
+seq_lens_tensor: torch.Tensor
+max_decode_seq_len: int
+block_tables: torch.Tensor
}
class AttentionImpl {
<<abstract>>
+can_return_lse: bool
+forward()* torch.Tensor
}
class AttentionLayer {
<<protocol>>
+_q_scale: torch.Tensor
+_k_scale: torch.Tensor
+_v_scale: torch.Tensor
+forward() torch.Tensor
}
AttentionBackend --> AttentionMetadata : creates
AttentionBackend --> AttentionImpl : creates
Attention --> AttentionImpl : uses
Attention --> AttentionBackend : configured by
PagedAttentionMetadata --|> AttentionMetadata
AttentionLayer <|.. Attention : implements
核心类定义
1. Attention 主类
class Attention(nn.Module, AttentionLayerBase):
"""
多头注意力计算的主要实现类
支持各种注意力变体和优化后端
"""
def __init__(
self,
num_heads: int, # 注意力头数
head_size: int, # 每个头的维度大小
scale: float, # 注意力缩放因子 (1/sqrt(head_size))
num_kv_heads: Optional[int] = None, # KV 头数(支持 GQA/MQA)
alibi_slopes: Optional[List[float]] = None, # ALiBi 位置编码斜率
cache_config: Optional[CacheConfig] = None, # 缓存配置
quant_config: Optional[QuantizationConfig] = None, # 量化配置
logits_soft_cap: Optional[float] = None, # Logits 软限制
per_layer_sliding_window: Optional[int] = None, # 滑动窗口大小
use_mla: bool = False, # 是否使用 MLA (Multi-Level Attention)
use_sparse: bool = False, # 是否使用稀疏注意力
prefix: str = "", # 层名前缀
attn_type: str = AttentionType.DECODER, # 注意力类型
kv_sharing_target_layer_name: Optional[str] = None, # KV 共享目标层
attn_backend: Optional[type[AttentionBackend]] = None, # 指定后端
**extra_impl_args, # 额外实现参数
) -> None
字段语义与约束:
| 字段 | 类型 | 约束 | 默认值 | 说明 |
|---|---|---|---|---|
num_heads |
int | > 0 | 必填 | 查询头数,决定并行度 |
head_size |
int | {32,64,80,96,112,120,128,192,256} | 必填 | 每个头的维度,受硬件优化限制 |
scale |
float | > 0 | 必填 | 注意力缩放因子,通常为 1/√head_size |
num_kv_heads |
int | ≤ num_heads | num_heads | KV 头数,支持分组查询注意力 |
sliding_window |
int | > 0 或 None | None | 滑动窗口大小,限制注意力范围 |
block_size |
int | 2^n, 通常 16 | 16 | KV 缓存块大小,影响内存分配粒度 |
2. AttentionBackend 抽象基类
class AttentionBackend(ABC):
"""
注意力后端的抽象基类
定义了所有后端必须实现的接口
"""
# 类属性
accept_output_buffer: bool = False # 是否接受预分配的输出缓冲区
supports_quant_query_input: bool = False # 是否支持量化查询输入
@staticmethod
@abstractmethod
def get_name() -> str:
"""返回后端名称"""
raise NotImplementedError
@staticmethod
@abstractmethod
def get_impl_cls() -> Type["AttentionImpl"]:
"""返回具体实现类"""
raise NotImplementedError
@staticmethod
@abstractmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
"""返回元数据类"""
raise NotImplementedError
@staticmethod
@abstractmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> Tuple[int, ...]:
"""计算 KV 缓存的张量形状"""
raise NotImplementedError
3. AttentionType 枚举类
class AttentionType:
"""
注意力类型定义
兼容 torch.compile 的字符串枚举
"""
DECODER = "decoder" # 解码器自注意力
ENCODER = "encoder" # 编码器自注意力
ENCODER_ONLY = "encoder_only" # 仅编码器模型的注意力
ENCODER_DECODER = "encoder_decoder" # 编码器-解码器交叉注意力
元数据结构
1. AttentionMetadata 基类
class AttentionMetadata:
"""
注意力计算的元数据基类
具体子类由各个后端实现
"""
pass # 基类为空,由子类扩展
2. PagedAttentionMetadata 结构
@dataclass
class PagedAttentionMetadata:
"""
PagedAttention 专用元数据
管理分页内存布局和序列信息
"""
# 序列长度信息
seq_lens_tensor: Optional[torch.Tensor] # 形状: (batch_size,)
max_decode_seq_len: int # 批次中最大解码序列长度
# 内存块映射表
block_tables: Optional[torch.Tensor] # 形状: (batch_size, max_blocks_per_seq)
字段详细说明:
| 字段 | 形状 | 数据类型 | 说明 |
|---|---|---|---|
seq_lens_tensor |
(batch_size,) |
torch.int32 | 每个序列的累计长度 |
max_decode_seq_len |
标量 | int | 当前批次中解码序列的最大长度 |
block_tables |
(batch_size, max_blocks_per_seq) |
torch.int32 | 物理块地址映射表 |
3. 前向上下文结构
class ForwardContext:
"""
前向传播的全局上下文
通过上下文管理器在层间传递信息
"""
attn_metadata: Union[AttentionMetadata, Dict[str, AttentionMetadata]]
# 注意力元数据,可以是单个对象或按层索引的字典
virtual_engine: int = 0
# 虚拟引擎索引,用于多引擎并行
no_compile_layers: Dict[str, "Attention"]
# 不参与编译的层映射,用于 torch.compile 优化
KV 缓存数据结构
1. 缓存张量布局
# KV 缓存张量的典型形状组织
kv_cache_shape = (
num_blocks, # 总块数
2, # Key 和 Value 两个张量
block_size, # 每块的 token 数
num_kv_heads, # KV 头数
head_size # 每个头的维度
)
2. 块管理结构
class BlockTable:
"""
逻辑到物理块的映射表
支持动态分配和释放
"""
def __init__(self):
self.logical_to_physical: Dict[int, int] = {} # 逻辑块 -> 物理块
self.physical_to_logical: Dict[int, int] = {} # 物理块 -> 逻辑块
self.free_blocks: Set[int] = set() # 空闲物理块集合
self.allocated_blocks: Set[int] = set() # 已分配物理块集合
量化相关结构
1. 量化参数
class QuantizationScales:
"""
注意力计算中的量化缩放参数
"""
# 输入量化缩放
_q_scale: torch.Tensor # Query 量化缩放因子
_k_scale: torch.Tensor # Key 量化缩放因子
_v_scale: torch.Tensor # Value 量化缩放因子
# 浮点版本(用于某些计算)
_q_scale_float: float
_k_scale_float: float
_v_scale_float: float
# 概率量化缩放
_prob_scale: torch.Tensor # 注意力概率的量化缩放
2. 量化配置集成
def setup_quantization(self, quant_config: QuantizationConfig):
"""
根据量化配置设置量化参数
"""
if quant_config is not None and quant_config.enable_kv_cache_quantization:
# 启用 KV 缓存量化
self.kv_cache_dtype = quant_config.kv_cache_dtype
self.calculate_kv_scales = True
# 设置量化算子
self.query_quant = quant_config.get_query_quantizer()
else:
# 禁用量化
self.kv_cache_dtype = "auto"
self.calculate_kv_scales = False
self.query_quant = None
数据流映射关系
1. 输入到缓存的映射
def store_kv_mapping(
key: torch.Tensor, # 输入 Key 张量 [batch_size, seq_len, num_kv_heads, head_size]
value: torch.Tensor, # 输入 Value 张量 [batch_size, seq_len, num_kv_heads, head_size]
block_tables: torch.Tensor, # 块映射表 [batch_size, max_blocks_per_seq]
kv_cache: torch.Tensor # KV 缓存 [num_blocks, 2, block_size, num_kv_heads, head_size]
) -> None:
"""
将输入的 Key/Value 张量存储到分页的 KV 缓存中
映射规则:
1. 根据 block_tables 确定每个序列的物理块位置
2. 将 Key/Value 按 block_size 切分并存储到对应块中
3. 处理跨块的序列,确保连续性
"""
2. 版本演进说明
| 版本 | 变更内容 | 兼容性 | 迁移建议 |
|---|---|---|---|
| v0.1.x | 基础注意力实现 | 不兼容 | 已废弃 |
| v0.2.x | 引入 PagedAttention | 向后兼容 | 建议升级 |
| v0.3.x | 多后端架构 | 部分兼容 | 配置需调整 |
| v0.4.x | 量化支持 | 向后兼容 | 新增配置选项 |
| 当前版本 | 统一前向上下文 | 向后兼容 | 推荐最新特性 |
内存布局优化
1. 缓存内存对齐
# 内存对齐策略
CACHE_ALIGNMENT = 16 # 字节对齐边界
def align_cache_size(size: int) -> int:
"""将缓存大小对齐到指定边界"""
return (size + CACHE_ALIGNMENT - 1) // CACHE_ALIGNMENT * CACHE_ALIGNMENT
2. 批处理内存布局
# 批处理优化的内存布局
batch_attention_layout = {
"query": [batch_size, num_heads, seq_len, head_size], # 查询张量
"key_cache": [num_blocks, num_kv_heads, block_size, head_size], # Key 缓存
"value_cache": [num_blocks, num_kv_heads, head_size, block_size], # Value 缓存(转置)
"output": [batch_size, num_heads, seq_len, head_size], # 输出张量
}
性能考虑:
- Key 缓存使用标准布局便于写入
- Value 缓存转置布局优化矩阵乘法
- 对齐边界减少内存访问延迟
- 块大小选择平衡内存利用率和计算效率