1. 项目架构设计最佳实践
1.1 状态设计原则
1.1.1 状态结构设计
"""
状态设计最佳实践
"""
from typing import TypedDict, Annotated, List, Optional, Dict, Any
from langgraph.graph import add_messages
from langchain_core.messages import AnyMessage
# ❌ 不好的状态设计
class BadState(TypedDict):
everything: dict # 将所有数据都放在一个字典中
data: Any # 使用模糊的类型
# ✅ 好的状态设计
class GoodState(TypedDict):
"""
良好的状态设计特点:
1. 清晰的字段定义和类型注解
2. 合理的数据分组
3. 必要的元数据
4. 适当的默认值
"""
# 核心业务数据
messages: Annotated[List[AnyMessage], add_messages]
user_context: Dict[str, str]
task_result: Optional[str]
# 流程控制
current_stage: str
retry_count: int
max_retries: int
# 元数据
session_id: str
created_at: str
updated_at: str
# 可选状态
debug_info: Optional[Dict[str, Any]]
performance_metrics: Optional[Dict[str, float]]
# 状态设计准则
class StateDesignPrinciples:
"""
状态设计核心准则
1. 单一职责:每个字段都有明确的用途
2. 类型安全:使用具体的类型注解
3. 向前兼容:使用Optional标记可选字段
4. 可扩展性:预留扩展空间
5. 可序列化:确保状态可以被检查点系统序列化
"""
@staticmethod
def validate_state_design(state_class: type) -> List[str]:
"""验证状态设计的合理性"""
issues = []
# 检查类型注解
annotations = getattr(state_class, '__annotations__', {})
if not annotations:
issues.append("缺少类型注解")
# 检查字段命名
for field_name in annotations:
if not field_name.replace('_', '').isalnum():
issues.append(f"字段名 {field_name} 包含特殊字符")
if field_name.startswith('__'):
issues.append(f"字段名 {field_name} 不应以双下划线开头")
return issues
1.1.2 状态更新策略
"""
状态更新最佳实践
"""
def good_node_function(state: GoodState) -> dict:
"""
良好的节点函数特点:
1. 只更新必要的字段
2. 保持状态的一致性
3. 提供清晰的更新日志
"""
current_stage = state.get("current_stage", "unknown")
retry_count = state.get("retry_count", 0)
# 执行业务逻辑
try:
result = perform_business_logic(state)
# 返回最小化的状态更新
return {
"task_result": result,
"current_stage": "completed",
"updated_at": datetime.now().isoformat(),
"performance_metrics": {
"processing_time": 1.23,
"memory_usage": 45.6
}
}
except Exception as e:
# 错误处理时的状态更新
return {
"current_stage": "error",
"retry_count": retry_count + 1,
"updated_at": datetime.now().isoformat(),
"debug_info": {
"error": str(e),
"error_type": type(e).__name__,
"traceback": traceback.format_exc()
}
}
def bad_node_function(state: GoodState) -> dict:
"""
❌ 避免的做法:
1. 更新不相关的字段
2. 直接修改输入状态
3. 返回整个状态副本
"""
# 错误做法1:修改输入状态
state["current_stage"] = "processing" # ❌
# 错误做法2:返回不必要的字段
return dict(state) # ❌ 返回整个状态
# 状态更新最佳实践总结
class StateUpdateBestPractices:
"""状态更新最佳实践汇总"""
@staticmethod
def create_minimal_update(updates: dict) -> dict:
"""创建最小化的状态更新"""
# 移除None值
return {k: v for k, v in updates.items() if v is not None}
@staticmethod
def add_timestamp(update: dict) -> dict:
"""为状态更新添加时间戳"""
update["updated_at"] = datetime.now().isoformat()
return update
@staticmethod
def validate_update(update: dict, state_schema: type) -> bool:
"""验证状态更新的合法性"""
schema_fields = getattr(state_schema, '__annotations__', {})
for key in update:
if key not in schema_fields:
logging.warning(f"Unknown state field: {key}")
return False
return True
1.2 节点设计模式
1.2.1 单一职责节点
"""
节点设计最佳实践:单一职责原则
"""
# ✅ 好的节点设计:单一职责
def data_validator_node(state: ProcessingState) -> dict:
"""
数据验证节点:仅负责数据验证
职责:
1. 验证输入数据格式
2. 检查数据完整性
3. 返回验证结果
"""
data = state["raw_data"]
# 专注于验证逻辑
validation_errors = []
if not data:
validation_errors.append("数据为空")
if not isinstance(data, dict):
validation_errors.append("数据格式不正确")
# 返回验证结果
return {
"validation_errors": validation_errors,
"is_valid": len(validation_errors) == 0,
"validation_timestamp": datetime.now().isoformat()
}
def data_processor_node(state: ProcessingState) -> dict:
"""
数据处理节点:仅负责数据转换
职责:
1. 转换数据格式
2. 执行业务逻辑
3. 生成处理结果
"""
if not state.get("is_valid", False):
return {"error": "数据未通过验证,无法处理"}
raw_data = state["raw_data"]
# 专注于处理逻辑
processed_data = transform_data(raw_data)
return {
"processed_data": processed_data,
"processing_timestamp": datetime.now().isoformat()
}
# ❌ 避免的设计:多职责节点
def bad_all_in_one_node(state: ProcessingState) -> dict:
"""
不好的设计:一个节点做太多事情
"""
# 验证、处理、保存、通知 - 太多职责
data = state["raw_data"]
# 验证
if not validate_data(data):
return {"error": "验证失败"}
# 处理
processed = process_data(data)
# 保存
save_to_database(processed)
# 发送通知
send_notification(processed)
return {"result": "everything done"} # 职责不明确
1.2.2 错误处理模式
"""
节点错误处理最佳实践
"""
from typing import Union, Tuple
import logging
class NodeErrorHandler:
"""节点错误处理器"""
@staticmethod
def with_retry(max_retries: int = 3):
"""重试装饰器"""
def decorator(node_func):
def wrapper(state: dict) -> dict:
retry_count = state.get("retry_count", 0)
if retry_count >= max_retries:
return {
"error": f"重试次数超过限制 ({max_retries})",
"retry_count": retry_count,
"status": "failed"
}
try:
result = node_func(state)
# 成功时重置重试计数
result["retry_count"] = 0
return result
except Exception as e:
logging.error(f"节点执行失败: {e}")
return {
"error": str(e),
"retry_count": retry_count + 1,
"status": "retry_needed"
}
return wrapper
return decorator
@staticmethod
def safe_execute(node_func, state: dict, fallback_value: Any = None) -> dict:
"""安全执行节点函数"""
try:
return node_func(state)
except Exception as e:
logging.error(f"节点执行异常: {e}")
return {
"error": str(e),
"fallback_result": fallback_value,
"status": "error",
"error_type": type(e).__name__
}
# 使用示例
@NodeErrorHandler.with_retry(max_retries=3)
def robust_api_call_node(state: dict) -> dict:
"""具有重试机制的API调用节点"""
url = state["api_url"]
payload = state["request_payload"]
# 可能失败的API调用
response = requests.post(url, json=payload, timeout=10)
response.raise_for_status()
return {
"api_response": response.json(),
"status": "success"
}
def graceful_node(state: dict) -> dict:
"""优雅处理错误的节点示例"""
try:
# 主要逻辑
result = complex_operation(state["input_data"])
return {
"result": result,
"status": "success",
"execution_time": time.time() - start_time
}
except ValidationError as e:
# 预期的业务错误
return {
"error": f"数据验证失败: {e}",
"error_code": "VALIDATION_ERROR",
"status": "failed",
"recoverable": True
}
except ConnectionError as e:
# 网络错误,可重试
return {
"error": f"网络连接失败: {e}",
"error_code": "CONNECTION_ERROR",
"status": "retry_needed",
"recoverable": True
}
except Exception as e:
# 未预期的错误
return {
"error": f"未知错误: {e}",
"error_code": "UNKNOWN_ERROR",
"status": "failed",
"recoverable": False,
"debug_info": traceback.format_exc()
}
1.3 图结构设计
1.3.1 条件路由最佳实践
"""
条件路由设计最佳实践
"""
from typing import Literal
# ✅ 清晰的条件路由函数
def business_logic_router(state: BusinessState) -> Literal["process", "validate", "error"]:
"""
业务逻辑路由器
设计原则:
1. 路由逻辑清晰明确
2. 覆盖所有可能情况
3. 提供默认分支
4. 使用类型提示限制返回值
"""
# 错误状态检查
if state.get("has_error", False):
return "error"
# 验证状态检查
validation_status = state.get("validation_status", "pending")
if validation_status == "failed":
return "validate"
elif validation_status == "passed":
return "process"
# 默认路由
return "validate"
def complex_workflow_router(state: WorkflowState) -> Literal["step1", "step2", "step3", "complete", "error"]:
"""
复杂工作流路由器
处理多步骤工作流的路由决策
"""
current_step = state.get("current_step", 0)
has_errors = bool(state.get("errors", []))
if has_errors:
return "error"
if current_step == 0:
return "step1"
elif current_step == 1:
return "step2"
elif current_step == 2:
return "step3"
else:
return "complete"
# ❌ 避免的路由设计
def bad_router(state: dict) -> str:
"""
不好的路由设计示例
"""
# 问题1:逻辑过于复杂
if (state.get("status") == "ok" and
len(state.get("data", [])) > 0 and
state.get("user_tier") in ["vip", "premium"] and
datetime.now().hour < 18):
return "complex_processing"
# 问题2:魔法数字
if state.get("score", 0) > 0.85:
return "high_quality"
# 问题3:缺少默认分支
# 如果以上条件都不满足会怎样?
# 路由设计最佳实践
class RoutingBestPractices:
"""路由设计最佳实践"""
@staticmethod
def create_state_based_router(state_field: str, route_map: dict):
"""基于状态字段的路由器工厂"""
def router(state: dict) -> str:
field_value = state.get(state_field)
return route_map.get(field_value, route_map.get("default", "error"))
return router
@staticmethod
def create_threshold_router(score_field: str, thresholds: dict):
"""基于阈值的路由器工厂"""
def router(state: dict) -> str:
score = state.get(score_field, 0)
for threshold, route in sorted(thresholds.items(), reverse=True):
if score >= threshold:
return route
return "default"
return router
# 使用示例
quality_router = RoutingBestPractices.create_threshold_router(
score_field="quality_score",
thresholds={
0.9: "excellent",
0.7: "good",
0.5: "acceptable",
0.0: "poor"
}
)
status_router = RoutingBestPractices.create_state_based_router(
state_field="processing_status",
route_map={
"pending": "process",
"completed": "finalize",
"error": "handle_error",
"default": "initialize"
}
)
2. 性能优化实战经验
2.1 检查点优化策略
"""
检查点性能优化实战经验
"""
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.checkpoint.base import BaseCheckpointSaver
import asyncio
import time
class OptimizedCheckpointSaver(BaseCheckpointSaver):
"""优化的检查点保存器"""
def __init__(self, batch_size: int = 10, batch_timeout: float = 1.0):
super().__init__()
self.batch_size = batch_size
self.batch_timeout = batch_timeout
self.pending_writes = []
self.last_flush = time.time()
def put(self, config, checkpoint, metadata, new_versions):
"""批量写入优化"""
# 添加到待写入队列
self.pending_writes.append({
"config": config,
"checkpoint": checkpoint,
"metadata": metadata,
"new_versions": new_versions,
"timestamp": time.time()
})
# 检查是否需要刷新
should_flush = (
len(self.pending_writes) >= self.batch_size or
time.time() - self.last_flush > self.batch_timeout
)
if should_flush:
self._flush_batch()
return config
def _flush_batch(self):
"""刷新批次写入"""
if not self.pending_writes:
return
# 批量写入逻辑
batch = self.pending_writes[:]
self.pending_writes.clear()
self.last_flush = time.time()
# 实际写入(这里需要具体实现)
self._batch_write_to_storage(batch)
class CheckpointOptimizer:
"""检查点优化工具类"""
@staticmethod
def minimize_state_size(state: dict) -> dict:
"""最小化状态大小"""
optimized_state = {}
for key, value in state.items():
# 移除大型临时数据
if key.startswith("_temp_") or key.startswith("_cache_"):
continue
# 压缩大型字符串
if isinstance(value, str) and len(value) > 1000:
# 可以使用压缩算法
optimized_state[key] = compress_string(value)
else:
optimized_state[key] = value
return optimized_state
@staticmethod
def should_create_checkpoint(state: dict, last_checkpoint_time: float) -> bool:
"""智能检查点创建决策"""
current_time = time.time()
time_since_last = current_time - last_checkpoint_time
# 基于时间的检查点创建
if time_since_last > 60: # 1分钟
return True
# 基于状态变化的检查点创建
significant_changes = [
state.get("task_completed", False),
state.get("error_occurred", False),
state.get("user_interaction", False)
]
return any(significant_changes)
def compress_string(text: str) -> dict:
"""字符串压缩工具"""
import gzip
import base64
compressed = gzip.compress(text.encode('utf-8'))
encoded = base64.b64encode(compressed).decode('ascii')
return {
"type": "compressed_string",
"data": encoded,
"original_size": len(text)
}
2.2 内存管理最佳实践
"""
内存管理最佳实践
"""
import gc
import psutil
import weakref
from typing import Any, Dict, List
class MemoryManager:
"""内存管理工具类"""
def __init__(self, memory_limit_mb: int = 1024):
self.memory_limit_mb = memory_limit_mb
self.large_objects: List[weakref.ref] = []
def monitor_memory_usage(self) -> Dict[str, float]:
"""监控内存使用情况"""
process = psutil.Process()
memory_info = process.memory_info()
return {
"rss_mb": memory_info.rss / 1024 / 1024, # 物理内存
"vms_mb": memory_info.vms / 1024 / 1024, # 虚拟内存
"percent": process.memory_percent(), # 内存使用百分比
"available_mb": psutil.virtual_memory().available / 1024 / 1024
}
def cleanup_if_needed(self) -> bool:
"""必要时清理内存"""
memory_stats = self.monitor_memory_usage()
if memory_stats["rss_mb"] > self.memory_limit_mb:
# 执行内存清理
return self._perform_cleanup()
return False
def _perform_cleanup(self) -> bool:
"""执行内存清理"""
initial_memory = self.monitor_memory_usage()["rss_mb"]
# 清理弱引用对象
self.large_objects = [ref for ref in self.large_objects if ref() is not None]
# 强制垃圾回收
gc.collect()
final_memory = self.monitor_memory_usage()["rss_mb"]
freed_mb = initial_memory - final_memory
return freed_mb > 0
def register_large_object(self, obj: Any) -> None:
"""注册大对象用于管理"""
self.large_objects.append(weakref.ref(obj))
class StateMemoryOptimizer:
"""状态内存优化器"""
@staticmethod
def optimize_message_history(messages: List[Any], max_messages: int = 50) -> List[Any]:
"""优化消息历史长度"""
if len(messages) <= max_messages:
return messages
# 保留最新的消息
return messages[-max_messages:]
@staticmethod
def compress_large_data(state: dict, size_threshold: int = 1000) -> dict:
"""压缩大型数据"""
optimized = {}
for key, value in state.items():
if isinstance(value, str) and len(value) > size_threshold:
optimized[key] = {
"type": "large_string",
"size": len(value),
"hash": hash(value),
"compressed": True
}
elif isinstance(value, (list, dict)) and len(str(value)) > size_threshold:
optimized[key] = {
"type": "large_object",
"size": len(str(value)),
"compressed": True
}
else:
optimized[key] = value
return optimized
# 在节点中使用内存优化
memory_manager = MemoryManager(memory_limit_mb=512)
def memory_optimized_node(state: dict) -> dict:
"""内存优化的节点实现"""
# 执行前检查内存
memory_manager.cleanup_if_needed()
# 优化输入状态
optimized_state = StateMemoryOptimizer.compress_large_data(state)
try:
# 执行主要逻辑
result = process_data(optimized_state)
# 注册大对象
if isinstance(result, dict) and len(str(result)) > 1000:
memory_manager.register_large_object(result)
return {"result": result}
finally:
# 执行后清理
memory_manager.cleanup_if_needed()
2.3 并发处理优化
"""
并发处理优化实践
"""
import asyncio
import concurrent.futures
from typing import List, Callable, Any
import threading
import queue
class ConcurrentNodeExecutor:
"""并发节点执行器"""
def __init__(self, max_workers: int = 4):
self.max_workers = max_workers
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
def execute_parallel_nodes(
self,
node_functions: List[Callable],
shared_state: dict
) -> List[dict]:
"""并行执行多个节点"""
futures = []
for node_func in node_functions:
future = self.executor.submit(self._safe_execute_node, node_func, shared_state.copy())
futures.append(future)
results = []
for future in concurrent.futures.as_completed(futures):
try:
result = future.result(timeout=30) # 30秒超时
results.append(result)
except Exception as e:
results.append({"error": str(e)})
return results
def _safe_execute_node(self, node_func: Callable, state: dict) -> dict:
"""安全执行单个节点"""
try:
return node_func(state)
except Exception as e:
return {
"error": str(e),
"node_function": node_func.__name__
}
class AsyncStateProcessor:
"""异步状态处理器"""
def __init__(self):
self.processing_queue = asyncio.Queue(maxsize=100)
self.results_cache = {}
async def process_states_async(self, states: List[dict]) -> List[dict]:
"""异步处理多个状态"""
tasks = []
for i, state in enumerate(states):
task = asyncio.create_task(self._process_single_state(state, i))
tasks.append(task)
results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理异常结果
processed_results = []
for result in results:
if isinstance(result, Exception):
processed_results.append({"error": str(result)})
else:
processed_results.append(result)
return processed_results
async def _process_single_state(self, state: dict, index: int) -> dict:
"""处理单个状态"""
# 模拟异步处理
await asyncio.sleep(0.1) # 模拟I/O操作
# 缓存检查
state_hash = hash(str(sorted(state.items())))
if state_hash in self.results_cache:
return self.results_cache[state_hash]
# 实际处理
result = {
"processed_state": state,
"index": index,
"timestamp": time.time()
}
# 缓存结果
self.results_cache[state_hash] = result
return result
# 工具节点并发优化示例
class OptimizedToolNode:
"""优化的工具节点"""
def __init__(self, tools: List[Any], max_concurrent: int = 3):
self.tools = {tool.name: tool for tool in tools}
self.max_concurrent = max_concurrent
self.semaphore = asyncio.Semaphore(max_concurrent)
async def execute_tools_async(self, tool_calls: List[dict]) -> List[dict]:
"""异步并发执行工具调用"""
tasks = []
for tool_call in tool_calls:
task = asyncio.create_task(self._execute_single_tool(tool_call))
tasks.append(task)
results = await asyncio.gather(*tasks, return_exceptions=True)
return [self._format_result(r) for r in results]
async def _execute_single_tool(self, tool_call: dict) -> dict:
"""执行单个工具调用"""
async with self.semaphore: # 限制并发数
tool_name = tool_call["name"]
tool_args = tool_call.get("args", {})
if tool_name not in self.tools:
raise ValueError(f"Unknown tool: {tool_name}")
tool = self.tools[tool_name]
# 在线程池中执行同步工具
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(None, tool.invoke, tool_args)
return {
"tool_call_id": tool_call["id"],
"tool_name": tool_name,
"result": result
}
def _format_result(self, result: Any) -> dict:
"""格式化结果"""
if isinstance(result, Exception):
return {"error": str(result)}
return result
3. 错误处理和调试策略
3.1 全面的错误处理体系
"""
全面的错误处理体系
"""
import traceback
import logging
from enum import Enum
from typing import Optional, Dict, Any
from contextlib import contextmanager
class ErrorSeverity(Enum):
"""错误严重程度"""
LOW = "low" # 警告级别,不影响主流程
MEDIUM = "medium" # 错误级别,影响当前操作
HIGH = "high" # 严重错误,影响整个会话
CRITICAL = "critical" # 关键错误,需要立即处理
class LangGraphError(Exception):
"""LangGraph自定义错误基类"""
def __init__(
self,
message: str,
error_code: str = "UNKNOWN",
severity: ErrorSeverity = ErrorSeverity.MEDIUM,
context: Optional[Dict[str, Any]] = None,
recoverable: bool = True
):
super().__init__(message)
self.message = message
self.error_code = error_code
self.severity = severity
self.context = context or {}
self.recoverable = recoverable
self.timestamp = time.time()
class NodeExecutionError(LangGraphError):
"""节点执行错误"""
def __init__(self, node_name: str, original_error: Exception, **kwargs):
super().__init__(
f"Node '{node_name}' execution failed: {str(original_error)}",
error_code="NODE_EXECUTION_ERROR",
**kwargs
)
self.node_name = node_name
self.original_error = original_error
class StateValidationError(LangGraphError):
"""状态验证错误"""
def __init__(self, validation_errors: List[str], **kwargs):
super().__init__(
f"State validation failed: {'; '.join(validation_errors)}",
error_code="STATE_VALIDATION_ERROR",
**kwargs
)
self.validation_errors = validation_errors
class ErrorHandler:
"""统一错误处理器"""
def __init__(self, logger: logging.Logger = None):
self.logger = logger or logging.getLogger(__name__)
self.error_counts = {}
self.recovery_strategies = {}
def register_recovery_strategy(
self,
error_type: type,
strategy: Callable[[Exception, dict], dict]
):
"""注册错误恢复策略"""
self.recovery_strategies[error_type] = strategy
@contextmanager
def error_context(self, operation: str, state: dict):
"""错误上下文管理器"""
try:
yield
except Exception as e:
self._handle_error(e, operation, state)
raise
def _handle_error(self, error: Exception, operation: str, state: dict):
"""处理错误"""
# 记录错误
error_info = {
"operation": operation,
"error_type": type(error).__name__,
"error_message": str(error),
"state_keys": list(state.keys()),
"traceback": traceback.format_exc(),
"timestamp": time.time()
}
self.logger.error(f"Error in operation '{operation}': {error_info}")
# 更新错误计数
error_key = f"{operation}:{type(error).__name__}"
self.error_counts[error_key] = self.error_counts.get(error_key, 0) + 1
# 检查恢复策略
for error_type, strategy in self.recovery_strategies.items():
if isinstance(error, error_type):
try:
recovery_result = strategy(error, state)
self.logger.info(f"Applied recovery strategy for {error_type.__name__}")
return recovery_result
except Exception as recovery_error:
self.logger.error(f"Recovery strategy failed: {recovery_error}")
# 恢复策略示例
def api_call_recovery_strategy(error: Exception, state: dict) -> dict:
"""API调用失败恢复策略"""
if "timeout" in str(error).lower():
# 超时错误:增加重试延迟
return {
"retry_delay": state.get("retry_delay", 1) * 2,
"should_retry": True,
"recovery_action": "increase_timeout"
}
elif "rate limit" in str(error).lower():
# 限流错误:等待更长时间
return {
"retry_delay": 60, # 等待1分钟
"should_retry": True,
"recovery_action": "wait_for_rate_limit"
}
else:
# 其他错误:使用降级方案
return {
"should_retry": False,
"use_fallback": True,
"recovery_action": "use_fallback_service"
}
# 在节点中使用错误处理
error_handler = ErrorHandler()
error_handler.register_recovery_strategy(requests.RequestException, api_call_recovery_strategy)
def robust_api_node(state: dict) -> dict:
"""具有完善错误处理的API节点"""
with error_handler.error_context("api_call", state):
try:
# 主要API调用逻辑
response = requests.get(
state["api_url"],
params=state["params"],
timeout=state.get("timeout", 10)
)
response.raise_for_status()
return {
"api_response": response.json(),
"status": "success",
"timestamp": time.time()
}
except requests.RequestException as e:
# 应用恢复策略
recovery_info = api_call_recovery_strategy(e, state)
raise NodeExecutionError(
node_name="robust_api_node",
original_error=e,
context=recovery_info,
severity=ErrorSeverity.MEDIUM,
recoverable=recovery_info.get("should_retry", False)
)
3.2 调试和监控工具
"""
调试和监控工具
"""
import time
import json
from typing import Dict, List, Any, Callable
from dataclasses import dataclass, asdict
import threading
@dataclass
class ExecutionMetrics:
"""执行指标"""
node_name: str
start_time: float
end_time: float
duration: float
memory_before: float
memory_after: float
success: bool
error_message: Optional[str] = None
@property
def memory_delta(self) -> float:
return self.memory_after - self.memory_before
class GraphDebugger:
"""图调试器"""
def __init__(self, enabled: bool = True):
self.enabled = enabled
self.execution_log: List[Dict[str, Any]] = []
self.node_metrics: Dict[str, List[ExecutionMetrics]] = {}
self.state_history: List[Dict[str, Any]] = []
self.lock = threading.Lock()
def log_node_execution(
self,
node_name: str,
input_state: dict,
output_state: dict,
execution_time: float,
error: Optional[Exception] = None
):
"""记录节点执行情况"""
if not self.enabled:
return
with self.lock:
log_entry = {
"timestamp": time.time(),
"node_name": node_name,
"input_state_keys": list(input_state.keys()),
"output_state_keys": list(output_state.keys()),
"execution_time": execution_time,
"success": error is None,
"error": str(error) if error else None
}
self.execution_log.append(log_entry)
def log_state_change(self, old_state: dict, new_state: dict, operation: str):
"""记录状态变化"""
if not self.enabled:
return
with self.lock:
changes = self._detect_state_changes(old_state, new_state)
change_entry = {
"timestamp": time.time(),
"operation": operation,
"changes": changes,
"state_size_before": len(str(old_state)),
"state_size_after": len(str(new_state))
}
self.state_history.append(change_entry)
def _detect_state_changes(self, old_state: dict, new_state: dict) -> Dict[str, Dict[str, Any]]:
"""检测状态变化"""
changes = {}
# 检查新增字段
for key in new_state:
if key not in old_state:
changes[key] = {"action": "added", "new_value": new_state[key]}
# 检查删除字段
for key in old_state:
if key not in new_state:
changes[key] = {"action": "removed", "old_value": old_state[key]}
# 检查修改字段
for key in old_state:
if key in new_state and old_state[key] != new_state[key]:
changes[key] = {
"action": "modified",
"old_value": old_state[key],
"new_value": new_state[key]
}
return changes
def get_performance_report(self) -> Dict[str, Any]:
"""获取性能报告"""
if not self.execution_log:
return {"message": "No execution data available"}
# 节点执行时间统计
node_times = {}
node_success_rates = {}
for entry in self.execution_log:
node_name = entry["node_name"]
exec_time = entry["execution_time"]
success = entry["success"]
if node_name not in node_times:
node_times[node_name] = []
node_success_rates[node_name] = {"success": 0, "total": 0}
node_times[node_name].append(exec_time)
node_success_rates[node_name]["total"] += 1
if success:
node_success_rates[node_name]["success"] += 1
# 计算统计信息
performance_stats = {}
for node_name, times in node_times.items():
success_rate = node_success_rates[node_name]["success"] / node_success_rates[node_name]["total"]
performance_stats[node_name] = {
"avg_execution_time": sum(times) / len(times),
"min_execution_time": min(times),
"max_execution_time": max(times),
"total_executions": len(times),
"success_rate": success_rate,
"total_time": sum(times)
}
return {
"total_executions": len(self.execution_log),
"unique_nodes": len(node_times),
"node_performance": performance_stats,
"state_changes": len(self.state_history)
}
def export_debug_data(self, filepath: str):
"""导出调试数据"""
debug_data = {
"execution_log": self.execution_log,
"state_history": self.state_history,
"performance_report": self.get_performance_report(),
"export_timestamp": time.time()
}
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(debug_data, f, indent=2, ensure_ascii=False)
def debug_node(node_func: Callable, debugger: GraphDebugger):
"""节点调试装饰器"""
def wrapper(state: dict) -> dict:
start_time = time.time()
node_name = node_func.__name__
input_state = state.copy()
error = None
try:
result = node_func(state)
return result
except Exception as e:
error = e
result = {"error": str(e)}
raise
finally:
execution_time = time.time() - start_time
debugger.log_node_execution(
node_name=node_name,
input_state=input_state,
output_state=result,
execution_time=execution_time,
error=error
)
return wrapper
# 使用示例
debugger = GraphDebugger(enabled=True)
@debug_node
def monitored_processing_node(state: dict) -> dict:
"""被监控的处理节点"""
# 模拟处理逻辑
time.sleep(0.1) # 模拟耗时操作
return {
"processed_data": f"Processed: {state.get('input_data', 'N/A')}",
"processing_timestamp": time.time()
}
# 定期生成性能报告
def generate_periodic_report(debugger: GraphDebugger, interval: int = 300):
"""定期生成性能报告"""
def report_worker():
while True:
time.sleep(interval)
report = debugger.get_performance_report()
# 输出报告到日志
logging.info("=== Performance Report ===")
logging.info(json.dumps(report, indent=2))
# 导出详细数据
timestamp = int(time.time())
debugger.export_debug_data(f"debug_report_{timestamp}.json")
thread = threading.Thread(target=report_worker, daemon=True)
thread.start()
4. 生产环境部署实践
4.1 生产环境配置
"""
生产环境配置最佳实践
"""
import os
import logging
from typing import Dict, Any, Optional
from dataclasses import dataclass
from langgraph.checkpoint.postgres import PostgresCheckpointSaver
from langgraph.store.postgres import PostgresStore
@dataclass
class ProductionConfig:
"""生产环境配置"""
# 数据库配置
postgres_uri: str
redis_uri: Optional[str] = None
# 性能配置
max_workers: int = 4
request_timeout: int = 30
checkpoint_batch_size: int = 10
# 安全配置
api_key_required: bool = True
rate_limit_per_minute: int = 60
max_message_length: int = 10000
# 监控配置
enable_metrics: bool = True
log_level: str = "INFO"
sentry_dsn: Optional[str] = None
# 资源限制
memory_limit_mb: int = 1024
max_concurrent_sessions: int = 100
@classmethod
def from_env(cls) -> 'ProductionConfig':
"""从环境变量创建配置"""
return cls(
postgres_uri=os.getenv("POSTGRES_URI", ""),
redis_uri=os.getenv("REDIS_URI"),
max_workers=int(os.getenv("MAX_WORKERS", "4")),
request_timeout=int(os.getenv("REQUEST_TIMEOUT", "30")),
api_key_required=os.getenv("API_KEY_REQUIRED", "true").lower() == "true",
rate_limit_per_minute=int(os.getenv("RATE_LIMIT_PER_MINUTE", "60")),
enable_metrics=os.getenv("ENABLE_METRICS", "true").lower() == "true",
log_level=os.getenv("LOG_LEVEL", "INFO"),
sentry_dsn=os.getenv("SENTRY_DSN"),
memory_limit_mb=int(os.getenv("MEMORY_LIMIT_MB", "1024"))
)
class ProductionGraphFactory:
"""生产环境图工厂"""
def __init__(self, config: ProductionConfig):
self.config = config
self._setup_logging()
self._setup_monitoring()
def _setup_logging(self):
"""设置日志"""
logging.basicConfig(
level=getattr(logging, self.config.log_level),
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
logging.FileHandler('langgraph_production.log')
]
)
# Sentry集成
if self.config.sentry_dsn:
import sentry_sdk
sentry_sdk.init(dsn=self.config.sentry_dsn)
def _setup_monitoring(self):
"""设置监控"""
if self.config.enable_metrics:
# 可以集成Prometheus等监控系统
pass
def create_checkpointer(self) -> PostgresCheckpointSaver:
"""创建生产级检查点保存器"""
return PostgresCheckpointSaver(
connection_string=self.config.postgres_uri,
# 生产环境配置
pool_size=10,
max_overflow=20,
pool_timeout=30,
pool_recycle=3600
)
def create_store(self) -> PostgresStore:
"""创建生产级存储"""
return PostgresStore(
connection_string=self.config.postgres_uri,
# 存储配置
table_name="langgraph_store",
pool_size=5
)
def create_production_graph(self, graph_builder_func: Callable) -> Any:
"""创建生产环境图"""
# 创建基础组件
checkpointer = self.create_checkpointer()
store = self.create_store()
# 构建图
graph = graph_builder_func()
# 编译带生产配置的图
compiled = graph.compile(
checkpointer=checkpointer,
store=store,
interrupt_before=None, # 生产环境通常不需要交互式中断
interrupt_after=None,
debug=False # 生产环境关闭调试
)
# 包装生产环境功能
return ProductionGraphWrapper(compiled, self.config)
class ProductionGraphWrapper:
"""生产环境图包装器"""
def __init__(self, graph: Any, config: ProductionConfig):
self.graph = graph
self.config = config
self.session_count = 0
self.request_count = 0
def invoke_with_safety(self, input_data: dict, config: dict) -> dict:
"""安全的调用包装"""
# 检查并发会话限制
if self.session_count >= self.config.max_concurrent_sessions:
raise Exception("Max concurrent sessions exceeded")
# 检查输入大小
input_str = str(input_data)
if len(input_str) > self.config.max_message_length:
raise Exception("Input message too large")
# 更新计数器
self.session_count += 1
self.request_count += 1
try:
# 设置超时
import signal
def timeout_handler(signum, frame):
raise TimeoutError("Request timeout")
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(self.config.request_timeout)
try:
result = self.graph.invoke(input_data, config)
return result
finally:
signal.alarm(0) # 取消超时
finally:
self.session_count -= 1
4.2 容器化部署
# Dockerfile 生产环境优化
FROM python:3.11-slim as builder
# 安装系统依赖
RUN apt-get update && apt-get install -y \
gcc \
g++ \
&& rm -rf /var/lib/apt/lists/*
# 设置工作目录
WORKDIR /app
# 复制依赖文件
COPY requirements.txt .
COPY pyproject.toml .
# 安装Python依赖
RUN pip install --no-cache-dir --upgrade pip && \
pip install --no-cache-dir -r requirements.txt
# 生产阶段
FROM python:3.11-slim as production
# 创建非root用户
RUN useradd --create-home --shell /bin/bash langgraph
# 安装运行时依赖
RUN apt-get update && apt-get install -y \
curl \
&& rm -rf /var/lib/apt/lists/*
# 复制应用和依赖
WORKDIR /app
COPY --from=builder /usr/local/lib/python3.11/site-packages /usr/local/lib/python3.11/site-packages
COPY --from=builder /usr/local/bin /usr/local/bin
# 复制应用代码
COPY . .
# 设置权限
RUN chown -R langgraph:langgraph /app
USER langgraph
# 健康检查
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1
# 环境变量
ENV PYTHONPATH=/app
ENV LANGGRAPH_ENV=production
# 启动命令
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4"]
# docker-compose.production.yml
version: '3.8'
services:
langgraph-app:
build:
context: .
dockerfile: Dockerfile
target: production
ports:
- "8000:8000"
environment:
- POSTGRES_URI=postgresql://user:password@postgres:5432/langgraph
- REDIS_URI=redis://redis:6379
- LOG_LEVEL=INFO
- ENABLE_METRICS=true
- MAX_WORKERS=4
- MEMORY_LIMIT_MB=1024
depends_on:
- postgres
- redis
restart: unless-stopped
deploy:
resources:
limits:
memory: 1G
cpus: '2'
reservations:
memory: 512M
cpus: '1'
volumes:
- ./logs:/app/logs
networks:
- langgraph-network
postgres:
image: postgres:15
environment:
- POSTGRES_DB=langgraph
- POSTGRES_USER=user
- POSTGRES_PASSWORD=password
volumes:
- postgres_data:/var/lib/postgresql/data
ports:
- "5432:5432"
restart: unless-stopped
networks:
- langgraph-network
redis:
image: redis:7-alpine
ports:
- "6379:6379"
restart: unless-stopped
networks:
- langgraph-network
nginx:
image: nginx:alpine
ports:
- "80:80"
- "443:443"
volumes:
- ./nginx.conf:/etc/nginx/nginx.conf:ro
- ./ssl:/etc/ssl/certs:ro
depends_on:
- langgraph-app
restart: unless-stopped
networks:
- langgraph-network
volumes:
postgres_data:
networks:
langgraph-network:
driver: bridge
4.3 监控和告警
"""
生产环境监控和告警
"""
import time
import threading
from typing import Dict, List, Callable
from dataclasses import dataclass, field
from collections import defaultdict, deque
@dataclass
class MetricPoint:
"""指标数据点"""
timestamp: float
value: float
labels: Dict[str, str] = field(default_factory=dict)
class MetricsCollector:
"""指标收集器"""
def __init__(self, max_points: int = 1000):
self.max_points = max_points
self.metrics: Dict[str, deque] = defaultdict(lambda: deque(maxlen=max_points))
self.lock = threading.Lock()
def record_metric(self, name: str, value: float, labels: Dict[str, str] = None):
"""记录指标"""
with self.lock:
point = MetricPoint(
timestamp=time.time(),
value=value,
labels=labels or {}
)
self.metrics[name].append(point)
def get_metric_stats(self, name: str, window_seconds: int = 300) -> Dict[str, float]:
"""获取指标统计信息"""
with self.lock:
if name not in self.metrics:
return {}
current_time = time.time()
cutoff_time = current_time - window_seconds
# 过滤时间窗口内的数据
recent_points = [
p for p in self.metrics[name]
if p.timestamp >= cutoff_time
]
if not recent_points:
return {}
values = [p.value for p in recent_points]
return {
"count": len(values),
"avg": sum(values) / len(values),
"min": min(values),
"max": max(values),
"latest": values[-1],
"window_seconds": window_seconds
}
class HealthChecker:
"""健康检查器"""
def __init__(self):
self.checks: Dict[str, Callable[[], bool]] = {}
self.check_results: Dict[str, Dict] = {}
def register_check(self, name: str, check_func: Callable[[], bool]):
"""注册健康检查"""
self.checks[name] = check_func
def run_checks(self) -> Dict[str, Any]:
"""运行所有健康检查"""
results = {"overall_healthy": True, "checks": {}}
for name, check_func in self.checks.items():
start_time = time.time()
try:
is_healthy = check_func()
check_time = time.time() - start_time
results["checks"][name] = {
"healthy": is_healthy,
"response_time": check_time,
"timestamp": time.time()
}
if not is_healthy:
results["overall_healthy"] = False
except Exception as e:
results["checks"][name] = {
"healthy": False,
"error": str(e),
"response_time": time.time() - start_time,
"timestamp": time.time()
}
results["overall_healthy"] = False
return results
class AlertManager:
"""告警管理器"""
def __init__(self):
self.alert_rules: List[AlertRule] = []
self.active_alerts: Dict[str, Dict] = {}
self.notification_handlers: List[Callable] = []
def add_alert_rule(self, rule: 'AlertRule'):
"""添加告警规则"""
self.alert_rules.append(rule)
def add_notification_handler(self, handler: Callable[[Dict], None]):
"""添加通知处理器"""
self.notification_handlers.append(handler)
def check_alerts(self, metrics_collector: MetricsCollector):
"""检查告警条件"""
for rule in self.alert_rules:
try:
should_alert = rule.evaluate(metrics_collector)
alert_key = f"{rule.name}:{rule.metric_name}"
if should_alert and alert_key not in self.active_alerts:
# 新告警
alert = {
"rule_name": rule.name,
"metric_name": rule.metric_name,
"description": rule.description,
"severity": rule.severity,
"started_at": time.time(),
"acknowledged": False
}
self.active_alerts[alert_key] = alert
self._send_notification(alert)
elif not should_alert and alert_key in self.active_alerts:
# 告警恢复
alert = self.active_alerts.pop(alert_key)
alert["resolved_at"] = time.time()
alert["duration"] = alert["resolved_at"] - alert["started_at"]
self._send_notification(alert, resolved=True)
except Exception as e:
logging.error(f"Error evaluating alert rule {rule.name}: {e}")
def _send_notification(self, alert: Dict, resolved: bool = False):
"""发送告警通知"""
for handler in self.notification_handlers:
try:
handler(alert, resolved)
except Exception as e:
logging.error(f"Error sending notification: {e}")
@dataclass
class AlertRule:
"""告警规则"""
name: str
metric_name: str
condition: str # "gt", "lt", "eq"
threshold: float
window_seconds: int = 300
description: str = ""
severity: str = "warning" # "info", "warning", "critical"
def evaluate(self, metrics_collector: MetricsCollector) -> bool:
"""评估告警条件"""
stats = metrics_collector.get_metric_stats(
self.metric_name,
self.window_seconds
)
if not stats:
return False
current_value = stats.get("avg", 0) # 使用平均值
if self.condition == "gt":
return current_value > self.threshold
elif self.condition == "lt":
return current_value < self.threshold
elif self.condition == "eq":
return abs(current_value - self.threshold) < 0.001
return False
# 生产环境监控集成
class ProductionMonitor:
"""生产环境监控"""
def __init__(self):
self.metrics_collector = MetricsCollector()
self.health_checker = HealthChecker()
self.alert_manager = AlertManager()
# 设置基础健康检查
self._setup_basic_health_checks()
# 设置基础告警规则
self._setup_basic_alert_rules()
# 启动监控线程
self._start_monitoring_thread()
def _setup_basic_health_checks(self):
"""设置基础健康检查"""
def database_check() -> bool:
# 检查数据库连接
try:
# 实际实现应该检查数据库连接
return True
except:
return False
def memory_check() -> bool:
# 检查内存使用
import psutil
memory_percent = psutil.virtual_memory().percent
return memory_percent < 90
self.health_checker.register_check("database", database_check)
self.health_checker.register_check("memory", memory_check)
def _setup_basic_alert_rules(self):
"""设置基础告警规则"""
# 响应时间告警
response_time_rule = AlertRule(
name="high_response_time",
metric_name="response_time",
condition="gt",
threshold=5.0, # 5秒
description="平均响应时间过高",
severity="warning"
)
# 错误率告警
error_rate_rule = AlertRule(
name="high_error_rate",
metric_name="error_rate",
condition="gt",
threshold=0.05, # 5%
description="错误率过高",
severity="critical"
)
self.alert_manager.add_alert_rule(response_time_rule)
self.alert_manager.add_alert_rule(error_rate_rule)
def _start_monitoring_thread(self):
"""启动监控线程"""
def monitor_worker():
while True:
try:
# 检查告警
self.alert_manager.check_alerts(self.metrics_collector)
time.sleep(30) # 每30秒检查一次
except Exception as e:
logging.error(f"Monitoring error: {e}")
time.sleep(60) # 出错时等待更长时间
thread = threading.Thread(target=monitor_worker, daemon=True)
thread.start()
def record_request_metrics(self, response_time: float, success: bool):
"""记录请求指标"""
self.metrics_collector.record_metric("response_time", response_time)
error_rate = 0.0 if success else 1.0
self.metrics_collector.record_metric("error_rate", error_rate)
# 记录请求计数
self.metrics_collector.record_metric("request_count", 1)
5. 总结
本文档总结了LangGraph在实际项目中的最佳实践和经验,涵盖了从架构设计到生产部署的全流程。
5.1 核心要点
- 状态设计:清晰的状态结构和最小化的状态更新
- 节点设计:单一职责和全面的错误处理
- 性能优化:检查点优化、内存管理、并发处理
- 错误处理:分层错误处理和恢复策略
- 调试监控:全面的调试工具和生产监控
- 部署运维:容器化部署和健康检查
5.2 实施建议
- 渐进式采用:从简单场景开始,逐步应用最佳实践
- 测试驱动:建立完善的测试体系验证最佳实践
- 持续改进:根据实际使用情况不断优化和调整
- 团队培训:确保团队成员理解和遵循最佳实践
这些实战经验为开发者提供了实用的指导,帮助构建高质量、可维护的LangGraph应用程序。