概述

LangGraph预构建模块提供了高级API,用于快速创建和部署智能体和工具。该模块封装了常见的设计模式,如ReAct智能体架构、工具执行节点等,让开发者能够快速构建复杂的AI应用。本文将深入分析这些预构建组件的实现原理。

1. Prebuilt模块架构

1.1 核心组件关系图

classDiagram class create_react_agent { <> +model: LanguageModelLike +tools: Sequence~BaseTool~ +prompt: Optional~Prompt~ +state_schema: Optional~StateSchemaType~ +checkpointer: Optional~Checkpointer~ +create_react_agent() CompiledStateGraph } class ToolNode { +list tools +dict fallbacks +bool handle_tool_errors +invoke(state) dict +_execute_tool_calls(messages) list +_handle_tool_call(call) ToolMessage } class ValidationNode { +Type schema +Callable error_formatter +invoke(state) dict +_validate_calls(messages) list } class AgentState { +messages: Annotated~list, add_messages~ +remaining_steps: NotRequired~RemainingSteps~ } class tools_condition { <> +state: Union~list, dict, BaseModel~ +messages_key: str +tools_condition() Literal } class InjectedState { <> +inject_state() } class InjectedStore { <> +inject_store() } create_react_agent --> AgentState : uses create_react_agent --> ToolNode : creates create_react_agent --> tools_condition : uses ToolNode --> InjectedState : supports ToolNode --> InjectedStore : supports ValidationNode --> ToolNode : extends style create_react_agent fill:#e1f5fe style ToolNode fill:#f3e5f5 style AgentState fill:#e8f5e8

1.2 模块架构图

graph TB subgraph "LangGraph Prebuilt架构" subgraph "智能体层" ReactAgent[create_react_agent
ReAct智能体创建器] AgentState[AgentState
智能体状态模式] Prompt[Prompt
提示词管理] end subgraph "工具层" ToolNode[ToolNode
工具执行节点] ToolValidator[ValidationNode
工具验证节点] ToolCondition[tools_condition
工具路由条件] end subgraph "依赖注入层" StateInjection[InjectedState
状态注入] StoreInjection[InjectedStore
存储注入] DynamicTools[动态工具绑定] end subgraph "执行引擎层" StateGraph[StateGraph
状态图] Pregel[Pregel执行器] Checkpoint[检查点系统] end subgraph "底层服务" LangChain[LangChain工具] LLMModels[语言模型] ExternalAPIs[外部API服务] end end %% 连接关系 ReactAgent --> AgentState ReactAgent --> Prompt ReactAgent --> ToolNode ToolNode --> StateInjection ToolNode --> StoreInjection ToolNode --> DynamicTools ToolValidator --> ToolNode ReactAgent --> StateGraph StateGraph --> Pregel Pregel --> Checkpoint ToolNode --> LangChain ReactAgent --> LLMModels ToolNode --> ExternalAPIs style ReactAgent fill:#e1f5fe style ToolNode fill:#f3e5f5 style StateGraph fill:#e8f5e8 style LangChain fill:#fff3e0

2. create_react_agent:ReAct智能体创建器

2.1 ReAct模式工作流

sequenceDiagram participant User as 用户 participant Agent as ReAct智能体 participant LLM as 语言模型 participant Tools as 工具节点 participant State as 状态管理 Note over User,State: ReAct智能体执行流程 User->>Agent: 用户输入 Agent->>State: 初始化状态 State->>Agent: 返回初始状态 loop ReAct循环 Agent->>LLM: 发送消息和上下文 LLM->>LLM: 推理(Reasoning) LLM->>LLM: 决定行动(Acting) LLM->>Agent: 返回响应和工具调用 alt 有工具调用 Agent->>Tools: 执行工具调用 Tools->>Tools: 并行执行多个工具 Tools->>Agent: 返回工具结果 Agent->>State: 更新状态 State->>Agent: 确认状态更新 else 无工具调用 Agent->>User: 返回最终答案 end Note over Agent: 观察(Observing)结果 end Agent->>User: 返回完整对话历史

2.2 create_react_agent核心实现

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
def create_react_agent(
    model: Union[
        str,
        LanguageModelLike,
        Callable[[StateSchema, Runtime[ContextT]], BaseChatModel],
        Callable[[StateSchema, Runtime[ContextT]], Awaitable[BaseChatModel]],
    ],
    tools: Union[Sequence[Union[BaseTool, Callable, dict[str, Any]]], ToolNode],
    *,
    prompt: Optional[Prompt] = None,
    response_format: Optional[Union[StructuredResponseSchema, tuple[str, StructuredResponseSchema]]] = None,
    pre_model_hook: Optional[RunnableLike] = None,
    post_model_hook: Optional[RunnableLike] = None,
    state_schema: Optional[StateSchemaType] = None,
    context_schema: Optional[Type[Any]] = None,
    checkpointer: Optional[Checkpointer] = None,
    store: Optional[BaseStore] = None,
    interrupt_before: Optional[list[str]] = None,
    interrupt_after: Optional[list[str]] = None,
    debug: bool = False,
    version: Literal["v1", "v2"] = "v2",
    name: Optional[str] = None,
    **deprecated_kwargs: Any,
) -> CompiledStateGraph:
    """创建一个在循环中调用工具的智能体图,直到满足停止条件
    
    该函数实现了ReAct(Reasoning and Acting)模式的智能体:
    1. 接收用户输入
    2. 使用语言模型进行推理
    3. 如果需要,调用工具获取信息
    4. 基于工具结果继续推理
    5. 返回最终答案
    
    Args:
        model: 智能体使用的语言模型,支持静态和动态模型选择
            - 静态模型:聊天模型实例或字符串标识符
            - 动态模型:签名为(state, runtime) -> BaseChatModel的可调用对象
        tools: 工具列表或ToolNode实例
        prompt: 可选的LLM提示词,支持多种格式
        response_format: 可选的最终智能体输出模式
        pre_model_hook: 在智能体节点前添加的可选节点
        post_model_hook: 在智能体节点后添加的可选节点  
        state_schema: 定义图状态的可选状态模式
        context_schema: 运行时上下文的可选模式
        checkpointer: 可选的检查点保存器对象
        store: 可选的存储对象
        interrupt_before: 要在之前中断的节点名称列表
        interrupt_after: 要在之后中断的节点名称列表
        debug: 是否启用调试模式的标志
        version: 要创建的图版本
        name: CompiledStateGraph的可选名称
        
    Returns:
        CompiledStateGraph: 编译后的可用于聊天交互的LangChain runnable
    """
    
    # 处理废弃参数
    if (config_schema := deprecated_kwargs.pop("config_schema", MISSING)) is not MISSING:
        warn(
            "`config_schema` is deprecated and will be removed in v2.0.0. "
            "Please use `context_schema` instead.",
            category=LangGraphDeprecatedSinceV10,
            stacklevel=2,
        )
        if context_schema is None:
            context_schema = config_schema

    # 验证参数
    if deprecated_kwargs:
        raise ValueError(f"Unexpected arguments: {deprecated_kwargs}")
    
    if version not in ("v1", "v2"):
        raise ValueError(f"Version {version} is not supported")
    
    # 处理工具参数
    if isinstance(tools, ToolNode):
        tool_node = tools
        tool_classes = tools.tools_by_name
    else:
        tool_classes = {t.name if isinstance(t, BaseTool) else t.__name__: t for t in tools}
        tool_node = ToolNode(tools) if tools else None
    
    # 确定状态模式
    if state_schema is None:
        if response_format is not None:
            state_schema = cast(
                StateSchemaType,
                AgentStateWithStructuredResponse 
                if not inspect.isclass(context_schema) 
                else AgentStateWithStructuredResponsePydantic
            )
        else:
            state_schema = cast(
                StateSchemaType,
                AgentState 
                if not inspect.isclass(context_schema) 
                else AgentStatePydantic
            )
    
    # 创建状态图
    graph = StateGraph(
        state_schema=state_schema,
        context_schema=context_schema,
    )
    
    # 添加智能体节点
    agent_node = _create_agent_node(
        model=model,
        tools=list(tool_classes.values()) if tool_classes else [],
        prompt=prompt,
        state_schema=state_schema,
    )
    
    # 根据版本创建不同的图结构
    if version == "v1":
        return _create_react_agent_v1(
            graph=graph,
            agent_node=agent_node,
            tool_node=tool_node,
            pre_model_hook=pre_model_hook,
            post_model_hook=post_model_hook,
            response_format=response_format,
            checkpointer=checkpointer,
            store=store,
            interrupt_before=interrupt_before,
            interrupt_after=interrupt_after,
            debug=debug,
            name=name,
        )
    else:  # version == "v2"
        return _create_react_agent_v2(
            graph=graph,
            agent_node=agent_node,
            tool_classes=tool_classes,
            pre_model_hook=pre_model_hook,
            post_model_hook=post_model_hook,
            response_format=response_format,
            checkpointer=checkpointer,
            store=store,
            interrupt_before=interrupt_before,
            interrupt_after=interrupt_after,
            debug=debug,
            name=name,
        )

