LangGraph-06-checkpoint-postgres模块综合文档

0. 模块概览

0.1 模块职责

checkpoint-postgres 模块是 LangGraph checkpoint 基础接口的 PostgreSQL 实现,提供了基于 PostgreSQL 数据库的持久化检查点存储能力。它继承并实现了 BaseCheckpointSaver 接口,为 LangGraph 应用提供了可靠的状态持久化、恢复和历史查询功能。

0.2 模块输入输出

输入

  • PostgreSQL 数据库连接(Connection 或 ConnectionPool)
  • Checkpoint 数据结构
  • RunnableConfig 配置对象
  • 序列化器(SerializerProtocol)

输出

  • 持久化的检查点数据
  • 检查点历史记录
  • CheckpointTuple 对象
  • 写入操作记录

0.3 上下游依赖

依赖关系

  • psycopg3: PostgreSQL 数据库驱动
  • psycopg_pool: 连接池管理
  • langgraph.checkpoint.base: 基础接口定义
  • langgraph.checkpoint.serde: 序列化协议

下游使用方

  • LangGraph 核心执行引擎
  • 需要持久化状态的 Agent 应用
  • 支持状态恢复的多轮对话系统

0.4 生命周期

  1. 初始化:建立数据库连接,设置序列化器
  2. 设置:创建必要的数据库表和索引
  3. 运行时:执行检查点的存储、查询和管理操作
  4. 清理:关闭数据库连接,释放资源

0.5 模块架构图

flowchart TD
    A[PostgresSaver] --> B[BasePostgresSaver]
    B --> C[BaseCheckpointSaver]
    
    A --> D[Connection Management]
    D --> E[Direct Connection]
    D --> F[Connection Pool]
    
    A --> G[SQL Operations]
    G --> H[Checkpoint CRUD]
    G --> I[Writes Management]
    G --> J[Migration]
    
    A --> K[Serialization]
    K --> L[SerializerProtocol]
    
    M[AsyncPostgresSaver] --> N[BasePostgresSaver]
    M --> O[Async Operations]
    
    P[PostgresStore] --> Q[BaseStore]
    P --> R[Vector Search]
    P --> S[TTL Management]
    
    T[Database Schema] --> U[checkpoints Table]
    T --> V[checkpoint_writes Table]
    T --> W[store Table]
    T --> X[Indexes]

架构说明

  • PostgresSaver: 同步版本的检查点保存器实现
  • AsyncPostgresSaver: 异步版本的检查点保存器实现
  • BasePostgresSaver: 共享的基础实现逻辑
  • Connection Management: 支持单连接和连接池两种模式
  • SQL Operations: 封装所有数据库操作逻辑
  • PostgresStore: 提供键值存储和向量搜索功能

1. 关键数据结构与UML

1.1 核心数据结构

# PostgreSQL 连接类型定义
Conn = Union[Connection[DictRow], ConnectionPool[Connection[DictRow]]]

class PostgresSaver(BasePostgresSaver):
    """PostgreSQL 检查点保存器"""
    
    conn: Conn                    # 数据库连接
    pipe: Optional[Pipeline]      # 批处理管道
    lock: threading.Lock         # 线程锁
    
class CheckpointRow(TypedDict):
    """数据库检查点行结构"""
    
    thread_id: str              # 线程ID
    checkpoint_ns: str          # 命名空间
    checkpoint_id: str          # 检查点ID
    parent_checkpoint_id: str   # 父检查点ID
    type: str                   # 检查点类型
    checkpoint: bytes           # 序列化的检查点数据
    metadata: bytes             # 序列化的元数据

class WriteRow(TypedDict):
    """数据库写入记录行结构"""
    
    thread_id: str              # 线程ID
    checkpoint_ns: str          # 命名空间
    checkpoint_id: str          # 检查点ID
    task_id: str               # 任务ID
    idx: int                   # 写入索引
    channel: str               # 通道名称
    type: str                  # 数据类型
    value: bytes               # 序列化的值数据

1.2 数据库表结构

-- 检查点主表
CREATE TABLE IF NOT EXISTS checkpoints (
    thread_id TEXT NOT NULL,
    checkpoint_ns TEXT NOT NULL DEFAULT '',
    checkpoint_id TEXT NOT NULL,
    parent_checkpoint_id TEXT,
    type TEXT,
    checkpoint BYTEA NOT NULL,
    metadata BYTEA NOT NULL,
    PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id)
);

-- 写入记录表
CREATE TABLE IF NOT EXISTS checkpoint_writes (
    thread_id TEXT NOT NULL,
    checkpoint_ns TEXT NOT NULL DEFAULT '',
    checkpoint_id TEXT NOT NULL,
    task_id TEXT NOT NULL,
    idx SERIAL NOT NULL,
    channel TEXT NOT NULL,
    type TEXT,
    value BYTEA,
    PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx)
);

-- 存储表(用于 PostgresStore)
CREATE TABLE IF NOT EXISTS store (
    prefix TEXT NOT NULL,
    key TEXT NOT NULL,
    value JSONB NOT NULL,
    created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
    updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
    PRIMARY KEY (prefix, key)
);

