概述

本文档基于网上多篇深度的LangChain源码分析文章,汇总了企业级应用中的高级实践模式。内容涵盖安全与隐私保护、多模态集成、智能负载均衡、性能优化等关键主题,为开发者提供生产环境中的最佳实践指导。

1. 安全与隐私保护机制

1.1 数据加密与隐私保护

基于深度源码分析,LangChain在企业应用中需要完善的安全机制:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import hashlib
import base64
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
import os
import re
from typing import Dict, Any, Optional

class LangChainSecurityManager:
    """LangChain安全管理器

    参考:LangChain解析器与下游任务集成的源码实现剖析
    https://blog.csdn.net/qq_28540861/article/details/149057817
    """

    def __init__(self, master_key: Optional[str] = None):
        self.master_key = master_key or os.environ.get('LANGCHAIN_MASTER_KEY')
        if not self.master_key:
            raise ValueError("必须提供主密钥")

        self.cipher_suite = self._create_cipher_suite()
        self.sensitive_patterns = [
            r'\b\d{4}[-\s]?\d{4}[-\s]?\d{4}[-\s]?\d{4}\b',  # 信用卡号
            r'\b\d{3}-\d{2}-\d{4}\b',  # SSN
            r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',  # 邮箱
            r'\b\d{11}\b',  # 手机号
        ]

    def _create_cipher_suite(self) -> Fernet:
        """创建加密套件"""
        password = self.master_key.encode()
        salt = b'langchain_salt_2024'  # 在生产环境中应使用随机salt

        kdf = PBKDF2HMAC(
            algorithm=hashes.SHA256(),
            length=32,
            salt=salt,
            iterations=100000,
        )

        key = base64.urlsafe_b64encode(kdf.derive(password))
        return Fernet(key)

    def encrypt_sensitive_data(self, data: str) -> str:
        """加密敏感数据"""
        try:
            encrypted_data = self.cipher_suite.encrypt(data.encode())
            return base64.urlsafe_b64encode(encrypted_data).decode()
        except Exception as e:
            raise ValueError(f"数据加密失败: {str(e)}")

    def decrypt_sensitive_data(self, encrypted_data: str) -> str:
        """解密敏感数据"""
        try:
            encrypted_bytes = base64.urlsafe_b64decode(encrypted_data.encode())
            decrypted_data = self.cipher_suite.decrypt(encrypted_bytes)
            return decrypted_data.decode()
        except Exception as e:
            raise ValueError(f"数据解密失败: {str(e)}")

    def sanitize_input(self, text: str) -> str:
        """清理输入中的敏感信息"""
        sanitized_text = text

        for pattern in self.sensitive_patterns:
            # 替换敏感信息为占位符
            sanitized_text = re.sub(pattern, '[REDACTED]', sanitized_text)

        return sanitized_text

    def create_secure_prompt_template(self, template: str) -> 'SecurePromptTemplate':
        """创建安全的提示模板"""
        return SecurePromptTemplate(template, self)

class SecurePromptTemplate:
    """安全的提示模板"""

    def __init__(self, template: str, security_manager: LangChainSecurityManager):
        self.template = template
        self.security_manager = security_manager

    def format(self, **kwargs) -> str:
        """格式化提示,自动清理敏感信息"""
        sanitized_kwargs = {}

        for key, value in kwargs.items():
            if isinstance(value, str):
                sanitized_kwargs[key] = self.security_manager.sanitize_input(value)
            else:
                sanitized_kwargs[key] = value

        return self.template.format(**sanitized_kwargs)

# 使用示例
def create_secure_langchain_demo():
    """安全LangChain使用示例"""

    # 初始化安全管理器
    security_manager = LangChainSecurityManager("your-master-key-here")

    # 创建安全的提示模板
    secure_template = security_manager.create_secure_prompt_template("""
基于以下用户信息回答问题:

用户信息:{user_info}
问题:{question}

请注意保护用户隐私,不要在回答中包含敏感信息。

回答:
""")

    # 测试敏感信息处理
    user_info = "我的邮箱是john.doe@example.com,信用卡号是1234-5678-9012-3456"
    question = "请帮我分析一下我的账户情况"

    # 格式化提示(自动清理敏感信息)
    safe_prompt = secure_template.format(
        user_info=user_info,
        question=question
    )

    print("安全处理后的提示:")
    print(safe_prompt)

    # 加密存储敏感数据
    encrypted_info = security_manager.encrypt_sensitive_data(user_info)
    print(f"加密后的用户信息:{encrypted_info}")

    # 解密数据
    decrypted_info = security_manager.decrypt_sensitive_data(encrypted_info)
    print(f"解密后的用户信息:{decrypted_info}")

if __name__ == "__main__":
    create_secure_langchain_demo()

1.2 访问控制与权限管理

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from enum import Enum
from functools import wraps
from typing import List, Dict, Any, Callable
import jwt
import time

class Permission(Enum):
    """权限枚举"""
    READ_DOCUMENTS = "read_documents"
    WRITE_DOCUMENTS = "write_documents"
    EXECUTE_TOOLS = "execute_tools"
    MANAGE_AGENTS = "manage_agents"
    ADMIN_ACCESS = "admin_access"

class Role(Enum):
    """角色枚举"""
    GUEST = "guest"
    USER = "user"
    DEVELOPER = "developer"
    ADMIN = "admin"

