1. 数据结构概览

VLLM的数据结构设计围绕推理请求的生命周期管理,从用户输入到最终输出的整个流程涉及多个关键数据结构。

1.1 核心数据结构层次

classDiagram
    class PromptType {
        <<Union>>
        +str
        +TextPrompt
        +TokensPrompt
    }
    
    class TextPrompt {
        +str prompt
        +MultiModalDataDict multi_modal_data
        +dict mm_processor_kwargs
    }
    
    class TokensPrompt {
        +List~int~ prompt_token_ids
        +MultiModalDataDict multi_modal_data
        +dict mm_processor_kwargs
    }
    
    class ProcessorInputs {
        +str request_id
        +List~int~ prompt_token_ids
        +str prompt
        +Any multi_modal_inputs
        +SamplingParams|PoolingParams params
        +LoRARequest lora_request
    }
    
    class EngineCoreRequest {
        +str request_id
        +ProcessorInputs inputs
        +int priority
        +float arrival_time
        +RequestStatus status
    }
    
    class Request {
        +str request_id  
        +ProcessorInputs inputs
        +RequestStatus status
        +float arrival_time
        +int priority
        +List~int~ prompt_token_ids
        +List~int~ output_token_ids
        +dict metrics
    }
    
    PromptType <|-- TextPrompt
    PromptType <|-- TokensPrompt
    ProcessorInputs --> Request
    EngineCoreRequest --> Request

2. 请求处理数据结构

2.1 Request 类详细设计

文件位置: vllm/v1/request.py

@dataclass
class Request:
    """
    代表一个推理请求的完整生命周期数据
    
    属性说明:
        request_id: 请求的唯一标识符
        inputs: 经过预处理的输入数据
        status: 当前请求状态
        arrival_time: 请求到达时间戳
        priority: 请求优先级(用于优先级调度)
        prompt_token_ids: 提示词的token ID列表
        output_token_ids: 已生成的输出token ID列表
        spec_token_ids: 投机解码的token ID列表(可选)
        metrics: 请求相关的性能指标
    """
    
    request_id: str
    inputs: ProcessorInputs
    status: RequestStatus = RequestStatus.WAITING
    arrival_time: float = field(default_factory=time.time)
    priority: int = 0
    
    # Token相关
    prompt_token_ids: List[int] = field(default_factory=list)
    output_token_ids: List[int] = field(default_factory=list) 
    spec_token_ids: List[int] = field(default_factory=list)
    
    # 缓存和性能
    encoder_outputs: Optional[Any] = None
    metrics: Dict[str, Any] = field(default_factory=dict)
    
    # 多模态支持
    mm_inputs: Optional[MultiModalInputs] = None
    mm_hashes: Optional[List[str]] = None
    
    def __post_init__(self):
        """请求初始化后处理"""
        if not self.prompt_token_ids and self.inputs:
            self.prompt_token_ids = self.inputs.prompt_token_ids
            
        # 初始化指标
        self.metrics.update({
            'created_time': self.arrival_time,
            'tokens_computed': 0,
            'total_tokens': len(self.prompt_token_ids)
        })
    
    @property
    def num_computed_tokens(self) -> int:
        """已计算的token数量"""
        return len(self.prompt_token_ids) + len(self.output_token_ids)
    
    @property  
    def num_total_tokens(self) -> int:
        """总token数量(包括投机token)"""
        return self.num_computed_tokens + len(self.spec_token_ids)
    
    def is_finished(self) -> bool:
        """检查请求是否已完成"""
        return self.status in [
            RequestStatus.FINISHED_STOPPED,
            RequestStatus.FINISHED_LENGTH_CAPPED,
            RequestStatus.FINISHED_ABORTED,
            RequestStatus.FINISHED_IGNORED
        ]
    
    def get_next_tokens_to_compute(self, max_tokens: int) -> int:
        """获取下一步要计算的token数量"""
        remaining = self.num_total_tokens - self.metrics['tokens_computed']
        return min(remaining, max_tokens)

请求状态枚举

