1. 项目架构设计最佳实践

1.1 状态设计原则

1.1.1 状态结构设计

"""
状态设计最佳实践
"""

from typing import TypedDict, Annotated, List, Optional, Dict, Any
from langgraph.graph import add_messages
from langchain_core.messages import AnyMessage

# ❌ 不好的状态设计
class BadState(TypedDict):
    everything: dict  # 将所有数据都放在一个字典中
    data: Any        # 使用模糊的类型

# ✅ 好的状态设计
class GoodState(TypedDict):
    """
    良好的状态设计特点:
    1. 清晰的字段定义和类型注解
    2. 合理的数据分组
    3. 必要的元数据
    4. 适当的默认值
    """
    
    # 核心业务数据
    messages: Annotated[List[AnyMessage], add_messages]
    user_context: Dict[str, str]
    task_result: Optional[str]
    
    # 流程控制
    current_stage: str
    retry_count: int
    max_retries: int
    
    # 元数据
    session_id: str
    created_at: str
    updated_at: str
    
    # 可选状态
    debug_info: Optional[Dict[str, Any]]
    performance_metrics: Optional[Dict[str, float]]

# 状态设计准则
class StateDesignPrinciples:
    """
    状态设计核心准则
    
    1. 单一职责:每个字段都有明确的用途
    2. 类型安全:使用具体的类型注解
    3. 向前兼容:使用Optional标记可选字段
    4. 可扩展性:预留扩展空间
    5. 可序列化:确保状态可以被检查点系统序列化
    """
    
    @staticmethod
    def validate_state_design(state_class: type) -> List[str]:
        """验证状态设计的合理性"""
        issues = []
        
        # 检查类型注解
        annotations = getattr(state_class, '__annotations__', {})
        if not annotations:
            issues.append("缺少类型注解")
        
        # 检查字段命名
        for field_name in annotations:
            if not field_name.replace('_', '').isalnum():
                issues.append(f"字段名 {field_name} 包含特殊字符")
            
            if field_name.startswith('__'):
                issues.append(f"字段名 {field_name} 不应以双下划线开头")
        
        return issues

1.1.2 状态更新策略

"""
状态更新最佳实践
"""

def good_node_function(state: GoodState) -> dict:
    """
    良好的节点函数特点:
    1. 只更新必要的字段
    2. 保持状态的一致性
    3. 提供清晰的更新日志
    """
    
    current_stage = state.get("current_stage", "unknown")
    retry_count = state.get("retry_count", 0)
    
    # 执行业务逻辑
    try:
        result = perform_business_logic(state)
        
        # 返回最小化的状态更新
        return {
            "task_result": result,
            "current_stage": "completed",
            "updated_at": datetime.now().isoformat(),
            "performance_metrics": {
                "processing_time": 1.23,
                "memory_usage": 45.6
            }
        }
        
    except Exception as e:
        # 错误处理时的状态更新
        return {
            "current_stage": "error",
            "retry_count": retry_count + 1,
            "updated_at": datetime.now().isoformat(),
            "debug_info": {
                "error": str(e),
                "error_type": type(e).__name__,
                "traceback": traceback.format_exc()
            }
        }

def bad_node_function(state: GoodState) -> dict:
    """
    ❌ 避免的做法:
    1. 更新不相关的字段
    2. 直接修改输入状态
    3. 返回整个状态副本
    """
    
    # 错误做法1:修改输入状态
    state["current_stage"] = "processing"  # ❌
    
    # 错误做法2:返回不必要的字段
    return dict(state)  # ❌ 返回整个状态

# 状态更新最佳实践总结
class StateUpdateBestPractices:
    """状态更新最佳实践汇总"""
    
    @staticmethod
    def create_minimal_update(updates: dict) -> dict:
        """创建最小化的状态更新"""
        # 移除None值
        return {k: v for k, v in updates.items() if v is not None}
    
    @staticmethod
    def add_timestamp(update: dict) -> dict:
        """为状态更新添加时间戳"""
        update["updated_at"] = datetime.now().isoformat()
        return update
    
    @staticmethod
    def validate_update(update: dict, state_schema: type) -> bool:
        """验证状态更新的合法性"""
        schema_fields = getattr(state_schema, '__annotations__', {})
        
        for key in update:
            if key not in schema_fields:
                logging.warning(f"Unknown state field: {key}")
                return False
        
        return True

1.2 节点设计模式

1.2.1 单一职责节点

"""
节点设计最佳实践:单一职责原则
"""

# ✅ 好的节点设计:单一职责
def data_validator_node(state: ProcessingState) -> dict:
    """
    数据验证节点:仅负责数据验证
    
    职责:
    1. 验证输入数据格式
    2. 检查数据完整性
    3. 返回验证结果
    """
    
    data = state["raw_data"]
    
    # 专注于验证逻辑
    validation_errors = []
    
    if not data:
        validation_errors.append("数据为空")
    
    if not isinstance(data, dict):
        validation_errors.append("数据格式不正确")
    
    # 返回验证结果
    return {
        "validation_errors": validation_errors,
        "is_valid": len(validation_errors) == 0,
        "validation_timestamp": datetime.now().isoformat()
    }

def data_processor_node(state: ProcessingState) -> dict:
    """
    数据处理节点:仅负责数据转换
    
    职责:
    1. 转换数据格式
    2. 执行业务逻辑
    3. 生成处理结果
    """
    
    if not state.get("is_valid", False):
        return {"error": "数据未通过验证,无法处理"}
    
    raw_data = state["raw_data"]
    
    # 专注于处理逻辑
    processed_data = transform_data(raw_data)
    
    return {
        "processed_data": processed_data,
        "processing_timestamp": datetime.now().isoformat()
    }

# ❌ 避免的设计:多职责节点
def bad_all_in_one_node(state: ProcessingState) -> dict:
    """
    不好的设计:一个节点做太多事情
    """
    
    # 验证、处理、保存、通知 - 太多职责
    data = state["raw_data"]
    
    # 验证
    if not validate_data(data):
        return {"error": "验证失败"}
    
    # 处理
    processed = process_data(data)
    
    # 保存
    save_to_database(processed)
    
    # 发送通知
    send_notification(processed)
    
    return {"result": "everything done"}  # 职责不明确

