TensorRT-LLM API 深度分析

核心 API 概览

TensorRT-LLM 提供了分层的 API 架构,从高级的 LLM API 到底层的 Runtime API,满足不同层次的使用需求。

1. 高级 LLM API

1.1 LLM 类 - 主要入口点

位置: tensorrt_llm/llmapi/llm.py

class LLM(_TorchLLM):
    """TensorRT-LLM 的主要入口类,提供简化的 generate() API"""

    def __init__(self,
                 model: Union[str, Path],
                 tokenizer: Optional[Union[str, Path, TokenizerBase, PreTrainedTokenizerBase]] = None,
                 tokenizer_mode: Literal['auto', 'slow'] = 'auto',
                 skip_tokenizer_init: bool = False,
                 trust_remote_code: bool = False,
                 tensor_parallel_size: int = 1,
                 dtype: str = "auto",
                 revision: Optional[str] = None,
                 tokenizer_revision: Optional[str] = None,
                 **kwargs: Any) -> None:
        """
        初始化 LLM 实例

        Args:
            model: 模型路径或 HuggingFace 模型名称
            tokenizer: 分词器路径或实例
            tensor_parallel_size: 张量并行大小
            dtype: 数据类型 ("auto", "float16", "bfloat16", "float32")
            **kwargs: 其他配置参数
        """

核心方法分析:

generate() 方法

def generate(
    self,
    inputs: Union[PromptInputs, Sequence[PromptInputs]],
    sampling_params: Optional[Union[SamplingParams, List[SamplingParams]]] = None,
    use_tqdm: bool = True,
    lora_request: Optional[Union[LoRARequest, Sequence[LoRARequest]]] = None,
    prompt_adapter_request: Optional[Union[PromptAdapterRequest, Sequence[PromptAdapterRequest]]] = None,
    kv_cache_retention_config: Optional[Union[KvCacheRetentionConfig, Sequence[KvCacheRetentionConfig]]] = None,
    disaggregated_params: Optional[Union[DisaggregatedParams, Sequence[DisaggregatedParams]]] = None,
    scheduling_params: Optional[Union[SchedulingParams, List[SchedulingParams]]] = None,
) -> Union[RequestOutput, List[RequestOutput]]:
    """
    同步生成文本的核心方法

    功能说明:
    1. 输入预处理和验证
    2. 创建生成请求
    3. 提交到执行器
    4. 等待并收集结果
    5. 后处理和返回

    调用链路:
    generate() -> _generate_non_streaming() -> _submit_requests() -> executor.submit()
    """

实现细节:

# 位置: tensorrt_llm/llmapi/llm.py:280-350
def _generate_non_streaming(self, requests: List[GenerationRequest]) -> List[RequestOutput]:
    """非流式生成的内部实现"""

    # 1. 提交所有请求到执行器
    futures = []
    for request in requests:
        future = self._executor.submit(request)
        futures.append(future)

    # 2. 等待所有结果
    results = []
    for future in tqdm(futures, desc="Generating", disable=not use_tqdm):
        result = future.result()  # 阻塞等待结果
        results.append(result)

    # 3. 转换为 RequestOutput 格式
    outputs = []
    for result in results:
        output = RequestOutput(
            request_id=result.request_id,
            prompt=result.prompt,
            outputs=result.outputs,
            finished=result.finished
        )
        outputs.append(output)

    return outputs

1.2 BaseLLM 基类架构

位置: tensorrt_llm/llmapi/llm.py:108-766

class BaseLLM:
    """所有 LLM 类的基类,定义核心接口和通用功能"""

    def __init__(self, **kwargs):
        # 核心组件初始化
        self._executor_cls = kwargs.pop("executor_cls", GenerationExecutor)
        self._llm_id = None
        self._tokenizer = None
        self._executor = None
        self.mpi_session = None

        # 参数解析和验证
        self.args = self._parse_args(**kwargs)

        # 初始化执行器
        self._init_executor()

    def _init_executor(self):
        """初始化执行器的核心逻辑"""
        # 1. 创建 MPI 会话(多 GPU 场景)
        if self.args.tensor_parallel_size > 1:
            self.mpi_session = MpiCommSession(
                n_workers=self.args.tensor_parallel_size
            )

        # 2. 创建执行器
        self._executor = GenerationExecutor.create(
            engine=self.args.model,
            executor_config=self._create_executor_config(),
            model_world_size=self.args.tensor_parallel_size,
            mpi_session=self.mpi_session
        )

