TensorRT-LLM-04-Executor模块-深度剖析

一、模块概览

1.1 模块定位

Executor 模块是 TensorRT-LLM 的高层请求调度和管理层,负责异步请求提交、批处理调度、结果分发和多进程/多节点协调。

核心职责:

  • 请求管理:异步请求提交和生命周期管理
  • 批处理调度:动态批处理和Inflight Batching
  • 结果分发:流式和非流式结果返回
  • 多进程协调:MPI/Ray多进程通信
  • 后处理:异步解tokenize和后处理

1.2 模块架构

Executor 模块架构:

tensorrt_llm/executor/
├── executor.py                    # 抽象基类
│   └── GenerationExecutor
│       ├── submit()               # 提交请求
│       ├── generate_async()       # 异步生成
│       ├── generate()             # 同步生成
│       └── shutdown()             # 关闭
├── base_worker.py                 # Worker实现
│   └── BaseWorker(GenerationExecutor)
│       ├── _enqueue_request()     # 入队请求
│       ├── _handle_responses()    # 处理响应
│       └── _background_loop()     # 后台循环
├── worker.py                      # Worker进程入口
│   └── worker_main()              # Worker主函数
├── ray_executor.py                # Ray分布式执行器
│   └── RayExecutor
│       ├── submit()               # Ray请求提交
│       └── collective_rpc()       # 集合通信
├── request.py                     # 请求定义
│   ├── GenerationRequest         # 生成请求
│   ├── LoRARequest               # LoRA请求
│   └── PromptAdapterRequest      # Prompt Adapter请求
├── result.py                      # 结果定义
│   ├── GenerationResult          # 生成结果
│   └── IterationResult           # 迭代结果
└── postproc_worker.py             # 后处理Worker
    └── PostprocWorker
        └── _handle_input()        # 处理输入

1.3 核心数据流

用户请求 → GenerationRequest → submit() → RequestQueue → Worker
                                                         Executor
                                                      TRT推理引擎
                                                         Response
                                    ResponseQueue ← _handle_responses()
                                 PostprocWorker(解tokenize)
                                 GenerationResult → 用户

二、核心API详细剖析

2.1 GenerationExecutor.submit() 方法

2.1.1 函数签名

@abstractmethod
def submit(self, request: GenerationRequest) -> GenerationResult:
    """
    提交生成请求(低层API)
    
    Args:
        request: GenerationRequest对象
        
    Returns:
        GenerationResult: 包含Future的结果对象
    """
    pass

2.1.2 GenerationRequest结构体

GenerationRequest

字段 类型 必填 说明
prompt_token_ids List[int] 输入Token IDs
sampling_params SamplingParams 采样配置
max_tokens int 最大生成Token数(可通过sampling_params设置)
query_token_ids List[int] 查询Token IDs(用于prefix caching)
lora_request LoRARequest LoRA适配器请求
prompt_adapter_request PromptAdapterRequest Prompt Adapter请求
streaming bool 是否流式返回
kv_cache_retention_config KvCacheRetentionConfig KV Cache保留配置
multimodal_params MultimodalParams 多模态参数
arrival_time float 请求到达时间(用于调度分析)

GenerationResult结构体

字段 类型 说明
request GenerationRequest 原始请求
queue Queue 结果队列(内部使用)
streaming bool 是否流式
done bool 是否完成

2.1.3 核心代码实现(BaseWorker)

def submit(self, request: GenerationRequest) -> GenerationResult:
    # 1. 分配唯一请求ID
    request.set_id(self._get_next_client_id())
    self._last_client_id = request.id
    
    # 2. 提取logprob参数(用于后续解tokenize)
    logprob_params = self._get_logprob_params(request)
    
    # 3. 创建GenerationResult(Future对象)
    result = GenerationResult(
        request,
        background_error_handler=self._handle_background_error,
        executor=self,
        disaggregated_params=request.disaggregated_params,
        logprob_params=logprob_params,
    )
    
    # 4. 保存到_results映射(用于后续查找)
    self._results[request.id] = result
    
    # 5. 入队请求
    self._enqueue_request(request, result.queue)
    
    return result

2.1.4 _enqueue_request 详细实现

