PyTorch-05-TorchGen代码生成器
模块概览
TorchGen是PyTorch的代码生成工具,负责从YAML配置文件和Python模板自动生成C++和Python绑定代码。它是PyTorch构建系统的核心组件,负责生成算子的C++实现、Python绑定、类型签名等。
核心职责
- 算子代码生成:从native_functions.yaml生成C++算子实现
- Python绑定:生成torch._C的Python-C++绑定代码
- 类型签名:生成.pyi类型提示文件
- Dispatcher注册:生成算子注册到Dispatcher的代码
- 后端适配:为不同backend(CPU/CUDA/XPU)生成特化代码
架构图
flowchart TB
subgraph 输入文件
A1[native_functions.yaml<br/>算子定义]
A2[derivatives.yaml<br/>梯度定义]
A3[templates/*.cpp<br/>C++模板]
A4[templates/*.py.in<br/>Python模板]
end
subgraph TorchGen核心
B1[gen.py<br/>主入口]
B2[model.py<br/>数据模型]
B3[api/*.py<br/>API生成器]
B4[code_template.py<br/>模板引擎]
end
subgraph 代码生成器
C1[gen_backend_stubs.py<br/>后端桩代码]
C2[gen_python_functions.py<br/>Python绑定]
C3[gen_autograd.py<br/>Autograd代码]
C4[gen_lazy_tensor.py<br/>懒计算代码]
end
subgraph 输出文件
D1[aten/src/ATen/*.cpp<br/>C++实现]
D2[torch/_C/*.cpp<br/>Python绑定]
D3[torch/*.pyi<br/>类型签名]
D4[build/*.h<br/>头文件]
end
A1 --> B1
A2 --> B1
A3 --> B4
A4 --> B4
B1 --> B2
B2 --> B3
B3 --> C1
B3 --> C2
B3 --> C3
B3 --> C4
C1 --> D1
C2 --> D2
C3 --> D1
C4 --> D1
B4 --> D1
B4 --> D2
B4 --> D3
B4 --> D4
style B1 fill:#e8f5e9
style B2 fill:#e8f5e9
style C1 fill:#e1f5ff
style C2 fill:#e1f5ff
style D1 fill:#fff4e1
核心数据模型
NativeFunction
表示native_functions.yaml中定义的一个算子。
@dataclass(frozen=True)
class NativeFunction:
func: FunctionSchema # 函数签名
use_c10_dispatcher: bool # 是否使用c10 dispatcher
python_module: str | None # Python模块名
category_override: str | None # 分类覆盖
variants: set[Variant] # 变体(function/method)
structured: bool # 是否结构化
structured_delegate: str | None # 结构化委托
out: NativeFunction | None # out变体
inplace: NativeFunction | None # inplace变体
abstract: bool # 是否抽象
device_check: DeviceCheckType # 设备检查类型
device_guard: bool # 是否需要设备保护
kernel: str | None # 内核名称
dispatch: dict[DispatchKey, str] # 分发表
autogen: list[str] # 自动生成列表
cpp_no_default_args: set[str] # C++无默认参数
is_abstract: bool # 是否抽象函数
has_composite_implicit_autograd_kernel: bool # 是否有复合隐式autograd内核
has_composite_explicit_autograd_kernel: bool # 是否有复合显式autograd内核
has_composite_explicit_autograd_non_functional_kernel: bool
@property
def root_name(self) -> str:
return self.func.name.name.base # 例如:add_
@property
def base_name(self) -> str:
return self.func.name.name.base # 例如:add
FunctionSchema
表示函数的完整签名。
@dataclass(frozen=True)
class FunctionSchema:
name: OperatorName # 操作符名称
arguments: Arguments # 参数列表
returns: tuple[Return, ...] # 返回值列表
def signature(self, *, strip_default: bool = False) -> str:
# 生成完整签名字符串
# 例如:"add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"
OperatorName
@dataclass(frozen=True)
class OperatorName:
name: BaseName # 基础名称
overload_name: str # 重载名称
@property
def unambiguous_name(self) -> str:
# 生成无歧义名称
if self.overload_name:
return f"{self.name.base}.{self.overload_name}"
return str(self.name.base)
BaseName
@dataclass(frozen=True)
class BaseName:
base: str # 基础名称(如:"add")
inplace: bool = False # 是否inplace(如:"add_")
def __str__(self) -> str:
return self.base + ("_" if self.inplace else "")
代码生成流程
主流程时序图
sequenceDiagram
autonumber
participant Build as 构建系统
participant Gen as gen.py
participant Model as model.py
participant API as api/*
participant Template as CodeTemplate
participant Output as 输出文件
Build->>Gen: python -m torchgen.gen
Gen->>Gen: 解析命令行参数
Gen->>Model: 加载native_functions.yaml
Model->>Model: 解析YAML结构
loop 每个算子定义
Model->>Model: 创建NativeFunction对象
Model->>Model: 解析FunctionSchema
Model->>Model: 解析DispatchKey映射
end
Model-->>Gen: List[NativeFunction]
Gen->>Model: 加载derivatives.yaml
Model->>Model: 解析梯度定义
Model-->>Gen: List[Derivative]
Gen->>API: 生成各种API
loop 每个代码生成器
API->>Template: 准备模板变量
Template->>Template: 渲染C++/Python代码
Template-->>API: 生成的代码字符串
API->>Output: 写入输出文件
end
Gen-->>Build: 代码生成完成
算子定义解析
native_functions.yaml示例:
- func: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
variants: function, method
structured: true
dispatch:
CPU: add_cpu
CUDA: add_cuda
MPS: add_mps
autogen: add.out
- func: add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
variants: method
structured_delegate: add.Tensor
dispatch:
CPU: add__cpu
CUDA: add__cuda
解析过程:
def parse_native_function(es: object) -> NativeFunction:
# 1. 解析函数签名
func_schema = FunctionSchema.parse(es['func'])
# 2. 解析dispatch映射
dispatch = {}
if 'dispatch' in es:
for key, kernel in es['dispatch'].items():
dispatch[DispatchKey[key]] = kernel
# 3. 解析variants
variants = set()
if 'variants' in es:
for variant in es['variants'].split(', '):
variants.add(Variant[variant])
# 4. 构造NativeFunction对象
return NativeFunction(
func=func_schema,
dispatch=dispatch,
variants=variants,
structured=es.get('structured', False),
# ... 其他字段
)
主要代码生成器
1. gen_backend_stubs.py
生成后端特定的桩代码。
def gen_backend_stubs(
native_functions: Sequence[NativeFunction],
backend_name: str,
dispatch_key: DispatchKey,
output_dir: str,
):
# 为每个native function生成backend stub
for f in native_functions:
if dispatch_key not in f.dispatch:
continue
# 生成函数声明
sig = CppSignatureGroup.from_native_function(f)
decl = f"TORCH_API {sig.signature()} {f.dispatch[dispatch_key]}({sig.arguments()});"
# 生成函数实现(调用kernel)
impl = f"""
{sig.defn()} {{
// 自动生成的backend stub
return at::native::{f.dispatch[dispatch_key]}({sig.call_args()});
}}
"""
# 写入文件
write_cpp_file(f"{output_dir}/{f.root_name}.cpp", impl)
2. gen_python_functions.py
生成Python绑定代码。
def gen_python_binding(f: NativeFunction) -> str:
# 1. 生成Python签名
python_sig = PythonSignature.from_native_function(f)
# 2. 生成C++分发lambda
cpp_sig = CppSignature.from_native_function(f)
lambda_def = f"""
auto dispatch_{f.func.name.unambiguous_name()} =
[]({cpp_sig.arguments()}) -> {cpp_sig.returns()} {{
pybind11::gil_scoped_release no_gil;
return {cpp_sig.call_expression()};
}};
"""
# 3. 生成参数解析代码
parser_code = f"""
static PythonArgParser parser({{
"{python_sig.signature_str()}",
}}, /*traceable=*/true);
ParsedArgs<{len(python_sig.arguments)}> parsed_args;
auto _r = parser.parse(nullptr, args, kwargs, parsed_args);
"""
# 4. 生成分发调用
dispatch_code = f"""
return wrap(dispatch_{f.func.name.unambiguous_name()}(
{python_sig.binding_expressions()}
));
"""
return lambda_def + parser_code + dispatch_code
3. gen_autograd.py
生成autograd Function代码。
def gen_autograd_function(f: NativeFunction, derivative: Derivative) -> str:
# 1. 生成Forward函数
forward_code = f"""
struct {f.func.name.base.title()}Backward : public TraceableFunction {{
using TraceableFunction::TraceableFunction;
variable_list apply(variable_list&& grads) override {{
// 生成梯度计算代码
{generate_gradient_code(derivative)}
}}
std::string name() const override {{
return "{f.func.name.base.title()}Backward";
}}
}};
"""
# 2. 生成包装函数
wrapper_code = f"""
Tensor {f.func.name.base}_autograd({cpp_sig.arguments()}) {{
auto result = at::{f.func.name.base}({cpp_sig.call_args()});
if (requires_grad({{self, other}})) {{
auto grad_fn = std::make_shared<{f.func.name.base.title()}Backward>();
grad_fn->save_variables({{self, other}});
set_history(result, grad_fn);
}}
return result;
}}
"""
return forward_code + wrapper_code
模板引擎
CodeTemplate
TorchGen使用自定义的模板引擎处理代码模板。
class CodeTemplate:
def __init__(self, template: str):
self.template = template
def substitute(self, env: dict[str, Any]) -> str:
# 1. 处理${变量}替换
result = self.template
for key, value in env.items():
placeholder = "${" + key + "}"
result = result.replace(placeholder, str(value))
# 2. 处理条件语句 $if{condition}...${endif}
result = self._process_conditionals(result, env)
# 3. 处理循环语句 $for{item in items}...${endfor}
result = self._process_loops(result, env)
return result
def _process_conditionals(self, text: str, env: dict) -> str:
# 处理 $if{condition}...${endif} 语法
import re
pattern = r'\$if\{([^}]+)\}(.*?)\$\{endif\}'
def replace_conditional(match):
condition = match.group(1)
content = match.group(2)
# 评估条件
try:
if eval(condition, env):
return content
else:
return ""
except:
return ""
return re.sub(pattern, replace_conditional, text, flags=re.DOTALL)
模板示例
aten/src/ATen/templates/Functions.cpp:
// 自动生成警告
// @generated from ${template_path}
#include <ATen/Functions.h>
#include <ATen/core/dispatch/Dispatcher.h>
namespace at {
$for{f in native_functions}
${cpp_function_definition(f)}
$endfor
} // namespace at
生成结果:
// @generated from aten/src/ATen/templates/Functions.cpp
#include <ATen/Functions.h>
#include <ATen/core/dispatch/Dispatcher.h>
namespace at {
Tensor add(const Tensor& self, const Tensor& other, const Scalar& alpha) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("aten::add", "Tensor");
return op.typed<Tensor(const Tensor&, const Tensor&, const Scalar&)>()
.call(self, other, alpha);
}
Tensor& add_(Tensor& self, const Tensor& other, const Scalar& alpha) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("aten::add_", "Tensor");
return op.typed<Tensor&(Tensor&, const Tensor&, const Scalar&)>()
.call(self, other, alpha);
}
// ... 更多函数
} // namespace at
关键API生成
Python API生成
# torch/_C/_VariableFunctions.cpp
PyObject* THPVariable_add(PyObject* self_, PyObject* args, PyObject* kwargs) {
HANDLE_TH_ERRORS
// 参数解析器
static PythonArgParser parser({
"add(Tensor input, Tensor other, *, Scalar alpha=1)",
"add(Tensor input, Scalar other, Scalar alpha=1)",
}, /*traceable=*/true);
ParsedArgs<3> parsed_args;
auto _r = parser.parse(nullptr, args, kwargs, parsed_args);
switch (_r.idx) {
case 0: {
// add(Tensor, Tensor, Scalar)
auto dispatch_add = [](const Tensor& self, const Tensor& other, const Scalar& alpha) {
pybind11::gil_scoped_release no_gil;
return self.add(other, alpha);
};
return wrap(dispatch_add(_r.tensor(0), _r.tensor(1), _r.scalar(2)));
}
case 1: {
// add(Tensor, Scalar, Scalar)
auto dispatch_add = [](const Tensor& self, const Scalar& other, const Scalar& alpha) {
pybind11::gil_scoped_release no_gil;
return self.add(other, alpha);
};
return wrap(dispatch_add(_r.tensor(0), _r.scalar(1), _r.scalar(2)));
}
}
END_HANDLE_TH_ERRORS
}
类型签名生成
# torch/_C/__init__.pyi
def add(
input: Tensor,
other: Union[Tensor, Number],
*,
alpha: Number = 1,
out: Optional[Tensor] = None
) -> Tensor: ...
def add_(
self: Tensor,
other: Union[Tensor, Number],
*,
alpha: Number = 1
) -> Tensor: ...
性能优化
1. 增量生成
TorchGen支持增量代码生成,只重新生成修改过的文件。
def should_regenerate_file(input_files: list[str], output_file: str) -> bool:
if not os.path.exists(output_file):
return True
output_mtime = os.path.getmtime(output_file)
for input_file in input_files:
if os.path.getmtime(input_file) > output_mtime:
return True
return False
2. 并行生成
多个生成器可以并行运行。
import multiprocessing
def generate_all_files():
tasks = [
(gen_backend_stubs, backend_args),
(gen_python_functions, python_args),
(gen_autograd, autograd_args),
]
with multiprocessing.Pool() as pool:
pool.starmap(run_generator, tasks)
3. 缓存优化
解析结果可以缓存避免重复计算。
@functools.lru_cache(maxsize=None)
def parse_native_functions(yaml_path: str) -> list[NativeFunction]:
with open(yaml_path) as f:
data = yaml.safe_load(f)
return [parse_native_function(item) for item in data]
扩展与定制
添加新的Backend
- 定义DispatchKey:
# torchgen/model.py
class DispatchKey(Enum):
# ... 现有keys
MyBackend = auto()
- 添加code generator:
# torchgen/gen_my_backend.py
def gen_my_backend_kernels(native_functions: Sequence[NativeFunction]):
for f in native_functions:
if DispatchKey.MyBackend in f.dispatch:
generate_kernel_for_function(f)
- 在native_functions.yaml中声明:
- func: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
dispatch:
MyBackend: add_my_backend
自定义模板
# 定义新的模板处理器
class MyCodeTemplate(CodeTemplate):
def __init__(self, template_path: str):
with open(template_path) as f:
super().__init__(f.read())
def generate_for_function(self, f: NativeFunction) -> str:
env = {
'function_name': f.func.name.base,
'arguments': f.func.arguments,
'returns': f.func.returns,
# ... 更多变量
}
return self.substitute(env)
构建集成
CMake集成
# 在CMakeLists.txt中
set(TORCHGEN_OUTPUTS
${CMAKE_CURRENT_BINARY_DIR}/Functions.cpp
${CMAKE_CURRENT_BINARY_DIR}/Functions.h
${CMAKE_CURRENT_BINARY_DIR}/python_functions.cpp
)
add_custom_command(
OUTPUT ${TORCHGEN_OUTPUTS}
COMMAND ${Python_EXECUTABLE} -m torchgen.gen
--source-path ${CMAKE_CURRENT_SOURCE_DIR}
--output-dir ${CMAKE_CURRENT_BINARY_DIR}
DEPENDS
${CMAKE_CURRENT_SOURCE_DIR}/native_functions.yaml
${CMAKE_CURRENT_SOURCE_DIR}/derivatives.yaml
COMMENT "Generating PyTorch operators"
)
add_custom_target(generate_pytorch_ops
DEPENDS ${TORCHGEN_OUTPUTS}
)
依赖关系
graph TD
A[native_functions.yaml] --> B[TorchGen]
C[derivatives.yaml] --> B
D[templates/*.cpp] --> B
E[templates/*.py.in] --> B
B --> F[ATen C++ code]
B --> G[Python bindings]
B --> H[Type signatures]
F --> I[libtorch.so]
G --> I
H --> J[torch package]
I --> K[PyTorch runtime]
J --> K