关键功能分析:

分词器管理

@property
def tokenizer(self) -> Optional[TokenizerBase]:
    """分词器属性访问器"""
    if hasattr(self, 'input_processor') and self.input_processor:
        return self.input_processor.tokenizer
    return self._tokenizer

def _load_tokenizer(self):
    """分词器加载逻辑"""
    if self.args.skip_tokenizer_init:
        return None

    # 1. 从模型路径推导分词器路径
    tokenizer_path = self.args.tokenizer or self.args.model

    # 2. 创建分词器实例
    from .tokenizer import TokenizerBase
    tokenizer = TokenizerBase.from_pretrained(
        tokenizer_path,
        trust_remote_code=self.args.trust_remote_code
    )

    return tokenizer

2. 执行器 API (Executor)

2.1 GenerationExecutor 抽象基类

位置: tensorrt_llm/executor/executor.py:78-407

class GenerationExecutor(ABC):
    """生成执行器的抽象基类,定义核心接口"""

    def __init__(self,
                 num_postprocess_workers: int = 0,
                 postprocess_tokenizer_dir: Optional[str] = None,
                 is_llm_executor: Optional[bool] = None):
        # 后处理配置
        self.postproc_config = PostprocWorkerConfig(
            num_postprocess_workers=num_postprocess_workers,
            postprocess_tokenizer_dir=postprocess_tokenizer_dir
        )

        # 结果队列
        self.kv_events_queues = IterationResultQueue()
        self.stats_queues = IterationResultQueue()

        # 错误处理
        self._error_queue = Queue()
        self.doing_shutdown = False

        # 客户端 ID 管理
        self._last_client_id: int = 1

    @abstractmethod
    def submit(self, request: GenerationRequest) -> GenerationResult:
        """提交生成请求的抽象方法"""
        pass

    @abstractmethod
    def abort_request(self, request_id: int) -> None:
        """中止请求的抽象方法"""
        pass

核心方法实现:

generate_async() 方法

def generate_async(
    self,
    prompt_token_ids: List[int],
    sampling_params: SamplingParams,
    query_token_ids: Optional[Union[torch.Tensor, np.ndarray, list]] = None,
    lora_request: Optional[LoRARequest] = None,
    prompt_adapter_request: Optional[PromptAdapterRequest] = None,
    streaming: bool = False,
    kv_cache_retention_config: Optional[KvCacheRetentionConfig] = None,
    disaggregated_params: Optional[DisaggregatedParams] = None,
    postproc_params: Optional[PostprocParams] = None,
    multimodal_params: Optional[MultimodalParams] = None,
    scheduling_params: Optional[SchedulingParams] = None,
    cache_salt_id: Optional[int] = None,
    arrival_time: Optional[float] = None,
) -> GenerationResult:
    """
    异步生成方法的核心实现

    功能流程:
    1. 创建生成请求对象
    2. 分配客户端 ID
    3. 提交请求到执行器
    4. 返回 Future 对象
    """

    # 1. 创建请求对象
    request = GenerationRequest(
        client_id=self._get_next_client_id(),
        prompt_token_ids=prompt_token_ids,
        sampling_params=sampling_params,
        lora_request=lora_request,
        streaming=streaming,
        arrival_time=arrival_time or time.time()
    )

    # 2. 提交请求
    result = self.submit(request)

    return result

2.2 GenerationExecutorWorker 实现类

位置: tensorrt_llm/executor/worker.py:41-91

class GenerationExecutorWorker(BaseWorker):
    """执行器工作进程实现,负责实际的模型推理"""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        # 设置推理引擎
        self.setup_engine()

        # 启动后台线程
        self.await_response_thread = ManagedThread(
            self.await_response_task,
            error_queue=self._error_queue,
            name="await_response_thread"
        )

        self.dispatch_stats_thread = ManagedThread(
            self.dispatch_stats_task,
            error_queue=self._error_queue,
            name="dispatch_stats_thread"
        )

核心功能实现:

submit() 方法