1.3 类图关系

classDiagram
    class BaseCheckpointSaver {
        <<abstract>>
        +serde: SerializerProtocol
        +get(config) CheckpointTuple
        +put(config, checkpoint, metadata) RunnableConfig
        +list(config) Iterator[CheckpointTuple]
        +put_writes(config, writes) None
    }
    
    class BasePostgresSaver {
        <<abstract>>
        +serde: SerializerProtocol
        +setup() None
        +list(config, filter, before, limit) Iterator[CheckpointTuple]
        +get_tuple(config) CheckpointTuple
        +put(config, checkpoint, metadata, new_versions) RunnableConfig
        +put_writes(config, writes, task_id) None
    }
    
    class PostgresSaver {
        +conn: Conn
        +pipe: Pipeline
        +lock: threading.Lock
        +from_conn_string(conn_string) PostgresSaver
        +setup() None
        +list(config, filter, before, limit) Iterator[CheckpointTuple]
        +get_tuple(config) CheckpointTuple
        +put(config, checkpoint, metadata, new_versions) RunnableConfig
        +put_writes(config, writes, task_id) None
        +delete_thread(thread_id) None
    }
    
    class AsyncPostgresSaver {
        +conn: AsyncConn
        +lock: asyncio.Lock
        +from_conn_string(conn_string) AsyncPostgresSaver
        +asetup() None
        +alist(config, filter, before, limit) AsyncIterator[CheckpointTuple]
        +aget_tuple(config) CheckpointTuple
        +aput(config, checkpoint, metadata, new_versions) RunnableConfig
        +aput_writes(config, writes, task_id) None
    }
    
    class PostgresStore {
        +conn: Conn
        +index_config: IndexConfig
        +ttl_config: TTLConfig
        +setup() None
        +get(namespace, key) Item
        +put(namespace, key, value) None
        +delete(namespace, key) None
        +search(namespace, query) List[SearchItem]
    }
    
    BaseCheckpointSaver <|-- BasePostgresSaver
    BasePostgresSaver <|-- PostgresSaver
    BasePostgresSaver <|-- AsyncPostgresSaver
    BaseStore <|-- PostgresStore
    
    PostgresSaver --> "1" Connection : uses
    PostgresSaver --> "1" ConnectionPool : uses
    AsyncPostgresSaver --> "1" AsyncConnection : uses
    PostgresStore --> "1" Connection : uses

类图说明

  • BaseCheckpointSaver: 定义检查点保存器的标准接口
  • BasePostgresSaver: PostgreSQL 实现的共同基类
  • PostgresSaver/AsyncPostgresSaver: 同步/异步版本的具体实现
  • PostgresStore: 键值存储和搜索功能的实现

2. 对外API列表与规格

2.1 核心检查点API

2.1.1 setup()

基本信息

  • 名称:setup
  • 协议:方法调用 checkpointer.setup()
  • 幂等性:是(可重复执行)

方法签名

def setup(self) -> None:
    """设置检查点数据库"""

核心实现

def setup(self) -> None:
    """设置数据库表和索引"""
    
    # 1) 获取数据库连接
    with self._get_connection() as conn:
        with conn.cursor() as cur:
            
            # 2) 创建检查点表
            cur.execute("""
                CREATE TABLE IF NOT EXISTS checkpoints (
                    thread_id TEXT NOT NULL,
                    checkpoint_ns TEXT NOT NULL DEFAULT '',
                    checkpoint_id TEXT NOT NULL,
                    parent_checkpoint_id TEXT,
                    type TEXT,
                    checkpoint BYTEA NOT NULL,
                    metadata BYTEA NOT NULL,
                    PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id)
                )
            """)
            
            # 3) 创建写入记录表
            cur.execute("""
                CREATE TABLE IF NOT EXISTS checkpoint_writes (
                    thread_id TEXT NOT NULL,
                    checkpoint_ns TEXT NOT NULL DEFAULT '',
                    checkpoint_id TEXT NOT NULL,
                    task_id TEXT NOT NULL,
                    idx SERIAL NOT NULL,
                    channel TEXT NOT NULL,
                    type TEXT,
                    value BYTEA,
                    PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx)
                )
            """)
            
            # 4) 创建必要的索引
            cur.execute("""
                CREATE INDEX IF NOT EXISTS checkpoints_thread_ns_idx 
                ON checkpoints (thread_id, checkpoint_ns)
            """)
            
        # 5) 提交事务
        conn.commit()

时序图

sequenceDiagram
    autonumber
    participant App as Application
    participant PS as PostgresSaver
    participant DB as PostgreSQL
    
    App->>PS: setup()
    PS->>DB: 获取连接
    DB-->>PS: 连接建立
    PS->>DB: CREATE TABLE checkpoints
    DB-->>PS: 表创建成功
    PS->>DB: CREATE TABLE checkpoint_writes
    DB-->>PS: 表创建成功
    PS->>DB: CREATE INDEX
    DB-->>PS: 索引创建成功
    PS->>DB: COMMIT
    DB-->>PS: 事务提交成功
    PS-->>App: 设置完成

