概述

本文总结了LangGraph框架在实际项目中的最佳实践和优化技巧,涵盖架构设计、性能优化、错误处理、监控运维等各个方面。通过实战经验分享,帮助开发者避免常见陷阱,构建高质量的AI应用。

1. 架构设计最佳实践

1.1 图结构设计原则

1.1.1 单一职责原则

# ❌ 错误示例:节点职责过多
def complex_agent_node(state):
    """一个节点做太多事情"""
    # 数据预处理
    processed_data = preprocess(state["input"])
    
    # LLM调用
    response = llm.invoke(processed_data)
    
    # 工具调用
    tool_results = []
    for tool_call in response.tool_calls:
        result = execute_tool(tool_call)
        tool_results.append(result)
    
    # 结果后处理
    final_result = postprocess(tool_results)
    
    return {"output": final_result}

# ✅ 正确示例:职责分离
def preprocess_node(state):
    """数据预处理节点"""
    return {"processed_input": preprocess(state["input"])}

def llm_node(state):
    """LLM推理节点"""
    response = llm.invoke(state["processed_input"])
    return {"llm_response": response}

def tool_execution_node(state):
    """工具执行节点"""
    tool_results = []
    for tool_call in state["llm_response"].tool_calls:
        result = execute_tool(tool_call)
        tool_results.append(result)
    return {"tool_results": tool_results}

def postprocess_node(state):
    """结果后处理节点"""
    return {"output": postprocess(state["tool_results"])}

# 图构建
graph = StateGraph(AgentState)
graph.add_node("preprocess", preprocess_node)
graph.add_node("llm", llm_node)
graph.add_node("tools", tool_execution_node)
graph.add_node("postprocess", postprocess_node)

# 线性流程
graph.add_edge("preprocess", "llm")
graph.add_edge("llm", "tools")
graph.add_edge("tools", "postprocess")

1.1.2 状态设计最佳实践

from typing import TypedDict, Annotated, List, Optional, Dict, Any
from langgraph.graph.message import add_messages

# ✅ 良好的状态设计
class AgentState(TypedDict):
    """智能体状态定义
    
    设计原则:
    1. 字段语义清晰
    2. 类型注解完整
    3. 合理使用reducer
    4. 避免嵌套过深
    """
    
    # === 核心数据 ===
    messages: Annotated[List[BaseMessage], add_messages]  # 消息历史
    current_task: Optional[str]                           # 当前任务
    
    # === 执行状态 ===
    remaining_steps: int                                  # 剩余步数
    is_complete: bool                                     # 是否完成
    
    # === 中间结果 ===
    tool_results: List[Dict[str, Any]]                   # 工具结果
    analysis_cache: Dict[str, Any]                       # 分析缓存
    
    # === 元数据 ===
    execution_metadata: Dict[str, Any]                   # 执行元数据
    performance_metrics: Dict[str, float]                # 性能指标

# ❌ 避免的状态设计
class BadAgentState(TypedDict):
    # 字段名不清晰
    data: Any
    stuff: Dict
    
    # 嵌套过深
    nested_complex_structure: Dict[str, Dict[str, List[Dict[str, Any]]]]
    
    # 缺少类型注解
    some_field: Any
    
    # 不合理的reducer使用
    single_value_with_list_reducer: Annotated[str, add_messages]  # 错误!

1.1.3 条件路由设计

# ✅ 清晰的条件路由
def should_continue_analysis(state: AgentState) -> str:
    """决定是否继续分析
    
    返回值说明:
    - "continue": 继续分析
    - "finalize": 完成分析
    - "error": 处理错误
    """
    if state.get("error"):
        return "error"
    
    if state["remaining_steps"] <= 0:
        return "finalize"
    
    if state["is_complete"]:
        return "finalize"
    
    return "continue"

# 路由映射清晰
graph.add_conditional_edges(
    "analysis",
    should_continue_analysis,
    {
        "continue": "deeper_analysis",
        "finalize": "generate_report", 
        "error": "error_handler"
    }
)

# ❌ 避免复杂的条件逻辑
def complex_routing(state):
    # 避免在路由函数中进行复杂计算
    complex_analysis = perform_heavy_computation(state)  # 不好!
    
    # 避免多层嵌套条件
    if state["condition1"]:
        if state["condition2"]:
            if state["condition3"]:
                return "path1"
            else:
                return "path2"
        else:
            return "path3"
    else:
        return "path4"

1.2 模块化设计

1.2.1 可复用组件设计

# ✅ 可复用的节点组件
class RetryableNode:
    """可重试的节点包装器"""
    
    def __init__(
        self,
        node_func: Callable,
        max_retries: int = 3,
        retry_delay: float = 1.0,
        retry_exceptions: tuple = (Exception,)
    ):
        self.node_func = node_func
        self.max_retries = max_retries
        self.retry_delay = retry_delay
        self.retry_exceptions = retry_exceptions
    
    def __call__(self, state: dict) -> dict:
        """执行节点,支持重试"""
        last_exception = None
        
        for attempt in range(self.max_retries + 1):
            try:
                return self.node_func(state)
            except self.retry_exceptions as e:
                last_exception = e
                if attempt < self.max_retries:
                    time.sleep(self.retry_delay * (2 ** attempt))  # 指数退避
                    continue
                else:
                    break
        
        # 重试失败,返回错误状态
        return {
            "error": f"Node failed after {self.max_retries} retries: {last_exception}",
            "retry_count": self.max_retries
        }