def _create_agent_node(
    model: Union[str, LanguageModelLike, Callable],
    tools: list[BaseTool],
    prompt: Optional[Prompt],
    state_schema: StateSchemaType,
) -> RunnableCallable:
    """创建智能体执行节点
    
    Args:
        model: 语言模型
        tools: 工具列表  
        prompt: 提示词
        state_schema: 状态模式
        
    Returns:
        RunnableCallable: 智能体节点
    """
    
    def agent_node(state: StateSchema, *, config: RunnableConfig = None, **kwargs) -> dict:
        """智能体节点执行函数
        
        Args:
            state: 当前图状态
            config: 运行配置
            
        Returns:
            dict: 状态更新
        """
        # 处理动态模型选择
        if callable(model) and not isinstance(model, (str, Runnable)):
            # 动态模型函数
            runtime = Runtime(context=config.get("configurable", {}) if config else {})
            if inspect.iscoroutinefunction(model):
                import asyncio
                current_model = asyncio.run(model(state, runtime))
            else:
                current_model = model(state, runtime)
        else:
            # 静态模型
            current_model = model
            
            # 如果是字符串,解析为模型实例
            if isinstance(current_model, str):
                current_model = _parse_model_string(current_model)
            
            # 绑定工具到模型
            if tools and hasattr(current_model, 'bind_tools'):
                current_model = current_model.bind_tools(tools)
        
        # 构建消息输入
        messages = state.get("messages", [])
        
        # 应用提示词
        if prompt is not None:
            if isinstance(prompt, str):
                system_message = SystemMessage(content=prompt)
                messages = [system_message] + list(messages)
            elif isinstance(prompt, SystemMessage):
                messages = [prompt] + list(messages)
            elif callable(prompt):
                prompt_input = prompt(state)
                if isinstance(prompt_input, str):
                    system_message = SystemMessage(content=prompt_input)
                    messages = [system_message] + list(messages)
                elif isinstance(prompt_input, list):
                    messages = prompt_input + list(messages)
                else:
                    messages = [prompt_input] + list(messages)
            elif hasattr(prompt, 'invoke'):
                # Runnable prompt
                prompt_result = prompt.invoke(state, config)
                if isinstance(prompt_result, str):
                    system_message = SystemMessage(content=prompt_result)
                    messages = [system_message] + list(messages)
                else:
                    messages = [prompt_result] + list(messages)
        
        # 调用模型
        response = current_model.invoke(messages, config)
        
        # 检查剩余步数
        remaining_steps = state.get("remaining_steps")
        if remaining_steps is not None and hasattr(remaining_steps, 'value'):
            if remaining_steps.value <= 1:
                # 步数用尽,强制结束
                if isinstance(response, AIMessage) and response.tool_calls:
                    # 移除工具调用
                    response.tool_calls = []
                    response.content = (
                        response.content + "\n\n(Stopped due to step limit)"
                        if response.content 
                        else "(Stopped due to step limit)"
                    )
        
        return {"messages": [response]}
    
    return RunnableCallable(agent_node, name="agent")

def _create_react_agent_v1(
    graph: StateGraph,
    agent_node: RunnableCallable,
    tool_node: Optional[ToolNode],
    **kwargs
) -> CompiledStateGraph:
    """创建v1版本的ReAct智能体图
    
    v1版本特点:
    - 工具节点处理单个消息
    - 消息中的所有工具调用在工具节点内并行执行
    """
    # 添加前置钩子
    if kwargs.get("pre_model_hook"):
        graph.add_node("pre_hook", kwargs["pre_model_hook"])
        graph.set_entry_point("pre_hook")
        graph.add_edge("pre_hook", "agent")
    else:
        graph.set_entry_point("agent")
    
    # 添加智能体节点
    graph.add_node("agent", agent_node)
    
    # 添加后置钩子
    if kwargs.get("post_model_hook"):
        graph.add_node("post_hook", kwargs["post_model_hook"])
        graph.add_edge("agent", "post_hook")
        
        if tool_node:
            graph.add_node("tools", tool_node)
            graph.add_conditional_edges(
                "post_hook",
                tools_condition,
                {"tools": "tools", END: END}
            )
            graph.add_edge("tools", "agent")
        else:
            graph.set_finish_point("post_hook")
    else:
        if tool_node:
            graph.add_node("tools", tool_node)
            graph.add_conditional_edges(
                "agent",
                tools_condition,
                {"tools": "tools", END: END}
            )
            graph.add_edge("tools", "agent")
        else:
            graph.set_finish_point("agent")
    
    # 添加结构化响应处理
    if kwargs.get("response_format"):
        _add_structured_response_node(graph, kwargs["response_format"])
    
    # 编译图
    return graph.compile(
        checkpointer=kwargs.get("checkpointer"),
        store=kwargs.get("store"),
        interrupt_before=kwargs.get("interrupt_before"),
        interrupt_after=kwargs.get("interrupt_after"),
        debug=kwargs.get("debug", False),
        name=kwargs.get("name"),
    )

def _create_react_agent_v2(
    graph: StateGraph,
    agent_node: RunnableCallable,
    tool_classes: dict[str, BaseTool],
    **kwargs
) -> CompiledStateGraph:
    """创建v2版本的ReAct智能体图
    
    v2版本特点:
    - 工具节点处理单个工具调用
    - 工具调用使用Send API分发到多个工具节点实例
    - 支持更细粒度的工具执行控制
    """
    
    def route_tools(state: StateSchema) -> list[Send]:
        """路由工具调用到不同的工具节点实例"""
        messages = state.get("messages", [])
        if not messages:
            return []
        
        last_message = messages[-1]
        if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
            return []
        
        sends = []
        for tool_call in last_message.tool_calls:
            tool_name = tool_call["name"]
            if tool_name in tool_classes:
                # 为每个工具调用创建一个Send命令
                sends.append(Send(f"tool_{tool_name}", {
                    "messages": [last_message],
                    "tool_call_id": tool_call["id"]
                }))
            else:
                # 处理无效工具名称
                error_message = ToolMessage(
                    content=f"Error: {tool_name} is not a valid tool",
                    tool_call_id=tool_call["id"]
                )
                sends.append(Send("agent", {"messages": [error_message]}))
        
        return sends
    
    # 添加节点和边
    if kwargs.get("pre_model_hook"):
        graph.add_node("pre_hook", kwargs["pre_model_hook"])
        graph.set_entry_point("pre_hook")
        graph.add_edge("pre_hook", "agent")
    else:
        graph.set_entry_point("agent")
    
    graph.add_node("agent", agent_node)
    
    if kwargs.get("post_model_hook"):
        graph.add_node("post_hook", kwargs["post_model_hook"])
        graph.add_edge("agent", "post_hook")
        
        if tool_classes:
            # 添加工具路由
            graph.add_conditional_edges("post_hook", route_tools)
            
            # 为每个工具添加节点
            for tool_name, tool in tool_classes.items():
                tool_node = ToolNode([tool])
                graph.add_node(f"tool_{tool_name}", tool_node)
                graph.add_edge(f"tool_{tool_name}", "agent")
        else:
            graph.set_finish_point("post_hook")
    else:
        if tool_classes:
            graph.add_conditional_edges("agent", route_tools)
            
            for tool_name, tool in tool_classes.items():
                tool_node = ToolNode([tool])
                graph.add_node(f"tool_{tool_name}", tool_node)
                graph.add_edge(f"tool_{tool_name}", "agent")
        else:
            graph.set_finish_point("agent")
    
    # 添加结构化响应处理
    if kwargs.get("response_format"):
        _add_structured_response_node(graph, kwargs["response_format"])
    
    return graph.compile(
        checkpointer=kwargs.get("checkpointer"),
        store=kwargs.get("store"),
        interrupt_before=kwargs.get("interrupt_before"),
        interrupt_after=kwargs.get("interrupt_after"),
        debug=kwargs.get("debug", False),
        name=kwargs.get("name"),
    )

2.3 AgentState:智能体状态模式

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
class AgentState(TypedDict):
    """智能体的状态定义
    
    这个状态模式定义了ReAct智能体运行过程中需要维护的状态信息:
    - messages:对话消息历史
    - remaining_steps:剩余执行步数
    """
    
    messages: Annotated[Sequence[BaseMessage], add_messages]
    """对话消息序列
    
    使用add_messages reducer确保消息正确累积:
    - 新消息会添加到现有消息列表末尾
    - 支持消息去重和更新
    - 自动处理RemoveMessage指令
    """
    
    remaining_steps: NotRequired[RemainingSteps]
    """剩余执行步数(可选)
    
    用于防止智能体无限循环:
    - 每次调用工具后递减
    - 达到0时强制结束执行
    - 支持动态调整
    """

class AgentStatePydantic(BaseModel):
    """Pydantic版本的智能体状态
    
    提供更强的类型验证和序列化支持
    """
    
    messages: Annotated[Sequence[BaseMessage], add_messages]
    remaining_steps: RemainingSteps = 25
    
    class Config:
        arbitrary_types_allowed = True

class AgentStateWithStructuredResponse(AgentState):
    """带结构化响应的智能体状态
    
    当智能体需要返回结构化输出时使用
    """
    
    structured_response: StructuredResponse
    """结构化响应数据
    
    可以是dict或BaseModel实例,用于存储
    智能体生成的结构化输出结果
    """