2.1.2 put()

基本信息

  • 名称:put
  • 协议:方法调用 checkpointer.put(config, checkpoint, metadata, new_versions)
  • 幂等性:否(每次插入新记录)

方法签名

def put(
    self,
    config: RunnableConfig,
    checkpoint: Checkpoint,
    metadata: CheckpointMetadata,
    new_versions: ChannelVersions,
) -> RunnableConfig:
    """存储检查点到数据库"""

请求参数

参数 类型 必填 说明
config RunnableConfig 运行配置(包含thread_id等)
checkpoint Checkpoint 检查点数据
metadata CheckpointMetadata 检查点元数据
new_versions ChannelVersions 通道版本信息

返回值

字段 类型 说明
返回值 RunnableConfig 更新后的配置(包含checkpoint_id)

核心实现

def put(
    self,
    config: RunnableConfig,
    checkpoint: Checkpoint,
    metadata: CheckpointMetadata,
    new_versions: ChannelVersions,
) -> RunnableConfig:
    """存储检查点核心逻辑"""
    
    # 1) 提取配置信息
    configurable = config.get("configurable", {})
    thread_id = configurable["thread_id"]
    checkpoint_ns = configurable.get("checkpoint_ns", "")
    checkpoint_id = checkpoint["id"]
    
    # 2) 序列化数据
    serialized_checkpoint = self.serde.dumps_typed(checkpoint)
    serialized_metadata = self.serde.dumps_typed(metadata)
    
    # 3) 准备数据库操作
    with self.lock:
        with self._get_connection() as conn:
            with conn.cursor() as cur:
                
                # 4) 插入或更新检查点
                cur.execute("""
                    INSERT INTO checkpoints 
                    (thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, 
                     type, checkpoint, metadata)
                    VALUES (%s, %s, %s, %s, %s, %s, %s)
                    ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id)
                    DO UPDATE SET
                        checkpoint = EXCLUDED.checkpoint,
                        metadata = EXCLUDED.metadata
                """, (
                    thread_id,
                    checkpoint_ns, 
                    checkpoint_id,
                    configurable.get("checkpoint_id"),  # parent_checkpoint_id
                    "standard",  # type
                    serialized_checkpoint[1],  # checkpoint data
                    serialized_metadata[1]     # metadata
                ))
            
            # 5) 提交事务
            conn.commit()
    
    # 6) 返回更新的配置
    return {
        **config,
        "configurable": {
            **configurable,
            "checkpoint_id": checkpoint_id,
        }
    }

2.1.3 get_tuple()

基本信息

  • 名称:get_tuple
  • 协议:方法调用 checkpointer.get_tuple(config)
  • 幂等性:是(只读操作)

方法签名

def get_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
    """获取检查点元组"""

核心实现

def get_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
    """获取检查点元组核心逻辑"""
    
    # 1) 解析配置
    configurable = config.get("configurable", {})
    thread_id = configurable["thread_id"]
    checkpoint_ns = configurable.get("checkpoint_ns", "")
    checkpoint_id = configurable.get("checkpoint_id")
    
    # 2) 构建查询条件
    if checkpoint_id:
        # 查询特定检查点
        where_clause = """
            thread_id = %s AND checkpoint_ns = %s AND checkpoint_id = %s
        """
        params = [thread_id, checkpoint_ns, checkpoint_id]
    else:
        # 查询最新检查点
        where_clause = """
            thread_id = %s AND checkpoint_ns = %s
        """
        params = [thread_id, checkpoint_ns]
        order_clause = "ORDER BY checkpoint_id::timestamp DESC LIMIT 1"
    
    # 3) 执行查询
    with self._get_connection() as conn:
        with conn.cursor(row_factory=dict_row) as cur:
            cur.execute(f"""
                SELECT * FROM checkpoints 
                WHERE {where_clause}
                {order_clause if not checkpoint_id else ''}
            """, params)
            
            row = cur.fetchone()
            if not row:
                return None
            
            # 4) 查询关联的写入记录
            cur.execute("""
                SELECT task_id, idx, channel, type, value
                FROM checkpoint_writes
                WHERE thread_id = %s AND checkpoint_ns = %s AND checkpoint_id = %s
                ORDER BY task_id, idx
            """, (thread_id, checkpoint_ns, row["checkpoint_id"]))
            
            writes = cur.fetchall()
    
    # 5) 反序列化数据
    checkpoint = self.serde.loads_typed((row["type"], row["checkpoint"]))
    metadata = self.serde.loads_typed(("json", row["metadata"]))
    
    # 6) 处理写入记录
    pending_writes = []
    for write in writes:
        value = self.serde.loads_typed((write["type"], write["value"]))
        pending_writes.append((write["task_id"], write["channel"], value))
    
    # 7) 构建父配置
    parent_config = None
    if row["parent_checkpoint_id"]:
        parent_config = {
            "configurable": {
                "thread_id": thread_id,
                "checkpoint_ns": checkpoint_ns,
                "checkpoint_id": row["parent_checkpoint_id"]
            }
        }
    
    # 8) 返回检查点元组
    return CheckpointTuple(
        config={
            "configurable": {
                "thread_id": thread_id,
                "checkpoint_ns": checkpoint_ns,
                "checkpoint_id": row["checkpoint_id"]
            }
        },
        checkpoint=checkpoint,
        metadata=metadata,
        parent_config=parent_config,
        pending_writes=pending_writes
    )