class AccessControlManager:
    """访问控制管理器"""

    def __init__(self, secret_key: str):
        self.secret_key = secret_key
        self.role_permissions = {
            Role.GUEST: [Permission.READ_DOCUMENTS],
            Role.USER: [Permission.READ_DOCUMENTS, Permission.EXECUTE_TOOLS],
            Role.DEVELOPER: [
                Permission.READ_DOCUMENTS,
                Permission.WRITE_DOCUMENTS,
                Permission.EXECUTE_TOOLS,
                Permission.MANAGE_AGENTS
            ],
            Role.ADMIN: [
                Permission.READ_DOCUMENTS,
                Permission.WRITE_DOCUMENTS,
                Permission.EXECUTE_TOOLS,
                Permission.MANAGE_AGENTS,
                Permission.ADMIN_ACCESS
            ]
        }

    def create_token(self, user_id: str, role: Role, expires_in: int = 3600) -> str:
        """创建JWT令牌"""
        payload = {
            'user_id': user_id,
            'role': role.value,
            'permissions': [p.value for p in self.role_permissions[role]],
            'exp': time.time() + expires_in,
            'iat': time.time()
        }

        return jwt.encode(payload, self.secret_key, algorithm='HS256')

    def verify_token(self, token: str) -> Dict[str, Any]:
        """验证JWT令牌"""
        try:
            payload = jwt.decode(token, self.secret_key, algorithms=['HS256'])
            return payload
        except jwt.ExpiredSignatureError:
            raise ValueError("令牌已过期")
        except jwt.InvalidTokenError:
            raise ValueError("无效的令牌")

    def check_permission(self, token: str, required_permission: Permission) -> bool:
        """检查权限"""
        try:
            payload = self.verify_token(token)
            user_permissions = payload.get('permissions', [])
            return required_permission.value in user_permissions
        except ValueError:
            return False

    def require_permission(self, permission: Permission):
        """权限装饰器"""
        def decorator(func: Callable):
            @wraps(func)
            def wrapper(*args, **kwargs):
                # 从kwargs中获取token,或从请求头中获取
                token = kwargs.get('auth_token') or getattr(args[0], 'auth_token', None)

                if not token:
                    raise PermissionError("缺少认证令牌")

                if not self.check_permission(token, permission):
                    raise PermissionError(f"缺少必要权限: {permission.value}")

                return func(*args, **kwargs)
            return wrapper
        return decorator

class SecureLangChainAgent:
    """安全的LangChain Agent"""

    def __init__(self, access_control: AccessControlManager):
        self.access_control = access_control
        self.auth_token = None

    def authenticate(self, token: str):
        """认证用户"""
        self.auth_token = token
        return self.access_control.verify_token(token)

    @AccessControlManager.require_permission(Permission.READ_DOCUMENTS)
    def read_documents(self, query: str, auth_token: str = None) -> List[str]:
        """读取文档(需要读取权限)"""
        # 实际的文档读取逻辑
        return [f"文档内容:{query}"]

    @AccessControlManager.require_permission(Permission.EXECUTE_TOOLS)
    def execute_tool(self, tool_name: str, params: Dict[str, Any], auth_token: str = None) -> Any:
        """执行工具(需要执行权限)"""
        # 实际的工具执行逻辑
        return f"工具 {tool_name} 执行结果:{params}"

    @AccessControlManager.require_permission(Permission.MANAGE_AGENTS)
    def create_agent(self, agent_config: Dict[str, Any], auth_token: str = None) -> str:
        """创建代理(需要管理权限)"""
        # 实际的代理创建逻辑
        return f"代理已创建:{agent_config.get('name', 'unnamed')}"

1.3 合规与数据主权(GDPR/CCPA/数据驻留)

在企业落地中,除安全之外需满足合规与数据主权要求:

  • 合规基线:GDPR/CCPA/ISO 27001;记录处理目的、数据最小化、可追溯删除(Right to be Forgotten)。
  • 数据分类与标记:按 PII/敏感等级打标,影响存储位置、加密强度与访问控制。
  • 数据驻留(Data Residency):按租户/地域隔离存储与处理(如 EU-only)。
  • 密钥与KMS:At-Rest/In-Transit 加密,集中管理密钥与轮换(Rotation)。
  • 数据保留策略:按法规/业务设置 TTL 与归档;实现可审计的删改记录。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from dataclasses import dataclass
from typing import Literal, Dict

Region = Literal["eu", "us", "apac"]

@dataclass
class DataResidencyConfig:
    tenant_id: str
    residency: Region  # 数据驻留地域
    pii_level: Literal["none", "low", "medium", "high"]
    encrypt_at_rest: bool = True
    kms_key_id: str | None = None
    retention_days: int = 180

    def storage_bucket(self) -> str:
        # 依据地域与租户路由到不同的对象存储/数据库实例
        return f"lc-{self.residency}-tenant-{self.tenant_id}"

    def should_mask_output(self) -> bool:
        return self.pii_level in ("medium", "high")

# 使用示例
cfg = DataResidencyConfig(tenant_id="acme", residency="eu", pii_level="high", kms_key_id="kms-eu-123")
bucket = cfg.storage_bucket()  # lc-eu-tenant-acme

1.4 审计日志与取证(Audit & Forensics)

通过回调将关键事件落盘/上报,满足合规审计与事后取证:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import json, time, os
from uuid import uuid4
from typing import Any, Dict, List, Optional
from langchain_core.callbacks import BaseCallbackHandler

class AuditCallbackHandler(BaseCallbackHandler):
    """结构化审计:链/LLM/工具 关键事件持久化(JSON Lines)。
    - 脱敏:对可能含PII的字段做脱敏/哈希
    - 取证:记录 run_id、parent_run_id、时间戳、地域/租户标签
    """

    def __init__(self, path: str = "./logs/audit.jsonl", tenant_id: str = "default", region: str = "eu"):
        os.makedirs(os.path.dirname(path), exist_ok=True)
        self.path = path
        self.tenant_id = tenant_id
        self.region = region

    def _write(self, record: Dict[str, Any]):
        record.setdefault("ts", time.time())
        record.setdefault("tenant_id", self.tenant_id)
        record.setdefault("region", self.region)
        with open(self.path, "a", encoding="utf-8") as f:
            f.write(json.dumps(record, ensure_ascii=False) + "\n")

    # 示例:LLM 开始/结束事件
    def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], *, run_id, **kwargs):
        self._write({
            "event": "llm_start",
            "run_id": str(run_id),
            "model": serialized.get("id"),
            "prompt_preview": (prompts[0][:200] if prompts else ""),  # 做截断+必要脱敏
        })

    def on_llm_end(self, response, *, run_id, **kwargs):
        usage = {}
        if getattr(response, "llm_output", None):
            usage = response.llm_output.get("token_usage", {})
        self._write({
            "event": "llm_end",
            "run_id": str(run_id),
            "usage": usage,
        })

    def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], *, run_id, **kwargs):
        self._write({
            "event": "chain_start",
            "run_id": str(run_id),
            "chain": serialized.get("id"),
        })

    def on_chain_end(self, outputs: Dict[str, Any], *, run_id, **kwargs):
        self._write({
            "event": "chain_end",
            "run_id": str(run_id),
        })

    def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id, **kwargs):
        self._write({
            "event": "tool_start",
            "run_id": str(run_id),
            "tool": serialized.get("name"),
        })

    def on_tool_end(self, output: str, *, run_id, **kwargs):
        self._write({
            "event": "tool_end",
            "run_id": str(run_id),
        })

