LangGraph-04-框架使用示例与最佳实践文档
0. 文档概览
本文档提供了LangGraph框架的完整使用指南,包含从基础入门到高级应用的实战示例,以及在生产环境中部署和优化的最佳实践。通过丰富的代码示例和案例分析,帮助开发者快速掌握LangGraph的核心概念和实用技巧。
内容结构
- 基础使用示例:涵盖核心功能的入门级代码示例
- 高级应用模式:复杂场景下的架构设计和实现方案
- 实战经验总结:来自生产环境的经验教训和优化技巧
- 最佳实践指南:开发、测试、部署的标准化流程和规范
1. 基础使用示例
1.1 最简单的ReAct Agent
from langgraph.prebuilt import create_react_agent
from langchain_openai import ChatOpenAI
from langchain_core.tools import tool
@tool
def get_weather(city: str) -> str:
"""获取指定城市的天气信息"""
# 模拟天气API调用
weather_data = {
"北京": "晴天,25°C",
"上海": "多云,22°C",
"深圳": "雨天,28°C"
}
return weather_data.get(city, "暂无该城市天气信息")
# 创建模型和工具
model = ChatOpenAI(model="gpt-4", temperature=0)
tools = [get_weather]
# 构建Agent
agent = create_react_agent(
model=model,
tools=tools,
prompt="你是一个专业的天气助手,能够为用户提供准确的天气信息。"
)
# 执行对话
result = agent.invoke({
"messages": [{"role": "user", "content": "北京今天天气怎么样?"}]
})
print(result["messages"][-1]["content"])
# 输出: 根据查询结果,北京今天是晴天,气温25°C。
关键要点:
- 使用@tool装饰器快速创建工具
- create_react_agent自动处理工具绑定和图构建
- 标准的消息格式便于集成和扩展
1.2 带持久化的对话Agent
from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.prebuilt import create_react_agent
import uuid
# 创建SQLite检查点保存器
checkpointer = SqliteSaver.from_conn_string("agent_memory.db")
# 构建具有记忆能力的Agent
agent = create_react_agent(
model=ChatOpenAI(model="gpt-4"),
tools=[get_weather],
checkpointer=checkpointer,
prompt="你是一个友好的助手,能够记住我们之前的对话内容。"
)
# 创建对话线程
thread_id = str(uuid.uuid4())
config = {"configurable": {"thread_id": thread_id}}
# 第一轮对话
result1 = agent.invoke(
{"messages": [{"role": "user", "content": "我的名字是小明"}]},
config=config
)
# 第二轮对话(会记住之前的内容)
result2 = agent.invoke(
{"messages": [{"role": "user", "content": "你还记得我的名字吗?"}]},
config=config
)
print(result2["messages"][-1]["content"])
# 输出: 当然记得,你的名字是小明。有什么可以帮助你的吗?
关键要点:
- 使用SqliteSaver实现本地持久化
- thread_id用于区分不同的对话会话
- 检查点自动保存和恢复对话历史
1.3 自定义状态的复杂工作流
from typing import TypedDict, List, Annotated
from langgraph.graph import StateGraph, END, START
from langgraph.graph.message import add_messages
import operator
class ResearchState(TypedDict):
"""研究助手的状态定义"""
messages: Annotated[List, add_messages]
research_topic: str
collected_info: Annotated[List[str], operator.add]
analysis_result: str
confidence_score: float
@tool
def search_academic_papers(query: str) -> str:
"""搜索学术论文"""
# 模拟学术搜索API
papers = [
f"关于'{query}'的研究论文1:介绍了相关理论基础",
f"关于'{query}'的研究论文2:提供了实验验证结果",
f"关于'{query}'的研究论文3:分析了实际应用场景"
]
return "\n".join(papers)
def research_planner(state: ResearchState) -> ResearchState:
"""研究计划制定节点"""
topic = state["research_topic"]
return {
"messages": [{"role": "assistant", "content": f"开始研究主题:{topic}"}],
"collected_info": [f"研究计划:针对'{topic}'进行深入调研"]
}
def information_collector(state: ResearchState) -> ResearchState:
"""信息收集节点"""
# 使用搜索工具收集信息
search_result = search_academic_papers(state["research_topic"])
return {
"collected_info": [f"搜索结果:{search_result}"],
"messages": [{"role": "assistant", "content": "已完成信息收集阶段"}]
}
def analysis_synthesizer(state: ResearchState) -> ResearchState:
"""分析综合节点"""
info_count = len(state["collected_info"])
analysis = f"基于{info_count}条信息,对'{state['research_topic']}'的综合分析..."
return {
"analysis_result": analysis,
"confidence_score": min(0.9, info_count * 0.2),
"messages": [{"role": "assistant", "content": f"分析完成,置信度:{min(0.9, info_count * 0.2):.2f}"}]
}
# 构建研究工作流图
workflow = StateGraph(ResearchState)
# 添加节点
workflow.add_node("planner", research_planner)
workflow.add_node("collector", information_collector)
workflow.add_node("analyzer", analysis_synthesizer)
# 定义执行顺序
workflow.add_edge(START, "planner")
workflow.add_edge("planner", "collector")
workflow.add_edge("collector", "analyzer")
workflow.add_edge("analyzer", END)
# 编译工作流
research_app = workflow.compile()
# 执行研究任务
result = research_app.invoke({
"research_topic": "人工智能在医疗诊断中的应用",
"messages": [],
"collected_info": [],
"analysis_result": "",
"confidence_score": 0.0
})
print(f"研究主题: {result['research_topic']}")
print(f"收集信息数: {len(result['collected_info'])}")
print(f"分析结果: {result['analysis_result']}")
print(f"置信度: {result['confidence_score']:.2f}")
关键要点:
- 自定义TypedDict状态支持复杂数据结构
- 使用Annotated和operator.add实现列表累加
- 清晰的节点职责分离,便于维护和测试
2. 高级应用模式
2.1 多Agent协作系统
from langgraph.graph import StateGraph, END
from langgraph.types import Send
from typing import List, Dict
class MultiAgentState(TypedDict):
"""多Agent协作状态"""
task_description: str
agent_assignments: Dict[str, str]
agent_results: Annotated[List[Dict], operator.add]
final_result: str
class SpecialistAgent:
"""专家Agent基类"""
def __init__(self, name: str, expertise: str, model: ChatOpenAI):
self.name = name
self.expertise = expertise
self.model = model
def process_task(self, task: str) -> Dict:
"""处理分配的任务"""
prompt = f"""
你是{self.expertise}专家{self.name}。
请基于你的专业知识处理以下任务:{task}
提供专业、详细的分析和建议。
"""
result = self.model.invoke([{"role": "user", "content": prompt}])
return {
"agent": self.name,
"expertise": self.expertise,
"task": task,
"result": result.content,
"timestamp": datetime.now().isoformat()
}
# 创建专家Agent实例
agents = {
"tech_expert": SpecialistAgent("技术专家", "技术架构和实现", ChatOpenAI(model="gpt-4")),
"business_expert": SpecialistAgent("商业专家", "商业模式和市场分析", ChatOpenAI(model="gpt-4")),
"risk_expert": SpecialistAgent("风险专家", "风险评估和管控", ChatOpenAI(model="gpt-4"))
}
def task_coordinator(state: MultiAgentState) -> List[Send]:
"""任务协调器 - 分配任务给不同专家"""
task = state["task_description"]
# 生成发送给各个专家的消息
sends = []
for agent_name, assignment in state["agent_assignments"].items():
sends.append(Send(
node=f"agent_{agent_name}",
arg={"task": assignment, "context": task}
))
return sends
def create_agent_node(agent: SpecialistAgent):
"""创建Agent处理节点"""
def agent_node(input_data: Dict) -> MultiAgentState:
task = input_data["task"]
result = agent.process_task(task)
return {
"agent_results": [result]
}
return agent_node
def result_synthesizer(state: MultiAgentState) -> MultiAgentState:
"""结果综合器"""
results = state["agent_results"]
# 综合所有专家的分析结果
synthesis_prompt = f"""
基于以下专家分析结果,提供综合性的决策建议:
原始任务:{state['task_description']}
专家分析结果:
"""
for result in results:
synthesis_prompt += f"\n{result['expertise']}({result['agent']}):\n{result['result']}\n"
synthesis_prompt += "\n请提供综合分析和最终建议:"
model = ChatOpenAI(model="gpt-4")
final_result = model.invoke([{"role": "user", "content": synthesis_prompt}])
return {"final_result": final_result.content}
# 构建多Agent协作图
multi_agent_graph = StateGraph(MultiAgentState)
# 添加协调器节点
multi_agent_graph.add_node("coordinator", task_coordinator)
# 添加专家Agent节点
for agent_name, agent in agents.items():
multi_agent_graph.add_node(f"agent_{agent_name}", create_agent_node(agent))
# 添加结果综合节点
multi_agent_graph.add_node("synthesizer", result_synthesizer)
# 定义条件边:协调器 -> 各专家Agent
multi_agent_graph.add_conditional_edges(
"coordinator",
lambda x: [f"agent_{name}" for name in agents.keys()],
{f"agent_{name}": f"agent_{name}" for name in agents.keys()}
)
# 各Agent完成后汇总到综合器
for agent_name in agents.keys():
multi_agent_graph.add_edge(f"agent_{agent_name}", "synthesizer")
multi_agent_graph.add_edge("synthesizer", END)
multi_agent_graph.set_entry_point("coordinator")
# 编译多Agent系统
multi_agent_app = multi_agent_graph.compile()
# 执行多Agent协作任务
collaboration_result = multi_agent_app.invoke({
"task_description": "评估开发一个AI驱动的客户服务系统的可行性",
"agent_assignments": {
"tech_expert": "分析技术实现方案、架构设计和技术风险",
"business_expert": "评估市场需求、商业价值和盈利模式",
"risk_expert": "识别项目风险、合规要求和缓解策略"
},
"agent_results": [],
"final_result": ""
})
print("=== 多Agent协作结果 ===")
print(f"任务描述: {collaboration_result['task_description']}")
print(f"参与专家: {len(collaboration_result['agent_results'])}位")
print(f"最终建议:\n{collaboration_result['final_result']}")
关键要点:
- Send机制实现动态任务分发
- 专家Agent封装特定领域知识
- 结果综合器整合多方观点
- 支持灵活的协作模式扩展
2.2 人机交互与审批流程
from langgraph.types import interrupt, Command
import uuid
class ApprovalState(TypedDict):
"""审批流程状态"""
request_id: str
request_content: str
current_step: str
approval_history: Annotated[List[Dict], operator.add]
final_decision: str
reason: str
def create_approval_request(state: ApprovalState) -> ApprovalState:
"""创建审批请求"""
request_id = str(uuid.uuid4())[:8]
return {
"request_id": request_id,
"current_step": "pending_review",
"approval_history": [{
"step": "created",
"timestamp": datetime.now().isoformat(),
"action": "request_created"
}]
}
def human_review_required(state: ApprovalState) -> ApprovalState:
"""需要人工审核的节点"""
# 暂停执行,等待人工输入
human_decision = interrupt(
value={
"request_id": state["request_id"],
"content": state["request_content"],
"question": "请审核此请求,输入 'approve' 批准或 'reject' 拒绝,并说明理由"
}
)
# 人工输入格式:{"decision": "approve", "reason": "符合政策要求"}
decision_data = human_decision if isinstance(human_decision, dict) else {"decision": "reject", "reason": "无效输入"}
return {
"current_step": "human_reviewed",
"approval_history": [{
"step": "human_review",
"timestamp": datetime.now().isoformat(),
"decision": decision_data["decision"],
"reason": decision_data["reason"]
}]
}
def auto_policy_check(state: ApprovalState) -> ApprovalState:
"""自动策略检查"""
content = state["request_content"]
# 模拟策略检查逻辑
if "urgent" in content.lower():
policy_result = "requires_approval"
reason = "包含紧急标识,需要上级批准"
elif len(content) > 1000:
policy_result = "requires_approval"
reason = "内容过长,需要详细审核"
else:
policy_result = "auto_approved"
reason = "符合自动批准条件"
return {
"current_step": policy_result,
"approval_history": [{
"step": "policy_check",
"timestamp": datetime.now().isoformat(),
"result": policy_result,
"reason": reason
}]
}
def finalize_decision(state: ApprovalState) -> ApprovalState:
"""最终决策"""
history = state["approval_history"]
# 分析审批历史做出最终决策
human_decisions = [h for h in history if h.get("step") == "human_review"]
policy_decisions = [h for h in history if h.get("step") == "policy_check"]
if human_decisions:
# 如果有人工审核,以人工决策为准
final_decision = human_decisions[-1]["decision"]
reason = f"人工审核结果:{human_decisions[-1]['reason']}"
elif policy_decisions:
# 否则使用策略检查结果
policy_result = policy_decisions[-1]["result"]
final_decision = "approved" if policy_result == "auto_approved" else "rejected"
reason = policy_decisions[-1]["reason"]
else:
final_decision = "rejected"
reason = "未完成审批流程"
return {
"final_decision": final_decision,
"reason": reason,
"current_step": "completed"
}
def approval_router(state: ApprovalState) -> str:
"""审批路由器"""
current_step = state["current_step"]
if current_step == "pending_review":
return "policy_check"
elif current_step == "auto_approved":
return "finalize"
elif current_step == "requires_approval":
return "human_review"
elif current_step == "human_reviewed":
return "finalize"
else:
return END
# 构建审批工作流
approval_workflow = StateGraph(ApprovalState)
# 添加节点
approval_workflow.add_node("create_request", create_approval_request)
approval_workflow.add_node("policy_check", auto_policy_check)
approval_workflow.add_node("human_review", human_review_required)
approval_workflow.add_node("finalize", finalize_decision)
# 设置路由
approval_workflow.set_entry_point("create_request")
approval_workflow.add_edge("create_request", "policy_check")
approval_workflow.add_conditional_edges(
"policy_check",
approval_router,
{
"policy_check": "policy_check",
"human_review": "human_review",
"finalize": "finalize",
END: END
}
)
approval_workflow.add_conditional_edges(
"human_review",
approval_router,
{
"finalize": "finalize",
END: END
}
)
approval_workflow.add_edge("finalize", END)
# 编译审批系统
approval_app = approval_workflow.compile(
checkpointer=SqliteSaver.from_conn_string("approval_system.db"),
interrupt_before=["human_review"] # 在人工审核前中断
)
# 使用示例
config = {"configurable": {"thread_id": "approval_001"}}
# 提交审批请求
approval_result = approval_app.invoke({
"request_content": "申请购买新的服务器设备,预算50万元,用于支持业务增长需求。标记为urgent。",
"current_step": "pending_review",
"approval_history": [],
"final_decision": "",
"reason": ""
}, config=config)
print("=== 审批请求已提交,等待人工审核 ===")
print(f"请求ID: {approval_result['request_id']}")
print(f"当前状态: {approval_result['current_step']}")
# 模拟人工审核(在实际应用中,这将是通过UI界面进行的)
human_input = {"decision": "approve", "reason": "预算合理,业务需求明确,批准购买"}
# 继续执行(提供人工审核结果)
final_result = approval_app.invoke(
Command(resume=human_input),
config=config
)
print("\n=== 最终审批结果 ===")
print(f"决策: {final_result['final_decision']}")
print(f"理由: {final_result['reason']}")
print(f"审批历史: {len(final_result['approval_history'])}个步骤")
关键要点:
- interrupt()函数实现人机交互暂停点
- Command(resume=value)恢复执行并传递用户输入
- 灵活的路由逻辑支持复杂审批流程
- 完整的审批历史追踪和记录
3. 实战经验总结
3.1 性能优化实践
3.1.1 工具并发执行优化
from langgraph.prebuilt import ToolNode
import asyncio
import time
# 并发工具执行示例
@tool
async def fetch_user_data(user_id: str) -> str:
"""异步获取用户数据"""
await asyncio.sleep(1) # 模拟API调用延迟
return f"用户{user_id}的详细信息"
@tool
async def fetch_order_history(user_id: str) -> str:
"""异步获取订单历史"""
await asyncio.sleep(0.8) # 模拟数据库查询延迟
return f"用户{user_id}的订单历史"
@tool
async def calculate_user_score(user_id: str) -> str:
"""异步计算用户评分"""
await asyncio.sleep(0.5) # 模拟计算延迟
return f"用户{user_id}的评分:85分"
# 性能对比:串行 vs 并发
class PerformanceComparison:
@staticmethod
def serial_execution():
"""串行执行模拟"""
start_time = time.time()
# 模拟串行执行
time.sleep(1) # fetch_user_data
time.sleep(0.8) # fetch_order_history
time.sleep(0.5) # calculate_user_score
end_time = time.time()
return end_time - start_time
@staticmethod
async def concurrent_execution():
"""并发执行模拟"""
start_time = time.time()
# 并发执行所有工具
tasks = [
fetch_user_data("user123"),
fetch_order_history("user123"),
calculate_user_score("user123")
]
results = await asyncio.gather(*tasks)
end_time = time.time()
return end_time - start_time, results
# 性能测试
print("=== 性能对比测试 ===")
serial_time = PerformanceComparison.serial_execution()
concurrent_time, results = asyncio.run(PerformanceComparison.concurrent_execution())
print(f"串行执行耗时: {serial_time:.2f}秒")
print(f"并发执行耗时: {concurrent_time:.2f}秒")
print(f"性能提升: {(serial_time - concurrent_time) / serial_time * 100:.1f}%")
3.1.2 状态管理优化
from typing import Optional
from dataclasses import dataclass
import json
@dataclass
class OptimizedState:
"""优化的状态结构"""
# 核心数据(经常访问)
user_id: str
session_id: str
current_step: str
# 大对象(按需加载)
_large_data: Optional[str] = None
_computation_cache: Optional[Dict] = None
def get_large_data(self) -> Dict:
"""懒加载大数据对象"""
if self._large_data is None:
# 模拟从外部存储加载大数据
self._large_data = json.dumps({"data": ["item"] * 1000})
return json.loads(self._large_data)
def get_cached_computation(self, key: str) -> Any:
"""获取缓存的计算结果"""
if self._computation_cache is None:
self._computation_cache = {}
if key not in self._computation_cache:
# 执行昂贵的计算
self._computation_cache[key] = self._expensive_computation(key)
return self._computation_cache[key]
def _expensive_computation(self, key: str) -> Any:
"""模拟昂贵的计算操作"""
# 实际应用中这里可能是复杂的数据处理
return f"computed_result_for_{key}"
# 状态压缩策略
class StateCompressionManager:
"""状态压缩管理器"""
@staticmethod
def compress_message_history(messages: List[Dict], max_messages: int = 10) -> List[Dict]:
"""压缩消息历史"""
if len(messages) <= max_messages:
return messages
# 保留最新的消息和重要的系统消息
important_messages = [msg for msg in messages if msg.get("role") == "system"]
recent_messages = messages[-(max_messages - len(important_messages)):]
return important_messages + recent_messages
@staticmethod
def summarize_old_context(messages: List[Dict]) -> str:
"""总结旧的上下文信息"""
# 使用LLM总结旧的对话内容
old_messages = messages[:-10] # 获取较旧的消息
if not old_messages:
return ""
summary_prompt = "请简要总结以下对话的关键信息:"
for msg in old_messages:
summary_prompt += f"\n{msg['role']}: {msg['content']}"
# 这里应该调用LLM进行总结,为了示例简化处理
return f"总结:共{len(old_messages)}条历史消息的对话摘要"
# 在工作流中应用优化策略
def optimized_state_processor(state: Dict) -> Dict:
"""优化的状态处理器"""
# 压缩消息历史
if "messages" in state and len(state["messages"]) > 20:
compressed_messages = StateCompressionManager.compress_message_history(
state["messages"], max_messages=15
)
# 如果有被压缩的消息,生成摘要
if len(compressed_messages) < len(state["messages"]):
summary = StateCompressionManager.summarize_old_context(state["messages"])
# 将摘要作为系统消息插入
compressed_messages.insert(0, {
"role": "system",
"content": f"对话历史摘要:{summary}"
})
state["messages"] = compressed_messages
return state
3.2 错误处理与监控
3.2.1 全面的错误处理策略
from enum import Enum
import logging
import traceback
from functools import wraps
class ErrorSeverity(Enum):
"""错误严重程度"""
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
CRITICAL = "critical"
class ErrorHandler:
"""统一错误处理器"""
def __init__(self):
self.logger = logging.getLogger(__name__)
self.error_stats = {}
def handle_error(self, error: Exception, context: Dict, severity: ErrorSeverity = ErrorSeverity.MEDIUM):
"""处理错误"""
error_info = {
"error_type": type(error).__name__,
"error_message": str(error),
"context": context,
"severity": severity.value,
"timestamp": datetime.now().isoformat(),
"traceback": traceback.format_exc()
}
# 记录错误统计
error_key = f"{error_info['error_type']}_{severity.value}"
self.error_stats[error_key] = self.error_stats.get(error_key, 0) + 1
# 根据严重程度决定处理方式
if severity == ErrorSeverity.CRITICAL:
self.logger.critical(f"Critical error: {error_info}")
# 发送告警通知
self._send_alert(error_info)
elif severity == ErrorSeverity.HIGH:
self.logger.error(f"High severity error: {error_info}")
else:
self.logger.warning(f"Error occurred: {error_info}")
return error_info
def _send_alert(self, error_info: Dict):
"""发送告警通知"""
# 实际应用中这里会集成告警系统
print(f"🚨 ALERT: Critical error detected - {error_info['error_message']}")
# 重试装饰器
def retry_with_backoff(max_retries: int = 3, base_delay: float = 1.0):
"""带退避策略的重试装饰器"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
last_exception = None
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except Exception as e:
last_exception = e
if attempt < max_retries - 1:
delay = base_delay * (2 ** attempt) # 指数退避
time.sleep(delay)
print(f"Retry {attempt + 1}/{max_retries} after {delay}s delay")
else:
print(f"All {max_retries} attempts failed")
raise last_exception
return wrapper
return decorator
# 带错误处理的工具示例
error_handler = ErrorHandler()
@tool
@retry_with_backoff(max_retries=3, base_delay=0.5)
def robust_api_call(endpoint: str) -> str:
"""具有容错能力的API调用"""
try:
# 模拟API调用
import random
if random.random() < 0.3: # 30%概率失败
raise ConnectionError("API服务暂时不可用")
return f"成功调用API端点: {endpoint}"
except ConnectionError as e:
error_handler.handle_error(
e,
{"endpoint": endpoint, "operation": "api_call"},
ErrorSeverity.HIGH
)
raise
except Exception as e:
error_handler.handle_error(
e,
{"endpoint": endpoint, "operation": "api_call"},
ErrorSeverity.CRITICAL
)
raise
# 带监控的Agent节点
def monitored_agent_node(state: Dict) -> Dict:
"""带监控的Agent节点"""
node_start_time = time.time()
try:
# 节点逻辑执行
result = process_agent_logic(state)
# 记录成功指标
execution_time = time.time() - node_start_time
log_metrics("agent_node_success", {
"execution_time": execution_time,
"state_size": len(str(state))
})
return result
except Exception as e:
execution_time = time.time() - node_start_time
error_handler.handle_error(e, {
"node": "agent_node",
"execution_time": execution_time,
"state_size": len(str(state))
})
# 记录失败指标
log_metrics("agent_node_failure", {
"execution_time": execution_time,
"error_type": type(e).__name__
})
# 根据错误类型决定是否继续
if isinstance(e, (ConnectionError, TimeoutError)):
# 网络相关错误,返回降级结果
return {"error": "service_unavailable", "fallback": True}
else:
# 其他错误,重新抛出
raise
def log_metrics(metric_name: str, data: Dict):
"""记录指标数据"""
# 实际应用中这里会发送到监控系统
print(f"📊 Metric: {metric_name} - {data}")
def process_agent_logic(state: Dict) -> Dict:
"""Agent逻辑处理"""
# 模拟处理逻辑
return {"processed": True, "result": "success"}
3.3 测试策略与质量保证
3.3.1 单元测试框架
import pytest
from unittest.mock import Mock, patch
from langgraph.graph import StateGraph
class TestAgentWorkflow:
"""Agent工作流测试类"""
@pytest.fixture
def mock_llm(self):
"""模拟LLM"""
mock = Mock()
mock.invoke.return_value = Mock(content="模拟LLM响应")
return mock
@pytest.fixture
def test_state(self):
"""测试状态"""
return {
"messages": [{"role": "user", "content": "测试消息"}],
"context": {"user_id": "test_user"}
}
def test_single_node_execution(self, mock_llm, test_state):
"""单节点执行测试"""
def test_node(state):
return {"processed": True}
result = test_node(test_state)
assert result["processed"] is True
def test_workflow_end_to_end(self, mock_llm):
"""端到端工作流测试"""
def node1(state):
return {"step1_complete": True}
def node2(state):
return {"step2_complete": True}
# 构建测试图
graph = StateGraph(dict)
graph.add_node("node1", node1)
graph.add_node("node2", node2)
graph.add_edge("node1", "node2")
graph.set_entry_point("node1")
graph.set_finish_point("node2")
app = graph.compile()
# 执行测试
result = app.invoke({"initial": "state"})
assert result["step1_complete"] is True
assert result["step2_complete"] is True
def test_error_handling(self):
"""错误处理测试"""
def failing_node(state):
raise ValueError("测试错误")
graph = StateGraph(dict)
graph.add_node("failing_node", failing_node)
graph.set_entry_point("failing_node")
graph.set_finish_point("failing_node")
app = graph.compile()
with pytest.raises(ValueError, match="测试错误"):
app.invoke({"test": "data"})
@patch('external_api.call')
def test_external_dependency_mock(self, mock_api_call):
"""外部依赖模拟测试"""
mock_api_call.return_value = {"result": "mocked_data"}
@tool
def api_tool():
import external_api
return external_api.call()
result = api_tool.invoke({})
assert result == {"result": "mocked_data"}
mock_api_call.assert_called_once()
# 集成测试
class TestAgentIntegration:
"""Agent集成测试"""
def test_react_agent_integration(self):
"""ReAct Agent集成测试"""
@tool
def test_tool(query: str) -> str:
return f"测试结果: {query}"
# 使用真实的图结构但模拟LLM
with patch('langchain_openai.ChatOpenAI') as mock_openai:
mock_instance = Mock()
mock_instance.bind_tools.return_value = mock_instance
mock_instance.invoke.return_value = Mock(
content="测试完成",
tool_calls=[]
)
mock_openai.return_value = mock_instance
agent = create_react_agent(
model=mock_instance,
tools=[test_tool]
)
result = agent.invoke({
"messages": [{"role": "user", "content": "执行测试"}]
})
assert "messages" in result
assert len(result["messages"]) > 0
# 性能测试
class TestAgentPerformance:
"""Agent性能测试"""
def test_execution_time(self):
"""执行时间测试"""
def timed_node(state):
time.sleep(0.1) # 模拟处理时间
return {"processed": True}
graph = StateGraph(dict)
graph.add_node("timed_node", timed_node)
graph.set_entry_point("timed_node")
graph.set_finish_point("timed_node")
app = graph.compile()
start_time = time.time()
result = app.invoke({"test": "performance"})
execution_time = time.time() - start_time
assert execution_time < 0.2 # 确保在合理时间内完成
assert result["processed"] is True
def test_memory_usage(self):
"""内存使用测试"""
import psutil
import os
process = psutil.Process(os.getpid())
initial_memory = process.memory_info().rss
# 执行大量操作
large_state = {"data": ["item"] * 10000}
def memory_intensive_node(state):
# 模拟内存密集操作
processed_data = [item.upper() for item in state["data"]]
return {"processed_data": processed_data}
graph = StateGraph(dict)
graph.add_node("memory_node", memory_intensive_node)
graph.set_entry_point("memory_node")
graph.set_finish_point("memory_node")
app = graph.compile()
result = app.invoke(large_state)
final_memory = process.memory_info().rss
memory_increase = final_memory - initial_memory
# 确保内存增长在合理范围内(100MB)
assert memory_increase < 100 * 1024 * 1024
4. 最佳实践指南
4.1 开发阶段最佳实践
4.1.1 项目结构组织
langgraph_project/
├── src/
│ ├── agents/
│ │ ├── __init__.py
│ │ ├── base_agent.py # Agent基类
│ │ ├── react_agent.py # ReAct Agent实现
│ │ └── multi_agent_system.py # 多Agent系统
│ ├── tools/
│ │ ├── __init__.py
│ │ ├── base_tool.py # 工具基类
│ │ ├── api_tools.py # API调用工具
│ │ └── data_tools.py # 数据处理工具
│ ├── states/
│ │ ├── __init__.py
│ │ ├── common_states.py # 通用状态定义
│ │ └── domain_states.py # 领域特定状态
│ ├── workflows/
│ │ ├── __init__.py
│ │ ├── approval_workflow.py # 审批工作流
│ │ └── data_pipeline.py # 数据管道
│ └── utils/
│ ├── __init__.py
│ ├── error_handling.py # 错误处理工具
│ ├── monitoring.py # 监控工具
│ └── testing.py # 测试工具
├── tests/
│ ├── unit/
│ ├── integration/
│ └── performance/
├── config/
│ ├── development.yml
│ ├── production.yml
│ └── testing.yml
├── requirements.txt
└── README.md
4.1.2 代码质量标准
# 良好的Agent设计示例
from abc import ABC, abstractmethod
from typing import TypedDict, Generic, TypeVar
from langgraph.graph import StateGraph
StateT = TypeVar('StateT', bound=dict)
class BaseAgent(ABC, Generic[StateT]):
"""Agent基类,定义通用接口"""
def __init__(self, name: str, model, tools: List, config: Dict = None):
self.name = name
self.model = model
self.tools = tools
self.config = config or {}
self._graph = None
@abstractmethod
def build_graph(self) -> StateGraph:
"""构建Agent图结构 - 子类必须实现"""
pass
@abstractmethod
def validate_state(self, state: StateT) -> bool:
"""验证状态有效性 - 子类必须实现"""
pass
def compile(self, **kwargs) -> 'CompiledAgent':
"""编译Agent为可执行实例"""
if self._graph is None:
self._graph = self.build_graph()
return self._graph.compile(**kwargs)
def get_metrics(self) -> Dict:
"""获取Agent性能指标"""
return {
"name": self.name,
"tools_count": len(self.tools),
"config": self.config
}
class CustomerServiceAgent(BaseAgent[CustomerServiceState]):
"""客户服务Agent实现"""
def build_graph(self) -> StateGraph:
"""构建客户服务工作流"""
graph = StateGraph(CustomerServiceState)
# 添加节点
graph.add_node("classify_request", self._classify_request)
graph.add_node("handle_complaint", self._handle_complaint)
graph.add_node("provide_info", self._provide_info)
graph.add_node("escalate", self._escalate)
# 设置路由逻辑
graph.set_entry_point("classify_request")
graph.add_conditional_edges(
"classify_request",
self._route_request,
{
"complaint": "handle_complaint",
"inquiry": "provide_info",
"complex": "escalate"
}
)
return graph
def validate_state(self, state: CustomerServiceState) -> bool:
"""验证客户服务状态"""
required_fields = ["customer_id", "request_type", "content"]
return all(field in state for field in required_fields)
def _classify_request(self, state: CustomerServiceState) -> CustomerServiceState:
"""请求分类节点"""
# 实现请求分类逻辑
pass
def _route_request(self, state: CustomerServiceState) -> str:
"""请求路由逻辑"""
# 实现路由决策逻辑
pass
# 配置管理最佳实践
class ConfigManager:
"""配置管理器"""
def __init__(self, env: str = "development"):
self.env = env
self._config = self._load_config()
def _load_config(self) -> Dict:
"""加载配置文件"""
config_file = f"config/{self.env}.yml"
with open(config_file, 'r') as f:
return yaml.safe_load(f)
def get(self, key: str, default=None):
"""获取配置值"""
keys = key.split('.')
value = self._config
for k in keys:
if isinstance(value, dict) and k in value:
value = value[k]
else:
return default
return value
def get_model_config(self) -> Dict:
"""获取模型配置"""
return self.get('model', {})
def get_tool_config(self) -> Dict:
"""获取工具配置"""
return self.get('tools', {})
# 使用示例
config = ConfigManager(env="production")
model_config = config.get_model_config()
agent = CustomerServiceAgent(
name="customer_service_v1",
model=ChatOpenAI(**model_config),
tools=load_customer_service_tools(),
config=config.get("agent.customer_service", {})
)
compiled_agent = agent.compile(
checkpointer=get_checkpointer(config),
debug=config.get("debug", False)
)
4.2 生产环境部署
4.2.1 容器化部署配置
# Dockerfile
FROM python:3.11-slim
# 设置工作目录
WORKDIR /app
# 安装系统依赖
RUN apt-get update && apt-get install -y \
build-essential \
curl \
&& rm -rf /var/lib/apt/lists/*
# 复制依赖文件
COPY requirements.txt .
# 安装Python依赖
RUN pip install --no-cache-dir -r requirements.txt
# 复制应用代码
COPY src/ ./src/
COPY config/ ./config/
# 设置环境变量
ENV PYTHONPATH=/app/src
ENV ENVIRONMENT=production
# 健康检查
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1
# 暴露端口
EXPOSE 8000
# 启动应用
CMD ["python", "-m", "uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "8000"]
# docker-compose.yml
version: '3.8'
services:
langgraph-agent:
build: .
ports:
- "8000:8000"
environment:
- DATABASE_URL=postgresql://user:password@postgres:5432/langgraph
- REDIS_URL=redis://redis:6379/0
- LOG_LEVEL=INFO
depends_on:
- postgres
- redis
volumes:
- ./config:/app/config:ro
- ./logs:/app/logs
restart: unless-stopped
postgres:
image: postgres:15
environment:
- POSTGRES_DB=langgraph
- POSTGRES_USER=user
- POSTGRES_PASSWORD=password
volumes:
- postgres_data:/var/lib/postgresql/data
restart: unless-stopped
redis:
image: redis:7-alpine
volumes:
- redis_data:/data
restart: unless-stopped
nginx:
image: nginx:alpine
ports:
- "80:80"
- "443:443"
volumes:
- ./nginx/nginx.conf:/etc/nginx/nginx.conf:ro
- ./nginx/ssl:/etc/nginx/ssl:ro
depends_on:
- langgraph-agent
restart: unless-stopped
volumes:
postgres_data:
redis_data:
4.2.2 监控与告警
# monitoring.py
from prometheus_client import Counter, Histogram, Gauge, start_http_server
import logging
import time
from functools import wraps
# Prometheus指标定义
REQUEST_COUNT = Counter('langgraph_requests_total', 'Total requests', ['method', 'endpoint', 'status'])
REQUEST_DURATION = Histogram('langgraph_request_duration_seconds', 'Request duration')
ACTIVE_AGENTS = Gauge('langgraph_active_agents', 'Number of active agents')
TOOL_EXECUTION_TIME = Histogram('langgraph_tool_execution_seconds', 'Tool execution time', ['tool_name'])
class MonitoringManager:
"""监控管理器"""
def __init__(self):
start_http_server(8001) # Prometheus metrics端口
self.logger = self._setup_logging()
def _setup_logging(self):
"""设置结构化日志"""
logger = logging.getLogger('langgraph')
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter(
'{"timestamp": "%(asctime)s", "level": "%(levelname)s", '
'"logger": "%(name)s", "message": "%(message)s"}'
)
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
def track_request(self, method: str, endpoint: str):
"""请求追踪装饰器"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time()
status = "success"
try:
result = func(*args, **kwargs)
return result
except Exception as e:
status = "error"
self.logger.error(f"Request failed: {str(e)}")
raise
finally:
duration = time.time() - start_time
REQUEST_COUNT.labels(method=method, endpoint=endpoint, status=status).inc()
REQUEST_DURATION.observe(duration)
return wrapper
return decorator
def track_tool_execution(self, tool_name: str):
"""工具执行追踪装饰器"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time()
try:
result = func(*args, **kwargs)
self.logger.info(f"Tool executed successfully: {tool_name}")
return result
except Exception as e:
self.logger.error(f"Tool execution failed: {tool_name}, error: {str(e)}")
raise
finally:
duration = time.time() - start_time
TOOL_EXECUTION_TIME.labels(tool_name=tool_name).observe(duration)
return wrapper
return decorator
# 告警规则配置
ALERT_RULES = {
"high_error_rate": {
"condition": "rate(langgraph_requests_total{status='error'}[5m]) > 0.1",
"severity": "warning",
"message": "Error rate is above 10%"
},
"high_response_time": {
"condition": "histogram_quantile(0.95, rate(langgraph_request_duration_seconds_bucket[5m])) > 5",
"severity": "critical",
"message": "95th percentile response time is above 5 seconds"
},
"tool_execution_slow": {
"condition": "histogram_quantile(0.95, rate(langgraph_tool_execution_seconds_bucket[5m])) > 10",
"severity": "warning",
"message": "Tool execution is slower than expected"
}
}
4.3 维护与优化
4.3.1 版本管理与灰度发布
# version_management.py
from enum import Enum
from typing import Optional
import hashlib
class DeploymentStrategy(Enum):
BLUE_GREEN = "blue_green"
CANARY = "canary"
ROLLING = "rolling"
class AgentVersionManager:
"""Agent版本管理器"""
def __init__(self):
self.versions = {}
self.active_version = None
self.traffic_split = {}
def register_version(self, version: str, agent_factory: Callable, config: Dict):
"""注册新版本的Agent"""
version_hash = self._generate_version_hash(config)
self.versions[version] = {
"factory": agent_factory,
"config": config,
"hash": version_hash,
"created_at": datetime.now(),
"status": "registered"
}
return version_hash
def deploy_version(self, version: str, strategy: DeploymentStrategy = DeploymentStrategy.CANARY):
"""部署指定版本"""
if version not in self.versions:
raise ValueError(f"Version {version} not found")
if strategy == DeploymentStrategy.CANARY:
self._canary_deployment(version)
elif strategy == DeploymentStrategy.BLUE_GREEN:
self._blue_green_deployment(version)
else:
self._rolling_deployment(version)
def _canary_deployment(self, version: str, initial_traffic: float = 0.05):
"""金丝雀部署"""
if self.active_version is None:
# 首次部署
self.active_version = version
self.traffic_split = {version: 1.0}
else:
# 金丝雀部署:新版本获得5%流量
old_version = self.active_version
self.traffic_split = {
old_version: 1.0 - initial_traffic,
version: initial_traffic
}
self.versions[version]["status"] = "canary"
def promote_canary(self, version: str):
"""提升金丝雀版本为主版本"""
if self.versions[version]["status"] != "canary":
raise ValueError(f"Version {version} is not in canary status")
# 逐步增加新版本流量
for traffic_percent in [0.1, 0.25, 0.5, 0.75, 1.0]:
self.traffic_split[version] = traffic_percent
remaining_traffic = 1.0 - traffic_percent
# 更新其他版本的流量分配
other_versions = [v for v in self.traffic_split.keys() if v != version]
for other_version in other_versions:
self.traffic_split[other_version] = remaining_traffic / len(other_versions)
# 监控健康指标
if not self._check_version_health(version):
self._rollback_deployment(version)
return False
time.sleep(300) # 等待5分钟观察
# 完全切换到新版本
self.active_version = version
self.traffic_split = {version: 1.0}
self.versions[version]["status"] = "active"
# 清理旧版本
self._cleanup_old_versions()
return True
def _check_version_health(self, version: str) -> bool:
"""检查版本健康状态"""
# 实际应用中这里会检查各种健康指标
health_checks = [
self._check_error_rate(version),
self._check_response_time(version),
self._check_resource_usage(version)
]
return all(health_checks)
def _rollback_deployment(self, failed_version: str):
"""回滚部署"""
self.logger.warning(f"Rolling back deployment of version {failed_version}")
# 恢复到之前的版本
if self.active_version and self.active_version != failed_version:
self.traffic_split = {self.active_version: 1.0}
# 标记失败版本
self.versions[failed_version]["status"] = "failed"
def get_agent_for_request(self, request_id: str):
"""根据流量分配返回对应版本的Agent"""
# 使用请求ID的哈希值来决定路由
hash_value = int(hashlib.md5(request_id.encode()).hexdigest(), 16)
normalized_hash = (hash_value % 100) / 100.0
cumulative_traffic = 0.0
for version, traffic in self.traffic_split.items():
cumulative_traffic += traffic
if normalized_hash <= cumulative_traffic:
return self.versions[version]["factory"]()
# 默认返回当前活跃版本
return self.versions[self.active_version]["factory"]()
def _generate_version_hash(self, config: Dict) -> str:
"""生成版本哈希"""
config_str = json.dumps(config, sort_keys=True)
return hashlib.sha256(config_str.encode()).hexdigest()[:8]
# 使用示例
version_manager = AgentVersionManager()
# 注册v1版本
version_manager.register_version(
"v1.0.0",
lambda: CustomerServiceAgent(model=ChatOpenAI(model="gpt-3.5-turbo")),
{"model": "gpt-3.5-turbo", "temperature": 0.1}
)
# 注册v2版本(使用更好的模型)
version_manager.register_version(
"v2.0.0",
lambda: CustomerServiceAgent(model=ChatOpenAI(model="gpt-4")),
{"model": "gpt-4", "temperature": 0.0}
)
# 金丝雀部署v2版本
version_manager.deploy_version("v2.0.0", DeploymentStrategy.CANARY)
# 监控一段时间后,如果表现良好则提升
if version_manager.promote_canary("v2.0.0"):
print("v2.0.0 successfully promoted to production")
4.3.2 性能调优指南
# performance_tuning.py
class PerformanceTuner:
"""性能调优工具"""
def __init__(self):
self.metrics_collector = MetricsCollector()
self.optimization_strategies = [
self._optimize_model_calls,
self._optimize_tool_execution,
self._optimize_state_management,
self._optimize_memory_usage
]
def analyze_performance(self, agent, test_cases: List[Dict]) -> Dict:
"""分析Agent性能"""
results = {
"baseline_metrics": {},
"bottlenecks": [],
"recommendations": []
}
# 基线性能测试
baseline_metrics = self._run_baseline_test(agent, test_cases)
results["baseline_metrics"] = baseline_metrics
# 识别瓶颈
bottlenecks = self._identify_bottlenecks(baseline_metrics)
results["bottlenecks"] = bottlenecks
# 生成优化建议
recommendations = self._generate_recommendations(bottlenecks)
results["recommendations"] = recommendations
return results
def _run_baseline_test(self, agent, test_cases: List[Dict]) -> Dict:
"""运行基线性能测试"""
metrics = {
"avg_response_time": 0,
"p95_response_time": 0,
"tool_call_frequency": 0,
"memory_usage": 0,
"error_rate": 0
}
response_times = []
tool_calls = 0
errors = 0
for test_case in test_cases:
start_time = time.time()
try:
result = agent.invoke(test_case["input"])
# 统计工具调用
if "messages" in result:
for msg in result["messages"]:
if hasattr(msg, 'tool_calls') and msg.tool_calls:
tool_calls += len(msg.tool_calls)
except Exception as e:
errors += 1
end_time = time.time()
response_times.append(end_time - start_time)
# 计算指标
metrics["avg_response_time"] = sum(response_times) / len(response_times)
metrics["p95_response_time"] = numpy.percentile(response_times, 95)
metrics["tool_call_frequency"] = tool_calls / len(test_cases)
metrics["error_rate"] = errors / len(test_cases)
return metrics
def _identify_bottlenecks(self, metrics: Dict) -> List[str]:
"""识别性能瓶颈"""
bottlenecks = []
if metrics["avg_response_time"] > 5.0:
bottlenecks.append("high_avg_response_time")
if metrics["p95_response_time"] > 10.0:
bottlenecks.append("high_p95_response_time")
if metrics["tool_call_frequency"] > 5:
bottlenecks.append("excessive_tool_calls")
if metrics["error_rate"] > 0.05:
bottlenecks.append("high_error_rate")
return bottlenecks
def _generate_recommendations(self, bottlenecks: List[str]) -> List[Dict]:
"""生成优化建议"""
recommendations = []
recommendation_map = {
"high_avg_response_time": {
"issue": "平均响应时间过高",
"suggestions": [
"启用模型响应缓存",
"优化提示词长度",
"使用更快的模型变体",
"实现工具调用并发执行"
]
},
"high_p95_response_time": {
"issue": "P95响应时间过高,存在长尾延迟",
"suggestions": [
"设置合理的超时时间",
"实现请求重试机制",
"优化最慢的工具调用",
"考虑异步处理模式"
]
},
"excessive_tool_calls": {
"issue": "工具调用过于频繁",
"suggestions": [
"优化Agent推理逻辑",
"合并相似功能的工具",
"增加工具调用缓存",
"改进提示词以减少不必要的工具调用"
]
},
"high_error_rate": {
"issue": "错误率过高",
"suggestions": [
"增强工具参数验证",
"改进错误处理逻辑",
"添加工具调用重试机制",
"优化模型指令清晰度"
]
}
}
for bottleneck in bottlenecks:
if bottleneck in recommendation_map:
recommendations.append(recommendation_map[bottleneck])
return recommendations
def apply_optimizations(self, agent, optimizations: List[str]):
"""应用性能优化"""
for optimization in optimizations:
if optimization == "enable_caching":
self._enable_response_caching(agent)
elif optimization == "optimize_prompts":
self._optimize_prompt_templates(agent)
elif optimization == "concurrent_tools":
self._enable_concurrent_tool_execution(agent)
# ... 其他优化策略
def _enable_response_caching(self, agent):
"""启用响应缓存"""
# 实现响应缓存逻辑
pass
def _optimize_prompt_templates(self, agent):
"""优化提示词模板"""
# 实现提示词优化逻辑
pass
def _enable_concurrent_tool_execution(self, agent):
"""启用工具并发执行"""
# 实现并发执行优化
pass
# 使用性能调优工具
tuner = PerformanceTuner()
# 准备测试用例
test_cases = [
{"input": {"messages": [{"role": "user", "content": "查询用户账户余额"}]}},
{"input": {"messages": [{"role": "user", "content": "处理退款申请"}]}},
{"input": {"messages": [{"role": "user", "content": "更新用户信息"}]}}
]
# 分析性能
performance_report = tuner.analyze_performance(customer_service_agent, test_cases)
print("=== 性能分析报告 ===")
print(f"平均响应时间: {performance_report['baseline_metrics']['avg_response_time']:.2f}秒")
print(f"P95响应时间: {performance_report['baseline_metrics']['p95_response_time']:.2f}秒")
print(f"错误率: {performance_report['baseline_metrics']['error_rate']:.2%}")
print("\n=== 识别的瓶颈 ===")
for bottleneck in performance_report['bottlenecks']:
print(f"- {bottleneck}")
print("\n=== 优化建议 ===")
for recommendation in performance_report['recommendations']:
print(f"问题: {recommendation['issue']}")
print("建议:")
for suggestion in recommendation['suggestions']:
print(f" • {suggestion}")
print()
5. 总结与展望
LangGraph作为一个强大的Agent开发框架,为构建复杂的AI应用提供了坚实的基础。通过本文档的详细示例和最佳实践,开发者可以:
5.1 核心收益
- 快速开发:使用预构建组件快速搭建Agent系统
- 可靠运行:通过检查点机制和错误处理保证系统稳定性
- 灵活扩展:模块化设计支持复杂业务场景的定制化需求
- 生产就绪:完整的监控、部署和维护体系
5.2 技术要点
- 状态管理:合理设计状态结构,优化内存使用和性能
- 工具集成:高效的并发执行和错误处理策略
- 人机交互:灵活的中断恢复机制支持复杂的审批流程
- 质量保证:全面的测试策略和性能调优方法
通过掌握这些实战经验和最佳实践,开发者能够构建出高质量、可扩展的AI Agent系统,为用户提供更智能、更可靠的服务体验。