class RequestStatus(str, enum.Enum):
    """请求状态枚举"""
    
    # 等待状态
    WAITING = "waiting"                    # 等待调度
    PREEMPTED = "preempted"               # 被抢占
    
    # 运行状态  
    RUNNING = "running"                    # 正在运行
    SWAPPED = "swapped"                   # 被交换到CPU
    
    # 完成状态
    FINISHED_STOPPED = "finished_stopped"        # 正常停止
    FINISHED_LENGTH_CAPPED = "finished_length"   # 达到长度限制
    FINISHED_ABORTED = "finished_aborted"        # 被中止
    FINISHED_IGNORED = "finished_ignored"        # 被忽略

2.2 采样参数数据结构

文件位置: vllm/sampling_params.py

classDiagram
    class SamplingParams {
        +int n
        +int best_of
        +float presence_penalty
        +float frequency_penalty
        +float repetition_penalty
        +float temperature
        +float top_p
        +int top_k
        +int min_p
        +bool use_beam_search
        +int length_penalty
        +bool early_stopping
        +List~str~ stop
        +bool ignore_eos
        +int max_tokens
        +int min_tokens
        +int logprobs
        +int prompt_logprobs
        +bool skip_special_tokens
        +List~int~ spaces_between_special_tokens
        +List~LogitsProcessor~ logits_processors
        +Dict~int,float~ logit_bias
        +RequestOutputKind output_kind
        +int seed
        +List~int~ allowed_token_ids
        +StructuredOutputsParams structured_outputs
        
        +validate_sampling_params()
        +update_from_generation_config()
        +clone() SamplingParams
    }
    
    class StructuredOutputsParams {
        +str|dict json
        +str regex
        +List~str~ choice
        +str grammar
        +bool json_object
        +bool disable_fallback
        +str whitespace_pattern
        +str structural_tag
        +str _backend
        +bool _backend_was_auto
        
        +to_dict() dict
        +from_dict() StructuredOutputsParams
    }
    
    SamplingParams --> StructuredOutputsParams

核心实现

@dataclass
class SamplingParams:
    """
    文本生成的采样参数配置
    
    主要参数分类:
    1. 基础生成参数: n, max_tokens, min_tokens
    2. 采样控制参数: temperature, top_p, top_k
    3. 惩罚参数: presence_penalty, frequency_penalty
    4. 停止条件: stop, ignore_eos
    5. 结构化输出: structured_outputs
    6. 随机性控制: seed, logit_bias
    """
    
    # 基础参数
    n: int = 1                           # 生成序列数量
    max_tokens: Optional[int] = 16       # 最大生成token数
    min_tokens: int = 0                  # 最小生成token数
    
    # 采样参数
    temperature: float = 1.0             # 温度参数,控制随机性
    top_p: float = 1.0                  # 核采样参数
    top_k: int = -1                     # Top-K采样参数
    
    # 惩罚参数
    presence_penalty: float = 0.0        # 存在惩罚
    frequency_penalty: float = 0.0       # 频率惩罚
    repetition_penalty: float = 1.0      # 重复惩罚
    
    # 停止条件
    stop: Union[None, str, List[str]] = None    # 停止字符串
    ignore_eos: bool = False                    # 是否忽略EOS token
    
    # 日志概率
    logprobs: Optional[int] = None              # 返回的logprob数量
    prompt_logprobs: Optional[int] = None       # prompt的logprob数量
    
    # 高级参数
    logit_bias: Optional[Dict[int, float]] = None      # logit偏置
    allowed_token_ids: Optional[List[int]] = None      # 允许的token ID
    structured_outputs: Optional[StructuredOutputsParams] = None  # 结构化输出
    
    def validate_sampling_params(self) -> None:
        """验证采样参数的合法性"""
        if self.temperature < 0:
            raise ValueError("temperature must be non-negative")
            
        if not 0 <= self.top_p <= 1:
            raise ValueError("top_p must be in [0, 1]")
            
        if self.top_k < -1 or self.top_k == 0:
            raise ValueError("top_k must be -1 (disabled) or > 0")
            
        if self.max_tokens is not None and self.max_tokens <= 0:
            raise ValueError("max_tokens must be positive")

3. 输出数据结构

3.1 RequestOutput 设计

文件位置: vllm/outputs.py

