1. 检查点系统概述
检查点系统是LangGraph的核心特性之一,提供了完整的状态持久化和恢复能力。它允许智能体在执行过程中保存状态快照,支持暂停、恢复、重放和时间旅行调试等高级功能。
1.1 系统架构
graph TB
subgraph "应用层"
USER[用户应用]
GRAPH[StateGraph]
end
subgraph "检查点接口层"
ABSTRACT[BaseCheckpointSaver]
METHODS[核心方法]
end
subgraph "实现层"
MEMORY[InMemorySaver]
POSTGRES[PostgresSaver]
SQLITE[SQLiteSaver]
end
subgraph "序列化层"
SERDE[SerializerProtocol]
JSONPLUS[JsonPlusSerializer]
ENCRYPTED[EncryptedSerializer]
end
subgraph "存储层"
INMEM[内存存储]
PGDB[(PostgreSQL)]
SQLITEDB[(SQLite)]
end
USER --> GRAPH
GRAPH --> ABSTRACT
ABSTRACT --> MEMORY
ABSTRACT --> POSTGRES
ABSTRACT --> SQLITE
MEMORY --> SERDE
POSTGRES --> SERDE
SQLITE --> SERDE
SERDE --> JSONPLUS
SERDE --> ENCRYPTED
MEMORY --> INMEM
POSTGRES --> PGDB
SQLITE --> SQLITEDB
style ABSTRACT fill:#f96,stroke:#333,stroke-width:2px
style SERDE fill:#69f,stroke:#333,stroke-width:2px
1.2 核心概念
1.2.1 检查点数据结构
class Checkpoint(TypedDict):
"""状态快照数据结构"""
v: int # 检查点格式版本,当前为1
id: str # 检查点唯一ID,单调递增
ts: str # 时间戳(ISO 8601格式)
channel_values: dict[str, Any] # 通道值映射
channel_versions: ChannelVersions # 通道版本映射
versions_seen: dict[str, ChannelVersions] # 节点已见版本记录
updated_channels: list[str] | None # 本次更新的通道列表
1.2.2 检查点元数据
class CheckpointMetadata(TypedDict, total=False):
"""检查点元数据"""
source: Literal["input", "loop", "update", "fork"] # 检查点来源
step: int # 步骤编号(-1为输入,0+为循环)
parents: dict[str, str] # 父检查点ID映射
1.2.3 检查点元组
class CheckpointTuple(NamedTuple):
"""检查点完整元组"""
config: RunnableConfig # 运行配置
checkpoint: Checkpoint # 检查点数据
metadata: CheckpointMetadata # 元数据
parent_config: RunnableConfig | None # 父配置
pending_writes: list[PendingWrite] | None # 待写入操作
2. BaseCheckpointSaver 基础接口
2.1 核心接口定义
class BaseCheckpointSaver(Generic[V]):
"""
检查点保存器基类
特性:
1. 提供统一的检查点操作接口
2. 支持同步和异步操作
3. 插件化的序列化机制
4. 版本管理和冲突检测
"""
serde: SerializerProtocol = JsonPlusSerializer() # 序列化器
def __init__(
self,
*,
serde: SerializerProtocol | None = None,
) -> None:
"""
初始化检查点保存器
参数:
- serde: 自定义序列化器(可选)
"""
self.serde = maybe_add_typed_methods(serde or self.serde)
@property
def config_specs(self) -> list:
"""
定义检查点保存器的配置规格
返回值:配置字段规格列表
"""
return []
2.2 核心方法接口
2.2.1 读取操作
def get(self, config: RunnableConfig) -> Checkpoint | None:
"""
获取检查点数据
参数:
- config: 包含thread_id和checkpoint_id的配置
返回值:检查点数据或None
"""
if value := self.get_tuple(config):
return value.checkpoint
def get_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
"""
获取完整的检查点元组
参数:
- config: 运行配置
返回值:检查点元组或None
注意:子类必须实现此方法
"""
raise NotImplementedError
def list(
self,
config: RunnableConfig | None,
*,
filter: dict[str, Any] | None = None,
before: RunnableConfig | None = None,
limit: int | None = None,
) -> Iterator[CheckpointTuple]:
"""
列出匹配条件的检查点
参数:
- config: 基础配置过滤器
- filter: 额外过滤条件
- before: 列出此配置之前创建的检查点
- limit: 返回检查点的最大数量
返回值:检查点元组迭代器
注意:子类必须实现此方法
"""
raise NotImplementedError
2.2.2 写入操作
def put(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""
存储检查点及其配置和元数据
参数:
- config: 检查点的配置
- checkpoint: 要存储的检查点
- metadata: 检查点元数据
- new_versions: 此次写入的新通道版本
返回值:存储后更新的配置
注意:子类必须实现此方法
"""
raise NotImplementedError
def put_writes(
self,
config: RunnableConfig,
writes: Sequence[tuple[str, Any]],
task_id: str,
task_path: str = "",
) -> None:
"""
存储链接到检查点的中间写入操作
参数:
- config: 相关检查点的配置
- writes: 要存储的写入操作列表
- task_id: 创建写入操作的任务ID
- task_path: 创建写入操作的任务路径
注意:子类必须实现此方法
"""
raise NotImplementedError
2.2.3 版本管理
def get_next_version(self, current: V | None, channel: None) -> V:
"""
生成通道的下一个版本ID
默认使用整数版本,每次递增1。子类可以重写此方法使用
字符串/整数/浮点数版本,只要它们是单调递增的。
参数:
- current: 当前版本标识符(int、float或str)
- channel: 已弃用的参数,保留以向后兼容
返回值:下一个版本标识符,必须是递增的
"""
if isinstance(current, str):
raise NotImplementedError
elif current is None:
return 1
else:
return current + 1
2.3 异步接口
# 异步版本的核心方法
async def aget_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
"""异步获取检查点元组"""
raise NotImplementedError
async def alist(
self,
config: RunnableConfig | None,
*,
filter: dict[str, Any] | None = None,
before: RunnableConfig | None = None,
limit: int | None = None,
) -> AsyncIterator[CheckpointTuple]:
"""异步列出检查点"""
raise NotImplementedError
async def aput(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""异步存储检查点"""
raise NotImplementedError
async def aput_writes(
self,
config: RunnableConfig,
writes: Sequence[tuple[str, Any]],
task_id: str,
task_path: str = "",
) -> None:
"""异步存储写入操作"""
raise NotImplementedError
3. InMemorySaver 内存实现
3.1 内存存储结构
class InMemorySaver(
BaseCheckpointSaver[str],
AbstractContextManager,
AbstractAsyncContextManager
):
"""
内存检查点保存器
数据结构:
- storage: thread_id -> checkpoint_ns -> checkpoint_id -> checkpoint数据
- writes: (thread_id, checkpoint_ns, checkpoint_id) -> 写入操作
- blobs: (thread_id, checkpoint_ns, channel, version) -> 通道值
"""
# 存储结构定义
storage: defaultdict[
str, # thread_id
dict[str, dict[str, tuple[tuple[str, bytes], tuple[str, bytes], str | None]]],
# checkpoint_ns -> checkpoint_id -> (checkpoint, metadata, parent_id)
]
writes: defaultdict[
tuple[str, str, str], # (thread_id, checkpoint_ns, checkpoint_id)
dict[tuple[str, int], tuple[str, str, tuple[str, bytes], str]],
# (task_id, write_idx) -> (task_id, channel, value, task_path)
]
blobs: dict[
tuple[str, str, str, str | int | float], # (thread_id, ns, channel, version)
tuple[str, bytes], # (type, serialized_value)
]
def __init__(
self,
*,
serde: SerializerProtocol | None = None,
factory: type[defaultdict] = defaultdict,
) -> None:
"""
初始化内存保存器
参数:
- serde: 序列化器
- factory: 字典工厂(支持自定义实现如Redis)
"""
super().__init__(serde=serde)
self.storage = factory(lambda: defaultdict(dict))
self.writes = factory(dict)
self.blobs = factory()
self.stack = ExitStack()
# 如果使用自定义工厂,注册到退出栈
if factory is not defaultdict:
self.stack.enter_context(self.storage)
self.stack.enter_context(self.writes)
self.stack.enter_context(self.blobs)
3.2 核心方法实现
3.2.1 检查点读取
def get_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
"""
从内存存储中获取检查点元组
逻辑:
1. 从配置提取thread_id和checkpoint_ns
2. 如果指定checkpoint_id,获取特定检查点
3. 否则获取最新检查点
4. 加载通道值和待写入操作
5. 构造完整的CheckpointTuple
"""
thread_id: str = config["configurable"]["thread_id"]
checkpoint_ns: str = config["configurable"].get("checkpoint_ns", "")
if checkpoint_id := get_checkpoint_id(config):
# 获取特定检查点
if saved := self.storage[thread_id][checkpoint_ns].get(checkpoint_id):
checkpoint, metadata, parent_checkpoint_id = saved
writes = self.writes[(thread_id, checkpoint_ns, checkpoint_id)].values()
checkpoint_: Checkpoint = self.serde.loads_typed(checkpoint)
return CheckpointTuple(
config=config,
checkpoint={
**checkpoint_,
"channel_values": self._load_blobs(
thread_id, checkpoint_ns, checkpoint_["channel_versions"]
),
},
metadata=self.serde.loads_typed(metadata),
pending_writes=[
(id, c, self.serde.loads_typed(v)) for id, c, v, _ in writes
],
parent_config=(
{
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": parent_checkpoint_id,
}
}
if parent_checkpoint_id
else None
),
)
else:
# 获取最新检查点
if checkpoints := self.storage[thread_id][checkpoint_ns]:
checkpoint_id = max(checkpoints.keys()) # 按ID排序获取最新
checkpoint, metadata, parent_checkpoint_id = checkpoints[checkpoint_id]
writes = self.writes[(thread_id, checkpoint_ns, checkpoint_id)].values()
checkpoint_ = self.serde.loads_typed(checkpoint)
return CheckpointTuple(
config={
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint_id,
}
},
checkpoint={
**checkpoint_,
"channel_values": self._load_blobs(
thread_id, checkpoint_ns, checkpoint_["channel_versions"]
),
},
metadata=self.serde.loads_typed(metadata),
pending_writes=[
(id, c, self.serde.loads_typed(v)) for id, c, v, _ in writes
],
parent_config=(
{
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": parent_checkpoint_id,
}
}
if parent_checkpoint_id
else None
),
)
def _load_blobs(
self,
thread_id: str,
checkpoint_ns: str,
versions: ChannelVersions
) -> dict[str, Any]:
"""
从blob存储加载通道值
参数:
- thread_id: 线程ID
- checkpoint_ns: 检查点命名空间
- versions: 通道版本映射
返回值:反序列化后的通道值字典
"""
channel_values: dict[str, Any] = {}
for k, v in versions.items():
kk = (thread_id, checkpoint_ns, k, v)
if kk in self.blobs:
vv = self.blobs[kk]
if vv[0] != "empty": # 跳过空值
channel_values[k] = self.serde.loads_typed(vv)
return channel_values
3.2.2 检查点写入
def put(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""
将检查点存储到内存
存储逻辑:
1. 提取配置信息
2. 序列化检查点和元数据
3. 存储检查点基础信息
4. 存储通道值到blob存储
5. 返回更新后的配置
"""
thread_id = config["configurable"]["thread_id"]
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
checkpoint_id = checkpoint["id"]
# 序列化数据
serialized_checkpoint = self.serde.dumps_typed(checkpoint)
serialized_metadata = self.serde.dumps_typed(metadata)
# 存储检查点基础信息
self.storage[thread_id][checkpoint_ns][checkpoint_id] = (
serialized_checkpoint,
serialized_metadata,
get_checkpoint_id(config, "parent_config"), # 父检查点ID
)
# 存储通道值到blob
for k, v in new_versions.items():
key = (thread_id, checkpoint_ns, k, v)
if k in checkpoint["channel_values"]:
# 存储实际值
value = checkpoint["channel_values"][k]
self.blobs[key] = self.serde.dumps_typed(value)
else:
# 存储空值标记
self.blobs[key] = ("empty", b"")
# 返回更新的配置
return {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint_id,
}
}
def put_writes(
self,
config: RunnableConfig,
writes: Sequence[tuple[str, Any]],
task_id: str,
task_path: str = "",
) -> None:
"""
存储待写入操作
写入逻辑:
1. 构造写入操作的唯一键
2. 序列化写入值
3. 存储到writes字典中
"""
thread_id = config["configurable"]["thread_id"]
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
checkpoint_id = get_checkpoint_id(config)
key = (thread_id, checkpoint_ns, checkpoint_id)
if key not in self.writes:
self.writes[key] = {}
# 存储每个写入操作
for idx, (channel, value) in enumerate(writes):
write_key = (task_id, idx)
serialized_value = self.serde.dumps_typed(value)
self.writes[key][write_key] = (
task_id,
channel,
serialized_value,
task_path,
)
3.2.3 检查点列举
def list(
self,
config: RunnableConfig | None,
*,
filter: dict[str, Any] | None = None,
before: RunnableConfig | None = None,
limit: int | None = None,
) -> Iterator[CheckpointTuple]:
"""
列出匹配条件的检查点
列举逻辑:
1. 确定搜索范围(thread_id, checkpoint_ns)
2. 应用过滤条件
3. 按时间戳排序
4. 限制返回数量
"""
# 确定搜索范围
if config:
thread_id = config["configurable"]["thread_id"]
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
thread_ids = [thread_id]
else:
thread_ids = list(self.storage.keys())
checkpoint_ns = ""
# 收集所有匹配的检查点
checkpoints = []
for tid in thread_ids:
if checkpoint_ns:
namespaces = [checkpoint_ns]
else:
namespaces = list(self.storage[tid].keys())
for ns in namespaces:
for cid, (checkpoint, metadata, parent_id) in self.storage[tid][ns].items():
checkpoint_data = self.serde.loads_typed(checkpoint)
metadata_data = self.serde.loads_typed(metadata)
# 应用过滤条件
if filter:
if not self._matches_filter(metadata_data, filter):
continue
# 应用before条件
if before:
before_id = get_checkpoint_id(before)
if before_id and cid >= before_id:
continue
# 构造检查点元组
checkpoint_tuple = CheckpointTuple(
config={
"configurable": {
"thread_id": tid,
"checkpoint_ns": ns,
"checkpoint_id": cid,
}
},
checkpoint={
**checkpoint_data,
"channel_values": self._load_blobs(
tid, ns, checkpoint_data["channel_versions"]
),
},
metadata=metadata_data,
pending_writes=None, # 按需加载
parent_config=(
{
"configurable": {
"thread_id": tid,
"checkpoint_ns": ns,
"checkpoint_id": parent_id,
}
}
if parent_id
else None
),
)
checkpoints.append(checkpoint_tuple)
# 按时间戳倒序排列
checkpoints.sort(key=lambda x: x.checkpoint["ts"], reverse=True)
# 应用限制
if limit:
checkpoints = checkpoints[:limit]
for checkpoint_tuple in checkpoints:
yield checkpoint_tuple
def _matches_filter(
self,
metadata: CheckpointMetadata,
filter_dict: dict[str, Any]
) -> bool:
"""检查元数据是否匹配过滤条件"""
for key, value in filter_dict.items():
if key not in metadata or metadata[key] != value:
return False
return True
4. 序列化机制
4.1 SerializerProtocol 接口
class SerializerProtocol(Protocol):
"""序列化器协议"""
def dumps(self, obj: Any) -> bytes:
"""序列化对象为字节串"""
...
def loads(self, data: bytes) -> Any:
"""从字节串反序列化对象"""
...
def dumps_typed(self, obj: tuple[str, Any]) -> tuple[str, bytes]:
"""类型化序列化"""
...
def loads_typed(self, data: tuple[str, bytes]) -> Any:
"""类型化反序列化"""
...
4.2 JsonPlusSerializer 实现
class JsonPlusSerializer:
"""
增强的JSON序列化器
特性:
1. 支持复杂Python对象
2. 保持类型信息
3. 处理循环引用
4. 自定义类型扩展
"""
def __init__(
self,
*,
file_path: Path | None = None,
at_least_one_serializer: bool = True,
) -> None:
self._encoders_by_name: dict[str, tuple[type, Callable[[Any], Any]]] = {}
self._decoders_by_name: dict[str, Callable[[Any], Any]] = {}
self._encoders_by_type: dict[type, str] = {}
# 注册默认编码器
self._register_default_encoders()
if file_path:
self._load_from_file(file_path)
def _register_default_encoders(self) -> None:
"""注册默认类型编码器"""
# 日期时间类型
self._encoders_by_name["datetime"] = (
datetime,
lambda obj: obj.isoformat(),
)
self._decoders_by_name["datetime"] = lambda data: datetime.fromisoformat(data)
self._encoders_by_type[datetime] = "datetime"
# UUID类型
self._encoders_by_name["uuid"] = (
UUID,
str,
)
self._decoders_by_name["uuid"] = UUID
self._encoders_by_type[UUID] = "uuid"
# 更多类型...
def dumps(self, obj: Any) -> bytes:
"""序列化对象"""
try:
return orjson.dumps(
obj,
default=self._encode_object,
option=orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_NON_STR_KEYS,
)
except Exception as e:
raise SerializationError(f"Failed to serialize object: {e}")
def loads(self, data: bytes) -> Any:
"""反序列化对象"""
try:
return orjson.loads(data)
except Exception as e:
raise SerializationError(f"Failed to deserialize data: {e}")
def _encode_object(self, obj: Any) -> Any:
"""编码复杂对象"""
obj_type = type(obj)
if encoder_name := self._encoders_by_type.get(obj_type):
_, encoder = self._encoders_by_name[encoder_name]
return {
"__type": encoder_name,
"__value": encoder(obj),
}
# 尝试基类编码
for cls, encoder_name in self._encoders_by_type.items():
if isinstance(obj, cls):
_, encoder = self._encoders_by_name[encoder_name]
return {
"__type": encoder_name,
"__value": encoder(obj),
}
raise TypeError(f"Object of type {obj_type} is not serializable")
def _decode_object(self, obj: Any) -> Any:
"""解码复杂对象"""
if isinstance(obj, dict) and "__type" in obj:
type_name = obj["__type"]
if decoder := self._decoders_by_name.get(type_name):
return decoder(obj["__value"])
return obj
4.3 加密序列化器
class EncryptedSerializer:
"""
加密序列化器
特性:
1. 数据加密存储
2. 密钥管理
3. 完整性校验
4. 向后兼容
"""
def __init__(
self,
key: bytes,
inner: SerializerProtocol | None = None,
) -> None:
if len(key) != 32:
raise ValueError("Key must be 32 bytes")
self.key = key
self.inner = inner or JsonPlusSerializer()
self.cipher_suite = Fernet(key)
def dumps(self, obj: Any) -> bytes:
"""加密序列化"""
# 先序列化
serialized = self.inner.dumps(obj)
# 再加密
encrypted = self.cipher_suite.encrypt(serialized)
return encrypted
def loads(self, data: bytes) -> Any:
"""解密反序列化"""
try:
# 先解密
decrypted = self.cipher_suite.decrypt(data)
# 再反序列化
return self.inner.loads(decrypted)
except Exception as e:
raise SerializationError(f"Failed to decrypt/deserialize: {e}")
5. 检查点生命周期管理
5.1 检查点创建流程
sequenceDiagram
participant Graph as StateGraph
participant Engine as Pregel引擎
participant Saver as CheckpointSaver
participant Storage as 存储后端
Graph->>Engine: 开始执行
Engine->>Engine: 创建初始检查点
loop 执行循环
Engine->>Engine: 执行节点
Engine->>Engine: 更新状态
Engine->>Saver: put(checkpoint)
Saver->>Storage: 存储检查点数据
Storage-->>Saver: 确认存储
Saver-->>Engine: 返回配置
alt 有待写入操作
Engine->>Saver: put_writes(writes)
Saver->>Storage: 存储写入操作
Storage-->>Saver: 确认存储
end
alt 需要中断
Engine->>Engine: 保存中断状态
break 中断执行
end
end
Engine-->>Graph: 返回最终结果
5.2 检查点恢复流程
sequenceDiagram
participant User as 用户
participant Graph as StateGraph
participant Engine as Pregel引擎
participant Saver as CheckpointSaver
User->>Graph: 恢复执行(config)
Graph->>Engine: 启动恢复模式
Engine->>Saver: get_tuple(config)
Saver-->>Engine: 返回检查点元组
Engine->>Engine: 恢复状态和通道
alt 有待写入操作
Engine->>Engine: 应用待写入操作
end
Engine->>Engine: 从断点继续执行
Engine-->>Graph: 返回执行结果
Graph-->>User: 返回最终结果
5.3 版本管理机制
class VersionManager:
"""
版本管理器:处理检查点版本控制
"""
def __init__(self, checkpointer: BaseCheckpointSaver):
self.checkpointer = checkpointer
def create_checkpoint_id(self, timestamp: str | None = None) -> str:
"""
创建检查点ID
使用uuid6确保时间排序和唯一性
"""
if timestamp:
# 使用指定时间戳
dt = datetime.fromisoformat(timestamp.replace('Z', '+00:00'))
return str(uuid6.uuid6(dt))
else:
# 使用当前时间
return str(uuid6.uuid6())
def get_version_chain(
self,
config: RunnableConfig
) -> list[CheckpointTuple]:
"""
获取检查点版本链
返回从最新到最旧的检查点列表
"""
chain = []
current_config = config
while current_config:
checkpoint_tuple = self.checkpointer.get_tuple(current_config)
if not checkpoint_tuple:
break
chain.append(checkpoint_tuple)
current_config = checkpoint_tuple.parent_config
return chain
def fork_checkpoint(
self,
source_config: RunnableConfig,
target_thread_id: str,
metadata_updates: dict[str, Any] | None = None,
) -> RunnableConfig:
"""
分叉检查点到新的线程
参数:
- source_config: 源检查点配置
- target_thread_id: 目标线程ID
- metadata_updates: 元数据更新
返回值:新检查点的配置
"""
# 获取源检查点
source_tuple = self.checkpointer.get_tuple(source_config)
if not source_tuple:
raise ValueError("Source checkpoint not found")
# 创建新的检查点ID
new_checkpoint_id = self.create_checkpoint_id()
# 复制检查点数据
new_checkpoint = copy_checkpoint(source_tuple.checkpoint)
new_checkpoint["id"] = new_checkpoint_id
new_checkpoint["ts"] = datetime.now(timezone.utc).isoformat()
# 更新元数据
new_metadata = dict(source_tuple.metadata)
new_metadata["source"] = "fork"
if metadata_updates:
new_metadata.update(metadata_updates)
# 创建新配置
new_config = {
"configurable": {
"thread_id": target_thread_id,
"checkpoint_ns": source_config["configurable"].get("checkpoint_ns", ""),
"checkpoint_id": new_checkpoint_id,
}
}
# 存储新检查点
self.checkpointer.put(
new_config,
new_checkpoint,
new_metadata,
new_checkpoint["channel_versions"],
)
return new_config
6. 使用示例和最佳实践
6.1 基础使用示例
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import StateGraph, MessagesState
# 创建检查点保存器
checkpointer = InMemorySaver()
# 创建图并启用检查点
graph = StateGraph(MessagesState)
graph.add_node("agent", agent_node)
graph.set_entry_point("agent")
graph.set_finish_point("agent")
compiled = graph.compile(checkpointer=checkpointer)
# 执行图(自动创建检查点)
config = {"configurable": {"thread_id": "thread-1"}}
result = compiled.invoke({"messages": [{"role": "user", "content": "Hello"}]}, config)
# 列出检查点历史
for checkpoint in checkpointer.list(config):
print(f"Checkpoint {checkpoint.checkpoint['id']} at {checkpoint.checkpoint['ts']}")
6.2 检查点恢复示例
# 从特定检查点恢复
checkpoint_id = "01234567-89ab-cdef-0123-456789abcdef"
resume_config = {
"configurable": {
"thread_id": "thread-1",
"checkpoint_id": checkpoint_id,
}
}
# 从检查点继续执行
result = compiled.invoke(
{"messages": [{"role": "user", "content": "Continue from here"}]},
resume_config
)
6.3 自定义检查点保存器
import redis
from langgraph.checkpoint.base import BaseCheckpointSaver
class RedisCheckpointSaver(BaseCheckpointSaver):
"""Redis检查点保存器实现"""
def __init__(self, redis_url: str, **kwargs):
super().__init__(**kwargs)
self.redis = redis.from_url(redis_url)
def get_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
"""从Redis获取检查点"""
thread_id = config["configurable"]["thread_id"]
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
if checkpoint_id := get_checkpoint_id(config):
key = f"checkpoint:{thread_id}:{checkpoint_ns}:{checkpoint_id}"
else:
# 获取最新检查点ID
pattern = f"checkpoint:{thread_id}:{checkpoint_ns}:*"
keys = self.redis.keys(pattern)
if not keys:
return None
# 按ID排序获取最新
keys.sort()
key = keys[-1].decode()
# 获取检查点数据
data = self.redis.hgetall(key)
if not data:
return None
# 反序列化数据
checkpoint = self.serde.loads(data[b'checkpoint'])
metadata = self.serde.loads(data[b'metadata'])
return CheckpointTuple(
config=config,
checkpoint=checkpoint,
metadata=metadata,
parent_config=None, # 简化实现
pending_writes=None,
)
def put(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""存储检查点到Redis"""
thread_id = config["configurable"]["thread_id"]
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
checkpoint_id = checkpoint["id"]
key = f"checkpoint:{thread_id}:{checkpoint_ns}:{checkpoint_id}"
# 序列化并存储
self.redis.hset(key, mapping={
"checkpoint": self.serde.dumps(checkpoint),
"metadata": self.serde.dumps(metadata),
})
# 设置过期时间(可选)
self.redis.expire(key, 86400 * 7) # 7天过期
return config
# 使用自定义保存器
redis_saver = RedisCheckpointSaver("redis://localhost:6379/0")
compiled = graph.compile(checkpointer=redis_saver)
7. 性能优化和监控
7.1 性能优化策略
class OptimizedCheckpointSaver(BaseCheckpointSaver):
"""优化版检查点保存器"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
# 启用压缩
self.enable_compression = True
# 批处理缓冲区
self.batch_buffer = []
self.batch_size = 100
self.batch_timeout = 5.0 # 秒
# 异步写入队列
self.write_queue = asyncio.Queue(maxsize=1000)
self.background_writer = None
async def _background_write_loop(self):
"""后台批量写入循环"""
batch = []
while True:
try:
# 等待写入项目或超时
item = await asyncio.wait_for(
self.write_queue.get(),
timeout=self.batch_timeout
)
batch.append(item)
# 批量写入
if len(batch) >= self.batch_size:
await self._flush_batch(batch)
batch = []
except asyncio.TimeoutError:
# 超时,写入现有批次
if batch:
await self._flush_batch(batch)
batch = []
async def _flush_batch(self, batch: list):
"""刷新批次到存储"""
# 实际的批量写入逻辑
pass
def put(self, config, checkpoint, metadata, new_versions):
"""非阻塞检查点存储"""
if self.enable_compression:
checkpoint = self._compress_checkpoint(checkpoint)
# 添加到异步队列
try:
self.write_queue.put_nowait((config, checkpoint, metadata, new_versions))
except asyncio.QueueFull:
# 队列满,同步写入
return self._sync_put(config, checkpoint, metadata, new_versions)
return config
def _compress_checkpoint(self, checkpoint: Checkpoint) -> Checkpoint:
"""压缩检查点数据"""
import gzip
compressed_checkpoint = dict(checkpoint)
# 压缩通道值
if "channel_values" in checkpoint:
channel_data = self.serde.dumps(checkpoint["channel_values"])
compressed_data = gzip.compress(channel_data)
compressed_checkpoint["channel_values"] = {
"__compressed": True,
"__data": compressed_data,
}
return compressed_checkpoint
7.2 监控和调试
class MonitoredCheckpointSaver:
"""带监控的检查点保存器"""
def __init__(self, inner: BaseCheckpointSaver):
self.inner = inner
self.metrics = {
"gets": 0,
"puts": 0,
"list_calls": 0,
"total_size": 0,
"avg_latency": 0.0,
}
self.start_time = time.time()
def get_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
"""监控版get_tuple"""
start = time.time()
try:
result = self.inner.get_tuple(config)
self.metrics["gets"] += 1
if result:
size = len(self.inner.serde.dumps(result.checkpoint))
self.metrics["total_size"] += size
return result
finally:
# 更新延迟
latency = time.time() - start
self._update_latency(latency)
def put(self, config, checkpoint, metadata, new_versions):
"""监控版put"""
start = time.time()
try:
result = self.inner.put(config, checkpoint, metadata, new_versions)
self.metrics["puts"] += 1
# 记录大小
size = len(self.inner.serde.dumps(checkpoint))
self.metrics["total_size"] += size
return result
finally:
latency = time.time() - start
self._update_latency(latency)
def _update_latency(self, latency: float):
"""更新平均延迟"""
total_ops = self.metrics["gets"] + self.metrics["puts"]
if total_ops > 1:
self.metrics["avg_latency"] = (
self.metrics["avg_latency"] * (total_ops - 1) + latency
) / total_ops
else:
self.metrics["avg_latency"] = latency
def get_stats(self) -> dict:
"""获取统计信息"""
uptime = time.time() - self.start_time
total_ops = self.metrics["gets"] + self.metrics["puts"]
return {
**self.metrics,
"uptime": uptime,
"ops_per_second": total_ops / uptime if uptime > 0 else 0,
"avg_size_mb": self.metrics["total_size"] / (1024 * 1024),
}
# 使用监控保存器
base_saver = InMemorySaver()
monitored_saver = MonitoredCheckpointSaver(base_saver)
# 周期性打印统计信息
async def print_stats():
while True:
stats = monitored_saver.get_stats()
print(f"Checkpoint Stats: {stats}")
await asyncio.sleep(30)
8. 总结
检查点系统是LangGraph的关键特性,提供了强大的状态管理能力:
8.1 核心优势
- 完整状态持久化:支持图执行过程中任意时刻的状态保存
- 灵活恢复机制:支持从任意检查点恢复执行
- 多后端支持:内存、PostgreSQL、SQLite等多种存储选择
- 版本管理:完整的版本控制和分支功能
- 性能优化:批处理、压缩、异步写入等优化手段
8.2 设计特点
- 接口统一:清晰的抽象接口便于扩展不同存储后端
- 序列化可插拔:支持多种序列化策略和加密需求
- 监控友好:丰富的监控接口支持性能分析
- 容错设计:完整的错误处理和数据一致性保证
检查点系统为构建可靠的长期运行AI应用提供了坚实的基础,是LangGraph企业级部署的重要保障。