PyTorch-06-Functorch函数式变换
模块概览
Functorch是PyTorch的函数式编程扩展,提供可组合的函数变换,包括自动向量化(vmap)、自动微分(grad)、即时编译等。它将函数视为一等公民,支持高阶函数操作,是PyTorch 2.0编译栈的重要组成部分。
核心功能
- vmap(向量化映射):自动批处理函数,无需手写循环
- grad(函数式梯度):纯函数式梯度计算,支持高阶导数
- jvp/vjp:Jacobian向量积,支持前向/反向模式AD
- hessian:Hessian矩阵计算
- jacrev/jacfwd:Jacobian矩阵计算(反向/前向模式)
- functionalize:将inplace操作转换为functional形式
架构图
flowchart TB
subgraph 用户API
A1[torch.func.vmap]
A2[torch.func.grad]
A3[torch.func.jvp/vjp]
A4[torch.func.hessian]
end
subgraph 变换引擎
B1[VmapTransform]
B2[GradTransform]
B3[JvpTransform]
B4[FunctionalizationTransform]
end
subgraph 批处理规则
C1[BatchingRules]
C2[VmapFallback]
C3[GenerateVmapPlumbing]
end
subgraph Dispatcher集成
D1[FuncTorchBatched]
D2[FuncTorchVmapMode]
D3[FuncTorchGradWrapper]
D4[Functionalize]
end
subgraph 编译器集成
E1[TorchDynamo]
E2[AOTAutograd]
E3[Inductor]
end
A1 --> B1
A2 --> B2
A3 --> B3
A4 --> B2
B1 --> C1
B1 --> C2
B2 --> D3
B3 --> D3
C1 --> D1
C2 --> D2
D1 --> E1
D2 --> E2
D3 --> E3
style B1 fill:#e8f5e9
style B2 fill:#e8f5e9
style D1 fill:#e1f5ff
style E1 fill:#fff4e1
核心变换
vmap - 自动向量化
基本概念
vmap将操作单个样本的函数自动转换为处理批次的函数。
import torch
from torch import func
# 原始函数:计算单个向量的L2范数
def compute_norm(x):
return x.pow(2).sum().sqrt()
# 对于批次数据,传统方法需要循环
batch_x = torch.randn(100, 3) # 100个3维向量
norms_loop = torch.stack([compute_norm(x) for x in batch_x])
# 使用vmap自动向量化
norms_vmap = func.vmap(compute_norm)(batch_x)
assert torch.allclose(norms_loop, norms_vmap)
vmap实现机制
sequenceDiagram
autonumber
participant User as 用户代码
participant Vmap as func.vmap
participant Transform as VmapTransform
participant Dispatcher as Dispatcher
participant BatchRule as BatchingRule
participant Kernel as 原始kernel
User->>Vmap: vmap(f)(batched_input)
Vmap->>Transform: 创建VmapTransform
Transform->>Transform: 设置batch_size和in_dims
Transform->>Dispatcher: 调用f(input)
Note over Transform: 进入VmapMode上下文
Dispatcher->>Dispatcher: 检测到FuncTorchBatched键
Dispatcher->>BatchRule: 查找batching rule
alt 有专门的batching rule
BatchRule->>BatchRule: 执行批处理逻辑
BatchRule-->>Dispatcher: 批处理结果
else 回退到默认实现
BatchRule->>Kernel: 循环调用原始kernel
Kernel-->>BatchRule: 单个结果
BatchRule->>BatchRule: 组合为批次结果
BatchRule-->>Dispatcher: 批处理结果
end
Dispatcher-->>Transform: 返回结果
Transform->>Transform: 处理out_dims
Transform-->>Vmap: 返回最终结果
Vmap-->>User: 批处理结果
Batching Rules
每个算子都需要定义batching rule来处理批维度。
# 示例:add操作的batching rule
def add_batching_rule(left_args, right_args):
left_x, left_bdim = left_args
right_x, right_bdim = right_args
# 如果两个输入都有batch维度
if left_bdim is not None and right_bdim is not None:
# 确保batch维度在同一位置
if left_bdim != right_bdim:
left_x = left_x.movedim(left_bdim, 0)
right_x = right_x.movedim(right_bdim, 0)
bdim = 0
else:
bdim = left_bdim
# 执行实际的add操作
result = torch.add(left_x, right_x)
return result, bdim
# 如果只有一个输入有batch维度
elif left_bdim is not None:
result = torch.add(left_x, right_x)
return result, left_bdim
elif right_bdim is not None:
result = torch.add(left_x, right_x)
return result, right_bdim
# 都没有batch维度(不应该发生)
else:
result = torch.add(left_x, right_x)
return result, None
grad - 函数式梯度
基本使用
import torch
from torch import func
def loss_fn(weights, bias, x, y):
pred = x @ weights + bias
return ((pred - y) ** 2).mean()
# 计算关于weights的梯度
weights = torch.randn(5, requires_grad=True)
bias = torch.randn(1, requires_grad=True)
x = torch.randn(100, 5)
y = torch.randn(100)
# 传统方式
loss = loss_fn(weights, bias, x, y)
loss.backward()
grad_weights_traditional = weights.grad
# 函数式方式
grad_fn = func.grad(loss_fn, argnums=0) # 对第0个参数求梯度
grad_weights_functional = grad_fn(weights, bias, x, y)
assert torch.allclose(grad_weights_traditional, grad_weights_functional)
高阶导数
# 二阶导数(Hessian对角线)
def scalar_fn(x):
return (x ** 4).sum()
x = torch.randn(5)
# 一阶导数
first_grad = func.grad(scalar_fn)(x)
# 二阶导数
second_grad = func.grad(func.grad(scalar_fn))(x)
# 或使用hessian
hess = func.hessian(scalar_fn)(x)
hess_diag = torch.diag(hess)
assert torch.allclose(second_grad, hess_diag)
jvp/vjp - Jacobian向量积
JVP(前向模式AD)
def f(x):
return x ** 3 + 2 * x ** 2 + x
x = torch.tensor([1.0, 2.0, 3.0])
v = torch.tensor([1.0, 1.0, 1.0]) # 切向量
# JVP: J(f)(x) @ v
y, jvp_result = func.jvp(f, (x,), (v,))
# 手动计算验证
# f'(x) = 3x^2 + 4x + 1
# f'([1,2,3]) = [8, 21, 40]
# jvp = f'(x) * v = [8, 21, 40]
expected_jvp = 3 * x ** 2 + 4 * x + 1
assert torch.allclose(jvp_result, expected_jvp)
VJP(反向模式AD)
def f(x):
return x ** 2
x = torch.tensor([1.0, 2.0, 3.0])
# VJP: v^T @ J(f)(x)
y, vjp_fn = func.vjp(f, x)
v = torch.tensor([1.0, 1.0, 1.0])
vjp_result = vjp_fn(v)[0]
# 手动计算验证
# f'(x) = 2x
# vjp = v^T @ f'(x) = [1,1,1] @ [2,4,6] = [2,4,6]
expected_vjp = 2 * x
assert torch.allclose(vjp_result, expected_vjp)
变换组合
vmap + grad
# 计算每个样本的梯度(per-sample gradients)
def loss_fn(weights, x, y):
pred = x @ weights
return (pred - y) ** 2
weights = torch.randn(5)
batch_x = torch.randn(100, 5) # 100个样本
batch_y = torch.randn(100)
# 方法1:vmap(grad(...))
per_sample_grads = func.vmap(func.grad(loss_fn, argnums=0), in_dims=(None, 0, 0))(
weights, batch_x, batch_y
)
# 方法2:grad(vmap(...))也可以,但语义不同
# 这会计算批次loss的梯度,而不是每个样本的梯度
print(per_sample_grads.shape) # [100, 5] - 每个样本一个梯度向量
grad + grad(二阶导数)
def f(x):
return (x ** 4).sum()
x = torch.randn(3)
# Hessian矩阵
def hessian_fn(f):
return func.jacrev(func.grad(f))
hess = hessian_fn(f)(x)
print(hess.shape) # [3, 3]
# 等价于直接使用hessian
hess_direct = func.hessian(f)(x)
assert torch.allclose(hess, hess_direct)
编译器集成
AOTAutograd
AOTAutograd(Ahead-Of-Time Autograd)使用functorch进行提前梯度计算。
import torch._dynamo as dynamo
from torch._functorch.aot_autograd import aot_function
def model(x, weight):
return torch.nn.functional.linear(x, weight)
# AOT编译
compiled_model = aot_function(
model,
fw_compiler=dynamo.optimize("inductor"), # 前向编译器
bw_compiler=dynamo.optimize("inductor"), # 反向编译器
)
x = torch.randn(10, 5, requires_grad=True)
weight = torch.randn(3, 5, requires_grad=True)
# 使用编译后的模型
output = compiled_model(x, weight)
loss = output.sum()
loss.backward()
函数化(Functionalization)
将inplace操作转换为functional操作,便于编译器优化。
def inplace_fn(x):
x.add_(1) # inplace操作
x.mul_(2) # inplace操作
return x
# 函数化转换
functional_fn = func.functionalize(inplace_fn)
x = torch.randn(5)
x_orig = x.clone()
# 原始函数修改输入
result1 = inplace_fn(x)
assert not torch.equal(x, x_orig) # x被修改了
# 函数化版本不修改输入
x = x_orig.clone()
result2 = functional_fn(x)
assert torch.equal(x, x_orig) # x未被修改
assert torch.equal(result1, result2) # 结果相同
性能优化
批处理规则优化
# 优化的matmul batching rule
def matmul_batching_rule(left_args, right_args):
left_x, left_bdim = left_args
right_x, right_bdim = right_args
# 利用批次矩阵乘法(bmm)避免循环
if left_bdim == 0 and right_bdim == 0:
# 两个输入都在第0维有batch
return torch.bmm(left_x, right_x), 0
elif left_bdim == 0 and right_bdim is None:
# 只有左边有batch,使用广播
result = torch.matmul(left_x, right_x.unsqueeze(0))
return result, 0
# ... 其他情况
内存优化
# 使用checkpoint减少内存占用
from torch.utils.checkpoint import checkpoint
def expensive_function(x):
# 计算密集的函数
for _ in range(100):
x = torch.sin(x) + torch.cos(x)
return x
# vmap with checkpointing
vmapped_fn = func.vmap(lambda x: checkpoint(expensive_function, x))
result = vmapped_fn(torch.randn(1000, 10))
编译优化
# 使用torch.compile优化vmap
@torch.compile
def fast_vmap_fn(x):
def single_fn(item):
return torch.sin(item).sum()
return func.vmap(single_fn)(x)
# 编译后的版本会融合操作,减少kernel launch
x = torch.randn(1000, 100)
result = fast_vmap_fn(x)
实际应用案例
参数高效微调(LoRA)
def lora_linear(x, W, A, B, scale):
# W: 预训练权重 [out_features, in_features]
# A, B: LoRA矩阵 [rank, in_features], [out_features, rank]
base_output = torch.nn.functional.linear(x, W)
lora_output = torch.nn.functional.linear(
torch.nn.functional.linear(x, A.T), B.T
)
return base_output + scale * lora_output
# 计算每个LoRA参数的梯度
def compute_lora_grads(loss_fn, A, B, scale, *other_args):
# 使用grad计算A, B的梯度
grad_fn = func.grad(loss_fn, argnums=(0, 1))
grad_A, grad_B = grad_fn(A, B, scale, *other_args)
return grad_A, grad_B
# 使用vmap处理多个LoRA适配器
def multi_lora_forward(x, W, As, Bs, scales):
# As: [num_adapters, rank, in_features]
# Bs: [num_adapters, out_features, rank]
def single_lora(A, B, scale):
return lora_linear(x, W, A, B, scale)
outputs = func.vmap(single_lora)(As, Bs, scales)
return outputs.mean(0) # 平均多个适配器的输出
神经ODE求解
def neural_ode_func(t, y, theta):
# y: [batch_size, state_dim]
# theta: 神经网络参数
return neural_net(y, theta)
def solve_ode_batch(y0_batch, t_span, theta):
# 使用vmap并行求解多个初始条件
def solve_single(y0):
return odeint(neural_ode_func, y0, t_span, theta)
return func.vmap(solve_single)(y0_batch)
# 计算关于初始条件的敏感性
def sensitivity_analysis(y0, t_span, theta):
def solve_fn(y0):
return odeint(neural_ode_func, y0, t_span, theta)
# 计算解关于初始条件的Jacobian
jacobian_fn = func.jacrev(solve_fn)
return jacobian_fn(y0)
贝叶斯神经网络
def bayesian_prediction(x, weight_samples):
# weight_samples: [num_samples, ...] 权重样本
def single_forward(weights):
return neural_net(x, weights)
# vmap over weight samples
predictions = func.vmap(single_forward)(weight_samples)
# 计算预测均值和方差
mean_pred = predictions.mean(0)
var_pred = predictions.var(0)
return mean_pred, var_pred
# 计算每个权重样本的对数似然
def compute_log_likelihoods(data_batch, weight_samples, targets):
def log_likelihood(weights, x, y):
pred = neural_net(x, weights)
return -torch.nn.functional.mse_loss(pred, y)
# vmap over samples and data
log_liks = func.vmap(
func.vmap(log_likelihood, in_dims=(None, 0, 0)),
in_dims=(0, None, None)
)(weight_samples, data_batch, targets)
return log_liks.sum(-1) # 每个样本的总对数似然
调试与分析
性能分析
import torch.profiler
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
record_shapes=True
) as prof:
with torch.profiler.record_function("vmap_operation"):
result = func.vmap(some_function)(batch_input)
print(prof.key_averages().table(sort_by="cuda_time_total"))
错误处理
# 检查vmap维度兼容性
def safe_vmap(fn, in_dims=0, out_dims=0):
def wrapper(*args):
# 验证输入维度
for i, (arg, dim) in enumerate(zip(args, in_dims)):
if isinstance(dim, int) and dim >= arg.ndim:
raise ValueError(f"in_dims[{i}]={dim} >= arg.ndim={arg.ndim}")
return func.vmap(fn, in_dims=in_dims, out_dims=out_dims)(*args)
return wrapper