classDiagram
    class RequestOutput {
        +str request_id
        +str prompt  
        +List~int~ prompt_token_ids
        +PromptLogprobs prompt_logprobs
        +List~CompletionOutput~ outputs
        +bool finished
        +RequestMetrics metrics
        +LoRARequest lora_request
        +str encoder_prompt
        +List~int~ encoder_prompt_token_ids
        +int num_cached_tokens
        
        +get_final_output() CompletionOutput
        +is_streaming() bool
    }
    
    class CompletionOutput {
        +int index
        +str text
        +List~int~ token_ids
        +float cumulative_logprob
        +SampleLogprobs logprobs
        +str finish_reason
        +str|int stop_reason
        +LoRARequest lora_request
        
        +finished() bool
    }
    
    class PoolingRequestOutput {
        +str request_id
        +List~PoolingOutput~ outputs
        +List~int~ prompt_token_ids
        +bool finished
        
        +from_base() PoolingRequestOutput
    }
    
    class PoolingOutput {
        +torch.Tensor data
        
        +__eq__() bool
    }
    
    RequestOutput --> CompletionOutput
    PoolingRequestOutput --> PoolingOutput

核心实现

@dataclass
class RequestOutput:
    """
    LLM完成请求的输出数据
    
    属性说明:
        request_id: 请求唯一ID
        prompt: 原始提示字符串
        prompt_token_ids: 提示的token ID列表
        outputs: 生成的完成输出列表
        finished: 整个请求是否完成
        metrics: 请求相关指标
    """
    
    request_id: str
    prompt: Optional[str]
    prompt_token_ids: Optional[List[int]]
    prompt_logprobs: Optional[PromptLogprobs] 
    outputs: List[CompletionOutput]
    finished: bool
    metrics: Optional[RequestMetrics] = None
    lora_request: Optional[LoRARequest] = None
    
    # 编码器-解码器支持
    encoder_prompt: Optional[str] = None
    encoder_prompt_token_ids: Optional[List[int]] = None
    
    # 缓存相关
    num_cached_tokens: int = 0
    
    def get_final_output(self) -> CompletionOutput:
        """获取最终的完成输出(通常是最佳输出)"""
        if not self.outputs:
            raise ValueError("No outputs available")
        return self.outputs[0]
    
    def is_streaming(self) -> bool:
        """检查是否为流式输出"""
        return not self.finished

@dataclass  
class CompletionOutput:
    """
    单个完成输出的数据
    
    属性说明:
        index: 在请求中的索引位置
        text: 生成的文本内容
        token_ids: 生成的token ID序列
        cumulative_logprob: 累积对数概率
        logprobs: 详细的对数概率信息
        finish_reason: 完成原因(如停止、长度限制等)
    """
    
    index: int
    text: str
    token_ids: List[int]
    cumulative_logprob: Optional[float]
    logprobs: Optional[SampleLogprobs]
    finish_reason: Optional[str] = None
    stop_reason: Union[int, str, None] = None
    lora_request: Optional[LoRARequest] = None
    
    def finished(self) -> bool:
        """检查输出是否完成"""
        return self.finish_reason is not None

4. KV缓存数据结构

4.1 KV缓存管理器设计

classDiagram
    class KVCacheManager {
        +KVCacheConfig kv_cache_config
        +int max_model_len
        +bool enable_caching
        +BlockPool block_pool
        +dict~str,KVCacheBlocks~ seq_to_blocks
        +PrefixCache prefix_cache
        
        +allocate_blocks(seq_id, num_tokens) List~int~
        +free_blocks(seq_id) void
        +get_available_blocks() int
        +fork_sequence(parent_id, child_id) void
    }
    
    class KVCacheBlocks {
        +List~int~ block_ids
        +int num_tokens
        +bool is_full
        
        +append_token() void
        +get_block_table() List~int~
        +copy_from(other) void
    }
    
    class BlockPool {
        +int num_blocks
        +int block_size
        +str device
        +Set~int~ free_blocks
        +dict~int,Block~ allocated_blocks
        
        +allocate() int
        +free(block_id) void
        +has_available() bool
        +get_utilization() float
    }
    
    class Block {
        +int block_id
        +int size
        +torch.Tensor data
        +int ref_count
        +bool is_computed
        
        +increment_ref() void
        +decrement_ref() void
        +copy_to(other) void
    }
    
    class PrefixCache {
        +dict~str,List~int~~ hash_to_blocks
        +int max_cache_size
        +LRUCache lru
        
        +get(prefix_hash) Optional~List~int~~
        +put(prefix_hash, blocks) void
        +evict_lru() void
    }
    
    KVCacheManager --> KVCacheBlocks
    KVCacheManager --> BlockPool
    KVCacheManager --> PrefixCache
    BlockPool --> Block

