1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
| class PerformanceCallbackHandler(BaseCallbackHandler):
"""性能监控回调处理器"""
def __init__(self):
self.start_times: Dict[UUID, float] = {}
self.metrics: Dict[str, List[float]] = defaultdict(list)
self.error_counts: Dict[str, int] = defaultdict(int)
self.total_counts: Dict[str, int] = defaultdict(int)
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], *, run_id: UUID, **kwargs: Any) -> None:
"""记录 LLM 开始时间"""
self.start_times[run_id] = time.time()
self.total_counts["llm"] += 1
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> None:
"""计算 LLM 执行时间和 Token 使用"""
# 计算执行时间
if run_id in self.start_times:
duration = time.time() - self.start_times[run_id]
self.metrics["llm_duration"].append(duration)
del self.start_times[run_id]
# 记录 Token 使用
if response.llm_output and "token_usage" in response.llm_output:
token_usage = response.llm_output["token_usage"]
for key, value in token_usage.items():
self.metrics[f"tokens_{key}"].append(value)
def on_llm_error(self, error: Union[Exception, KeyboardInterrupt], *, run_id: UUID, **kwargs: Any) -> None:
"""记录 LLM 错误"""
self.error_counts["llm"] += 1
# 清理开始时间
if run_id in self.start_times:
del self.start_times[run_id]
def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID, **kwargs: Any) -> None:
"""记录工具开始时间"""
self.start_times[run_id] = time.time()
tool_name = serialized.get("name", "unknown_tool")
self.total_counts[f"tool_{tool_name}"] += 1
def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> None:
"""计算工具执行时间"""
if run_id in self.start_times:
duration = time.time() - self.start_times[run_id]
self.metrics["tool_duration"].append(duration)
del self.start_times[run_id]
def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], *, run_id: UUID, **kwargs: Any) -> None:
"""记录工具错误"""
self.error_counts["tool"] += 1
if run_id in self.start_times:
del self.start_times[run_id]
def get_metrics_summary(self) -> Dict[str, Any]:
"""获取指标摘要"""
summary = {}
# 性能指标
for metric_name, values in self.metrics.items():
if values:
summary[metric_name] = {
"count": len(values),
"total": sum(values),
"avg": sum(values) / len(values),
"min": min(values),
"max": max(values),
"p50": np.percentile(values, 50) if len(values) > 1 else values[0],
"p95": np.percentile(values, 95) if len(values) > 1 else values[0],
"p99": np.percentile(values, 99) if len(values) > 1 else values[0],
}
# 错误率
for component, error_count in self.error_counts.items():
total_count = self.total_counts.get(component, 0)
if total_count > 0:
summary[f"{component}_error_rate"] = error_count / total_count
return summary
def reset_metrics(self):
"""重置指标"""
self.start_times.clear()
self.metrics.clear()
self.error_counts.clear()
self.total_counts.clear()
|