2.1.4 list()

基本信息

  • 名称:list
  • 协议:方法调用 checkpointer.list(config, filter=None, before=None, limit=10)
  • 幂等性:是(只读操作)

方法签名

def list(
    self,
    config: RunnableConfig | None,
    *,
    filter: dict[str, Any] | None = None,
    before: RunnableConfig | None = None,
    limit: int = 10,
) -> Iterator[CheckpointTuple]:
    """列出检查点历史"""

核心实现

def list(
    self,
    config: RunnableConfig | None,
    *,
    filter: dict[str, Any] | None = None,
    before: RunnableConfig | None = None,
    limit: int = 10,
) -> Iterator[CheckpointTuple]:
    """列出检查点历史核心逻辑"""
    
    # 1) 构建基础查询条件
    where_conditions = []
    params = []
    
    if config:
        configurable = config.get("configurable", {})
        if thread_id := configurable.get("thread_id"):
            where_conditions.append("thread_id = %s")
            params.append(thread_id)
        
        if checkpoint_ns := configurable.get("checkpoint_ns"):
            where_conditions.append("checkpoint_ns = %s")
            params.append(checkpoint_ns)
    
    # 2) 处理过滤条件
    if filter:
        for key, value in filter.items():
            where_conditions.append(f"metadata->>'{key}' = %s")
            params.append(json.dumps(value))
    
    # 3) 处理 before 条件
    if before:
        before_id = before.get("configurable", {}).get("checkpoint_id")
        if before_id:
            where_conditions.append("checkpoint_id < %s")
            params.append(before_id)
    
    # 4) 构建完整查询
    where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
    
    # 5) 执行查询
    with self._get_connection() as conn:
        with conn.cursor(row_factory=dict_row) as cur:
            cur.execute(f"""
                SELECT * FROM checkpoints 
                WHERE {where_clause}
                ORDER BY checkpoint_id DESC
                LIMIT %s
            """, params + [limit])
            
            # 6) 逐行处理并生成结果
            for row in cur:
                # 查询相关写入记录
                cur.execute("""
                    SELECT task_id, idx, channel, type, value
                    FROM checkpoint_writes
                    WHERE thread_id = %s AND checkpoint_ns = %s AND checkpoint_id = %s
                    ORDER BY task_id, idx
                """, (row["thread_id"], row["checkpoint_ns"], row["checkpoint_id"]))
                
                writes = cur.fetchall()
                
                # 反序列化数据
                checkpoint = self.serde.loads_typed((row["type"], row["checkpoint"]))
                metadata = self.serde.loads_typed(("json", row["metadata"]))
                
                # 处理写入记录
                pending_writes = []
                for write in writes:
                    value = self.serde.loads_typed((write["type"], write["value"]))
                    pending_writes.append((write["task_id"], write["channel"], value))
                
                # 构建父配置
                parent_config = None
                if row["parent_checkpoint_id"]:
                    parent_config = {
                        "configurable": {
                            "thread_id": row["thread_id"],
                            "checkpoint_ns": row["checkpoint_ns"],
                            "checkpoint_id": row["parent_checkpoint_id"]
                        }
                    }
                
                # 返回检查点元组
                yield CheckpointTuple(
                    config={
                        "configurable": {
                            "thread_id": row["thread_id"],
                            "checkpoint_ns": row["checkpoint_ns"],
                            "checkpoint_id": row["checkpoint_id"]
                        }
                    },
                    checkpoint=checkpoint,
                    metadata=metadata,
                    parent_config=parent_config,
                    pending_writes=pending_writes
                )

2.2 写入管理API

2.2.1 put_writes()

基本信息

  • 名称:put_writes
  • 协议:方法调用 checkpointer.put_writes(config, writes, task_id)
  • 幂等性:否(每次添加新写入记录)

方法签名

def put_writes(
    self,
    config: RunnableConfig,
    writes: Sequence[tuple[str, Any]],
    task_id: str,
    task_path: str = "",
) -> None:
    """存储写入记录"""

核心实现

def put_writes(
    self,
    config: RunnableConfig,
    writes: Sequence[tuple[str, Any]],
    task_id: str,
    task_path: str = "",
) -> None:
    """存储写入记录核心逻辑"""
    
    # 1) 提取配置信息
    configurable = config.get("configurable", {})
    thread_id = configurable["thread_id"]
    checkpoint_ns = configurable.get("checkpoint_ns", "")
    checkpoint_id = configurable["checkpoint_id"]
    
    # 2) 准备批量插入数据
    write_records = []
    for idx, (channel, value) in enumerate(writes):
        # 序列化值
        serialized_value = self.serde.dumps_typed(value)
        
        write_records.append((
            thread_id,
            checkpoint_ns,
            checkpoint_id,
            task_id,
            idx,
            channel,
            serialized_value[0],  # type
            serialized_value[1]   # value
        ))
    
    # 3) 批量插入写入记录
    with self.lock:
        with self._get_connection() as conn:
            with conn.cursor() as cur:
                cur.executemany("""
                    INSERT INTO checkpoint_writes 
                    (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, value)
                    VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
                """, write_records)
            
            # 4) 提交事务
            conn.commit()