# 用法:在调用链上绑定
# chain.invoke(inputs, config={"callbacks": [AuditCallbackHandler(path="/var/log/lc/audit.jsonl", tenant_id="acme", region="eu")]})

2. 多模态集成实现

2.1 多模态聊天模型

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
from typing import Any, Dict, List, Optional, Union
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from langchain_core.outputs import ChatResult, ChatGeneration
import base64
import requests
from PIL import Image
import io

class MultiModalChatModel(BaseChatModel):
    """多模态聊天模型集成

    参考:LangChain 核心概念深度解析与实战指南
    https://blog.csdn.net/jkgSFS/article/details/145068612
    """

    def __init__(
        self,
        model_name: str = "gpt-4-vision-preview",
        api_key: Optional[str] = None,
        base_url: Optional[str] = None,
        max_tokens: int = 1000,
        temperature: float = 0.1,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.model_name = model_name
        self.api_key = api_key
        self.base_url = base_url or "https://api.openai.com/v1"
        self.max_tokens = max_tokens
        self.temperature = temperature

        # 支持的图像格式
        self.supported_image_formats = {'.jpg', '.jpeg', '.png', '.gif', '.webp'}

        # 模态处理器注册
        self.modality_processors = {
            'text': self._process_text,
            'image': self._process_image,
            'audio': self._process_audio,
            'video': self._process_video
        }

    def _generate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[Any] = None,
        **kwargs: Any,
    ) -> ChatResult:
        """生成多模态响应"""

        # 处理多模态消息
        processed_messages = []
        for message in messages:
            processed_msg = self._process_multimodal_message(message)
            processed_messages.append(processed_msg)

        # 构建API请求
        request_data = {
            "model": self.model_name,
            "messages": processed_messages,
            "max_tokens": self.max_tokens,
            "temperature": self.temperature,
        }

        if stop:
            request_data["stop"] = stop

        # 发送请求
        headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json"
        }

        try:
            response = requests.post(
                f"{self.base_url}/chat/completions",
                json=request_data,
                headers=headers,
                timeout=60
            )
            response.raise_for_status()

            result = response.json()

            # 解析响应
            choice = result["choices"][0]
            message_content = choice["message"]["content"]

            # 创建生成结果
            generation = ChatGeneration(
                message=AIMessage(content=message_content),
                generation_info={
                    "finish_reason": choice.get("finish_reason"),
                    "model": result.get("model"),
                    "usage": result.get("usage", {})
                }
            )

            return ChatResult(generations=[generation])

        except Exception as e:
            raise ValueError(f"多模态模型调用失败: {str(e)}")

    def _process_multimodal_message(self, message: BaseMessage) -> Dict[str, Any]:
        """处理多模态消息"""

        if isinstance(message, HumanMessage):
            content = message.content

            # 检查是否包含多模态内容
            if isinstance(content, str):
                # 纯文本消息
                return {
                    "role": "user",
                    "content": content
                }
            elif isinstance(content, list):
                # 多模态内容列表
                processed_content = []

                for item in content:
                    if isinstance(item, dict):
                        modality_type = item.get("type", "text")
                        processor = self.modality_processors.get(modality_type)

                        if processor:
                            processed_item = processor(item)
                            processed_content.append(processed_item)
                        else:
                            # 未知模态类型,作为文本处理
                            processed_content.append({
                                "type": "text",
                                "text": str(item)
                            })
                    else:
                        # 非字典项,作为文本处理
                        processed_content.append({
                            "type": "text",
                            "text": str(item)
                        })

                return {
                    "role": "user",
                    "content": processed_content
                }

        elif isinstance(message, AIMessage):
            return {
                "role": "assistant",
                "content": message.content
            }

        else:
            # 其他消息类型
            return {
                "role": "user",
                "content": str(message.content)
            }

    def _process_text(self, item: Dict[str, Any]) -> Dict[str, Any]:
        """处理文本模态"""
        return {
            "type": "text",
            "text": item.get("text", "")
        }

    def _process_image(self, item: Dict[str, Any]) -> Dict[str, Any]:
        """处理图像模态"""
        image_data = item.get("image_url") or item.get("image")

        if isinstance(image_data, str):
            if image_data.startswith("http"):
                # 网络图片URL
                return {
                    "type": "image_url",
                    "image_url": {
                        "url": image_data,
                        "detail": item.get("detail", "auto")
                    }
                }
            elif image_data.startswith("data:image"):
                # Base64编码的图片
                return {
                    "type": "image_url",
                    "image_url": {
                        "url": image_data,
                        "detail": item.get("detail", "auto")
                    }
                }
            else:
                # 本地文件路径
                encoded_image = self._encode_image_file(image_data)
                return {
                    "type": "image_url",
                    "image_url": {
                        "url": f"data:image/jpeg;base64,{encoded_image}",
                        "detail": item.get("detail", "auto")
                    }
                }

        return {
            "type": "text",
            "text": "[无法处理的图像数据]"
        }

    def _process_audio(self, item: Dict[str, Any]) -> Dict[str, Any]:
        """处理音频模态(暂不支持,转为文本描述)"""
        return {
            "type": "text",
            "text": f"[音频文件: {item.get('audio', 'unknown')}]"
        }

    def _process_video(self, item: Dict[str, Any]) -> Dict[str, Any]:
        """处理视频模态(暂不支持,转为文本描述)"""
        return {
            "type": "text",
            "text": f"[视频文件: {item.get('video', 'unknown')}]"
        }

    def _encode_image_file(self, image_path: str) -> str:
        """编码本地图像文件为Base64"""
        try:
            with open(image_path, "rb") as image_file:
                encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
            return encoded_string
        except Exception as e:
            raise ValueError(f"无法编码图像文件 {image_path}: {str(e)}")

    @property
    def _llm_type(self) -> str:
        return "multimodal_chat"

