TensorRT-LLM 模块深度分析

1. LLM API 模块

1.1 模块架构图

graph TB
    subgraph "LLM API 模块"
        A[LLM类] --> B[BaseLLM基类]
        A --> C[_TorchLLM]
        A --> D[_TrtLLM]

        B --> E[输入处理器]
        B --> F[分词器管理]
        B --> G[执行器管理]

        E --> H[PromptInputs]
        F --> I[TokenizerBase]
        G --> J[GenerationExecutor]
    end

    subgraph "配置管理"
        K[LlmArgs] --> L[TorchLlmArgs]
        K --> M[TrtLlmArgs]
        N[BuildConfig] --> O[构建参数]
        P[KvCacheConfig] --> Q[缓存配置]
    end

    subgraph "请求响应"
        R[RequestOutput] --> S[CompletionOutput]
        T[SamplingParams] --> U[采样配置]
        V[LoRARequest] --> W[适配器请求]
    end

    A --> K
    A --> N
    A --> P
    J --> R
    J --> T
    J --> V

1.2 核心类实现分析

BaseLLM 基类

位置: tensorrt_llm/llmapi/llm.py:108-766

class BaseLLM:
    """所有 LLM 类的基类,提供核心功能和接口"""

    def __init__(self, **kwargs):
        """
        初始化基类

        核心功能:
        1. 参数解析和验证
        2. 执行器类型选择
        3. MPI 会话管理
        4. 分词器初始化
        """
        # 执行器类选择
        self._executor_cls = kwargs.pop("executor_cls", GenerationExecutor)
        self._llm_id = None

        # 日志级别管理
        log_level = logger.level
        logger.set_level("info")

        # 参数解析
        self.args = self._parse_llm_args(**kwargs)

        # 分词器初始化
        if not self.args.skip_tokenizer_init:
            self._tokenizer = self._load_tokenizer()

        # 执行器初始化
        self._init_executor()

        # 注册清理函数
        import weakref
        weakref.finalize(self, self._shutdown_wrapper, weakref.ref(self))

    def _parse_llm_args(self, **kwargs) -> BaseLlmArgs:
        """解析和验证 LLM 参数"""
        # 参数类型推导
        if 'backend' in kwargs:
            backend = kwargs['backend']
        else:
            # 根据模型路径推导后端类型
            backend = self._infer_backend(kwargs.get('model'))

        # 创建对应的参数对象
        if backend == 'pytorch':
            return TorchLlmArgs(**kwargs)
        elif backend == 'tensorrt':
            return TrtLlmArgs(**kwargs)
        else:
            raise ValueError(f"Unsupported backend: {backend}")

    def _init_executor(self):
        """初始化执行器"""
        # 多 GPU 场景下创建 MPI 会话
        if self.args.tensor_parallel_size > 1:
            self.mpi_session = MpiCommSession(
                n_workers=self.args.tensor_parallel_size,
                gpus_per_node=self.args.gpus_per_node
            )

        # 创建执行器
        self._executor = GenerationExecutor.create(
            engine=self._prepare_engine(),
            executor_config=self._create_executor_config(),
            model_world_size=self.args.tensor_parallel_size,
            mpi_session=getattr(self, 'mpi_session', None),
            tokenizer=self._tokenizer,
            llm_args=self.args
        )

generate() 方法深度分析

def generate(self, inputs, sampling_params=None, **kwargs):
    """
    生成文本的核心方法

    执行流程:
    1. 输入预处理和验证
    2. 采样参数处理
    3. 请求创建和提交
    4. 结果收集和后处理
    5. 输出格式化
    """

    # 1. 输入预处理
    processed_inputs = self._preprocess_inputs(inputs)

    # 2. 采样参数处理
    if sampling_params is None:
        sampling_params = SamplingParams()
    processed_sampling_params = self._process_sampling_params(sampling_params)

    # 3. 创建生成请求
    requests = []
    for i, input_data in enumerate(processed_inputs):
        request = GenerationRequest(
            prompt_token_ids=input_data.token_ids,
            sampling_params=processed_sampling_params[i] if isinstance(processed_sampling_params, list) else processed_sampling_params,
            lora_request=kwargs.get('lora_request'),
            prompt_adapter_request=kwargs.get('prompt_adapter_request')
        )
        requests.append(request)

    # 4. 提交请求并收集结果
    if self._is_streaming_request(kwargs):
        return self._generate_streaming(requests, **kwargs)
    else:
        return self._generate_non_streaming(requests, **kwargs)

def _generate_non_streaming(self, requests, use_tqdm=True):
    """非流式生成实现"""

    # 提交所有请求
    futures = []
    for request in requests:
        future = self._executor.submit(request)
        futures.append(future)

    # 等待所有结果
    results = []
    for future in tqdm(futures, desc="Generating", disable=not use_tqdm):
        try:
            result = future.result()  # 阻塞等待
            results.append(result)
        except Exception as e:
            logger.error(f"Generation failed: {e}")
            raise

    # 转换为输出格式
    outputs = []
    for result in results:
        output = RequestOutput(
            request_id=result.request_id,
            prompt=result.prompt,
            outputs=[CompletionOutput(
                index=0,
                text=result.text,
                token_ids=result.token_ids,
                cumulative_logprob=result.cumulative_logprob,
                logprobs=result.logprobs,
                finish_reason=result.finish_reason
            )],
            finished=result.finished
        )
        outputs.append(output)

    return outputs[0] if len(outputs) == 1 else outputs

1.3 时序图分析

sequenceDiagram
    participant User
    participant LLM
    participant InputProcessor
    participant Tokenizer
    participant Executor
    participant Engine

    User->>LLM: generate("Hello world")

    Note over LLM: 输入预处理阶段
    LLM->>InputProcessor: preprocess_inputs()
    InputProcessor->>Tokenizer: encode("Hello world")
    Tokenizer-->>InputProcessor: [101, 7592, 2088, 102]
    InputProcessor-->>LLM: PromptInputs(token_ids=[...])

    Note over LLM: 请求创建阶段
    LLM->>LLM: create_generation_request()
    LLM->>Executor: submit(GenerationRequest)

    Note over Executor: 推理执行阶段
    Executor->>Engine: enqueue_request()
    Engine->>Engine: forward_pass()
    Engine->>Engine: sampling()
    Engine-->>Executor: GenerationResult

    Note over LLM: 输出后处理阶段
    Executor-->>LLM: GenerationResult
    LLM->>Tokenizer: decode([102, 2003, 1037, 3231])
    Tokenizer-->>LLM: " is a test"
    LLM-->>User: RequestOutput("Hello world is a test")