# 状态管理辅助函数
def validate_agent_state(state: dict[str, Any]) -> AgentState:
    """验证智能体状态格式
    
    Args:
        state: 待验证的状态字典
        
    Returns:
        AgentState: 验证后的状态
        
    Raises:
        ValueError: 状态格式不正确时
    """
    if "messages" not in state:
        raise ValueError("Agent state must contain 'messages' key")
    
    messages = state["messages"]
    if not isinstance(messages, (list, tuple)):
        raise ValueError("Messages must be a sequence")
    
    # 验证消息格式
    validated_messages = []
    for msg in messages:
        if not isinstance(msg, BaseMessage):
            # 尝试转换
            validated_messages.extend(convert_to_messages([msg]))
        else:
            validated_messages.append(msg)
    
    result = AgentState(messages=validated_messages)
    
    # 添加可选字段
    if "remaining_steps" in state:
        result["remaining_steps"] = state["remaining_steps"]
    
    if "structured_response" in state:
        result = AgentStateWithStructuredResponse(
            messages=result["messages"],
            structured_response=state["structured_response"]
        )
        if "remaining_steps" in state:
            result["remaining_steps"] = state["remaining_steps"]
    
    return result

3. ToolNode:工具执行节点

3.1 ToolNode核心架构

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
class ToolNode(RunnableCallable):
    """LangGraph工作流中的工具执行节点
    
    该类实现了以下关键设计模式:
    - 多个工具调用的并行执行以提高效率
    - 可自定义错误消息的健壮错误处理
    - 需要访问图状态的工具的状态注入
    - 需要持久存储的工具的存储注入
    - 用于高级控制流的基于命令的状态更新
    """
    
    def __init__(
        self,
        tools: Sequence[Union[BaseTool, Callable, dict[str, Any]]],
        *,
        name: str = "tools",
        tags: Optional[list[str]] = None,
        handle_tool_errors: Optional[bool] = True,
        fallbacks: Optional[dict[str, Union[str, Callable, Runnable]]] = None,
    ) -> None:
        """初始化工具节点
        
        Args:
            tools: 工具序列,支持多种格式
            name: 节点名称
            tags: 节点标签
            handle_tool_errors: 是否处理工具错误
            fallbacks: 工具调用失败时的回退策略
        """
        super().__init__(self._func, name=name, tags=tags, trace=False)
        
        # 处理工具列表
        self.tools_by_name: dict[str, BaseTool] = {}
        for tool in tools:
            if isinstance(tool, BaseTool):
                self.tools_by_name[tool.name] = tool
            elif callable(tool):
                # 转换函数为工具
                tool_obj = create_tool(tool)
                self.tools_by_name[tool_obj.name] = tool_obj
            elif isinstance(tool, dict):
                # 从字典创建工具
                tool_obj = BaseTool.from_dict(tool)
                self.tools_by_name[tool_obj.name] = tool_obj
            else:
                raise ValueError(f"Unsupported tool type: {type(tool)}")
        
        self.handle_tool_errors = handle_tool_errors
        self.fallbacks = fallbacks or {}
        
        # 分析工具的特殊参数需求
        self._analyze_tool_signatures()
    
    def _analyze_tool_signatures(self) -> None:
        """分析工具签名,识别特殊参数需求"""
        self._tools_with_state: dict[str, list[str]] = {}
        self._tools_with_store: dict[str, list[str]] = {}
        
        for tool_name, tool in self.tools_by_name.items():
            # 获取工具函数签名
            if hasattr(tool, 'func'):
                func = tool.func
            else:
                func = tool
            
            try:
                sig = inspect.signature(func)
                type_hints = get_type_hints(func, include_extras=True)
                
                state_params = []
                store_params = []
                
                for param_name, param in sig.parameters.items():
                    param_type = type_hints.get(param_name, param.annotation)
                    
                    # 检查状态注入
                    if self._is_injected_state(param_type):
                        state_params.append(param_name)
                    
                    # 检查存储注入
                    elif self._is_injected_store(param_type):
                        store_params.append(param_name)
                
                if state_params:
                    self._tools_with_state[tool_name] = state_params
                
                if store_params:
                    self._tools_with_store[tool_name] = store_params
                    
            except (ValueError, TypeError):
                # 无法分析签名,跳过
                continue
    
    def _is_injected_state(self, param_type) -> bool:
        """检查参数是否为注入状态类型"""
        return (
            get_origin(param_type) is Annotated and
            any(isinstance(arg, InjectedState) for arg in get_args(param_type)[1:])
        ) or param_type is InjectedState
    
    def _is_injected_store(self, param_type) -> bool:
        """检查参数是否为注入存储类型"""
        return (
            get_origin(param_type) is Annotated and
            any(isinstance(arg, InjectedStore) for arg in get_args(param_type)[1:])
        ) or param_type is InjectedStore
    
    def _func(
        self,
        input: Union[
            list[AnyMessage],
            dict[str, Any],
            BaseModel,
        ],
        config: RunnableConfig,
        **kwargs: Any,
    ) -> Any:
        """工具节点的主执行函数
        
        Args:
            input: 输入消息或状态
            config: 运行配置
            **kwargs: 额外参数
            
        Returns:
            执行结果
        """
        # 标准化输入格式
        input_type, message_list = self._parse_input(input)
        
        # 提取工具调用
        tool_calls = []
        for message in message_list:
            if isinstance(message, AIMessage) and message.tool_calls:
                tool_calls.extend(message.tool_calls)
        
        if not tool_calls:
            # 无工具调用,返回空结果
            return self._format_output([], input_type)
        
        # 执行工具调用
        tool_messages = self._execute_tool_calls(
            tool_calls=tool_calls,
            config=config,
            state=input if input_type != "list" else None,
            **kwargs
        )
        
        return self._format_output(tool_messages, input_type)
    
    def _execute_tool_calls(
        self,
        tool_calls: list[ToolCall],
        config: RunnableConfig,
        state: Optional[Union[dict[str, Any], BaseModel]] = None,
        **kwargs: Any,
    ) -> list[ToolMessage]:
        """并行执行工具调用
        
        Args:
            tool_calls: 工具调用列表
            config: 运行配置
            state: 图状态(用于状态注入)
            **kwargs: 额外参数
            
        Returns:
            list[ToolMessage]: 工具执行结果消息列表
        """
        # 获取执行器
        executor = get_executor_for_config(config)
        
        # 创建执行任务
        with executor:
            # 为每个工具调用创建Future
            futures = []
            for tool_call in tool_calls:
                future = executor.submit(
                    self._execute_single_tool_call,
                    tool_call=tool_call,
                    config=config,
                    state=state,
                    **kwargs
                )
                futures.append((tool_call, future))
            
            # 收集结果
            tool_messages = []
            for tool_call, future in futures:
                try:
                    result = future.result()
                    tool_messages.append(result)
                except Exception as e:
                    # 处理执行错误
                    error_message = self._handle_tool_error(tool_call, e)
                    tool_messages.append(error_message)
        
        return tool_messages
    
    def _execute_single_tool_call(
        self,
        tool_call: ToolCall,
        config: RunnableConfig,
        state: Optional[Union[dict[str, Any], BaseModel]] = None,
        **kwargs: Any,
    ) -> ToolMessage:
        """执行单个工具调用
        
        Args:
            tool_call: 工具调用信息
            config: 运行配置
            state: 图状态
            **kwargs: 额外参数
            
        Returns:
            ToolMessage: 工具执行结果消息
        """
        tool_name = tool_call["name"]
        tool_args = tool_call.get("args", {})
        tool_call_id = tool_call["id"]
        
        # 检查工具是否存在
        if tool_name not in self.tools_by_name:
            available_tools = ", ".join(self.tools_by_name.keys())
            error_content = INVALID_TOOL_NAME_ERROR_TEMPLATE.format(
                requested_tool=tool_name,
                available_tools=available_tools
            )
            return ToolMessage(
                content=error_content,
                tool_call_id=tool_call_id,
                name=tool_name,
            )
        
        tool = self.tools_by_name[tool_name]
        
        try:
            # 处理参数注入
            final_args = tool_args.copy()
            
            # 状态注入
            if tool_name in self._tools_with_state and state is not None:
                state_params = self._tools_with_state[tool_name]
                for param_name in state_params:
                    final_args[param_name] = state
            
            # 存储注入
            if tool_name in self._tools_with_store:
                store = kwargs.get('store') or config.get('configurable', {}).get('store')
                if store:
                    store_params = self._tools_with_store[tool_name]
                    for param_name in store_params:
                        final_args[param_name] = store
            
            # 调用工具
            if hasattr(tool, 'ainvoke') and inspect.iscoroutinefunction(tool.ainvoke):
                # 异步工具
                import asyncio
                result = asyncio.run(tool.ainvoke(final_args, config))
            else:
                # 同步工具
                result = tool.invoke(final_args, config)
            
            # 处理结果
            content = msg_content_output(result)
            
            return ToolMessage(
                content=content,
                tool_call_id=tool_call_id,
                name=tool_name,
            )
            
        except Exception as e:
            if self.handle_tool_errors:
                # 生成错误消息
                error_content = TOOL_CALL_ERROR_TEMPLATE.format(error=str(e))
                return ToolMessage(
                    content=error_content,
                    tool_call_id=tool_call_id,
                    name=tool_name,
                )
            else:
                # 重新抛出异常
                raise
    
    def _handle_tool_error(self, tool_call: ToolCall, error: Exception) -> ToolMessage:
        """处理工具执行错误
        
        Args:
            tool_call: 失败的工具调用
            error: 错误实例
            
        Returns:
            ToolMessage: 错误消息
        """
        tool_name = tool_call["name"]
        tool_call_id = tool_call["id"]
        
        # 检查是否有回退策略
        if tool_name in self.fallbacks:
            fallback = self.fallbacks[tool_name]
            
            try:
                if isinstance(fallback, str):
                    # 字符串回退
                    content = fallback
                elif callable(fallback):
                    # 可调用回退
                    content = fallback(tool_call, error)
                elif hasattr(fallback, 'invoke'):
                    # Runnable回退
                    content = fallback.invoke({
                        "tool_call": tool_call,
                        "error": error
                    })
                else:
                    content = str(fallback)
                
                return ToolMessage(
                    content=msg_content_output(content),
                    tool_call_id=tool_call_id,
                    name=tool_name,
                )
                
            except Exception as fallback_error:
                # 回退也失败了
                error_content = (
                    f"Tool execution failed: {error}\n"
                    f"Fallback also failed: {fallback_error}"
                )
        else:
            # 没有回退策略
            error_content = TOOL_CALL_ERROR_TEMPLATE.format(error=str(error))
        
        return ToolMessage(
            content=error_content,
            tool_call_id=tool_call_id,
            name=tool_name,
        )
    
    def _parse_input(self, input: Any) -> tuple[str, list[AnyMessage]]:
        """解析输入格式
        
        Args:
            input: 输入数据
            
        Returns:
            tuple: (输入类型, 消息列表)
        """
        if isinstance(input, list):
            return "list", input
        elif isinstance(input, dict):
            if "messages" in input:
                return "dict", input["messages"]
            else:
                raise ValueError("Dictionary input must contain 'messages' key")
        elif isinstance(input, BaseModel):
            if hasattr(input, 'messages'):
                return "dict", input.messages  
            else:
                raise ValueError("BaseModel input must have 'messages' attribute")
        else:
            raise ValueError(f"Unsupported input type: {type(input)}")
    
    def _format_output(self, tool_messages: list[ToolMessage], input_type: str) -> Any:
        """格式化输出结果
        
        Args:
            tool_messages: 工具消息列表
            input_type: 输入类型
            
        Returns:
            格式化后的输出
        """
        if input_type == "list":
            return tool_messages
        else:
            return {"messages": tool_messages}