1.2.2 错误处理模式

"""
节点错误处理最佳实践
"""

from typing import Union, Tuple
import logging

class NodeErrorHandler:
    """节点错误处理器"""
    
    @staticmethod
    def with_retry(max_retries: int = 3):
        """重试装饰器"""
        def decorator(node_func):
            def wrapper(state: dict) -> dict:
                retry_count = state.get("retry_count", 0)
                
                if retry_count >= max_retries:
                    return {
                        "error": f"重试次数超过限制 ({max_retries})",
                        "retry_count": retry_count,
                        "status": "failed"
                    }
                
                try:
                    result = node_func(state)
                    # 成功时重置重试计数
                    result["retry_count"] = 0
                    return result
                    
                except Exception as e:
                    logging.error(f"节点执行失败: {e}")
                    return {
                        "error": str(e),
                        "retry_count": retry_count + 1,
                        "status": "retry_needed"
                    }
            
            return wrapper
        return decorator
    
    @staticmethod
    def safe_execute(node_func, state: dict, fallback_value: Any = None) -> dict:
        """安全执行节点函数"""
        try:
            return node_func(state)
        except Exception as e:
            logging.error(f"节点执行异常: {e}")
            return {
                "error": str(e),
                "fallback_result": fallback_value,
                "status": "error",
                "error_type": type(e).__name__
            }

# 使用示例
@NodeErrorHandler.with_retry(max_retries=3)
def robust_api_call_node(state: dict) -> dict:
    """具有重试机制的API调用节点"""
    
    url = state["api_url"]
    payload = state["request_payload"]
    
    # 可能失败的API调用
    response = requests.post(url, json=payload, timeout=10)
    response.raise_for_status()
    
    return {
        "api_response": response.json(),
        "status": "success"
    }

def graceful_node(state: dict) -> dict:
    """优雅处理错误的节点示例"""
    
    try:
        # 主要逻辑
        result = complex_operation(state["input_data"])
        
        return {
            "result": result,
            "status": "success",
            "execution_time": time.time() - start_time
        }
        
    except ValidationError as e:
        # 预期的业务错误
        return {
            "error": f"数据验证失败: {e}",
            "error_code": "VALIDATION_ERROR",
            "status": "failed",
            "recoverable": True
        }
        
    except ConnectionError as e:
        # 网络错误,可重试
        return {
            "error": f"网络连接失败: {e}",
            "error_code": "CONNECTION_ERROR", 
            "status": "retry_needed",
            "recoverable": True
        }
        
    except Exception as e:
        # 未预期的错误
        return {
            "error": f"未知错误: {e}",
            "error_code": "UNKNOWN_ERROR",
            "status": "failed",
            "recoverable": False,
            "debug_info": traceback.format_exc()
        }

1.3 图结构设计

1.3.1 条件路由最佳实践

"""
条件路由设计最佳实践
"""

from typing import Literal

# ✅ 清晰的条件路由函数
def business_logic_router(state: BusinessState) -> Literal["process", "validate", "error"]:
    """
    业务逻辑路由器
    
    设计原则:
    1. 路由逻辑清晰明确
    2. 覆盖所有可能情况
    3. 提供默认分支
    4. 使用类型提示限制返回值
    """
    
    # 错误状态检查
    if state.get("has_error", False):
        return "error"
    
    # 验证状态检查
    validation_status = state.get("validation_status", "pending")
    if validation_status == "failed":
        return "validate"
    elif validation_status == "passed":
        return "process"
    
    # 默认路由
    return "validate"

def complex_workflow_router(state: WorkflowState) -> Literal["step1", "step2", "step3", "complete", "error"]:
    """
    复杂工作流路由器
    
    处理多步骤工作流的路由决策
    """
    
    current_step = state.get("current_step", 0)
    has_errors = bool(state.get("errors", []))
    
    if has_errors:
        return "error"
    
    if current_step == 0:
        return "step1"
    elif current_step == 1:
        return "step2"
    elif current_step == 2:
        return "step3"
    else:
        return "complete"

# ❌ 避免的路由设计
def bad_router(state: dict) -> str:
    """
    不好的路由设计示例
    """
    
    # 问题1:逻辑过于复杂
    if (state.get("status") == "ok" and 
        len(state.get("data", [])) > 0 and
        state.get("user_tier") in ["vip", "premium"] and
        datetime.now().hour < 18):
        return "complex_processing"
    
    # 问题2:魔法数字
    if state.get("score", 0) > 0.85:
        return "high_quality"
    
    # 问题3:缺少默认分支
    # 如果以上条件都不满足会怎样?

# 路由设计最佳实践
class RoutingBestPractices:
    """路由设计最佳实践"""
    
    @staticmethod
    def create_state_based_router(state_field: str, route_map: dict):
        """基于状态字段的路由器工厂"""
        def router(state: dict) -> str:
            field_value = state.get(state_field)
            return route_map.get(field_value, route_map.get("default", "error"))
        return router
    
    @staticmethod
    def create_threshold_router(score_field: str, thresholds: dict):
        """基于阈值的路由器工厂"""
        def router(state: dict) -> str:
            score = state.get(score_field, 0)
            
            for threshold, route in sorted(thresholds.items(), reverse=True):
                if score >= threshold:
                    return route
            
            return "default"
        return router

# 使用示例
quality_router = RoutingBestPractices.create_threshold_router(
    score_field="quality_score",
    thresholds={
        0.9: "excellent",
        0.7: "good", 
        0.5: "acceptable",
        0.0: "poor"
    }
)

status_router = RoutingBestPractices.create_state_based_router(
    state_field="processing_status",
    route_map={
        "pending": "process",
        "completed": "finalize",
        "error": "handle_error",
        "default": "initialize"
    }
)

2. 性能优化实战经验

2.1 检查点优化策略

"""
检查点性能优化实战经验
"""

from langgraph.checkpoint.memory import InMemorySaver
from langgraph.checkpoint.base import BaseCheckpointSaver
import asyncio
import time