def submit(self, request: GenerationRequest) -> GenerationResult:
    """提交请求到底层推理引擎"""

    # 1. 验证请求参数
    self._validate_request(request)

    # 2. 创建结果对象
    result = GenerationResult(
        request_id=request.client_id,
        prompt=request.prompt_token_ids
    )

    # 3. 注册结果对象
    self._results[request.client_id] = result

    # 4. 提交到底层引擎
    backend_request_id = self.engine.enqueue_request(
        prompt_token_ids=request.prompt_token_ids,
        sampling_config=request.sampling_params.to_backend_config(),
        lora_config=request.lora_request.to_backend_config() if request.lora_request else None
    )

    # 5. 建立映射关系
    self._client_id_to_request_id[request.client_id] = backend_request_id

    return result

await_response_task() 后台任务

def await_response_task(self):
    """等待推理结果的后台任务"""
    while not self.doing_shutdown:
        try:
            # 1. 从引擎获取响应
            responses = self.engine.await_responses(timeout=0.1)

            # 2. 处理每个响应
            for response in responses:
                self._handle_response(response)

        except Exception as e:
            logger.error(f"Error in await_response_task: {e}")
            self._error_queue.put(e)
            break

def _handle_response(self, response):
    """处理单个推理响应"""
    # 1. 查找对应的结果对象
    backend_request_id = response.request_id
    client_id = self._find_client_id(backend_request_id)

    if client_id not in self._results:
        return

    result = self._results[client_id]

    # 2. 更新结果
    if response.has_error():
        result.set_exception(RequestError(response.error_msg))
    else:
        # 添加新的输出 token
        result.add_output_tokens(response.output_token_ids)

        # 检查是否完成
        if response.is_final:
            result.set_finished()
            # 清理资源
            del self._results[client_id]
            del self._client_id_to_request_id[client_id]

3. 构建器 API (Builder)

3.1 Builder 类

位置: tensorrt_llm/builder.py:108-478

class Builder:
    """TensorRT 引擎构建器"""

    _ALLOWED_PRECISIONS = [
        'float32', 'float16', 'bfloat16', 'int8', 'fp8'
    ]

    def __init__(self):
        self._trt_builder = trt.Builder(logger.trt_logger)
        self._timing_cache = None

    def build_engine(self,
                     network: Network,
                     builder_config: BuilderConfig,
                     managed_weights: dict = None) -> trt.IHostMemory:
        """
        构建 TensorRT 引擎的核心方法

        Args:
            network: TensorRT-LLM 网络对象
            builder_config: 构建配置
            managed_weights: 托管权重字典

        Returns:
            序列化的 TensorRT 引擎
        """

        # 1. 设置插件配置
        builder_config.plugin_config = network.plugin_config
        builder_config.auto_parallel_config = network.auto_parallel_config

        # 2. 添加优化配置文件
        if builder_config.trt_builder_config.num_optimization_profiles == 0:
            self._add_optimization_profile(network, builder_config)

        # 3. 重命名权重(如果需要)
        if network.named_parameters is not None:
            self._rename_weights(network, managed_weights)

        # 4. 构建引擎
        tik = time.time()
        engine = self._trt_builder.build_serialized_network(
            network.trt_network,
            builder_config.trt_builder_config
        )

        tok = time.time()
        logger.info(f'Build TensorRT engine Took: {tok - tik:.2f} s')

        return engine

关键方法分析:

_add_optimization_profile() 方法

def _add_optimization_profile(self, network: Network, builder_config: BuilderConfig):
    """添加优化配置文件"""

    # 1. 创建优化配置文件
    profile = self._trt_builder.create_optimization_profile()

    # 2. 为每个输入张量设置形状范围
    for input_name in network.get_input_names():
        input_shape = network.get_input_shape(input_name)

        # 最小形状
        min_shape = self._get_min_shape(input_name, input_shape, builder_config)
        # 最优形状
        opt_shape = self._get_opt_shape(input_name, input_shape, builder_config)
        # 最大形状
        max_shape = self._get_max_shape(input_name, input_shape, builder_config)

        profile.set_shape(input_name, min_shape, opt_shape, max_shape)

    # 3. 添加到构建配置
    builder_config.trt_builder_config.add_optimization_profile(profile)

3.2 BuildConfig 配置类

位置: tensorrt_llm/builder.py:481-570