3. 核心算法/流程剖析

3.1 连接管理算法

目的:管理 PostgreSQL 连接的获取和释放,支持单连接和连接池模式 输入:连接对象(Connection 或 ConnectionPool) 输出:可用的数据库连接 复杂度:O(1)

@contextmanager
def _get_connection(self) -> Iterator[Connection[DictRow]]:
    """获取数据库连接的核心算法"""
    
    # 1) 判断连接类型
    if isinstance(self.conn, ConnectionPool):
        # 连接池模式:从池中获取连接
        with self.conn.connection() as conn:
            # 确保连接配置正确
            conn.autocommit = True
            conn.row_factory = dict_row
            yield conn
    
    elif isinstance(self.conn, Connection):
        # 直连模式:直接使用现有连接
        if self.conn.closed:
            raise ConnectionError("数据库连接已关闭")
        
        # 检查连接配置
        if not hasattr(self.conn, 'autocommit') or not self.conn.autocommit:
            self.conn.autocommit = True
        
        if self.conn.row_factory != dict_row:
            self.conn.row_factory = dict_row
        
        yield self.conn
    
    else:
        raise TypeError(f"不支持的连接类型: {type(self.conn)}")

关键设计点

  1. 连接池支持:自动从池中获取和归还连接
  2. 配置检查:确保 autocommit=True 和 row_factory=dict_row
  3. 错误处理:检测连接状态并提供清晰的错误信息

3.2 序列化管理算法

目的:统一管理检查点数据的序列化和反序列化过程 输入:Python 对象或序列化的字节数据 输出:序列化的字节数据或 Python 对象 复杂度:O(n),其中 n 为数据大小

class SerializationManager:
    """序列化管理器"""
    
    def __init__(self, serde: SerializerProtocol):
        self.serde = serde
    
    def serialize_checkpoint(self, checkpoint: Checkpoint) -> tuple[str, bytes]:
        """序列化检查点数据"""
        
        # 1) 使用类型化序列化
        type_name, serialized_data = self.serde.dumps_typed(checkpoint)
        
        # 2) 验证序列化结果
        if not isinstance(serialized_data, bytes):
            raise SerializationError("序列化结果必须是字节类型")
        
        return type_name, serialized_data
    
    def deserialize_checkpoint(self, type_name: str, data: bytes) -> Checkpoint:
        """反序列化检查点数据"""
        
        # 1) 类型化反序列化
        try:
            checkpoint = self.serde.loads_typed((type_name, data))
        except Exception as e:
            raise DeserializationError(f"反序列化失败: {e}")
        
        # 2) 验证反序列化结果
        if not isinstance(checkpoint, dict) or "v" not in checkpoint:
            raise DeserializationError("无效的检查点数据结构")
        
        return checkpoint
    
    def batch_serialize_writes(self, writes: Sequence[tuple[str, Any]]) -> List[tuple[str, str, bytes]]:
        """批量序列化写入记录"""
        
        serialized_writes = []
        for channel, value in writes:
            try:
                type_name, serialized_value = self.serde.dumps_typed(value)
                serialized_writes.append((channel, type_name, serialized_value))
            except Exception as e:
                # 记录序列化失败的通道和值
                logger.warning(f"写入记录序列化失败 - 通道: {channel}, 错误: {e}")
                # 可以选择跳过或使用默认值
                continue
        
        return serialized_writes

3.3 事务管理算法

目的:确保检查点操作的原子性和一致性 输入:数据库操作序列 输出:事务执行结果 复杂度:O(m),其中 m 为操作数量

class TransactionManager:
    """事务管理器"""
    
    def __init__(self, connection_getter: Callable):
        self.get_connection = connection_getter
        
    def execute_checkpoint_transaction(
        self, 
        checkpoint_data: CheckpointData,
        writes_data: List[WriteData]
    ) -> bool:
        """执行检查点事务"""
        
        try:
            with self.get_connection() as conn:
                with conn.transaction():  # 开始事务
                    cursor = conn.cursor()
                    
                    # 1) 插入或更新检查点
                    cursor.execute("""
                        INSERT INTO checkpoints (...) VALUES (...)
                        ON CONFLICT (...) DO UPDATE SET ...
                    """, checkpoint_data.to_params())
                    
                    # 2) 批量插入写入记录
                    if writes_data:
                        cursor.executemany("""
                            INSERT INTO checkpoint_writes (...) VALUES (...)
                        """, [w.to_params() for w in writes_data])
                    
                    # 3) 验证插入结果
                    cursor.execute("""
                        SELECT COUNT(*) FROM checkpoints 
                        WHERE thread_id = %s AND checkpoint_id = %s
                    """, (checkpoint_data.thread_id, checkpoint_data.checkpoint_id))
                    
                    if cursor.fetchone()[0] != 1:
                        raise TransactionError("检查点插入验证失败")
                    
                    # 事务自动提交(由 with conn.transaction() 管理)
                    return True
                    
        except Exception as e:
            logger.error(f"检查点事务执行失败: {e}")
            # 事务自动回滚
            return False