class OptimizedCheckpointSaver(BaseCheckpointSaver):
    """优化的检查点保存器"""
    
    def __init__(self, batch_size: int = 10, batch_timeout: float = 1.0):
        super().__init__()
        self.batch_size = batch_size
        self.batch_timeout = batch_timeout
        self.pending_writes = []
        self.last_flush = time.time()
    
    def put(self, config, checkpoint, metadata, new_versions):
        """批量写入优化"""
        
        # 添加到待写入队列
        self.pending_writes.append({
            "config": config,
            "checkpoint": checkpoint,
            "metadata": metadata,
            "new_versions": new_versions,
            "timestamp": time.time()
        })
        
        # 检查是否需要刷新
        should_flush = (
            len(self.pending_writes) >= self.batch_size or
            time.time() - self.last_flush > self.batch_timeout
        )
        
        if should_flush:
            self._flush_batch()
        
        return config
    
    def _flush_batch(self):
        """刷新批次写入"""
        if not self.pending_writes:
            return
        
        # 批量写入逻辑
        batch = self.pending_writes[:]
        self.pending_writes.clear()
        self.last_flush = time.time()
        
        # 实际写入(这里需要具体实现)
        self._batch_write_to_storage(batch)

class CheckpointOptimizer:
    """检查点优化工具类"""
    
    @staticmethod
    def minimize_state_size(state: dict) -> dict:
        """最小化状态大小"""
        optimized_state = {}
        
        for key, value in state.items():
            # 移除大型临时数据
            if key.startswith("_temp_") or key.startswith("_cache_"):
                continue
            
            # 压缩大型字符串
            if isinstance(value, str) and len(value) > 1000:
                # 可以使用压缩算法
                optimized_state[key] = compress_string(value)
            else:
                optimized_state[key] = value
        
        return optimized_state
    
    @staticmethod
    def should_create_checkpoint(state: dict, last_checkpoint_time: float) -> bool:
        """智能检查点创建决策"""
        
        current_time = time.time()
        time_since_last = current_time - last_checkpoint_time
        
        # 基于时间的检查点创建
        if time_since_last > 60:  # 1分钟
            return True
        
        # 基于状态变化的检查点创建
        significant_changes = [
            state.get("task_completed", False),
            state.get("error_occurred", False),
            state.get("user_interaction", False)
        ]
        
        return any(significant_changes)

def compress_string(text: str) -> dict:
    """字符串压缩工具"""
    import gzip
    import base64
    
    compressed = gzip.compress(text.encode('utf-8'))
    encoded = base64.b64encode(compressed).decode('ascii')
    
    return {
        "type": "compressed_string",
        "data": encoded,
        "original_size": len(text)
    }

2.2 内存管理最佳实践

"""
内存管理最佳实践
"""

import gc
import psutil
import weakref
from typing import Any, Dict, List

class MemoryManager:
    """内存管理工具类"""
    
    def __init__(self, memory_limit_mb: int = 1024):
        self.memory_limit_mb = memory_limit_mb
        self.large_objects: List[weakref.ref] = []
    
    def monitor_memory_usage(self) -> Dict[str, float]:
        """监控内存使用情况"""
        process = psutil.Process()
        memory_info = process.memory_info()
        
        return {
            "rss_mb": memory_info.rss / 1024 / 1024,  # 物理内存
            "vms_mb": memory_info.vms / 1024 / 1024,  # 虚拟内存
            "percent": process.memory_percent(),       # 内存使用百分比
            "available_mb": psutil.virtual_memory().available / 1024 / 1024
        }
    
    def cleanup_if_needed(self) -> bool:
        """必要时清理内存"""
        memory_stats = self.monitor_memory_usage()
        
        if memory_stats["rss_mb"] > self.memory_limit_mb:
            # 执行内存清理
            return self._perform_cleanup()
        
        return False
    
    def _perform_cleanup(self) -> bool:
        """执行内存清理"""
        initial_memory = self.monitor_memory_usage()["rss_mb"]
        
        # 清理弱引用对象
        self.large_objects = [ref for ref in self.large_objects if ref() is not None]
        
        # 强制垃圾回收
        gc.collect()
        
        final_memory = self.monitor_memory_usage()["rss_mb"]
        freed_mb = initial_memory - final_memory
        
        return freed_mb > 0
    
    def register_large_object(self, obj: Any) -> None:
        """注册大对象用于管理"""
        self.large_objects.append(weakref.ref(obj))

class StateMemoryOptimizer:
    """状态内存优化器"""
    
    @staticmethod
    def optimize_message_history(messages: List[Any], max_messages: int = 50) -> List[Any]:
        """优化消息历史长度"""
        if len(messages) <= max_messages:
            return messages
        
        # 保留最新的消息
        return messages[-max_messages:]
    
    @staticmethod
    def compress_large_data(state: dict, size_threshold: int = 1000) -> dict:
        """压缩大型数据"""
        optimized = {}
        
        for key, value in state.items():
            if isinstance(value, str) and len(value) > size_threshold:
                optimized[key] = {
                    "type": "large_string",
                    "size": len(value),
                    "hash": hash(value),
                    "compressed": True
                }
            elif isinstance(value, (list, dict)) and len(str(value)) > size_threshold:
                optimized[key] = {
                    "type": "large_object",
                    "size": len(str(value)),
                    "compressed": True
                }
            else:
                optimized[key] = value
        
        return optimized

# 在节点中使用内存优化
memory_manager = MemoryManager(memory_limit_mb=512)

def memory_optimized_node(state: dict) -> dict:
    """内存优化的节点实现"""
    
    # 执行前检查内存
    memory_manager.cleanup_if_needed()
    
    # 优化输入状态
    optimized_state = StateMemoryOptimizer.compress_large_data(state)
    
    try:
        # 执行主要逻辑
        result = process_data(optimized_state)
        
        # 注册大对象
        if isinstance(result, dict) and len(str(result)) > 1000:
            memory_manager.register_large_object(result)
        
        return {"result": result}
        
    finally:
        # 执行后清理
        memory_manager.cleanup_if_needed()

2.3 并发处理优化

"""
并发处理优化实践
"""

import asyncio
import concurrent.futures
from typing import List, Callable, Any
import threading
import queue