2. 执行器模块 (Executor)

2.1 模块架构图

graph TB
    subgraph "执行器抽象层"
        A[GenerationExecutor] --> B[抽象接口定义]
        A --> C[公共功能实现]
    end

    subgraph "代理执行器"
        D[GenerationExecutorProxy] --> E[多进程管理]
        D --> F[进程间通信]
        D --> G[负载均衡]
    end

    subgraph "工作执行器"
        H[GenerationExecutorWorker] --> I[推理引擎管理]
        H --> J[请求处理]
        H --> K[结果返回]
    end

    subgraph "基础工作器"
        L[BaseWorker] --> M[引擎设置]
        L --> N[后处理管理]
        L --> O[错误处理]
    end

    subgraph "支撑组件"
        P[IPC队列] --> Q[进程间消息传递]
        R[ManagedThread] --> S[后台任务管理]
        T[PostprocWorker] --> U[后处理并行化]
    end

    A <|-- D
    A <|-- H
    H --|> L
    D --> P
    H --> R
    H --> T

2.2 GenerationExecutor 抽象基类

位置: tensorrt_llm/executor/executor.py:78-407

class GenerationExecutor(ABC):
    """生成执行器抽象基类"""

    def __init__(self,
                 num_postprocess_workers: int = 0,
                 postprocess_tokenizer_dir: Optional[str] = None,
                 is_llm_executor: Optional[bool] = None):
        """
        初始化执行器基类

        功能组件:
        1. 后处理工作器配置
        2. 结果队列管理
        3. 错误处理机制
        4. 客户端ID管理
        """

        # 后处理配置
        self.postproc_config = PostprocWorkerConfig(
            num_postprocess_workers=num_postprocess_workers,
            postprocess_tokenizer_dir=postprocess_tokenizer_dir
        )

        # 结果队列
        self.kv_events_queues = IterationResultQueue()
        self.stats_queues = IterationResultQueue()

        # 错误处理
        self._error_queue = Queue()
        self.doing_shutdown = False

        # 客户端管理
        self._last_client_id: int = 1

        # 注册清理函数
        atexit.register(self.shutdown)

    @abstractmethod
    def submit(self, request: GenerationRequest) -> GenerationResult:
        """提交生成请求 - 抽象方法"""
        pass

    @abstractmethod
    def abort_request(self, request_id: int) -> None:
        """中止请求 - 抽象方法"""
        pass

    def generate_async(self, **kwargs) -> GenerationResult:
        """异步生成方法"""
        # 创建生成请求
        request = GenerationRequest(
            client_id=self._get_next_client_id(),
            prompt_token_ids=kwargs['prompt_token_ids'],
            sampling_params=kwargs['sampling_params'],
            lora_request=kwargs.get('lora_request'),
            streaming=kwargs.get('streaming', False),
            arrival_time=kwargs.get('arrival_time', time.time())
        )

        # 提交请求
        return self.submit(request)

    def _get_next_client_id(self) -> int:
        """获取下一个客户端ID"""
        client_id = self._last_client_id
        self._last_client_id += 1
        return client_id

    @staticmethod
    def create(**kwargs) -> Union["GenerationExecutorProxy", "GenerationExecutorWorker"]:
        """工厂方法创建执行器实例"""
        from .proxy import GenerationExecutorProxy
        from .worker import GenerationExecutorWorker

        world_size = kwargs.get('world_size', 0)
        model_world_size = kwargs.get('model_world_size', 1)

        # 根据世界大小选择执行器类型
        if world_size > 1 and world_size >= model_world_size:
            # 多进程模式
            return GenerationExecutorProxy.create(**kwargs)
        else:
            # 单进程模式
            return GenerationExecutorWorker(**kwargs)

2.3 GenerationExecutorWorker 详细实现

位置: tensorrt_llm/executor/worker.py:41-91

class GenerationExecutorWorker(BaseWorker):
    """
    执行器工作进程实现,负责实际的模型推理

    继承关系:
    GenerationExecutor (抽象基类)
    BaseWorker (基础工作器)
    GenerationExecutorWorker (具体实现)

    核心功能:
    1. 推理引擎管理和初始化
    2. 多线程后台任务处理
    3. 请求提交和结果收集
    4. 错误处理和资源清理
    """

    class WorkerExit(GeneratorExit):
        """工作进程退出异常"""
        pass

    def __init__(
        self,
        engine: Union[Path, Engine],
        executor_config: Optional[tllm.ExecutorConfig] = None,
        batched_logits_processor: Optional[BatchedLogitsProcessor] = None,
        postproc_worker_config: Optional[PostprocWorkerConfig] = None,
        is_llm_executor: Optional[bool] = None,
        lora_config: Optional[LoraConfig] = None,
        kv_connector_config: Optional[KvCacheConnectorConfig] = None,
        hf_model_dir: Optional[Path] = None,
        tokenizer: Optional[TokenizerBase] = None,
        llm_args: Optional[BaseLlmArgs] = None,
    ) -> None:
        """
        初始化工作执行器

        参数说明:
        - engine: TensorRT引擎路径或Engine对象
        - executor_config: 执行器配置,包含批次大小、序列长度等
        - batched_logits_processor: 批量logits后处理器
        - postproc_worker_config: 后处理工作器配置
        - is_llm_executor: 标识是否为LLM执行器实例
        - lora_config: LoRA适配器配置
        - kv_connector_config: KV缓存连接器配置
        - hf_model_dir: HuggingFace模型目录
        - tokenizer: 分词器实例
        - llm_args: LLM参数配置
        """

        # 调用父类初始化
        super().__init__(
            engine=engine,
            executor_config=executor_config,
            batched_logits_processor=batched_logits_processor,
            postproc_worker_config=postproc_worker_config,
            is_llm_executor=is_llm_executor,
            lora_config=lora_config,
            kv_connector_config=kv_connector_config,
            hf_model_dir=hf_model_dir,
            tokenizer=tokenizer,
            llm_args=llm_args,
        )

        # 设置推理引擎
        self.setup_engine()

        # 初始化后台线程管理器
        self._init_background_threads()

    def _init_background_threads(self):
        """
        初始化后台线程

        线程说明:
        1. await_response_thread: 监听推理引擎响应
        2. dispatch_stats_thread: 分发统计信息
        3. dispatch_kv_cache_events_thread: 分发KV缓存事件
        """

        # 响应等待线程 - 核心线程,处理推理结果
        self.await_response_thread = ManagedThread(
            self.await_response_task,
            error_queue=self._error_queue,
            name="await_response_thread"
        )

        # 统计分发线程 - 处理性能统计信息
        self.dispatch_stats_thread = ManagedThread(
            self.dispatch_stats_task,
            error_queue=self._error_queue,
            name="dispatch_stats_thread"
        )

        # KV缓存事件分发线程 - 处理缓存事件
        self.dispatch_kv_cache_events_thread = ManagedThread(
            self.dispatch_kv_cache_events_task,
            error_queue=self._error_queue,
            name="dispatch_kv_cache_events_thread"
        )

    def await_response_task(self) -> bool:
        """
        等待推理响应的后台任务

        功能说明:
        1. 持续监听引擎响应
        2. 调用通用响应处理助手
        3. 处理异常和错误情况
        4. 支持优雅关闭

        Returns:
            bool: 任务是否成功完成
        """
        return self._await_response_helper()

    def dispatch_stats_task(self) -> bool:
        """
        分发统计信息任务

        功能说明:
        1. 合并迭代统计和请求统计
        2. 调用通用迭代结果处理
        3. 提供性能监控数据

        Returns:
            bool: 任务是否成功
        """

        # 定义统计合并函数
        def join_iteration_and_request_stats(iteration_stats):
            """合并迭代统计和请求统计"""
            if not iteration_stats:
                return iteration_stats

            # 获取请求统计
            request_stats_result = self.engine.get_latest_request_stats()
            request_stats = (request_stats_result
                           if request_stats_result is not None
                           else [])

            # 合并统计数据
            return iteration_stats + request_stats

        # 调用通用处理方法
        return self._iteration_result_task(
            self.stats_queues,
            self.engine.get_latest_iteration_stats,
            self._iter_stats_result,
            join_iteration_and_request_stats
        )

    def dispatch_kv_cache_events_task(self) -> bool:
        """
        分发KV缓存事件任务

        功能说明:
        1. 获取KV缓存事件
        2. 分发到事件队列
        3. 支持缓存管理和监控

        Returns:
            bool: 任务是否成功
        """
        return self._iteration_result_task(
            self.kv_events_queues,
            self.engine.get_latest_kv_cache_events,
            self._iter_kv_events_result,
            lambda x: x  # 直接返回,不需要额外处理
        )