# 使用示例
def demo_multimodal_integration():
    """多模态集成演示"""

    # 初始化多模态模型
    multimodal_model = MultiModalChatModel(
        model_name="gpt-4-vision-preview",
        api_key="your-api-key-here"
    )

    # 创建多模态消息
    multimodal_message = HumanMessage(content=[
        {
            "type": "text",
            "text": "请分析这张图片中的内容,并描述你看到的主要元素。"
        },
        {
            "type": "image",
            "image_url": "https://example.com/image.jpg",
            "detail": "high"
        }
    ])

    # 调用模型
    try:
        result = multimodal_model._generate([multimodal_message])
        print(f"多模态分析结果: {result.generations[0].message.content}")
    except Exception as e:
        print(f"多模态调用失败: {e}")

if __name__ == "__main__":
    demo_multimodal_integration()

2.2 输出安全过滤与SSE流式对接

生产中常见双目标:一边流式渲染响应(SSE),一边做输出过滤(敏感词/PII/政策)。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from typing import AsyncIterator, Callable
import asyncio
import re

SENSITIVE_PATTERNS = [
    re.compile(r"\b\d{3}-\d{2}-\d{4}\b"),            # SSN
    re.compile(r"\b\d{16}\b"),                         # 粗略信用卡
    re.compile(r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}"),
]

def sanitize_chunk(text: str) -> str:
    for p in SENSITIVE_PATTERNS:
        text = p.sub("[REDACTED]", text)
    return text

async def astream_with_safety(chain, payload: dict, *, on_chunk: Callable[[str], None]) -> str:
    """边流式边过滤,返回最终完整文本。
    - chain.astream(...) 产生片段
    - 对每个片段做敏感信息过滤
    - 可在回调中推送到SSE
    """
    full = []
    async for chunk in chain.astream(payload):
        safe = sanitize_chunk(str(chunk))
        on_chunk(safe)
        full.append(safe)
        await asyncio.sleep(0)  # 让出事件循环
    return "".join(full)

# 示例:FastAPI SSE 推送(简化)
"""
from fastapi import FastAPI
from fastapi.responses import StreamingResponse

app = FastAPI()

@app.get("/stream")
async def stream_endpoint(q: str):
    async def event_source():
        async def push(data: str):
            yield f"data: {data}\n\n"

        # 假设 chain 为已构建的 LCEL 链
        final = await astream_with_safety(chain, {"question": q}, on_chunk=lambda s: None)
        # 这里简化为一次性返回,真实情况可逐chunk yield
        yield f"data: {final}\n\n"

    return StreamingResponse(event_source(), media_type="text/event-stream")
"""

3. 智能负载均衡与故障转移

3.1 负载均衡实现

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
from typing import List, Dict, Any, Optional, Callable
import random
import time
import threading
from dataclasses import dataclass
from enum import Enum
import logging

class ProviderStatus(Enum):
    """Provider状态枚举"""
    HEALTHY = "healthy"
    DEGRADED = "degraded"
    UNHEALTHY = "unhealthy"
    MAINTENANCE = "maintenance"

@dataclass
class ProviderMetrics:
    """Provider性能指标"""
    response_time: float = 0.0
    success_rate: float = 1.0
    error_count: int = 0
    total_requests: int = 0
    last_error_time: Optional[float] = None
    status: ProviderStatus = ProviderStatus.HEALTHY

