LangGraph-06-checkpoint-postgres模块综合文档
0. 模块概览
0.1 模块职责
checkpoint-postgres 模块是 LangGraph checkpoint 基础接口的 PostgreSQL 实现,提供了基于 PostgreSQL 数据库的持久化检查点存储能力。它继承并实现了 BaseCheckpointSaver 接口,为 LangGraph 应用提供了可靠的状态持久化、恢复和历史查询功能。
0.2 模块输入输出
输入:
- PostgreSQL 数据库连接(Connection 或 ConnectionPool)
- Checkpoint 数据结构
- RunnableConfig 配置对象
- 序列化器(SerializerProtocol)
输出:
- 持久化的检查点数据
- 检查点历史记录
- CheckpointTuple 对象
- 写入操作记录
0.3 上下游依赖
依赖关系:
- psycopg3: PostgreSQL 数据库驱动
- psycopg_pool: 连接池管理
- langgraph.checkpoint.base: 基础接口定义
- langgraph.checkpoint.serde: 序列化协议
下游使用方:
- LangGraph 核心执行引擎
- 需要持久化状态的 Agent 应用
- 支持状态恢复的多轮对话系统
0.4 生命周期
- 初始化:建立数据库连接,设置序列化器
- 设置:创建必要的数据库表和索引
- 运行时:执行检查点的存储、查询和管理操作
- 清理:关闭数据库连接,释放资源
0.5 模块架构图
flowchart TD
A[PostgresSaver] --> B[BasePostgresSaver]
B --> C[BaseCheckpointSaver]
A --> D[Connection Management]
D --> E[Direct Connection]
D --> F[Connection Pool]
A --> G[SQL Operations]
G --> H[Checkpoint CRUD]
G --> I[Writes Management]
G --> J[Migration]
A --> K[Serialization]
K --> L[SerializerProtocol]
M[AsyncPostgresSaver] --> N[BasePostgresSaver]
M --> O[Async Operations]
P[PostgresStore] --> Q[BaseStore]
P --> R[Vector Search]
P --> S[TTL Management]
T[Database Schema] --> U[checkpoints Table]
T --> V[checkpoint_writes Table]
T --> W[store Table]
T --> X[Indexes]
架构说明:
- PostgresSaver: 同步版本的检查点保存器实现
- AsyncPostgresSaver: 异步版本的检查点保存器实现
- BasePostgresSaver: 共享的基础实现逻辑
- Connection Management: 支持单连接和连接池两种模式
- SQL Operations: 封装所有数据库操作逻辑
- PostgresStore: 提供键值存储和向量搜索功能
1. 关键数据结构与UML
1.1 核心数据结构
# PostgreSQL 连接类型定义
Conn = Union[Connection[DictRow], ConnectionPool[Connection[DictRow]]]
class PostgresSaver(BasePostgresSaver):
"""PostgreSQL 检查点保存器"""
conn: Conn # 数据库连接
pipe: Optional[Pipeline] # 批处理管道
lock: threading.Lock # 线程锁
class CheckpointRow(TypedDict):
"""数据库检查点行结构"""
thread_id: str # 线程ID
checkpoint_ns: str # 命名空间
checkpoint_id: str # 检查点ID
parent_checkpoint_id: str # 父检查点ID
type: str # 检查点类型
checkpoint: bytes # 序列化的检查点数据
metadata: bytes # 序列化的元数据
class WriteRow(TypedDict):
"""数据库写入记录行结构"""
thread_id: str # 线程ID
checkpoint_ns: str # 命名空间
checkpoint_id: str # 检查点ID
task_id: str # 任务ID
idx: int # 写入索引
channel: str # 通道名称
type: str # 数据类型
value: bytes # 序列化的值数据
1.2 数据库表结构
-- 检查点主表
CREATE TABLE IF NOT EXISTS checkpoints (
thread_id TEXT NOT NULL,
checkpoint_ns TEXT NOT NULL DEFAULT '',
checkpoint_id TEXT NOT NULL,
parent_checkpoint_id TEXT,
type TEXT,
checkpoint BYTEA NOT NULL,
metadata BYTEA NOT NULL,
PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id)
);
-- 写入记录表
CREATE TABLE IF NOT EXISTS checkpoint_writes (
thread_id TEXT NOT NULL,
checkpoint_ns TEXT NOT NULL DEFAULT '',
checkpoint_id TEXT NOT NULL,
task_id TEXT NOT NULL,
idx SERIAL NOT NULL,
channel TEXT NOT NULL,
type TEXT,
value BYTEA,
PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx)
);
-- 存储表(用于 PostgresStore)
CREATE TABLE IF NOT EXISTS store (
prefix TEXT NOT NULL,
key TEXT NOT NULL,
value JSONB NOT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (prefix, key)
);
1.3 类图关系
classDiagram
class BaseCheckpointSaver {
<<abstract>>
+serde: SerializerProtocol
+get(config) CheckpointTuple
+put(config, checkpoint, metadata) RunnableConfig
+list(config) Iterator[CheckpointTuple]
+put_writes(config, writes) None
}
class BasePostgresSaver {
<<abstract>>
+serde: SerializerProtocol
+setup() None
+list(config, filter, before, limit) Iterator[CheckpointTuple]
+get_tuple(config) CheckpointTuple
+put(config, checkpoint, metadata, new_versions) RunnableConfig
+put_writes(config, writes, task_id) None
}
class PostgresSaver {
+conn: Conn
+pipe: Pipeline
+lock: threading.Lock
+from_conn_string(conn_string) PostgresSaver
+setup() None
+list(config, filter, before, limit) Iterator[CheckpointTuple]
+get_tuple(config) CheckpointTuple
+put(config, checkpoint, metadata, new_versions) RunnableConfig
+put_writes(config, writes, task_id) None
+delete_thread(thread_id) None
}
class AsyncPostgresSaver {
+conn: AsyncConn
+lock: asyncio.Lock
+from_conn_string(conn_string) AsyncPostgresSaver
+asetup() None
+alist(config, filter, before, limit) AsyncIterator[CheckpointTuple]
+aget_tuple(config) CheckpointTuple
+aput(config, checkpoint, metadata, new_versions) RunnableConfig
+aput_writes(config, writes, task_id) None
}
class PostgresStore {
+conn: Conn
+index_config: IndexConfig
+ttl_config: TTLConfig
+setup() None
+get(namespace, key) Item
+put(namespace, key, value) None
+delete(namespace, key) None
+search(namespace, query) List[SearchItem]
}
BaseCheckpointSaver <|-- BasePostgresSaver
BasePostgresSaver <|-- PostgresSaver
BasePostgresSaver <|-- AsyncPostgresSaver
BaseStore <|-- PostgresStore
PostgresSaver --> "1" Connection : uses
PostgresSaver --> "1" ConnectionPool : uses
AsyncPostgresSaver --> "1" AsyncConnection : uses
PostgresStore --> "1" Connection : uses
类图说明:
- BaseCheckpointSaver: 定义检查点保存器的标准接口
- BasePostgresSaver: PostgreSQL 实现的共同基类
- PostgresSaver/AsyncPostgresSaver: 同步/异步版本的具体实现
- PostgresStore: 键值存储和搜索功能的实现
2. 对外API列表与规格
2.1 核心检查点API
2.1.1 setup()
基本信息:
- 名称:
setup - 协议:方法调用
checkpointer.setup() - 幂等性:是(可重复执行)
方法签名:
def setup(self) -> None:
"""设置检查点数据库"""
核心实现:
def setup(self) -> None:
"""设置数据库表和索引"""
# 1) 获取数据库连接
with self._get_connection() as conn:
with conn.cursor() as cur:
# 2) 创建检查点表
cur.execute("""
CREATE TABLE IF NOT EXISTS checkpoints (
thread_id TEXT NOT NULL,
checkpoint_ns TEXT NOT NULL DEFAULT '',
checkpoint_id TEXT NOT NULL,
parent_checkpoint_id TEXT,
type TEXT,
checkpoint BYTEA NOT NULL,
metadata BYTEA NOT NULL,
PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id)
)
""")
# 3) 创建写入记录表
cur.execute("""
CREATE TABLE IF NOT EXISTS checkpoint_writes (
thread_id TEXT NOT NULL,
checkpoint_ns TEXT NOT NULL DEFAULT '',
checkpoint_id TEXT NOT NULL,
task_id TEXT NOT NULL,
idx SERIAL NOT NULL,
channel TEXT NOT NULL,
type TEXT,
value BYTEA,
PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx)
)
""")
# 4) 创建必要的索引
cur.execute("""
CREATE INDEX IF NOT EXISTS checkpoints_thread_ns_idx
ON checkpoints (thread_id, checkpoint_ns)
""")
# 5) 提交事务
conn.commit()
时序图:
sequenceDiagram
autonumber
participant App as Application
participant PS as PostgresSaver
participant DB as PostgreSQL
App->>PS: setup()
PS->>DB: 获取连接
DB-->>PS: 连接建立
PS->>DB: CREATE TABLE checkpoints
DB-->>PS: 表创建成功
PS->>DB: CREATE TABLE checkpoint_writes
DB-->>PS: 表创建成功
PS->>DB: CREATE INDEX
DB-->>PS: 索引创建成功
PS->>DB: COMMIT
DB-->>PS: 事务提交成功
PS-->>App: 设置完成
2.1.2 put()
基本信息:
- 名称:
put - 协议:方法调用
checkpointer.put(config, checkpoint, metadata, new_versions) - 幂等性:否(每次插入新记录)
方法签名:
def put(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""存储检查点到数据库"""
请求参数:
| 参数 | 类型 | 必填 | 说明 |
|---|---|---|---|
| config | RunnableConfig | 是 | 运行配置(包含thread_id等) |
| checkpoint | Checkpoint | 是 | 检查点数据 |
| metadata | CheckpointMetadata | 是 | 检查点元数据 |
| new_versions | ChannelVersions | 是 | 通道版本信息 |
返回值:
| 字段 | 类型 | 说明 |
|---|---|---|
| 返回值 | RunnableConfig | 更新后的配置(包含checkpoint_id) |
核心实现:
def put(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""存储检查点核心逻辑"""
# 1) 提取配置信息
configurable = config.get("configurable", {})
thread_id = configurable["thread_id"]
checkpoint_ns = configurable.get("checkpoint_ns", "")
checkpoint_id = checkpoint["id"]
# 2) 序列化数据
serialized_checkpoint = self.serde.dumps_typed(checkpoint)
serialized_metadata = self.serde.dumps_typed(metadata)
# 3) 准备数据库操作
with self.lock:
with self._get_connection() as conn:
with conn.cursor() as cur:
# 4) 插入或更新检查点
cur.execute("""
INSERT INTO checkpoints
(thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id,
type, checkpoint, metadata)
VALUES (%s, %s, %s, %s, %s, %s, %s)
ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id)
DO UPDATE SET
checkpoint = EXCLUDED.checkpoint,
metadata = EXCLUDED.metadata
""", (
thread_id,
checkpoint_ns,
checkpoint_id,
configurable.get("checkpoint_id"), # parent_checkpoint_id
"standard", # type
serialized_checkpoint[1], # checkpoint data
serialized_metadata[1] # metadata
))
# 5) 提交事务
conn.commit()
# 6) 返回更新的配置
return {
**config,
"configurable": {
**configurable,
"checkpoint_id": checkpoint_id,
}
}
2.1.3 get_tuple()
基本信息:
- 名称:
get_tuple - 协议:方法调用
checkpointer.get_tuple(config) - 幂等性:是(只读操作)
方法签名:
def get_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
"""获取检查点元组"""
核心实现:
def get_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
"""获取检查点元组核心逻辑"""
# 1) 解析配置
configurable = config.get("configurable", {})
thread_id = configurable["thread_id"]
checkpoint_ns = configurable.get("checkpoint_ns", "")
checkpoint_id = configurable.get("checkpoint_id")
# 2) 构建查询条件
if checkpoint_id:
# 查询特定检查点
where_clause = """
thread_id = %s AND checkpoint_ns = %s AND checkpoint_id = %s
"""
params = [thread_id, checkpoint_ns, checkpoint_id]
else:
# 查询最新检查点
where_clause = """
thread_id = %s AND checkpoint_ns = %s
"""
params = [thread_id, checkpoint_ns]
order_clause = "ORDER BY checkpoint_id::timestamp DESC LIMIT 1"
# 3) 执行查询
with self._get_connection() as conn:
with conn.cursor(row_factory=dict_row) as cur:
cur.execute(f"""
SELECT * FROM checkpoints
WHERE {where_clause}
{order_clause if not checkpoint_id else ''}
""", params)
row = cur.fetchone()
if not row:
return None
# 4) 查询关联的写入记录
cur.execute("""
SELECT task_id, idx, channel, type, value
FROM checkpoint_writes
WHERE thread_id = %s AND checkpoint_ns = %s AND checkpoint_id = %s
ORDER BY task_id, idx
""", (thread_id, checkpoint_ns, row["checkpoint_id"]))
writes = cur.fetchall()
# 5) 反序列化数据
checkpoint = self.serde.loads_typed((row["type"], row["checkpoint"]))
metadata = self.serde.loads_typed(("json", row["metadata"]))
# 6) 处理写入记录
pending_writes = []
for write in writes:
value = self.serde.loads_typed((write["type"], write["value"]))
pending_writes.append((write["task_id"], write["channel"], value))
# 7) 构建父配置
parent_config = None
if row["parent_checkpoint_id"]:
parent_config = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": row["parent_checkpoint_id"]
}
}
# 8) 返回检查点元组
return CheckpointTuple(
config={
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": row["checkpoint_id"]
}
},
checkpoint=checkpoint,
metadata=metadata,
parent_config=parent_config,
pending_writes=pending_writes
)
2.1.4 list()
基本信息:
- 名称:
list - 协议:方法调用
checkpointer.list(config, filter=None, before=None, limit=10) - 幂等性:是(只读操作)
方法签名:
def list(
self,
config: RunnableConfig | None,
*,
filter: dict[str, Any] | None = None,
before: RunnableConfig | None = None,
limit: int = 10,
) -> Iterator[CheckpointTuple]:
"""列出检查点历史"""
核心实现:
def list(
self,
config: RunnableConfig | None,
*,
filter: dict[str, Any] | None = None,
before: RunnableConfig | None = None,
limit: int = 10,
) -> Iterator[CheckpointTuple]:
"""列出检查点历史核心逻辑"""
# 1) 构建基础查询条件
where_conditions = []
params = []
if config:
configurable = config.get("configurable", {})
if thread_id := configurable.get("thread_id"):
where_conditions.append("thread_id = %s")
params.append(thread_id)
if checkpoint_ns := configurable.get("checkpoint_ns"):
where_conditions.append("checkpoint_ns = %s")
params.append(checkpoint_ns)
# 2) 处理过滤条件
if filter:
for key, value in filter.items():
where_conditions.append(f"metadata->>'{key}' = %s")
params.append(json.dumps(value))
# 3) 处理 before 条件
if before:
before_id = before.get("configurable", {}).get("checkpoint_id")
if before_id:
where_conditions.append("checkpoint_id < %s")
params.append(before_id)
# 4) 构建完整查询
where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
# 5) 执行查询
with self._get_connection() as conn:
with conn.cursor(row_factory=dict_row) as cur:
cur.execute(f"""
SELECT * FROM checkpoints
WHERE {where_clause}
ORDER BY checkpoint_id DESC
LIMIT %s
""", params + [limit])
# 6) 逐行处理并生成结果
for row in cur:
# 查询相关写入记录
cur.execute("""
SELECT task_id, idx, channel, type, value
FROM checkpoint_writes
WHERE thread_id = %s AND checkpoint_ns = %s AND checkpoint_id = %s
ORDER BY task_id, idx
""", (row["thread_id"], row["checkpoint_ns"], row["checkpoint_id"]))
writes = cur.fetchall()
# 反序列化数据
checkpoint = self.serde.loads_typed((row["type"], row["checkpoint"]))
metadata = self.serde.loads_typed(("json", row["metadata"]))
# 处理写入记录
pending_writes = []
for write in writes:
value = self.serde.loads_typed((write["type"], write["value"]))
pending_writes.append((write["task_id"], write["channel"], value))
# 构建父配置
parent_config = None
if row["parent_checkpoint_id"]:
parent_config = {
"configurable": {
"thread_id": row["thread_id"],
"checkpoint_ns": row["checkpoint_ns"],
"checkpoint_id": row["parent_checkpoint_id"]
}
}
# 返回检查点元组
yield CheckpointTuple(
config={
"configurable": {
"thread_id": row["thread_id"],
"checkpoint_ns": row["checkpoint_ns"],
"checkpoint_id": row["checkpoint_id"]
}
},
checkpoint=checkpoint,
metadata=metadata,
parent_config=parent_config,
pending_writes=pending_writes
)
2.2 写入管理API
2.2.1 put_writes()
基本信息:
- 名称:
put_writes - 协议:方法调用
checkpointer.put_writes(config, writes, task_id) - 幂等性:否(每次添加新写入记录)
方法签名:
def put_writes(
self,
config: RunnableConfig,
writes: Sequence[tuple[str, Any]],
task_id: str,
task_path: str = "",
) -> None:
"""存储写入记录"""
核心实现:
def put_writes(
self,
config: RunnableConfig,
writes: Sequence[tuple[str, Any]],
task_id: str,
task_path: str = "",
) -> None:
"""存储写入记录核心逻辑"""
# 1) 提取配置信息
configurable = config.get("configurable", {})
thread_id = configurable["thread_id"]
checkpoint_ns = configurable.get("checkpoint_ns", "")
checkpoint_id = configurable["checkpoint_id"]
# 2) 准备批量插入数据
write_records = []
for idx, (channel, value) in enumerate(writes):
# 序列化值
serialized_value = self.serde.dumps_typed(value)
write_records.append((
thread_id,
checkpoint_ns,
checkpoint_id,
task_id,
idx,
channel,
serialized_value[0], # type
serialized_value[1] # value
))
# 3) 批量插入写入记录
with self.lock:
with self._get_connection() as conn:
with conn.cursor() as cur:
cur.executemany("""
INSERT INTO checkpoint_writes
(thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, value)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
""", write_records)
# 4) 提交事务
conn.commit()
3. 核心算法/流程剖析
3.1 连接管理算法
目的:管理 PostgreSQL 连接的获取和释放,支持单连接和连接池模式 输入:连接对象(Connection 或 ConnectionPool) 输出:可用的数据库连接 复杂度:O(1)
@contextmanager
def _get_connection(self) -> Iterator[Connection[DictRow]]:
"""获取数据库连接的核心算法"""
# 1) 判断连接类型
if isinstance(self.conn, ConnectionPool):
# 连接池模式:从池中获取连接
with self.conn.connection() as conn:
# 确保连接配置正确
conn.autocommit = True
conn.row_factory = dict_row
yield conn
elif isinstance(self.conn, Connection):
# 直连模式:直接使用现有连接
if self.conn.closed:
raise ConnectionError("数据库连接已关闭")
# 检查连接配置
if not hasattr(self.conn, 'autocommit') or not self.conn.autocommit:
self.conn.autocommit = True
if self.conn.row_factory != dict_row:
self.conn.row_factory = dict_row
yield self.conn
else:
raise TypeError(f"不支持的连接类型: {type(self.conn)}")
关键设计点:
- 连接池支持:自动从池中获取和归还连接
- 配置检查:确保 autocommit=True 和 row_factory=dict_row
- 错误处理:检测连接状态并提供清晰的错误信息
3.2 序列化管理算法
目的:统一管理检查点数据的序列化和反序列化过程 输入:Python 对象或序列化的字节数据 输出:序列化的字节数据或 Python 对象 复杂度:O(n),其中 n 为数据大小
class SerializationManager:
"""序列化管理器"""
def __init__(self, serde: SerializerProtocol):
self.serde = serde
def serialize_checkpoint(self, checkpoint: Checkpoint) -> tuple[str, bytes]:
"""序列化检查点数据"""
# 1) 使用类型化序列化
type_name, serialized_data = self.serde.dumps_typed(checkpoint)
# 2) 验证序列化结果
if not isinstance(serialized_data, bytes):
raise SerializationError("序列化结果必须是字节类型")
return type_name, serialized_data
def deserialize_checkpoint(self, type_name: str, data: bytes) -> Checkpoint:
"""反序列化检查点数据"""
# 1) 类型化反序列化
try:
checkpoint = self.serde.loads_typed((type_name, data))
except Exception as e:
raise DeserializationError(f"反序列化失败: {e}")
# 2) 验证反序列化结果
if not isinstance(checkpoint, dict) or "v" not in checkpoint:
raise DeserializationError("无效的检查点数据结构")
return checkpoint
def batch_serialize_writes(self, writes: Sequence[tuple[str, Any]]) -> List[tuple[str, str, bytes]]:
"""批量序列化写入记录"""
serialized_writes = []
for channel, value in writes:
try:
type_name, serialized_value = self.serde.dumps_typed(value)
serialized_writes.append((channel, type_name, serialized_value))
except Exception as e:
# 记录序列化失败的通道和值
logger.warning(f"写入记录序列化失败 - 通道: {channel}, 错误: {e}")
# 可以选择跳过或使用默认值
continue
return serialized_writes
3.3 事务管理算法
目的:确保检查点操作的原子性和一致性 输入:数据库操作序列 输出:事务执行结果 复杂度:O(m),其中 m 为操作数量
class TransactionManager:
"""事务管理器"""
def __init__(self, connection_getter: Callable):
self.get_connection = connection_getter
def execute_checkpoint_transaction(
self,
checkpoint_data: CheckpointData,
writes_data: List[WriteData]
) -> bool:
"""执行检查点事务"""
try:
with self.get_connection() as conn:
with conn.transaction(): # 开始事务
cursor = conn.cursor()
# 1) 插入或更新检查点
cursor.execute("""
INSERT INTO checkpoints (...) VALUES (...)
ON CONFLICT (...) DO UPDATE SET ...
""", checkpoint_data.to_params())
# 2) 批量插入写入记录
if writes_data:
cursor.executemany("""
INSERT INTO checkpoint_writes (...) VALUES (...)
""", [w.to_params() for w in writes_data])
# 3) 验证插入结果
cursor.execute("""
SELECT COUNT(*) FROM checkpoints
WHERE thread_id = %s AND checkpoint_id = %s
""", (checkpoint_data.thread_id, checkpoint_data.checkpoint_id))
if cursor.fetchone()[0] != 1:
raise TransactionError("检查点插入验证失败")
# 事务自动提交(由 with conn.transaction() 管理)
return True
except Exception as e:
logger.error(f"检查点事务执行失败: {e}")
# 事务自动回滚
return False
4. 模块级架构图与时序图
4.1 整体数据流架构图
flowchart LR
subgraph "LangGraph Core"
A[StateGraph] --> B[CompiledStateGraph]
B --> C[Pregel]
end
subgraph "Checkpoint Layer"
C --> D[BaseCheckpointSaver]
D --> E[PostgresSaver]
end
subgraph "Database Layer"
E --> F[Connection Manager]
F --> G[PostgreSQL DB]
E --> H[Serialization]
H --> I[JsonPlusSerializer]
E --> J[Transaction Manager]
J --> G
end
subgraph "Storage Schema"
G --> K[checkpoints Table]
G --> L[checkpoint_writes Table]
G --> M[Indexes]
end
subgraph "Async Support"
N[AsyncPostgresSaver] --> O[Async Connection]
O --> G
end
4.2 完整检查点生命周期时序图
sequenceDiagram
autonumber
participant App as Application
participant Graph as StateGraph
participant Saver as PostgresSaver
participant Conn as Connection Manager
participant DB as PostgreSQL
participant Ser as Serializer
Note over App,Ser: 1. 初始化阶段
App->>Saver: PostgresSaver(conn, serde)
Saver->>Conn: 初始化连接管理
Saver->>Saver: 设置线程锁
Note over App,Ser: 2. 设置阶段
App->>Saver: setup()
Saver->>Conn: get_connection()
Conn->>DB: 建立连接
DB-->>Conn: 连接就绪
Saver->>DB: CREATE TABLE checkpoints
Saver->>DB: CREATE TABLE checkpoint_writes
Saver->>DB: CREATE INDEX
DB-->>Saver: 表和索引创建成功
Note over App,Ser: 3. 运行阶段 - 存储检查点
Graph->>Saver: put(config, checkpoint, metadata, versions)
Saver->>Ser: dumps_typed(checkpoint)
Ser-->>Saver: 序列化数据
Saver->>Ser: dumps_typed(metadata)
Ser-->>Saver: 序列化元数据
Saver->>Conn: get_connection()
Conn->>DB: 获取连接
DB-->>Conn: 连接准备就绪
Saver->>DB: BEGIN TRANSACTION
Saver->>DB: INSERT INTO checkpoints
DB-->>Saver: 插入成功
Saver->>DB: COMMIT
DB-->>Saver: 事务提交成功
Saver-->>Graph: 返回更新的配置
Note over App,Ser: 4. 运行阶段 - 存储写入记录
Graph->>Saver: put_writes(config, writes, task_id)
Saver->>Ser: 批量序列化写入数据
Ser-->>Saver: 序列化完成
Saver->>DB: BEGIN TRANSACTION
Saver->>DB: INSERT INTO checkpoint_writes (批量)
DB-->>Saver: 批量插入成功
Saver->>DB: COMMIT
DB-->>Saver: 事务提交成功
Note over App,Ser: 5. 查询阶段
App->>Saver: get_tuple(config)
Saver->>Conn: get_connection()
Conn->>DB: 获取连接
Saver->>DB: SELECT FROM checkpoints
DB-->>Saver: 检查点数据行
Saver->>DB: SELECT FROM checkpoint_writes
DB-->>Saver: 写入记录行
Saver->>Ser: loads_typed(checkpoint_data)
Ser-->>Saver: 反序列化检查点
Saver->>Ser: loads_typed(metadata)
Ser-->>Saver: 反序列化元数据
Saver->>Ser: 批量反序列化写入数据
Ser-->>Saver: 反序列化完成
Saver->>Saver: 构建 CheckpointTuple
Saver-->>App: 返回完整检查点元组
4.3 异步操作时序图
sequenceDiagram
autonumber
participant App as Async Application
participant ASaver as AsyncPostgresSaver
participant AConn as Async Connection
participant DB as PostgreSQL
App->>ASaver: asetup()
ASaver->>AConn: 获取异步连接
AConn->>DB: 建立异步连接
DB-->>AConn: 连接建立
ASaver->>DB: 异步创建表和索引
DB-->>ASaver: 创建完成
App->>ASaver: aput(config, checkpoint, metadata, versions)
ASaver->>ASaver: 异步序列化数据
ASaver->>AConn: 获取异步连接
AConn->>DB: 获取连接
ASaver->>DB: 异步 INSERT
DB-->>ASaver: 插入完成
ASaver-->>App: 返回结果
App->>ASaver: aget_tuple(config)
ASaver->>AConn: 获取异步连接
ASaver->>DB: 异步 SELECT
DB-->>ASaver: 查询结果
ASaver->>ASaver: 异步反序列化
ASaver-->>App: 返回 CheckpointTuple
5. 异常处理与性能优化
5.1 异常处理策略
5.1.1 连接异常处理
class ConnectionManager:
"""连接管理器异常处理"""
def __init__(self, conn: Conn, max_retries: int = 3, retry_delay: float = 1.0):
self.conn = conn
self.max_retries = max_retries
self.retry_delay = retry_delay
@contextmanager
def get_connection_with_retry(self) -> Iterator[Connection[DictRow]]:
"""带重试的连接获取"""
last_exception = None
for attempt in range(self.max_retries):
try:
with self._get_connection() as conn:
yield conn
return
except (ConnectionError, OperationalError) as e:
last_exception = e
logger.warning(f"连接失败,第 {attempt + 1} 次尝试: {e}")
if attempt < self.max_retries - 1:
time.sleep(self.retry_delay * (2 ** attempt)) # 指数退避
except Exception as e:
# 非连接异常直接抛出
logger.error(f"非连接异常: {e}")
raise
# 所有重试都失败
raise ConnectionError(f"连接失败,已重试 {self.max_retries} 次: {last_exception}")
def health_check(self) -> bool:
"""连接健康检查"""
try:
with self._get_connection() as conn:
with conn.cursor() as cur:
cur.execute("SELECT 1")
return cur.fetchone() is not None
except Exception as e:
logger.error(f"健康检查失败: {e}")
return False
5.1.2 序列化异常处理
class SafeSerializationManager:
"""安全序列化管理器"""
def __init__(self, serde: SerializerProtocol, fallback_serde: SerializerProtocol = None):
self.serde = serde
self.fallback_serde = fallback_serde or JsonPlusSerializer()
def safe_serialize(self, obj: Any) -> tuple[str, bytes]:
"""安全序列化"""
try:
return self.serde.dumps_typed(obj)
except Exception as e:
logger.warning(f"主序列化器失败,使用备用序列化器: {e}")
try:
return self.fallback_serde.dumps_typed(obj)
except Exception as fallback_e:
logger.error(f"备用序列化器也失败: {fallback_e}")
# 最后的备案:存储错误信息
error_info = {
"error": "serialization_failed",
"original_type": type(obj).__name__,
"primary_error": str(e),
"fallback_error": str(fallback_e)
}
return self.fallback_serde.dumps_typed(error_info)
def safe_deserialize(self, type_name: str, data: bytes) -> Any:
"""安全反序列化"""
try:
result = self.serde.loads_typed((type_name, data))
# 检查是否是错误占位符
if isinstance(result, dict) and result.get("error") == "serialization_failed":
logger.warning("检测到序列化失败的占位数据")
return None
return result
except Exception as e:
logger.error(f"反序列化失败: {e}")
# 尝试用备用序列化器
try:
return self.fallback_serde.loads_typed((type_name, data))
except Exception:
logger.error("备用反序列化也失败,返回 None")
return None
5.2 性能优化策略
5.2.1 批处理优化
class BatchOptimizer:
"""批处理优化器"""
def __init__(self, batch_size: int = 100, flush_interval: float = 5.0):
self.batch_size = batch_size
self.flush_interval = flush_interval
self.pending_writes = []
self.last_flush = time.time()
self.lock = threading.Lock()
def add_write(self, write_data: WriteData):
"""添加写入操作到批处理队列"""
with self.lock:
self.pending_writes.append(write_data)
# 检查是否需要刷新
if (len(self.pending_writes) >= self.batch_size or
time.time() - self.last_flush >= self.flush_interval):
self._flush_writes()
def _flush_writes(self):
"""刷新写入队列"""
if not self.pending_writes:
return
writes_to_flush = self.pending_writes[:]
self.pending_writes.clear()
self.last_flush = time.time()
try:
self._execute_batch_writes(writes_to_flush)
except Exception as e:
logger.error(f"批处理写入失败: {e}")
# 将失败的写入重新加入队列(可选)
self.pending_writes.extend(writes_to_flush)
def _execute_batch_writes(self, writes: List[WriteData]):
"""执行批量写入"""
# 按线程ID和检查点ID分组
grouped_writes = defaultdict(list)
for write in writes:
key = (write.thread_id, write.checkpoint_id)
grouped_writes[key].append(write)
# 为每组执行批量插入
for (thread_id, checkpoint_id), group_writes in grouped_writes.items():
self._batch_insert_writes(group_writes)
5.2.2 连接池优化
class OptimizedConnectionPool:
"""优化的连接池管理"""
def __init__(
self,
conn_string: str,
min_size: int = 2,
max_size: int = 10,
max_lifetime: float = 3600, # 1小时
max_idle_time: float = 300, # 5分钟
):
self.pool = ConnectionPool(
conn_string,
min_size=min_size,
max_size=max_size,
kwargs={
"autocommit": True,
"row_factory": dict_row
}
)
self.max_lifetime = max_lifetime
self.max_idle_time = max_idle_time
self._monitoring_thread = None
self._start_monitoring()
def _start_monitoring(self):
"""启动连接池监控"""
def monitor():
while True:
try:
# 检查连接健康状态
self.pool.check()
# 清理过期连接
self._cleanup_expired_connections()
time.sleep(30) # 30秒检查一次
except Exception as e:
logger.error(f"连接池监控异常: {e}")
self._monitoring_thread = threading.Thread(target=monitor, daemon=True)
self._monitoring_thread.start()
def _cleanup_expired_connections(self):
"""清理过期连接"""
# 这里可以实现更复杂的连接清理逻辑
# 例如基于连接创建时间和空闲时间的清理
pass
def get_stats(self) -> dict:
"""获取连接池统计信息"""
return {
"pool_size": self.pool.get_stats()["pool_size"],
"available": self.pool.get_stats()["available"],
"used": self.pool.get_stats()["used"],
}
5.2.3 查询优化
class QueryOptimizer:
"""查询优化器"""
def __init__(self):
self.query_cache = {}
self.cache_ttl = 300 # 5分钟缓存
def get_optimized_list_query(
self,
config: RunnableConfig,
filter_conditions: List[str],
params: List[Any],
limit: int
) -> tuple[str, List[Any]]:
"""生成优化的列表查询"""
# 1) 基础查询构建
base_query = """
SELECT c.*,
array_agg(
json_build_object(
'task_id', w.task_id,
'idx', w.idx,
'channel', w.channel,
'type', w.type,
'value', w.value
) ORDER BY w.task_id, w.idx
) FILTER (WHERE w.task_id IS NOT NULL) as writes
FROM checkpoints c
LEFT JOIN checkpoint_writes w ON (
c.thread_id = w.thread_id AND
c.checkpoint_ns = w.checkpoint_ns AND
c.checkpoint_id = w.checkpoint_id
)
"""
# 2) 添加 WHERE 条件
where_clause = " AND ".join(filter_conditions) if filter_conditions else "1=1"
# 3) 添加分组和排序
group_order_clause = """
GROUP BY c.thread_id, c.checkpoint_ns, c.checkpoint_id,
c.parent_checkpoint_id, c.type, c.checkpoint, c.metadata
ORDER BY c.checkpoint_id DESC
LIMIT %s
"""
# 4) 组合完整查询
full_query = f"{base_query} WHERE {where_clause} {group_order_clause}"
full_params = params + [limit]
return full_query, full_params
def create_indexes_for_common_queries(self, connection):
"""为常见查询创建索引"""
indexes = [
# 线程ID和命名空间的复合索引
"""
CREATE INDEX IF NOT EXISTS idx_checkpoints_thread_ns_id
ON checkpoints (thread_id, checkpoint_ns, checkpoint_id DESC)
""",
# 写入记录的复合索引
"""
CREATE INDEX IF NOT EXISTS idx_writes_thread_checkpoint
ON checkpoint_writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx)
""",
# 元数据查询索引(GIN索引支持JSON查询)
"""
CREATE INDEX IF NOT EXISTS idx_checkpoints_metadata
ON checkpoints USING gin (metadata)
""",
]
with connection.cursor() as cur:
for index_sql in indexes:
try:
cur.execute(index_sql)
except Exception as e:
logger.warning(f"索引创建失败: {e}")
6. 总结
checkpoint-postgres 模块作为 LangGraph 的核心持久化组件,提供了高效、可靠的 PostgreSQL 检查点存储实现。其主要特点包括:
6.1 核心优势
- 高性能:利用 PostgreSQL 的事务特性和索引优化
- 可扩展性:支持连接池和异步操作
- 数据完整性:完善的事务管理和异常处理
- 灵活性:支持多种连接模式和配置选项
6.2 架构特点
- 分层设计:清晰的接口抽象和实现分离
- 异步支持:完整的异步操作实现
- 批处理优化:高效的批量数据操作
- 监控友好:完善的日志和性能指标
6.3 应用价值
该模块为需要持久化状态的 LangGraph 应用提供了生产级的数据存储解决方案,特别适合需要高并发、高可靠性的企业级应用场景。通过 PostgreSQL 的强大功能,实现了检查点数据的高效管理和查询能力。