概述
本文总结了LangGraph框架在实际项目中的最佳实践和优化技巧,涵盖架构设计、性能优化、错误处理、监控运维等各个方面。通过实战经验分享,帮助开发者避免常见陷阱,构建高质量的AI应用。
1. 架构设计最佳实践
1.1 图结构设计原则
1.1.1 单一职责原则
# ❌ 错误示例:节点职责过多
def complex_agent_node(state):
"""一个节点做太多事情"""
# 数据预处理
processed_data = preprocess(state["input"])
# LLM调用
response = llm.invoke(processed_data)
# 工具调用
tool_results = []
for tool_call in response.tool_calls:
result = execute_tool(tool_call)
tool_results.append(result)
# 结果后处理
final_result = postprocess(tool_results)
return {"output": final_result}
# ✅ 正确示例:职责分离
def preprocess_node(state):
"""数据预处理节点"""
return {"processed_input": preprocess(state["input"])}
def llm_node(state):
"""LLM推理节点"""
response = llm.invoke(state["processed_input"])
return {"llm_response": response}
def tool_execution_node(state):
"""工具执行节点"""
tool_results = []
for tool_call in state["llm_response"].tool_calls:
result = execute_tool(tool_call)
tool_results.append(result)
return {"tool_results": tool_results}
def postprocess_node(state):
"""结果后处理节点"""
return {"output": postprocess(state["tool_results"])}
# 图构建
graph = StateGraph(AgentState)
graph.add_node("preprocess", preprocess_node)
graph.add_node("llm", llm_node)
graph.add_node("tools", tool_execution_node)
graph.add_node("postprocess", postprocess_node)
# 线性流程
graph.add_edge("preprocess", "llm")
graph.add_edge("llm", "tools")
graph.add_edge("tools", "postprocess")
1.1.2 状态设计最佳实践
from typing import TypedDict, Annotated, List, Optional, Dict, Any
from langgraph.graph.message import add_messages
# ✅ 良好的状态设计
class AgentState(TypedDict):
"""智能体状态定义
设计原则:
1. 字段语义清晰
2. 类型注解完整
3. 合理使用reducer
4. 避免嵌套过深
"""
# === 核心数据 ===
messages: Annotated[List[BaseMessage], add_messages] # 消息历史
current_task: Optional[str] # 当前任务
# === 执行状态 ===
remaining_steps: int # 剩余步数
is_complete: bool # 是否完成
# === 中间结果 ===
tool_results: List[Dict[str, Any]] # 工具结果
analysis_cache: Dict[str, Any] # 分析缓存
# === 元数据 ===
execution_metadata: Dict[str, Any] # 执行元数据
performance_metrics: Dict[str, float] # 性能指标
# ❌ 避免的状态设计
class BadAgentState(TypedDict):
# 字段名不清晰
data: Any
stuff: Dict
# 嵌套过深
nested_complex_structure: Dict[str, Dict[str, List[Dict[str, Any]]]]
# 缺少类型注解
some_field: Any
# 不合理的reducer使用
single_value_with_list_reducer: Annotated[str, add_messages] # 错误!
1.1.3 条件路由设计
# ✅ 清晰的条件路由
def should_continue_analysis(state: AgentState) -> str:
"""决定是否继续分析
返回值说明:
- "continue": 继续分析
- "finalize": 完成分析
- "error": 处理错误
"""
if state.get("error"):
return "error"
if state["remaining_steps"] <= 0:
return "finalize"
if state["is_complete"]:
return "finalize"
return "continue"
# 路由映射清晰
graph.add_conditional_edges(
"analysis",
should_continue_analysis,
{
"continue": "deeper_analysis",
"finalize": "generate_report",
"error": "error_handler"
}
)
# ❌ 避免复杂的条件逻辑
def complex_routing(state):
# 避免在路由函数中进行复杂计算
complex_analysis = perform_heavy_computation(state) # 不好!
# 避免多层嵌套条件
if state["condition1"]:
if state["condition2"]:
if state["condition3"]:
return "path1"
else:
return "path2"
else:
return "path3"
else:
return "path4"
1.2 模块化设计
1.2.1 可复用组件设计
# ✅ 可复用的节点组件
class RetryableNode:
"""可重试的节点包装器"""
def __init__(
self,
node_func: Callable,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_exceptions: tuple = (Exception,)
):
self.node_func = node_func
self.max_retries = max_retries
self.retry_delay = retry_delay
self.retry_exceptions = retry_exceptions
def __call__(self, state: dict) -> dict:
"""执行节点,支持重试"""
last_exception = None
for attempt in range(self.max_retries + 1):
try:
return self.node_func(state)
except self.retry_exceptions as e:
last_exception = e
if attempt < self.max_retries:
time.sleep(self.retry_delay * (2 ** attempt)) # 指数退避
continue
else:
break
# 重试失败,返回错误状态
return {
"error": f"Node failed after {self.max_retries} retries: {last_exception}",
"retry_count": self.max_retries
}
# 使用示例
def unreliable_api_call(state):
"""可能失败的API调用"""
response = external_api.call(state["query"])
return {"api_response": response}
# 包装为可重试节点
reliable_api_node = RetryableNode(
unreliable_api_call,
max_retries=3,
retry_exceptions=(requests.RequestException, TimeoutError)
)
graph.add_node("api_call", reliable_api_node)
1.2.2 子图组合模式
# ✅ 子图组合设计
def create_data_processing_subgraph() -> CompiledGraph:
"""创建数据处理子图"""
class DataProcessingState(TypedDict):
raw_data: Any
processed_data: Any
validation_errors: List[str]
subgraph = StateGraph(DataProcessingState)
# 数据处理流程
subgraph.add_node("validate", validate_data_node)
subgraph.add_node("clean", clean_data_node)
subgraph.add_node("transform", transform_data_node)
# 流程连接
subgraph.add_edge(START, "validate")
subgraph.add_conditional_edges(
"validate",
lambda state: "clean" if not state.get("validation_errors") else END,
{"clean": "clean", END: END}
)
subgraph.add_edge("clean", "transform")
subgraph.add_edge("transform", END)
return subgraph.compile()
def create_main_workflow() -> CompiledGraph:
"""创建主工作流"""
class MainState(TypedDict):
input_data: Any
processed_data: Any
final_result: Any
main_graph = StateGraph(MainState)
# 集成子图
data_processor = create_data_processing_subgraph()
def process_data_node(state):
"""调用数据处理子图"""
subgraph_input = {"raw_data": state["input_data"]}
result = data_processor.invoke(subgraph_input)
return {"processed_data": result["processed_data"]}
main_graph.add_node("process_data", process_data_node)
main_graph.add_node("generate_result", generate_result_node)
main_graph.add_edge(START, "process_data")
main_graph.add_edge("process_data", "generate_result")
main_graph.add_edge("generate_result", END)
return main_graph.compile()
2. 性能优化最佳实践
2.1 状态管理优化
2.1.1 状态大小控制
# ✅ 状态大小优化
class OptimizedAgentState(TypedDict):
"""优化的状态设计"""
# 只保留必要的消息
messages: Annotated[List[BaseMessage], add_messages]
# 使用引用而非复制大对象
document_ids: List[str] # 而非完整文档内容
# 分页处理大列表
current_page: int
page_size: int
# 缓存键而非缓存值
cache_keys: List[str]
def optimize_message_history(state: OptimizedAgentState) -> OptimizedAgentState:
"""优化消息历史"""
messages = state["messages"]
# 保留最近的N条消息
MAX_MESSAGES = 50
if len(messages) > MAX_MESSAGES:
# 保留系统消息和最近的用户消息
system_messages = [msg for msg in messages if isinstance(msg, SystemMessage)]
recent_messages = messages[-(MAX_MESSAGES - len(system_messages)):]
optimized_messages = system_messages + recent_messages
return {"messages": optimized_messages}
return state
# 定期清理状态
graph.add_node("optimize_state", optimize_message_history)
2.1.2 缓存策略
from functools import lru_cache
import hashlib
import pickle
class StateCache:
"""状态缓存管理器"""
def __init__(self, max_size: int = 1000, ttl: int = 3600):
self.cache = {}
self.max_size = max_size
self.ttl = ttl
self.access_times = {}
def get_cache_key(self, state: dict, node_name: str) -> str:
"""生成缓存键"""
# 只使用相关字段生成键
relevant_data = {
k: v for k, v in state.items()
if k in ["query", "context", "parameters"] # 只缓存相关字段
}
data_str = pickle.dumps(relevant_data, protocol=pickle.HIGHEST_PROTOCOL)
hash_key = hashlib.md5(data_str).hexdigest()
return f"{node_name}:{hash_key}"
def get(self, cache_key: str) -> Optional[Any]:
"""获取缓存"""
if cache_key in self.cache:
entry = self.cache[cache_key]
if time.time() - entry["timestamp"] < self.ttl:
self.access_times[cache_key] = time.time()
return entry["value"]
else:
del self.cache[cache_key]
return None
def put(self, cache_key: str, value: Any) -> None:
"""存储缓存"""
# LRU淘汰
if len(self.cache) >= self.max_size:
oldest_key = min(self.access_times.keys(),
key=lambda k: self.access_times[k])
del self.cache[oldest_key]
del self.access_times[oldest_key]
self.cache[cache_key] = {
"value": value,
"timestamp": time.time()
}
self.access_times[cache_key] = time.time()
# 全局缓存实例
state_cache = StateCache()
def cached_expensive_operation(state):
"""带缓存的昂贵操作"""
cache_key = state_cache.get_cache_key(state, "expensive_operation")
# 尝试从缓存获取
cached_result = state_cache.get(cache_key)
if cached_result is not None:
return {"result": cached_result, "from_cache": True}
# 执行昂贵操作
result = perform_expensive_computation(state)
# 存储到缓存
state_cache.put(cache_key, result)
return {"result": result, "from_cache": False}
2.2 并发执行优化
2.2.1 并行节点设计
# ✅ 并行执行设计
def create_parallel_analysis_graph():
"""创建并行分析图"""
class AnalysisState(TypedDict):
input_data: Any
sentiment_analysis: Optional[dict]
entity_extraction: Optional[dict]
topic_modeling: Optional[dict]
final_analysis: Optional[dict]
graph = StateGraph(AnalysisState)
# 独立的分析节点(可并行执行)
def sentiment_analysis_node(state):
"""情感分析(独立执行)"""
result = analyze_sentiment(state["input_data"])
return {"sentiment_analysis": result}
def entity_extraction_node(state):
"""实体提取(独立执行)"""
result = extract_entities(state["input_data"])
return {"entity_extraction": result}
def topic_modeling_node(state):
"""主题建模(独立执行)"""
result = model_topics(state["input_data"])
return {"topic_modeling": result}
def combine_analysis_node(state):
"""合并分析结果"""
combined = {
"sentiment": state["sentiment_analysis"],
"entities": state["entity_extraction"],
"topics": state["topic_modeling"]
}
return {"final_analysis": combined}
# 添加节点
graph.add_node("sentiment", sentiment_analysis_node)
graph.add_node("entities", entity_extraction_node)
graph.add_node("topics", topic_modeling_node)
graph.add_node("combine", combine_analysis_node)
# 并行执行设置
graph.add_edge(START, "sentiment")
graph.add_edge(START, "entities")
graph.add_edge(START, "topics")
# 等待所有并行任务完成
graph.add_edge("sentiment", "combine")
graph.add_edge("entities", "combine")
graph.add_edge("topics", "combine")
graph.add_edge("combine", END)
return graph.compile()
2.2.2 异步操作优化
import asyncio
from concurrent.futures import ThreadPoolExecutor
class AsyncNodeWrapper:
"""异步节点包装器"""
def __init__(self, async_func: Callable, max_workers: int = 4):
self.async_func = async_func
self.executor = ThreadPoolExecutor(max_workers=max_workers)
def __call__(self, state: dict) -> dict:
"""同步调用异步函数"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(self.async_func(state))
return result
finally:
loop.close()
# 异步节点示例
async def async_api_calls_node(state):
"""并发API调用节点"""
queries = state["queries"]
async def call_api(query):
"""单个API调用"""
async with aiohttp.ClientSession() as session:
async with session.post("/api/analyze", json={"query": query}) as response:
return await response.json()
# 并发执行所有API调用
tasks = [call_api(query) for query in queries]
results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理结果和异常
successful_results = []
errors = []
for i, result in enumerate(results):
if isinstance(result, Exception):
errors.append({"query_index": i, "error": str(result)})
else:
successful_results.append(result)
return {
"api_results": successful_results,
"api_errors": errors
}
# 包装为同步节点
sync_api_node = AsyncNodeWrapper(async_api_calls_node)
graph.add_node("api_calls", sync_api_node)
2.3 内存管理优化
2.3.1 内存使用监控
import psutil
import gc
from typing import Dict, Any
class MemoryMonitor:
"""内存监控器"""
def __init__(self, warning_threshold: float = 0.8, critical_threshold: float = 0.9):
self.warning_threshold = warning_threshold
self.critical_threshold = critical_threshold
self.baseline_memory = self._get_memory_usage()
def _get_memory_usage(self) -> Dict[str, float]:
"""获取内存使用情况"""
process = psutil.Process()
memory_info = process.memory_info()
return {
"rss": memory_info.rss / 1024 / 1024, # MB
"vms": memory_info.vms / 1024 / 1024, # MB
"percent": process.memory_percent()
}
def check_memory_usage(self) -> Dict[str, Any]:
"""检查内存使用"""
current_memory = self._get_memory_usage()
memory_growth = current_memory["rss"] - self.baseline_memory["rss"]
status = "normal"
if current_memory["percent"] > self.critical_threshold:
status = "critical"
elif current_memory["percent"] > self.warning_threshold:
status = "warning"
return {
"status": status,
"current_usage": current_memory,
"growth_mb": memory_growth,
"recommendations": self._get_recommendations(status)
}
def _get_recommendations(self, status: str) -> List[str]:
"""获取优化建议"""
if status == "critical":
return [
"立即执行垃圾回收",
"清理状态缓存",
"减少批处理大小",
"考虑重启进程"
]
elif status == "warning":
return [
"执行垃圾回收",
"检查内存泄漏",
"优化状态大小"
]
return []
def memory_cleanup_node(state):
"""内存清理节点"""
monitor = MemoryMonitor()
memory_status = monitor.check_memory_usage()
if memory_status["status"] in ["warning", "critical"]:
# 执行垃圾回收
gc.collect()
# 清理状态中的大对象
cleaned_state = {}
for key, value in state.items():
if key.startswith("temp_") or key.endswith("_cache"):
continue # 跳过临时和缓存数据
cleaned_state[key] = value
# 添加内存状态信息
cleaned_state["memory_status"] = memory_status
return cleaned_state
return state
3. 错误处理最佳实践
3.1 分层错误处理
from enum import Enum
from dataclasses import dataclass
from typing import Optional, Dict, Any, List
class ErrorSeverity(Enum):
"""错误严重程度"""
LOW = "low" # 可忽略的错误
MEDIUM = "medium" # 需要处理但不影响主流程
HIGH = "high" # 影响主流程但可恢复
CRITICAL = "critical" # 致命错误,需要停止执行
@dataclass
class GraphError:
"""图执行错误"""
code: str
message: str
severity: ErrorSeverity
node_name: Optional[str] = None
context: Optional[Dict[str, Any]] = None
recoverable: bool = True
retry_count: int = 0
def to_dict(self) -> Dict[str, Any]:
return {
"code": self.code,
"message": self.message,
"severity": self.severity.value,
"node_name": self.node_name,
"context": self.context,
"recoverable": self.recoverable,
"retry_count": self.retry_count
}
class ErrorHandler:
"""错误处理器"""
def __init__(self):
self.error_handlers = {}
self.error_history = []
def register_handler(
self,
error_code: str,
handler: Callable[[GraphError, dict], dict]
):
"""注册错误处理器"""
self.error_handlers[error_code] = handler
def handle_error(self, error: GraphError, state: dict) -> dict:
"""处理错误"""
# 记录错误历史
self.error_history.append(error)
# 查找对应的处理器
handler = self.error_handlers.get(error.code)
if handler:
try:
return handler(error, state)
except Exception as e:
# 处理器本身出错
fallback_error = GraphError(
code="HANDLER_ERROR",
message=f"Error handler failed: {e}",
severity=ErrorSeverity.HIGH,
recoverable=False
)
return self._default_error_handler(fallback_error, state)
# 使用默认处理器
return self._default_error_handler(error, state)
def _default_error_handler(self, error: GraphError, state: dict) -> dict:
"""默认错误处理器"""
error_state = state.copy()
# 添加错误信息到状态
if "errors" not in error_state:
error_state["errors"] = []
error_state["errors"].append(error.to_dict())
# 根据严重程度决定处理策略
if error.severity == ErrorSeverity.CRITICAL:
error_state["should_stop"] = True
elif error.severity == ErrorSeverity.HIGH:
error_state["needs_intervention"] = True
return error_state
# 全局错误处理器
global_error_handler = ErrorHandler()
# 注册具体的错误处理器
def handle_api_timeout(error: GraphError, state: dict) -> dict:
"""处理API超时错误"""
if error.retry_count < 3:
# 重试策略
return {
**state,
"should_retry": True,
"retry_delay": 2 ** error.retry_count, # 指数退避
"retry_node": error.node_name
}
else:
# 降级策略
return {
**state,
"use_fallback_api": True,
"api_timeout_handled": True
}
def handle_validation_error(error: GraphError, state: dict) -> dict:
"""处理验证错误"""
return {
**state,
"validation_failed": True,
"skip_validation": True, # 跳过验证继续执行
"validation_error_details": error.context
}
# 注册处理器
global_error_handler.register_handler("API_TIMEOUT", handle_api_timeout)
global_error_handler.register_handler("VALIDATION_ERROR", handle_validation_error)
# 错误处理装饰器
def with_error_handling(error_code: str, severity: ErrorSeverity = ErrorSeverity.MEDIUM):
"""错误处理装饰器"""
def decorator(node_func):
def wrapper(state):
try:
return node_func(state)
except Exception as e:
error = GraphError(
code=error_code,
message=str(e),
severity=severity,
node_name=node_func.__name__,
context={"exception_type": type(e).__name__}
)
return global_error_handler.handle_error(error, state)
return wrapper
return decorator
# 使用示例
@with_error_handling("API_CALL_ERROR", ErrorSeverity.HIGH)
def api_call_node(state):
"""API调用节点"""
response = external_api.call(state["query"])
return {"api_response": response}
3.2 优雅降级策略
class FallbackStrategy:
"""降级策略"""
def __init__(self):
self.fallback_chains = {}
def register_fallback_chain(self, primary_node: str, fallback_nodes: List[str]):
"""注册降级链"""
self.fallback_chains[primary_node] = fallback_nodes
def get_fallback_node(self, failed_node: str, attempt: int) -> Optional[str]:
"""获取降级节点"""
fallback_chain = self.fallback_chains.get(failed_node, [])
if attempt < len(fallback_chain):
return fallback_chain[attempt]
return None
# 全局降级策略
fallback_strategy = FallbackStrategy()
# 注册降级链
fallback_strategy.register_fallback_chain(
"premium_llm",
["standard_llm", "basic_llm", "rule_based_fallback"]
)
def create_resilient_llm_node():
"""创建具有降级能力的LLM节点"""
def resilient_llm_node(state):
"""具有降级能力的LLM节点"""
current_node = state.get("current_llm_node", "premium_llm")
attempt = state.get("llm_attempt", 0)
try:
if current_node == "premium_llm":
response = premium_llm.invoke(state["query"])
elif current_node == "standard_llm":
response = standard_llm.invoke(state["query"])
elif current_node == "basic_llm":
response = basic_llm.invoke(state["query"])
else: # rule_based_fallback
response = rule_based_response(state["query"])
return {
"llm_response": response,
"used_llm": current_node,
"llm_attempt": attempt
}
except Exception as e:
# 尝试降级
fallback_node = fallback_strategy.get_fallback_node(current_node, attempt)
if fallback_node:
return {
**state,
"current_llm_node": fallback_node,
"llm_attempt": attempt + 1,
"llm_error": str(e),
"should_retry_llm": True
}
else:
# 所有降级选项都失败
return {
**state,
"llm_failed": True,
"llm_error": str(e),
"final_attempt": attempt
}
return resilient_llm_node
# 条件路由支持降级
def llm_routing_condition(state):
"""LLM路由条件"""
if state.get("should_retry_llm"):
return "retry_llm"
elif state.get("llm_failed"):
return "handle_llm_failure"
else:
return "continue"
# 图构建
graph.add_node("llm", create_resilient_llm_node())
graph.add_conditional_edges(
"llm",
llm_routing_condition,
{
"retry_llm": "llm", # 重试(使用降级节点)
"handle_llm_failure": "llm_failure_handler",
"continue": "next_step"
}
)
4. 监控与可观测性
4.1 性能监控
import time
from dataclasses import dataclass, field
from typing import Dict, List, Any
from collections import defaultdict
@dataclass
class NodeMetrics:
"""节点性能指标"""
name: str
execution_count: int = 0
total_duration: float = 0.0
min_duration: float = float('inf')
max_duration: float = 0.0
error_count: int = 0
last_execution: Optional[float] = None
@property
def avg_duration(self) -> float:
return self.total_duration / self.execution_count if self.execution_count > 0 else 0.0
@property
def success_rate(self) -> float:
return (self.execution_count - self.error_count) / self.execution_count if self.execution_count > 0 else 0.0
class GraphMetricsCollector:
"""图执行指标收集器"""
def __init__(self):
self.node_metrics: Dict[str, NodeMetrics] = {}
self.execution_history: List[Dict[str, Any]] = []
self.current_execution: Optional[Dict[str, Any]] = None
def start_execution(self, execution_id: str):
"""开始执行追踪"""
self.current_execution = {
"id": execution_id,
"start_time": time.time(),
"nodes_executed": [],
"total_duration": 0.0,
"success": True
}
def start_node_execution(self, node_name: str) -> str:
"""开始节点执行追踪"""
execution_id = f"{node_name}_{int(time.time() * 1000)}"
if node_name not in self.node_metrics:
self.node_metrics[node_name] = NodeMetrics(name=node_name)
return execution_id
def end_node_execution(self, node_name: str, execution_id: str, success: bool = True):
"""结束节点执行追踪"""
end_time = time.time()
start_time = float(execution_id.split('_')[-1]) / 1000
duration = end_time - start_time
metrics = self.node_metrics[node_name]
metrics.execution_count += 1
metrics.total_duration += duration
metrics.min_duration = min(metrics.min_duration, duration)
metrics.max_duration = max(metrics.max_duration, duration)
metrics.last_execution = end_time
if not success:
metrics.error_count += 1
# 记录到当前执行
if self.current_execution:
self.current_execution["nodes_executed"].append({
"name": node_name,
"duration": duration,
"success": success,
"timestamp": end_time
})
def end_execution(self, success: bool = True):
"""结束执行追踪"""
if self.current_execution:
self.current_execution["end_time"] = time.time()
self.current_execution["total_duration"] = (
self.current_execution["end_time"] - self.current_execution["start_time"]
)
self.current_execution["success"] = success
self.execution_history.append(self.current_execution)
self.current_execution = None
def get_performance_report(self) -> Dict[str, Any]:
"""获取性能报告"""
total_executions = len(self.execution_history)
successful_executions = sum(1 for ex in self.execution_history if ex["success"])
node_performance = {}
for name, metrics in self.node_metrics.items():
node_performance[name] = {
"execution_count": metrics.execution_count,
"avg_duration": metrics.avg_duration,
"min_duration": metrics.min_duration,
"max_duration": metrics.max_duration,
"success_rate": metrics.success_rate,
"error_count": metrics.error_count
}
return {
"total_executions": total_executions,
"success_rate": successful_executions / total_executions if total_executions > 0 else 0,
"node_performance": node_performance,
"recent_executions": self.execution_history[-10:] # 最近10次执行
}
# 全局指标收集器
metrics_collector = GraphMetricsCollector()
# 性能监控装饰器
def with_performance_monitoring(node_func):
"""性能监控装饰器"""
def wrapper(state):
node_name = node_func.__name__
execution_id = metrics_collector.start_node_execution(node_name)
try:
result = node_func(state)
metrics_collector.end_node_execution(node_name, execution_id, success=True)
return result
except Exception as e:
metrics_collector.end_node_execution(node_name, execution_id, success=False)
raise
return wrapper
# 使用示例
@with_performance_monitoring
def monitored_node(state):
"""被监控的节点"""
# 模拟一些处理时间
time.sleep(0.1)
return {"processed": True}
4.2 日志记录最佳实践
import logging
import json
from datetime import datetime
from typing import Any, Dict, Optional
class StructuredLogger:
"""结构化日志记录器"""
def __init__(self, name: str, level: int = logging.INFO):
self.logger = logging.getLogger(name)
self.logger.setLevel(level)
# 配置结构化日志格式
handler = logging.StreamHandler()
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
handler.setFormatter(formatter)
self.logger.addHandler(handler)
def log_node_execution(
self,
node_name: str,
state_summary: Dict[str, Any],
execution_time: float,
success: bool = True,
error: Optional[str] = None
):
"""记录节点执行日志"""
log_data = {
"event_type": "node_execution",
"node_name": node_name,
"execution_time": execution_time,
"success": success,
"state_summary": state_summary,
"timestamp": datetime.utcnow().isoformat()
}
if error:
log_data["error"] = error
if success:
self.logger.info(f"Node executed: {json.dumps(log_data)}")
else:
self.logger.error(f"Node failed: {json.dumps(log_data)}")
def log_graph_execution(
self,
execution_id: str,
total_time: float,
nodes_executed: List[str],
success: bool = True,
final_state_summary: Optional[Dict[str, Any]] = None
):
"""记录图执行日志"""
log_data = {
"event_type": "graph_execution",
"execution_id": execution_id,
"total_time": total_time,
"nodes_executed": nodes_executed,
"success": success,
"timestamp": datetime.utcnow().isoformat()
}
if final_state_summary:
log_data["final_state_summary"] = final_state_summary
if success:
self.logger.info(f"Graph execution completed: {json.dumps(log_data)}")
else:
self.logger.error(f"Graph execution failed: {json.dumps(log_data)}")
# 创建日志记录器
graph_logger = StructuredLogger("langgraph.execution")
def create_state_summary(state: Dict[str, Any]) -> Dict[str, Any]:
"""创建状态摘要(避免记录敏感或大量数据)"""
summary = {}
for key, value in state.items():
if key.startswith("_") or key in ["password", "token", "secret"]:
summary[key] = "[REDACTED]"
elif isinstance(value, (str, int, float, bool)):
summary[key] = value
elif isinstance(value, list):
summary[key] = f"[List with {len(value)} items]"
elif isinstance(value, dict):
summary[key] = f"[Dict with {len(value)} keys]"
else:
summary[key] = f"[{type(value).__name__}]"
return summary
# 日志记录装饰器
def with_logging(logger: StructuredLogger):
"""日志记录装饰器"""
def decorator(node_func):
def wrapper(state):
node_name = node_func.__name__
start_time = time.time()
try:
result = node_func(state)
execution_time = time.time() - start_time
# 记录成功执行
logger.log_node_execution(
node_name=node_name,
state_summary=create_state_summary(state),
execution_time=execution_time,
success=True
)
return result
except Exception as e:
execution_time = time.time() - start_time
# 记录执行失败
logger.log_node_execution(
node_name=node_name,
state_summary=create_state_summary(state),
execution_time=execution_time,
success=False,
error=str(e)
)
raise
return wrapper
return decorator
# 使用示例
@with_logging(graph_logger)
def logged_node(state):
"""带日志记录的节点"""
return {"processed": True}
5. 部署与运维最佳实践
5.1 配置管理
from typing import Any, Dict, Optional
import os
import yaml
from dataclasses import dataclass
@dataclass
class GraphConfig:
"""图配置"""
# === 执行配置 ===
max_steps: int = 100
step_timeout: float = 30.0
enable_checkpoints: bool = True
# === 性能配置 ===
max_workers: int = 4
memory_limit_mb: int = 1024
enable_caching: bool = True
cache_size: int = 1000
# === 监控配置 ===
enable_metrics: bool = True
enable_logging: bool = True
log_level: str = "INFO"
# === 外部服务配置 ===
llm_config: Dict[str, Any] = None
database_config: Dict[str, Any] = None
def __post_init__(self):
if self.llm_config is None:
self.llm_config = {}
if self.database_config is None:
self.database_config = {}
class ConfigManager:
"""配置管理器"""
def __init__(self, config_file: Optional[str] = None):
self.config_file = config_file or "graph_config.yaml"
self.config = self._load_config()
def _load_config(self) -> GraphConfig:
"""加载配置"""
config_data = {}
# 从文件加载
if os.path.exists(self.config_file):
with open(self.config_file, 'r') as f:
config_data = yaml.safe_load(f) or {}
# 环境变量覆盖
env_overrides = self._load_env_overrides()
config_data.update(env_overrides)
return GraphConfig(**config_data)
def _load_env_overrides(self) -> Dict[str, Any]:
"""从环境变量加载覆盖配置"""
overrides = {}
# 定义环境变量映射
env_mappings = {
"GRAPH_MAX_STEPS": ("max_steps", int),
"GRAPH_STEP_TIMEOUT": ("step_timeout", float),
"GRAPH_MAX_WORKERS": ("max_workers", int),
"GRAPH_MEMORY_LIMIT": ("memory_limit_mb", int),
"GRAPH_LOG_LEVEL": ("log_level", str),
"GRAPH_ENABLE_CHECKPOINTS": ("enable_checkpoints", bool),
"GRAPH_ENABLE_CACHING": ("enable_caching", bool),
}
for env_var, (config_key, type_func) in env_mappings.items():
value = os.getenv(env_var)
if value is not None:
try:
if type_func == bool:
overrides[config_key] = value.lower() in ('true', '1', 'yes', 'on')
else:
overrides[config_key] = type_func(value)
except ValueError:
print(f"Warning: Invalid value for {env_var}: {value}")
return overrides
def get_config(self) -> GraphConfig:
"""获取配置"""
return self.config
def reload_config(self):
"""重新加载配置"""
self.config = self._load_config()
# 全局配置管理器
config_manager = ConfigManager()
# 配置示例文件 (graph_config.yaml)
EXAMPLE_CONFIG = """
# 执行配置
max_steps: 50
step_timeout: 60.0
enable_checkpoints: true
# 性能配置
max_workers: 8
memory_limit_mb: 2048
enable_caching: true
cache_size: 2000
# 监控配置
enable_metrics: true
enable_logging: true
log_level: "DEBUG"
# LLM配置
llm_config:
provider: "openai"
model: "gpt-4"
temperature: 0.7
max_tokens: 2000
timeout: 30
# 数据库配置
database_config:
host: "localhost"
port: 5432
database: "langgraph"
username: "user"
password: "password"
pool_size: 10
"""
def create_configured_graph() -> CompiledGraph:
"""创建配置化的图"""
config = config_manager.get_config()
# 使用配置创建检查点保存器
if config.enable_checkpoints:
if config.database_config.get("host"):
checkpointer = PostgresCheckpointSaver(**config.database_config)
else:
checkpointer = MemorySaver()
else:
checkpointer = None
# 创建图
graph = StateGraph(AgentState)
# 添加节点...
# 编译时应用配置
compiled_graph = graph.compile(
checkpointer=checkpointer,
interrupt_before=None,
interrupt_after=None,
debug=(config.log_level == "DEBUG")
)
return compiled_graph
5.2 健康检查和监控
from typing import Dict, Any, List
import time
import threading
from dataclasses import dataclass
from enum import Enum
class HealthStatus(Enum):
"""健康状态"""
HEALTHY = "healthy"
DEGRADED = "degraded"
UNHEALTHY = "unhealthy"
@dataclass
class HealthCheck:
"""健康检查结果"""
name: str
status: HealthStatus
message: str
duration: float
timestamp: float
details: Dict[str, Any] = None
class HealthMonitor:
"""健康监控器"""
def __init__(self):
self.checks = {}
self.check_history = []
self.monitoring_thread = None
self.monitoring_interval = 30 # 30秒检查一次
self.is_monitoring = False
def register_check(self, name: str, check_func: Callable[[], HealthCheck]):
"""注册健康检查"""
self.checks[name] = check_func
def run_check(self, name: str) -> HealthCheck:
"""运行单个健康检查"""
if name not in self.checks:
return HealthCheck(
name=name,
status=HealthStatus.UNHEALTHY,
message=f"Check '{name}' not found",
duration=0.0,
timestamp=time.time()
)
start_time = time.time()
try:
result = self.checks[name]()
result.duration = time.time() - start_time
result.timestamp = time.time()
return result
except Exception as e:
return HealthCheck(
name=name,
status=HealthStatus.UNHEALTHY,
message=f"Check failed: {e}",
duration=time.time() - start_time,
timestamp=time.time()
)
def run_all_checks(self) -> Dict[str, HealthCheck]:
"""运行所有健康检查"""
results = {}
for name in self.checks:
results[name] = self.run_check(name)
# 记录历史
self.check_history.append({
"timestamp": time.time(),
"results": results
})
# 保持历史记录在合理范围内
if len(self.check_history) > 100:
self.check_history = self.check_history[-100:]
return results
def get_overall_status(self) -> HealthStatus:
"""获取整体健康状态"""
results = self.run_all_checks()
if not results:
return HealthStatus.UNHEALTHY
statuses = [result.status for result in results.values()]
if all(status == HealthStatus.HEALTHY for status in statuses):
return HealthStatus.HEALTHY
elif any(status == HealthStatus.UNHEALTHY for status in statuses):
return HealthStatus.UNHEALTHY
else:
return HealthStatus.DEGRADED
def start_monitoring(self):
"""开始监控"""
if self.is_monitoring:
return
self.is_monitoring = True
self.monitoring_thread = threading.Thread(target=self._monitoring_loop)
self.monitoring_thread.daemon = True
self.monitoring_thread.start()
def stop_monitoring(self):
"""停止监控"""
self.is_monitoring = False
if self.monitoring_thread:
self.monitoring_thread.join()
def _monitoring_loop(self):
"""监控循环"""
while self.is_monitoring:
try:
self.run_all_checks()
time.sleep(self.monitoring_interval)
except Exception as e:
print(f"Monitoring error: {e}")
time.sleep(5) # 错误时短暂等待
# 全局健康监控器
health_monitor = HealthMonitor()
# 具体的健康检查实现
def check_memory_usage() -> HealthCheck:
"""检查内存使用"""
try:
import psutil
memory = psutil.virtual_memory()
if memory.percent > 90:
status = HealthStatus.UNHEALTHY
message = f"Memory usage critical: {memory.percent}%"
elif memory.percent > 80:
status = HealthStatus.DEGRADED
message = f"Memory usage high: {memory.percent}%"
else:
status = HealthStatus.HEALTHY
message = f"Memory usage normal: {memory.percent}%"
return HealthCheck(
name="memory_usage",
status=status,
message=message,
duration=0.0,
timestamp=time.time(),
details={
"total_mb": memory.total / 1024 / 1024,
"used_mb": memory.used / 1024 / 1024,
"percent": memory.percent
}
)
except Exception as e:
return HealthCheck(
name="memory_usage",
status=HealthStatus.UNHEALTHY,
message=f"Failed to check memory: {e}",
duration=0.0,
timestamp=time.time()
)
def check_database_connection() -> HealthCheck:
"""检查数据库连接"""
try:
# 这里应该是实际的数据库连接检查
# 示例代码
start_time = time.time()
# db_connection.ping() # 实际的ping操作
duration = time.time() - start_time
if duration > 5.0:
status = HealthStatus.DEGRADED
message = f"Database response slow: {duration:.2f}s"
else:
status = HealthStatus.HEALTHY
message = f"Database connection OK: {duration:.2f}s"
return HealthCheck(
name="database_connection",
status=status,
message=message,
duration=duration,
timestamp=time.time()
)
except Exception as e:
return HealthCheck(
name="database_connection",
status=HealthStatus.UNHEALTHY,
message=f"Database connection failed: {e}",
duration=0.0,
timestamp=time.time()
)
def check_external_api() -> HealthCheck:
"""检查外部API"""
try:
import requests
start_time = time.time()
# 健康检查端点
response = requests.get("https://api.example.com/health", timeout=10)
duration = time.time() - start_time
if response.status_code == 200:
status = HealthStatus.HEALTHY
message = f"External API OK: {duration:.2f}s"
else:
status = HealthStatus.DEGRADED
message = f"External API error: {response.status_code}"
return HealthCheck(
name="external_api",
status=status,
message=message,
duration=duration,
timestamp=time.time(),
details={
"status_code": response.status_code,
"response_time": duration
}
)
except Exception as e:
return HealthCheck(
name="external_api",
status=HealthStatus.UNHEALTHY,
message=f"External API check failed: {e}",
duration=0.0,
timestamp=time.time()
)
# 注册健康检查
health_monitor.register_check("memory_usage", check_memory_usage)
health_monitor.register_check("database_connection", check_database_connection)
health_monitor.register_check("external_api", check_external_api)
# 启动监控
health_monitor.start_monitoring()
# 健康检查API端点(Flask示例)
from flask import Flask, jsonify
app = Flask(__name__)
@app.route('/health')
def health_endpoint():
"""健康检查端点"""
overall_status = health_monitor.get_overall_status()
results = health_monitor.run_all_checks()
response_data = {
"status": overall_status.value,
"timestamp": time.time(),
"checks": {
name: {
"status": check.status.value,
"message": check.message,
"duration": check.duration,
"details": check.details
}
for name, check in results.items()
}
}
status_code = 200 if overall_status == HealthStatus.HEALTHY else 503
return jsonify(response_data), status_code
@app.route('/metrics')
def metrics_endpoint():
"""指标端点"""
performance_report = metrics_collector.get_performance_report()
return jsonify(performance_report)
6. 总结与建议
6.1 核心最佳实践总结
-
架构设计
- 遵循单一职责原则
- 合理设计状态结构
- 使用清晰的条件路由
- 采用模块化组合
-
性能优化
- 控制状态大小
- 实施缓存策略
- 支持并行执行
- 监控内存使用
-
错误处理
- 分层错误处理
- 优雅降级策略
- 完善的重试机制
- 详细的错误上下文
-
监控运维
- 结构化日志记录
- 全面的性能监控
- 健康检查机制
- 配置管理系统
6.2 常见陷阱与避免方法
- 状态膨胀:定期清理状态,使用引用而非复制
- 内存泄漏:及时释放资源,监控内存使用
- 死锁风险:避免循环依赖,设置超时机制
- 性能瓶颈:识别热点节点,优化关键路径
6.3 发展趋势与建议
- 云原生部署:容器化部署,自动扩缩容
- 微服务架构:服务拆分,独立部署
- AI Ops集成:智能监控,自动优化
- 边缘计算:分布式执行,就近处理
通过遵循这些最佳实践,开发者可以构建出高质量、高性能、高可靠性的LangGraph应用,为用户提供优秀的AI服务体验。
tommie blog