class LoadBalancedChatModel:
    """负载均衡的聊天模型

    参考:逐行深入分析:langchain-ChatGLM项目的源码解读
    https://jishu.proginn.com/doc/298065111cfa69fe7
    """

    def __init__(
        self,
        providers: List[Dict[str, Any]],
        strategy: str = "round_robin",
        health_check_interval: int = 60,
        max_retries: int = 3,
        circuit_breaker_threshold: float = 0.5,
        **kwargs
    ):
        self.providers = {}
        self.provider_metrics = {}
        self.strategy = strategy
        self.health_check_interval = health_check_interval
        self.max_retries = max_retries
        self.circuit_breaker_threshold = circuit_breaker_threshold

        # 初始化providers
        for i, provider_config in enumerate(providers):
            provider_id = f"provider_{i}"
            self.providers[provider_id] = self._create_provider(provider_config)
            self.provider_metrics[provider_id] = ProviderMetrics()

        # 负载均衡策略
        self.current_index = 0
        self.strategy_lock = threading.Lock()

        # 健康检查
        self.health_check_thread = threading.Thread(
            target=self._health_check_loop,
            daemon=True
        )
        self.health_check_thread.start()

        self.logger = logging.getLogger(__name__)

    def _create_provider(self, config: Dict[str, Any]):
        """根据配置创建provider实例"""
        provider_type = config.get("type", "openai")

        if provider_type == "openai":
            from langchain_openai import ChatOpenAI
            return ChatOpenAI(**config.get("params", {}))
        elif provider_type == "anthropic":
            from langchain_anthropic import ChatAnthropic
            return ChatAnthropic(**config.get("params", {}))
        elif provider_type == "groq":
            from langchain_groq import ChatGroq
            return ChatGroq(**config.get("params", {}))
        else:
            raise ValueError(f"不支持的provider类型: {provider_type}")

    def _select_provider(self) -> Optional[str]:
        """根据策略选择provider"""

        # 过滤健康的providers
        healthy_providers = [
            pid for pid, metrics in self.provider_metrics.items()
            if metrics.status in [ProviderStatus.HEALTHY, ProviderStatus.DEGRADED]
        ]

        if not healthy_providers:
            self.logger.error("没有可用的健康providers")
            return None

        with self.strategy_lock:
            if self.strategy == "round_robin":
                return self._round_robin_select(healthy_providers)
            elif self.strategy == "weighted":
                return self._weighted_select(healthy_providers)
            elif self.strategy == "least_connections":
                return self._least_connections_select(healthy_providers)
            elif self.strategy == "fastest":
                return self._fastest_select(healthy_providers)
            else:
                return random.choice(healthy_providers)

    def _round_robin_select(self, providers: List[str]) -> str:
        """轮询选择"""
        if not providers:
            return None

        provider = providers[self.current_index % len(providers)]
        self.current_index += 1
        return provider

    def _weighted_select(self, providers: List[str]) -> str:
        """基于成功率的加权选择"""
        weights = []
        for pid in providers:
            metrics = self.provider_metrics[pid]
            # 权重基于成功率和响应时间
            weight = metrics.success_rate / max(metrics.response_time, 0.1)
            weights.append(weight)

        # 加权随机选择
        total_weight = sum(weights)
        if total_weight == 0:
            return random.choice(providers)

        rand_val = random.uniform(0, total_weight)
        cumulative = 0

        for i, weight in enumerate(weights):
            cumulative += weight
            if rand_val <= cumulative:
                return providers[i]

        return providers[-1]

    def _least_connections_select(self, providers: List[str]) -> str:
        """选择连接数最少的provider"""
        # 简化实现:选择错误数最少的
        min_errors = float('inf')
        best_provider = None

        for pid in providers:
            metrics = self.provider_metrics[pid]
            if metrics.error_count < min_errors:
                min_errors = metrics.error_count
                best_provider = pid

        return best_provider or providers[0]

    def _fastest_select(self, providers: List[str]) -> str:
        """选择响应最快的provider"""
        min_response_time = float('inf')
        fastest_provider = None

        for pid in providers:
            metrics = self.provider_metrics[pid]
            if metrics.response_time < min_response_time:
                min_response_time = metrics.response_time
                fastest_provider = pid

        return fastest_provider or providers[0]

    def _generate_with_fallback(
        self,
        messages: List[Any],
        **kwargs
    ) -> Any:
        """带故障转移的生成"""

        last_exception = None
        attempted_providers = set()

        for attempt in range(self.max_retries):
            # 选择provider
            provider_id = self._select_provider()

            if not provider_id or provider_id in attempted_providers:
                # 如果没有可用provider或已尝试过,跳出循环
                break

            attempted_providers.add(provider_id)
            provider = self.providers[provider_id]
            metrics = self.provider_metrics[provider_id]

            try:
                start_time = time.time()

                # 调用provider
                result = provider._generate(messages, **kwargs)

                # 更新成功指标
                response_time = time.time() - start_time
                self._update_success_metrics(provider_id, response_time)

                return result

            except Exception as e:
                last_exception = e

                # 更新失败指标
                self._update_failure_metrics(provider_id, e)

                self.logger.warning(
                    f"Provider {provider_id} 调用失败 (尝试 {attempt + 1}): {str(e)}"
                )

                # 如果还有重试机会,继续下一个provider
                continue

        # 所有provider都失败了
        if last_exception:
            raise last_exception
        else:
            raise RuntimeError("没有可用的provider")

    def _update_success_metrics(self, provider_id: str, response_time: float):
        """更新成功指标"""
        metrics = self.provider_metrics[provider_id]

        # 更新响应时间(使用指数移动平均)
        if metrics.total_requests == 0:
            metrics.response_time = response_time
        else:
            alpha = 0.1  # 平滑因子
            metrics.response_time = (
                alpha * response_time + (1 - alpha) * metrics.response_time
            )

        # 更新成功率
        metrics.total_requests += 1
        success_count = metrics.total_requests - metrics.error_count
        metrics.success_rate = success_count / metrics.total_requests

        # 更新状态
        if metrics.success_rate >= 0.95:
            metrics.status = ProviderStatus.HEALTHY
        elif metrics.success_rate >= 0.8:
            metrics.status = ProviderStatus.DEGRADED
        else:
            metrics.status = ProviderStatus.UNHEALTHY

    def _update_failure_metrics(self, provider_id: str, error: Exception):
        """更新失败指标"""
        metrics = self.provider_metrics[provider_id]

        metrics.error_count += 1
        metrics.total_requests += 1
        metrics.last_error_time = time.time()

        # 更新成功率
        success_count = metrics.total_requests - metrics.error_count
        metrics.success_rate = success_count / metrics.total_requests

        # 检查熔断器
        if metrics.success_rate < self.circuit_breaker_threshold:
            metrics.status = ProviderStatus.UNHEALTHY
            self.logger.error(
                f"Provider {provider_id} 触发熔断器,成功率: {metrics.success_rate:.2%}"
            )

    def _health_check_loop(self):
        """健康检查循环"""
        while True:
            try:
                time.sleep(self.health_check_interval)
                self._perform_health_checks()
            except Exception as e:
                self.logger.error(f"健康检查失败: {str(e)}")

    def _perform_health_checks(self):
        """执行健康检查"""
        from langchain_core.messages import HumanMessage

        test_message = [HumanMessage(content="Health check")]

        for provider_id, provider in self.providers.items():
            metrics = self.provider_metrics[provider_id]

            # 跳过维护状态的provider
            if metrics.status == ProviderStatus.MAINTENANCE:
                continue

            try:
                start_time = time.time()

                # 执行简单的健康检查
                provider._generate(test_message, max_tokens=1)

                response_time = time.time() - start_time

                # 如果之前是不健康状态,现在恢复了
                if metrics.status == ProviderStatus.UNHEALTHY:
                    metrics.status = ProviderStatus.DEGRADED
                    self.logger.info(f"Provider {provider_id} 健康检查通过,状态恢复")

                # 更新响应时间
                if metrics.total_requests > 0:
                    alpha = 0.1
                    metrics.response_time = (
                        alpha * response_time + (1 - alpha) * metrics.response_time
                    )

            except Exception as e:
                # 健康检查失败
                if metrics.status != ProviderStatus.UNHEALTHY:
                    metrics.status = ProviderStatus.UNHEALTHY
                    self.logger.warning(f"Provider {provider_id} 健康检查失败: {str(e)}")

    def get_provider_stats(self) -> Dict[str, Dict[str, Any]]:
        """获取所有provider的统计信息"""
        stats = {}

        for provider_id, metrics in self.provider_metrics.items():
            stats[provider_id] = {
                "status": metrics.status.value,
                "success_rate": f"{metrics.success_rate:.2%}",
                "response_time": f"{metrics.response_time:.3f}s",
                "error_count": metrics.error_count,
                "total_requests": metrics.total_requests,
                "last_error_time": metrics.last_error_time
            }

        return stats