核心实现

class KVCacheManager:
    """
    KV缓存管理器,实现PagedAttention的核心逻辑
    """
    
    def __init__(
        self,
        kv_cache_config: KVCacheConfig,
        max_model_len: int,
        enable_caching: bool = True,
        log_stats: bool = False,
    ):
        """
        初始化KV缓存管理器
        
        Args:
            kv_cache_config: KV缓存配置
            max_model_len: 模型最大长度
            enable_caching: 是否启用前缀缓存
            log_stats: 是否记录统计信息
        """
        self.kv_cache_config = kv_cache_config
        self.max_model_len = max_model_len
        self.enable_caching = enable_caching
        self.log_stats = log_stats
        
        # 块配置
        self.block_size = kv_cache_config.block_size
        self.num_gpu_blocks = kv_cache_config.num_gpu_blocks
        self.num_cpu_blocks = kv_cache_config.num_cpu_blocks
        
        # 初始化块池
        self.gpu_block_pool = BlockPool(
            num_blocks=self.num_gpu_blocks,
            block_size=self.block_size,
            device="cuda"
        )
        
        self.cpu_block_pool = BlockPool(
            num_blocks=self.num_cpu_blocks,
            block_size=self.block_size,
            device="cpu"
        )
        
        # 序列到块的映射
        self.seq_to_blocks: Dict[str, KVCacheBlocks] = {}
        
        # 前缀缓存
        self.prefix_cache = PrefixCache(
            max_cache_size=1000,
            enable=enable_caching
        ) if enable_caching else None
    
    def allocate_blocks(
        self,
        seq_id: str,
        num_tokens: int,
        parent_seq_id: Optional[str] = None
    ) -> List[int]:
        """
        为序列分配KV缓存块
        
        Args:
            seq_id: 序列标识符
            num_tokens: 需要缓存的token数量
            parent_seq_id: 父序列ID(fork时使用)
            
        Returns:
            分配的块ID列表
            
        处理流程:
        1. 计算需要的块数量
        2. 检查是否可以重用父序列的块
        3. 查找前缀缓存可重用的块
        4. 从块池分配新块
        5. 更新序列-块映射
        """
        # 计算所需块数
        num_blocks_needed = (num_tokens + self.block_size - 1) // self.block_size
        
        allocated_blocks = []
        
        # 处理序列fork情况
        if parent_seq_id and parent_seq_id in self.seq_to_blocks:
            parent_blocks = self.seq_to_blocks[parent_seq_id]
            # 复制父序列的块(copy-on-write)
            allocated_blocks.extend(parent_blocks.block_ids)
        
        # 检查前缀缓存
        if self.prefix_cache and not parent_seq_id:
            prefix_hash = self._compute_prefix_hash(seq_id, num_tokens)
            cached_blocks = self.prefix_cache.get(prefix_hash)
            if cached_blocks:
                allocated_blocks.extend(cached_blocks)
        
        # 分配额外所需的块
        blocks_to_allocate = max(0, num_blocks_needed - len(allocated_blocks))
        for _ in range(blocks_to_allocate):
            if self.gpu_block_pool.has_available():
                block_id = self.gpu_block_pool.allocate()
                allocated_blocks.append(block_id)
            else:
                # GPU内存不足,尝试抢占或使用CPU
                raise OutOfMemoryError("Insufficient GPU memory for KV cache")
        
        # 创建KVCacheBlocks对象
        kv_blocks = KVCacheBlocks(
            block_ids=allocated_blocks,
            num_tokens=num_tokens,
            block_size=self.block_size
        )
        
        self.seq_to_blocks[seq_id] = kv_blocks
        return allocated_blocks
    
    def free_blocks(self, seq_id: str) -> None:
        """
        释放序列的KV缓存块
        
        Args:
            seq_id: 要释放的序列ID
        """
        if seq_id not in self.seq_to_blocks:
            return
            
        kv_blocks = self.seq_to_blocks[seq_id]
        
        # 释放块到池中
        for block_id in kv_blocks.block_ids:
            if block_id < self.num_gpu_blocks:
                self.gpu_block_pool.free(block_id)
            else:
                self.cpu_block_pool.free(block_id)
        
        # 从映射中删除
        del self.seq_to_blocks[seq_id]
    
    def get_available_blocks(self) -> int:
        """获取可用的KV缓存块数量"""
        return (self.gpu_block_pool.get_free_blocks() + 
                self.cpu_block_pool.get_free_blocks())