@dataclass
class BuildConfig:
    """TensorRT-LLM 引擎构建配置"""

    # 序列长度配置
    max_input_len: int = 1024
    max_seq_len: int = None

    # 批次配置
    opt_batch_size: int = 8
    max_batch_size: int = 2048
    max_beam_width: int = 1
    max_num_tokens: int = 8192
    opt_num_tokens: int = None

    # KV 缓存配置
    kv_cache_type: KVCacheType = None

    # 优化配置
    strongly_typed: bool = True
    profiling_verbosity: str = 'layer_names_only'
    enable_debug_output: bool = False

    # 投机解码配置
    max_draft_len: int = 0
    speculative_decoding_mode: SpeculativeDecodingMode = SpeculativeDecodingMode.NONE

    # 高级配置
    use_refit: bool = False
    weight_sparsity: bool = False
    weight_streaming: bool = False

    def update_kv_cache_type(self, architecture: str):
        """根据模型架构更新 KV 缓存类型"""
        if self.kv_cache_type is None:
            # 默认使用分页 KV 缓存
            self.kv_cache_type = KVCacheType.PAGED

3.3 build() 函数

位置: tensorrt_llm/builder.py:1106-1402

def build(model: PretrainedModel, build_config: BuildConfig) -> Engine:
    """
    从给定模型和构建配置构建引擎

    Args:
        model: 预训练模型对象
        build_config: 构建配置

    Returns:
        构建好的引擎对象
    """

    tic = time.time()

    # 1. 深拷贝配置避免修改输入
    build_config = copy.deepcopy(build_config)
    build_config.plugin_config.dtype = model.config.dtype
    build_config.update_kv_cache_type(model.config.architecture)

    # 2. 初始化最大序列长度
    _init_max_seq_len(model.config, build_config)

    # 3. 验证和调整配置
    _validate_build_config(model.config, build_config)

    # 4. 创建网络
    network = Network()
    network.plugin_config = build_config.plugin_config

    # 5. 构建模型网络
    with net_guard(network):
        # 准备输入参数
        prepare_input_args = {
            'max_batch_size': build_config.max_batch_size,
            'max_input_len': build_config.max_input_len,
            'max_seq_len': build_config.max_seq_len,
            'max_beam_width': build_config.max_beam_width,
            'max_num_tokens': build_config.max_num_tokens,
        }

        # 准备模型输入
        inputs = model.prepare_inputs(**prepare_input_args)

        # 前向传播构建网络
        model(**inputs)

    # 6. 网络优化
    if model.config.architecture != "DecoderModel":
        optimize(network)

    # 7. 自动并行处理(如果启用)
    if build_config.auto_parallel_config.enabled:
        sharded_networks = auto_parallel(network, build_config.auto_parallel_config)
        network = sharded_networks[model.config.mapping.rank]

    # 8. 创建构建器配置
    builder_config = BuilderConfig.from_build_config(build_config)

    # 9. 构建引擎
    builder = Builder()
    engine_buffer = builder.build_engine(network, builder_config)

    # 10. 创建引擎对象
    engine = Engine(
        config=model.config,
        engine_buffer=engine_buffer
    )

    toc = time.time()
    logger.info(f'Total time of building {model.config.architecture}: {toc - tic:.2f} s')

    return engine

4. 运行时 API (Runtime)

4.1 Session 类

位置: tensorrt_llm/runtime/session.py:83-303

class Session:
    """TensorRT 运行时会话管理器"""

    def __init__(self, **kwargs):
        # 使用 Session.from_serialized_engine 创建会话
        pass

    def _init(self, engine_buffer=None):
        """初始化 TensorRT 引擎和上下文"""

        # 1. 创建 TensorRT 运行时
        self._runtime = trt.Runtime(logger.trt_logger)

        # 2. 反序列化引擎
        if engine_buffer is not None:
            self._engine = self.runtime.deserialize_cuda_engine(engine_buffer)

        # 3. 创建执行上下文
        self._context = None
        if not self.engine.streamable_weights_size:
            self.__prepare_execution_contexts()

        return self

    def __prepare_execution_contexts(self):
        """准备执行上下文"""
        self._context = self.engine.create_execution_context()
        assert self._context is not None, "Failed to create an execution context!"

        # 设置优化配置文件
        with _scoped_stream() as stream:
            self._context.set_optimization_profile_async(0, stream)

    @staticmethod
    def from_serialized_engine(engine) -> 'Session':
        """从序列化引擎创建会话"""
        session = Session()
        return session._init(engine)

    def run(self,
            inputs: Dict[str, Any],
            outputs: Dict[str, Any],
            stream,
            context=None) -> bool:
        """
        运行 TensorRT 引擎

        Args:
            inputs: 输入张量字典
            outputs: 输出张量字典
            stream: CUDA 流
            context: 执行上下文

        Returns:
            是否成功入队
        """

        # 使用默认上下文(如果未指定)
        if context is None:
            context = self.context

        # 设置输入张量地址
        for tensor_name in inputs:
            tensor = inputs[tensor_name]
            ptr = tensor.data_ptr() if isinstance(tensor, torch.Tensor) else tensor
            context.set_tensor_address(tensor_name, ptr)

        # 设置输出张量地址
        for tensor_name in outputs:
            tensor = outputs[tensor_name]
            ptr = tensor.data_ptr() if isinstance(tensor, torch.Tensor) else tensor
            context.set_tensor_address(tensor_name, ptr)

        # 异步执行
        ok = context.execute_async_v3(stream)
        return ok