def msg_content_output(output: Any) -> Union[str, list[dict]]:
    """将工具输出转换为有效的消息内容格式
    
    LangChain ToolMessages接受字符串内容或内容块列表。
    此函数确保工具输出正确格式化以供消息消费,
    尽可能保留结构化数据,回退到JSON序列化或字符串转换。
    
    Args:
        output: 工具执行的原始输出,可以是任何类型
        
    Returns:
        Union[str, list[dict]]: 输出的字符串表示或内容块列表
    """
    if isinstance(output, str):
        return output
    elif isinstance(output, (dict, list)):
        # 尝试保持结构化格式
        if isinstance(output, list) and all(
            isinstance(item, dict) and 
            item.get("type") in TOOL_MESSAGE_BLOCK_TYPES 
            for item in output
        ):
            # 已经是有效的内容块格式
            return output
        else:
            # 转换为JSON字符串
            import json
            try:
                return json.dumps(output, ensure_ascii=False, indent=2)
            except (TypeError, ValueError):
                return str(output)
    else:
        # 其他类型转换为字符串
        return str(output)

3.2 依赖注入机制

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
class InjectedState:
    """状态注入注解类
    
    用于标记工具参数需要注入当前图状态
    """
    
    def __init__(self, *, key: Optional[str] = None):
        """初始化状态注入器
        
        Args:
            key: 可选的状态键,如果指定则只注入特定键的值
        """
        self.key = key
    
    def __repr__(self) -> str:
        return f"InjectedState(key={self.key!r})"

class InjectedStore:
    """存储注入注解类
    
    用于标记工具参数需要注入持久存储实例
    """
    
    def __init__(self, *, namespace: Optional[str] = None):
        """初始化存储注入器
        
        Args:
            namespace: 可选的命名空间,用于隔离不同工具的存储
        """
        self.namespace = namespace
    
    def __repr__(self) -> str:
        return f"InjectedStore(namespace={self.namespace!r})"

# 使用示例
def weather_tool_with_state(
    location: str,
    state: Annotated[dict, InjectedState()],
    store: Annotated[BaseStore, InjectedStore(namespace="weather")],
) -> str:
    """带状态和存储注入的天气工具
    
    Args:
        location: 查询位置
        state: 注入的图状态
        store: 注入的持久存储
        
    Returns:
        str: 天气信息
    """
    # 访问图状态
    user_preferences = state.get("user_preferences", {})
    temperature_unit = user_preferences.get("temperature_unit", "celsius")
    
    # 访问持久存储
    cached_weather = store.get(("weather_cache", location))
    if cached_weather and not _is_cache_expired(cached_weather):
        return cached_weather["data"]
    
    # 调用天气API
    weather_data = _fetch_weather(location, temperature_unit)
    
    # 缓存结果
    store.put(("weather_cache", location), {
        "data": weather_data,
        "timestamp": time.time()
    })
    
    return weather_data

def _is_cache_expired(cached_data: dict, max_age: int = 3600) -> bool:
    """检查缓存是否过期"""
    return time.time() - cached_data.get("timestamp", 0) > max_age

def _fetch_weather(location: str, unit: str) -> str:
    """模拟天气API调用"""
    import random
    temp = random.randint(-10, 35) if unit == "celsius" else random.randint(14, 95)
    return f"The weather in {location} is {temp}°{unit[0].upper()}"

3.3 tools_condition:工具路由条件

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
def tools_condition(
    state: Union[list[AnyMessage], dict[str, Any], BaseModel],
    messages_key: str = "messages",
) -> Literal["tools", "__end__"]:
    """工具调用工作流的条件路由函数
    
    此实用函数实现ReAct风格智能体的标准条件逻辑:
    如果最后一条AI消息包含工具调用,则路由到工具执行节点;
    否则,结束工作流。这种模式是大多数工具调用智能体架构的基础。
    
    该函数处理LangGraph应用程序中常用的多种状态格式,
    使其在不同图设计中具有灵活性,同时保持一致的行为。
    
    Args:
        state: 要检查工具调用的当前图状态。支持的格式:
            - 包含messages键的字典(用于StateGraph)
            - 具有messages属性的BaseModel实例
        messages_key: 包含状态中消息列表的键或属性名称。
            这允许使用不同状态模式的图进行自定义。
            默认为"messages"。
            
    Returns:
        如果最后一条AI消息中存在工具调用,则返回"tools",
        否则返回"__end__"以终止工作流。这些是工具调用条件边的标准路由目标。
        
    Raises:
        ValueError: 如果在提供的状态格式中找不到消息。
        
    Example:
        在ReAct智能体中的基本用法:
        
        ```python
        from langgraph.graph import StateGraph  
        from langgraph.prebuilt import ToolNode, tools_condition
        from typing_extensions import TypedDict
        
        class State(TypedDict):
            messages: list
            
        graph = StateGraph(State)
        graph.add_node("llm", call_model)
        graph.add_node("tools", ToolNode([my_tool]))
        graph.add_conditional_edges(
            "llm",
            tools_condition,  # 路由到"tools"或"__end__"
            {"tools": "tools", "__end__": "__end__"}
        )
        ```
        
        自定义消息键:
        
        ```python
        def custom_condition(state):
            return tools_condition(state, messages_key="chat_history")
        ```
    """
    # 处理不同的状态格式
    if isinstance(state, list):
        # 状态本身就是消息列表
        messages = state
    elif isinstance(state, dict):
        # 字典状态格式
        if messages_key not in state:
            raise ValueError(f"No '{messages_key}' key found in state")
        messages = state[messages_key]
    elif isinstance(state, BaseModel):
        # Pydantic模型状态格式
        if not hasattr(state, messages_key):
            raise ValueError(f"No '{messages_key}' attribute found in state")
        messages = getattr(state, messages_key)
    else:
        raise ValueError(f"Unsupported state type: {type(state)}")
    
    # 检查消息列表
    if not messages:
        return "__end__"
    
    # 获取最后一条消息
    last_message = messages[-1]
    
    # 检查是否为带工具调用的AI消息
    if isinstance(last_message, AIMessage) and last_message.tool_calls:
        return "tools"
    else:
        return "__end__"

# 扩展版本的条件函数
def advanced_tools_condition(
    state: Union[list[AnyMessage], dict[str, Any], BaseModel],
    *,
    messages_key: str = "messages",
    max_iterations: Optional[int] = None,
    check_remaining_steps: bool = True,
) -> Literal["tools", "__end__"]:
    """高级工具条件函数,支持更多控制选项
    
    Args:
        state: 图状态
        messages_key: 消息键名
        max_iterations: 最大迭代次数
        check_remaining_steps: 是否检查剩余步数
        
    Returns:
        路由目标:"tools"或"__end__"
    """
    # 基础工具调用检查
    basic_result = tools_condition(state, messages_key)
    
    if basic_result == "__end__":
        return "__end__"
    
    # 高级检查
    if isinstance(state, dict):
        # 检查剩余步数
        if check_remaining_steps and "remaining_steps" in state:
            remaining = state["remaining_steps"]
            if hasattr(remaining, 'value') and remaining.value <= 0:
                return "__end__"
        
        # 检查最大迭代次数
        if max_iterations is not None:
            messages = state.get(messages_key, [])
            tool_call_count = sum(
                1 for msg in messages 
                if isinstance(msg, AIMessage) and msg.tool_calls
            )
            if tool_call_count >= max_iterations:
                return "__end__"
    
    return "tools"