4. 模块级架构图与时序图

4.1 整体数据流架构图

flowchart LR
    subgraph "LangGraph Core"
        A[StateGraph] --> B[CompiledStateGraph]
        B --> C[Pregel]
    end
    
    subgraph "Checkpoint Layer"
        C --> D[BaseCheckpointSaver]
        D --> E[PostgresSaver]
    end
    
    subgraph "Database Layer"
        E --> F[Connection Manager]
        F --> G[PostgreSQL DB]
        
        E --> H[Serialization]
        H --> I[JsonPlusSerializer]
        
        E --> J[Transaction Manager]
        J --> G
    end
    
    subgraph "Storage Schema"
        G --> K[checkpoints Table]
        G --> L[checkpoint_writes Table]
        G --> M[Indexes]
    end
    
    subgraph "Async Support"
        N[AsyncPostgresSaver] --> O[Async Connection]
        O --> G
    end

4.2 完整检查点生命周期时序图

sequenceDiagram
    autonumber
    participant App as Application
    participant Graph as StateGraph
    participant Saver as PostgresSaver
    participant Conn as Connection Manager
    participant DB as PostgreSQL
    participant Ser as Serializer
    
    Note over App,Ser: 1. 初始化阶段
    App->>Saver: PostgresSaver(conn, serde)
    Saver->>Conn: 初始化连接管理
    Saver->>Saver: 设置线程锁
    
    Note over App,Ser: 2. 设置阶段
    App->>Saver: setup()
    Saver->>Conn: get_connection()
    Conn->>DB: 建立连接
    DB-->>Conn: 连接就绪
    Saver->>DB: CREATE TABLE checkpoints
    Saver->>DB: CREATE TABLE checkpoint_writes
    Saver->>DB: CREATE INDEX
    DB-->>Saver: 表和索引创建成功
    
    Note over App,Ser: 3. 运行阶段 - 存储检查点
    Graph->>Saver: put(config, checkpoint, metadata, versions)
    Saver->>Ser: dumps_typed(checkpoint)
    Ser-->>Saver: 序列化数据
    Saver->>Ser: dumps_typed(metadata)
    Ser-->>Saver: 序列化元数据
    
    Saver->>Conn: get_connection()
    Conn->>DB: 获取连接
    DB-->>Conn: 连接准备就绪
    
    Saver->>DB: BEGIN TRANSACTION
    Saver->>DB: INSERT INTO checkpoints
    DB-->>Saver: 插入成功
    Saver->>DB: COMMIT
    DB-->>Saver: 事务提交成功
    
    Saver-->>Graph: 返回更新的配置
    
    Note over App,Ser: 4. 运行阶段 - 存储写入记录
    Graph->>Saver: put_writes(config, writes, task_id)
    Saver->>Ser: 批量序列化写入数据
    Ser-->>Saver: 序列化完成
    
    Saver->>DB: BEGIN TRANSACTION
    Saver->>DB: INSERT INTO checkpoint_writes (批量)
    DB-->>Saver: 批量插入成功
    Saver->>DB: COMMIT
    DB-->>Saver: 事务提交成功
    
    Note over App,Ser: 5. 查询阶段
    App->>Saver: get_tuple(config)
    Saver->>Conn: get_connection()
    Conn->>DB: 获取连接
    
    Saver->>DB: SELECT FROM checkpoints
    DB-->>Saver: 检查点数据行
    Saver->>DB: SELECT FROM checkpoint_writes
    DB-->>Saver: 写入记录行
    
    Saver->>Ser: loads_typed(checkpoint_data)
    Ser-->>Saver: 反序列化检查点
    Saver->>Ser: loads_typed(metadata)
    Ser-->>Saver: 反序列化元数据
    Saver->>Ser: 批量反序列化写入数据
    Ser-->>Saver: 反序列化完成
    
    Saver->>Saver: 构建 CheckpointTuple
    Saver-->>App: 返回完整检查点元组

4.3 异步操作时序图

sequenceDiagram
    autonumber
    participant App as Async Application
    participant ASaver as AsyncPostgresSaver
    participant AConn as Async Connection
    participant DB as PostgreSQL
    
    App->>ASaver: asetup()
    ASaver->>AConn: 获取异步连接
    AConn->>DB: 建立异步连接
    DB-->>AConn: 连接建立
    ASaver->>DB: 异步创建表和索引
    DB-->>ASaver: 创建完成
    
    App->>ASaver: aput(config, checkpoint, metadata, versions)
    ASaver->>ASaver: 异步序列化数据
    ASaver->>AConn: 获取异步连接
    AConn->>DB: 获取连接
    ASaver->>DB: 异步 INSERT
    DB-->>ASaver: 插入完成
    ASaver-->>App: 返回结果
    
    App->>ASaver: aget_tuple(config)
    ASaver->>AConn: 获取异步连接
    ASaver->>DB: 异步 SELECT
    DB-->>ASaver: 查询结果
    ASaver->>ASaver: 异步反序列化
    ASaver-->>App: 返回 CheckpointTuple