2.4 执行器时序图

sequenceDiagram
    participant LLM
    participant Executor
    participant Engine
    participant BackgroundThread
    participant Result

    LLM->>Executor: submit(GenerationRequest)

    Note over Executor: 请求处理
    Executor->>Executor: validate_request()
    Executor->>Result: create GenerationResult()
    Executor->>Engine: enqueue_request()
    Executor-->>LLM: return GenerationResult (Future)

    Note over BackgroundThread: 后台处理
    loop 等待响应
        BackgroundThread->>Engine: await_responses()
        Engine-->>BackgroundThread: response batch

        loop 处理每个响应
            BackgroundThread->>BackgroundThread: handle_response()
            BackgroundThread->>Result: update result

            alt 生成完成
                BackgroundThread->>Result: set_finished()
                Result-->>LLM: notify completion
            else 继续生成
                BackgroundThread->>Engine: continue generation
            end
        end
    end

3. 构建器模块 (Builder)

3.1 模块架构图

graph TB
    subgraph "构建器核心"
        A[Builder类] --> B[引擎构建]
        A --> C[优化配置]
        A --> D[序列化管理]
    end

    subgraph "配置管理"
        E[BuildConfig] --> F[序列长度配置]
        E --> G[批次配置]
        E --> H[优化配置]
        I[BuilderConfig] --> J[TensorRT配置]
        I --> K[插件配置]
    end

    subgraph "网络构建"
        L[Network] --> M[计算图构建]
        L --> N[输入输出管理]
        L --> O[插件集成]
    end

    subgraph "优化流程"
        P[图优化] --> Q[算子融合]
        P --> R[内存优化]
        P --> S[并行优化]
    end

    subgraph "引擎管理"
        T[Engine] --> U[序列化引擎]
        T --> V[引擎元数据]
        T --> W[权重管理]
    end

    A --> E
    A --> I
    A --> L
    L --> P
    A --> T

3.2 Builder 类实现

位置: tensorrt_llm/builder.py:108-478