# 工具条件的工厂函数
def create_tools_condition(
    *,
    messages_key: str = "messages",
    max_iterations: Optional[int] = None,
    required_tools: Optional[set[str]] = None,
    forbidden_tools: Optional[set[str]] = None,
) -> Callable:
    """创建自定义工具条件函数
    
    Args:
        messages_key: 消息键名
        max_iterations: 最大迭代次数
        required_tools: 必需工具集合(如果指定,只有这些工具被允许)
        forbidden_tools: 禁止工具集合
        
    Returns:
        Callable: 自定义工具条件函数
    """
    
    def custom_condition(state) -> Literal["tools", "__end__"]:
        # 基础检查
        if tools_condition(state, messages_key) == "__end__":
            return "__end__"
        
        # 获取消息
        if isinstance(state, list):
            messages = state
        elif isinstance(state, dict):
            messages = state.get(messages_key, [])
        else:
            messages = getattr(state, messages_key, [])
        
        if not messages:
            return "__end__"
        
        last_message = messages[-1]
        if not (isinstance(last_message, AIMessage) and last_message.tool_calls):
            return "__end__"
        
        # 检查工具限制
        if required_tools or forbidden_tools:
            tool_names = {call["name"] for call in last_message.tool_calls}
            
            if required_tools and not tool_names.issubset(required_tools):
                # 有不被允许的工具
                return "__end__"
            
            if forbidden_tools and tool_names.intersection(forbidden_tools):
                # 有被禁止的工具
                return "__end__"
        
        # 检查迭代次数
        if max_iterations is not None:
            tool_call_count = sum(
                1 for msg in messages
                if isinstance(msg, AIMessage) and msg.tool_calls
            )
            if tool_call_count >= max_iterations:
                return "__end__"
        
        return "tools"
    
    return custom_condition

4. ValidationNode:工具验证节点

4.1 ValidationNode实现

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
class ValidationNode(RunnableCallable):
    """工具调用验证节点
    
    该节点在工具执行前验证工具调用参数,确保:
    - 参数类型正确
    - 必需参数存在
    - 参数值有效
    - 符合工具的输入模式
    """
    
    def __init__(
        self,
        schemas: dict[str, Union[Type[BaseModel], Type[BaseModelV1]]],
        *,
        name: str = "validation",
        format_error: Optional[Callable[[BaseException, ToolCall, Any], str]] = None,
        handle_validation_errors: bool = True,
    ) -> None:
        """初始化验证节点
        
        Args:
            schemas: 工具名称到验证模式的映射
            name: 节点名称
            format_error: 自定义错误格式化函数
            handle_validation_errors: 是否处理验证错误
        """
        super().__init__(self._func, name=name, trace=False)
        
        self.schemas = schemas
        self.format_error = format_error or _default_format_error
        self.handle_validation_errors = handle_validation_errors
    
    def _func(
        self,
        input: Union[list[AnyMessage], dict[str, Any], BaseModel],
        config: RunnableConfig,
        **kwargs: Any,
    ) -> Any:
        """验证节点执行函数
        
        Args:
            input: 输入消息或状态
            config: 运行配置
            **kwargs: 额外参数
            
        Returns:
            验证结果
        """
        # 解析输入
        input_type, messages = self._parse_input(input)
        
        # 查找需要验证的工具调用
        tool_calls_to_validate = []
        validated_messages = []
        
        for message in messages:
            if isinstance(message, AIMessage) and message.tool_calls:
                # 验证工具调用
                new_tool_calls = []
                validation_errors = []
                
                for tool_call in message.tool_calls:
                    try:
                        validated_call = self._validate_tool_call(tool_call)
                        new_tool_calls.append(validated_call)
                    except Exception as e:
                        if self.handle_validation_errors:
                            # 创建错误消息
                            error_msg = ToolMessage(
                                content=self.format_error(
                                    e, tool_call, self.schemas.get(tool_call["name"])
                                ),
                                tool_call_id=tool_call["id"],
                                name=tool_call["name"],
                            )
                            validation_errors.append(error_msg)
                        else:
                            raise
                
                # 创建新的AI消息(如果有有效的工具调用)
                if new_tool_calls:
                    new_message = AIMessage(
                        content=message.content,
                        tool_calls=new_tool_calls,
                        id=message.id,
                    )
                    validated_messages.append(new_message)
                
                # 添加验证错误消息
                validated_messages.extend(validation_errors)
                
            else:
                # 非工具调用消息,直接添加
                validated_messages.append(message)
        
        # 格式化输出
        return self._format_output(validated_messages, input_type)
    
    def _validate_tool_call(self, tool_call: ToolCall) -> ToolCall:
        """验证单个工具调用
        
        Args:
            tool_call: 工具调用信息
            
        Returns:
            ToolCall: 验证后的工具调用
            
        Raises:
            ValidationError: 验证失败时
        """
        tool_name = tool_call["name"]
        tool_args = tool_call.get("args", {})
        
        # 检查是否有对应的模式
        if tool_name not in self.schemas:
            # 没有模式,跳过验证
            return tool_call
        
        schema = self.schemas[tool_name]
        
        try:
            # 使用Pydantic模式验证参数
            if hasattr(schema, 'model_validate'):
                # Pydantic v2
                validated_args = schema.model_validate(tool_args).model_dump()
            else:
                # Pydantic v1
                validated_args = schema(**tool_args).dict()
            
            # 返回验证后的工具调用
            return {
                **tool_call,
                "args": validated_args
            }
            
        except Exception as e:
            # 验证失败
            raise ValidationError(f"Tool '{tool_name}' validation failed: {e}") from e
    
    def _parse_input(self, input: Any) -> tuple[str, list[AnyMessage]]:
        """解析输入格式"""
        if isinstance(input, list):
            return "list", input
        elif isinstance(input, dict):
            return "dict", input.get("messages", [])
        elif isinstance(input, BaseModel):
            return "dict", getattr(input, 'messages', [])
        else:
            raise ValueError(f"Unsupported input type: {type(input)}")
    
    def _format_output(self, messages: list[AnyMessage], input_type: str) -> Any:
        """格式化输出"""
        if input_type == "list":
            return messages
        else:
            return {"messages": messages}

def _default_format_error(
    error: BaseException,
    call: ToolCall,
    schema: Union[Type[BaseModel], Type[BaseModelV1]],
) -> str:
    """默认错误格式化函数
    
    Args:
        error: 验证错误
        call: 失败的工具调用
        schema: 验证模式
        
    Returns:
        str: 格式化的错误消息
    """
    tool_name = call["name"]
    
    error_msg = f"Tool '{tool_name}' validation error: {repr(error)}"
    
    # 添加模式信息
    if schema:
        if hasattr(schema, 'model_fields'):
            # Pydantic v2
            fields_info = []
            for field_name, field_info in schema.model_fields.items():
                required = field_info.is_required()
                field_type = field_info.annotation
                fields_info.append(f"  - {field_name}: {field_type} {'(required)' if required else '(optional)'}")
            
            if fields_info:
                error_msg += f"\n\nExpected parameters:\n" + "\n".join(fields_info)
        
        elif hasattr(schema, '__fields__'):
            # Pydantic v1
            fields_info = []
            for field_name, field in schema.__fields__.items():
                required = field.required
                field_type = field.type_
                fields_info.append(f"  - {field_name}: {field_type} {'(required)' if required else '(optional)'}")
            
            if fields_info:
                error_msg += f"\n\nExpected parameters:\n" + "\n".join(fields_info)
    
    error_msg += "\n\nPlease fix the validation errors and try again."
    
    return error_msg

class ValidationError(Exception):
    """工具验证错误"""
    pass

# 验证节点工厂函数
def create_validation_node(
    tools: Sequence[BaseTool],
    *,
    name: str = "validation",
    auto_generate_schemas: bool = True,
    custom_schemas: Optional[dict[str, Type[BaseModel]]] = None,
    format_error: Optional[Callable] = None,
) -> ValidationNode:
    """创建验证节点的工厂函数
    
    Args:
        tools: 工具列表
        name: 节点名称
        auto_generate_schemas: 是否自动生成验证模式
        custom_schemas: 自定义验证模式
        format_error: 自定义错误格式化函数
        
    Returns:
        ValidationNode: 验证节点实例
    """
    schemas = custom_schemas or {}
    
    if auto_generate_schemas:
        for tool in tools:
            if tool.name not in schemas:
                # 尝试从工具生成模式
                if hasattr(tool, 'args_schema') and tool.args_schema:
                    schemas[tool.name] = tool.args_schema
                elif hasattr(tool, '_schema'):
                    schemas[tool.name] = tool._schema
    
    return ValidationNode(
        schemas=schemas,
        name=name,
        format_error=format_error,
    )

5. 高级特性与扩展