5. 异常处理与性能优化

5.1 异常处理策略

5.1.1 连接异常处理

class ConnectionManager:
    """连接管理器异常处理"""
    
    def __init__(self, conn: Conn, max_retries: int = 3, retry_delay: float = 1.0):
        self.conn = conn
        self.max_retries = max_retries
        self.retry_delay = retry_delay
    
    @contextmanager
    def get_connection_with_retry(self) -> Iterator[Connection[DictRow]]:
        """带重试的连接获取"""
        last_exception = None
        
        for attempt in range(self.max_retries):
            try:
                with self._get_connection() as conn:
                    yield conn
                return
                
            except (ConnectionError, OperationalError) as e:
                last_exception = e
                logger.warning(f"连接失败,第 {attempt + 1} 次尝试: {e}")
                
                if attempt < self.max_retries - 1:
                    time.sleep(self.retry_delay * (2 ** attempt))  # 指数退避
                
            except Exception as e:
                # 非连接异常直接抛出
                logger.error(f"非连接异常: {e}")
                raise
        
        # 所有重试都失败
        raise ConnectionError(f"连接失败,已重试 {self.max_retries} 次: {last_exception}")
    
    def health_check(self) -> bool:
        """连接健康检查"""
        try:
            with self._get_connection() as conn:
                with conn.cursor() as cur:
                    cur.execute("SELECT 1")
                    return cur.fetchone() is not None
        except Exception as e:
            logger.error(f"健康检查失败: {e}")
            return False

5.1.2 序列化异常处理

class SafeSerializationManager:
    """安全序列化管理器"""
    
    def __init__(self, serde: SerializerProtocol, fallback_serde: SerializerProtocol = None):
        self.serde = serde
        self.fallback_serde = fallback_serde or JsonPlusSerializer()
    
    def safe_serialize(self, obj: Any) -> tuple[str, bytes]:
        """安全序列化"""
        try:
            return self.serde.dumps_typed(obj)
        except Exception as e:
            logger.warning(f"主序列化器失败,使用备用序列化器: {e}")
            try:
                return self.fallback_serde.dumps_typed(obj)
            except Exception as fallback_e:
                logger.error(f"备用序列化器也失败: {fallback_e}")
                # 最后的备案:存储错误信息
                error_info = {
                    "error": "serialization_failed",
                    "original_type": type(obj).__name__,
                    "primary_error": str(e),
                    "fallback_error": str(fallback_e)
                }
                return self.fallback_serde.dumps_typed(error_info)
    
    def safe_deserialize(self, type_name: str, data: bytes) -> Any:
        """安全反序列化"""
        try:
            result = self.serde.loads_typed((type_name, data))
            
            # 检查是否是错误占位符
            if isinstance(result, dict) and result.get("error") == "serialization_failed":
                logger.warning("检测到序列化失败的占位数据")
                return None
            
            return result
            
        except Exception as e:
            logger.error(f"反序列化失败: {e}")
            # 尝试用备用序列化器
            try:
                return self.fallback_serde.loads_typed((type_name, data))
            except Exception:
                logger.error("备用反序列化也失败,返回 None")
                return None

5.2 性能优化策略

5.2.1 批处理优化

class BatchOptimizer:
    """批处理优化器"""
    
    def __init__(self, batch_size: int = 100, flush_interval: float = 5.0):
        self.batch_size = batch_size
        self.flush_interval = flush_interval
        self.pending_writes = []
        self.last_flush = time.time()
        self.lock = threading.Lock()
    
    def add_write(self, write_data: WriteData):
        """添加写入操作到批处理队列"""
        with self.lock:
            self.pending_writes.append(write_data)
            
            # 检查是否需要刷新
            if (len(self.pending_writes) >= self.batch_size or 
                time.time() - self.last_flush >= self.flush_interval):
                self._flush_writes()
    
    def _flush_writes(self):
        """刷新写入队列"""
        if not self.pending_writes:
            return
        
        writes_to_flush = self.pending_writes[:]
        self.pending_writes.clear()
        self.last_flush = time.time()
        
        try:
            self._execute_batch_writes(writes_to_flush)
        except Exception as e:
            logger.error(f"批处理写入失败: {e}")
            # 将失败的写入重新加入队列(可选)
            self.pending_writes.extend(writes_to_flush)
    
    def _execute_batch_writes(self, writes: List[WriteData]):
        """执行批量写入"""
        # 按线程ID和检查点ID分组
        grouped_writes = defaultdict(list)
        for write in writes:
            key = (write.thread_id, write.checkpoint_id)
            grouped_writes[key].append(write)
        
        # 为每组执行批量插入
        for (thread_id, checkpoint_id), group_writes in grouped_writes.items():
            self._batch_insert_writes(group_writes)

5.2.2 连接池优化