class ConcurrentNodeExecutor:
    """并发节点执行器"""
    
    def __init__(self, max_workers: int = 4):
        self.max_workers = max_workers
        self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
    
    def execute_parallel_nodes(
        self, 
        node_functions: List[Callable],
        shared_state: dict
    ) -> List[dict]:
        """并行执行多个节点"""
        
        futures = []
        for node_func in node_functions:
            future = self.executor.submit(self._safe_execute_node, node_func, shared_state.copy())
            futures.append(future)
        
        results = []
        for future in concurrent.futures.as_completed(futures):
            try:
                result = future.result(timeout=30)  # 30秒超时
                results.append(result)
            except Exception as e:
                results.append({"error": str(e)})
        
        return results
    
    def _safe_execute_node(self, node_func: Callable, state: dict) -> dict:
        """安全执行单个节点"""
        try:
            return node_func(state)
        except Exception as e:
            return {
                "error": str(e),
                "node_function": node_func.__name__
            }

class AsyncStateProcessor:
    """异步状态处理器"""
    
    def __init__(self):
        self.processing_queue = asyncio.Queue(maxsize=100)
        self.results_cache = {}
    
    async def process_states_async(self, states: List[dict]) -> List[dict]:
        """异步处理多个状态"""
        
        tasks = []
        for i, state in enumerate(states):
            task = asyncio.create_task(self._process_single_state(state, i))
            tasks.append(task)
        
        results = await asyncio.gather(*tasks, return_exceptions=True)
        
        # 处理异常结果
        processed_results = []
        for result in results:
            if isinstance(result, Exception):
                processed_results.append({"error": str(result)})
            else:
                processed_results.append(result)
        
        return processed_results
    
    async def _process_single_state(self, state: dict, index: int) -> dict:
        """处理单个状态"""
        
        # 模拟异步处理
        await asyncio.sleep(0.1)  # 模拟I/O操作
        
        # 缓存检查
        state_hash = hash(str(sorted(state.items())))
        if state_hash in self.results_cache:
            return self.results_cache[state_hash]
        
        # 实际处理
        result = {
            "processed_state": state,
            "index": index,
            "timestamp": time.time()
        }
        
        # 缓存结果
        self.results_cache[state_hash] = result
        
        return result

# 工具节点并发优化示例
class OptimizedToolNode:
    """优化的工具节点"""
    
    def __init__(self, tools: List[Any], max_concurrent: int = 3):
        self.tools = {tool.name: tool for tool in tools}
        self.max_concurrent = max_concurrent
        self.semaphore = asyncio.Semaphore(max_concurrent)
    
    async def execute_tools_async(self, tool_calls: List[dict]) -> List[dict]:
        """异步并发执行工具调用"""
        
        tasks = []
        for tool_call in tool_calls:
            task = asyncio.create_task(self._execute_single_tool(tool_call))
            tasks.append(task)
        
        results = await asyncio.gather(*tasks, return_exceptions=True)
        return [self._format_result(r) for r in results]
    
    async def _execute_single_tool(self, tool_call: dict) -> dict:
        """执行单个工具调用"""
        
        async with self.semaphore:  # 限制并发数
            tool_name = tool_call["name"]
            tool_args = tool_call.get("args", {})
            
            if tool_name not in self.tools:
                raise ValueError(f"Unknown tool: {tool_name}")
            
            tool = self.tools[tool_name]
            
            # 在线程池中执行同步工具
            loop = asyncio.get_event_loop()
            result = await loop.run_in_executor(None, tool.invoke, tool_args)
            
            return {
                "tool_call_id": tool_call["id"],
                "tool_name": tool_name,
                "result": result
            }
    
    def _format_result(self, result: Any) -> dict:
        """格式化结果"""
        if isinstance(result, Exception):
            return {"error": str(result)}
        return result

3. 错误处理和调试策略

3.1 全面的错误处理体系

"""
全面的错误处理体系
"""

import traceback
import logging
from enum import Enum
from typing import Optional, Dict, Any
from contextlib import contextmanager

class ErrorSeverity(Enum):
    """错误严重程度"""
    LOW = "low"           # 警告级别,不影响主流程
    MEDIUM = "medium"     # 错误级别,影响当前操作
    HIGH = "high"         # 严重错误,影响整个会话
    CRITICAL = "critical" # 关键错误,需要立即处理

class LangGraphError(Exception):
    """LangGraph自定义错误基类"""
    
    def __init__(
        self, 
        message: str,
        error_code: str = "UNKNOWN",
        severity: ErrorSeverity = ErrorSeverity.MEDIUM,
        context: Optional[Dict[str, Any]] = None,
        recoverable: bool = True
    ):
        super().__init__(message)
        self.message = message
        self.error_code = error_code
        self.severity = severity
        self.context = context or {}
        self.recoverable = recoverable
        self.timestamp = time.time()

class NodeExecutionError(LangGraphError):
    """节点执行错误"""
    
    def __init__(self, node_name: str, original_error: Exception, **kwargs):
        super().__init__(
            f"Node '{node_name}' execution failed: {str(original_error)}",
            error_code="NODE_EXECUTION_ERROR",
            **kwargs
        )
        self.node_name = node_name
        self.original_error = original_error

class StateValidationError(LangGraphError):
    """状态验证错误"""
    
    def __init__(self, validation_errors: List[str], **kwargs):
        super().__init__(
            f"State validation failed: {'; '.join(validation_errors)}",
            error_code="STATE_VALIDATION_ERROR",
            **kwargs
        )
        self.validation_errors = validation_errors

class ErrorHandler:
    """统一错误处理器"""
    
    def __init__(self, logger: logging.Logger = None):
        self.logger = logger or logging.getLogger(__name__)
        self.error_counts = {}
        self.recovery_strategies = {}
    
    def register_recovery_strategy(
        self, 
        error_type: type, 
        strategy: Callable[[Exception, dict], dict]
    ):
        """注册错误恢复策略"""
        self.recovery_strategies[error_type] = strategy
    
    @contextmanager
    def error_context(self, operation: str, state: dict):
        """错误上下文管理器"""
        try:
            yield
        except Exception as e:
            self._handle_error(e, operation, state)
            raise
    
    def _handle_error(self, error: Exception, operation: str, state: dict):
        """处理错误"""
        
        # 记录错误
        error_info = {
            "operation": operation,
            "error_type": type(error).__name__,
            "error_message": str(error),
            "state_keys": list(state.keys()),
            "traceback": traceback.format_exc(),
            "timestamp": time.time()
        }
        
        self.logger.error(f"Error in operation '{operation}': {error_info}")
        
        # 更新错误计数
        error_key = f"{operation}:{type(error).__name__}"
        self.error_counts[error_key] = self.error_counts.get(error_key, 0) + 1
        
        # 检查恢复策略
        for error_type, strategy in self.recovery_strategies.items():
            if isinstance(error, error_type):
                try:
                    recovery_result = strategy(error, state)
                    self.logger.info(f"Applied recovery strategy for {error_type.__name__}")
                    return recovery_result
                except Exception as recovery_error:
                    self.logger.error(f"Recovery strategy failed: {recovery_error}")