5.1 动态工具绑定

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
class DynamicToolAgent:
    """支持动态工具绑定的智能体"""
    
    def __init__(
        self,
        base_model: BaseChatModel,
        tool_registry: dict[str, BaseTool],
    ):
        self.base_model = base_model
        self.tool_registry = tool_registry
    
    def create_agent_with_tools(
        self,
        tool_names: list[str],
        context: Optional[dict] = None,
    ) -> CompiledStateGraph:
        """根据指定工具动态创建智能体
        
        Args:
            tool_names: 要绑定的工具名称列表
            context: 上下文信息
            
        Returns:
            CompiledStateGraph: 动态创建的智能体图
        """
        # 选择工具
        selected_tools = []
        for name in tool_names:
            if name in self.tool_registry:
                tool = self.tool_registry[name]
                
                # 根据上下文定制工具
                if context and hasattr(tool, 'with_config'):
                    tool = tool.with_config({"context": context})
                
                selected_tools.append(tool)
        
        # 创建动态模型函数
        def dynamic_model(state, runtime):
            """动态模型选择函数"""
            # 根据状态选择合适的模型配置
            model_config = self._select_model_config(state, runtime)
            
            # 绑定选择的工具
            if selected_tools:
                return model_config.bind_tools(selected_tools)
            else:
                return model_config
        
        # 创建智能体
        return create_react_agent(
            model=dynamic_model,
            tools=selected_tools,
            prompt=self._create_dynamic_prompt(tool_names, context),
        )
    
    def _select_model_config(self, state, runtime) -> BaseChatModel:
        """根据状态选择模型配置"""
        # 示例逻辑:根据消息长度选择模型
        messages = state.get("messages", [])
        total_tokens = sum(len(str(msg)) for msg in messages)
        
        if total_tokens > 10000:
            # 长对话使用更强的模型
            return self.base_model.with_config({"model": "gpt-4"})
        else:
            # 短对话使用效率模型
            return self.base_model.with_config({"model": "gpt-3.5-turbo"})
    
    def _create_dynamic_prompt(
        self, 
        tool_names: list[str], 
        context: Optional[dict]
    ) -> str:
        """创建动态提示词"""
        base_prompt = "You are a helpful assistant with access to the following tools:\n"
        
        # 添加工具描述
        for name in tool_names:
            if name in self.tool_registry:
                tool = self.tool_registry[name]
                base_prompt += f"- {name}: {tool.description}\n"
        
        # 添加上下文信息
        if context:
            base_prompt += f"\nContext: {context}\n"
        
        base_prompt += "\nUse tools when appropriate to help answer questions."
        
        return base_prompt

# 工具注册表
class ToolRegistry:
    """工具注册表,支持动态注册和发现"""
    
    def __init__(self):
        self._tools: dict[str, BaseTool] = {}
        self._categories: dict[str, set[str]] = {}
        self._dependencies: dict[str, set[str]] = {}
    
    def register(
        self,
        tool: BaseTool,
        *,
        category: Optional[str] = None,
        dependencies: Optional[list[str]] = None,
    ) -> None:
        """注册工具
        
        Args:
            tool: 要注册的工具
            category: 工具类别
            dependencies: 工具依赖列表
        """
        self._tools[tool.name] = tool
        
        if category:
            if category not in self._categories:
                self._categories[category] = set()
            self._categories[category].add(tool.name)
        
        if dependencies:
            self._dependencies[tool.name] = set(dependencies)
    
    def get_tools_by_category(self, category: str) -> list[BaseTool]:
        """根据类别获取工具"""
        if category not in self._categories:
            return []
        
        return [
            self._tools[name] 
            for name in self._categories[category]
            if name in self._tools
        ]
    
    def resolve_dependencies(self, tool_names: list[str]) -> list[str]:
        """解析工具依赖关系"""
        resolved = set()
        to_resolve = set(tool_names)
        
        while to_resolve:
            tool_name = to_resolve.pop()
            if tool_name in resolved:
                continue
            
            resolved.add(tool_name)
            
            # 添加依赖工具
            if tool_name in self._dependencies:
                for dep in self._dependencies[tool_name]:
                    if dep not in resolved:
                        to_resolve.add(dep)
        
        return list(resolved)

5.2 工具执行监控

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
class MonitoredToolNode(ToolNode):
    """带监控的工具节点"""
    
    def __init__(
        self,
        tools: Sequence[Union[BaseTool, Callable, dict[str, Any]]],
        *,
        monitor_callback: Optional[Callable] = None,
        performance_tracker: Optional["PerformanceTracker"] = None,
        **kwargs
    ):
        super().__init__(tools, **kwargs)
        self.monitor_callback = monitor_callback
        self.performance_tracker = performance_tracker or PerformanceTracker()
    
    def _execute_single_tool_call(
        self,
        tool_call: ToolCall,
        config: RunnableConfig,
        **kwargs
    ) -> ToolMessage:
        """带监控的工具执行"""
        import time
        
        tool_name = tool_call["name"]
        start_time = time.time()
        
        try:
            # 记录开始执行
            if self.monitor_callback:
                self.monitor_callback("tool_start", {
                    "tool_name": tool_name,
                    "tool_call_id": tool_call["id"],
                    "args": tool_call.get("args", {}),
                    "start_time": start_time,
                })
            
            # 执行工具
            result = super()._execute_single_tool_call(tool_call, config, **kwargs)
            
            # 记录执行结果
            end_time = time.time()
            duration = end_time - start_time
            
            self.performance_tracker.record_execution(
                tool_name=tool_name,
                duration=duration,
                success=True,
                result_size=len(str(result.content)) if result.content else 0,
            )
            
            if self.monitor_callback:
                self.monitor_callback("tool_success", {
                    "tool_name": tool_name,
                    "tool_call_id": tool_call["id"],
                    "duration": duration,
                    "result": result,
                })
            
            return result
            
        except Exception as e:
            # 记录错误
            end_time = time.time()
            duration = end_time - start_time
            
            self.performance_tracker.record_execution(
                tool_name=tool_name,
                duration=duration,
                success=False,
                error=str(e),
            )
            
            if self.monitor_callback:
                self.monitor_callback("tool_error", {
                    "tool_name": tool_name,
                    "tool_call_id": tool_call["id"],
                    "duration": duration,
                    "error": e,
                })
            
            raise

class PerformanceTracker:
    """工具性能追踪器"""
    
    def __init__(self):
        self.executions: list[dict] = []
        self.stats: dict[str, dict] = {}
    
    def record_execution(
        self,
        tool_name: str,
        duration: float,
        success: bool,
        result_size: int = 0,
        error: Optional[str] = None,
    ):
        """记录工具执行"""
        execution = {
            "tool_name": tool_name,
            "duration": duration,
            "success": success,
            "result_size": result_size,
            "error": error,
            "timestamp": time.time(),
        }
        
        self.executions.append(execution)
        
        # 更新统计信息
        if tool_name not in self.stats:
            self.stats[tool_name] = {
                "total_executions": 0,
                "successful_executions": 0,
                "failed_executions": 0,
                "total_duration": 0.0,
                "avg_duration": 0.0,
                "max_duration": 0.0,
                "min_duration": float('inf'),
            }
        
        stats = self.stats[tool_name]
        stats["total_executions"] += 1
        
        if success:
            stats["successful_executions"] += 1
        else:
            stats["failed_executions"] += 1
        
        stats["total_duration"] += duration
        stats["avg_duration"] = stats["total_duration"] / stats["total_executions"]
        stats["max_duration"] = max(stats["max_duration"], duration)
        stats["min_duration"] = min(stats["min_duration"], duration)
    
    def get_performance_report(self) -> dict:
        """生成性能报告"""
        return {
            "total_executions": len(self.executions),
            "tool_stats": self.stats,
            "recent_failures": [
                ex for ex in self.executions[-10:]
                if not ex["success"]
            ],
        }

5.3 错误分级与告警策略

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class ToolErrorClassifier:
    """工具错误分级:用于区分可重试/不可重试/降级可接受的错误"""
    
    RETRYABLE = (TimeoutError, ConnectionError, OSError)
    NON_RETRYABLE = (ValueError, TypeError)
    
    @staticmethod
    def classify(error: Exception) -> str:
        if isinstance(error, ToolErrorClassifier.RETRYABLE):
            return "retryable"
        if isinstance(error, ToolErrorClassifier.NON_RETRYABLE):
            return "non_retryable"
        return "unknown"

class AlertingGateway:
    """告警网关:将高等级错误和异常QPS/延迟推送到外部告警系统"""
    
    def __init__(self, send_fn: Callable[[str, dict], None]):
        self.send = send_fn
    
    def alert(self, level: str, message: str, **kv):
        payload = {"level": level, "message": message, **kv}
        self.send("tool_alert", payload)

def monitored_fallback(tool_node: MonitoredToolNode, alerting: AlertingGateway):
    """给MonitoredToolNode注入错误分级与告警回调"""
    original_handle = tool_node._handle_tool_error
    
    def _wrapped_handle(call: ToolCall, error: Exception) -> ToolMessage:
        cls = ToolErrorClassifier.classify(error)
        level = "warning" if cls == "retryable" else "critical"
        alerting.alert(level, "tool_execution_failed", tool=call.get("name"), error=str(error), type=cls)
        return original_handle(call, error)
    
    tool_node._handle_tool_error = _wrapped_handle  # 注入包装
    return tool_node