# 使用示例
def unreliable_api_call(state):
    """可能失败的API调用"""
    response = external_api.call(state["query"])
    return {"api_response": response}

# 包装为可重试节点
reliable_api_node = RetryableNode(
    unreliable_api_call,
    max_retries=3,
    retry_exceptions=(requests.RequestException, TimeoutError)
)

graph.add_node("api_call", reliable_api_node)

1.2.2 子图组合模式

# ✅ 子图组合设计
def create_data_processing_subgraph() -> CompiledGraph:
    """创建数据处理子图"""
    
    class DataProcessingState(TypedDict):
        raw_data: Any
        processed_data: Any
        validation_errors: List[str]
    
    subgraph = StateGraph(DataProcessingState)
    
    # 数据处理流程
    subgraph.add_node("validate", validate_data_node)
    subgraph.add_node("clean", clean_data_node)
    subgraph.add_node("transform", transform_data_node)
    
    # 流程连接
    subgraph.add_edge(START, "validate")
    subgraph.add_conditional_edges(
        "validate",
        lambda state: "clean" if not state.get("validation_errors") else END,
        {"clean": "clean", END: END}
    )
    subgraph.add_edge("clean", "transform")
    subgraph.add_edge("transform", END)
    
    return subgraph.compile()

def create_main_workflow() -> CompiledGraph:
    """创建主工作流"""
    
    class MainState(TypedDict):
        input_data: Any
        processed_data: Any
        final_result: Any
    
    main_graph = StateGraph(MainState)
    
    # 集成子图
    data_processor = create_data_processing_subgraph()
    
    def process_data_node(state):
        """调用数据处理子图"""
        subgraph_input = {"raw_data": state["input_data"]}
        result = data_processor.invoke(subgraph_input)
        return {"processed_data": result["processed_data"]}
    
    main_graph.add_node("process_data", process_data_node)
    main_graph.add_node("generate_result", generate_result_node)
    
    main_graph.add_edge(START, "process_data")
    main_graph.add_edge("process_data", "generate_result")
    main_graph.add_edge("generate_result", END)
    
    return main_graph.compile()

2. 性能优化最佳实践

2.1 状态管理优化

2.1.1 状态大小控制

# ✅ 状态大小优化
class OptimizedAgentState(TypedDict):
    """优化的状态设计"""
    
    # 只保留必要的消息
    messages: Annotated[List[BaseMessage], add_messages]
    
    # 使用引用而非复制大对象
    document_ids: List[str]  # 而非完整文档内容
    
    # 分页处理大列表
    current_page: int
    page_size: int
    
    # 缓存键而非缓存值
    cache_keys: List[str]

def optimize_message_history(state: OptimizedAgentState) -> OptimizedAgentState:
    """优化消息历史"""
    messages = state["messages"]
    
    # 保留最近的N条消息
    MAX_MESSAGES = 50
    if len(messages) > MAX_MESSAGES:
        # 保留系统消息和最近的用户消息
        system_messages = [msg for msg in messages if isinstance(msg, SystemMessage)]
        recent_messages = messages[-(MAX_MESSAGES - len(system_messages)):]
        optimized_messages = system_messages + recent_messages
        
        return {"messages": optimized_messages}
    
    return state

# 定期清理状态
graph.add_node("optimize_state", optimize_message_history)

2.1.2 缓存策略

from functools import lru_cache
import hashlib
import pickle

class StateCache:
    """状态缓存管理器"""
    
    def __init__(self, max_size: int = 1000, ttl: int = 3600):
        self.cache = {}
        self.max_size = max_size
        self.ttl = ttl
        self.access_times = {}
    
    def get_cache_key(self, state: dict, node_name: str) -> str:
        """生成缓存键"""
        # 只使用相关字段生成键
        relevant_data = {
            k: v for k, v in state.items() 
            if k in ["query", "context", "parameters"]  # 只缓存相关字段
        }
        
        data_str = pickle.dumps(relevant_data, protocol=pickle.HIGHEST_PROTOCOL)
        hash_key = hashlib.md5(data_str).hexdigest()
        return f"{node_name}:{hash_key}"
    
    def get(self, cache_key: str) -> Optional[Any]:
        """获取缓存"""
        if cache_key in self.cache:
            entry = self.cache[cache_key]
            if time.time() - entry["timestamp"] < self.ttl:
                self.access_times[cache_key] = time.time()
                return entry["value"]
            else:
                del self.cache[cache_key]
        return None
    
    def put(self, cache_key: str, value: Any) -> None:
        """存储缓存"""
        # LRU淘汰
        if len(self.cache) >= self.max_size:
            oldest_key = min(self.access_times.keys(), 
                           key=lambda k: self.access_times[k])
            del self.cache[oldest_key]
            del self.access_times[oldest_key]
        
        self.cache[cache_key] = {
            "value": value,
            "timestamp": time.time()
        }
        self.access_times[cache_key] = time.time()

# 全局缓存实例
state_cache = StateCache()

def cached_expensive_operation(state):
    """带缓存的昂贵操作"""
    cache_key = state_cache.get_cache_key(state, "expensive_operation")
    
    # 尝试从缓存获取
    cached_result = state_cache.get(cache_key)
    if cached_result is not None:
        return {"result": cached_result, "from_cache": True}
    
    # 执行昂贵操作
    result = perform_expensive_computation(state)
    
    # 存储到缓存
    state_cache.put(cache_key, result)
    
    return {"result": result, "from_cache": False}

2.2 并发执行优化

2.2.1 并行节点设计