# 恢复策略示例
def api_call_recovery_strategy(error: Exception, state: dict) -> dict:
    """API调用失败恢复策略"""
    
    if "timeout" in str(error).lower():
        # 超时错误:增加重试延迟
        return {
            "retry_delay": state.get("retry_delay", 1) * 2,
            "should_retry": True,
            "recovery_action": "increase_timeout"
        }
    
    elif "rate limit" in str(error).lower():
        # 限流错误:等待更长时间
        return {
            "retry_delay": 60,  # 等待1分钟
            "should_retry": True,
            "recovery_action": "wait_for_rate_limit"
        }
    
    else:
        # 其他错误:使用降级方案
        return {
            "should_retry": False,
            "use_fallback": True,
            "recovery_action": "use_fallback_service"
        }

# 在节点中使用错误处理
error_handler = ErrorHandler()
error_handler.register_recovery_strategy(requests.RequestException, api_call_recovery_strategy)

def robust_api_node(state: dict) -> dict:
    """具有完善错误处理的API节点"""
    
    with error_handler.error_context("api_call", state):
        try:
            # 主要API调用逻辑
            response = requests.get(
                state["api_url"],
                params=state["params"],
                timeout=state.get("timeout", 10)
            )
            response.raise_for_status()
            
            return {
                "api_response": response.json(),
                "status": "success",
                "timestamp": time.time()
            }
            
        except requests.RequestException as e:
            # 应用恢复策略
            recovery_info = api_call_recovery_strategy(e, state)
            
            raise NodeExecutionError(
                node_name="robust_api_node",
                original_error=e,
                context=recovery_info,
                severity=ErrorSeverity.MEDIUM,
                recoverable=recovery_info.get("should_retry", False)
            )

3.2 调试和监控工具

"""
调试和监控工具
"""

import time
import json
from typing import Dict, List, Any, Callable
from dataclasses import dataclass, asdict
import threading

@dataclass
class ExecutionMetrics:
    """执行指标"""
    node_name: str
    start_time: float
    end_time: float
    duration: float
    memory_before: float
    memory_after: float
    success: bool
    error_message: Optional[str] = None
    
    @property
    def memory_delta(self) -> float:
        return self.memory_after - self.memory_before

class GraphDebugger:
    """图调试器"""
    
    def __init__(self, enabled: bool = True):
        self.enabled = enabled
        self.execution_log: List[Dict[str, Any]] = []
        self.node_metrics: Dict[str, List[ExecutionMetrics]] = {}
        self.state_history: List[Dict[str, Any]] = []
        self.lock = threading.Lock()
    
    def log_node_execution(
        self, 
        node_name: str, 
        input_state: dict, 
        output_state: dict,
        execution_time: float,
        error: Optional[Exception] = None
    ):
        """记录节点执行情况"""
        
        if not self.enabled:
            return
        
        with self.lock:
            log_entry = {
                "timestamp": time.time(),
                "node_name": node_name,
                "input_state_keys": list(input_state.keys()),
                "output_state_keys": list(output_state.keys()),
                "execution_time": execution_time,
                "success": error is None,
                "error": str(error) if error else None
            }
            
            self.execution_log.append(log_entry)
    
    def log_state_change(self, old_state: dict, new_state: dict, operation: str):
        """记录状态变化"""
        
        if not self.enabled:
            return
        
        with self.lock:
            changes = self._detect_state_changes(old_state, new_state)
            
            change_entry = {
                "timestamp": time.time(),
                "operation": operation,
                "changes": changes,
                "state_size_before": len(str(old_state)),
                "state_size_after": len(str(new_state))
            }
            
            self.state_history.append(change_entry)
    
    def _detect_state_changes(self, old_state: dict, new_state: dict) -> Dict[str, Dict[str, Any]]:
        """检测状态变化"""
        
        changes = {}
        
        # 检查新增字段
        for key in new_state:
            if key not in old_state:
                changes[key] = {"action": "added", "new_value": new_state[key]}
        
        # 检查删除字段
        for key in old_state:
            if key not in new_state:
                changes[key] = {"action": "removed", "old_value": old_state[key]}
        
        # 检查修改字段
        for key in old_state:
            if key in new_state and old_state[key] != new_state[key]:
                changes[key] = {
                    "action": "modified",
                    "old_value": old_state[key],
                    "new_value": new_state[key]
                }
        
        return changes
    
    def get_performance_report(self) -> Dict[str, Any]:
        """获取性能报告"""
        
        if not self.execution_log:
            return {"message": "No execution data available"}
        
        # 节点执行时间统计
        node_times = {}
        node_success_rates = {}
        
        for entry in self.execution_log:
            node_name = entry["node_name"]
            exec_time = entry["execution_time"]
            success = entry["success"]
            
            if node_name not in node_times:
                node_times[node_name] = []
                node_success_rates[node_name] = {"success": 0, "total": 0}
            
            node_times[node_name].append(exec_time)
            node_success_rates[node_name]["total"] += 1
            if success:
                node_success_rates[node_name]["success"] += 1
        
        # 计算统计信息
        performance_stats = {}
        for node_name, times in node_times.items():
            success_rate = node_success_rates[node_name]["success"] / node_success_rates[node_name]["total"]
            
            performance_stats[node_name] = {
                "avg_execution_time": sum(times) / len(times),
                "min_execution_time": min(times),
                "max_execution_time": max(times),
                "total_executions": len(times),
                "success_rate": success_rate,
                "total_time": sum(times)
            }
        
        return {
            "total_executions": len(self.execution_log),
            "unique_nodes": len(node_times),
            "node_performance": performance_stats,
            "state_changes": len(self.state_history)
        }
    
    def export_debug_data(self, filepath: str):
        """导出调试数据"""
        
        debug_data = {
            "execution_log": self.execution_log,
            "state_history": self.state_history,
            "performance_report": self.get_performance_report(),
            "export_timestamp": time.time()
        }
        
        with open(filepath, 'w', encoding='utf-8') as f:
            json.dump(debug_data, f, indent=2, ensure_ascii=False)