def _enqueue_request(self, request: GenerationRequest, result_wait_queue=None) -> int:
    # 1. 创建C++ Executor请求对象
    executor_request = tllm.Request(
        input_token_ids=request.prompt_token_ids,
        max_tokens=request.sampling_params.max_tokens,
        streaming=request.streaming,
        sampling_config=tllm.SamplingConfig(
            beam_width=request.sampling_params.beam_width if request.sampling_params.use_beam_search else 1,
            temperature=request.sampling_params.temperature,
            top_k=request.sampling_params.top_k,
            top_p=request.sampling_params.top_p,
            # ... 其他采样参数
        ),
        output_config=tllm.OutputConfig(
            return_log_probs=request.sampling_params.logprobs is not None,
            return_context_logits=request.sampling_params.return_context_logits,
            return_generation_logits=request.sampling_params.return_generation_logits,
        ),
        # 高级特性
        lora_config=self._create_lora_config(request.lora_request),
        prompt_tuning_config=self._create_prompt_adapter_config(request.prompt_adapter_request),
        kv_cache_retention_config=request.kv_cache_retention_config,
    )
    
    # 2. 设置logits后处理器(用于自定义采样逻辑)
    if request.sampling_params.logits_processor:
        executor_request.logits_post_processor = request.sampling_params.logits_processor
    
    # 3. 提交到C++ Executor
    request_id = self.executor.enqueue_request(executor_request)
    
    # 4. 映射request_id到client_id
    self.req_id_to_client_id[request_id] = request.id
    
    return request_id

2.1.5 调用链路

GenerationExecutor.submit()
  ├─→ 分配请求ID
  │     └─→ _get_next_client_id()
  ├─→ 创建GenerationResult
  │     └─→ result.queue = Queue()
  ├─→ _enqueue_request()
  │     ├─→ 创建tllm.Request
  │     │     ├─→ 转换SamplingConfig
  │     │     ├─→ 转换OutputConfig
  │     │     └─→ 设置logits_processor
  │     ├─→ executor.enqueue_request()  # C++层
  │     │     ├─→ 加入请求队列
  │     │     ├─→ 调度器通知
  │     │     └─→ 返回request_id
  │     └─→ 映射request_id → client_id
  └─→ 返回GenerationResult

2.2 GenerationExecutor.generate_async() 方法

2.2.1 函数签名

def generate_async(
    self,
    prompt_token_ids: List[int],
    sampling_params: SamplingParams,
    streaming: bool = False,
    **kwargs
) -> GenerationResult:
    """
    异步生成(高层API)
    
    Args:
        prompt_token_ids: 输入Token IDs
        sampling_params: 采样配置
        streaming: 是否流式返回
        
    Returns:
        GenerationResult: Future对象,可通过result()或aresult()获取
    """

2.2.2 核心实现

def generate_async(self, prompt_token_ids, sampling_params, streaming=False, **kwargs):
    # 1. 校验输入
    assert isinstance(prompt_token_ids[0], int)
    assert isinstance(sampling_params, SamplingParams)
    
    # 2. 初始化迭代结果队列(首次调用)
    self._maybe_initialize_iteration_results()
    
    # 3. 创建GenerationRequest
    request = GenerationRequest(
        prompt_token_ids,
        sampling_params=sampling_params,
        streaming=streaming,
        query_token_ids=kwargs.get('query_token_ids'),
        lora_request=kwargs.get('lora_request'),
        prompt_adapter_request=kwargs.get('prompt_adapter_request'),
        # ... 其他参数
    )
    
    # 4. 提交请求
    result = self.submit(request)
    
    return result

2.2.3 结果获取

# 同步获取(阻塞)
result = executor.generate_async(prompt_token_ids, sampling_params)
output = result.result()  # 阻塞直到完成

# 异步获取(需要在async函数中)
async def generate():
    result = executor.generate_async(prompt_token_ids, sampling_params)
    output = await result.aresult()  # 异步等待
    return output

# 流式获取
result = executor.generate_async(prompt_token_ids, sampling_params, streaming=True)
for output in result:  # 迭代器,每生成一个token返回一次
    print(output.text, end='', flush=True)

2.3 后台响应处理循环

2.3.1 _handle_responses() 方法