# ✅ 并行执行设计
def create_parallel_analysis_graph():
    """创建并行分析图"""
    
    class AnalysisState(TypedDict):
        input_data: Any
        sentiment_analysis: Optional[dict]
        entity_extraction: Optional[dict]
        topic_modeling: Optional[dict]
        final_analysis: Optional[dict]
    
    graph = StateGraph(AnalysisState)
    
    # 独立的分析节点(可并行执行)
    def sentiment_analysis_node(state):
        """情感分析(独立执行)"""
        result = analyze_sentiment(state["input_data"])
        return {"sentiment_analysis": result}
    
    def entity_extraction_node(state):
        """实体提取(独立执行)"""
        result = extract_entities(state["input_data"])
        return {"entity_extraction": result}
    
    def topic_modeling_node(state):
        """主题建模(独立执行)"""
        result = model_topics(state["input_data"])
        return {"topic_modeling": result}
    
    def combine_analysis_node(state):
        """合并分析结果"""
        combined = {
            "sentiment": state["sentiment_analysis"],
            "entities": state["entity_extraction"],
            "topics": state["topic_modeling"]
        }
        return {"final_analysis": combined}
    
    # 添加节点
    graph.add_node("sentiment", sentiment_analysis_node)
    graph.add_node("entities", entity_extraction_node)
    graph.add_node("topics", topic_modeling_node)
    graph.add_node("combine", combine_analysis_node)
    
    # 并行执行设置
    graph.add_edge(START, "sentiment")
    graph.add_edge(START, "entities")
    graph.add_edge(START, "topics")
    
    # 等待所有并行任务完成
    graph.add_edge("sentiment", "combine")
    graph.add_edge("entities", "combine")
    graph.add_edge("topics", "combine")
    
    graph.add_edge("combine", END)
    
    return graph.compile()

2.2.2 异步操作优化

import asyncio
from concurrent.futures import ThreadPoolExecutor

class AsyncNodeWrapper:
    """异步节点包装器"""
    
    def __init__(self, async_func: Callable, max_workers: int = 4):
        self.async_func = async_func
        self.executor = ThreadPoolExecutor(max_workers=max_workers)
    
    def __call__(self, state: dict) -> dict:
        """同步调用异步函数"""
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        
        try:
            result = loop.run_until_complete(self.async_func(state))
            return result
        finally:
            loop.close()

# 异步节点示例
async def async_api_calls_node(state):
    """并发API调用节点"""
    queries = state["queries"]
    
    async def call_api(query):
        """单个API调用"""
        async with aiohttp.ClientSession() as session:
            async with session.post("/api/analyze", json={"query": query}) as response:
                return await response.json()
    
    # 并发执行所有API调用
    tasks = [call_api(query) for query in queries]
    results = await asyncio.gather(*tasks, return_exceptions=True)
    
    # 处理结果和异常
    successful_results = []
    errors = []
    
    for i, result in enumerate(results):
        if isinstance(result, Exception):
            errors.append({"query_index": i, "error": str(result)})
        else:
            successful_results.append(result)
    
    return {
        "api_results": successful_results,
        "api_errors": errors
    }

# 包装为同步节点
sync_api_node = AsyncNodeWrapper(async_api_calls_node)
graph.add_node("api_calls", sync_api_node)

2.3 内存管理优化

2.3.1 内存使用监控

import psutil
import gc
from typing import Dict, Any

class MemoryMonitor:
    """内存监控器"""
    
    def __init__(self, warning_threshold: float = 0.8, critical_threshold: float = 0.9):
        self.warning_threshold = warning_threshold
        self.critical_threshold = critical_threshold
        self.baseline_memory = self._get_memory_usage()
    
    def _get_memory_usage(self) -> Dict[str, float]:
        """获取内存使用情况"""
        process = psutil.Process()
        memory_info = process.memory_info()
        
        return {
            "rss": memory_info.rss / 1024 / 1024,  # MB
            "vms": memory_info.vms / 1024 / 1024,  # MB
            "percent": process.memory_percent()
        }
    
    def check_memory_usage(self) -> Dict[str, Any]:
        """检查内存使用"""
        current_memory = self._get_memory_usage()
        memory_growth = current_memory["rss"] - self.baseline_memory["rss"]
        
        status = "normal"
        if current_memory["percent"] > self.critical_threshold:
            status = "critical"
        elif current_memory["percent"] > self.warning_threshold:
            status = "warning"
        
        return {
            "status": status,
            "current_usage": current_memory,
            "growth_mb": memory_growth,
            "recommendations": self._get_recommendations(status)
        }
    
    def _get_recommendations(self, status: str) -> List[str]:
        """获取优化建议"""
        if status == "critical":
            return [
                "立即执行垃圾回收",
                "清理状态缓存",
                "减少批处理大小",
                "考虑重启进程"
            ]
        elif status == "warning":
            return [
                "执行垃圾回收",
                "检查内存泄漏",
                "优化状态大小"
            ]
        return []

def memory_cleanup_node(state):
    """内存清理节点"""
    monitor = MemoryMonitor()
    memory_status = monitor.check_memory_usage()
    
    if memory_status["status"] in ["warning", "critical"]:
        # 执行垃圾回收
        gc.collect()
        
        # 清理状态中的大对象
        cleaned_state = {}
        for key, value in state.items():
            if key.startswith("temp_") or key.endswith("_cache"):
                continue  # 跳过临时和缓存数据
            cleaned_state[key] = value
        
        # 添加内存状态信息
        cleaned_state["memory_status"] = memory_status
        
        return cleaned_state
    
    return state