def debug_node(node_func: Callable, debugger: GraphDebugger):
    """节点调试装饰器"""
    
    def wrapper(state: dict) -> dict:
        start_time = time.time()
        node_name = node_func.__name__
        input_state = state.copy()
        error = None
        
        try:
            result = node_func(state)
            return result
            
        except Exception as e:
            error = e
            result = {"error": str(e)}
            raise
            
        finally:
            execution_time = time.time() - start_time
            
            debugger.log_node_execution(
                node_name=node_name,
                input_state=input_state,
                output_state=result,
                execution_time=execution_time,
                error=error
            )
    
    return wrapper

# 使用示例
debugger = GraphDebugger(enabled=True)

@debug_node
def monitored_processing_node(state: dict) -> dict:
    """被监控的处理节点"""
    
    # 模拟处理逻辑
    time.sleep(0.1)  # 模拟耗时操作
    
    return {
        "processed_data": f"Processed: {state.get('input_data', 'N/A')}",
        "processing_timestamp": time.time()
    }

# 定期生成性能报告
def generate_periodic_report(debugger: GraphDebugger, interval: int = 300):
    """定期生成性能报告"""
    
    def report_worker():
        while True:
            time.sleep(interval)
            report = debugger.get_performance_report()
            
            # 输出报告到日志
            logging.info("=== Performance Report ===")
            logging.info(json.dumps(report, indent=2))
            
            # 导出详细数据
            timestamp = int(time.time())
            debugger.export_debug_data(f"debug_report_{timestamp}.json")
    
    thread = threading.Thread(target=report_worker, daemon=True)
    thread.start()

4. 生产环境部署实践

4.1 生产环境配置

"""
生产环境配置最佳实践
"""

import os
import logging
from typing import Dict, Any, Optional
from dataclasses import dataclass
from langgraph.checkpoint.postgres import PostgresCheckpointSaver
from langgraph.store.postgres import PostgresStore

@dataclass
class ProductionConfig:
    """生产环境配置"""
    
    # 数据库配置
    postgres_uri: str
    redis_uri: Optional[str] = None
    
    # 性能配置
    max_workers: int = 4
    request_timeout: int = 30
    checkpoint_batch_size: int = 10
    
    # 安全配置
    api_key_required: bool = True
    rate_limit_per_minute: int = 60
    max_message_length: int = 10000
    
    # 监控配置
    enable_metrics: bool = True
    log_level: str = "INFO"
    sentry_dsn: Optional[str] = None
    
    # 资源限制
    memory_limit_mb: int = 1024
    max_concurrent_sessions: int = 100
    
    @classmethod
    def from_env(cls) -> 'ProductionConfig':
        """从环境变量创建配置"""
        
        return cls(
            postgres_uri=os.getenv("POSTGRES_URI", ""),
            redis_uri=os.getenv("REDIS_URI"),
            max_workers=int(os.getenv("MAX_WORKERS", "4")),
            request_timeout=int(os.getenv("REQUEST_TIMEOUT", "30")),
            api_key_required=os.getenv("API_KEY_REQUIRED", "true").lower() == "true",
            rate_limit_per_minute=int(os.getenv("RATE_LIMIT_PER_MINUTE", "60")),
            enable_metrics=os.getenv("ENABLE_METRICS", "true").lower() == "true",
            log_level=os.getenv("LOG_LEVEL", "INFO"),
            sentry_dsn=os.getenv("SENTRY_DSN"),
            memory_limit_mb=int(os.getenv("MEMORY_LIMIT_MB", "1024"))
        )

class ProductionGraphFactory:
    """生产环境图工厂"""
    
    def __init__(self, config: ProductionConfig):
        self.config = config
        self._setup_logging()
        self._setup_monitoring()
    
    def _setup_logging(self):
        """设置日志"""
        
        logging.basicConfig(
            level=getattr(logging, self.config.log_level),
            format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
            handlers=[
                logging.StreamHandler(),
                logging.FileHandler('langgraph_production.log')
            ]
        )
        
        # Sentry集成
        if self.config.sentry_dsn:
            import sentry_sdk
            sentry_sdk.init(dsn=self.config.sentry_dsn)
    
    def _setup_monitoring(self):
        """设置监控"""
        
        if self.config.enable_metrics:
            # 可以集成Prometheus等监控系统
            pass
    
    def create_checkpointer(self) -> PostgresCheckpointSaver:
        """创建生产级检查点保存器"""
        
        return PostgresCheckpointSaver(
            connection_string=self.config.postgres_uri,
            # 生产环境配置
            pool_size=10,
            max_overflow=20,
            pool_timeout=30,
            pool_recycle=3600
        )
    
    def create_store(self) -> PostgresStore:
        """创建生产级存储"""
        
        return PostgresStore(
            connection_string=self.config.postgres_uri,
            # 存储配置
            table_name="langgraph_store",
            pool_size=5
        )
    
    def create_production_graph(self, graph_builder_func: Callable) -> Any:
        """创建生产环境图"""
        
        # 创建基础组件
        checkpointer = self.create_checkpointer()
        store = self.create_store()
        
        # 构建图
        graph = graph_builder_func()
        
        # 编译带生产配置的图
        compiled = graph.compile(
            checkpointer=checkpointer,
            store=store,
            interrupt_before=None,  # 生产环境通常不需要交互式中断
            interrupt_after=None,
            debug=False  # 生产环境关闭调试
        )
        
        # 包装生产环境功能
        return ProductionGraphWrapper(compiled, self.config)