3.2 可靠性控制:超时/重试/熔断/限流与配额路由

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import time
import asyncio
from typing import Callable, Any

class RetryPolicy:
    def __init__(self, max_attempts: int = 3, base_delay: float = 0.2, jitter: float = 0.1):
        self.max_attempts = max_attempts
        self.base_delay = base_delay
        self.jitter = jitter

    async def aretry(self, fn: Callable[[], Any]):
        last = None
        for i in range(self.max_attempts):
            try:
                return await fn()
            except Exception as e:
                last = e
                await asyncio.sleep(self.base_delay * (2 ** i) + self.jitter)
        raise last

class CircuitBreaker:
    def __init__(self, failure_threshold: int = 5, cool_down: float = 30.0):
        self.failure_threshold = failure_threshold
        self.cool_down = cool_down
        self.failures = 0
        self.open_until = 0.0

    def allow(self) -> bool:
        return time.time() >= self.open_until

    def record_success(self):
        self.failures = 0

    def record_failure(self):
        self.failures += 1
        if self.failures >= self.failure_threshold:
            self.open_until = time.time() + self.cool_down

class RateLimiter:
    def __init__(self, qps: float = 10.0):
        self.interval = 1.0 / qps
        self.last = 0.0

    async def acquire(self):
        now = time.time()
        delta = self.interval - (now - self.last)
        if delta > 0:
            await asyncio.sleep(delta)
        self.last = time.time()

class QuotaRouter:
    """按模型/Provider 配额与成本做路由:先走便宜配额,耗尽再升级。"""
    def __init__(self, providers: list[dict]):
        self.providers = providers  # [{"id": "openai:gpt-3.5", "cost": 1, "remaining": 1000}, ...]

    def select(self) -> dict:
        affordable = [p for p in self.providers if p.get("remaining", 0) > 0]
        if not affordable:
            # 全部耗尽时,选择质量更高但更贵的作为兜底
            return sorted(self.providers, key=lambda x: x["cost"])[0]
        # 选择最低成本的可用配额
        return sorted(affordable, key=lambda x: x["cost"])[0]

    def consume(self, pid: str, tokens: int):
        for p in self.providers:
            if p["id"] == pid:
                p["remaining"] = max(0, p.get("remaining", 0) - tokens)
                break

# 组合使用示例
async def robust_generate(llm, messages):
    rl = RateLimiter(qps=20)
    cb = CircuitBreaker(failure_threshold=5, cool_down=15)
    retry = RetryPolicy(max_attempts=3)

    await rl.acquire()

    if not cb.allow():
        raise RuntimeError("circuit_open")

    async def call():
        return await llm._agenerate(messages)

    try:
        result = await retry.aretry(call)
        cb.record_success()
        return result
    except Exception:
        cb.record_failure()
        raise

4. 高性能向量存储

4.1 优化的向量存储实现

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
from typing import Any, Dict, List, Optional, Tuple, Union
from langchain_core.vectorstores import VectorStore
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
import numpy as np
import faiss
import pickle
import os
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
import time