3. 错误处理最佳实践

3.1 分层错误处理

from enum import Enum
from dataclasses import dataclass
from typing import Optional, Dict, Any, List

class ErrorSeverity(Enum):
    """错误严重程度"""
    LOW = "low"           # 可忽略的错误
    MEDIUM = "medium"     # 需要处理但不影响主流程
    HIGH = "high"         # 影响主流程但可恢复
    CRITICAL = "critical" # 致命错误,需要停止执行

@dataclass
class GraphError:
    """图执行错误"""
    code: str
    message: str
    severity: ErrorSeverity
    node_name: Optional[str] = None
    context: Optional[Dict[str, Any]] = None
    recoverable: bool = True
    retry_count: int = 0
    
    def to_dict(self) -> Dict[str, Any]:
        return {
            "code": self.code,
            "message": self.message,
            "severity": self.severity.value,
            "node_name": self.node_name,
            "context": self.context,
            "recoverable": self.recoverable,
            "retry_count": self.retry_count
        }

class ErrorHandler:
    """错误处理器"""
    
    def __init__(self):
        self.error_handlers = {}
        self.error_history = []
    
    def register_handler(
        self, 
        error_code: str, 
        handler: Callable[[GraphError, dict], dict]
    ):
        """注册错误处理器"""
        self.error_handlers[error_code] = handler
    
    def handle_error(self, error: GraphError, state: dict) -> dict:
        """处理错误"""
        # 记录错误历史
        self.error_history.append(error)
        
        # 查找对应的处理器
        handler = self.error_handlers.get(error.code)
        if handler:
            try:
                return handler(error, state)
            except Exception as e:
                # 处理器本身出错
                fallback_error = GraphError(
                    code="HANDLER_ERROR",
                    message=f"Error handler failed: {e}",
                    severity=ErrorSeverity.HIGH,
                    recoverable=False
                )
                return self._default_error_handler(fallback_error, state)
        
        # 使用默认处理器
        return self._default_error_handler(error, state)
    
    def _default_error_handler(self, error: GraphError, state: dict) -> dict:
        """默认错误处理器"""
        error_state = state.copy()
        
        # 添加错误信息到状态
        if "errors" not in error_state:
            error_state["errors"] = []
        
        error_state["errors"].append(error.to_dict())
        
        # 根据严重程度决定处理策略
        if error.severity == ErrorSeverity.CRITICAL:
            error_state["should_stop"] = True
        elif error.severity == ErrorSeverity.HIGH:
            error_state["needs_intervention"] = True
        
        return error_state

# 全局错误处理器
global_error_handler = ErrorHandler()

# 注册具体的错误处理器
def handle_api_timeout(error: GraphError, state: dict) -> dict:
    """处理API超时错误"""
    if error.retry_count < 3:
        # 重试策略
        return {
            **state,
            "should_retry": True,
            "retry_delay": 2 ** error.retry_count,  # 指数退避
            "retry_node": error.node_name
        }
    else:
        # 降级策略
        return {
            **state,
            "use_fallback_api": True,
            "api_timeout_handled": True
        }

def handle_validation_error(error: GraphError, state: dict) -> dict:
    """处理验证错误"""
    return {
        **state,
        "validation_failed": True,
        "skip_validation": True,  # 跳过验证继续执行
        "validation_error_details": error.context
    }

# 注册处理器
global_error_handler.register_handler("API_TIMEOUT", handle_api_timeout)
global_error_handler.register_handler("VALIDATION_ERROR", handle_validation_error)

# 错误处理装饰器
def with_error_handling(error_code: str, severity: ErrorSeverity = ErrorSeverity.MEDIUM):
    """错误处理装饰器"""
    def decorator(node_func):
        def wrapper(state):
            try:
                return node_func(state)
            except Exception as e:
                error = GraphError(
                    code=error_code,
                    message=str(e),
                    severity=severity,
                    node_name=node_func.__name__,
                    context={"exception_type": type(e).__name__}
                )
                return global_error_handler.handle_error(error, state)
        return wrapper
    return decorator

# 使用示例
@with_error_handling("API_CALL_ERROR", ErrorSeverity.HIGH)
def api_call_node(state):
    """API调用节点"""
    response = external_api.call(state["query"])
    return {"api_response": response}

3.2 优雅降级策略

class FallbackStrategy:
    """降级策略"""
    
    def __init__(self):
        self.fallback_chains = {}
    
    def register_fallback_chain(self, primary_node: str, fallback_nodes: List[str]):
        """注册降级链"""
        self.fallback_chains[primary_node] = fallback_nodes
    
    def get_fallback_node(self, failed_node: str, attempt: int) -> Optional[str]:
        """获取降级节点"""
        fallback_chain = self.fallback_chains.get(failed_node, [])
        if attempt < len(fallback_chain):
            return fallback_chain[attempt]
        return None

# 全局降级策略
fallback_strategy = FallbackStrategy()

# 注册降级链
fallback_strategy.register_fallback_chain(
    "premium_llm",
    ["standard_llm", "basic_llm", "rule_based_fallback"]
)