def _handle_responses(self):
    """
    后台线程:持续从C++ Executor获取响应并分发
    """
    while not self._shutdown:
        try:
            # 1. 从C++ Executor获取响应(阻塞,带超时)
            responses = self.executor.await_responses(timeout=datetime.timedelta(milliseconds=100))
            
            # 2. 批量处理响应
            for response in responses:
                # 2.1 查找对应的client_id
                client_id = self.req_id_to_client_id.get(response.request_id)
                if client_id is None:
                    continue
                
                # 2.2 获取GenerationResult
                result = self._results.get(client_id)
                if result is None:
                    continue
                
                # 2.3 解析响应
                output_token_ids = response.result.output_token_ids[0]  # beam_idx=0
                is_final = response.result.is_final
                
                # 2.4 创建CompletionOutput
                completion_output = CompletionOutput(
                    index=0,
                    text="",  # 后续由PostprocWorker解tokenize
                    token_ids=output_token_ids,
                    cumulative_logprob=response.result.cum_log_probs,
                    logprobs=self._get_logprobs(response),
                    finish_reason=response.result.finish_reason if is_final else None,
                )
                
                # 2.5 发送到后处理队列
                if self.postproc_config.num_postprocess_workers > 0:
                    # 异步后处理
                    postproc_input = PostprocWorker.Input(
                        response=response,
                        sampling_params=result.sampling_params,
                        streaming=result.streaming,
                    )
                    pid = client_id % self.postproc_config.num_postprocess_workers
                    self.postproc_queues[pid].put(postproc_input)
                else:
                    # 同步后处理
                    text = self.tokenizer.decode(output_token_ids)
                    completion_output.text = text
                    result.queue.put(completion_output)
                
                # 2.6 标记完成
                if is_final:
                    self._pop_result(client_id)
        
        except Exception as e:
            self._error_queue.put(e)
            break

2.3.2 时序图

sequenceDiagram
    autonumber
    participant User
    participant Executor as GenerationExecutor
    participant ReqQueue as Request Queue
    participant CppExec as C++ Executor
    participant Scheduler as Batch Scheduler
    participant Runtime as TRT Runtime
    participant RespQueue as Response Queue
    participant BgThread as Background Thread
    participant PostProc as PostprocWorker
    
    User->>Executor: submit(request)
    activate Executor
    
    Executor->>Executor: 分配client_id
    Executor->>Executor: 创建GenerationResult
    
    Executor->>CppExec: enqueue_request()
    activate CppExec
    CppExec->>ReqQueue: push(request)
    CppExec-->>Executor: request_id
    deactivate CppExec
    
    Executor-->>User: GenerationResult
    deactivate Executor
    
    par 后台调度和推理
        loop 调度循环
            Scheduler->>ReqQueue: 获取待处理请求
            Scheduler->>Scheduler: 构建批次(Inflight Batching)
            Scheduler->>Runtime: 批量推理
            activate Runtime
            Runtime->>Runtime: TRT执行
            Runtime-->>Scheduler: logits
            deactivate Runtime
            Scheduler->>Scheduler: Token采样
            Scheduler->>RespQueue: push(response)
        end
    and 后台响应处理
        loop 响应循环
            BgThread->>CppExec: await_responses(timeout)
            activate CppExec
            CppExec->>RespQueue: pop(response)
            CppExec-->>BgThread: response
            deactivate CppExec
            
            BgThread->>BgThread: 查找client_id
            BgThread->>BgThread: 创建CompletionOutput
            
            alt 有后处理Worker
                BgThread->>PostProc: put(response)
                activate PostProc
                PostProc->>PostProc: 解tokenize
                PostProc->>PostProc: 创建text
                PostProc-->>BgThread: result
                deactivate PostProc
            else 无后处理Worker
                BgThread->>BgThread: 同步解tokenize
            end
            
            BgThread->>Executor: result.queue.put(output)
        end
    end
    
    User->>Executor: result.result()
    Executor->>Executor: queue.get()
    Executor-->>User: CompletionOutput

三、关键功能深度剖析

3.1 动态批处理调度

3.1.1 Inflight Batching原理

传统静态批处理:
Batch 1: [Req1, Req2, Req3] → 推理 → [全部完成] → Batch 2: [Req4, Req5, Req6]
问题:Req1完成后需要等待Req2、Req3完成才能释放资源