class OptimizedConnectionPool:
    """优化的连接池管理"""
    
    def __init__(
        self, 
        conn_string: str,
        min_size: int = 2,
        max_size: int = 10,
        max_lifetime: float = 3600,  # 1小时
        max_idle_time: float = 300,  # 5分钟
    ):
        self.pool = ConnectionPool(
            conn_string,
            min_size=min_size,
            max_size=max_size,
            kwargs={
                "autocommit": True,
                "row_factory": dict_row
            }
        )
        self.max_lifetime = max_lifetime
        self.max_idle_time = max_idle_time
        self._monitoring_thread = None
        self._start_monitoring()
    
    def _start_monitoring(self):
        """启动连接池监控"""
        def monitor():
            while True:
                try:
                    # 检查连接健康状态
                    self.pool.check()
                    
                    # 清理过期连接
                    self._cleanup_expired_connections()
                    
                    time.sleep(30)  # 30秒检查一次
                    
                except Exception as e:
                    logger.error(f"连接池监控异常: {e}")
        
        self._monitoring_thread = threading.Thread(target=monitor, daemon=True)
        self._monitoring_thread.start()
    
    def _cleanup_expired_connections(self):
        """清理过期连接"""
        # 这里可以实现更复杂的连接清理逻辑
        # 例如基于连接创建时间和空闲时间的清理
        pass
    
    def get_stats(self) -> dict:
        """获取连接池统计信息"""
        return {
            "pool_size": self.pool.get_stats()["pool_size"],
            "available": self.pool.get_stats()["available"],
            "used": self.pool.get_stats()["used"],
        }

5.2.3 查询优化

class QueryOptimizer:
    """查询优化器"""
    
    def __init__(self):
        self.query_cache = {}
        self.cache_ttl = 300  # 5分钟缓存
    
    def get_optimized_list_query(
        self, 
        config: RunnableConfig,
        filter_conditions: List[str],
        params: List[Any],
        limit: int
    ) -> tuple[str, List[Any]]:
        """生成优化的列表查询"""
        
        # 1) 基础查询构建
        base_query = """
            SELECT c.*, 
                   array_agg(
                       json_build_object(
                           'task_id', w.task_id,
                           'idx', w.idx,
                           'channel', w.channel,
                           'type', w.type,
                           'value', w.value
                       ) ORDER BY w.task_id, w.idx
                   ) FILTER (WHERE w.task_id IS NOT NULL) as writes
            FROM checkpoints c
            LEFT JOIN checkpoint_writes w ON (
                c.thread_id = w.thread_id AND 
                c.checkpoint_ns = w.checkpoint_ns AND 
                c.checkpoint_id = w.checkpoint_id
            )
        """
        
        # 2) 添加 WHERE 条件
        where_clause = " AND ".join(filter_conditions) if filter_conditions else "1=1"
        
        # 3) 添加分组和排序
        group_order_clause = """
            GROUP BY c.thread_id, c.checkpoint_ns, c.checkpoint_id, 
                     c.parent_checkpoint_id, c.type, c.checkpoint, c.metadata
            ORDER BY c.checkpoint_id DESC
            LIMIT %s
        """
        
        # 4) 组合完整查询
        full_query = f"{base_query} WHERE {where_clause} {group_order_clause}"
        full_params = params + [limit]
        
        return full_query, full_params
    
    def create_indexes_for_common_queries(self, connection):
        """为常见查询创建索引"""
        indexes = [
            # 线程ID和命名空间的复合索引
            """
            CREATE INDEX IF NOT EXISTS idx_checkpoints_thread_ns_id 
            ON checkpoints (thread_id, checkpoint_ns, checkpoint_id DESC)
            """,
            
            # 写入记录的复合索引
            """
            CREATE INDEX IF NOT EXISTS idx_writes_thread_checkpoint 
            ON checkpoint_writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx)
            """,
            
            # 元数据查询索引(GIN索引支持JSON查询)
            """
            CREATE INDEX IF NOT EXISTS idx_checkpoints_metadata 
            ON checkpoints USING gin (metadata)
            """,
        ]
        
        with connection.cursor() as cur:
            for index_sql in indexes:
                try:
                    cur.execute(index_sql)
                except Exception as e:
                    logger.warning(f"索引创建失败: {e}")

6. 总结

checkpoint-postgres 模块作为 LangGraph 的核心持久化组件,提供了高效、可靠的 PostgreSQL 检查点存储实现。其主要特点包括:

6.1 核心优势

  • 高性能:利用 PostgreSQL 的事务特性和索引优化
  • 可扩展性:支持连接池和异步操作
  • 数据完整性:完善的事务管理和异常处理
  • 灵活性:支持多种连接模式和配置选项

6.2 架构特点

  • 分层设计:清晰的接口抽象和实现分离
  • 异步支持:完整的异步操作实现
  • 批处理优化:高效的批量数据操作
  • 监控友好:完善的日志和性能指标

6.3 应用价值

该模块为需要持久化状态的 LangGraph 应用提供了生产级的数据存储解决方案,特别适合需要高并发、高可靠性的企业级应用场景。通过 PostgreSQL 的强大功能,实现了检查点数据的高效管理和查询能力。