要点:

  • 将错误按“可重试/不可重试/未知”分级,配合回退策略与指数退避。
  • 仅对高等级(critical)错误推送告警,避免告警风暴;对retryable聚合统计上报。
  • 与性能追踪合并,形成“错误率、平均延迟、P95”三指标联动的SLO看板。

6. 生产级应用案例

6.1 法律信息结构化智能体

基于LangGraph在法律领域的实际应用,展示如何构建专业的法律信息处理系统:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
from typing import List, Dict, Optional
from pydantic import BaseModel, Field
from langgraph.prebuilt import create_react_agent, ToolNode
from langchain_core.tools import tool

class LegalEntity(BaseModel):
    """法律实体模型"""
    name: str = Field(description="实体名称")
    type: str = Field(description="实体类型:人员、组织、法条等")
    attributes: Dict[str, str] = Field(default_factory=dict, description="实体属性")
    relationships: List[str] = Field(default_factory=list, description="关联关系")

class LegalDocument(BaseModel):
    """法律文档模型"""
    title: str
    content: str
    entities: List[LegalEntity] = Field(default_factory=list)
    legal_provisions: List[str] = Field(default_factory=list)
    case_references: List[str] = Field(default_factory=list)

@tool
def extract_legal_entities(
    text: str,
    state: Annotated[dict, InjectedState(key="knowledge_base")]
) -> str:
    """提取法律实体和关系"""
    # 使用知识图谱辅助实体识别
    knowledge_base = state.get("knowledge_base", {})
    
    # 实体提取逻辑(结合NER和规则)
    entities = []
    
    # 使用预训练的法律NER模型
    ner_results = _extract_entities_with_ner(text)
    
    # 结合法律知识库进行实体链接
    for entity in ner_results:
        linked_entity = _link_to_knowledge_base(entity, knowledge_base)
        if linked_entity:
            entities.append(linked_entity)
    
    return json.dumps({
        "entities": [entity.dict() for entity in entities],
        "confidence_scores": _calculate_confidence_scores(entities),
    }, ensure_ascii=False, indent=2)

@tool
def analyze_legal_provisions(document_text: str) -> str:
    """分析法律条文和适用性"""
    # 法条识别和解析
    provisions = _identify_legal_provisions(document_text)
    
    # 适用性分析
    applicability = _analyze_applicability(provisions, document_text)
    
    return json.dumps({
        "provisions": provisions,
        "applicability_analysis": applicability,
        "recommendation": _generate_legal_recommendation(provisions, applicability)
    }, ensure_ascii=False, indent=2)

@tool
def structure_legal_information(
    entities: str,
    provisions: str,
    store: Annotated[BaseStore, InjectedStore(namespace="legal_kb")]
) -> str:
    """结构化法律信息并更新知识库"""
    entities_data = json.loads(entities)
    provisions_data = json.loads(provisions)
    
    # 构建结构化法律文档
    structured_doc = LegalDocument(
        title="分析结果",
        content="",
        entities=[LegalEntity(**entity) for entity in entities_data["entities"]],
        legal_provisions=provisions_data["provisions"],
    )
    
    # 更新知识库
    for entity in structured_doc.entities:
        store.put(
            ("entities", entity.name),
            entity.dict()
        )
    
    # 存储案例关联
    case_id = f"case_{int(time.time())}"
    store.put(
        ("cases", case_id),
        {
            "entities": entities_data,
            "provisions": provisions_data,
            "structured_document": structured_doc.dict(),
            "created_at": time.time(),
        }
    )
    
    return structured_doc.model_dump_json(indent=2, ensure_ascii=False)

# 创建专业的法律分析智能体
def create_legal_analysis_agent():
    """创建法律分析智能体"""
    from langchain_openai import ChatOpenAI
    
    # 专业模型配置
    model = ChatOpenAI(
        model="gpt-4-turbo-preview",
        temperature=0.1,  # 法律分析需要稳定性
    )
    
    # 法律专业工具
    legal_tools = [
        extract_legal_entities,
        analyze_legal_provisions,
        structure_legal_information,
    ]
    
    # 专业提示词
    legal_prompt = """你是一个专业的法律分析助手,具备以下能力:

1. 法律实体识别:准确识别人员、组织、法律条文等实体
2. 条文分析:分析相关法律条文的适用性和影响
3. 信息结构化:将非结构化的法律文本转换为结构化信息

分析时请遵循以下原则:
- 准确性优先:确保法律信息的准确性和完整性
- 引用规范:明确标注法律条文的出处和版本
- 关联分析:识别实体间的法律关系和影响

请使用提供的工具进行专业的法律分析。"""
    
    return create_react_agent(
        model=model,
        tools=legal_tools,
        prompt=legal_prompt,
        checkpointer=PostgresCheckpointSaver.from_conn_string(
            "postgresql://user:pass@localhost/legal_db"
        ),
        store=PostgresStore.from_conn_string(
            "postgresql://user:pass@localhost/legal_kb"
        ),
        name="LegalAnalysisAgent",
        version="v2",
    )

def _extract_entities_with_ner(text: str) -> List[Dict]:
    """使用NER模型提取实体"""
    # 这里使用专业的法律NER模型
    # 实际实现中会调用如spaCy、transformers等库
    entities = []
    
    # 模拟NER结果
    mock_entities = [
        {"text": "最高人民法院", "label": "ORGANIZATION", "start": 0, "end": 6},
        {"text": "民法典", "label": "LAW", "start": 20, "end": 23},
        {"text": "第一百条", "label": "PROVISION", "start": 24, "end": 28},
    ]
    
    return mock_entities

def _link_to_knowledge_base(entity: Dict, knowledge_base: Dict) -> Optional[LegalEntity]:
    """将实体链接到知识库"""
    entity_name = entity["text"]
    entity_type = entity["label"]
    
    # 在知识库中查找匹配实体
    kb_matches = knowledge_base.get(entity_type.lower(), {})
    
    if entity_name in kb_matches:
        kb_info = kb_matches[entity_name]
        return LegalEntity(
            name=entity_name,
            type=entity_type,
            attributes=kb_info.get("attributes", {}),
            relationships=kb_info.get("relationships", [])
        )
    
    return LegalEntity(name=entity_name, type=entity_type)

法律智能体特点

  • 专业知识集成:结合法律知识库和NER模型
  • 结构化输出:将非结构化文本转换为标准化数据模型
  • 知识图谱支持:通过实体链接建立法律概念间的关联
  • 持久化存储:自动更新和维护法律知识库

6.2 代码生成与调试智能体

展示LangGraph在代码生成领域的强大能力,实现反思式代码生成:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
class CodeGenerationState(TypedDict):
    """代码生成状态模式"""
    messages: Annotated[list[BaseMessage], add_messages]
    current_code: Optional[str]
    test_results: Optional[dict]
    error_history: list[str]
    iteration_count: int
    max_iterations: int

@tool
def generate_code(
    requirements: str,
    programming_language: str = "python",
    state: Annotated[dict, InjectedState()],
) -> str:
    """生成代码工具"""
    iteration = state.get("iteration_count", 0)
    error_history = state.get("error_history", [])
    
    # 构建上下文提示
    context_prompt = f"""
要求:{requirements}
语言:{programming_language}

之前的错误(如果有):
{chr(10).join(error_history[-3:]) if error_history else "无"}

请生成满足要求的代码,注意避免之前的错误。
"""
    
    # 调用代码生成模型
    generated_code = _call_code_generation_model(context_prompt)
    
    return generated_code

@tool
def execute_and_test_code(
    code: str,
    test_cases: Optional[str] = None
) -> str:
    """执行和测试代码"""
    results = {
        "execution_success": False,
        "test_success": False,
        "errors": [],
        "output": "",
        "test_results": {},
    }
    
    try:
        # 安全代码执行环境
        execution_globals = {
            "__builtins__": {
                "print": print,
                "len": len,
                "str": str,
                "int": int,
                "float": float,
                "list": list,
                "dict": dict,
                # 限制可用的内置函数
            }
        }
        
        # 执行代码
        exec(code, execution_globals)
        results["execution_success"] = True
        
        # 运行测试用例(如果提供)
        if test_cases:
            test_results = _run_test_cases(code, test_cases, execution_globals)
            results["test_results"] = test_results
            results["test_success"] = all(
                test["passed"] for test in test_results.values()
            )
        
    except Exception as e:
        results["errors"].append({
            "type": type(e).__name__,
            "message": str(e),
            "traceback": traceback.format_exc(),
        })
    
    return json.dumps(results, ensure_ascii=False, indent=2)

@tool
def reflect_on_errors(
    code: str,
    test_results: str,
    state: Annotated[dict, InjectedState()],
) -> str:
    """反思代码错误并提供改进建议"""
    results = json.loads(test_results)
    current_iteration = state.get("iteration_count", 0)
    
    if results["execution_success"] and results.get("test_success", True):
        return "代码执行成功,无需反思"
    
    # 错误分析
    error_analysis = {
        "iteration": current_iteration,
        "error_types": [],
        "improvement_suggestions": [],
        "code_issues": [],
    }
    
    # 分析执行错误
    for error in results.get("errors", []):
        error_analysis["error_types"].append(error["type"])
        error_analysis["improvement_suggestions"].append(
            _generate_improvement_suggestion(error, code)
        )
    
    # 分析测试失败
    for test_name, test_result in results.get("test_results", {}).items():
        if not test_result["passed"]:
            error_analysis["code_issues"].append({
                "test": test_name,
                "expected": test_result.get("expected"),
                "actual": test_result.get("actual"),
                "suggestion": _suggest_test_fix(test_result, code),
            })
    
    return json.dumps(error_analysis, ensure_ascii=False, indent=2)

