1. 依赖注入概述
依赖注入(Dependency Injection, DI)是 FastAPI 最强大和创新的特性之一。它允许开发者以声明式的方式定义函数的依赖关系,FastAPI 会自动解析和注入这些依赖,从而实现代码的解耦和复用。
1.1 依赖注入系统架构
graph TB
subgraph "依赖定义层"
A[依赖函数定义] --> B[Depends()]
B --> C[SecurityBase]
C --> D[OAuth2, APIKey等]
end
subgraph "依赖分析层"
E[get_dependant()] --> F[分析函数签名]
F --> G[识别依赖参数]
G --> H[构建依赖树]
H --> I[Dependant对象]
end
subgraph "依赖解析层"
J[solve_dependencies()] --> K[递归解析依赖]
K --> L[缓存机制]
L --> M[作用域管理]
M --> N[生命周期控制]
end
subgraph "依赖注入层"
O[注入到路径操作] --> P[参数绑定]
P --> Q[调用用户函数]
end
A --> E
I --> J
N --> O
1.2 核心概念
依赖(Dependency): 一个函数或类,提供某种功能或资源
依赖项(Dependant): 需要依赖的函数或类
依赖声明: 通过 Depends()
声明依赖关系
依赖解析: FastAPI 自动分析和解决依赖关系的过程
依赖注入: 将解析后的依赖值注入到目标函数的过程
2. Depends() 函数深度分析
2.1 Depends 类的实现
class Depends:
"""
依赖注入标记类
用于标记函数参数是一个依赖注入,需要由 FastAPI 自动解析和注入
Attributes:
dependency: 依赖函数或类
use_cache: 是否使用缓存,同一请求内相同依赖只执行一次
"""
def __init__(
self,
dependency: Optional[Callable[..., Any]] = None, # 依赖函数
*,
use_cache: bool = True, # 是否缓存结果
):
"""
初始化依赖对象
Args:
dependency: 依赖函数。如果为None,将使用参数的类型注解作为依赖
use_cache: 是否缓存依赖结果。在同一请求内,相同依赖只执行一次
Examples:
# 显式依赖函数
def get_db() -> Database:
return Database()
def get_user(user_id: int, db: Database = Depends(get_db)):
return db.get_user(user_id)
# 类型注解依赖(dependency=None)
def get_user(user_id: int, db: Database = Depends()):
return Database().get_user(user_id) # 会自动调用 Database()
"""
self.dependency = dependency
self.use_cache = use_cache
def __repr__(self) -> str:
attr = getattr(self.dependency, "__name__", type(self.dependency).__name__)
cache = "" if self.use_cache else ", use_cache=False"
return f"{self.__class__.__name__}({attr}{cache})"
2.2 依赖声明的多种形式
FastAPI 支持多种依赖声明方式:
from fastapi import Depends, FastAPI
from sqlalchemy.orm import Session
from typing import Annotated
app = FastAPI()
# === 1. 函数依赖 ===
def get_database_session() -> Session:
"""
数据库会话依赖
Returns:
Session: SQLAlchemy 数据库会话
"""
db = SessionLocal()
try:
yield db # 使用生成器可以在请求结束后清理资源
finally:
db.close()
# === 2. 类依赖 ===
class DatabaseService:
"""数据库服务类依赖"""
def __init__(self, session: Session = Depends(get_database_session)):
"""
初始化数据库服务
Args:
session: 数据库会话,通过依赖注入获得
"""
self.session = session
def get_user(self, user_id: int):
"""获取用户信息"""
return self.session.query(User).filter(User.id == user_id).first()
# === 3. 子依赖 ===
def get_current_user_id(token: str = Depends(oauth2_scheme)) -> int:
"""
从令牌中提取当前用户ID
Args:
token: OAuth2 令牌,通过依赖注入获得
Returns:
int: 用户ID
Raises:
HTTPException: 当令牌无效时
"""
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_id = payload.get("user_id")
if user_id is None:
raise HTTPException(status_code=401, detail="Invalid token")
return user_id
def get_current_user(
user_id: int = Depends(get_current_user_id), # 依赖于 get_current_user_id
db_service: DatabaseService = Depends(DatabaseService) # 依赖于 DatabaseService
) -> User:
"""
获取当前用户对象
这个函数展示了依赖链:
get_current_user -> get_current_user_id -> oauth2_scheme
-> DatabaseService -> get_database_session
Args:
user_id: 用户ID,通过依赖链获得
db_service: 数据库服务,通过依赖注入获得
Returns:
User: 用户对象
"""
user = db_service.get_user(user_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
return user
# === 4. 路径操作中使用依赖 ===
@app.get("/users/me")
async def read_users_me(
current_user: User = Depends(get_current_user) # 自动解析整个依赖链
) -> User:
"""
获取当前用户信息
依赖解析流程:
1. 解析 get_current_user 依赖
2. 解析 get_current_user_id 子依赖
3. 解析 oauth2_scheme 子依赖
4. 解析 DatabaseService 依赖
5. 解析 get_database_session 子依赖
6. 按顺序执行所有依赖函数
7. 注入结果到路径操作函数
Args:
current_user: 当前用户对象,通过复杂依赖链自动注入
Returns:
User: 当前用户信息
"""
return current_user
# === 5. 类型注解依赖(推荐用法)===
@app.get("/items/")
async def read_items(
# 使用 Annotated 类型注解,提供更好的类型提示
db: Annotated[Session, Depends(get_database_session)],
current_user: Annotated[User, Depends(get_current_user)]
) -> List[Item]:
"""
获取项目列表
使用 Annotated 类型注解的优势:
1. 提供完整的类型信息
2. IDE 支持更好
3. 代码更清晰
Args:
db: 数据库会话
current_user: 当前用户
Returns:
List[Item]: 项目列表
"""
return db.query(Item).filter(Item.owner_id == current_user.id).all()
3. Dependant 模型深度分析
3.1 Dependant 数据结构
@dataclass
class Dependant:
"""
依赖对象模型
包含函数或类的所有依赖信息,用于依赖解析和注入
这是 FastAPI 依赖系统的核心数据结构,包含了一个函数
或类的所有参数信息和依赖关系
"""
# === 基本信息 ===
path: str = "" # 路径(仅用于路径操作)
call: Optional[Callable[..., Any]] = None # 要调用的函数或类
name: Optional[str] = None # 依赖名称
# === 参数分类 ===
path_params: List[ModelField] = field(default_factory=list) # 路径参数
query_params: List[ModelField] = field(default_factory=list) # 查询参数
header_params: List[ModelField] = field(default_factory=list) # 请求头参数
cookie_params: List[ModelField] = field(default_factory=list) # Cookie参数
body_params: List[ModelField] = field(default_factory=list) # 请求体参数
form_params: List[ModelField] = field(default_factory=list) # 表单参数
file_params: List[ModelField] = field(default_factory=list) # 文件参数
# === 依赖关系 ===
dependencies: List["Dependant"] = field(default_factory=list) # 子依赖列表
# === 特殊参数 ===
request_param_name: Optional[str] = None # Request 对象参数名
websocket_param_name: Optional[str] = None # WebSocket 对象参数名
http_connection_param_name: Optional[str] = None # HTTPConnection 参数名
response_param_name: Optional[str] = None # Response 对象参数名
background_tasks_param_name: Optional[str] = None# BackgroundTasks 参数名
security_scopes_param_name: Optional[str] = None # SecurityScopes 参数名
# === 安全相关 ===
security_requirements: List[SecurityRequirement] = field(default_factory=list)
# === 缓存控制 ===
use_cache: bool = True # 是否使用缓存
@dataclass
class SecurityRequirement:
"""
安全需求模型
定义访问某个资源所需的安全权限
Attributes:
security_scheme: 安全方案(OAuth2、API Key等)
scopes: 所需的权限范围列表
"""
security_scheme: SecurityBase
scopes: List[str] = field(default_factory=list)
3.2 依赖分析过程 - get_dependant()
get_dependant()
是依赖分析的核心函数,它将一个普通的 Python 函数转换为 Dependant 对象:
def get_dependant(
*,
path: str = "", # URL路径
call: Optional[Callable[..., Any]] = None, # 要分析的函数
dependencies: Optional[Sequence[params.Depends]] = None, # 额外依赖
name: Optional[str] = None, # 依赖名称
security_scopes: Optional[List[str]] = None, # 安全范围
) -> Dependant:
"""
分析函数或类的依赖关系
这个函数是 FastAPI 依赖系统的核心,它:
1. 解析函数签名
2. 分析每个参数的类型和注解
3. 识别各种参数类型(路径、查询、依赖等)
4. 递归分析子依赖
5. 构建完整的依赖树
Args:
path: URL路径模式,用于提取路径参数名
call: 要分析的函数或类
dependencies: 额外的依赖列表
name: 依赖的名称
security_scopes: 安全权限范围
Returns:
Dependant: 包含所有依赖信息的对象
Examples:
# 分析简单函数
def get_user(user_id: int, db: Session = Depends(get_db)):
return db.query(User).get(user_id)
dependant = get_dependant(call=get_user)
# dependant.path_params: [] # 没有路径参数
# dependant.query_params: [ModelField(name="user_id", type_=int)]
# dependant.dependencies: [Dependant(call=get_db)]
"""
# 获取路径中的参数名
path_param_names = get_path_param_names(path)
# 获取函数签名
signature = inspect.signature(call) if call else None
# 创建 Dependant 对象
dependant = Dependant(path=path, call=call, name=name)
# === 处理额外依赖 ===
if dependencies:
for depends in dependencies:
# 递归分析每个子依赖
sub_dependant = get_sub_dependant(
depends=depends,
dependency_overrides_provider=None,
path=path,
name=getattr(depends.dependency, "__name__", None),
security_scopes=security_scopes,
)
dependant.dependencies.append(sub_dependant)
# 如果没有函数可分析,返回基础依赖对象
if not signature:
return dependant
# === 遍历函数的每个参数 ===
for param_name, param in signature.parameters.items():
param_default = param.default
param_annotation = param.annotation if param.annotation != param.empty else Any
# === 处理特殊类型参数 ===
# 1. Request 对象
if lenient_issubclass(param_annotation, Request):
dependant.request_param_name = param_name
continue
# 2. WebSocket 对象
if lenient_issubclass(param_annotation, WebSocket):
dependant.websocket_param_name = param_name
continue
# 3. HTTPConnection 对象
if lenient_issubclass(param_annotation, HTTPConnection):
dependant.http_connection_param_name = param_name
continue
# 4. Response 对象
if lenient_issubclass(param_annotation, Response):
dependant.response_param_name = param_name
continue
# 5. BackgroundTasks
if lenient_issubclass(param_annotation, BackgroundTasks):
dependant.background_tasks_param_name = param_name
continue
# 6. SecurityScopes
if lenient_issubclass(param_annotation, SecurityScopes):
dependant.security_scopes_param_name = param_name
continue
# === 处理 Depends 依赖注入 ===
if isinstance(param_default, params.Depends):
# 这是一个依赖注入参数
sub_dependant = get_sub_dependant(
depends=param_default,
dependency_overrides_provider=None,
name=param_name,
path=path,
security_scopes=security_scopes,
)
dependant.dependencies.append(sub_dependant)
continue
# === 处理各种参数类型 ===
# 路径参数:在路径模式中定义的参数
if param_name in path_param_names:
add_param_to_fields(
field_info=params.Path(),
field_name=param_name,
annotation=param_annotation,
field_list=dependant.path_params,
)
continue
# 显式标记的参数类型
param_field_info = None
if isinstance(param_default, params.Query):
param_field_info = param_default
field_list = dependant.query_params
elif isinstance(param_default, params.Header):
param_field_info = param_default
field_list = dependant.header_params
elif isinstance(param_default, params.Cookie):
param_field_info = param_default
field_list = dependant.cookie_params
elif isinstance(param_default, params.Body):
param_field_info = param_default
field_list = dependant.body_params
elif isinstance(param_default, params.Form):
param_field_info = param_default
field_list = dependant.form_params
elif isinstance(param_default, params.File):
param_field_info = param_default
field_list = dependant.file_params
# 如果有明确的参数类型标记
if param_field_info is not None:
add_param_to_fields(
field_info=param_field_info,
field_name=param_name,
annotation=param_annotation,
field_list=field_list,
)
continue
# === 自动推断参数类型 ===
# 如果没有显式标记,根据类型和默认值推断
if is_scalar_field(field=create_model_field(name=param_name, type_=param_annotation)):
# 标量类型(int, str, bool等)-> 查询参数
add_param_to_fields(
field_info=params.Query(default=param_default),
field_name=param_name,
annotation=param_annotation,
field_list=dependant.query_params,
)
else:
# 复杂类型(Pydantic模型等)-> 请求体参数
add_param_to_fields(
field_info=params.Body(default=param_default),
field_name=param_name,
annotation=param_annotation,
field_list=dependant.body_params,
)
return dependant
def get_sub_dependant(
*,
depends: params.Depends, # Depends 对象
dependency_overrides_provider: Optional[Any] = None, # 依赖覆盖提供者
name: Optional[str] = None, # 依赖名称
path: str = "", # 路径
security_scopes: Optional[List[str]] = None, # 安全范围
) -> Dependant:
"""
获取子依赖的 Dependant 对象
处理 Depends() 标记的参数,递归分析其依赖关系
Args:
depends: Depends 对象,包含依赖函数和配置
dependency_overrides_provider: 用于测试时覆盖依赖
name: 子依赖的名称
path: URL路径
security_scopes: 安全权限范围
Returns:
Dependant: 子依赖的依赖对象
"""
# 获取实际的依赖函数
dependency = depends.dependency
# 如果没有指定依赖函数,使用参数的类型注解
if dependency is None:
# 这种情况下,参数的类型注解就是要调用的类或函数
# 例如: db: Database = Depends() -> 会调用 Database()
dependency = param_annotation
# 检查是否为安全依赖
security_requirement = None
if isinstance(dependency, SecurityBase):
# 这是一个安全依赖(OAuth2、API Key等)
security_requirement = SecurityRequirement(
security_scheme=dependency,
scopes=security_scopes or [],
)
# 递归分析子依赖
sub_dependant = get_dependant(
path=path,
call=dependency,
name=name,
security_scopes=security_scopes,
)
# 设置缓存策略
sub_dependant.use_cache = depends.use_cache
# 添加安全需求
if security_requirement:
sub_dependant.security_requirements.append(security_requirement)
return sub_dependant
4. 依赖解析过程 - solve_dependencies()
4.1 依赖解析算法
@dataclass
class SolvedDependency:
"""
依赖解析结果
包含解析后的参数值和相关信息
Attributes:
values: 解析后的参数值字典
errors: 解析过程中的错误列表
background_tasks: 后台任务集合
security_scopes: 安全权限范围
response: 响应对象(用于WebSocket)
"""
values: Dict[str, Any] = field(default_factory=dict)
errors: List[ErrorWrapper] = field(default_factory=list)
background_tasks: Optional[BackgroundTasks] = None
security_scopes: List[str] = field(default_factory=list)
response: Optional[Response] = None
async def solve_dependencies(
*,
request: Union[Request, WebSocket], # 请求对象
dependant: Dependant, # 依赖对象
dependency_overrides_provider: Optional[Any] = None, # 依赖覆盖
async_exit_stack: AsyncExitStack, # 异步上下文栈
embed_body_fields: bool = False, # 是否嵌入body字段
) -> SolvedDependency:
"""
解析所有依赖
这是 FastAPI 依赖注入系统的核心函数,负责:
1. 解析各种类型的参数(路径、查询、头部、Cookie、body等)
2. 递归解析所有子依赖
3. 处理依赖缓存
4. 管理依赖的生命周期
5. 处理安全依赖
解析过程采用深度优先算法,确保子依赖在父依赖之前被解析
Args:
request: HTTP请求或WebSocket连接对象
dependant: 要解析的依赖对象
dependency_overrides_provider: 依赖覆盖提供者(主要用于测试)
async_exit_stack: 异步退出栈,用于管理上下文管理器的生命周期
embed_body_fields: 是否嵌入请求体字段
Returns:
SolvedDependency: 包含所有解析结果的对象
"""
# 初始化解析结果
values: Dict[str, Any] = {}
errors: List[ErrorWrapper] = []
background_tasks = BackgroundTasks()
security_scopes: List[str] = []
response: Optional[Response] = None
# === 1. 解析各种参数类型 ===
# 路径参数解析
if dependant.path_params:
path_values, path_errors = request_params_to_args(
required_params=dependant.path_params,
received_params=request.path_params, # 从URL路径提取
)
values.update(path_values)
errors.extend(path_errors)
# 查询参数解析
if dependant.query_params:
query_values, query_errors = request_params_to_args(
required_params=dependant.query_params,
received_params=request.query_params, # 从查询字符串提取
)
values.update(query_values)
errors.extend(query_errors)
# 请求头参数解析
if dependant.header_params:
header_values, header_errors = request_params_to_args(
required_params=dependant.header_params,
received_params=request.headers, # 从HTTP头提取
)
values.update(header_values)
errors.extend(header_errors)
# Cookie参数解析
if dependant.cookie_params:
cookie_values, cookie_errors = request_params_to_args(
required_params=dependant.cookie_params,
received_params=request.cookies, # 从Cookie提取
)
values.update(cookie_values)
errors.extend(cookie_errors)
# 请求体参数解析
if dependant.body_params:
if isinstance(request, Request): # HTTP请求
body_values, body_errors = await request_body_to_args(
body_fields=dependant.body_params,
received_body=await request.body(),
embed_body_fields=embed_body_fields,
)
values.update(body_values)
errors.extend(body_errors)
# 表单参数解析
if dependant.form_params:
if isinstance(request, Request):
form = await request.form()
form_values, form_errors = await request_form_to_args(
required_params=dependant.form_params,
received_form=form,
)
values.update(form_values)
errors.extend(form_errors)
# === 2. 递归解析子依赖 ===
dependency_cache: Dict[Callable[..., Any], Any] = {}
for sub_dependant in dependant.dependencies:
# 检查是否有依赖覆盖(主要用于测试)
call = sub_dependant.call
if (
dependency_overrides_provider
and call in dependency_overrides_provider.dependency_overrides
):
# 使用覆盖的依赖
call = dependency_overrides_provider.dependency_overrides[call]
# 检查缓存
cache_key = call
if sub_dependant.use_cache and cache_key in dependency_cache:
# 使用缓存的结果
solved_result = dependency_cache[cache_key]
else:
# 需要重新解析
if call != sub_dependant.call:
# 使用覆盖的依赖,重新创建 Dependant 对象
use_sub_dependant = get_dependant(
path="",
call=call,
name=sub_dependant.name,
)
else:
use_sub_dependant = sub_dependant
# 递归解析子依赖
solved_result = await solve_dependencies(
request=request,
dependant=use_sub_dependant,
dependency_overrides_provider=dependency_overrides_provider,
async_exit_stack=async_exit_stack,
embed_body_fields=embed_body_fields,
)
# 缓存结果(如果启用缓存)
if sub_dependant.use_cache:
dependency_cache[cache_key] = solved_result
# 处理解析错误
if solved_result.errors:
errors.extend(solved_result.errors)
continue
# === 3. 调用依赖函数 ===
if call:
# 检查是否为协程函数
if iscoroutinefunction(call):
# 异步依赖
sub_response = await call(**solved_result.values)
else:
# 同步依赖,在线程池中运行
sub_response = await run_in_threadpool(call, **solved_result.values)
# === 4. 处理上下文管理器 ===
if hasattr(sub_response, "__aenter__"):
# 异步上下文管理器
sub_response = await async_exit_stack.enter_async_context(sub_response)
elif hasattr(sub_response, "__enter__"):
# 同步上下文管理器
sub_response = async_exit_stack.enter_context(sub_response)
# 将依赖结果添加到参数值中
if sub_dependant.name:
values[sub_dependant.name] = sub_response
# 合并后台任务和安全范围
if solved_result.background_tasks:
background_tasks.add_task_list(solved_result.background_tasks.tasks)
security_scopes.extend(solved_result.security_scopes)
# === 5. 处理特殊参数 ===
# Request对象
if dependant.request_param_name:
values[dependant.request_param_name] = request
# WebSocket对象
if dependant.websocket_param_name and isinstance(request, WebSocket):
values[dependant.websocket_param_name] = request
# HTTPConnection对象
if dependant.http_connection_param_name:
values[dependant.http_connection_param_name] = request
# BackgroundTasks对象
if dependant.background_tasks_param_name:
values[dependant.background_tasks_param_name] = background_tasks
# SecurityScopes对象
if dependant.security_scopes_param_name:
values[dependant.security_scopes_param_name] = SecurityScopes(scopes=security_scopes)
# Response对象(仅用于WebSocket)
if dependant.response_param_name:
response = Response()
values[dependant.response_param_name] = response
return SolvedDependency(
values=values,
errors=errors,
background_tasks=background_tasks,
security_scopes=security_scopes,
response=response,
)
4.2 依赖解析时序图
sequenceDiagram
participant Handler as 请求处理器
participant Solver as solve_dependencies()
participant ParamParser as 参数解析器
participant DepCall as 依赖调用
participant Cache as 依赖缓存
participant Context as 上下文管理器
Handler->>Solver: solve_dependencies(request, dependant)
loop 各种参数类型
Solver->>ParamParser: request_params_to_args()
ParamParser->>ParamParser: 解析和验证参数
ParamParser-->>Solver: (values, errors)
end
loop 每个子依赖
Solver->>Cache: 检查缓存
alt 缓存存在
Cache-->>Solver: 返回缓存结果
else 缓存不存在
Solver->>Solver: 递归调用 solve_dependencies()
Solver->>DepCall: 调用依赖函数
alt 异步依赖
DepCall->>DepCall: await call(**values)
else 同步依赖
DepCall->>DepCall: run_in_threadpool(call, **values)
end
DepCall-->>Solver: 依赖结果
alt 是上下文管理器
Solver->>Context: enter_async_context() / enter_context()
Context-->>Solver: 管理的资源
end
Solver->>Cache: 缓存结果(如果启用)
end
end
Solver->>Solver: 处理特殊参数 (Request, Response等)
Solver-->>Handler: SolvedDependency(values, errors, ...)
5. 依赖缓存机制
5.1 缓存策略
FastAPI 的依赖缓存有以下特点:
- 请求级别缓存:在同一个请求内,相同依赖只执行一次
- 可选择性:通过
use_cache=False
可以禁用缓存 - 基于函数引用:使用函数对象作为缓存键
- 自动清理:请求结束后自动清理缓存
# 缓存示例
expensive_computation_count = 0
def expensive_computation() -> str:
"""
模拟昂贵的计算
在同一请求内,这个函数只会被调用一次
"""
global expensive_computation_count
expensive_computation_count += 1
time.sleep(2) # 模拟耗时操作
return f"Result {expensive_computation_count}"
@app.get("/cached")
async def test_cache(
# 这两个依赖都会调用 expensive_computation,但只执行一次
result1: str = Depends(expensive_computation),
result2: str = Depends(expensive_computation)
) -> dict:
"""
测试依赖缓存
尽管有两个依赖都需要 expensive_computation 的结果,
但在同一请求内,expensive_computation 只会被执行一次
"""
return {"result1": result1, "result2": result2} # 两个结果相同
# 禁用缓存示例
@app.get("/no-cache")
async def test_no_cache(
result1: str = Depends(expensive_computation),
result2: str = Depends(expensive_computation, use_cache=False) # 禁用缓存
) -> dict:
"""
测试禁用缓存
result2 禁用了缓存,所以 expensive_computation 会被调用两次
"""
return {"result1": result1, "result2": result2} # 可能得到不同结果
5.2 上下文管理器支持
FastAPI 支持依赖作为上下文管理器,自动管理资源的生命周期:
from contextlib import contextmanager, asynccontextmanager
from typing import Iterator, AsyncIterator
# === 同步上下文管理器 ===
@contextmanager
def get_db_session() -> Iterator[Session]:
"""
数据库会话上下文管理器
自动处理会话的创建和清理
"""
session = SessionLocal()
try:
yield session
session.commit() # 自动提交
except Exception:
session.rollback() # 出错时回滚
raise
finally:
session.close() # 自动关闭会话
# === 异步上下文管理器 ===
@asynccontextmanager
async def get_redis_pool() -> AsyncIterator[Redis]:
"""
Redis连接池上下文管理器
自动处理连接的获取和释放
"""
pool = await aioredis.create_pool('redis://localhost')
try:
redis = Redis(connection_pool=pool)
yield redis
finally:
pool.close()
await pool.wait_closed()
# === 在路径操作中使用 ===
@app.get("/users/{user_id}")
async def get_user(
user_id: int,
# FastAPI 自动管理上下文管理器的生命周期
db: Session = Depends(get_db_session),
redis: Redis = Depends(get_redis_pool)
) -> User:
"""
使用上下文管理器依赖
FastAPI 会自动:
1. 进入上下文管理器(获取资源)
2. 执行路径操作函数
3. 退出上下文管理器(清理资源)
即使路径操作函数抛出异常,资源也会被正确清理
"""
# 从缓存中查找用户
cached_user = await redis.get(f"user:{user_id}")
if cached_user:
return User.parse_raw(cached_user)
# 从数据库查找用户
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
# 缓存用户信息
await redis.setex(f"user:{user_id}", 3600, user.json())
return user
# 请求处理流程:
# 1. 进入 get_db_session 上下文 -> 创建 session
# 2. 进入 get_redis_pool 上下文 -> 创建 redis 连接
# 3. 执行 get_user 函数
# 4. 退出 get_redis_pool 上下文 -> 关闭 redis 连接
# 5. 退出 get_db_session 上下文 -> 提交/回滚并关闭 session
6. 依赖作用域和生命周期
6.1 依赖的生命周期
FastAPI 中的依赖有明确的生命周期:
graph TD
A[请求到达] --> B[解析依赖]
B --> C[调用依赖函数]
C --> D[缓存结果]
D --> E[注入到路径操作]
E --> F[执行路径操作]
F --> G[清理上下文管理器]
G --> H[清理缓存]
H --> I[请求结束]
6.2 依赖作用域
# === 全局依赖 ===
# 应用级别的依赖,作用于所有路径操作
app = FastAPI(dependencies=[
Depends(verify_api_key), # 全局API密钥验证
Depends(rate_limiter), # 全局限流
])
# === 路由器级别依赖 ===
# 作用于整个路由器的所有路径操作
router = APIRouter(
prefix="/api/v1",
dependencies=[
Depends(get_current_user), # 路由器级别的用户认证
]
)
# === 路径操作级别依赖 ===
@router.get("/users/me")
async def get_current_user_info(
# 路径操作级别的依赖
user: User = Depends(get_current_user),
permissions: List[str] = Depends(get_user_permissions)
):
"""
依赖执行顺序:
1. 全局依赖:verify_api_key, rate_limiter
2. 路由器依赖:get_current_user
3. 路径操作依赖:get_current_user(缓存命中),get_user_permissions
"""
return {"user": user, "permissions": permissions}
7. 安全依赖集成
7.1 安全依赖的特殊处理
FastAPI 对安全依赖(继承自 SecurityBase
的依赖)有特殊处理:
from fastapi.security import OAuth2PasswordBearer, HTTPBearer
# OAuth2 安全方案
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# HTTP Bearer 安全方案
bearer_scheme = HTTPBearer()
def get_current_user(token: str = Depends(oauth2_scheme)) -> User:
"""
从令牌获取当前用户
oauth2_scheme 是一个安全依赖,FastAPI 会:
1. 在 OpenAPI 中标记此端点需要认证
2. 在 Swagger UI 中显示认证按钮
3. 从 Authorization 头中提取令牌
"""
# 验证令牌并返回用户
return verify_token(token)
@app.get("/protected")
async def protected_endpoint(
current_user: User = Depends(get_current_user)
):
"""
受保护的端点
在 OpenAPI 文档中,此端点会显示:
1. 需要 OAuth2 认证
2. 锁图标表示需要认证
3. 在 Swagger UI 中可以输入令牌进行测试
"""
return {"user": current_user.username}
7.2 复合安全依赖
from fastapi.security import SecurityScopes
def verify_permissions(
security_scopes: SecurityScopes,
token: str = Depends(oauth2_scheme)
) -> User:
"""
验证用户权限
Args:
security_scopes: 所需的权限范围
token: 访问令牌
Returns:
User: 验证通过的用户
Raises:
HTTPException: 权限不足或令牌无效
"""
user = verify_token(token)
# 检查用户是否有所需权限
for scope in security_scopes.scopes:
if scope not in user.permissions:
raise HTTPException(
status_code=403,
detail=f"Not enough permissions. Required: {scope}"
)
return user
@app.get("/admin/users")
async def list_users(
# 使用 Security() 指定所需权限
current_user: User = Security(verify_permissions, scopes=["users:read"])
):
"""
需要 users:read 权限的端点
FastAPI 会自动:
1. 将 scopes=["users:read"] 传递给 SecurityScopes
2. 在 OpenAPI 中记录所需权限
3. 在文档中显示权限要求
"""
return get_all_users()
8. 依赖注入的高级特性
8.1 动态依赖
def create_database_dependency(database_url: str):
"""
动态创建数据库依赖
根据配置动态创建不同的数据库连接
"""
engine = create_engine(database_url)
SessionLocal = sessionmaker(bind=engine)
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
return get_db
# 根据环境使用不同的数据库
if settings.ENVIRONMENT == "test":
get_db = create_database_dependency(settings.TEST_DATABASE_URL)
else:
get_db = create_database_dependency(settings.DATABASE_URL)
@app.get("/users")
async def list_users(db: Session = Depends(get_db)):
return db.query(User).all()
8.2 条件依赖
def conditional_dependency(condition: bool):
"""
条件依赖工厂
根据条件返回不同的依赖
"""
if condition:
return expensive_dependency
else:
return simple_dependency
def get_service_dependency():
"""
根据配置选择服务实现
"""
if settings.USE_REDIS:
return get_redis_service
else:
return get_memory_service
@app.get("/data")
async def get_data(
service = Depends(get_service_dependency())
):
return service.get_data()
9. 性能优化和最佳实践
9.1 依赖优化建议
# ✅ 好的做法:使用缓存
@lru_cache()
def get_settings():
"""
配置依赖使用缓存
配置信息通常不会改变,使用缓存可以避免重复创建
"""
return Settings()
# ✅ 好的做法:异步依赖
async def get_async_database():
"""
对于 I/O 密集型操作,使用异步依赖
"""
async with AsyncSession() as session:
yield session
# ❌ 避免的做法:在依赖中执行昂贵操作而不使用缓存
def expensive_calculation(): # 没有 use_cache=True
time.sleep(5) # 昂贵操作
return "result"
# ✅ 改进:使用缓存或者全局变量
@lru_cache()
def cached_expensive_calculation():
time.sleep(5)
return "result"
9.2 依赖组织建议
# dependencies/database.py
async def get_database_session():
"""数据库会话依赖"""
pass
# dependencies/auth.py
async def get_current_user():
"""认证依赖"""
pass
# dependencies/cache.py
async def get_redis_client():
"""缓存依赖"""
pass
# main.py
from dependencies import database, auth, cache
@app.get("/users")
async def list_users(
db: Session = Depends(database.get_database_session),
current_user: User = Depends(auth.get_current_user),
redis: Redis = Depends(cache.get_redis_client)
):
# 清晰的依赖组织
pass
10. 总结
FastAPI 的依赖注入系统是其最强大的特性之一,它通过以下创新实现了优雅的依赖管理:
- 类型驱动:基于 Python 类型提示自动推断依赖关系
- 声明式设计:通过
Depends()
声明式地定义依赖 - 递归解析:支持复杂的依赖链和嵌套依赖
- 智能缓存:请求级别的依赖缓存,提高性能
- 生命周期管理:自动管理上下文管理器的生命周期
- 安全集成:与安全系统深度集成,自动生成文档
- 测试友好:支持依赖覆盖,方便单元测试
这个系统不仅简化了代码编写,还提供了强大的功能和优秀的性能。下一章我们将分析 FastAPI 的安全认证系统实现。