TensorRT-LLM-07-Quantization模块-深度剖析
一、模块概览
1.1 模块定位
Quantization 模块是TensorRT-LLM的量化加速核心,提供多种精度量化技术,在保持模型精度的同时显著降低内存占用和提升推理速度。
核心职责:
- 量化算法:FP8、INT8、INT4、INT3等
- 校准技术:SmoothQuant、AWQ、GPTQ等
- KV Cache量化:减少内存占用
- 权重量化:压缩模型大小
- 激活量化:降低计算精度
1.2 支持的量化技术
| 量化类型 | 精度 | 加速比 | 精度保持 | 硬件需求 | 适用场景 |
|---|---|---|---|---|---|
| FP8 | 8位浮点 | 1.5-2x | ~100% | H100/H200 | 高精度推理 |
| INT8 SmoothQuant | 8位整数 | 1.5-2x | 99%+ | V100+ | 通用加速 |
| INT4 AWQ | 4位整数 | 2-3x | 95%+ | A100+ | 内存受限 |
| INT4 GPTQ | 4位整数 | 2-3x | 95%+ | A100+ | 权重压缩 |
| W8A16 | 权重8位,激活16位 | 1.2-1.5x | 98%+ | 通用 | 权重压缩 |
| KV Cache INT8 | KV 8位 | 内存减半 | 99%+ | 通用 | 长序列 |
1.3 模块架构
tensorrt_llm/quantization/
├── __init__.py # 量化接口导出
├── quantize.py # 统一量化入口
├── mode.py # QuantMode定义
├── layers.py # 量化层实现
├── functional.py # 量化函数
│
├── algorithms/ # 量化算法
│ ├── smooth_quant.py # SmoothQuant实现
│ ├── awq.py # AWQ实现
│ ├── gptq.py # GPTQ实现
│ └── fp8_quant.py # FP8量化
│
├── calib/ # 校准相关
│ ├── int8/ # INT8校准
│ └── calibrator.py # 校准器实现
│
└── kernels/ # 量化kernel
├── fp8_gemm_kernel.py # FP8 GEMM
├── int8_gemm_kernel.py # INT8 GEMM
└── int4_gemm_kernel.py # INT4 GEMM
1.4 量化流程
原始模型(FP16) → 校准数据集 → 量化校准 → 量化参数 → 量化模型 → TRT引擎
↓ ↓ ↓ ↓ ↓ ↓
权重+激活 代表性样本 统计分析 scale/zp 压缩表示 优化推理
二、核心API详细剖析
2.1 quantize()统一接口
2.1.1 函数签名
def quantize(
model: PretrainedModel, # 待量化模型
quant_config: Union[str, QuantConfig], # 量化配置
calib_dataset: Optional[Dataset] = None, # 校准数据集
calib_size: int = 512, # 校准样本数
random_seed: int = 42, # 随机种子
tokenizer = None, # 分词器
**kwargs
) -> PretrainedModel:
"""
统一量化接口
Args:
model: 预训练模型
quant_config: 量化配置("fp8", "int8_sq", "int4_awq"等)
calib_dataset: 校准数据集(INT8/INT4需要)
Returns:
量化后的模型
"""
2.1.2 量化配置
QuantConfig结构体
| 字段 | 类型 | 说明 | 示例值 |
|---|---|---|---|
| quant_algo | QuantAlgo | 量化算法 | FP8, INT8_SQ, INT4_AWQ |
| kv_cache_quant_algo | QuantAlgo | KV Cache量化 | INT8, FP8 |
| exclude_modules | List[str] | 排除模块 | [“lm_head”] |
| per_channel | bool | 是否按通道量化 | True |
| per_token | bool | 是否按token量化 | True |
| use_plugin | bool | 是否使用插件 | True |
| calib_method | str | 校准方法 | “max”, “percentile” |
2.1.3 核心实现
def quantize(model, quant_config, calib_dataset=None, **kwargs):
# 1. 解析量化配置
if isinstance(quant_config, str):
quant_config = _parse_quant_config(quant_config)
# 2. 选择量化算法
if quant_config.quant_algo == QuantAlgo.FP8:
return _quantize_fp8(model, quant_config)
elif quant_config.quant_algo == QuantAlgo.INT8_SMOOTHQUANT:
return _quantize_int8_smoothquant(model, quant_config, calib_dataset)
elif quant_config.quant_algo == QuantAlgo.INT4_AWQ:
return _quantize_int4_awq(model, quant_config, calib_dataset)
elif quant_config.quant_algo == QuantAlgo.INT4_GPTQ:
return _quantize_int4_gptq(model, quant_config, calib_dataset)
else:
raise ValueError(f"Unsupported quantization algorithm: {quant_config.quant_algo}")
def _parse_quant_config(config_str: str) -> QuantConfig:
"""解析字符串配置"""
config_map = {
"fp8": QuantConfig(
quant_algo=QuantAlgo.FP8,
kv_cache_quant_algo=QuantAlgo.FP8,
),
"int8_sq": QuantConfig(
quant_algo=QuantAlgo.INT8_SMOOTHQUANT,
kv_cache_quant_algo=QuantAlgo.INT8,
per_channel=True,
per_token=True,
),
"int4_awq": QuantConfig(
quant_algo=QuantAlgo.INT4_AWQ,
per_channel=True,
group_size=128,
),
"w8a16": QuantConfig(
quant_algo=QuantAlgo.W8A16,
per_channel=True,
),
}
return config_map.get(config_str, QuantConfig())
2.2 FP8量化
2.2.1 原理
FP8格式(IEEE标准):
- E4M3:4位指数,3位尾数(动态范围大)
- E5M2:5位指数,2位尾数(精度高)
优势:
- 接近FP16精度(精度损失<1%)
- 硬件原生支持(H100/H200)
- 不需要复杂校准
- 支持权重和激活量化
应用:
- 权重:E4M3格式(-448 to 448)
- 激活:E5M2格式(更大动态范围)
2.2.2 实现
def _quantize_fp8(model: PretrainedModel, quant_config: QuantConfig):
"""
FP8量化实现
"""
# 1. 设置量化模式
model.config.quantization = quant_config
quant_mode = QuantMode.from_quant_algo(QuantAlgo.FP8)
# 2. 遍历所有线性层
for name, module in model.named_modules():
if isinstance(module, (Linear, ColumnLinear, RowLinear)):
# 2.1 跳过排除的模块
if any(exclude in name for exclude in quant_config.exclude_modules):
continue
# 2.2 转换为FP8量化层
fp8_module = _convert_to_fp8_layer(module, quant_config)
# 2.3 替换原模块
parent = model
attrs = name.split('.')
for attr in attrs[:-1]:
parent = getattr(parent, attr)
setattr(parent, attrs[-1], fp8_module)
# 3. 设置KV Cache量化
if quant_config.kv_cache_quant_algo == QuantAlgo.FP8:
_enable_fp8_kv_cache(model)
return model
def _convert_to_fp8_layer(module, quant_config):
"""
转换为FP8量化层
"""
# 1. 创建FP8线性层
fp8_layer = FP8Linear(
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
dtype=module.dtype,
tp_group=getattr(module, 'tp_group', None),
tp_size=getattr(module, 'tp_size', 1),
)
# 2. 转换权重为FP8
with torch.no_grad():
# 2.1 获取原始权重
weight_fp16 = module.weight.data # [out_features, in_features]
# 2.2 计算缩放因子
# FP8 E4M3 范围:[-448, 448]
max_val = weight_fp16.abs().max()
scale = 448.0 / max_val
# 2.3 量化权重
weight_fp8 = (weight_fp16 * scale).clamp(-448, 448)
weight_fp8 = weight_fp8.to(torch.float8_e4m3fn) # PyTorch 2.1+
# 2.4 设置权重和缩放因子
fp8_layer.weight.data = weight_fp8
fp8_layer.weight_scale = Parameter(torch.tensor(scale))
# 2.5 处理偏置
if module.bias is not None:
fp8_layer.bias.data = module.bias.data
return fp8_layer
class FP8Linear(Module):
"""
FP8量化线性层
"""
def __init__(self, in_features, out_features, bias=True, **kwargs):
super().__init__()
# FP8权重
self.weight = Parameter(
shape=(out_features, in_features),
dtype=torch.float8_e4m3fn, # FP8格式
)
# 权重缩放因子
self.weight_scale = Parameter(
shape=(1,),
dtype=torch.float32,
)
# 激活缩放因子(动态计算)
self.activation_scale = Parameter(
shape=(1,),
dtype=torch.float32,
)
if bias:
self.bias = Parameter(
shape=(out_features,),
dtype=torch.float16,
)
def forward(self, input: Tensor) -> Tensor:
# 1. 计算激活缩放因子(动态)
input_max = input.abs().max()
act_scale = 240.0 / input_max # E5M2范围较大
# 2. 量化激活
input_fp8 = (input * act_scale).to(torch.float8_e5m2)
# 3. FP8矩阵乘法
# 现代GPU直接支持FP8 GEMM
output_fp8 = torch.matmul(input_fp8, self.weight.T)
# 4. 反量化到FP16
output = output_fp8.to(torch.float16)
output = output / (self.weight_scale * act_scale)
# 5. 添加偏置
if hasattr(self, 'bias'):
output = output + self.bias
return output
2.3 INT8 SmoothQuant
2.3.1 原理
SmoothQuant核心思想:
1. 观察:激活值分布不均匀,有outlier
2. 解决:通过数学变换平滑激活分布
3. 公式:Y = (X diag(s)^(-1)) · (diag(s) W)
其中s是平滑因子
步骤:
1. 统计激活值分布
2. 计算平滑因子s
3. 调整权重:W' = diag(s) * W
4. 调整激活:X' = X * diag(s)^(-1)
5. 量化W'和X'
优势:
- 激活分布更均匀
- 量化误差更小
- 不需要复杂校准
2.3.2 实现
def _quantize_int8_smoothquant(model, quant_config, calib_dataset):
"""
INT8 SmoothQuant量化
"""
# 1. 收集激活统计信息
act_scales = _collect_activation_scales(model, calib_dataset)
# 2. 计算平滑因子
smooth_scales = _compute_smooth_scales(model, act_scales, alpha=0.5)
# 3. 应用平滑变换
_apply_smooth_transform(model, smooth_scales)
# 4. 量化权重和激活
_quantize_weights_int8(model, quant_config)
_setup_activation_quantization(model, quant_config)
return model
def _collect_activation_scales(model, calib_dataset):
"""
收集激活值范围统计
"""
scales = {}
hooks = []
def hook_fn(name):
def hook(module, input, output):
# 记录激活值的最大绝对值
if name not in scales:
scales[name] = []
scales[name].append(input[0].abs().max().item())
return hook
# 1. 注册hook
for name, module in model.named_modules():
if isinstance(module, (Linear, ColumnLinear, RowLinear)):
handle = module.register_forward_hook(hook_fn(name))
hooks.append(handle)
# 2. 前向传播收集统计
model.eval()
with torch.no_grad():
for batch in calib_dataset:
input_ids = batch['input_ids']
model(input_ids)
# 3. 清理hook
for handle in hooks:
handle.remove()
# 4. 计算平均scale
final_scales = {}
for name, scale_list in scales.items():
final_scales[name] = np.mean(scale_list)
return final_scales
def _compute_smooth_scales(model, act_scales, alpha=0.5):
"""
计算平滑因子
Args:
alpha: 平滑强度(0.0-1.0)
0.0:不平滑,1.0:完全平滑
"""
smooth_scales = {}
for name, module in model.named_modules():
if isinstance(module, (Linear, ColumnLinear, RowLinear)):
# 1. 获取权重统计
weight = module.weight.data # [out_features, in_features]
weight_scales = weight.abs().max(dim=0).values # 按输入通道
# 2. 获取激活统计
act_scale = act_scales.get(name, 1.0)
# 3. 计算平滑因子
# s_j = (max_weight_j)^alpha / (max_activation_j)^alpha
smooth_scale = (weight_scales ** alpha) / (act_scale ** alpha)
# 4. 裁剪到合理范围
smooth_scale = torch.clamp(smooth_scale, 0.1, 10.0)
smooth_scales[name] = smooth_scale
return smooth_scales
def _apply_smooth_transform(model, smooth_scales):
"""
应用平滑变换
"""
for name, module in model.named_modules():
if name in smooth_scales:
scale = smooth_scales[name]
# 1. 调整权重:W' = diag(s) * W
module.weight.data = module.weight.data * scale.unsqueeze(0)
# 2. 调整对应的LayerNorm/RmsNorm(如果存在)
# 需要找到前一个normalization层并调整其权重
prev_norm = _find_previous_norm_layer(model, name)
if prev_norm is not None:
# norm_weight' = norm_weight / s
prev_norm.weight.data = prev_norm.weight.data / scale
class INT8Linear(Module):
"""
INT8量化线性层
"""
def __init__(self, in_features, out_features, **kwargs):
super().__init__()
# 量化权重(INT8)
self.qweight = Parameter(
shape=(out_features, in_features),
dtype=torch.int8,
)
# 权重量化参数
self.weight_scale = Parameter(
shape=(out_features, 1) if kwargs.get('per_channel') else (1,),
dtype=torch.float32,
)
# 激活量化参数
self.input_scale = Parameter(
shape=(1,),
dtype=torch.float32,
)
def forward(self, input: Tensor) -> Tensor:
# 1. 动态量化激活
input_scale = input.abs().max() / 127.0
input_int8 = torch.round(input / input_scale).clamp(-128, 127).to(torch.int8)
# 2. INT8矩阵乘法
output_int32 = torch.matmul(input_int8.to(torch.int32), self.qweight.T.to(torch.int32))
# 3. 反量化
output = output_int32.to(torch.float32) * (input_scale * self.weight_scale)
return output.to(torch.float16)
2.4 INT4 AWQ量化
2.4.1 原理
AWQ (Activation-aware Weight Quantization):
1. 观察:不同权重通道的重要性不同
2. 策略:保护重要通道,积极量化不重要通道
3. 方法:基于激活值幅度确定通道重要性
核心算法:
1. 收集激活统计:A = {a1, a2, ..., an}
2. 计算重要性:importance_j = mean(|a_j|)
3. 计算缩放因子:s_j = importance_j^(-α)
4. 应用缩放:W' = diag(s) * W, X' = X * diag(s)^(-1)
5. 量化W'为INT4
分组量化:
- 将权重按组量化(如128个权重一组)
- 每组独立计算量化参数
- 减少量化误差
2.4.2 实现
def _quantize_int4_awq(model, quant_config, calib_dataset):
"""
INT4 AWQ量化
"""
# 1. 收集激活统计
act_stats = _collect_activation_statistics(model, calib_dataset)
# 2. 计算AWQ缩放因子
awq_scales = _compute_awq_scales(model, act_stats, alpha=0.5)
# 3. 应用AWQ变换
_apply_awq_transform(model, awq_scales)
# 4. 分组量化权重
_quantize_weights_int4_grouped(model, quant_config)
return model
def _compute_awq_scales(model, act_stats, alpha=0.5):
"""
计算AWQ缩放因子
"""
awq_scales = {}
for name, module in model.named_modules():
if isinstance(module, (Linear, ColumnLinear, RowLinear)):
# 1. 获取激活统计
if name in act_stats:
act_mean = act_stats[name]['mean'] # 每个通道的平均激活值
# 2. 计算重要性(激活值越大越重要)
importance = act_mean
# 3. 计算缩放因子
# 重要通道缩放因子小(保护),不重要通道缩放因子大
scale = importance ** (-alpha)
scale = scale / scale.mean() # 归一化
awq_scales[name] = scale
return awq_scales
def _quantize_weights_int4_grouped(model, quant_config):
"""
分组INT4权重量化
"""
group_size = quant_config.group_size # 默认128
for name, module in model.named_modules():
if isinstance(module, (Linear, ColumnLinear, RowLinear)):
# 1. 获取权重
weight = module.weight.data # [out_features, in_features]
out_features, in_features = weight.shape
# 2. 按组量化
# 将in_features维度按group_size分组
num_groups = (in_features + group_size - 1) // group_size
qweight = torch.zeros(out_features, in_features // 2, dtype=torch.uint8) # pack 2个4位数字
scales = torch.zeros(out_features, num_groups, dtype=torch.float16)
zeros = torch.zeros(out_features, num_groups, dtype=torch.float16)
for g in range(num_groups):
start_idx = g * group_size
end_idx = min((g + 1) * group_size, in_features)
# 2.1 获取当前组权重
group_weight = weight[:, start_idx:end_idx] # [out_features, group_size]
# 2.2 计算量化参数(per channel)
group_min = group_weight.min(dim=1, keepdim=True).values
group_max = group_weight.max(dim=1, keepdim=True).values
# INT4范围:0-15
scale = (group_max - group_min) / 15.0
zero_point = -group_min / scale
# 2.3 量化
qweight_group = torch.round(group_weight / scale + zero_point).clamp(0, 15)
# 2.4 打包(2个4位数字打包成1个8位)
if end_idx - start_idx == group_size: # 完整组
packed = qweight_group[:, ::2] + (qweight_group[:, 1::2] << 4)
qweight[:, start_idx//2:end_idx//2] = packed.to(torch.uint8)
# 2.5 保存量化参数
scales[:, g] = scale.squeeze()
zeros[:, g] = zero_point.squeeze()
# 3. 替换为INT4量化层
int4_layer = INT4Linear(
in_features=in_features,
out_features=out_features,
group_size=group_size,
)
int4_layer.qweight.data = qweight
int4_layer.scales.data = scales
int4_layer.zeros.data = zeros
# 4. 替换原层
_replace_module(model, name, int4_layer)
class INT4Linear(Module):
"""
INT4分组量化线性层
"""
def __init__(self, in_features, out_features, group_size=128):
super().__init__()
num_groups = (in_features + group_size - 1) // group_size
# 打包的INT4权重(2个4位打包成1个8位)
self.qweight = Parameter(
shape=(out_features, in_features // 2),
dtype=torch.uint8,
)
# 每组的量化参数
self.scales = Parameter(
shape=(out_features, num_groups),
dtype=torch.float16,
)
self.zeros = Parameter(
shape=(out_features, num_groups),
dtype=torch.float16,
)
self.group_size = group_size
self.in_features = in_features
self.out_features = out_features
def forward(self, input: Tensor) -> Tensor:
# 使用自定义CUDA kernel进行INT4矩阵乘法
return int4_linear_forward(
input,
self.qweight,
self.scales,
self.zeros,
self.group_size
)
def int4_linear_forward(input, qweight, scales, zeros, group_size):
"""
INT4线性层前向传播(调用CUDA kernel)
"""
# 1. 解包INT4权重
batch_size, seq_len, in_features = input.shape
out_features = qweight.shape[0]
# 2. 调用优化的CUDA kernel
# 实际实现中会使用C++/CUDA编写的高效kernel
output = torch.empty(
batch_size, seq_len, out_features,
dtype=input.dtype, device=input.device
)
# 伪代码:实际使用CUDA kernel
# cutlass_int4_gemm(input, qweight, scales, zeros, output, group_size)
return output
三、关键功能深度剖析
3.1 量化精度对比
3.1.1 数值范围对比
# 不同精度的数值范围
precisions = {
"FP16": {
"range": "±65504",
"precision": "~4位有效数字",
"memory": "2 bytes",
},
"FP8 E4M3": {
"range": "±448",
"precision": "~2位有效数字",
"memory": "1 byte",
},
"FP8 E5M2": {
"range": "±57344",
"precision": "~1.5位有效数字",
"memory": "1 byte",
},
"INT8": {
"range": "-128 to 127",
"precision": "整数",
"memory": "1 byte",
},
"INT4": {
"range": "0 to 15 (unsigned)",
"precision": "整数",
"memory": "0.5 bytes",
}
}
3.1.2 精度损失分析
def analyze_quantization_error(original_weights, quantized_weights):
"""
量化误差分析
"""
# 1. 均方误差
mse = torch.mean((original_weights - quantized_weights) ** 2)
# 2. 信噪比
signal_power = torch.mean(original_weights ** 2)
noise_power = mse
snr_db = 10 * torch.log10(signal_power / noise_power)
# 3. 相对误差
relative_error = torch.abs(original_weights - quantized_weights) / torch.abs(original_weights)
mean_relative_error = torch.mean(relative_error)
return {
"mse": mse.item(),
"snr_db": snr_db.item(),
"mean_relative_error": mean_relative_error.item(),
}
# 典型结果:
# FP8: SNR ~40dB, 相对误差 ~1%
# INT8 SmoothQuant: SNR ~35dB, 相对误差 ~2%
# INT4 AWQ: SNR ~25dB, 相对误差 ~5%
3.2 KV Cache量化
3.2.1 原理
KV Cache特点:
1. 占用内存大(长序列时占主导)
2. 数值范围相对稳定
3. 对量化不敏感
量化策略:
1. 按token动态量化
2. 按head独立量化
3. 使用INT8或FP8格式
内存节省:
- FP16 KV Cache: 2 bytes/element
- INT8 KV Cache: 1 byte/element
- 节省50%内存
3.2.2 实现
class QuantizedKVCache(Module):
"""
量化KV Cache
"""
def __init__(
self,
num_layers: int,
num_heads: int,
head_size: int,
max_seq_len: int,
dtype: str = "int8",
):
super().__init__()
self.num_layers = num_layers
self.num_heads = num_heads
self.head_size = head_size
self.max_seq_len = max_seq_len
self.dtype = dtype
# 量化的KV Cache存储
if dtype == "int8":
cache_dtype = torch.int8
self.qkv_cache = torch.zeros(
num_layers, 2, max_seq_len, num_heads, head_size,
dtype=cache_dtype
)
# 量化参数(每个head独立)
self.kv_scales = torch.ones(
num_layers, 2, num_heads,
dtype=torch.float32
)
def store_kv(self, layer_idx: int, key: Tensor, value: Tensor, seq_pos: int):
"""
存储量化的KV
"""
# 1. 计算量化参数
k_scale = key.abs().max() / 127.0
v_scale = value.abs().max() / 127.0
# 2. 量化
k_quantized = torch.round(key / k_scale).clamp(-128, 127).to(torch.int8)
v_quantized = torch.round(value / v_scale).clamp(-128, 127).to(torch.int8)
# 3. 存储
self.qkv_cache[layer_idx, 0, seq_pos] = k_quantized
self.qkv_cache[layer_idx, 1, seq_pos] = v_quantized
# 4. 保存量化参数
self.kv_scales[layer_idx, 0] = k_scale
self.kv_scales[layer_idx, 1] = v_scale
def get_kv(self, layer_idx: int, seq_len: int) -> Tuple[Tensor, Tensor]:
"""
获取反量化的KV
"""
# 1. 获取量化数据
k_quantized = self.qkv_cache[layer_idx, 0, :seq_len] # [seq_len, num_heads, head_size]
v_quantized = self.qkv_cache[layer_idx, 1, :seq_len]
# 2. 反量化
k_scale = self.kv_scales[layer_idx, 0]
v_scale = self.kv_scales[layer_idx, 1]
key = k_quantized.to(torch.float16) * k_scale
value = v_quantized.to(torch.float16) * v_scale
return key, value
3.3 量化感知训练 vs 训练后量化
3.3.1 对比
| 方法 | 优势 | 劣势 | 适用场景 |
|---|---|---|---|
| 训练后量化(PTQ) | 简单、快速、无需重训练 | 精度损失较大 | 对精度要求不高 |
| 量化感知训练(QAT) | 精度损失小、效果好 | 需要重训练、时间长 | 对精度要求高 |
3.3.2 QAT实现示例
class QATLinear(Module):
"""
量化感知训练线性层
"""
def __init__(self, in_features, out_features, quant_bits=8):
super().__init__()
self.weight = Parameter(torch.randn(out_features, in_features))
self.quant_bits = quant_bits
# 可学习的量化参数
self.weight_scale = Parameter(torch.ones(1))
self.weight_zero_point = Parameter(torch.zeros(1))
def quantize_weight(self, weight):
"""
伪量化(训练时模拟量化过程)
"""
# 1. 计算量化范围
qmin = 0
qmax = 2 ** self.quant_bits - 1
# 2. 量化
scale = self.weight_scale
zero_point = self.weight_zero_point
# 3. 伪量化(前向量化,反向保持FP32梯度)
qweight = torch.round(weight / scale + zero_point).clamp(qmin, qmax)
dequant_weight = (qweight - zero_point) * scale
# 4. 直通估计器(Straight Through Estimator)
# 前向使用量化值,反向传播原始梯度
return weight + (dequant_weight - weight).detach()
def forward(self, input):
# 训练时使用伪量化
if self.training:
qweight = self.quantize_weight(self.weight)
else:
# 推理时使用真实量化
qweight = self.real_quantize(self.weight)
return torch.matmul(input, qweight.T)
四、使用示例
4.1 FP8量化部署
from tensorrt_llm import LLM
from tensorrt_llm.quantization import quantize
# 1. 加载模型
llm = LLM("meta-llama/Llama-3-8B")
# 2. FP8量化(无需校准数据)
quantized_model = quantize(llm.model, "fp8")
# 3. 构建引擎
llm.model = quantized_model
engine = llm._build_engine()
# 4. 推理
output = llm.generate("Hello world", max_tokens=100)
# 性能提升:
# - 内存使用:~50%减少
# - 推理速度:1.5-2x加速(H100)
# - 精度损失:<1%
4.2 INT4 AWQ量化
from datasets import load_dataset
from transformers import AutoTokenizer
# 1. 准备校准数据集
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3-8B")
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
def preprocess(examples):
return tokenizer(examples["text"], truncation=True, max_length=512)
calib_dataset = dataset.map(preprocess, batched=True).select(range(128))
# 2. INT4 AWQ量化
llm = LLM("meta-llama/Llama-3-8B")
quantized_model = quantize(
llm.model,
quant_config="int4_awq",
calib_dataset=calib_dataset,
calib_size=128,
)
# 3. 构建引擎
llm.model = quantized_model
engine = llm._build_engine()
# 性能提升:
# - 内存使用:~75%减少
# - 推理速度:2-3x加速
# - 精度损失:~5%
4.3 混合精度量化
from tensorrt_llm.quantization import QuantConfig, QuantAlgo
# 自定义量化配置
quant_config = QuantConfig(
# 权重INT4,激活FP16
quant_algo=QuantAlgo.INT4_AWQ,
# KV Cache INT8
kv_cache_quant_algo=QuantAlgo.INT8,
# 排除敏感层
exclude_modules=["lm_head"],
# 分组量化
group_size=128,
per_channel=True,
)
# 应用量化
quantized_model = quantize(llm.model, quant_config, calib_dataset)
4.4 量化效果评估
def evaluate_quantization_quality(original_model, quantized_model, test_dataset):
"""
评估量化效果
"""
from sklearn.metrics import mean_squared_error
import numpy as np
original_outputs = []
quantized_outputs = []
# 1. 收集输出
for batch in test_dataset:
with torch.no_grad():
orig_out = original_model(batch["input_ids"])
quant_out = quantized_model(batch["input_ids"])
original_outputs.append(orig_out.logits.cpu().numpy())
quantized_outputs.append(quant_out.logits.cpu().numpy())
# 2. 计算指标
orig_concat = np.concatenate(original_outputs, axis=0)
quant_concat = np.concatenate(quantized_outputs, axis=0)
# 均方误差
mse = mean_squared_error(orig_concat.flatten(), quant_concat.flatten())
# 余弦相似度
cosine_sim = np.dot(orig_concat.flatten(), quant_concat.flatten()) / \
(np.linalg.norm(orig_concat.flatten()) * np.linalg.norm(quant_concat.flatten()))
# Top-1准确率差异
orig_pred = np.argmax(orig_concat, axis=-1)
quant_pred = np.argmax(quant_concat, axis=-1)
top1_accuracy = np.mean(orig_pred == quant_pred)
return {
"mse": mse,
"cosine_similarity": cosine_sim,
"top1_accuracy": top1_accuracy,
}
# 使用示例
results = evaluate_quantization_quality(original_model, quantized_model, test_dataset)
print(f"MSE: {results['mse']:.6f}")
print(f"Cosine Similarity: {results['cosine_similarity']:.4f}")
print(f"Top-1 Accuracy: {results['top1_accuracy']:.4f}")
五、性能优化建议
5.1 量化算法选择
# 根据硬件和需求选择量化算法
def select_quantization_strategy(model_size, hardware, precision_requirement):
"""
量化策略选择指南
"""
strategies = {
# 高精度场景
"high_precision": {
"H100/H200": "fp8", # 原生支持,速度快
"A100": "int8_smoothquant", # 通用,精度好
"V100": "w8a16", # 权重量化,激活保持FP16
},
# 内存受限场景
"memory_constrained": {
"A100+": "int4_awq", # 激活感知,精度相对好
"V100+": "int4_gptq", # 通用,压缩比高
},
# 超大模型场景
"large_model": {
"multi_gpu": "int4_awq", # 显存节省明显
"single_gpu": "int8_smoothquant", # 平衡精度和速度
}
}
return strategies.get(precision_requirement, {}).get(hardware, "int8_smoothquant")
# 使用示例
strategy = select_quantization_strategy(
model_size="70B",
hardware="A100",
precision_requirement="memory_constrained"
)
print(f"Recommended strategy: {strategy}")
5.2 校准数据集优化
def create_optimal_calib_dataset(model_name, domain="general"):
"""
创建最优校准数据集
"""
datasets = {
"general": [
"wikitext-2-raw-v1",
"c4",
"openwebtext",
],
"code": [
"codeparrot/github-code",
"bigcode/the-stack",
],
"math": [
"hendrycks/math",
"gsm8k",
],
"chat": [
"alpaca",
"sharegpt",
]
}
# 组合多个数据集
combined_dataset = []
for dataset_name in datasets[domain]:
dataset = load_dataset(dataset_name, split="train")
# 采样均匀分布的样本
sampled = dataset.shuffle(seed=42).select(range(128))
combined_dataset.extend(sampled)
return combined_dataset
# 关键原则:
# 1. 多样性:覆盖不同类型文本
# 2. 代表性:反映实际使用场景
# 3. 长度分布:包含不同长度的样本
# 4. 数量适中:128-512样本通常足够
5.3 量化后优化
def optimize_quantized_model(quantized_model):
"""
量化后模型优化
"""
# 1. 层融合
fused_model = fuse_quantized_layers(quantized_model)
# 2. 常量折叠
folded_model = fold_quantization_constants(fused_model)
# 3. 死代码消除
optimized_model = eliminate_unused_quantization_ops(folded_model)
return optimized_model
def fuse_quantized_layers(model):
"""
融合量化层
例如:QuantLinear + QuantReLU → QuantLinearReLU
"""
for name, module in model.named_modules():
if isinstance(module, QuantLinear):
# 查找后续的激活函数
next_module = get_next_module(model, name)
if isinstance(next_module, QuantReLU):
# 融合为单个kernel
fused = QuantLinearReLU(module, next_module)
replace_module(model, name, fused)
return model
六、常见问题
Q1:FP8量化需要什么硬件?
- 需要H100或H200 GPU
- 支持FP8 Tensor Core
- 其他GPU可以模拟FP8但无加速效果
Q2:INT4量化为什么需要校准数据集?
- INT4动态范围小,需要精确的量化参数
- 校准数据帮助统计权重和激活分布
- AWQ/GPTQ算法需要激活统计信息
Q3:量化后精度损失如何评估?
# 标准评估流程
metrics = [
"perplexity", # 困惑度(语言模型)
"bleu_score", # BLEU分数(生成任务)
"accuracy", # 准确率(分类任务)
"cosine_similarity" # 输出相似度
]
# 可接受的精度损失:
# FP8: <1%
# INT8: <3%
# INT4: <8%
Q4:如何选择group_size?
- 较小group_size(64-128):精度好,开销大
- 较大group_size(256-512):精度差,开销小
- 推荐:128(平衡精度和性能)
Q5:量化模型如何微调?
# QLoRA:量化模型+LoRA微调
quantized_model = quantize(base_model, "int4_awq")
lora_model = add_lora_adapters(quantized_model, rank=16)
# 只训练LoRA参数,量化权重冻结
for param in quantized_model.parameters():
param.requires_grad = False
for param in lora_model.parameters():
param.requires_grad = True