class HighPerformanceVectorStore(VectorStore):
    """高性能向量存储实现

    参考:LangChain原理解析及开发实战指南(2025年最新版)
    https://jishuzhan.net/article/1895692926025994242
    """

    def __init__(
        self,
        embedding_function: Embeddings,
        index_factory: str = "IVF1024,Flat",
        metric_type: str = "L2",
        use_gpu: bool = False,
        cache_size: int = 10000,
        batch_size: int = 1000,
        **kwargs
    ):
        self.embedding_function = embedding_function
        self.index_factory = index_factory
        self.metric_type = metric_type
        self.use_gpu = use_gpu
        self.cache_size = cache_size
        self.batch_size = batch_size

        # 初始化FAISS索引
        self.index = None
        self.dimension = None

        # 文档存储
        self.documents = {}
        self.id_to_index = {}
        self.index_to_id = {}

        # 缓存机制
        self.query_cache = {}
        self.cache_lock = threading.RLock()

        # 性能统计
        self.stats = {
            'total_queries': 0,
            'cache_hits': 0,
            'avg_query_time': 0,
            'total_documents': 0
        }

    def add_texts(
        self,
        texts: List[str],
        metadatas: Optional[List[dict]] = None,
        ids: Optional[List[str]] = None,
        **kwargs: Any,
    ) -> List[str]:
        """批量添加文本"""

        if not texts:
            return []

        # 生成ID
        if ids is None:
            ids = [f"doc_{int(time.time() * 1000000)}_{i}" for i in range(len(texts))]

        # 处理元数据
        if metadatas is None:
            metadatas = [{}] * len(texts)

        # 批量生成嵌入
        embeddings = self._batch_embed_texts(texts)

        # 初始化索引(如果需要)
        if self.index is None:
            self.dimension = len(embeddings[0])
            self._initialize_index()

        # 添加到索引
        start_index = len(self.index_to_id)

        # 批量添加向量
        vectors = np.array(embeddings, dtype=np.float32)
        self.index.add(vectors)

        # 更新映射和文档存储
        for i, (text, metadata, doc_id) in enumerate(zip(texts, metadatas, ids)):
            index_id = start_index + i

            # 创建文档
            doc = Document(page_content=text, metadata=metadata)

            # 更新存储
            self.documents[doc_id] = doc
            self.id_to_index[doc_id] = index_id
            self.index_to_id[index_id] = doc_id

        # 更新统计
        self.stats['total_documents'] += len(texts)

        # 清空查询缓存(因为索引已更新)
        with self.cache_lock:
            self.query_cache.clear()

        return ids

    def similarity_search_with_score(
        self,
        query: str,
        k: int = 4,
        filter: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> List[Tuple[Document, float]]:
        """相似度搜索(带分数)"""

        start_time = time.time()

        # 检查缓存
        cache_key = self._generate_cache_key(query, k, filter)

        with self.cache_lock:
            if cache_key in self.query_cache:
                self.stats['cache_hits'] += 1
                self.stats['total_queries'] += 1
                return self.query_cache[cache_key]

        # 生成查询向量
        query_embedding = self.embedding_function.embed_query(query)
        query_vector = np.array([query_embedding], dtype=np.float32)

        # 执行搜索
        if self.index is None or self.index.ntotal == 0:
            return []

        # FAISS搜索
        scores, indices = self.index.search(query_vector, min(k, self.index.ntotal))

        # 处理结果
        results = []
        for score, idx in zip(scores[0], indices[0]):
            if idx == -1:  # FAISS返回-1表示无效结果
                continue

            doc_id = self.index_to_id.get(idx)
            if doc_id and doc_id in self.documents:
                doc = self.documents[doc_id]

                # 应用过滤器
                if filter and not self._apply_filter(doc, filter):
                    continue

                # 转换分数(FAISS返回的是距离,需要转换为相似度)
                similarity_score = self._distance_to_similarity(score)
                results.append((doc, similarity_score))

        # 限制结果数量
        results = results[:k]

        # 缓存结果
        with self.cache_lock:
            if len(self.query_cache) < self.cache_size:
                self.query_cache[cache_key] = results

        # 更新统计
        query_time = time.time() - start_time
        self.stats['total_queries'] += 1

        # 更新平均查询时间
        total_queries = self.stats['total_queries']
        current_avg = self.stats['avg_query_time']
        self.stats['avg_query_time'] = (
            (current_avg * (total_queries - 1) + query_time) / total_queries
        )

        return results

    def similarity_search(
        self,
        query: str,
        k: int = 4,
        filter: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> List[Document]:
        """相似度搜索"""
        results = self.similarity_search_with_score(query, k, filter, **kwargs)
        return [doc for doc, _ in results]

    def _batch_embed_texts(self, texts: List[str]) -> List[List[float]]:
        """批量生成文本嵌入"""

        if len(texts) <= self.batch_size:
            return self.embedding_function.embed_documents(texts)

        # 分批处理大量文本
        all_embeddings = []

        with ThreadPoolExecutor(max_workers=4) as executor:
            futures = []

            for i in range(0, len(texts), self.batch_size):
                batch = texts[i:i + self.batch_size]
                future = executor.submit(self.embedding_function.embed_documents, batch)
                futures.append(future)

            for future in as_completed(futures):
                batch_embeddings = future.result()
                all_embeddings.extend(batch_embeddings)

        return all_embeddings

    def _initialize_index(self):
        """初始化FAISS索引"""

        if self.metric_type == "L2":
            metric = faiss.METRIC_L2
        elif self.metric_type == "IP":
            metric = faiss.METRIC_INNER_PRODUCT
        else:
            metric = faiss.METRIC_L2

        # 创建索引
        if self.index_factory == "Flat":
            self.index = faiss.IndexFlatL2(self.dimension)
        else:
            self.index = faiss.index_factory(self.dimension, self.index_factory, metric)

        # GPU支持
        if self.use_gpu and faiss.get_num_gpus() > 0:
            gpu_resource = faiss.StandardGpuResources()
            self.index = faiss.index_cpu_to_gpu(gpu_resource, 0, self.index)

        # 训练索引(如果需要)
        if hasattr(self.index, 'is_trained') and not self.index.is_trained:
            # 对于需要训练的索引类型,这里需要训练数据
            pass

    def _generate_cache_key(
        self,
        query: str,
        k: int,
        filter: Optional[Dict[str, Any]]
    ) -> str:
        """生成缓存键"""
        import hashlib

        key_data = f"{query}:{k}:{filter}"
        return hashlib.md5(key_data.encode()).hexdigest()

    def _apply_filter(self, doc: Document, filter: Dict[str, Any]) -> bool:
        """应用元数据过滤器"""
        for key, value in filter.items():
            if key not in doc.metadata:
                return False
            if doc.metadata[key] != value:
                return False
        return True

    def _distance_to_similarity(self, distance: float) -> float:
        """将距离转换为相似度分数"""
        # 简单的转换公式,实际应用中可能需要更复杂的转换
        return 1.0 / (1.0 + distance)

    def get_performance_stats(self) -> Dict[str, Any]:
        """获取性能统计"""
        cache_hit_rate = (
            self.stats['cache_hits'] / max(self.stats['total_queries'], 1) * 100
        )

        return {
            **self.stats,
            'cache_hit_rate': f"{cache_hit_rate:.2f}%",
            'cache_size': len(self.query_cache),
            'index_size': self.index.ntotal if self.index else 0
        }

4.2 混合检索与重排(Hybrid + Rerank)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from typing import List, Tuple

class HybridRetriever:
    """向量检索 + BM25(或关键词) 混合,并用 RRF 融合;支持可选重排模型。"""
    def __init__(self, vector_retriever, bm25_retriever, k: int = 8, alpha: float = 0.7, reranker=None):
        self.vec = vector_retriever
        self.bm25 = bm25_retriever
        self.k = k
        self.alpha = alpha
        self.reranker = reranker  # 可对融合后的候选做二次重排(如 Cross-Encoder)

    def _rrf(self, lists: List[List[Tuple[str, float]]]) -> List[Tuple[str, float]]:
        # Reciprocal Rank Fusion: score += 1/(rank + 60)
        scores = {}
        for results in lists:
            for rank, (doc_id, _) in enumerate(results):
                scores[doc_id] = scores.get(doc_id, 0.0) + 1.0 / (rank + 60.0)
        return sorted(scores.items(), key=lambda x: x[1], reverse=True)

    def _topk_pairs(self, docs) -> List[Tuple[str, float]]:
        pairs = []
        for i, d in enumerate(docs):
            # 简化:以位置当分数占位
            pairs.append((getattr(d, "id", f"doc_{i}"), 1.0 / (i + 1)))
        return pairs

    def get_relevant_documents(self, query: str):
        vec_docs = self.vec.get_relevant_documents(query)
        bm25_docs = self.bm25.get_relevant_documents(query)

        fused_ids = self._rrf([self._topk_pairs(vec_docs), self._topk_pairs(bm25_docs)])

        # 恢复文档对象并截断到 k
        id_to_doc = {}
        for d in vec_docs + bm25_docs:
            id_to_doc[getattr(d, "id", id(d))] = d

        candidates = [id_to_doc[i] for i, _ in fused_ids if i in id_to_doc][: self.k]

        if self.reranker:
            candidates = self.reranker.rerank(query, candidates)[: self.k]
        return candidates

4.3 语义缓存(Semantic Cache)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
from langchain_core.documents import Document

class SemanticCache:
    def __init__(self, embeddings, vectorstore, threshold: float = 0.92):
        self.emb = embeddings
        self.vs = vectorstore
        self.threshold = threshold

    def lookup(self, query: str) -> str | None:
        results = self.vs.similarity_search_with_score(query, k=1)
        if not results:
            return None
        doc, score = results[0]
        if score >= self.threshold:  # 假设 score 越大越相似(需与具体向量库一致)
            return doc.metadata.get("response")
        return None

    def update(self, query: str, response: str):
        doc = Document(page_content=query, metadata={"response": response})
        self.vs.add_texts([doc.page_content], metadatas=[doc.metadata])

提示:生产中应根据具体向量库分数语义(相似度或距离)调整阈值与比较方向,并增加 TTL 与逐出策略。

5. 参考文献

本文档基于以下深度源码分析文章:

  1. LangChain解析器与下游任务集成的源码实现剖析

  2. LangChain 核心概念深度解析与实战指南

  3. 逐行深入分析:langchain-ChatGLM项目的源码解读

  4. LangChain原理解析及开发实战指南(2025年最新版)

  5. 从源码视角,窥探LangChain的运行逻辑

附录A. 可观测性与成本控制

A.1 指标与Tracing

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
from typing import Dict, Any
from langchain_core.callbacks import BaseCallbackHandler
import time

class MetricsCallback(BaseCallbackHandler):
    def __init__(self, emitter):
        self.emitter = emitter  # 可为 Prometheus/StatsD/OpenTelemetry 导出器
        self.llm_start_time = {}

    def on_llm_start(self, serialized: Dict[str, Any], prompts, *, run_id, **kwargs):
        self.llm_start_time[run_id] = time.time()

    def on_llm_end(self, response, *, run_id, **kwargs):
        start = self.llm_start_time.pop(run_id, None)
        if start:
            duration = time.time() - start
            self.emitter.gauge("llm.duration", duration)
        usage = (response.llm_output or {}).get("token_usage", {})
        self.emitter.gauge("llm.tokens.input", usage.get("prompt_tokens", 0))
        self.emitter.gauge("llm.tokens.output", usage.get("completion_tokens", 0))

A.2 成本仪表(按模型/租户)

1
2
3
4
5
6
7
8
PRICING = {
    "gpt-4": {"in": 0.03, "out": 0.06},
    "gpt-3.5-turbo": {"in": 0.001, "out": 0.002},
}

def estimate_cost(model: str, input_tokens: int, output_tokens: int) -> float:
    p = PRICING.get(model, {"in": 0.0, "out": 0.0})
    return (input_tokens / 1000) * p["in"] + (output_tokens / 1000) * p["out"]

附录B. 评测与 A/B

B.1 离线基准集

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
from dataclasses import dataclass
from typing import List

@dataclass
class EvalCase:
    query: str
    expected_keywords: List[str]

def offline_eval(chain, cases: List[EvalCase]) -> float:
    hits = 0
    for c in cases:
        out = chain.invoke({"question": c.query})
        text = str(out)
        if all(k.lower() in text.lower() for k in c.expected_keywords):
            hits += 1
    return hits / max(len(cases), 1)

B.2 在线 A/B(简化)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
import random

class ABRouter:
    def __init__(self, variants: dict[str, Any], weights: dict[str, float]):
        self.variants = variants
        self.weights = weights

    def pick(self) -> str:
        names, ws = zip(*self.weights.items())
        r = random.random() * sum(ws)
        acc = 0
        for n, w in zip(names, ws):
            acc += w
            if r <= acc:
                return n
        return names[-1]

6. 总结

本文档汇总了基于网上深度源码分析文章的LangChain高级实践模式,涵盖了:

  1. 安全机制:数据加密、访问控制、隐私保护
  2. 多模态集成:图像、音频、视频等多种模态的处理
  3. 负载均衡:智能路由、故障转移、健康检查
  4. 性能优化:高性能向量存储、缓存机制、批处理优化

这些实践模式为开发者在生产环境中部署LangChain应用提供了重要的技术指导和最佳实践参考。


创建时间: 2025年09月13日

本文档基于网上优秀的LangChain源码分析文章,为企业级AI应用开发提供了全面的高级实践指导。