vLLM-09-Compilation模块-数据结构
关键数据结构概览
Compilation 模块的数据结构设计围绕编译流程管理、缓存系统和性能监控展开,包括编译配置、缓存管理、编译器接口和监控统计四个层次。
classDiagram
class CompilationConfig {
+bool use_inductor
+str compilation_level
+Optional[str] backend
+Dict[str, Any] custom_ops
+bool enable_chunked_prefill
+int max_chunked_prefill_size
+bool use_v2_block_manager
+bool disable_compile_cache
+str cache_dir
+int max_compile_time
+to_dict() Dict[str, Any]
+from_dict() CompilationConfig
}
class CompilerManager {
+CompilationConfig compilation_config
+CompilerInterface compiler
+str cache_dir
+bool disable_cache
+Dict[str, Any] cache
+CompilationCounter compilation_counter
+PerformanceMonitor performance_monitor
+compile() Any
+load() Optional[Callable]
+save() None
+_compute_hash_keys() Tuple[str, str]
}
class CompilerInterface {
<<abstract>>
+str name
+initialize_cache() None
+compute_hash() str
+compile() Tuple[Optional[Callable], Optional[Any]]
+load() Callable
}
class VllmBackend {
+VllmConfig vllm_config
+CompilationConfig compilation_config
+str prefix
+fx.GraphModule graph
+fx.GraphModule split_gm
+List[SplitItem] piecewise_graphs
+Callable returned_callable
+PostGradPassManager post_grad_pass_manager
+CompilerManager compiler_manager
+__call__() Callable
+configure_post_pass() None
}
class InductorAdaptor {
+str name = "inductor"
+initialize_cache() None
+compute_hash() str
+compile() Tuple[Optional[Callable], Optional[Any]]
+load() Callable
+_merge_inductor_config() Dict[str, Any]
}
class CompilationCounter {
+int num_backend_compilations
+int num_inductor_compiles
+int num_cache_hits
+int num_cache_misses
+int num_compiled_artifacts_saved
+int num_compiled_artifacts_loaded
+cache_hit_rate: float
+reset() None
}
class PerformanceMonitor {
+List[float] compilation_times
+List[float] cache_load_times
+float total_compilation_time
+float avg_compilation_time
+float avg_cache_load_time
+record_compilation_time() None
+record_cache_load_time() None
+get_statistics() Dict[str, float]
}
CompilerManager --> CompilationConfig : configured by
CompilerManager --> CompilerInterface : uses
CompilerManager --> CompilationCounter : monitors with
CompilerManager --> PerformanceMonitor : tracks with
VllmBackend --> CompilerManager : contains
CompilerInterface <|-- InductorAdaptor : implements
CompilerInterface <|-- EagerAdaptor : implements
CompilerInterface <|-- InductorStandaloneAdaptor : implements
核心类定义
1. CompilationConfig 编译配置
@dataclass
class CompilationConfig:
"""
编译模块的配置参数
控制编译行为、缓存策略和性能优化
"""
# 核心编译选项
use_inductor: bool = True # 是否使用Inductor编译
compilation_level: str = "piecewise" # 编译级别 ("dynamo_once", "piecewise")
backend: Optional[str] = None # 指定编译后端
# 自定义操作和扩展
custom_ops: Dict[str, Any] = field(default_factory=dict) # 自定义操作配置
# 分块预填充配置
enable_chunked_prefill: bool = False # 启用分块预填充
max_chunked_prefill_size: int = 8192 # 最大分块预填充大小
# 块管理器配置
use_v2_block_manager: bool = False # 使用V2块管理器
# 缓存配置
disable_compile_cache: bool = False # 禁用编译缓存
cache_dir: str = field(default_factory=lambda: tempfile.gettempdir()) # 缓存目录
# 性能和限制
max_compile_time: int = 600 # 最大编译时间(秒)
debug_compile: bool = False # 编译调试模式
def __post_init__(self):
"""配置验证和标准化"""
self._validate_and_normalize()
def _validate_and_normalize(self):
"""验证配置参数的有效性"""
# 1) 编译级别验证
valid_levels = ["dynamo_once", "piecewise"]
if self.compilation_level not in valid_levels:
raise ValueError(f"compilation_level must be one of {valid_levels}")
# 2) 缓存目录处理
if not self.disable_compile_cache:
self.cache_dir = os.path.expanduser(self.cache_dir)
os.makedirs(self.cache_dir, exist_ok=True)
# 3) 时间限制验证
if self.max_compile_time <= 0:
raise ValueError("max_compile_time must be positive")
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式"""
return {
"use_inductor": self.use_inductor,
"compilation_level": self.compilation_level,
"backend": self.backend,
"custom_ops": self.custom_ops,
"enable_chunked_prefill": self.enable_chunked_prefill,
"max_chunked_prefill_size": self.max_chunked_prefill_size,
"use_v2_block_manager": self.use_v2_block_manager,
"disable_compile_cache": self.disable_compile_cache,
"cache_dir": self.cache_dir,
"max_compile_time": self.max_compile_time,
"debug_compile": self.debug_compile,
}
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> "CompilationConfig":
"""从字典创建配置对象"""
return cls(**config_dict)
字段语义与约束:
| 字段 | 类型 | 约束 | 默认值 | 说明 |
|---|---|---|---|---|
use_inductor |
bool | 无 | True | 是否启用Inductor编译器 |
compilation_level |
str | 枚举值 | “piecewise” | 编译粒度级别 |
backend |
Optional[str] | 有效后端名 | None | 指定特定编译后端 |
custom_ops |
Dict[str, Any] | 有效配置 | {} | 自定义操作配置 |
max_compile_time |
int | > 0 | 600 | 编译超时时间 |
cache_dir |
str | 有效路径 | 临时目录 | 编译缓存存储路径 |
2. CompilerManager 编译管理器
class CompilerManager:
"""
编译流程的核心管理器
负责编译调度、缓存管理和性能监控
"""
def __init__(
self,
compilation_config: CompilationConfig,
vllm_config: Optional[VllmConfig] = None
):
# 配置管理
self.compilation_config = compilation_config
self.vllm_config = vllm_config
# 编译器后端
self.compiler = make_compiler(compilation_config)
# 缓存系统
self.disable_cache = compilation_config.disable_compile_cache
if not self.disable_cache:
self.cache_dir = Path(compilation_config.cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.cache: Dict[str, Any] = {} # 内存缓存
else:
self.cache_dir = None
self.cache = None
# 性能监控
self.compilation_counter = CompilationCounter()
self.performance_monitor = PerformanceMonitor()
# 运行时状态
self._compiler_processes: Dict[str, subprocess.Popen] = {}
self._active_compilations: Set[str] = set()
def _compute_hash_keys(
self,
graph: fx.GraphModule,
example_inputs: List[Any],
graph_index: int,
runtime_shape: Optional[int] = None
) -> Tuple[str, str]:
"""
计算编译缓存的哈希键
Returns:
(cache_key, cache_filename): 缓存键和文件名
"""
# 1) 收集哈希因子
hash_factors = []
# 图结构哈希
graph_str = str(graph.code)
hash_factors.append(("graph", hashlib.md5(graph_str.encode()).hexdigest()[:8]))
# 输入签名
input_signature = []
for inp in example_inputs:
if hasattr(inp, 'shape') and hasattr(inp, 'dtype'):
input_signature.append((tuple(inp.shape), str(inp.dtype)))
hash_factors.append(("inputs", input_signature))
# 运行时形状
if runtime_shape is not None:
hash_factors.append(("runtime_shape", runtime_shape))
# 编译配置
config_hash = hashlib.md5(
json.dumps(self.compilation_config.to_dict(), sort_keys=True).encode()
).hexdigest()[:8]
hash_factors.append(("config", config_hash))
# 编译器版本
compiler_hash = self.compiler.compute_hash(self.vllm_config)
hash_factors.append(("compiler", compiler_hash))
# 2) 生成最终哈希
hash_content = json.dumps(hash_factors, sort_keys=True)
cache_key = hashlib.md5(hash_content.encode(), usedforsecurity=False).hexdigest()[:16]
# 3) 构建文件名
runtime_shape_str = f"shape_{runtime_shape}" if runtime_shape else "dynamic"
cache_filename = f"compiled_graph_{graph_index}_{runtime_shape_str}_{cache_key}"
return cache_key, cache_filename
3. CompilerInterface 编译器接口
class CompilerInterface(ABC):
"""
编译器后端的抽象接口
定义统一的编译、加载和缓存接口
"""
# 编译器标识
name: str = "abstract"
@abstractmethod
def initialize_cache(
self,
cache_dir: str,
disable_cache: bool = False,
prefix: str = ""
) -> None:
"""
初始化编译器的缓存系统
Args:
cache_dir: 缓存目录路径
disable_cache: 是否禁用缓存
prefix: 缓存前缀(用于多实例区分)
"""
pass
@abstractmethod
def compute_hash(self, vllm_config: VllmConfig) -> str:
"""
计算编译器相关的版本哈希
Args:
vllm_config: vLLM配置对象
Returns:
版本哈希字符串
"""
return ""
@abstractmethod
def compile(
self,
graph: fx.GraphModule,
example_inputs: List[Any],
compiler_config: Dict[str, Any],
runtime_shape: Optional[int] = None,
key: Optional[str] = None,
) -> Tuple[Optional[Callable], Optional[Any]]:
"""
编译计算图
Args:
graph: FX计算图
example_inputs: 示例输入数据
compiler_config: 编译器配置
runtime_shape: 运行时形状
key: 缓存键
Returns:
(compiled_function, handle): 编译结果和句柄
"""
return None, None
@abstractmethod
def load(
self,
handle: Any,
graph: fx.GraphModule,
example_inputs: List[Any],
graph_index: int,
runtime_shape: Optional[int] = None
) -> Callable:
"""
从句柄加载编译结果
Args:
handle: 编译器句柄
graph: 原始计算图
example_inputs: 示例输入
graph_index: 图索引
runtime_shape: 运行时形状
Returns:
加载的编译函数
"""
raise NotImplementedError
4. VllmBackend vLLM编译后端
class VllmBackend:
"""
vLLM专用的编译后端
实现分片编译和后梯度优化Pass
"""
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
):
# 配置初始化
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.prefix = prefix or model_tag
# 后梯度Pass管理器
self.post_grad_pass_manager = PostGradPassManager()
# 编译状态
self.sym_tensor_indices: List[int] = []
self.input_buffers: List[torch.Tensor] = []
# 编译管理器
self.compiler_manager = CompilerManager(
self.compilation_config, self.vllm_config)
# 分片编译相关
self.graph: Optional[fx.GraphModule] = None
self.split_gm: Optional[fx.GraphModule] = None
self.piecewise_graphs: List[SplitItem] = []
self.returned_callable: Optional[Callable] = None
def __call__(
self,
graph: fx.GraphModule,
example_inputs: List[Any]
) -> Callable:
"""
编译入口方法
根据配置选择完整编译或分片编译
"""
# 1) 保存图和输入信息
self.graph = graph
self._analyze_input_buffers(example_inputs)
# 2) 配置后处理Pass
self.configure_post_pass()
# 3) 选择编译策略
if self.compilation_config.compilation_level == "piecewise":
return self._piecewise_compile(graph, example_inputs)
else:
return self._full_compile(graph, example_inputs)
def _piecewise_compile(
self,
graph: fx.GraphModule,
example_inputs: List[Any]
) -> Callable:
"""
分片编译实现
将大图分解为多个小图分别编译
"""
# 1) 图分片分析
self.piecewise_graphs = self._split_graph_into_pieces(graph)
# 2) 创建拼接图
self.split_gm = self._create_split_graph_module(self.piecewise_graphs)
# 3) 编译各个分片
compiled_pieces = []
for i, split_item in enumerate(self.piecewise_graphs):
compiled_piece = self.compiler_manager.compile(
split_item.graph,
split_item.example_inputs,
self._get_additional_inductor_config(),
self.compilation_config,
graph_index=i,
num_graphs=len(self.piecewise_graphs)
)
compiled_pieces.append(compiled_piece)
# 4) 构建完整执行函数
def piecewise_callable(*args):
current_args = args
for piece in compiled_pieces:
current_args = piece(*current_args)
if not isinstance(current_args, tuple):
current_args = (current_args,)
return current_args[0] if len(current_args) == 1 else current_args
self.returned_callable = piecewise_callable
return piecewise_callable
缓存和监控数据结构
1. CompilationCounter 编译统计
@dataclass
class CompilationCounter:
"""
编译过程的统计计数器
跟踪编译次数、缓存效率等指标
"""
# 编译统计
num_backend_compilations: int = 0 # 后端编译总次数
num_inductor_compiles: int = 0 # Inductor编译次数
# 缓存统计
num_cache_hits: int = 0 # 缓存命中次数
num_cache_misses: int = 0 # 缓存未命中次数
# 产物管理
num_compiled_artifacts_saved: int = 0 # 保存的编译产物数
num_compiled_artifacts_loaded: int = 0 # 加载的编译产物数
@property
def cache_hit_rate(self) -> float:
"""计算缓存命中率"""
total_requests = self.num_cache_hits + self.num_cache_misses
if total_requests == 0:
return 0.0
return self.num_cache_hits / total_requests
@property
def total_cache_requests(self) -> int:
"""总缓存请求数"""
return self.num_cache_hits + self.num_cache_misses
def reset(self) -> None:
"""重置所有计数器"""
self.num_backend_compilations = 0
self.num_inductor_compiles = 0
self.num_cache_hits = 0
self.num_cache_misses = 0
self.num_compiled_artifacts_saved = 0
self.num_compiled_artifacts_loaded = 0
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式"""
return {
"backend_compilations": self.num_backend_compilations,
"inductor_compiles": self.num_inductor_compiles,
"cache_hits": self.num_cache_hits,
"cache_misses": self.num_cache_misses,
"cache_hit_rate": self.cache_hit_rate,
"artifacts_saved": self.num_compiled_artifacts_saved,
"artifacts_loaded": self.num_compiled_artifacts_loaded,
}
2. PerformanceMonitor 性能监控器
class PerformanceMonitor:
"""
编译性能监控器
跟踪编译时间、内存使用等性能指标
"""
def __init__(self, max_history: int = 1000):
self.max_history = max_history
# 时间统计
self.compilation_times: deque = deque(maxlen=max_history)
self.cache_load_times: deque = deque(maxlen=max_history)
# 累计统计
self.total_compilation_time: float = 0.0
self.total_cache_load_time: float = 0.0
# 内存统计
self.peak_memory_usage: int = 0
self.current_memory_usage: int = 0
# 并发统计
self.active_compilations: int = 0
self.max_concurrent_compilations: int = 0
def record_compilation_time(self, duration: float) -> None:
"""记录编译时间"""
self.compilation_times.append(duration)
self.total_compilation_time += duration
def record_cache_load_time(self, duration: float) -> None:
"""记录缓存加载时间"""
self.cache_load_times.append(duration)
self.total_cache_load_time += duration
def record_memory_usage(self, usage: int) -> None:
"""记录内存使用情况"""
self.current_memory_usage = usage
self.peak_memory_usage = max(self.peak_memory_usage, usage)
def start_compilation(self) -> None:
"""开始编译(用于并发统计)"""
self.active_compilations += 1
self.max_concurrent_compilations = max(
self.max_concurrent_compilations, self.active_compilations)
def end_compilation(self) -> None:
"""结束编译"""
self.active_compilations = max(0, self.active_compilations - 1)
@property
def avg_compilation_time(self) -> float:
"""平均编译时间"""
if not self.compilation_times:
return 0.0
return sum(self.compilation_times) / len(self.compilation_times)
@property
def avg_cache_load_time(self) -> float:
"""平均缓存加载时间"""
if not self.cache_load_times:
return 0.0
return sum(self.cache_load_times) / len(self.cache_load_times)
def get_statistics(self) -> Dict[str, float]:
"""获取性能统计信息"""
return {
"avg_compilation_time": self.avg_compilation_time,
"avg_cache_load_time": self.avg_cache_load_time,
"total_compilation_time": self.total_compilation_time,
"total_cache_load_time": self.total_cache_load_time,
"peak_memory_usage_mb": self.peak_memory_usage / 1024**2,
"current_memory_usage_mb": self.current_memory_usage / 1024**2,
"max_concurrent_compilations": self.max_concurrent_compilations,
"active_compilations": self.active_compilations,
}
分片编译数据结构
1. SplitItem 分片项
@dataclass
class SplitItem:
"""
分片编译中的单个分片信息
包含分片图、输入输出和编译选项
"""
# 分片标识
graph_index: int # 分片索引
graph: fx.GraphModule # 分片计算图
# 输入输出
example_inputs: List[Any] # 示例输入
input_nodes: List[fx.Node] # 输入节点列表
output_nodes: List[fx.Node] # 输出节点列表
# 形状信息
input_shapes: List[Tuple[int, ...]] # 输入形状列表
output_shapes: List[Tuple[int, ...]] # 输出形状列表
# 编译选项
compile_options: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
"""验证分片项的一致性"""
if len(self.input_shapes) != len(self.example_inputs):
raise ValueError("Input shapes and example inputs length mismatch")
@property
def memory_estimate(self) -> int:
"""估算分片的内存需求"""
total_elements = 0
for shape in self.input_shapes + self.output_shapes:
total_elements += np.prod(shape)
# 假设float16,每个元素2字节
return total_elements * 2
@property
def complexity_score(self) -> float:
"""计算分片的复杂度分数"""
# 基于节点数量和连接复杂度的简单评估
num_nodes = len(list(self.graph.graph.nodes))
num_edges = len(list(self.graph.graph.edges))
return num_nodes + 0.5 * num_edges
2. PostGradPassManager Pass管理器
class PostGradPassManager:
"""
后梯度优化Pass的管理器
负责注册、调度和执行各种优化Pass
"""
def __init__(self):
# Pass注册表
self.registered_passes: List[InductorPass] = []
# Pass配置
self.pass_config: Dict[str, Any] = {
"enable_fusion": True,
"enable_memory_planning": True,
"enable_kernel_optimization": True,
}
# 执行统计
self.pass_execution_times: Dict[str, List[float]] = {}
def register_pass(self, pass_instance: InductorPass) -> None:
"""注册一个优化Pass"""
self.registered_passes.append(pass_instance)
self.pass_execution_times[pass_instance.name] = []
def apply_passes(
self,
graph: fx.GraphModule,
example_inputs: List[Any]
) -> fx.GraphModule:
"""
应用所有注册的Pass
Args:
graph: 输入计算图
example_inputs: 示例输入
Returns:
优化后的计算图
"""
current_graph = graph
for pass_instance in self.registered_passes:
if not self._should_apply_pass(pass_instance):
continue
start_time = time.time()
try:
current_graph = pass_instance.apply(current_graph, example_inputs)
execution_time = time.time() - start_time
self.pass_execution_times[pass_instance.name].append(execution_time)
logger.debug(f"Applied pass {pass_instance.name} in {execution_time:.3f}s")
except Exception as e:
logger.warning(f"Pass {pass_instance.name} failed: {e}")
# 继续执行其他Pass
return current_graph
def _should_apply_pass(self, pass_instance: InductorPass) -> bool:
"""判断是否应该应用某个Pass"""
pass_name = pass_instance.name
# 根据配置决定是否启用
if pass_name == "fusion_pass":
return self.pass_config.get("enable_fusion", True)
elif pass_name == "memory_planning_pass":
return self.pass_config.get("enable_memory_planning", True)
elif pass_name == "kernel_optimization_pass":
return self.pass_config.get("enable_kernel_optimization", True)
# 默认启用
return True
数据流映射关系
1. 编译流程数据映射
def compilation_data_flow(
pytorch_model: nn.Module, # 原始PyTorch模型
example_inputs: List[torch.Tensor], # 示例输入
compilation_config: CompilationConfig # 编译配置
) -> Callable: # 编译后的可调用对象
"""
展示从PyTorch模型到编译结果的完整数据流映射
"""
# 第1步:模型到FX图的转换
with torch.no_grad():
fx_graph = torch.fx.symbolic_trace(pytorch_model)
# 第2步:图优化和变换
optimized_graph = apply_graph_optimizations(fx_graph)
# 第3步:缓存键计算
cache_key = compute_compilation_hash(
optimized_graph, example_inputs, compilation_config)
# 第4步:编译器调用
if compilation_config.use_inductor:
compiled_fn = torch.compile(
optimized_graph,
mode="reduce-overhead",
dynamic=compilation_config.enable_dynamic_shapes
)
else:
compiled_fn = optimized_graph # Eager模式
# 第5步:缓存存储
save_compilation_cache(cache_key, compiled_fn)
return compiled_fn
2. 缓存系统数据组织
class CompilationCacheLayout:
"""
编译缓存的数据组织结构
"""
def __init__(self, cache_dir: str):
self.cache_dir = Path(cache_dir)
# 分层缓存结构
self.memory_cache: Dict[str, Any] = {} # 内存缓存(最快)
self.disk_cache_dir = self.cache_dir / "disk" # 磁盘缓存
self.artifacts_dir = self.cache_dir / "artifacts" # 编译产物
# 元数据管理
self.cache_metadata = self._load_cache_metadata()
def _get_cache_hierarchy(self, cache_key: str) -> List[Tuple[str, str]]:
"""
获取缓存的层次结构路径
Returns:
[(cache_type, cache_path), ...] 按访问速度排序
"""
paths = []
# 1) 内存缓存
if cache_key in self.memory_cache:
paths.append(("memory", cache_key))
# 2) 磁盘缓存
disk_path = self.disk_cache_dir / f"{cache_key}.pkl"
if disk_path.exists():
paths.append(("disk", str(disk_path)))
# 3) 编译产物缓存
artifact_path = self.artifacts_dir / cache_key
if artifact_path.exists():
paths.append(("artifact", str(artifact_path)))
return paths
版本演进说明
| 版本 | 主要变更 | 数据结构影响 | 兼容性 | 迁移建议 |
|---|---|---|---|---|
| v0.1.x | 基础编译支持 | 简单配置结构 | 不兼容 | 已废弃 |
| v0.2.x | 引入Inductor | 添加编译器接口 | 向后兼容 | 建议升级 |
| v0.3.x | 缓存系统 | 分层缓存结构 | 部分兼容 | 需要重建缓存 |
| v0.4.x | 分片编译 | 分片数据结构 | 向后兼容 | 新增分片配置 |
| 当前版本 | 性能监控优化 | 监控数据结构 | 向后兼容 | 推荐监控功能 |
内存布局优化
1. 编译缓存内存对齐
# 缓存友好的数据结构设计
@dataclass
class OptimizedCacheEntry:
"""
内存优化的缓存条目结构
"""
__slots__ = ['key', 'compiled_fn', 'metadata', 'access_count', 'last_access']
key: str # 缓存键
compiled_fn: Callable # 编译后的函数
metadata: Dict[str, Any] # 元数据信息
access_count: int = 0 # 访问次数
last_access: float = 0.0 # 最后访问时间
def update_access(self):
"""更新访问统计"""
self.access_count += 1
self.last_access = time.time()
2. 批处理编译内存池
class CompilationMemoryPool:
"""
编译过程的内存池管理
复用临时对象减少分配开销
"""
def __init__(self, pool_size: int = 100):
# 对象池
self.fx_graph_pool: List[fx.GraphModule] = []
self.tensor_pool: List[torch.Tensor] = []
self.dict_pool: List[Dict[str, Any]] = []
# 池大小限制
self.max_pool_size = pool_size
# 使用统计
self.pool_hits = 0
self.pool_misses = 0
def get_graph_module(self) -> fx.GraphModule:
"""从池中获取图模块对象"""
if self.fx_graph_pool:
self.pool_hits += 1
return self.fx_graph_pool.pop()
else:
self.pool_misses += 1
return fx.GraphModule(nn.Module(), fx.Graph())
def return_graph_module(self, graph: fx.GraphModule):
"""归还图模块对象到池中"""
if len(self.fx_graph_pool) < self.max_pool_size:
# 清理状态后放回池中
graph.graph.clear()
self.fx_graph_pool.append(graph)
@property
def hit_rate(self) -> float:
"""计算池命中率"""
total = self.pool_hits + self.pool_misses
return self.pool_hits / total if total > 0 else 0.0
这些数据结构为Compilation模块提供了完整的类型系统和内存管理方案,支持从简单编译到复杂分片编译的各种场景。