Inflight Batching:
Step 0: [Req1, Req2, Req3] → 推理
Step 1: [Req1, Req2, Req3] → 推理(Req1生成第2个token)
Step 2: [Req1, Req2, Req3, Req4] → 推理(Req4加入,Req1生成第3个token)
Step 3: [Req2, Req3, Req4] → 推理(Req1完成并移除)
Step 4: [Req2, Req3, Req4, Req5] → 推理(Req5加入)
...

优势:
- 更高GPU利用率(批次动态变化)
- 更低延迟(新请求无需等待)
- 更高吞吐量(相比静态批处理提升2-3倍)

3.1.2 调度策略代码

class BatchScheduler:
    """
    动态批处理调度器(C++实现,此处为伪代码)
    """
    def __init__(self, max_batch_size, max_num_tokens):
        self.max_batch_size = max_batch_size
        self.max_num_tokens = max_num_tokens
        self.active_requests = []
        self.pending_requests = Queue()
    
    def schedule_step(self):
        """
        单步调度:决定哪些请求参与本次推理
        """
        # 1. 移除已完成的请求
        self.active_requests = [
            req for req in self.active_requests
            if not req.is_finished()
        ]
        
        # 2. 尝试添加新请求
        while not self.pending_requests.empty():
            new_req = self.pending_requests.peek()
            
            # 2.1 检查批次大小限制
            if len(self.active_requests) >= self.max_batch_size:
                break
            
            # 2.2 检查Token数量限制(关键!)
            current_tokens = sum(req.num_tokens for req in self.active_requests)
            new_tokens = new_req.num_tokens
            if current_tokens + new_tokens > self.max_num_tokens:
                break
            
            # 2.3 加入批次
            self.active_requests.append(self.pending_requests.get())
        
        # 3. 返回当前批次
        return self.active_requests

3.2 多进程/多GPU协调

3.2.1 MPI通信架构

Rank 0 (Leader)              Rank 1              Rank 2              Rank 3
    ↓                          ↓                   ↓                   ↓
Executor                   Executor            Executor            Executor
    ↓                          ↓                   ↓                   ↓
TRT Engine                TRT Engine          TRT Engine          TRT Engine
(部分权重 TP=0)            (部分权重 TP=1)      (部分权重 TP=2)      (部分权重 TP=3)
    ↓                          ↓                   ↓                   ↓
    └────────────── NCCL AllReduce/AllGather ────────────────────────┘
                           完整输出(仅Rank 0)

3.2.2 Tensor Parallel推理流程

def tensor_parallel_inference(input_ids, tp_size=4):
    """
    Tensor Parallel推理(伪代码)
    """
    # 1. 所有rank接收相同输入
    # input_ids: [batch, seq_len]
    
    # 2. Embedding层(每个rank持有vocab_size/tp_size的embedding)
    embeddings = model.embedding(input_ids)  # [batch, seq_len, hidden]
    
    # 3. Transformer层
    for layer in model.layers:
        # 3.1 Attention(Q/K/V分片)
        # Rank 0: heads 0-7
        # Rank 1: heads 8-15
        # Rank 2: heads 16-23
        # Rank 3: heads 24-31
        attn_out = layer.attention(embeddings)  # [batch, seq_len, hidden/tp_size]
        
        # 3.2 AllReduce合并结果
        attn_out = nccl.all_reduce(attn_out)  # [batch, seq_len, hidden]
        
        # 3.3 FFN(分片)
        ffn_out = layer.ffn(attn_out)  # [batch, seq_len, hidden/tp_size]
        
        # 3.4 AllReduce合并结果
        embeddings = nccl.all_reduce(ffn_out)  # [batch, seq_len, hidden]
    
    # 4. LM Head(每个rank持有vocab_size/tp_size)
    logits = model.lm_head(embeddings)  # [batch, seq_len, vocab/tp_size]
    
    # 5. AllGather收集完整logits(仅Rank 0需要)
    if rank == 0:
        logits = nccl.all_gather(logits)  # [batch, seq_len, vocab]
    
    return logits

3.3 流式结果返回

3.3.1 流式迭代器实现