def create_resilient_llm_node():
    """创建具有降级能力的LLM节点"""
    
    def resilient_llm_node(state):
        """具有降级能力的LLM节点"""
        current_node = state.get("current_llm_node", "premium_llm")
        attempt = state.get("llm_attempt", 0)
        
        try:
            if current_node == "premium_llm":
                response = premium_llm.invoke(state["query"])
            elif current_node == "standard_llm":
                response = standard_llm.invoke(state["query"])
            elif current_node == "basic_llm":
                response = basic_llm.invoke(state["query"])
            else:  # rule_based_fallback
                response = rule_based_response(state["query"])
            
            return {
                "llm_response": response,
                "used_llm": current_node,
                "llm_attempt": attempt
            }
        
        except Exception as e:
            # 尝试降级
            fallback_node = fallback_strategy.get_fallback_node(current_node, attempt)
            
            if fallback_node:
                return {
                    **state,
                    "current_llm_node": fallback_node,
                    "llm_attempt": attempt + 1,
                    "llm_error": str(e),
                    "should_retry_llm": True
                }
            else:
                # 所有降级选项都失败
                return {
                    **state,
                    "llm_failed": True,
                    "llm_error": str(e),
                    "final_attempt": attempt
                }
    
    return resilient_llm_node

# 条件路由支持降级
def llm_routing_condition(state):
    """LLM路由条件"""
    if state.get("should_retry_llm"):
        return "retry_llm"
    elif state.get("llm_failed"):
        return "handle_llm_failure"
    else:
        return "continue"

# 图构建
graph.add_node("llm", create_resilient_llm_node())
graph.add_conditional_edges(
    "llm",
    llm_routing_condition,
    {
        "retry_llm": "llm",  # 重试(使用降级节点)
        "handle_llm_failure": "llm_failure_handler",
        "continue": "next_step"
    }
)

4. 监控与可观测性

4.1 性能监控

import time
from dataclasses import dataclass, field
from typing import Dict, List, Any
from collections import defaultdict

@dataclass
class NodeMetrics:
    """节点性能指标"""
    name: str
    execution_count: int = 0
    total_duration: float = 0.0
    min_duration: float = float('inf')
    max_duration: float = 0.0
    error_count: int = 0
    last_execution: Optional[float] = None
    
    @property
    def avg_duration(self) -> float:
        return self.total_duration / self.execution_count if self.execution_count > 0 else 0.0
    
    @property
    def success_rate(self) -> float:
        return (self.execution_count - self.error_count) / self.execution_count if self.execution_count > 0 else 0.0

class GraphMetricsCollector:
    """图执行指标收集器"""
    
    def __init__(self):
        self.node_metrics: Dict[str, NodeMetrics] = {}
        self.execution_history: List[Dict[str, Any]] = []
        self.current_execution: Optional[Dict[str, Any]] = None
    
    def start_execution(self, execution_id: str):
        """开始执行追踪"""
        self.current_execution = {
            "id": execution_id,
            "start_time": time.time(),
            "nodes_executed": [],
            "total_duration": 0.0,
            "success": True
        }
    
    def start_node_execution(self, node_name: str) -> str:
        """开始节点执行追踪"""
        execution_id = f"{node_name}_{int(time.time() * 1000)}"
        
        if node_name not in self.node_metrics:
            self.node_metrics[node_name] = NodeMetrics(name=node_name)
        
        return execution_id
    
    def end_node_execution(self, node_name: str, execution_id: str, success: bool = True):
        """结束节点执行追踪"""
        end_time = time.time()
        start_time = float(execution_id.split('_')[-1]) / 1000
        duration = end_time - start_time
        
        metrics = self.node_metrics[node_name]
        metrics.execution_count += 1
        metrics.total_duration += duration
        metrics.min_duration = min(metrics.min_duration, duration)
        metrics.max_duration = max(metrics.max_duration, duration)
        metrics.last_execution = end_time
        
        if not success:
            metrics.error_count += 1
        
        # 记录到当前执行
        if self.current_execution:
            self.current_execution["nodes_executed"].append({
                "name": node_name,
                "duration": duration,
                "success": success,
                "timestamp": end_time
            })
    
    def end_execution(self, success: bool = True):
        """结束执行追踪"""
        if self.current_execution:
            self.current_execution["end_time"] = time.time()
            self.current_execution["total_duration"] = (
                self.current_execution["end_time"] - self.current_execution["start_time"]
            )
            self.current_execution["success"] = success
            
            self.execution_history.append(self.current_execution)
            self.current_execution = None
    
    def get_performance_report(self) -> Dict[str, Any]:
        """获取性能报告"""
        total_executions = len(self.execution_history)
        successful_executions = sum(1 for ex in self.execution_history if ex["success"])
        
        node_performance = {}
        for name, metrics in self.node_metrics.items():
            node_performance[name] = {
                "execution_count": metrics.execution_count,
                "avg_duration": metrics.avg_duration,
                "min_duration": metrics.min_duration,
                "max_duration": metrics.max_duration,
                "success_rate": metrics.success_rate,
                "error_count": metrics.error_count
            }
        
        return {
            "total_executions": total_executions,
            "success_rate": successful_executions / total_executions if total_executions > 0 else 0,
            "node_performance": node_performance,
            "recent_executions": self.execution_history[-10:]  # 最近10次执行
        }

# 全局指标收集器
metrics_collector = GraphMetricsCollector()

# 性能监控装饰器
def with_performance_monitoring(node_func):
    """性能监控装饰器"""
    def wrapper(state):
        node_name = node_func.__name__
        execution_id = metrics_collector.start_node_execution(node_name)
        
        try:
            result = node_func(state)
            metrics_collector.end_node_execution(node_name, execution_id, success=True)
            return result
        except Exception as e:
            metrics_collector.end_node_execution(node_name, execution_id, success=False)
            raise
    
    return wrapper