class Builder:
    """TensorRT 引擎构建器"""

    _ALLOWED_PRECISIONS = [
        'float32', 'float16', 'bfloat16', 'int8', 'fp8'
    ]

    def __init__(self):
        """初始化构建器"""
        # 创建 TensorRT 构建器
        self._trt_builder = trt.Builder(logger.trt_logger)

        # 时序缓存
        self._timing_cache = None

        # 构建统计
        self._build_stats = {}

    def build_engine(self,
                     network: Network,
                     builder_config: BuilderConfig,
                     managed_weights: dict = None) -> trt.IHostMemory:
        """
        构建 TensorRT 引擎

        Args:
            network: TensorRT-LLM 网络对象
            builder_config: 构建配置
            managed_weights: 托管权重字典

        Returns:
            序列化的 TensorRT 引擎

        构建流程:
        1. 配置验证和设置
        2. 优化配置文件添加
        3. 权重处理
        4. 引擎构建
        5. 序列化返回
        """

        logger.info("Starting TensorRT engine build...")
        build_start_time = time.time()

        # 1. 配置设置
        builder_config.plugin_config = network.plugin_config
        builder_config.auto_parallel_config = network.auto_parallel_config

        # 2. 添加优化配置文件
        if builder_config.trt_builder_config.num_optimization_profiles == 0:
            self._add_optimization_profile(network, builder_config)

        logger.info(f"Total optimization profiles: {builder_config.trt_builder_config.num_optimization_profiles}")

        # 3. 权重重命名(如果需要)
        if network.named_parameters is not None:
            logger.info("Renaming weights for TensorRT compatibility...")
            self._rename_weights(network, managed_weights)

        # 4. 设置时序缓存
        if builder_config.input_timing_cache:
            self._load_timing_cache(builder_config.input_timing_cache)

        # 5. 构建引擎
        logger.info("Building TensorRT engine...")
        engine_build_start = time.time()

        serialized_engine = self._trt_builder.build_serialized_network(
            network.trt_network,
            builder_config.trt_builder_config
        )

        if serialized_engine is None:
            raise RuntimeError("Failed to build TensorRT engine")

        engine_build_time = time.time() - engine_build_start
        total_build_time = time.time() - build_start_time

        logger.info(f"Engine build time: {engine_build_time:.2f}s")
        logger.info(f"Total build time: {total_build_time:.2f}s")

        # 6. 保存时序缓存
        if builder_config.output_timing_cache:
            self._save_timing_cache(builder_config.output_timing_cache)

        return serialized_engine

    def _add_optimization_profile(self, network: Network, builder_config: BuilderConfig):
        """添加优化配置文件"""

        # 创建优化配置文件
        profile = self._trt_builder.create_optimization_profile()

        # 为每个输入张量设置形状范围
        for input_name in network.get_input_names():
            input_shape = network.get_input_shape(input_name)

            # 计算形状范围
            min_shape = self._compute_min_shape(input_name, input_shape, builder_config)
            opt_shape = self._compute_opt_shape(input_name, input_shape, builder_config)
            max_shape = self._compute_max_shape(input_name, input_shape, builder_config)

            logger.debug(f"Input {input_name}: min={min_shape}, opt={opt_shape}, max={max_shape}")

            # 设置形状范围
            profile.set_shape(input_name, min_shape, opt_shape, max_shape)

        # 验证配置文件
        if not profile.is_valid():
            raise RuntimeError("Invalid optimization profile")

        # 添加到构建配置
        builder_config.trt_builder_config.add_optimization_profile(profile)

    def _compute_min_shape(self, input_name: str, input_shape: List[int], config: BuilderConfig) -> List[int]:
        """计算最小输入形状"""
        min_shape = list(input_shape)

        if 'input_ids' in input_name:
            min_shape[0] = 1  # min batch size
            min_shape[1] = 1  # min sequence length
        elif 'attention_mask' in input_name:
            min_shape[0] = 1
            min_shape[1] = 1
        elif 'position_ids' in input_name:
            min_shape[0] = 1
            min_shape[1] = 1

        return min_shape

    def _compute_opt_shape(self, input_name: str, input_shape: List[int], config: BuilderConfig) -> List[int]:
        """计算最优输入形状"""
        opt_shape = list(input_shape)

        if 'input_ids' in input_name:
            opt_shape[0] = config.opt_batch_size
            opt_shape[1] = config.max_input_len // 2
        elif 'attention_mask' in input_name:
            opt_shape[0] = config.opt_batch_size
            opt_shape[1] = config.max_input_len // 2

        return opt_shape

    def _compute_max_shape(self, input_name: str, input_shape: List[int], config: BuilderConfig) -> List[int]:
        """计算最大输入形状"""
        max_shape = list(input_shape)

        if 'input_ids' in input_name:
            max_shape[0] = config.max_batch_size
            max_shape[1] = config.max_input_len
        elif 'attention_mask' in input_name:
            max_shape[0] = config.max_batch_size
            max_shape[1] = config.max_seq_len

        return max_shape

3.3 build() 函数深度分析

位置: tensorrt_llm/builder.py:1106-1402

def build(model: PretrainedModel, build_config: BuildConfig) -> Engine:
    """
    从模型和配置构建引擎

    Args:
        model: 预训练模型对象
        build_config: 构建配置

    Returns:
        构建好的引擎对象

    构建流程:
    1. 配置预处理
    2. 网络构建
    3. 图优化
    4. 自动并行
    5. 引擎编译
    """

    logger.info(f"Building {model.config.architecture} engine...")
    total_start_time = time.time()

    # 1. 配置预处理
    build_config = copy.deepcopy(build_config)
    build_config.plugin_config.dtype = model.config.dtype
    build_config.update_kv_cache_type(model.config.architecture)

    # 初始化最大序列长度
    _init_max_seq_len(model.config, build_config)

    # 验证配置兼容性
    _validate_build_config(model.config, build_config)

    # 2. 网络构建阶段
    logger.info("Building network...")
    network_start_time = time.time()

    # 创建网络对象
    network = Network()
    network.plugin_config = build_config.plugin_config

    # 设置网络上下文
    with net_guard(network):
        # 准备输入参数
        prepare_input_args = {
            'max_batch_size': build_config.max_batch_size,
            'max_input_len': build_config.max_input_len,
            'max_seq_len': build_config.max_seq_len,
            'max_beam_width': build_config.max_beam_width,
            'max_num_tokens': build_config.max_num_tokens,
            'opt_batch_size': build_config.opt_batch_size,
            'opt_num_tokens': build_config.opt_num_tokens,
        }

        # 特殊配置处理
        if build_config.speculative_decoding_mode != SpeculativeDecodingMode.NONE:
            prepare_input_args['max_draft_len'] = build_config.max_draft_len
            prepare_input_args['spec_decoding_is_generation_length_variable'] = True

        # 准备模型输入
        inputs = model.prepare_inputs(**prepare_input_args)

        # 前向传播构建计算图
        model(**inputs)

        # 标记调试输出(如果启用)
        if build_config.enable_debug_output:
            for name, tensor in model.named_network_outputs():
                network._mark_output(tensor, name, str_dtype_to_trt(model.config.dtype))

    network_build_time = time.time() - network_start_time
    logger.info(f"Network build time: {network_build_time:.2f}s")

    # 3. 图优化阶段
    if model.config.architecture != "DecoderModel":
        logger.info("Optimizing network...")
        optimize_start_time = time.time()

        optimize(network)

        optimize_time = time.time() - optimize_start_time
        logger.info(f"Network optimization time: {optimize_time:.2f}s")

    # 4. 自动并行处理
    use_auto_parallel = build_config.auto_parallel_config.enabled
    if use_auto_parallel:
        logger.info("Applying auto parallel...")
        auto_parallel_start_time = time.time()

        config = build_config.auto_parallel_config
        config.builder_flags = builder_config.trt_builder_config.flags

        sharded_networks = auto_parallel(network, config)
        network = sharded_networks[model.config.mapping.rank]

        if not config.debug_mode:
            mapping = network.auto_parallel_config["mapping"]
            model.config.mapping = mapping

        auto_parallel_time = time.time() - auto_parallel_start_time
        logger.info(f"Auto parallel time: {auto_parallel_time:.2f}s")

    # 5. 网络可视化(如果启用)
    if build_config.visualize_network:
        logger.info(f"Saving network visualization to {build_config.visualize_network}")
        network.save_visualization(build_config.visualize_network)

    # 6. 创建构建器配置
    builder_config = BuilderConfig.from_build_config(build_config)

    # 7. 引擎编译
    logger.info("Compiling TensorRT engine...")
    compile_start_time = time.time()

    builder = Builder()
    engine_buffer = builder.build_engine(network, builder_config)

    compile_time = time.time() - compile_start_time
    logger.info(f"Engine compilation time: {compile_time:.2f}s")

    # 8. 创建引擎对象
    engine = Engine(
        config=model.config,
        engine_buffer=engine_buffer
    )

    total_time = time.time() - total_start_time
    logger.info(f"Total build time: {total_time:.2f}s")

    return engine