class GenerationResult:
    """
    生成结果Future对象
    """
    def __init__(self, request, queue, streaming):
        self.request = request
        self.queue = queue  # 接收CompletionOutput的队列
        self.streaming = streaming
        self._done = False
        self._final_output = None
    
    def __iter__(self):
        """
        流式迭代器(支持for循环)
        """
        assert self.streaming, "Not a streaming result"
        
        while True:
            # 从队列获取输出(阻塞)
            output = self.queue.get()
            
            # 检查是否完成
            if output.finish_reason is not None:
                self._done = True
                self._final_output = output
                yield output
                break
            
            yield output
    
    def result(self, timeout=None):
        """
        同步获取最终结果(非流式)
        """
        if self._done:
            return self._final_output
        
        # 阻塞等待
        while True:
            output = self.queue.get(timeout=timeout)
            if output.finish_reason is not None:
                self._done = True
                self._final_output = output
                return output
    
    async def aresult(self, timeout=None):
        """
        异步获取最终结果
        """
        if self._done:
            return self._final_output
        
        # 异步等待
        while True:
            output = await self.queue.get_async(timeout=timeout)
            if output.finish_reason is not None:
                self._done = True
                self._final_output = output
                return output

3.3.2 流式使用示例

# 流式生成
result = executor.generate_async(
    prompt_token_ids=[1, 15043, 29892],
    sampling_params=SamplingParams(max_tokens=100),
    streaming=True
)

# 方法1:for循环迭代
for output in result:
    new_text = output.text[len(prev_text):]  # 增量文本
    print(new_text, end='', flush=True)
    prev_text = output.text

# 方法2:异步迭代
async for output in result:
    print(output.text_diff, end='', flush=True)

四、数据结构UML图

4.1 Executor核心类图

classDiagram
    class GenerationExecutor {
        <<abstract>>
        +submit(request) GenerationResult
        +generate_async(...) GenerationResult
        +generate(...) List~GenerationResult~
        +abort_request(request_id)
        +shutdown()
        #_get_next_client_id() int
    }
    
    class BaseWorker {
        +executor: tllm.Executor
        +tokenizer: TokenizerBase
        +result_queue: Queue
        +postproc_queues: List~Queue~
        +_results: Dict~int, GenerationResult~
        +_enqueue_request(request) int
        +_handle_responses()
        -_background_loop()
    }
    
    class RayExecutor {
        +workers: List~RayWorker~
        +request_queue: RayQueue
        +response_queue: RayQueue
        +submit(request) GenerationResult
        +collective_rpc(method, args)
    }
    
    class GenerationRequest {
        +id: int
        +prompt_token_ids: List~int~
        +sampling_params: SamplingParams
        +streaming: bool
        +lora_request: LoRARequest
        +multimodal_params: MultimodalParams
        +arrival_time: float
    }
    
    class GenerationResult {
        +request: GenerationRequest
        +queue: Queue
        +streaming: bool
        +done: bool
        +result(timeout) CompletionOutput
        +aresult(timeout) CompletionOutput
        +__iter__() Iterator
    }
    
    class CompletionOutput {
        +index: int
        +text: str
        +token_ids: List~int~
        +cumulative_logprob: float
        +logprobs: LogProbsResult
        +finish_reason: str
    }
    
    class PostprocWorker {
        +tokenizer: TokenizerBase
        +_pull_pipe: AsyncQueue
        +_push_pipe: AsyncQueue
        +_handle_input(inp) Output
        +_mainloop()
    }
    
    GenerationExecutor <|-- BaseWorker
    GenerationExecutor <|-- RayExecutor
    BaseWorker --> GenerationRequest : receives
    BaseWorker --> GenerationResult : returns
    GenerationResult --> CompletionOutput : yields
    BaseWorker --> PostprocWorker : uses

4.2 请求生命周期状态图

stateDiagram-v2
    [*] --> Created: generate_async()
    Created --> Pending: submit()
    
    Pending --> Scheduled: 调度器选中
    Pending --> Pending: 批次已满/Token数超限
    
    Scheduled --> ContextPhase: 首次推理
    ContextPhase --> Processing: 处理全部输入tokens
    Processing --> GenerationPhase: Context完成
    
    GenerationPhase --> Generating: 自回归生成
    Generating --> Generating: 生成下一个token
    
    Generating --> Finished: EOS或max_tokens
    Generating --> Aborted: abort_request()
    
    Finished --> [*]: 返回结果
    Aborted --> [*]: 清理资源