# 使用示例
@with_performance_monitoring
def monitored_node(state):
    """被监控的节点"""
    # 模拟一些处理时间
    time.sleep(0.1)
    return {"processed": True}

4.2 日志记录最佳实践

import logging
import json
from datetime import datetime
from typing import Any, Dict, Optional

class StructuredLogger:
    """结构化日志记录器"""
    
    def __init__(self, name: str, level: int = logging.INFO):
        self.logger = logging.getLogger(name)
        self.logger.setLevel(level)
        
        # 配置结构化日志格式
        handler = logging.StreamHandler()
        formatter = logging.Formatter(
            '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
        )
        handler.setFormatter(formatter)
        self.logger.addHandler(handler)
    
    def log_node_execution(
        self,
        node_name: str,
        state_summary: Dict[str, Any],
        execution_time: float,
        success: bool = True,
        error: Optional[str] = None
    ):
        """记录节点执行日志"""
        log_data = {
            "event_type": "node_execution",
            "node_name": node_name,
            "execution_time": execution_time,
            "success": success,
            "state_summary": state_summary,
            "timestamp": datetime.utcnow().isoformat()
        }
        
        if error:
            log_data["error"] = error
        
        if success:
            self.logger.info(f"Node executed: {json.dumps(log_data)}")
        else:
            self.logger.error(f"Node failed: {json.dumps(log_data)}")
    
    def log_graph_execution(
        self,
        execution_id: str,
        total_time: float,
        nodes_executed: List[str],
        success: bool = True,
        final_state_summary: Optional[Dict[str, Any]] = None
    ):
        """记录图执行日志"""
        log_data = {
            "event_type": "graph_execution",
            "execution_id": execution_id,
            "total_time": total_time,
            "nodes_executed": nodes_executed,
            "success": success,
            "timestamp": datetime.utcnow().isoformat()
        }
        
        if final_state_summary:
            log_data["final_state_summary"] = final_state_summary
        
        if success:
            self.logger.info(f"Graph execution completed: {json.dumps(log_data)}")
        else:
            self.logger.error(f"Graph execution failed: {json.dumps(log_data)}")

# 创建日志记录器
graph_logger = StructuredLogger("langgraph.execution")

def create_state_summary(state: Dict[str, Any]) -> Dict[str, Any]:
    """创建状态摘要(避免记录敏感或大量数据)"""
    summary = {}
    
    for key, value in state.items():
        if key.startswith("_") or key in ["password", "token", "secret"]:
            summary[key] = "[REDACTED]"
        elif isinstance(value, (str, int, float, bool)):
            summary[key] = value
        elif isinstance(value, list):
            summary[key] = f"[List with {len(value)} items]"
        elif isinstance(value, dict):
            summary[key] = f"[Dict with {len(value)} keys]"
        else:
            summary[key] = f"[{type(value).__name__}]"
    
    return summary

# 日志记录装饰器
def with_logging(logger: StructuredLogger):
    """日志记录装饰器"""
    def decorator(node_func):
        def wrapper(state):
            node_name = node_func.__name__
            start_time = time.time()
            
            try:
                result = node_func(state)
                execution_time = time.time() - start_time
                
                # 记录成功执行
                logger.log_node_execution(
                    node_name=node_name,
                    state_summary=create_state_summary(state),
                    execution_time=execution_time,
                    success=True
                )
                
                return result
            
            except Exception as e:
                execution_time = time.time() - start_time
                
                # 记录执行失败
                logger.log_node_execution(
                    node_name=node_name,
                    state_summary=create_state_summary(state),
                    execution_time=execution_time,
                    success=False,
                    error=str(e)
                )
                
                raise
        
        return wrapper
    return decorator

# 使用示例
@with_logging(graph_logger)
def logged_node(state):
    """带日志记录的节点"""
    return {"processed": True}

5. 部署与运维最佳实践

5.1 配置管理

from typing import Any, Dict, Optional
import os
import yaml
from dataclasses import dataclass

@dataclass
class GraphConfig:
    """图配置"""
    
    # === 执行配置 ===
    max_steps: int = 100
    step_timeout: float = 30.0
    enable_checkpoints: bool = True
    
    # === 性能配置 ===
    max_workers: int = 4
    memory_limit_mb: int = 1024
    enable_caching: bool = True
    cache_size: int = 1000
    
    # === 监控配置 ===
    enable_metrics: bool = True
    enable_logging: bool = True
    log_level: str = "INFO"
    
    # === 外部服务配置 ===
    llm_config: Dict[str, Any] = None
    database_config: Dict[str, Any] = None
    
    def __post_init__(self):
        if self.llm_config is None:
            self.llm_config = {}
        if self.database_config is None:
            self.database_config = {}