3.4 构建器时序图

sequenceDiagram
    participant User
    participant BuildFunc
    participant Model
    participant Network
    participant Optimizer
    participant Builder
    participant TensorRT

    User->>BuildFunc: build(model, config)

    Note over BuildFunc: 配置预处理
    BuildFunc->>BuildFunc: validate_config()
    BuildFunc->>BuildFunc: init_max_seq_len()

    Note over BuildFunc: 网络构建
    BuildFunc->>Network: create Network()
    BuildFunc->>Model: prepare_inputs()
    Model-->>BuildFunc: input tensors
    BuildFunc->>Model: forward(**inputs)
    Model->>Network: build computation graph

    Note over BuildFunc: 图优化
    BuildFunc->>Optimizer: optimize(network)
    Optimizer->>Optimizer: operator fusion
    Optimizer->>Optimizer: memory optimization
    Optimizer-->>BuildFunc: optimized network

    Note over BuildFunc: 引擎编译
    BuildFunc->>Builder: create Builder()
    BuildFunc->>Builder: build_engine(network, config)
    Builder->>TensorRT: build_serialized_network()
    TensorRT-->>Builder: serialized engine
    Builder-->>BuildFunc: engine buffer

    BuildFunc-->>User: Engine object

4. 运行时模块 (Runtime)

4.1 模块架构图

graph TB
    subgraph "会话管理"
        A[Session] --> B[引擎管理]
        A --> C[上下文管理]
        A --> D[内存管理]
    end

    subgraph "模型运行器"
        E[ModelRunner] --> F[推理执行]
        E --> G[缓冲区管理]
        E --> H[批次处理]
    end

    subgraph "生成管理"
        I[GenerationSession] --> J[自回归生成]
        I --> K[采样策略]
        I --> L[KV缓存管理]
    end

    subgraph "多模态支持"
        M[MultimodalModelRunner] --> N[视觉编码器]
        M --> O[音频编码器]
        M --> P[跨模态融合]
    end

    subgraph "底层运行时"
        Q[TensorRT Runtime] --> R[引擎执行]
        Q --> S[内存分配]
        Q --> T[CUDA流管理]
    end

    A --> E
    E --> I
    E --> M
    A --> Q

4.2 Session 类实现

位置: tensorrt_llm/runtime/session.py:83-303

class Session:
    """TensorRT 运行时会话管理器"""

    def __init__(self, **kwargs):
        """使用静态方法创建会话"""
        pass

    def _init(self, engine_buffer=None):
        """
        初始化会话

        功能:
        1. 创建 TensorRT 运行时
        2. 反序列化引擎
        3. 创建执行上下文
        4. 设置优化配置文件
        """

        # 1. 创建 TensorRT 运行时
        self._runtime = trt.Runtime(logger.trt_logger)

        # 2. 反序列化引擎
        if engine_buffer is not None:
            self._engine = self.runtime.deserialize_cuda_engine(engine_buffer)
            if self._engine is None:
                raise RuntimeError("Failed to deserialize TensorRT engine")

        # 3. 创建执行上下文
        self._context = None
        if not self.engine.streamable_weights_size:
            self.__prepare_execution_contexts()

        return self

    def __prepare_execution_contexts(self):
        """准备执行上下文"""

        # 创建执行上下文
        self._context = self.engine.create_execution_context()
        if self._context is None:
            raise RuntimeError("Failed to create execution context")

        # 设置优化配置文件
        with _scoped_stream() as stream:
            success = self._context.set_optimization_profile_async(0, stream)
            if not success:
                raise RuntimeError("Failed to set optimization profile")

            # 同步等待设置完成
            stream.synchronize()

    @staticmethod
    def from_serialized_engine(engine_buffer) -> 'Session':
        """从序列化引擎创建会话"""
        session = Session()
        return session._init(engine_buffer)

    @staticmethod
    def from_engine(engine: trt.ICudaEngine) -> 'Session':
        """从现有引擎创建会话"""
        session = Session()
        session.engine = engine
        return session._init()

    @property
    def engine(self) -> trt.ICudaEngine:
        """获取 TensorRT 引擎"""
        return self._engine

    @engine.setter
    def engine(self, engine: trt.ICudaEngine):
        """设置 TensorRT 引擎"""
        self._engine = engine

    @property
    def context(self) -> trt.IExecutionContext:
        """获取执行上下文"""
        if self._context is None:
            raise RuntimeError("Execution context not initialized")
        return self._context

    def run(self,
            inputs: Dict[str, Any],
            outputs: Dict[str, Any],
            stream,
            context=None) -> bool:
        """
        运行推理

        Args:
            inputs: 输入张量字典 {name: tensor}
            outputs: 输出张量字典 {name: tensor}
            stream: CUDA 流
            context: 执行上下文(可选)

        Returns:
            是否成功入队执行

        执行流程:
        1. 设置输入张量地址
        2. 设置输出张量地址
        3. 异步执行推理
        """

        # 使用默认上下文(如果未指定)
        if context is None:
            context = self.context

        # 设置输入张量地址
        for tensor_name, tensor in inputs.items():
            if isinstance(tensor, torch.Tensor):
                ptr = tensor.data_ptr()
            else:
                ptr = tensor

            context.set_tensor_address(tensor_name, ptr)

        # 设置输出张量地址
        for tensor_name, tensor in outputs.items():
            if isinstance(tensor, torch.Tensor):
                ptr = tensor.data_ptr()
            else:
                ptr = tensor

            context.set_tensor_address(tensor_name, ptr)

        # 异步执行推理
        success = context.execute_async_v3(stream)

        return success

    def set_weight_streaming(self, gpu_weights_percent: float):
        """设置权重流式传输"""
        if not self.engine.streamable_weights_size:
            logger.warning("Engine does not support weight streaming")
            return

        try:
            # 计算 GPU 权重预算
            total_weights_size = self.engine.streamable_weights_size
            gpu_weights_budget = int(total_weights_size * gpu_weights_percent)

            # 设置权重流式预算
            self.engine.weight_streaming_budget_v2 = gpu_weights_budget

            # 重新创建执行上下文
            self.__prepare_execution_contexts()

            logger.info(f"Weight streaming enabled: {gpu_weights_percent*100:.1f}% on GPU")

        except Exception as e:
            logger.error(f"Failed to set weight streaming: {e}")
            raise

4.3 ModelRunner 类实现

位置: tensorrt_llm/runtime/model_runner.py:515+