5. 调度数据结构

5.1 调度器输出设计

classDiagram
    class SchedulerOutput {
        +List~Request~ prefill_requests
        +List~Request~ decode_requests  
        +List~Request~ preempted_requests
        +dict~str,KVCacheBlocks~ req_to_new_blocks
        +dict~str,int~ num_scheduled_tokens
        +SchedulerMetrics metrics
        
        +get_total_tokens() int
        +get_batch_size() int
        +is_empty() bool
    }
    
    class SchedulerMetrics {
        +int num_waiting_requests
        +int num_running_requests
        +int num_preempted_requests
        +float gpu_memory_utilization
        +float scheduling_time
        +int total_tokens_scheduled
        
        +to_dict() dict
        +update_from(other) void
    }
    
    class RequestQueue {
        +PriorityQueue waiting_queue
        +List~Request~ running_requests
        +SchedulingPolicy policy
        
        +add_request(request) void
        +get_next_request() Optional~Request~
        +remove_request(request_id) bool
        +get_queue_size() int
    }
    
    SchedulerOutput --> SchedulerMetrics
    RequestQueue --> Request

6. 多模态数据结构

6.1 多模态输入设计

classDiagram
    class MultiModalDataDict {
        +dict~str,Any~ data
        
        +get_image_data() List~Image~
        +get_audio_data() List~Audio~
        +get_video_data() List~Video~
    }
    
    class MultiModalInputs {
        +str modality_type
        +torch.Tensor embeddings
        +List~int~ token_positions
        +dict~str,Any~ processor_kwargs
        
        +get_embedding_size() int
        +get_num_tokens() int
    }
    
    class ImageInput {
        +PIL.Image image
        +str format
        +Tuple~int,int~ size
        +torch.Tensor tensor
        
        +preprocess() torch.Tensor
        +resize(size) ImageInput
    }
    
    class AudioInput {
        +np.ndarray waveform
        +int sample_rate
        +float duration
        +torch.Tensor features
        
        +extract_features() torch.Tensor
        +resample(rate) AudioInput
    }
    
    MultiModalDataDict --> ImageInput
    MultiModalDataDict --> AudioInput
    MultiModalInputs --> MultiModalDataDict

7. 性能监控数据结构

7.1 指标收集设计

classDiagram
    class RequestMetrics {
        +float arrival_time
        +float first_token_time
        +float last_token_time
        +int prompt_tokens
        +int completion_tokens
        +float time_to_first_token
        +float inter_token_latency
        +List~float~ token_timestamps
        
        +calculate_throughput() float
        +get_total_time() float
    }
    
    class SchedulerStats {
        +int num_running_seqs
        +int num_waiting_seqs
        +int num_preempted_seqs
        +float gpu_cache_usage
        +float cpu_cache_usage
        +int total_num_batched_tokens
        
        +to_prometheus_labels() dict
        +update() void
    }
    
    class ModelRunnerStats {
        +float model_execution_time
        +int batch_size
        +int num_tokens
        +float tokens_per_second
        +float gpu_utilization
        +Dict~str,float~ layer_timings
        
        +add_timing(layer, time) void
        +get_average_latency() float
    }
    
    RequestMetrics <-- Request
    SchedulerStats <-- Scheduler
    ModelRunnerStats <-- ModelRunner

8. 数据结构关系总览