class ConfigManager:
    """配置管理器"""
    
    def __init__(self, config_file: Optional[str] = None):
        self.config_file = config_file or "graph_config.yaml"
        self.config = self._load_config()
    
    def _load_config(self) -> GraphConfig:
        """加载配置"""
        config_data = {}
        
        # 从文件加载
        if os.path.exists(self.config_file):
            with open(self.config_file, 'r') as f:
                config_data = yaml.safe_load(f) or {}
        
        # 环境变量覆盖
        env_overrides = self._load_env_overrides()
        config_data.update(env_overrides)
        
        return GraphConfig(**config_data)
    
    def _load_env_overrides(self) -> Dict[str, Any]:
        """从环境变量加载覆盖配置"""
        overrides = {}
        
        # 定义环境变量映射
        env_mappings = {
            "GRAPH_MAX_STEPS": ("max_steps", int),
            "GRAPH_STEP_TIMEOUT": ("step_timeout", float),
            "GRAPH_MAX_WORKERS": ("max_workers", int),
            "GRAPH_MEMORY_LIMIT": ("memory_limit_mb", int),
            "GRAPH_LOG_LEVEL": ("log_level", str),
            "GRAPH_ENABLE_CHECKPOINTS": ("enable_checkpoints", bool),
            "GRAPH_ENABLE_CACHING": ("enable_caching", bool),
        }
        
        for env_var, (config_key, type_func) in env_mappings.items():
            value = os.getenv(env_var)
            if value is not None:
                try:
                    if type_func == bool:
                        overrides[config_key] = value.lower() in ('true', '1', 'yes', 'on')
                    else:
                        overrides[config_key] = type_func(value)
                except ValueError:
                    print(f"Warning: Invalid value for {env_var}: {value}")
        
        return overrides
    
    def get_config(self) -> GraphConfig:
        """获取配置"""
        return self.config
    
    def reload_config(self):
        """重新加载配置"""
        self.config = self._load_config()

# 全局配置管理器
config_manager = ConfigManager()

# 配置示例文件 (graph_config.yaml)
EXAMPLE_CONFIG = """
# 执行配置
max_steps: 50
step_timeout: 60.0
enable_checkpoints: true

# 性能配置
max_workers: 8
memory_limit_mb: 2048
enable_caching: true
cache_size: 2000

# 监控配置
enable_metrics: true
enable_logging: true
log_level: "DEBUG"

# LLM配置
llm_config:
  provider: "openai"
  model: "gpt-4"
  temperature: 0.7
  max_tokens: 2000
  timeout: 30

# 数据库配置
database_config:
  host: "localhost"
  port: 5432
  database: "langgraph"
  username: "user"
  password: "password"
  pool_size: 10
"""

def create_configured_graph() -> CompiledGraph:
    """创建配置化的图"""
    config = config_manager.get_config()
    
    # 使用配置创建检查点保存器
    if config.enable_checkpoints:
        if config.database_config.get("host"):
            checkpointer = PostgresCheckpointSaver(**config.database_config)
        else:
            checkpointer = MemorySaver()
    else:
        checkpointer = None
    
    # 创建图
    graph = StateGraph(AgentState)
    
    # 添加节点...
    
    # 编译时应用配置
    compiled_graph = graph.compile(
        checkpointer=checkpointer,
        interrupt_before=None,
        interrupt_after=None,
        debug=(config.log_level == "DEBUG")
    )
    
    return compiled_graph

5.2 健康检查和监控

from typing import Dict, Any, List
import time
import threading
from dataclasses import dataclass
from enum import Enum

class HealthStatus(Enum):
    """健康状态"""
    HEALTHY = "healthy"
    DEGRADED = "degraded"
    UNHEALTHY = "unhealthy"

@dataclass
class HealthCheck:
    """健康检查结果"""
    name: str
    status: HealthStatus
    message: str
    duration: float
    timestamp: float
    details: Dict[str, Any] = None

class HealthMonitor:
    """健康监控器"""
    
    def __init__(self):
        self.checks = {}
        self.check_history = []
        self.monitoring_thread = None
        self.monitoring_interval = 30  # 30秒检查一次
        self.is_monitoring = False
    
    def register_check(self, name: str, check_func: Callable[[], HealthCheck]):
        """注册健康检查"""
        self.checks[name] = check_func
    
    def run_check(self, name: str) -> HealthCheck:
        """运行单个健康检查"""
        if name not in self.checks:
            return HealthCheck(
                name=name,
                status=HealthStatus.UNHEALTHY,
                message=f"Check '{name}' not found",
                duration=0.0,
                timestamp=time.time()
            )
        
        start_time = time.time()
        try:
            result = self.checks[name]()
            result.duration = time.time() - start_time
            result.timestamp = time.time()
            return result
        except Exception as e:
            return HealthCheck(
                name=name,
                status=HealthStatus.UNHEALTHY,
                message=f"Check failed: {e}",
                duration=time.time() - start_time,
                timestamp=time.time()
            )
    
    def run_all_checks(self) -> Dict[str, HealthCheck]:
        """运行所有健康检查"""
        results = {}
        for name in self.checks:
            results[name] = self.run_check(name)
        
        # 记录历史
        self.check_history.append({
            "timestamp": time.time(),
            "results": results
        })
        
        # 保持历史记录在合理范围内
        if len(self.check_history) > 100:
            self.check_history = self.check_history[-100:]
        
        return results
    
    def get_overall_status(self) -> HealthStatus:
        """获取整体健康状态"""
        results = self.run_all_checks()
        
        if not results:
            return HealthStatus.UNHEALTHY
        
        statuses = [result.status for result in results.values()]
        
        if all(status == HealthStatus.HEALTHY for status in statuses):
            return HealthStatus.HEALTHY
        elif any(status == HealthStatus.UNHEALTHY for status in statuses):
            return HealthStatus.UNHEALTHY
        else:
            return HealthStatus.DEGRADED
    
    def start_monitoring(self):
        """开始监控"""
        if self.is_monitoring:
            return
        
        self.is_monitoring = True
        self.monitoring_thread = threading.Thread(target=self._monitoring_loop)
        self.monitoring_thread.daemon = True
        self.monitoring_thread.start()
    
    def stop_monitoring(self):
        """停止监控"""
        self.is_monitoring = False
        if self.monitoring_thread:
            self.monitoring_thread.join()
    
    def _monitoring_loop(self):
        """监控循环"""
        while self.is_monitoring:
            try:
                self.run_all_checks()
                time.sleep(self.monitoring_interval)
            except Exception as e:
                print(f"Monitoring error: {e}")
                time.sleep(5)  # 错误时短暂等待