class ModelRunner(ModelRunnerMixin):
    """模型运行器,封装推理执行逻辑"""

    def __init__(self,
                 session: Session,
                 max_batch_size: int,
                 max_input_len: int,
                 max_seq_len: int,
                 max_beam_width: int = 1,
                 lora_config: Optional[LoraConfig] = None,
                 **kwargs):
        """
        初始化模型运行器

        Args:
            session: TensorRT 会话
            max_batch_size: 最大批次大小
            max_input_len: 最大输入长度
            max_seq_len: 最大序列长度
            max_beam_width: 最大束搜索宽度
        """

        self.session = session
        self.max_batch_size = max_batch_size
        self.max_input_len = max_input_len
        self.max_seq_len = max_seq_len
        self.max_beam_width = max_beam_width

        # 创建 CUDA 流
        self.stream = torch.cuda.Stream()

        # 初始化缓冲区
        self._init_buffers()

        # 初始化 KV 缓存
        self._init_kv_cache()

        # LoRA 配置
        self.lora_config = lora_config
        if lora_config:
            self._init_lora_weights()

    def _init_buffers(self):
        """初始化推理缓冲区"""

        # 输入缓冲区
        self.input_ids = torch.zeros(
            (self.max_batch_size, self.max_input_len),
            dtype=torch.int32,
            device='cuda'
        )

        self.input_lengths = torch.zeros(
            (self.max_batch_size,),
            dtype=torch.int32,
            device='cuda'
        )

        self.position_ids = torch.zeros(
            (self.max_batch_size, self.max_input_len),
            dtype=torch.int32,
            device='cuda'
        )

        # 输出缓冲区
        self.output_ids = torch.zeros(
            (self.max_batch_size, self.max_beam_width, self.max_seq_len),
            dtype=torch.int32,
            device='cuda'
        )

        self.sequence_lengths = torch.zeros(
            (self.max_batch_size, self.max_beam_width),
            dtype=torch.int32,
            device='cuda'
        )

        # Logits 缓冲区(如果需要)
        if self.gather_all_token_logits:
            vocab_size = self.session.engine.get_tensor_shape('logits')[-1]
            self.logits = torch.zeros(
                (self.max_batch_size, self.max_seq_len, vocab_size),
                dtype=torch.float32,
                device='cuda'
            )

    def _init_kv_cache(self):
        """初始化 KV 缓存"""

        # 获取 KV 缓存配置
        kv_cache_config = self._get_kv_cache_config()

        if kv_cache_config.cache_type == KVCacheType.PAGED:
            # 分页 KV 缓存
            self._init_paged_kv_cache(kv_cache_config)
        else:
            # 连续 KV 缓存
            self._init_continuous_kv_cache(kv_cache_config)

    def generate(self,
                 batch_input_ids: torch.Tensor,
                 sampling_config: SamplingConfig,
                 prompt_table: Optional[torch.Tensor] = None,
                 tasks: Optional[torch.Tensor] = None,
                 lora_uids: Optional[torch.Tensor] = None,
                 **kwargs) -> Dict[str, torch.Tensor]:
        """
        生成文本

        Args:
            batch_input_ids: 批次输入 token IDs [batch_size, input_len]
            sampling_config: 采样配置
            prompt_table: 提示表(可选)
            tasks: 任务 ID(可选)
            lora_uids: LoRA UID(可选)

        Returns:
            生成结果字典

        执行流程:
        1. 输入预处理
        2. 准备推理输入
        3. 执行推理
        4. 结果后处理
        """

        # 1. 输入预处理
        batch_size = batch_input_ids.shape[0]
        input_lengths = self._get_input_lengths(batch_input_ids)

        # 验证输入尺寸
        if batch_size > self.max_batch_size:
            raise ValueError(f"Batch size {batch_size} exceeds maximum {self.max_batch_size}")

        # 2. 准备推理输入
        inputs = self._prepare_inputs(
            batch_input_ids=batch_input_ids,
            input_lengths=input_lengths,
            sampling_config=sampling_config,
            prompt_table=prompt_table,
            tasks=tasks,
            lora_uids=lora_uids
        )

        # 3. 执行推理
        outputs = self._run_inference(inputs, batch_size)

        # 4. 结果后处理
        results = self._postprocess_outputs(outputs, batch_size, input_lengths)

        return results

    def _run_inference(self, inputs: Dict[str, torch.Tensor], batch_size: int) -> Dict[str, torch.Tensor]:
        """执行推理"""

        # 分配输出缓冲区
        outputs = self._allocate_output_buffers(batch_size)

        # 在 CUDA 流上执行
        with torch.cuda.stream(self.stream):
            # 运行 TensorRT 会话
            success = self.session.run(inputs, outputs, self.stream.cuda_stream)

            if not success:
                raise RuntimeError("TensorRT inference execution failed")

        # 同步等待完成
        self.stream.synchronize()

        return outputs

    def _prepare_inputs(self, **kwargs) -> Dict[str, torch.Tensor]:
        """准备推理输入张量"""

        batch_input_ids = kwargs['batch_input_ids']
        input_lengths = kwargs['input_lengths']
        sampling_config = kwargs['sampling_config']

        batch_size, input_len = batch_input_ids.shape

        # 输入 token IDs
        self.input_ids[:batch_size, :input_len] = batch_input_ids

        # 输入长度
        self.input_lengths[:batch_size] = input_lengths

        # 位置 IDs
        for i in range(batch_size):
            self.position_ids[i, :input_lengths[i]] = torch.arange(input_lengths[i])

        # 构建输入字典
        inputs = {
            'input_ids': self.input_ids[:batch_size],
            'input_lengths': self.input_lengths[:batch_size],
            'position_ids': self.position_ids[:batch_size],
        }

        # 添加采样参数
        inputs.update(self._prepare_sampling_inputs(sampling_config, batch_size))

        # 添加 KV 缓存输入
        inputs.update(self._prepare_kv_cache_inputs(batch_size))

        # 添加可选输入
        if 'prompt_table' in kwargs and kwargs['prompt_table'] is not None:
            inputs['prompt_embedding_table'] = kwargs['prompt_table']

        if 'lora_uids' in kwargs and kwargs['lora_uids'] is not None:
            inputs['lora_ranks'] = kwargs['lora_uids']

        return inputs

4.4 运行时时序图