4.2 ModelRunner 类

位置: tensorrt_llm/runtime/model_runner.py:515+

class ModelRunner(ModelRunnerMixin):
    """模型运行器,封装推理执行逻辑"""

    def __init__(self,
                 session: Session,
                 max_batch_size: int,
                 max_input_len: int,
                 max_seq_len: int,
                 **kwargs):

        self.session = session
        self.max_batch_size = max_batch_size
        self.max_input_len = max_input_len
        self.max_seq_len = max_seq_len

        # 初始化缓冲区
        self._init_buffers()

        # 初始化 KV 缓存
        self._init_kv_cache()

    def generate(self,
                 batch_input_ids: torch.Tensor,
                 sampling_config: SamplingConfig,
                 **kwargs) -> Dict[str, torch.Tensor]:
        """
        生成文本的核心方法

        Args:
            batch_input_ids: 批次输入 token IDs
            sampling_config: 采样配置

        Returns:
            生成结果字典
        """

        # 1. 输入预处理
        batch_size = batch_input_ids.shape[0]
        input_lengths = self._get_input_lengths(batch_input_ids)

        # 2. 准备输入张量
        inputs = self._prepare_inputs(
            batch_input_ids=batch_input_ids,
            input_lengths=input_lengths,
            sampling_config=sampling_config
        )

        # 3. 执行推理
        outputs = self._run_inference(inputs)

        # 4. 后处理
        results = self._postprocess_outputs(outputs, batch_size)

        return results

    def _run_inference(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """执行推理的内部方法"""

        # 1. 分配输出缓冲区
        outputs = self._allocate_output_buffers(inputs)

        # 2. 运行会话
        with torch.cuda.stream(self.stream):
            ok = self.session.run(inputs, outputs, self.stream.cuda_stream)
            if not ok:
                raise RuntimeError("TensorRT execution failed")

        # 3. 同步等待完成
        self.stream.synchronize()

        return outputs

5. 量化 API

5.1 QuantConfig 配置类

位置: tensorrt_llm/quantization/quantize.py

@dataclass
class QuantConfig:
    """量化配置类"""

    quant_algo: QuantAlgo = QuantAlgo.NO_QUANT
    kv_cache_quant_algo: QuantAlgo = QuantAlgo.NO_QUANT

    # 权重量化参数
    group_size: int = 128
    has_zero_point: bool = True
    pre_quant_scale: bool = False

    # 激活量化参数
    activation_scaling_factor: float = 1.0
    weight_scaling_factor: float = 1.0

    def to_dict(self) -> Dict[str, Any]:
        """转换为字典格式"""
        return {
            'quant_algo': self.quant_algo.value,
            'kv_cache_quant_algo': self.kv_cache_quant_algo.value,
            'group_size': self.group_size,
            'has_zero_point': self.has_zero_point
        }

5.2 quantize() 函数

位置: tensorrt_llm/quantization/quantize.py:561-603

def quantize(model, quant_config: Union[QuantConfig, LayerQuantConfig]):
    """
    模型量化的核心函数

    Args:
        model: 要量化的模型
        quant_config: 量化配置

    Returns:
        量化后的模型
    """

    # 遍历模型的所有模块
    for name, module, parent in model.named_modules_with_parent():

        # 1. 确定层级量化模式
        if quant_config.quant_algo == QuantAlgo.MIXED_PRECISION:
            layer_quant_mode = quant_config.layer_quant_mode(name)
        else:
            layer_quant_mode = quant_config.layer_quant_mode

        if layer_quant_mode == QuantMode(0):
            continue

        # 2. 获取层级量化配置
        layer_quant_cfg = quant_config._get_quant_cfg(name)

        # 3. 根据量化模式选择量化方法
        if layer_quant_mode.has_fp8_qdq():
            module = fp8_quantize(module, layer_quant_cfg)
        elif layer_quant_mode.has_fp8_rowwise():
            module = fp8_rowwise_quantize(module, layer_quant_cfg)
        elif layer_quant_mode.is_qserve_w4a8():
            module = qserve_quantize(module, quant_config)
        elif layer_quant_mode.has_nvfp4():
            module = fp4_quantize(module, layer_quant_cfg)
        elif layer_quant_mode.has_act_and_weight_quant():
            module = smooth_quantize(module, layer_quant_cfg)
        elif layer_quant_mode.is_weight_only():
            if layer_quant_mode.has_per_group_scaling():
                module = weight_only_groupwise_quantize(module, layer_quant_cfg, model.config)
            else:
                module = weight_only_quantize(module, layer_quant_cfg, model.config)

        # 4. 替换模块
        if parent is not None:
            module_name = name.rsplit('.', 1)[-1]
            setattr(parent, module_name, module)
        else:
            model = module
            break

    # 5. KV 缓存量化
    if quant_config.quant_mode.has_kv_cache_quant():
        model = kv_cache_quantize(model)

    # 6. 设置量化模式
    setattr(model, 'quant_mode', quant_config.quant_mode)

    return model

6. API 调用链路分析

6.1 完整推理调用链

sequenceDiagram
    participant User
    participant LLM
    participant BaseLLM
    participant Executor
    participant Worker
    participant Engine
    participant TensorRT

    User->>LLM: generate(inputs, sampling_params)
    LLM->>BaseLLM: generate()

    Note over BaseLLM: 输入预处理
    BaseLLM->>BaseLLM: 检查输入格式 (unbatched/batched)
    BaseLLM->>BaseLLM: 转换为 PromptInputs

    Note over BaseLLM: 异步请求处理
    loop 每个输入
        BaseLLM->>BaseLLM: generate_async(request_inputs)
        BaseLLM->>Executor: 提交异步请求
        Executor->>Worker: submit(GenerationRequest)
        Worker->>Engine: enqueue_request()
        Engine->>TensorRT: 执行推理
    end

    Note over BaseLLM: 等待结果
    BaseLLM->>BaseLLM: 等待所有 futures 完成
    Worker-->>BaseLLM: 返回 RequestOutput
    BaseLLM-->>User: 返回生成结果

6.2 详细调用链路代码分析

6.2.1 LLM.generate() 入口方法

位置: tensorrt_llm/llmapi/llm.py:241-319

def generate(
    self,
    inputs: Union[PromptInputs, Sequence[PromptInputs]],
    sampling_params: Optional[Union[SamplingParams, List[SamplingParams]]] = None,
    use_tqdm: bool = True,
    lora_request: Optional[Union[LoRARequest, Sequence[LoRARequest]]] = None,
    # ... 其他参数
) -> Union[RequestOutput, List[RequestOutput]]:
    """
    同步生成文本的主入口方法

    功能说明:
    1. 输入格式检查和标准化
    2. 批量请求处理
    3. 异步请求提交
    4. 结果收集和返回
    """

    # 1. 输入格式检查和标准化
    unbatched = not isinstance(inputs, list)
    if not unbatched:
        if isinstance(inputs[0], int):
            unbatched = True  # 处理 token ids 输入

    if unbatched:
        inputs = [inputs]  # 转换为批量格式

    # 2. 转换为标准 PromptInputs 格式
    inputs = [prompt_inputs(i) for i in inputs]

    # 3. 辅助函数:从批量参数中获取单个项目
    def _item_at(maybe_batched: Union[Any, Sequence[Any]], pos: int) -> Any:
        if isinstance(maybe_batched, list):
            return maybe_batched[pos]
        else:
            return maybe_batched

    # 4. 提交异步请求
    futures = []
    for i, request_inputs in enumerate(inputs):
        future = self.generate_async(
            request_inputs,
            sampling_params=_item_at(sampling_params, i),
            lora_request=_item_at(lora_request, i),
            # ... 其他参数
            streaming=False
        )
        futures.append(future)

    # 5. 等待所有结果完成
    for future in tqdm(futures, desc="Processed requests",
                      dynamic_ncols=True, disable=not use_tqdm):
        future.result()  # 阻塞等待结果

    # 6. 返回结果
    if unbatched:
        futures = futures[0]  # 单个输入返回单个结果

    return futures

6.2.2 generate_async() 异步处理方法

位置: tensorrt_llm/llmapi/llm.py:322-354

@nvtx_range_debug("LLM.generate_async", color="green", category="LLM")
def generate_async(
    self,
    inputs: PromptInputs,
    sampling_params: Optional[SamplingParams] = None,
    # ... 其他参数
) -> RequestOutput:
    """
    异步生成方法,处理单个请求

    功能说明:
    1. 参数验证和默认值设置
    2. 输入预处理
    3. 创建生成请求
    4. 提交到执行器
    5. 返回 Future 对象
    """

    # 1. 检查是否正在关闭
    if hasattr(self, '_executor') and self._executor is None:
        raise RuntimeError("LLM is shutting down or has been shut down")

    # 2. 设置默认采样参数
    if sampling_params is None:
        sampling_params = SamplingParams()

    # 3. 输入预处理
    processed_inputs = self.input_processor.process(
        inputs,
        sampling_params=sampling_params
    )

    # 4. 提交到执行器
    result = self._executor.submit(generation_request)

    # 5. 包装为 RequestOutput
    return RequestOutput._from_generation_result(
        result,
        prompt=inputs.text if hasattr(inputs, 'text') else None,
        tokenizer=self.tokenizer
    )

6.3 执行器层调用链路

6.3.1 GenerationExecutor.create() 工厂方法

位置: tensorrt_llm/executor/executor.py:356-370

@staticmethod
def create(**kwargs) -> Union["GenerationExecutorProxy", "GenerationExecutorWorker"]:
    """
    工厂方法创建执行器实例

    功能说明:
    1. 根据配置选择执行器类型
    2. 支持单进程和多进程模式
    3. 自动处理 MPI 通信
    """
    from .proxy import GenerationExecutorProxy
    from .worker import GenerationExecutorWorker

    # 获取配置参数
    world_size = kwargs.get('world_size', 0)
    model_world_size = kwargs.get('model_world_size', 1)

    # 根据世界大小选择执行器类型
    if world_size > 1 and world_size >= model_world_size:
        # 多进程模式:使用代理执行器
        logger.info(f"Creating GenerationExecutorProxy with {world_size} workers")
        return GenerationExecutorProxy.create(**kwargs)
    else:
        # 单进程模式:使用工作执行器
        logger.info("Creating GenerationExecutorWorker for single process")
        return GenerationExecutorWorker(**kwargs)

6.4 模型构建调用链

graph TB
    A[用户调用 trtllm-build] --> B[build函数]
    B --> C[配置预处理]
    C --> D[网络构建]
    D --> E[图优化]
    E --> F[引擎编译]
    F --> G[序列化保存]

6.5 量化处理调用链

graph LR
    A[QuantConfig] --> B[quantize函数]
    B --> C[模块遍历]
    C --> D[量化方法选择]
    D --> E[权重转换]
    E --> F[模块替换]
    F --> G[量化模型]

7. 关键数据结构

7.1 GenerationRequest

@dataclass
class GenerationRequest:
    """生成请求数据结构"""
    client_id: int
    prompt_token_ids: List[int]
    sampling_params: SamplingParams
    lora_request: Optional[LoRARequest] = None
    streaming: bool = False
    arrival_time: float = field(default_factory=time.time)

7.2 RequestOutput

@dataclass
class RequestOutput:
    """请求输出数据结构"""
    request_id: int
    prompt: str
    outputs: List[CompletionOutput]
    finished: bool
    metrics: Optional[RequestMetrics] = None

7.3 SamplingParams

@dataclass
class SamplingParams:
    """采样参数数据结构"""
    max_tokens: int = 16
    temperature: float = 1.0
    top_p: float = 1.0
    top_k: int = 0
    beam_width: int = 1
    length_penalty: float = 1.0
    repetition_penalty: float = 1.0

这个 API 深度分析涵盖了 TensorRT-LLM 的核心 API 层次结构、关键方法实现、调用链路和数据结构,为深入理解和使用该框架提供了详细的技术参考。