# 全局健康监控器
health_monitor = HealthMonitor()

# 具体的健康检查实现
def check_memory_usage() -> HealthCheck:
    """检查内存使用"""
    try:
        import psutil
        memory = psutil.virtual_memory()
        
        if memory.percent > 90:
            status = HealthStatus.UNHEALTHY
            message = f"Memory usage critical: {memory.percent}%"
        elif memory.percent > 80:
            status = HealthStatus.DEGRADED
            message = f"Memory usage high: {memory.percent}%"
        else:
            status = HealthStatus.HEALTHY
            message = f"Memory usage normal: {memory.percent}%"
        
        return HealthCheck(
            name="memory_usage",
            status=status,
            message=message,
            duration=0.0,
            timestamp=time.time(),
            details={
                "total_mb": memory.total / 1024 / 1024,
                "used_mb": memory.used / 1024 / 1024,
                "percent": memory.percent
            }
        )
    except Exception as e:
        return HealthCheck(
            name="memory_usage",
            status=HealthStatus.UNHEALTHY,
            message=f"Failed to check memory: {e}",
            duration=0.0,
            timestamp=time.time()
        )

def check_database_connection() -> HealthCheck:
    """检查数据库连接"""
    try:
        # 这里应该是实际的数据库连接检查
        # 示例代码
        start_time = time.time()
        # db_connection.ping()  # 实际的ping操作
        duration = time.time() - start_time
        
        if duration > 5.0:
            status = HealthStatus.DEGRADED
            message = f"Database response slow: {duration:.2f}s"
        else:
            status = HealthStatus.HEALTHY
            message = f"Database connection OK: {duration:.2f}s"
        
        return HealthCheck(
            name="database_connection",
            status=status,
            message=message,
            duration=duration,
            timestamp=time.time()
        )
    except Exception as e:
        return HealthCheck(
            name="database_connection",
            status=HealthStatus.UNHEALTHY,
            message=f"Database connection failed: {e}",
            duration=0.0,
            timestamp=time.time()
        )

def check_external_api() -> HealthCheck:
    """检查外部API"""
    try:
        import requests
        start_time = time.time()
        
        # 健康检查端点
        response = requests.get("https://api.example.com/health", timeout=10)
        duration = time.time() - start_time
        
        if response.status_code == 200:
            status = HealthStatus.HEALTHY
            message = f"External API OK: {duration:.2f}s"
        else:
            status = HealthStatus.DEGRADED
            message = f"External API error: {response.status_code}"
        
        return HealthCheck(
            name="external_api",
            status=status,
            message=message,
            duration=duration,
            timestamp=time.time(),
            details={
                "status_code": response.status_code,
                "response_time": duration
            }
        )
    except Exception as e:
        return HealthCheck(
            name="external_api",
            status=HealthStatus.UNHEALTHY,
            message=f"External API check failed: {e}",
            duration=0.0,
            timestamp=time.time()
        )

# 注册健康检查
health_monitor.register_check("memory_usage", check_memory_usage)
health_monitor.register_check("database_connection", check_database_connection)
health_monitor.register_check("external_api", check_external_api)

# 启动监控
health_monitor.start_monitoring()

# 健康检查API端点(Flask示例)
from flask import Flask, jsonify

app = Flask(__name__)

@app.route('/health')
def health_endpoint():
    """健康检查端点"""
    overall_status = health_monitor.get_overall_status()
    results = health_monitor.run_all_checks()
    
    response_data = {
        "status": overall_status.value,
        "timestamp": time.time(),
        "checks": {
            name: {
                "status": check.status.value,
                "message": check.message,
                "duration": check.duration,
                "details": check.details
            }
            for name, check in results.items()
        }
    }
    
    status_code = 200 if overall_status == HealthStatus.HEALTHY else 503
    return jsonify(response_data), status_code

@app.route('/metrics')
def metrics_endpoint():
    """指标端点"""
    performance_report = metrics_collector.get_performance_report()
    return jsonify(performance_report)

6. 总结与建议

6.1 核心最佳实践总结

  1. 架构设计

    • 遵循单一职责原则
    • 合理设计状态结构
    • 使用清晰的条件路由
    • 采用模块化组合
  2. 性能优化

    • 控制状态大小
    • 实施缓存策略
    • 支持并行执行
    • 监控内存使用
  3. 错误处理

    • 分层错误处理
    • 优雅降级策略
    • 完善的重试机制
    • 详细的错误上下文
  4. 监控运维

    • 结构化日志记录
    • 全面的性能监控
    • 健康检查机制
    • 配置管理系统

6.2 常见陷阱与避免方法

  1. 状态膨胀:定期清理状态,使用引用而非复制
  2. 内存泄漏:及时释放资源,监控内存使用
  3. 死锁风险:避免循环依赖,设置超时机制
  4. 性能瓶颈:识别热点节点,优化关键路径

6.3 发展趋势与建议

  1. 云原生部署:容器化部署,自动扩缩容
  2. 微服务架构:服务拆分,独立部署
  3. AI Ops集成:智能监控,自动优化
  4. 边缘计算:分布式执行,就近处理

通过遵循这些最佳实践,开发者可以构建出高质量、高性能、高可靠性的LangGraph应用,为用户提供优秀的AI服务体验。



tommie blog