sequenceDiagram
    participant Executor
    participant ModelRunner
    participant Session
    participant TensorRT
    participant CUDA

    Executor->>ModelRunner: generate(batch_input_ids, sampling_config)

    Note over ModelRunner: 输入预处理
    ModelRunner->>ModelRunner: validate_inputs()
    ModelRunner->>ModelRunner: prepare_inputs()

    Note over ModelRunner: 推理执行
    ModelRunner->>Session: run(inputs, outputs, stream)
    Session->>Session: set_tensor_addresses()
    Session->>TensorRT: execute_async_v3(stream)
    TensorRT->>CUDA: launch_kernels()

    Note over CUDA: GPU 计算
    CUDA->>CUDA: attention_computation()
    CUDA->>CUDA: ffn_computation()
    CUDA->>CUDA: output_projection()

    CUDA-->>TensorRT: computation_complete
    TensorRT-->>Session: execution_complete
    Session-->>ModelRunner: inference_results

    Note over ModelRunner: 结果后处理
    ModelRunner->>ModelRunner: postprocess_outputs()
    ModelRunner-->>Executor: generation_results

5. 量化模块 (Quantization)

5.1 模块架构图

graph TB
    subgraph "量化配置"
        A[QuantConfig] --> B[量化算法选择]
        A --> C[量化参数配置]
        A --> D[校准配置]
    end

    subgraph "量化算法"
        E[权重量化] --> F[INT4 AWQ]
        E --> G[INT8 GPTQ]
        E --> H[FP8量化]
        E --> I[FP4量化]

        J[激活量化] --> K[SmoothQuant]
        J --> L[动态量化]

        M[KV缓存量化] --> N[INT8 KV Cache]
        M --> O[FP8 KV Cache]
    end

    subgraph "量化执行"
        P[quantize函数] --> Q[模块遍历]
        P --> R[量化方法选择]
        P --> S[权重转换]
        P --> T[模块替换]
    end

    subgraph "校准流程"
        U[Calibrator] --> V[数据收集]
        U --> W[统计计算]
        U --> X[量化参数生成]
    end

    A --> E
    A --> J
    A --> M
    E --> P
    J --> P
    M --> P
    P --> U

5.2 QuantConfig 配置类

位置: tensorrt_llm/quantization/quantize.py

@dataclass
class QuantConfig:
    """量化配置类"""

    # 量化算法
    quant_algo: QuantAlgo = QuantAlgo.NO_QUANT
    kv_cache_quant_algo: QuantAlgo = QuantAlgo.NO_QUANT

    # 权重量化参数
    group_size: int = 128
    has_zero_point: bool = True
    pre_quant_scale: bool = False
    exclude_modules: Optional[List[str]] = None

    # 激活量化参数
    activation_scaling_factor: float = 1.0
    weight_scaling_factor: float = 1.0
    smoothquant_val: float = 0.5

    # 校准参数
    calib_size: int = 512
    calib_max_seq_length: int = 512
    calib_dataset: str = "cnn_dailymail"

    def __post_init__(self):
        """后初始化验证"""
        # 验证量化算法兼容性
        if self.quant_algo == QuantAlgo.W4A16_AWQ and self.group_size not in [64, 128]:
            raise ValueError("AWQ quantization requires group_size of 64 or 128")

        # 验证 KV 缓存量化
        if self.kv_cache_quant_algo != QuantAlgo.NO_QUANT:
            if self.kv_cache_quant_algo not in KV_CACHE_QUANT_ALGO_LIST:
                raise ValueError(f"Unsupported KV cache quantization: {self.kv_cache_quant_algo}")

    @property
    def quant_mode(self) -> QuantMode:
        """获取量化模式"""
        mode = QuantMode(0)

        # 权重量化模式
        if self.quant_algo in [QuantAlgo.W4A16, QuantAlgo.W4A16_AWQ, QuantAlgo.W4A16_GPTQ]:
            mode |= QuantMode.INT4_WEIGHTS
            if self.group_size > 0:
                mode |= QuantMode.PER_GROUP
        elif self.quant_algo in [QuantAlgo.W8A16, QuantAlgo.W8A16_GPTQ]:
            mode |= QuantMode.INT8_WEIGHTS
            mode |= QuantMode.PER_CHANNEL
        elif self.quant_algo == QuantAlgo.FP8:
            mode |= QuantMode.FP8_QDQ
        elif self.quant_algo == QuantAlgo.NVFP4:
            mode |= QuantMode.NVFP4

        # 激活量化模式
        if 'A8' in self.quant_algo.value:
            mode |= QuantMode.ACTIVATIONS
            mode |= QuantMode.PER_TOKEN

        # KV 缓存量化模式
        if self.kv_cache_quant_algo == QuantAlgo.INT8:
            mode |= QuantMode.INT8_KV_CACHE
        elif self.kv_cache_quant_algo == QuantAlgo.FP8:
            mode |= QuantMode.FP8_KV_CACHE
        elif self.kv_cache_quant_algo == QuantAlgo.NVFP4:
            mode |= QuantMode.NVFP4_KV_CACHE

        return mode

    def to_dict(self) -> Dict[str, Any]:
        """转换为字典格式"""
        return {
            'quant_algo': self.quant_algo.value,
            'kv_cache_quant_algo': self.kv_cache_quant_algo.value,
            'group_size': self.group_size,
            'has_zero_point': self.has_zero_point,
            'activation_scaling_factor': self.activation_scaling_factor,
            'weight_scaling_factor': self.weight_scaling_factor
        }

    @classmethod
    def from_dict(cls, config_dict: Dict[str, Any]) -> 'QuantConfig':
        """从字典创建配置"""
        return cls(
            quant_algo=QuantAlgo(config_dict['quant_algo']),
            kv_cache_quant_algo=QuantAlgo(config_dict.get('kv_cache_quant_algo', 'NO_QUANT')),
            group_size=config_dict.get('group_size', 128),
            has_zero_point=config_dict.get('has_zero_point', True),
            activation_scaling_factor=config_dict.get('activation_scaling_factor', 1.0),
            weight_scaling_factor=config_dict.get('weight_scaling_factor', 1.0)
        )

5.3 quantize() 函数实现

位置: tensorrt_llm/quantization/quantize.py:561-603