erDiagram
    USER_INPUT ||--o{ TEXT_PROMPT : contains
    USER_INPUT ||--o{ TOKENS_PROMPT : contains
    
    TEXT_PROMPT ||--|| PROCESSOR_INPUTS : "processes to"
    TOKENS_PROMPT ||--|| PROCESSOR_INPUTS : "processes to"
    
    PROCESSOR_INPUTS ||--|| ENGINE_CORE_REQUEST : "creates"
    ENGINE_CORE_REQUEST ||--|| REQUEST : "becomes"
    
    REQUEST ||--o{ KV_CACHE_BLOCKS : "allocates"
    REQUEST ||--|| SAMPLING_PARAMS : "configured by"
    
    SCHEDULER ||--o{ REQUEST : "manages"
    SCHEDULER ||--|| SCHEDULER_OUTPUT : "produces"
    
    SCHEDULER_OUTPUT ||--o{ REQUEST : "contains"
    SCHEDULER_OUTPUT ||--|| MODEL_RUNNER : "sent to"
    
    MODEL_RUNNER ||--|| MODEL_RUNNER_OUTPUT : "produces"
    MODEL_RUNNER_OUTPUT ||--|| REQUEST_OUTPUT : "becomes"
    
    REQUEST_OUTPUT ||--o{ COMPLETION_OUTPUT : "contains"
    
    KV_CACHE_MANAGER ||--o{ KV_CACHE_BLOCKS : "manages"
    KV_CACHE_BLOCKS ||--o{ BLOCK : "composed of"
    
    REQUEST_METRICS ||--|| REQUEST : "tracks"
    SCHEDULER_STATS ||--|| SCHEDULER : "monitors"

9. 关键设计模式

9.1 工厂模式

VLLM在多个地方使用了工厂模式来创建合适的实现:

class AttentionBackendFactory:
    """注意力后端工厂"""
    
    @staticmethod
    def create_backend(model_config: ModelConfig) -> AttentionBackend:
        """根据配置创建合适的注意力后端"""
        
        if model_config.use_flashattn:
            return FlashAttentionBackend(model_config)
        elif model_config.use_triton:
            return TritonAttentionBackend(model_config)
        else:
            return DefaultAttentionBackend(model_config)

class ExecutorFactory:
    """执行器工厂"""
    
    @staticmethod 
    def create_executor(parallel_config: ParallelConfig) -> Executor:
        """根据并行配置创建执行器"""
        
        if parallel_config.tensor_parallel_size > 1:
            return MultiGPUExecutor(parallel_config)
        else:
            return SingleGPUExecutor(parallel_config)

9.2 观察者模式

用于性能监控和指标收集:

class MetricCollector:
    """指标收集器,实现观察者模式"""
    
    def __init__(self):
        self.observers: List[MetricObserver] = []
    
    def add_observer(self, observer: MetricObserver):
        self.observers.append(observer)
    
    def notify_observers(self, metric: Metric):
        for observer in self.observers:
            observer.update(metric)

class PrometheusObserver(MetricObserver):
    """Prometheus指标观察者"""
    
    def update(self, metric: Metric):
        # 更新Prometheus指标
        prometheus_client.Counter(metric.name).inc(metric.value)

9.3 状态机模式

请求状态管理:

class RequestStateMachine:
    """请求状态机"""
    
    TRANSITIONS = {
        RequestStatus.WAITING: [RequestStatus.RUNNING, RequestStatus.PREEMPTED],
        RequestStatus.RUNNING: [RequestStatus.FINISHED_STOPPED, RequestStatus.SWAPPED],
        RequestStatus.PREEMPTED: [RequestStatus.RUNNING, RequestStatus.FINISHED_ABORTED],
        RequestStatus.SWAPPED: [RequestStatus.RUNNING, RequestStatus.FINISHED_ABORTED]
    }
    
    def transition(self, request: Request, new_status: RequestStatus) -> bool:
        """状态转换"""
        if new_status in self.TRANSITIONS.get(request.status, []):
            request.status = new_status
            return True
        return False

这个数据结构分析展示了VLLM如何通过精心设计的数据结构来支持高性能的LLM推理。每个数据结构都有明确的职责,并通过合理的关系设计实现了高效的数据流转和状态管理。