五、使用示例

5.1 基础异步生成

from tensorrt_llm.executor import GenerationExecutor
from tensorrt_llm import SamplingParams

# 1. 初始化Executor(通常由LLM类内部创建)
executor = GenerationExecutor.create(
    engine_dir="./llama-3-8b-engine",
    tokenizer_dir="./llama-3-8b",
)

# 2. 准备输入
prompt_token_ids = [1, 15043, 29892, 920, 526, 366]  # "Hello, how are you"

# 3. 配置采样
sampling_params = SamplingParams(
    max_tokens=100,
    temperature=0.8,
    top_p=0.95,
)

# 4. 异步生成
result = executor.generate_async(
    prompt_token_ids=prompt_token_ids,
    sampling_params=sampling_params,
)

# 5. 获取结果
output = result.result()  # 阻塞等待
print("Generated:", output.text)

5.2 流式生成

# 流式生成配置
result = executor.generate_async(
    prompt_token_ids=prompt_token_ids,
    sampling_params=sampling_params,
    streaming=True,  # 启用流式
)

# 流式迭代
print("Streaming: ", end='')
for output in result:
    # 每生成一个token返回一次
    print(output.text_diff, end='', flush=True)
print()

5.3 批量生成

# 批量请求
prompt_token_ids_list = [
    [1, 15043, 29892],  # "Hello,"
    [1, 3532, 825],     # "What is"
    [1, 1724, 437],     # "How to"
]

sampling_params_list = [
    SamplingParams(max_tokens=50, temperature=0.7),
    SamplingParams(max_tokens=100, temperature=0.9),
    SamplingParams(max_tokens=80, temperature=0.8),
]

# 同步批量生成
results = executor.generate(
    prompt_token_ids=prompt_token_ids_list,
    sampling_params=sampling_params_list,
)

for i, result in enumerate(results):
    print(f"Result {i}: {result.result().text}")

5.4 LoRA推理

from tensorrt_llm.executor import LoRARequest

# LoRA请求
lora_request = LoRARequest(
    lora_name="math_lora",
    lora_int_id=1,
    lora_path="./lora_weights/math",
)

# 使用LoRA生成
result = executor.generate_async(
    prompt_token_ids=prompt_token_ids,
    sampling_params=sampling_params,
    lora_request=lora_request,
)

output = result.result()

六、性能优化建议

6.1 批处理优化

# 配置更大的批次和Token数
executor_config = tllm.ExecutorConfig(
    max_batch_size=256,        # 更大批次
    max_num_tokens=8192,       # Token数限制
    scheduler_config=tllm.SchedulerConfig(
        policy=tllm.SchedulerPolicy.MAX_UTILIZATION,
    ),
)

6.2 异步后处理

# 启用多个后处理Worker
executor = GenerationExecutor.create(
    engine_dir="./engine",
    num_postprocess_workers=4,  # 4个后处理进程
    postprocess_tokenizer_dir="./tokenizer",
)

# 优势:解tokenize异步并行,降低主线程延迟

6.3 KV Cache复用

from tensorrt_llm.executor import KvCacheRetentionConfig

# 配置KV Cache保留
kv_config = KvCacheRetentionConfig(
    max_tokens=1024,  # 保留前1024个tokens的KV Cache
)

# 用于prefix相同的多轮对话
result = executor.generate_async(
    prompt_token_ids=prompt_token_ids,
    sampling_params=sampling_params,
    kv_cache_retention_config=kv_config,
)

七、常见问题

Q1:如何处理请求超时?

try:
    output = result.result(timeout=30.0)  # 30秒超时
except TimeoutError:
    executor.abort_request(result.request.id)

Q2:如何监控批处理状态?

# 通过stats队列获取统计信息
stats = executor.get_latest_stats()
print(f"Active requests: {stats.num_active_requests}")
print(f"Queued requests: {stats.num_queued_requests}")
print(f"Batch size: {stats.current_batch_size}")

Q3:多GPU如何分配请求?

  • Tensor Parallel(TP):所有GPU处理相同请求,并行计算
  • Pipeline Parallel(PP):不同GPU处理不同层
  • Data Parallel(DP):不同GPU处理不同请求