def quantize(model, quant_config: Union[QuantConfig, LayerQuantConfig]):
    """
    模型量化的核心函数

    Args:
        model: 要量化的模型
        quant_config: 量化配置

    Returns:
        量化后的模型

    量化流程:
    1. 模块遍历和分析
    2. 量化方法选择
    3. 权重量化转换
    4. 模块替换
    5. KV缓存量化
    """

    logger.info(f"Starting model quantization with {quant_config.quant_algo}")

    # 统计信息
    quantized_modules = 0
    total_modules = 0

    # 遍历模型的所有模块
    for name, module, parent in model.named_modules_with_parent():
        total_modules += 1

        # 1. 确定层级量化模式
        if isinstance(quant_config, LayerQuantConfig):
            # 混合精度量化
            layer_quant_mode = quant_config.layer_quant_mode(name)
        else:
            # 统一量化
            layer_quant_mode = quant_config.quant_mode

        # 跳过不量化的模块
        if layer_quant_mode == QuantMode(0):
            continue

        # 检查排除列表
        if quant_config.exclude_modules:
            if any(excluded in name for excluded in quant_config.exclude_modules):
                logger.debug(f"Skipping excluded module: {name}")
                continue

        # 2. 获取层级量化配置
        layer_quant_cfg = quant_config._get_quant_cfg(name) if hasattr(quant_config, '_get_quant_cfg') else quant_config

        # 3. 根据量化模式选择量化方法
        original_module = module

        if layer_quant_mode.has_fp8_qdq():
            logger.debug(f"Applying FP8 quantization to {name}")
            module = fp8_quantize(module, layer_quant_cfg)

        elif layer_quant_mode.has_fp8_rowwise():
            logger.debug(f"Applying FP8 rowwise quantization to {name}")
            module = fp8_rowwise_quantize(module, layer_quant_cfg)

        elif layer_quant_mode.is_qserve_w4a8():
            logger.debug(f"Applying QServe W4A8 quantization to {name}")
            module = qserve_quantize(module, quant_config)

        elif layer_quant_mode.has_nvfp4():
            logger.debug(f"Applying NVFP4 quantization to {name}")
            module = fp4_quantize(module, layer_quant_cfg)

        elif layer_quant_mode.has_act_and_weight_quant():
            logger.debug(f"Applying SmoothQuant to {name}")
            module = smooth_quantize(module, layer_quant_cfg)

        elif layer_quant_mode.is_weight_only():
            if layer_quant_mode.has_per_group_scaling():
                logger.debug(f"Applying weight-only groupwise quantization to {name}")
                module = weight_only_groupwise_quantize(module, layer_quant_cfg, model.config)
            else:
                logger.debug(f"Applying weight-only quantization to {name}")
                module = weight_only_quantize(module, layer_quant_cfg, model.config)

        # 4. 替换模块
        if module is not original_module:
            quantized_modules += 1

            if parent is not None:
                # 替换子模块
                module_name = name.rsplit('.', 1)[-1]
                setattr(parent, module_name, module)
                logger.debug(f"Replaced module {name}")
            else:
                # 替换整个模型
                model = module
                break

    # 5. KV 缓存量化
    if quant_config.quant_mode.has_kv_cache_quant():
        logger.info("Applying KV cache quantization")
        model = kv_cache_quantize(model)

    # 6. 设置量化模式属性
    setattr(model, 'quant_mode', quant_config.quant_mode)

    logger.info(f"Quantization complete: {quantized_modules}/{total_modules} modules quantized")

    return model

def weight_only_quantize(module, quant_config, model_config):
    """权重量化实现"""

    if not hasattr(module, 'weight'):
        return module

    # 获取权重
    weight = module.weight.data

    # 量化权重
    if quant_config.quant_algo == QuantAlgo.W4A16_AWQ:
        # AWQ 量化
        quantized_weight, scales, zeros = awq_quantize_weights(
            weight,
            group_size=quant_config.group_size,
            has_zero_point=quant_config.has_zero_point
        )
    elif quant_config.quant_algo == QuantAlgo.W8A16_GPTQ:
        # GPTQ 量化
        quantized_weight, scales = gptq_quantize_weights(weight)
    else:
        raise ValueError(f"Unsupported weight quantization: {quant_config.quant_algo}")

    # 创建量化模块
    from .layers import WeightOnlyQuantLinear

    quantized_module = WeightOnlyQuantLinear(
        in_features=module.in_features,
        out_features=module.out_features,
        bias=module.bias is not None,
        quant_mode=quant_config.quant_mode,
        group_size=quant_config.group_size
    )

    # 设置量化权重
    quantized_module.weight.data = quantized_weight
    quantized_module.scales.data = scales

    if quant_config.has_zero_point and 'zeros' in locals():
        quantized_module.zeros.data = zeros

    if module.bias is not None:
        quantized_module.bias.data = module.bias.data

    return quantized_module

def fp8_quantize(module, quant_config):
    """FP8 量化实现"""

    if not hasattr(module, 'weight'):
        return module

    # FP8 量化参数
    fp8_max = 448.0  # FP8 E4M3 最大值

    # 计算缩放因子
    weight = module.weight.data
    weight_scale = weight.abs().max() / fp8_max

    # 量化权重
    quantized_weight = (weight / weight_scale).to(torch.float8_e4m3fn)

    # 创建 FP8 量化模块
    from .layers import FP8QuantLinear

    quantized_module = FP8QuantLinear(
        in_features=module.in_features,
        out_features=module.out_features,
        bias=module.bias is not None
    )

    # 设置量化权重和缩放因子
    quantized_module.weight.data = quantized_weight
    quantized_module.weight_scale.data = weight_scale

    if module.bias is not None:
        quantized_module.bias.data = module.bias.data

    return quantized_module

5.4 量化时序图

sequenceDiagram
    participant User
    participant QuantConfig
    participant Quantizer
    participant Module
    participant Calibrator
    participant QuantModule

    User->>QuantConfig: create QuantConfig(algo=FP8)
    User->>Quantizer: quantize(model, config)

    Note over Quantizer: 模块遍历
    loop 遍历所有模块
        Quantizer->>Module: get module info
        Quantizer->>QuantConfig: get layer config

        alt 需要校准
            Quantizer->>Calibrator: collect statistics
            Calibrator->>Calibrator: compute scales
            Calibrator-->>Quantizer: quantization params
        end

        Quantizer->>Quantizer: select quantization method

        alt FP8 量化
            Quantizer->>QuantModule: create FP8QuantLinear
            Quantizer->>QuantModule: set quantized weights
        else INT4 量化
            Quantizer->>QuantModule: create WeightOnlyQuantLinear
            Quantizer->>QuantModule: set quantized weights
        end

        Quantizer->>Module: replace with quantized module
    end

    Note over Quantizer: KV缓存量化
    alt KV缓存量化启用
        Quantizer->>Quantizer: kv_cache_quantize()
    end

    Quantizer-->>User: quantized model

这个模块深度分析文档详细剖析了 TensorRT-LLM 的五个核心模块:LLM API、执行器、构建器、运行时和量化模块。每个模块都包含了架构图、核心类实现、关键方法分析和时序图,为深入理解各模块的设计原理和实现细节提供了全面的技术参考。