def create_code_generation_agent():
    """创建代码生成智能体"""
    graph = StateGraph(CodeGenerationState)
    
    # 代码生成节点
    def code_gen_node(state: CodeGenerationState) -> dict:
        messages = state["messages"]
        requirements = messages[-1].content
        
        code = generate_code.invoke({
            "requirements": requirements,
            "programming_language": "python"
        })
        
        return {
            "current_code": code,
            "iteration_count": state.get("iteration_count", 0) + 1,
        }
    
    # 测试执行节点
    def test_node(state: CodeGenerationState) -> dict:
        code = state["current_code"]
        test_results = execute_and_test_code.invoke({"code": code})
        
        results_data = json.loads(test_results)
        success = results_data["execution_success"] and results_data.get("test_success", True)
        
        return {
            "test_results": test_results,
            "success": success,
        }
    
    # 反思节点
    def reflection_node(state: CodeGenerationState) -> dict:
        code = state["current_code"]
        test_results = state["test_results"]
        
        reflection = reflect_on_errors.invoke({
            "code": code,
            "test_results": test_results
        })
        
        # 更新错误历史
        error_history = state.get("error_history", [])
        error_history.append(reflection)
        
        return {
            "error_history": error_history,
            "reflection": reflection,
        }
    
    # 构建图结构
    graph.add_node("generate", code_gen_node)
    graph.add_node("test", test_node)
    graph.add_node("reflect", reflection_node)
    
    # 设置流程
    graph.set_entry_point("generate")
    graph.add_edge("generate", "test")
    
    # 条件分支:成功则结束,失败则反思
    graph.add_conditional_edges(
        "test",
        lambda state: END if state.get("success") else (
            "reflect" if state.get("iteration_count", 0) < state.get("max_iterations", 3)
            else END
        ),
        {"reflect": "reflect", END: END}
    )
    
    graph.add_edge("reflect", "generate")
    
    return graph.compile(
        checkpointer=InMemorySaver(),
        interrupt_before=["reflect"],  # 允许人工介入反思过程
    )

def _generate_improvement_suggestion(error: dict, code: str) -> str:
    """生成代码改进建议"""
    error_type = error["type"]
    error_message = error["message"]
    
    suggestions = {
        "SyntaxError": "检查代码语法,特别是括号、引号匹配",
        "NameError": "检查变量名拼写和作用域",
        "TypeError": "检查数据类型和函数参数",
        "ImportError": "检查导入模块是否正确",
    }
    
    return suggestions.get(error_type, f"分析错误: {error_message}")

代码生成智能体特点

  • 迭代改进:通过反思机制持续优化代码质量
  • 错误学习:从历史错误中学习,避免重复问题
  • 测试驱动:自动执行测试用例验证代码正确性
  • 人工介入:在关键决策点支持人工审核

6.3 完整的智能体应用示例

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import asyncio
from typing import Annotated
from langchain_anthropic import ChatAnthropic
from langchain_core.tools import tool
from langgraph.prebuilt import create_react_agent, ToolNode, InjectedState, InjectedStore
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.store.memory import InMemoryStore

# 定义工具
@tool
def weather_tool(location: str) -> str:
    """获取指定位置的天气信息"""
    # 模拟天气API调用
    weather_data = {
        "北京": "晴天,温度25°C",
        "上海": "多云,温度22°C", 
        "广州": "小雨,温度28°C",
    }
    return weather_data.get(location, f"{location}的天气信息暂未获取到")

@tool
def calculator_tool(expression: str) -> str:
    """计算数学表达式"""
    try:
        # 安全的数学计算
        import ast
        import operator
        
        ops = {
            ast.Add: operator.add,
            ast.Sub: operator.sub,
            ast.Mult: operator.mul,
            ast.Div: operator.truediv,
            ast.Pow: operator.pow,
        }
        
        def eval_expr(expr):
            if isinstance(expr, ast.Num):
                return expr.n
            elif isinstance(expr, ast.BinOp):
                return ops[type(expr.op)](eval_expr(expr.left), eval_expr(expr.right))
            else:
                raise TypeError(expr)
        
        result = eval_expr(ast.parse(expression, mode='eval').body)
        return f"计算结果: {result}"
    except Exception as e:
        return f"计算错误: {e}"

@tool
def search_tool(
    query: str,
    state: Annotated[dict, InjectedState()],
    store: Annotated[InMemoryStore, InjectedStore(namespace="search")],
) -> str:
    """搜索信息工具(带状态和存储注入)"""
    # 从状态获取用户偏好
    user_id = state.get("user_id", "anonymous")
    search_history = store.get(("history", user_id)) or []
    
    # 模拟搜索
    search_results = f"搜索'{query}'的结果:相关信息已找到"
    
    # 保存搜索历史
    search_history.append({
        "query": query,
        "timestamp": time.time(),
    })
    store.put(("history", user_id), search_history[-10:])  # 保留最近10次搜索
    
    return search_results

async def main():
    """主程序演示"""
    
    # 创建模型
    model = ChatAnthropic(model="claude-3-sonnet-20240229")
    
    # 创建检查点保存器和存储
    checkpointer = InMemorySaver()
    store = InMemoryStore()
    
    # 创建智能体
    agent = create_react_agent(
        model=model,
        tools=[weather_tool, calculator_tool, search_tool],
        checkpointer=checkpointer,
        store=store,
        prompt="""你是一个有用的助手,可以:
1. 查询天气信息
2. 进行数学计算  
3. 搜索信息

请根据用户需求选择合适的工具来帮助回答问题。""",
        version="v2",
        name="MultiToolAssistant",
    )
    
    # 配置
    config = {
        "configurable": {
            "thread_id": "demo-session-1",
            "user_id": "user123",
        }
    }
    
    # 测试对话
    test_messages = [
        "北京今天天气怎么样?",
        "帮我计算 15 * 8 + 32",
        "搜索一下人工智能的最新发展",
        "再帮我算一下 (25 + 15) / 4",
    ]
    
    for user_input in test_messages:
        print(f"\n用户: {user_input}")
        print("助手:", end=" ")
        
        # 流式输出
        async for chunk in agent.astream(
            {"messages": [{"role": "user", "content": user_input}]},
            config=config,
            stream_mode="updates"
        ):
            for node_name, output in chunk.items():
                if "messages" in output:
                    for msg in output["messages"]:
                        if hasattr(msg, 'content') and msg.content:
                            print(msg.content, end="", flush=True)
        
        print()  # 换行

if __name__ == "__main__":
    asyncio.run(main())

6.2 最佳实践

  1. 工具设计原则

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    
    # 好的工具设计
    @tool
    def good_weather_tool(
        location: str,
        units: str = "celsius",
        include_forecast: bool = False
    ) -> str:
        """获取天气信息
    
        Args:
            location: 位置名称
            units: 温度单位 (celsius/fahrenheit)
            include_forecast: 是否包含预报
        """
        # 清晰的功能、完整的参数说明、合理的默认值
    
  2. 错误处理策略

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    
    # 健壮的工具实现
    @tool  
    def robust_api_tool(endpoint: str) -> str:
        """调用外部API"""
        try:
            response = requests.get(endpoint, timeout=30)
            response.raise_for_status()
            return response.text
        except requests.Timeout:
            return "请求超时,请稍后重试"
        except requests.RequestException as e:
            return f"API调用失败: {e}"
    
  3. 性能优化

    1
    2
    3
    4
    5
    6
    7
    
    # 使用缓存和批量处理
    @tool
    def optimized_batch_tool(items: list[str]) -> str:
        """批量处理项目"""
        # 批量处理而不是逐个处理
        results = process_items_batch(items)
        return json.dumps(results)
    

7. 总结

LangGraph预构建模块通过封装常见模式和提供高级API,大大简化了智能体应用的开发:

7.1 核心优势

  • 快速原型:create_react_agent让ReAct智能体开发变得简单直接
  • 灵活工具系统:ToolNode支持复杂的工具执行和依赖注入
  • 生产就绪:内置错误处理、监控、验证等企业级特性
  • 可扩展架构:清晰的扩展点支持自定义功能

7.2 设计亮点

  1. 模块化设计:每个组件职责明确,可独立使用或组合
  2. 类型安全:完整的类型注解确保开发时的类型检查
  3. 异步支持:原生支持异步执行,适合IO密集型应用
  4. 监控友好:内置性能追踪和监控钩子

7.3 适用场景

  • 聊天机器人:需要工具调用能力的对话系统
  • 任务自动化:复杂的多步骤任务执行
  • 知识问答:结合外部数据源的智能问答
  • 工作流代理:业务流程的智能化处理

通过深入理解这些预构建组件,开发者能够快速构建高质量的AI应用,同时保持足够的灵活性来满足特定需求。


创建时间: 2025年09月13日

本文由 tommie blog 原创发布