class ProductionGraphWrapper:
    """生产环境图包装器"""
    
    def __init__(self, graph: Any, config: ProductionConfig):
        self.graph = graph
        self.config = config
        self.session_count = 0
        self.request_count = 0
        
    def invoke_with_safety(self, input_data: dict, config: dict) -> dict:
        """安全的调用包装"""
        
        # 检查并发会话限制
        if self.session_count >= self.config.max_concurrent_sessions:
            raise Exception("Max concurrent sessions exceeded")
        
        # 检查输入大小
        input_str = str(input_data)
        if len(input_str) > self.config.max_message_length:
            raise Exception("Input message too large")
        
        # 更新计数器
        self.session_count += 1
        self.request_count += 1
        
        try:
            # 设置超时
            import signal
            
            def timeout_handler(signum, frame):
                raise TimeoutError("Request timeout")
            
            signal.signal(signal.SIGALRM, timeout_handler)
            signal.alarm(self.config.request_timeout)
            
            try:
                result = self.graph.invoke(input_data, config)
                return result
                
            finally:
                signal.alarm(0)  # 取消超时
                
        finally:
            self.session_count -= 1

4.2 容器化部署

# Dockerfile 生产环境优化
FROM python:3.11-slim as builder

# 安装系统依赖
RUN apt-get update && apt-get install -y \
    gcc \
    g++ \
    && rm -rf /var/lib/apt/lists/*

# 设置工作目录
WORKDIR /app

# 复制依赖文件
COPY requirements.txt .
COPY pyproject.toml .

# 安装Python依赖
RUN pip install --no-cache-dir --upgrade pip && \
    pip install --no-cache-dir -r requirements.txt

# 生产阶段
FROM python:3.11-slim as production

# 创建非root用户
RUN useradd --create-home --shell /bin/bash langgraph

# 安装运行时依赖
RUN apt-get update && apt-get install -y \
    curl \
    && rm -rf /var/lib/apt/lists/*

# 复制应用和依赖
WORKDIR /app
COPY --from=builder /usr/local/lib/python3.11/site-packages /usr/local/lib/python3.11/site-packages
COPY --from=builder /usr/local/bin /usr/local/bin

# 复制应用代码
COPY . .

# 设置权限
RUN chown -R langgraph:langgraph /app
USER langgraph

# 健康检查
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
    CMD curl -f http://localhost:8000/health || exit 1

# 环境变量
ENV PYTHONPATH=/app
ENV LANGGRAPH_ENV=production

# 启动命令
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4"]
# docker-compose.production.yml
version: '3.8'

services:
  langgraph-app:
    build:
      context: .
      dockerfile: Dockerfile
      target: production
    ports:
      - "8000:8000"
    environment:
      - POSTGRES_URI=postgresql://user:password@postgres:5432/langgraph
      - REDIS_URI=redis://redis:6379
      - LOG_LEVEL=INFO
      - ENABLE_METRICS=true
      - MAX_WORKERS=4
      - MEMORY_LIMIT_MB=1024
    depends_on:
      - postgres
      - redis
    restart: unless-stopped
    deploy:
      resources:
        limits:
          memory: 1G
          cpus: '2'
        reservations:
          memory: 512M
          cpus: '1'
    volumes:
      - ./logs:/app/logs
    networks:
      - langgraph-network

  postgres:
    image: postgres:15
    environment:
      - POSTGRES_DB=langgraph
      - POSTGRES_USER=user
      - POSTGRES_PASSWORD=password
    volumes:
      - postgres_data:/var/lib/postgresql/data
    ports:
      - "5432:5432"
    restart: unless-stopped
    networks:
      - langgraph-network

  redis:
    image: redis:7-alpine
    ports:
      - "6379:6379"
    restart: unless-stopped
    networks:
      - langgraph-network

  nginx:
    image: nginx:alpine
    ports:
      - "80:80"
      - "443:443"
    volumes:
      - ./nginx.conf:/etc/nginx/nginx.conf:ro
      - ./ssl:/etc/ssl/certs:ro
    depends_on:
      - langgraph-app
    restart: unless-stopped
    networks:
      - langgraph-network

volumes:
  postgres_data:

networks:
  langgraph-network:
    driver: bridge

4.3 监控和告警

"""
生产环境监控和告警
"""

import time
import threading
from typing import Dict, List, Callable
from dataclasses import dataclass, field
from collections import defaultdict, deque

@dataclass
class MetricPoint:
    """指标数据点"""
    timestamp: float
    value: float
    labels: Dict[str, str] = field(default_factory=dict)

class MetricsCollector:
    """指标收集器"""
    
    def __init__(self, max_points: int = 1000):
        self.max_points = max_points
        self.metrics: Dict[str, deque] = defaultdict(lambda: deque(maxlen=max_points))
        self.lock = threading.Lock()
        
    def record_metric(self, name: str, value: float, labels: Dict[str, str] = None):
        """记录指标"""
        with self.lock:
            point = MetricPoint(
                timestamp=time.time(),
                value=value,
                labels=labels or {}
            )
            self.metrics[name].append(point)
    
    def get_metric_stats(self, name: str, window_seconds: int = 300) -> Dict[str, float]:
        """获取指标统计信息"""
        with self.lock:
            if name not in self.metrics:
                return {}
            
            current_time = time.time()
            cutoff_time = current_time - window_seconds
            
            # 过滤时间窗口内的数据
            recent_points = [
                p for p in self.metrics[name]
                if p.timestamp >= cutoff_time
            ]
            
            if not recent_points:
                return {}
            
            values = [p.value for p in recent_points]
            
            return {
                "count": len(values),
                "avg": sum(values) / len(values),
                "min": min(values),
                "max": max(values),
                "latest": values[-1],
                "window_seconds": window_seconds
            }

class HealthChecker:
    """健康检查器"""
    
    def __init__(self):
        self.checks: Dict[str, Callable[[], bool]] = {}
        self.check_results: Dict[str, Dict] = {}
    
    def register_check(self, name: str, check_func: Callable[[], bool]):
        """注册健康检查"""
        self.checks[name] = check_func
    
    def run_checks(self) -> Dict[str, Any]:
        """运行所有健康检查"""
        results = {"overall_healthy": True, "checks": {}}
        
        for name, check_func in self.checks.items():
            start_time = time.time()
            
            try:
                is_healthy = check_func()
                check_time = time.time() - start_time
                
                results["checks"][name] = {
                    "healthy": is_healthy,
                    "response_time": check_time,
                    "timestamp": time.time()
                }
                
                if not is_healthy:
                    results["overall_healthy"] = False
                    
            except Exception as e:
                results["checks"][name] = {
                    "healthy": False,
                    "error": str(e),
                    "response_time": time.time() - start_time,
                    "timestamp": time.time()
                }
                results["overall_healthy"] = False
        
        return results

class AlertManager:
    """告警管理器"""
    
    def __init__(self):
        self.alert_rules: List[AlertRule] = []
        self.active_alerts: Dict[str, Dict] = {}
        self.notification_handlers: List[Callable] = []
    
    def add_alert_rule(self, rule: 'AlertRule'):
        """添加告警规则"""
        self.alert_rules.append(rule)
    
    def add_notification_handler(self, handler: Callable[[Dict], None]):
        """添加通知处理器"""
        self.notification_handlers.append(handler)
    
    def check_alerts(self, metrics_collector: MetricsCollector):
        """检查告警条件"""
        
        for rule in self.alert_rules:
            try:
                should_alert = rule.evaluate(metrics_collector)
                alert_key = f"{rule.name}:{rule.metric_name}"
                
                if should_alert and alert_key not in self.active_alerts:
                    # 新告警
                    alert = {
                        "rule_name": rule.name,
                        "metric_name": rule.metric_name,
                        "description": rule.description,
                        "severity": rule.severity,
                        "started_at": time.time(),
                        "acknowledged": False
                    }
                    
                    self.active_alerts[alert_key] = alert
                    self._send_notification(alert)
                
                elif not should_alert and alert_key in self.active_alerts:
                    # 告警恢复
                    alert = self.active_alerts.pop(alert_key)
                    alert["resolved_at"] = time.time()
                    alert["duration"] = alert["resolved_at"] - alert["started_at"]
                    
                    self._send_notification(alert, resolved=True)
                    
            except Exception as e:
                logging.error(f"Error evaluating alert rule {rule.name}: {e}")
    
    def _send_notification(self, alert: Dict, resolved: bool = False):
        """发送告警通知"""
        
        for handler in self.notification_handlers:
            try:
                handler(alert, resolved)
            except Exception as e:
                logging.error(f"Error sending notification: {e}")

@dataclass
class AlertRule:
    """告警规则"""
    
    name: str
    metric_name: str
    condition: str  # "gt", "lt", "eq"
    threshold: float
    window_seconds: int = 300
    description: str = ""
    severity: str = "warning"  # "info", "warning", "critical"
    
    def evaluate(self, metrics_collector: MetricsCollector) -> bool:
        """评估告警条件"""
        
        stats = metrics_collector.get_metric_stats(
            self.metric_name, 
            self.window_seconds
        )
        
        if not stats:
            return False
        
        current_value = stats.get("avg", 0)  # 使用平均值
        
        if self.condition == "gt":
            return current_value > self.threshold
        elif self.condition == "lt":
            return current_value < self.threshold
        elif self.condition == "eq":
            return abs(current_value - self.threshold) < 0.001
        
        return False

# 生产环境监控集成
class ProductionMonitor:
    """生产环境监控"""
    
    def __init__(self):
        self.metrics_collector = MetricsCollector()
        self.health_checker = HealthChecker()
        self.alert_manager = AlertManager()
        
        # 设置基础健康检查
        self._setup_basic_health_checks()
        
        # 设置基础告警规则
        self._setup_basic_alert_rules()
        
        # 启动监控线程
        self._start_monitoring_thread()
    
    def _setup_basic_health_checks(self):
        """设置基础健康检查"""
        
        def database_check() -> bool:
            # 检查数据库连接
            try:
                # 实际实现应该检查数据库连接
                return True
            except:
                return False
        
        def memory_check() -> bool:
            # 检查内存使用
            import psutil
            memory_percent = psutil.virtual_memory().percent
            return memory_percent < 90
        
        self.health_checker.register_check("database", database_check)
        self.health_checker.register_check("memory", memory_check)
    
    def _setup_basic_alert_rules(self):
        """设置基础告警规则"""
        
        # 响应时间告警
        response_time_rule = AlertRule(
            name="high_response_time",
            metric_name="response_time",
            condition="gt",
            threshold=5.0,  # 5秒
            description="平均响应时间过高",
            severity="warning"
        )
        
        # 错误率告警
        error_rate_rule = AlertRule(
            name="high_error_rate",
            metric_name="error_rate",
            condition="gt", 
            threshold=0.05,  # 5%
            description="错误率过高",
            severity="critical"
        )
        
        self.alert_manager.add_alert_rule(response_time_rule)
        self.alert_manager.add_alert_rule(error_rate_rule)
    
    def _start_monitoring_thread(self):
        """启动监控线程"""
        
        def monitor_worker():
            while True:
                try:
                    # 检查告警
                    self.alert_manager.check_alerts(self.metrics_collector)
                    time.sleep(30)  # 每30秒检查一次
                except Exception as e:
                    logging.error(f"Monitoring error: {e}")
                    time.sleep(60)  # 出错时等待更长时间
        
        thread = threading.Thread(target=monitor_worker, daemon=True)
        thread.start()
    
    def record_request_metrics(self, response_time: float, success: bool):
        """记录请求指标"""
        
        self.metrics_collector.record_metric("response_time", response_time)
        
        error_rate = 0.0 if success else 1.0
        self.metrics_collector.record_metric("error_rate", error_rate)
        
        # 记录请求计数
        self.metrics_collector.record_metric("request_count", 1)

5. 总结

本文档总结了LangGraph在实际项目中的最佳实践和经验,涵盖了从架构设计到生产部署的全流程。

5.1 核心要点

  1. 状态设计:清晰的状态结构和最小化的状态更新
  2. 节点设计:单一职责和全面的错误处理
  3. 性能优化:检查点优化、内存管理、并发处理
  4. 错误处理:分层错误处理和恢复策略
  5. 调试监控:全面的调试工具和生产监控
  6. 部署运维:容器化部署和健康检查

5.2 实施建议

  1. 渐进式采用:从简单场景开始,逐步应用最佳实践
  2. 测试驱动:建立完善的测试体系验证最佳实践
  3. 持续改进:根据实际使用情况不断优化和调整
  4. 团队培训:确保团队成员理解和遵循最佳实践

这些实战经验为开发者提供了实用的指导,帮助构建高质量、可维护的LangGraph应用程序。