1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
| @register_provider([
LLMType.OPENAI, LLMType.FIREWORKS, LLMType.OPEN_LLM,
LLMType.MOONSHOT, LLMType.MISTRAL, LLMType.YI,
LLMType.OPEN_ROUTER, LLMType.DEEPSEEK, LLMType.SILICONFLOW,
LLMType.OPENROUTER, LLMType.LLAMA_API,
])
class OpenAILLM(BaseLLM):
"""OpenAI兼容的LLM提供商"""
def __init__(self, config: LLMConfig):
self.config = config
self._init_client()
self.auto_max_tokens = False
self.cost_manager: Optional[CostManager] = None
def _init_client(self):
"""初始化OpenAI客户端"""
self.model = self.config.model
self.pricing_plan = self.config.pricing_plan or self.model
kwargs = self._make_client_kwargs()
self.aclient = AsyncOpenAI(**kwargs)
def _make_client_kwargs(self) -> dict:
"""构建客户端参数"""
kwargs = {
"api_key": self.config.api_key,
"base_url": self.config.base_url
}
# 代理支持
if proxy_params := self._get_proxy_params():
kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params)
return kwargs
def _get_proxy_params(self) -> dict:
"""获取代理参数"""
params = {}
if self.config.proxy:
params = {"proxy": self.config.proxy}
if self.config.base_url:
params["base_url"] = self.config.base_url
return params
async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str:
"""流式聊天完成"""
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
**self._cons_kwargs(messages, timeout=self.get_timeout(timeout)),
stream=True
)
usage = None
collected_messages = []
collected_reasoning_messages = []
has_finished = False
async for chunk in response:
if not chunk.choices:
continue
choice0 = chunk.choices[0]
choice_delta = choice0.delta
# 处理推理内容(如DeepSeek)
if hasattr(choice_delta, "reasoning_content") and choice_delta.reasoning_content:
collected_reasoning_messages.append(choice_delta.reasoning_content)
continue
# 提取消息内容
chunk_message = choice_delta.content or ""
finish_reason = choice0.finish_reason if hasattr(choice0, "finish_reason") else None
log_llm_stream(chunk_message)
collected_messages.append(chunk_message)
# 处理使用量信息
chunk_has_usage = hasattr(chunk, "usage") and chunk.usage
if has_finished:
if chunk_has_usage:
usage = CompletionUsage(**chunk.usage) if isinstance(chunk.usage, dict) else chunk.usage
if finish_reason:
if chunk_has_usage:
usage = CompletionUsage(**chunk.usage) if isinstance(chunk.usage, dict) else chunk.usage
elif hasattr(choice0, "usage"):
usage = CompletionUsage(**choice0.usage)
has_finished = True
log_llm_stream("\n")
full_reply_content = "".join(collected_messages)
# 保存推理内容
if collected_reasoning_messages:
self.reasoning_content = "".join(collected_reasoning_messages)
# 计算使用量(如果服务未提供)
if not usage:
usage = self._calc_usage(messages, full_reply_content)
self._update_costs(usage)
return full_reply_content
def _cons_kwargs(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT, **extra_kwargs) -> dict:
"""构建请求参数"""
kwargs = {
"messages": messages,
"max_tokens": self._get_max_tokens(messages),
"temperature": self.config.temperature,
"model": self.model,
"timeout": self.get_timeout(timeout),
}
# o1系列模型特殊处理
if "o1-" in self.model:
kwargs["temperature"] = 1
kwargs.pop("max_tokens", None) # o1不支持max_tokens
# 推理模式支持
if self.config.reasoning and "reasoning" in extra_kwargs:
kwargs["reasoning"] = extra_kwargs["reasoning"]
kwargs["max_completion_tokens"] = self.config.reasoning_max_token
kwargs.update(extra_kwargs)
return kwargs
def _get_max_tokens(self, messages: list[dict]) -> int:
"""获取最大token数"""
if self.auto_max_tokens:
return get_max_completion_tokens(messages, self.model, self.config.max_token)
return self.config.max_token
def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage:
"""计算token使用量"""
prompt_tokens = count_message_tokens(messages, self.model)
completion_tokens = count_output_tokens(rsp, self.model)
return CompletionUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
@retry(
stop=stop_after_attempt(3),
wait=wait_random_exponential(min=1, max=60),
after=after_log(logger, logger.level("WARNING").name),
retry=retry_if_exception_type((APIConnectionError, ConnectionError)),
retry_error_callback=log_and_reraise,
)
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> ChatCompletion:
"""异步聊天完成"""
return await self._achat_completion(messages, timeout)
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> ChatCompletion:
"""内部聊天完成实现"""
rsp: ChatCompletion = await self.aclient.chat.completions.create(
**self._cons_kwargs(messages, timeout=self.get_timeout(timeout))
)
self._update_costs(rsp.usage)
return rsp
|