PyTorch-04-torch.nn模块与JIT编译
torch.nn模块
模块概览
torch.nn提供神经网络构建的高层API,包括层(Layer)、损失函数、优化器等。核心基类是nn.Module,所有神经网络模型都继承自它。
nn.Module架构
classDiagram
class Module {
-dict _parameters
-dict _buffers
-dict _modules
-OrderedDict _forward_hooks
-OrderedDict _backward_hooks
-bool training
+forward(*input) Tensor
+__call__(*input) Tensor
+parameters() Iterator
+named_parameters() Iterator
+state_dict() OrderedDict
+load_state_dict(dict)
+train(bool) Module
+eval() Module
+to(device) Module
+cuda() Module
+cpu() Module
}
class Linear {
+weight Parameter
+bias Parameter
+in_features int
+out_features int
+forward(input) Tensor
}
class Conv2d {
+weight Parameter
+bias Parameter
+in_channels int
+out_channels int
+kernel_size tuple
+forward(input) Tensor
}
class Sequential {
-ModuleList _modules
+forward(input) Tensor
}
Module <|-- Linear
Module <|-- Conv2d
Module <|-- Sequential
nn.Module核心机制
参数注册
class MyLinear(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
# Parameter自动注册到_parameters
self.weight = nn.Parameter(torch.randn(out_features, in_features))
self.bias = nn.Parameter(torch.randn(out_features))
def forward(self, x):
return x.matmul(self.weight.t()) + self.bias
__setattr__拦截:
def __setattr__(self, name, value):
if isinstance(value, Parameter):
# 注册到_parameters字典
self._parameters[name] = value
elif isinstance(value, Module):
# 注册到_modules字典
self._modules[name] = value
else:
# 普通属性
object.__setattr__(self, name, value)
forward与__call__
class Module:
def __call__(self, *input, **kwargs):
# 1. 执行forward pre-hooks
for hook in self._forward_pre_hooks.values():
result = hook(self, input)
if result is not None:
input = result
# 2. 调用forward
result = self.forward(*input, **kwargs)
# 3. 执行forward hooks
for hook in self._forward_hooks.values():
hook_result = hook(self, input, result)
if hook_result is not None:
result = hook_result
# 4. 执行backward hooks(注册到autograd)
if len(self._backward_hooks) > 0:
var = result
while not isinstance(var, torch.Tensor):
var = var[0]
grad_fn = var.grad_fn
if grad_fn is not None:
for hook in self._backward_hooks.values():
wrapper = _WrappedHook(hook)
grad_fn.register_hook(wrapper)
return result
常用Layer实现
Linear层
class Linear(Module):
def __init__(self, in_features, out_features, bias=True):
super().__init__()
self.in_features = in_features
self.out_features = out_features
# 初始化权重
self.weight = Parameter(torch.empty(out_features, in_features))
if bias:
self.bias = Parameter(torch.empty(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
# Kaiming uniform初始化
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, input):
return F.linear(input, self.weight, self.bias)
F.linear实现:
// aten/src/ATen/native/Linear.cpp
Tensor linear(const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias) {
// input: [*, in_features]
// weight: [out_features, in_features]
// output: [*, out_features]
auto output = input.matmul(weight.t());
if (bias.has_value()) {
output.add_(*bias);
}
return output;
}
Conv2d层
class Conv2d(Module):
def __init__(self, in_channels, out_channels, kernel_size,
stride=1, padding=0, dilation=1, groups=1, bias=True):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size) # (3, 3) or (3,)
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.groups = groups
# 权重形状: [out_channels, in_channels/groups, kernel_h, kernel_w]
self.weight = Parameter(torch.empty(
out_channels, in_channels // groups, *self.kernel_size
))
if bias:
self.bias = Parameter(torch.empty(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def forward(self, input):
return F.conv2d(input, self.weight, self.bias,
self.stride, self.padding, self.dilation, self.groups)
Conv2d实现(CUDA):
// aten/src/ATen/native/cuda/Conv.cu
Tensor conv2d_cuda(
const Tensor& input, // [N, C_in, H, W]
const Tensor& weight, // [C_out, C_in/groups, kH, kW]
const Tensor& bias,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
int64_t groups) {
// 计算输出形状
auto output_size = conv_output_size(
input.sizes(), weight.sizes(), stride, padding, dilation
);
Tensor output = at::empty(output_size, input.options());
// 调用cuDNN
cudnnConvolutionForward(
handle, &alpha,
input_desc, input.data_ptr(),
filter_desc, weight.data_ptr(),
conv_desc,
algorithm,
workspace, workspace_size,
&beta,
output_desc, output.data_ptr()
);
if (bias.defined()) {
output.add_(bias.view({1, -1, 1, 1}));
}
return output;
}
BatchNorm2d层
class BatchNorm2d(Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1,
affine=True, track_running_stats=True):
super().__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.affine = affine
self.track_running_stats = track_running_stats
if self.affine:
self.weight = Parameter(torch.ones(num_features))
self.bias = Parameter(torch.zeros(num_features))
if self.track_running_stats:
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
def forward(self, input):
if self.training:
# 训练模式:计算batch统计量
mean = input.mean([0, 2, 3])
var = input.var([0, 2, 3], unbiased=False)
# 更新running stats
with torch.no_grad():
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
self.num_batches_tracked += 1
else:
# 推理模式:使用running stats
mean = self.running_mean
var = self.running_var
# 归一化
output = (input - mean.view(1, -1, 1, 1)) / torch.sqrt(var.view(1, -1, 1, 1) + self.eps)
# 缩放和偏移
if self.affine:
output = output * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)
return output
模型构建示例
Sequential容器
# 简单顺序模型
model = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
# 带命名的Sequential
model = nn.Sequential(OrderedDict([
('fc1', nn.Linear(784, 256)),
('relu1', nn.ReLU()),
('fc2', nn.Linear(256, 128)),
('relu2', nn.ReLU()),
('fc3', nn.Linear(128, 10))
]))
自定义模型
class ResNetBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x):
residual = x
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += residual
out = F.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, num_blocks, num_classes=10):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, 7, stride=2, padding=3)
self.bn1 = nn.BatchNorm2d(64)
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
# 使用ModuleList管理多个block
self.blocks = nn.ModuleList([
ResNetBlock(64) for _ in range(num_blocks)
])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(64, num_classes)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = self.maxpool(x)
for block in self.blocks:
x = block(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
TorchScript与JIT编译
概览
TorchScript是PyTorch的静态图表示,支持:
- 脚本化(Script):解析Python代码生成IR
- 跟踪(Trace):记录实际执行的算子序列
- 优化:图融合、死代码消除、常量折叠
- 序列化:保存为.pt文件,跨平台部署
Script模式
@torch.jit.script
def fused_relu(x, y):
# JIT编译器解析Python代码
z = x + y
return z.relu()
# 查看生成的图IR
print(fused_relu.graph)
# graph(%x : Tensor, %y : Tensor):
# %z : Tensor = aten::add(%x, %y)
# %output : Tensor = aten::relu(%z)
# return (%output)
# 优化后(融合)
print(fused_relu.graph_for(torch.randn(10), torch.randn(10)))
# graph(%x : Tensor, %y : Tensor):
# %output : Tensor = prim::FusedAddRelu(%x, %y) # 融合为单个kernel
# return (%output)
Script化Module
class MyCell(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
new_h = torch.tanh(self.linear(x) + h)
return new_h, new_h
# Script化
scripted_cell = torch.jit.script(MyCell())
# 保存
scripted_cell.save("my_cell.pt")
# 加载(可在C++中加载)
loaded_cell = torch.jit.load("my_cell.pt")
Trace模式
def trace_example(x, y):
if x.sum() > 0: # 动态控制流
return x + y
else:
return x * y
# 跟踪(记录一次执行路径)
x = torch.randn(3, 4)
y = torch.randn(3, 4)
traced_fn = torch.jit.trace(trace_example, (x, y))
# 问题:trace不记录控制流
# 无论输入如何,traced_fn总是执行x+y分支(跟踪时执行的路径)
# 解决:使用Script或条件注解
@torch.jit.script
def script_example(x, y):
if bool(x.sum() > 0): # Script支持控制流
return x + y
else:
return x * y
IR表示
Graph IR
# Python代码
def forward(x, weight, bias):
h = torch.matmul(x, weight)
h = h + bias
return torch.relu(h)
# 对应的Graph IR
graph(%x : Tensor, %weight : Tensor, %bias : Tensor):
%h.1 : Tensor = aten::matmul(%x, %weight)
%h.2 : Tensor = aten::add(%h.1, %bias, %1)
%output : Tensor = aten::relu(%h.2)
return (%output)
IR节点类型:
aten::xxx: ATen算子prim::xxx: 原语操作(if/loop/Constant)prim::TupleConstruct: 构造tupleprim::GetAttr: 获取模块属性
优化Pass
# 原始图
graph:
%1 = aten::add(%x, %y)
%2 = aten::relu(%1)
return (%2)
# Fusion优化后
graph:
%1 = prim::FusedAddRelu(%x, %y) # 融合为单个CUDA kernel
return (%1)
# 常量折叠
graph:
%a = prim::Constant[value=2]()
%b = prim::Constant[value=3]()
%c = aten::add(%a, %b)
%d = aten::mul(%x, %c)
return (%d)
# 优化后
graph:
%c = prim::Constant[value=5]() # 2+3在编译时计算
%d = aten::mul(%x, %c)
return (%d)
性能优化
图融合
# 未融合:3个kernel launch
@torch.jit.script
def unfused(x, y, z):
a = x + y
b = a * z
return b.relu()
# JIT自动融合为1个kernel
# fused_kernel: out[i] = relu((x[i] + y[i]) * z[i])
Frozen Modules
model = MyModel()
model.eval()
# Freeze:内联常量参数和buffer
frozen_model = torch.jit.freeze(torch.jit.script(model))
# 效果:
# 1. 参数从GetAttr变为Constant
# 2. 允许更多常量折叠优化
# 3. 减少内存访问
CUDA Graph Capture
# 捕获CUDA kernel序列为静态图
model = MyModel().cuda()
x = torch.randn(32, 3, 224, 224, device='cuda')
# Warmup
for _ in range(10):
y = model(x)
# Capture
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
y = model(x)
# Replay(减少kernel launch开销)
g.replay()
模型部署
导出ONNX
model = torchvision.models.resnet18(pretrained=True)
model.eval()
dummy_input = torch.randn(1, 3, 224, 224)
# 导出ONNX
torch.onnx.export(
model,
dummy_input,
"resnet18.onnx",
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
)
C++部署(LibTorch)
#include <torch/script.h>
int main() {
// 加载TorchScript模型
torch::jit::script::Module module;
module = torch::jit::load("model.pt");
module.eval();
// 准备输入
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::randn({1, 3, 224, 224}));
// 前向传播
at::Tensor output = module.forward(inputs).toTensor();
std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';
return 0;
}
移动端部署
# 优化移动端模型
from torch.utils.mobile_optimizer import optimize_for_mobile
model = MyModel()
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
# 移动端优化
optimized_model = optimize_for_mobile(traced_script_module)
optimized_model._save_for_lite_interpreter("model_mobile.ptl")
Android使用:
import org.pytorch.Module;
import org.pytorch.Tensor;
Module module = Module.load("model_mobile.ptl");
Tensor input = Tensor.fromBlob(inputData, shape);
Tensor output = module.forward(IValue.from(input)).toTensor();