概述

原创贡献

  • 统一“时序图与调用路径”范式:各模块同时提供可运行代码、时序与函数链,支持工程落地与审计追溯。
  • 合规模块闭环:采用“租户×地域×PII 等级”三维配置矩阵,驱动审计与脱敏策略的自动化选择。
  • 多模态调用抽象:将 image/audio/video 处理器统一成可注册的“模态处理器表”,支持本地文件→Base64 的内联降级。
  • 可靠性与成本协同:限流/熔断/重试与“配额-成本”路由联动,包含灰度选择逻辑与撤回路径。
  • 高性能向量检索:在 FAISS 路径上加入查询缓存与分数语义归一化,配合混合检索的 RRF→可选重排流程。
  • 可观测性:审计日志结构化字段最小集与成本仪表,兼容 LangChain 回调体系的低侵入接入。

1. 安全与隐私保护机制

1.1 数据加密与隐私保护

LangChain在企业应用中需要完善的安全机制。“最小侵入安全管道”包括:输入先脱敏→再模板化→仅必要字段加密→回传口径可控(可截断/可掩码),以降低上游改造成本:

import hashlib
import base64
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
import os
import re
from typing import Dict, Any, Optional

class LangChainSecurityManager:
    """LangChain安全管理器

        https://blog.csdn.net/qq_28540861/article/details/149057817
    """

    def __init__(self, master_key: Optional[str] = None):
        self.master_key = master_key or os.environ.get('LANGCHAIN_MASTER_KEY')
        if not self.master_key:
            raise ValueError("必须提供主密钥")

        self.cipher_suite = self._create_cipher_suite()
        self.sensitive_patterns = [
            r'\b\d{4}[-\s]?\d{4}[-\s]?\d{4}[-\s]?\d{4}\b',  # 信用卡号
            r'\b\d{3}-\d{2}-\d{4}\b',  # SSN
            r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',  # 邮箱
            r'\b\d{11}\b',  # 手机号
        ]

    def _create_cipher_suite(self) -> Fernet:
        """创建加密套件"""
        password = self.master_key.encode()
        salt = b'langchain_salt_2024'  # 在生产环境中应使用随机salt

        kdf = PBKDF2HMAC(
            algorithm=hashes.SHA256(),
            length=32,
            salt=salt,
            iterations=100000,
        )

        key = base64.urlsafe_b64encode(kdf.derive(password))
        return Fernet(key)

    def encrypt_sensitive_data(self, data: str) -> str:
        """加密敏感数据"""
        try:
            encrypted_data = self.cipher_suite.encrypt(data.encode())
            return base64.urlsafe_b64encode(encrypted_data).decode()
        except Exception as e:
            raise ValueError(f"数据加密失败: {str(e)}")

    def decrypt_sensitive_data(self, encrypted_data: str) -> str:
        """解密敏感数据"""
        try:
            encrypted_bytes = base64.urlsafe_b64decode(encrypted_data.encode())
            decrypted_data = self.cipher_suite.decrypt(encrypted_bytes)
            return decrypted_data.decode()
        except Exception as e:
            raise ValueError(f"数据解密失败: {str(e)}")

    def sanitize_input(self, text: str) -> str:
        """清理输入中的敏感信息"""
        sanitized_text = text

        for pattern in self.sensitive_patterns:
            # 替换敏感信息为占位符
            sanitized_text = re.sub(pattern, '[REDACTED]', sanitized_text)

        return sanitized_text

    def create_secure_prompt_template(self, template: str) -> 'SecurePromptTemplate':
        """创建安全的提示模板"""
        return SecurePromptTemplate(template, self)

class SecurePromptTemplate:
    """安全的提示模板"""

    def __init__(self, template: str, security_manager: LangChainSecurityManager):
        self.template = template
        self.security_manager = security_manager

    def format(self, **kwargs) -> str:
        """格式化提示,自动清理敏感信息"""
        sanitized_kwargs = {}

        for key, value in kwargs.items():
            if isinstance(value, str):
                sanitized_kwargs[key] = self.security_manager.sanitize_input(value)
            else:
                sanitized_kwargs[key] = value

        return self.template.format(**sanitized_kwargs)

# 使用示例
def create_secure_langchain_demo():
    """安全LangChain使用示例"""

    # 初始化安全管理器
    security_manager = LangChainSecurityManager("your-master-key-here")

    # 创建安全的提示模板
    secure_template = security_manager.create_secure_prompt_template("""
基于以下用户信息回答问题:

用户信息:{user_info}
问题:{question}

请注意保护用户隐私,不要在回答中包含敏感信息。

回答:
""")

    # 测试敏感信息处理
    user_info = "我的邮箱是john.doe@example.com,信用卡号是1234-5678-9012-3456"
    question = "请帮我分析一下我的账户情况"

    # 格式化提示(自动清理敏感信息)
    safe_prompt = secure_template.format(
        user_info=user_info,
        question=question
    )

    print("安全处理后的提示:")
    print(safe_prompt)

    # 加密存储敏感数据
    encrypted_info = security_manager.encrypt_sensitive_data(user_info)
    print(f"加密后的用户信息:{encrypted_info}")

    # 解密数据
    decrypted_info = security_manager.decrypt_sensitive_data(encrypted_info)
    print(f"解密后的用户信息:{decrypted_info}")

if __name__ == "__main__":
    create_secure_langchain_demo()

输入清理与加密解密:时序图与调用路径

sequenceDiagram
    participant App as 应用
    participant Sec as LangChainSecurityManager
    participant Prompt as SecurePromptTemplate
    participant Store as 安全存储

    App->>Sec: sanitize_input(text)
    Sec-->>App: 返回已脱敏文本
    App->>Prompt: format(user_info, question)
    Prompt->>Sec: sanitize_input(字段逐项)
    Sec-->>Prompt: 已脱敏字段
    Prompt-->>App: 安全提示(safe_prompt)

    App->>Sec: encrypt_sensitive_data(user_info)
    Sec-->>App: 加密密文(ciphertext)
    App->>Store: 持久化密文
    App->>Sec: decrypt_sensitive_data(ciphertext)
    Sec-->>App: 明文(user_info)

关键调用路径(安全与隐私)

  • 输入脱敏:SecurePromptTemplate.format() -> LangChainSecurityManager.sanitize_input() -> re.sub()
  • 加密写入:LangChainSecurityManager.encrypt_sensitive_data() -> Fernet.encrypt() -> base64.urlsafe_b64encode()
  • 解密读取:LangChainSecurityManager.decrypt_sensitive_data() -> base64.urlsafe_b64decode() -> Fernet.decrypt()
  • 模板创建:LangChainSecurityManager.create_secure_prompt_template() -> SecurePromptTemplate.__init__()

1.2 访问控制与权限管理

在常见 RBAC 基础上,增加“权限装饰器可注入来源(header/kwargs/上下文)”与“权限向量化快照(便于审计回放)”,示例:

from enum import Enum
from functools import wraps
from typing import List, Dict, Any, Callable
import jwt
import time

class Permission(Enum):
    """权限枚举"""
    READ_DOCUMENTS = "read_documents"
    WRITE_DOCUMENTS = "write_documents"
    EXECUTE_TOOLS = "execute_tools"
    MANAGE_AGENTS = "manage_agents"
    ADMIN_ACCESS = "admin_access"

class Role(Enum):
    """角色枚举"""
    GUEST = "guest"
    USER = "user"
    DEVELOPER = "developer"
    ADMIN = "admin"

class AccessControlManager:
    """访问控制管理器"""

    def __init__(self, secret_key: str):
        self.secret_key = secret_key
        self.role_permissions = {
            Role.GUEST: [Permission.READ_DOCUMENTS],
            Role.USER: [Permission.READ_DOCUMENTS, Permission.EXECUTE_TOOLS],
            Role.DEVELOPER: [
                Permission.READ_DOCUMENTS,
                Permission.WRITE_DOCUMENTS,
                Permission.EXECUTE_TOOLS,
                Permission.MANAGE_AGENTS
            ],
            Role.ADMIN: [
                Permission.READ_DOCUMENTS,
                Permission.WRITE_DOCUMENTS,
                Permission.EXECUTE_TOOLS,
                Permission.MANAGE_AGENTS,
                Permission.ADMIN_ACCESS
            ]
        }

    def create_token(self, user_id: str, role: Role, expires_in: int = 3600) -> str:
        """创建JWT令牌"""
        payload = {
            'user_id': user_id,
            'role': role.value,
            'permissions': [p.value for p in self.role_permissions[role]],
            'exp': time.time() + expires_in,
            'iat': time.time()
        }

        return jwt.encode(payload, self.secret_key, algorithm='HS256')

    def verify_token(self, token: str) -> Dict[str, Any]:
        """验证JWT令牌"""
        try:
            payload = jwt.decode(token, self.secret_key, algorithms=['HS256'])
            return payload
        except jwt.ExpiredSignatureError:
            raise ValueError("令牌已过期")
        except jwt.InvalidTokenError:
            raise ValueError("无效的令牌")

    def check_permission(self, token: str, required_permission: Permission) -> bool:
        """检查权限"""
        try:
            payload = self.verify_token(token)
            user_permissions = payload.get('permissions', [])
            return required_permission.value in user_permissions
        except ValueError:
            return False

    def require_permission(self, permission: Permission):
        """权限装饰器"""
        def decorator(func: Callable):
            @wraps(func)
            def wrapper(*args, **kwargs):
                # 从kwargs中获取token,或从请求头中获取
                token = kwargs.get('auth_token') or getattr(args[0], 'auth_token', None)

                if not token:
                    raise PermissionError("缺少认证令牌")

                if not self.check_permission(token, permission):
                    raise PermissionError(f"缺少必要权限: {permission.value}")

                return func(*args, **kwargs)
            return wrapper
        return decorator

class SecureLangChainAgent:
    """安全的LangChain Agent"""

    def __init__(self, access_control: AccessControlManager):
        self.access_control = access_control
        self.auth_token = None

    def authenticate(self, token: str):
        """认证用户"""
        self.auth_token = token
        return self.access_control.verify_token(token)

    @AccessControlManager.require_permission(Permission.READ_DOCUMENTS)
    def read_documents(self, query: str, auth_token: str = None) -> List[str]:
        """读取文档(需要读取权限)"""
        # 实际的文档读取逻辑
        return [f"文档内容:{query}"]

    @AccessControlManager.require_permission(Permission.EXECUTE_TOOLS)
    def execute_tool(self, tool_name: str, params: Dict[str, Any], auth_token: str = None) -> Any:
        """执行工具(需要执行权限)"""
        # 实际的工具执行逻辑
        return f"工具 {tool_name} 执行结果:{params}"

    @AccessControlManager.require_permission(Permission.MANAGE_AGENTS)
    def create_agent(self, agent_config: Dict[str, Any], auth_token: str = None) -> str:
        """创建代理(需要管理权限)"""
        # 实际的代理创建逻辑
        return f"代理已创建:{agent_config.get('name', 'unnamed')}"

访问控制与权限校验:时序图与调用路径

sequenceDiagram
    participant Client as 客户端
    participant ACM as AccessControlManager
    participant Agent as SecureLangChainAgent

    Client->>Agent: authenticate(token)
    Agent->>ACM: verify_token(token)
    ACM-->>Agent: payload(permissions)
    Agent-->>Client: 认证成功

    Client->>Agent: read_documents(query, auth_token)
    Note over Agent: @require_permission(READ_DOCUMENTS)
    Agent->>ACM: check_permission(token, READ_DOCUMENTS)
    ACM-->>Agent: 允许/拒绝
    alt 允许
        Agent-->>Client: 文档内容
    else 拒绝
        Agent-->>Client: PermissionError
    end
  • 认证解析:SecureLangChainAgent.authenticate() -> AccessControlManager.verify_token() -> jwt.decode()
  • 权限校验(装饰器):AccessControlManager.require_permission() -> decorator() -> wrapper() -> AccessControlManager.check_permission() -> AccessControlManager.verify_token()
  • 资源访问(示例):SecureLangChainAgent.read_documents() -> (装饰器校验通过)-> 业务逻辑

1.3 合规与数据主权(GDPR/CCPA/数据驻留)

在企业落地中,除安全之外需满足合规与数据主权要求:

  • 合规基线:GDPR/CCPA/ISO 27001;记录处理目的、数据最小化、可追溯删除(Right to be Forgotten)。
  • 数据分类与标记:按 PII/敏感等级打标,影响存储位置、加密强度与访问控制。
  • 数据驻留(Data Residency):按租户/地域隔离存储与处理(如 EU-only)。
  • 密钥与KMS:At-Rest/In-Transit 加密,集中管理密钥与轮换(Rotation)。
  • 数据保留策略:按法规/业务设置 TTL 与归档;实现可审计的删改记录。
from dataclasses import dataclass
from typing import Literal, Dict

Region = Literal["eu", "us", "apac"]

@dataclass
class DataResidencyConfig:
    tenant_id: str
    residency: Region  # 数据驻留地域
    pii_level: Literal["none", "low", "medium", "high"]
    encrypt_at_rest: bool = True
    kms_key_id: str | None = None
    retention_days: int = 180

    def storage_bucket(self) -> str:
        # 依据地域与租户路由到不同的对象存储/数据库实例
        return f"lc-{self.residency}-tenant-{self.tenant_id}"

    def should_mask_output(self) -> bool:
        return self.pii_level in ("medium", "high")

# 使用示例
cfg = DataResidencyConfig(tenant_id="acme", residency="eu", pii_level="high", kms_key_id="kms-eu-123")
bucket = cfg.storage_bucket()  # lc-eu-tenant-acme

1.4 审计日志与取证(Audit & Forensics)

采用“事件字段最小充分集”:run_id、parent_run_id、ts、tenant_id、region、prompt_preview(脱敏/截断)、usage、tool 与 chain 标识,统一回放口径,避免日志冗长:

import json, time, os
from uuid import uuid4
from typing import Any, Dict, List, Optional
from langchain_core.callbacks import BaseCallbackHandler

class AuditCallbackHandler(BaseCallbackHandler):
    """结构化审计:链/LLM/工具 关键事件持久化(JSON Lines)。
    - 脱敏:对可能含PII的字段做脱敏/哈希
    - 取证:记录 run_id、parent_run_id、时间戳、地域/租户标签
    """

    def __init__(self, path: str = "./logs/audit.jsonl", tenant_id: str = "default", region: str = "eu"):
        os.makedirs(os.path.dirname(path), exist_ok=True)
        self.path = path
        self.tenant_id = tenant_id
        self.region = region

    def _write(self, record: Dict[str, Any]):
        record.setdefault("ts", time.time())
        record.setdefault("tenant_id", self.tenant_id)
        record.setdefault("region", self.region)
        with open(self.path, "a", encoding="utf-8") as f:
            f.write(json.dumps(record, ensure_ascii=False) + "\n")

    # 示例:LLM 开始/结束事件
    def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], *, run_id, **kwargs):
        self._write({
            "event": "llm_start",
            "run_id": str(run_id),
            "model": serialized.get("id"),
            "prompt_preview": (prompts[0][:200] if prompts else ""),  # 做截断+必要脱敏
        })

    def on_llm_end(self, response, *, run_id, **kwargs):
        usage = {}
        if getattr(response, "llm_output", None):
            usage = response.llm_output.get("token_usage", {})
        self._write({
            "event": "llm_end",
            "run_id": str(run_id),
            "usage": usage,
        })

    def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], *, run_id, **kwargs):
        self._write({
            "event": "chain_start",
            "run_id": str(run_id),
            "chain": serialized.get("id"),
        })

    def on_chain_end(self, outputs: Dict[str, Any], *, run_id, **kwargs):
        self._write({
            "event": "chain_end",
            "run_id": str(run_id),
        })

    def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id, **kwargs):
        self._write({
            "event": "tool_start",
            "run_id": str(run_id),
            "tool": serialized.get("name"),
        })

    def on_tool_end(self, output: str, *, run_id, **kwargs):
        self._write({
            "event": "tool_end",
            "run_id": str(run_id),
        })

# 用法:在调用链上绑定
# chain.invoke(inputs, config={"callbacks": [AuditCallbackHandler(path="/var/log/lc/audit.jsonl", tenant_id="acme", region="eu")]})

审计回调与落盘:时序图与调用路径

sequenceDiagram
    participant User as 用户
    participant Chain as Chain/Agent
    participant LLM as LLM
    participant Tool as 工具
    participant Audit as AuditCallbackHandler
    participant Store as 审计存储(JSONL)

    User->>Chain: invoke(inputs, callbacks=[Audit])
    Chain->>Audit: on_chain_start
    Chain->>LLM: generate(...)
    LLM->>Audit: on_llm_start(prompts)
    LLM-->>Audit: on_llm_end(usage)
    alt 需要工具
        Chain->>Tool: run(input)
        Tool->>Audit: on_tool_start
        Tool-->>Audit: on_tool_end
    end
    Chain-->>Audit: on_chain_end(outputs)
    Audit->>Store: 结构化写入(ts, run_id, usage, tags)
    Store-->>Audit: 持久化成功
  • 绑定回调:chain.invoke(..., config={callbacks:[AuditCallbackHandler]}) -> CallbackManager.on_chain_start() -> AuditCallbackHandler.on_chain_start()
  • LLM事件:LLM.generate() -> CallbackManager.on_llm_start() -> AuditCallbackHandler.on_llm_start() -> CallbackManager.on_llm_end() -> AuditCallbackHandler.on_llm_end()
  • 工具事件:Tool.run() -> CallbackManager.on_tool_start() -> AuditCallbackHandler.on_tool_start() -> on_tool_end()
  • 持久化:AuditCallbackHandler._write() -> open().write(jsonl)

2. 多模态集成实现

2.1 多模态聊天模型

为提升多模态落地的一致性,使用“模态处理器注册表”与本地文件→Base64 的一致化降级策略,保障端到端可运行:

from typing import Any, Dict, List, Optional, Union
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from langchain_core.outputs import ChatResult, ChatGeneration
import base64
import requests
from PIL import Image
import io

class MultiModalChatModel(BaseChatModel):
    """多模态聊天模型集成

        https://blog.csdn.net/jkgSFS/article/details/145068612
    """

    def __init__(
        self,
        model_name: str = "gpt-4-vision-preview",
        api_key: Optional[str] = None,
        base_url: Optional[str] = None,
        max_tokens: int = 1000,
        temperature: float = 0.1,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.model_name = model_name
        self.api_key = api_key
        self.base_url = base_url or "https://api.openai.com/v1"
        self.max_tokens = max_tokens
        self.temperature = temperature

        # 支持的图像格式
        self.supported_image_formats = {'.jpg', '.jpeg', '.png', '.gif', '.webp'}

        # 模态处理器注册
        self.modality_processors = {
            'text': self._process_text,
            'image': self._process_image,
            'audio': self._process_audio,
            'video': self._process_video
        }

    def _generate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[Any] = None,
        **kwargs: Any,
    ) -> ChatResult:
        """生成多模态响应"""

        # 处理多模态消息
        processed_messages = []
        for message in messages:
            processed_msg = self._process_multimodal_message(message)
            processed_messages.append(processed_msg)

        # 构建API请求
        request_data = {
            "model": self.model_name,
            "messages": processed_messages,
            "max_tokens": self.max_tokens,
            "temperature": self.temperature,
        }

        if stop:
            request_data["stop"] = stop

        # 发送请求
        headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json"
        }

        try:
            response = requests.post(
                f"{self.base_url}/chat/completions",
                json=request_data,
                headers=headers,
                timeout=60
            )
            response.raise_for_status()

            result = response.json()

            # 解析响应
            choice = result["choices"][0]
            message_content = choice["message"]["content"]

            # 创建生成结果
            generation = ChatGeneration(
                message=AIMessage(content=message_content),
                generation_info={
                    "finish_reason": choice.get("finish_reason"),
                    "model": result.get("model"),
                    "usage": result.get("usage", {})
                }
            )

            return ChatResult(generations=[generation])

        except Exception as e:
            raise ValueError(f"多模态模型调用失败: {str(e)}")

    def _process_multimodal_message(self, message: BaseMessage) -> Dict[str, Any]:
        """处理多模态消息"""

        if isinstance(message, HumanMessage):
            content = message.content

            # 检查是否包含多模态内容
            if isinstance(content, str):
                # 纯文本消息
                return {
                    "role": "user",
                    "content": content
                }
            elif isinstance(content, list):
                # 多模态内容列表
                processed_content = []

                for item in content:
                    if isinstance(item, dict):
                        modality_type = item.get("type", "text")
                        processor = self.modality_processors.get(modality_type)

                        if processor:
                            processed_item = processor(item)
                            processed_content.append(processed_item)
                        else:
                            # 未知模态类型,作为文本处理
                            processed_content.append({
                                "type": "text",
                                "text": str(item)
                            })
                    else:
                        # 非字典项,作为文本处理
                        processed_content.append({
                            "type": "text",
                            "text": str(item)
                        })

                return {
                    "role": "user",
                    "content": processed_content
                }

        elif isinstance(message, AIMessage):
            return {
                "role": "assistant",
                "content": message.content
            }

        else:
            # 其他消息类型
            return {
                "role": "user",
                "content": str(message.content)
            }

    def _process_text(self, item: Dict[str, Any]) -> Dict[str, Any]:
        """处理文本模态"""
        return {
            "type": "text",
            "text": item.get("text", "")
        }

    def _process_image(self, item: Dict[str, Any]) -> Dict[str, Any]:
        """处理图像模态"""
        image_data = item.get("image_url") or item.get("image")

        if isinstance(image_data, str):
            if image_data.startswith("http"):
                # 网络图片URL
                return {
                    "type": "image_url",
                    "image_url": {
                        "url": image_data,
                        "detail": item.get("detail", "auto")
                    }
                }
            elif image_data.startswith("data:image"):
                # Base64编码的图片
                return {
                    "type": "image_url",
                    "image_url": {
                        "url": image_data,
                        "detail": item.get("detail", "auto")
                    }
                }
            else:
                # 本地文件路径
                encoded_image = self._encode_image_file(image_data)
                return {
                    "type": "image_url",
                    "image_url": {
                        "url": f"data:image/jpeg;base64,{encoded_image}",
                        "detail": item.get("detail", "auto")
                    }
                }

        return {
            "type": "text",
            "text": "[无法处理的图像数据]"
        }

    def _process_audio(self, item: Dict[str, Any]) -> Dict[str, Any]:
        """处理音频模态(暂不支持,转为文本描述)"""
        return {
            "type": "text",
            "text": f"[音频文件: {item.get('audio', 'unknown')}]"
        }

    def _process_video(self, item: Dict[str, Any]) -> Dict[str, Any]:
        """处理视频模态(暂不支持,转为文本描述)"""
        return {
            "type": "text",
            "text": f"[视频文件: {item.get('video', 'unknown')}]"
        }

    def _encode_image_file(self, image_path: str) -> str:
        """编码本地图像文件为Base64"""
        try:
            with open(image_path, "rb") as image_file:
                encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
            return encoded_string
        except Exception as e:
            raise ValueError(f"无法编码图像文件 {image_path}: {str(e)}")

    @property
    def _llm_type(self) -> str:
        return "multimodal_chat"

# 使用示例
def demo_multimodal_integration():
    """多模态集成演示"""

    # 初始化多模态模型
    multimodal_model = MultiModalChatModel(
        model_name="gpt-4-vision-preview",
        api_key="your-api-key-here"
    )

    # 创建多模态消息
    multimodal_message = HumanMessage(content=[
        {
            "type": "text",
            "text": "请分析这张图片中的内容,并描述你看到的主要元素。"
        },
        {
            "type": "image",
            "image_url": "https://example.com/image.jpg",
            "detail": "high"
        }
    ])

    # 调用模型
    try:
        result = multimodal_model._generate([multimodal_message])
        print(f"多模态分析结果: {result.generations[0].message.content}")
    except Exception as e:
        print(f"多模态调用失败: {e}")

if __name__ == "__main__":
    demo_multimodal_integration()

多模态消息处理与模型调用:时序图与调用路径

sequenceDiagram
    participant App as 应用
    participant MM as MultiModalChatModel
    participant Proc as 模态处理器
    participant API as Chat Completions API

    App->>MM: _generate(messages)
    loop 遍历消息
        MM->>MM: _process_multimodal_message()
        alt 内容为列表
            MM->>Proc: 根据type选择处理器(text/image/audio/video)
            Proc-->>MM: 规范化内容项
        else 纯文本
            MM-->>MM: 直接封装为text
        end
    end
    MM->>API: POST /chat/completions(json)
    API-->>MM: result(choices, usage)
    MM-->>App: ChatResult(generation_info)
  • 生成主流程:MultiModalChatModel._generate() -> _process_multimodal_message() -> requests.post('/chat/completions') -> ChatGeneration/ChatResult
  • 文本处理:_process_text() -> 规范化为 {type:'text', text:...}
  • 图像处理(URL/Base64/本地):_process_image() ->(本地时)_encode_image_file() -> base64.b64encode()
  • AI消息输出:ChatGeneration(message=AIMessage(...)) -> ChatResult

2.2 输出安全过滤与SSE流式对接

采用“双通道最小实现”:上行流式渲染与下行敏感过滤复用同一函数接口;提供 SSE 端点示例以便快速接入。

from typing import AsyncIterator, Callable
import asyncio
import re

SENSITIVE_PATTERNS = [
    re.compile(r"\b\d{3}-\d{2}-\d{4}\b"),            # SSN
    re.compile(r"\b\d{16}\b"),                         # 粗略信用卡
    re.compile(r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}"),
]

def sanitize_chunk(text: str) -> str:
    for p in SENSITIVE_PATTERNS:
        text = p.sub("[REDACTED]", text)
    return text

async def astream_with_safety(chain, payload: dict, *, on_chunk: Callable[[str], None]) -> str:
    """边流式边过滤,返回最终完整文本。
    - chain.astream(...) 产生片段
    - 对每个片段做敏感信息过滤
    - 可在回调中推送到SSE
    """
    full = []
    async for chunk in chain.astream(payload):
        safe = sanitize_chunk(str(chunk))
        on_chunk(safe)
        full.append(safe)
        await asyncio.sleep(0)  # 让出事件循环
    return "".join(full)

# 示例:FastAPI SSE 推送(简化)
"""
from fastapi import FastAPI
from fastapi.responses import StreamingResponse

app = FastAPI()

@app.get("/stream")
async def stream_endpoint(q: str):
    async def event_source():
        async def push(data: str):
            yield f"data: {data}\n\n"

        # 假设 chain 为已构建的 LCEL 链
        final = await astream_with_safety(chain, {"question": q}, on_chunk=lambda s: None)
        # 这里简化为一次性返回,真实情况可逐chunk yield
        yield f"data: {final}\n\n"

    return StreamingResponse(event_source(), media_type="text/event-stream")
"""

SSE流式与输出过滤:时序图与调用路径

sequenceDiagram
    participant C as 客户端
    participant API as API(StreamingResponse)
    participant Chain as LCEL链
    participant Safe as sanitize_chunk

    C->>API: GET /stream?q=...
    API->>Chain: astream({question:q})
    loop 分片生成
        Chain-->>API: chunk
        API->>Safe: 过滤PII/敏感词
        Safe-->>API: safe_chunk
        API-->>C: SSE: data: safe_chunk\n\n
    end
    API-->>C: SSE 结束
  • 流式过滤:astream_with_safety() -> chain.astream() -> sanitize_chunk() -> on_chunk(safe) -> 拼接返回
  • HTTP 推送(示例):GET /stream -> StreamingResponse(event_source()) -> astream_with_safety() -> yield "data: ...\n\n"

3. 智能负载均衡与故障转移

3.1 负载均衡实现

from typing import List, Dict, Any, Optional, Callable
import random
import time
import threading
from dataclasses import dataclass
from enum import Enum
import logging

class ProviderStatus(Enum):
    """Provider状态枚举"""
    HEALTHY = "healthy"
    DEGRADED = "degraded"
    UNHEALTHY = "unhealthy"
    MAINTENANCE = "maintenance"

@dataclass
class ProviderMetrics:
    """Provider性能指标"""
    response_time: float = 0.0
    success_rate: float = 1.0
    error_count: int = 0
    total_requests: int = 0
    last_error_time: Optional[float] = None
    status: ProviderStatus = ProviderStatus.HEALTHY

class LoadBalancedChatModel:
    """负载均衡的聊天模型

        https://jishu.proginn.com/doc/298065111cfa69fe7
    """

    def __init__(
        self,
        providers: List[Dict[str, Any]],
        strategy: str = "round_robin",
        health_check_interval: int = 60,
        max_retries: int = 3,
        circuit_breaker_threshold: float = 0.5,
        **kwargs
    ):
        self.providers = {}
        self.provider_metrics = {}
        self.strategy = strategy
        self.health_check_interval = health_check_interval
        self.max_retries = max_retries
        self.circuit_breaker_threshold = circuit_breaker_threshold

        # 初始化providers
        for i, provider_config in enumerate(providers):
            provider_id = f"provider_{i}"
            self.providers[provider_id] = self._create_provider(provider_config)
            self.provider_metrics[provider_id] = ProviderMetrics()

        # 负载均衡策略
        self.current_index = 0
        self.strategy_lock = threading.Lock()

        # 健康检查
        self.health_check_thread = threading.Thread(
            target=self._health_check_loop,
            daemon=True
        )
        self.health_check_thread.start()

        self.logger = logging.getLogger(__name__)

    def _create_provider(self, config: Dict[str, Any]):
        """根据配置创建provider实例"""
        provider_type = config.get("type", "openai")

        if provider_type == "openai":
            from langchain_openai import ChatOpenAI
            return ChatOpenAI(**config.get("params", {}))
        elif provider_type == "anthropic":
            from langchain_anthropic import ChatAnthropic
            return ChatAnthropic(**config.get("params", {}))
        elif provider_type == "groq":
            from langchain_groq import ChatGroq
            return ChatGroq(**config.get("params", {}))
        else:
            raise ValueError(f"不支持的provider类型: {provider_type}")

    def _select_provider(self) -> Optional[str]:
        """根据策略选择provider"""

        # 过滤健康的providers
        healthy_providers = [
            pid for pid, metrics in self.provider_metrics.items()
            if metrics.status in [ProviderStatus.HEALTHY, ProviderStatus.DEGRADED]
        ]

        if not healthy_providers:
            self.logger.error("没有可用的健康providers")
            return None

        with self.strategy_lock:
            if self.strategy == "round_robin":
                return self._round_robin_select(healthy_providers)
            elif self.strategy == "weighted":
                return self._weighted_select(healthy_providers)
            elif self.strategy == "least_connections":
                return self._least_connections_select(healthy_providers)
            elif self.strategy == "fastest":
                return self._fastest_select(healthy_providers)
            else:
                return random.choice(healthy_providers)

    def _round_robin_select(self, providers: List[str]) -> str:
        """轮询选择"""
        if not providers:
            return None

        provider = providers[self.current_index % len(providers)]
        self.current_index += 1
        return provider

    def _weighted_select(self, providers: List[str]) -> str:
        """基于成功率的加权选择"""
        weights = []
        for pid in providers:
            metrics = self.provider_metrics[pid]
            # 权重基于成功率和响应时间
            weight = metrics.success_rate / max(metrics.response_time, 0.1)
            weights.append(weight)

        # 加权随机选择
        total_weight = sum(weights)
        if total_weight == 0:
            return random.choice(providers)

        rand_val = random.uniform(0, total_weight)
        cumulative = 0

        for i, weight in enumerate(weights):
            cumulative += weight
            if rand_val <= cumulative:
                return providers[i]

        return providers[-1]

    def _least_connections_select(self, providers: List[str]) -> str:
        """选择连接数最少的provider"""
        # 简化实现:选择错误数最少的
        min_errors = float('inf')
        best_provider = None

        for pid in providers:
            metrics = self.provider_metrics[pid]
            if metrics.error_count < min_errors:
                min_errors = metrics.error_count
                best_provider = pid

        return best_provider or providers[0]

    def _fastest_select(self, providers: List[str]) -> str:
        """选择响应最快的provider"""
        min_response_time = float('inf')
        fastest_provider = None

        for pid in providers:
            metrics = self.provider_metrics[pid]
            if metrics.response_time < min_response_time:
                min_response_time = metrics.response_time
                fastest_provider = pid

        return fastest_provider or providers[0]

    def _generate_with_fallback(
        self,
        messages: List[Any],
        **kwargs
    ) -> Any:
        """带故障转移的生成"""

        last_exception = None
        attempted_providers = set()

        for attempt in range(self.max_retries):
            # 选择provider
            provider_id = self._select_provider()

            if not provider_id or provider_id in attempted_providers:
                # 如果没有可用provider或已尝试过,跳出循环
                break

            attempted_providers.add(provider_id)
            provider = self.providers[provider_id]
            metrics = self.provider_metrics[provider_id]

            try:
                start_time = time.time()

                # 调用provider
                result = provider._generate(messages, **kwargs)

                # 更新成功指标
                response_time = time.time() - start_time
                self._update_success_metrics(provider_id, response_time)

                return result

            except Exception as e:
                last_exception = e

                # 更新失败指标
                self._update_failure_metrics(provider_id, e)

                self.logger.warning(
                    f"Provider {provider_id} 调用失败 (尝试 {attempt + 1}): {str(e)}"
                )

                # 如果还有重试机会,继续下一个provider
                continue

        # 所有provider都失败了
        if last_exception:
            raise last_exception
        else:
            raise RuntimeError("没有可用的provider")

    def _update_success_metrics(self, provider_id: str, response_time: float):
        """更新成功指标"""
        metrics = self.provider_metrics[provider_id]

        # 更新响应时间(使用指数移动平均)
        if metrics.total_requests == 0:
            metrics.response_time = response_time
        else:
            alpha = 0.1  # 平滑因子
            metrics.response_time = (
                alpha * response_time + (1 - alpha) * metrics.response_time
            )

        # 更新成功率
        metrics.total_requests += 1
        success_count = metrics.total_requests - metrics.error_count
        metrics.success_rate = success_count / metrics.total_requests

        # 更新状态
        if metrics.success_rate >= 0.95:
            metrics.status = ProviderStatus.HEALTHY
        elif metrics.success_rate >= 0.8:
            metrics.status = ProviderStatus.DEGRADED
        else:
            metrics.status = ProviderStatus.UNHEALTHY

    def _update_failure_metrics(self, provider_id: str, error: Exception):
        """更新失败指标"""
        metrics = self.provider_metrics[provider_id]

        metrics.error_count += 1
        metrics.total_requests += 1
        metrics.last_error_time = time.time()

        # 更新成功率
        success_count = metrics.total_requests - metrics.error_count
        metrics.success_rate = success_count / metrics.total_requests

        # 检查熔断器
        if metrics.success_rate < self.circuit_breaker_threshold:
            metrics.status = ProviderStatus.UNHEALTHY
            self.logger.error(
                f"Provider {provider_id} 触发熔断器,成功率: {metrics.success_rate:.2%}"
            )

    def _health_check_loop(self):
        """健康检查循环"""
        while True:
            try:
                time.sleep(self.health_check_interval)
                self._perform_health_checks()
            except Exception as e:
                self.logger.error(f"健康检查失败: {str(e)}")

    def _perform_health_checks(self):
        """执行健康检查"""
        from langchain_core.messages import HumanMessage

        test_message = [HumanMessage(content="Health check")]

        for provider_id, provider in self.providers.items():
            metrics = self.provider_metrics[provider_id]

            # 跳过维护状态的provider
            if metrics.status == ProviderStatus.MAINTENANCE:
                continue

            try:
                start_time = time.time()

                # 执行简单的健康检查
                provider._generate(test_message, max_tokens=1)

                response_time = time.time() - start_time

                # 如果之前是不健康状态,现在恢复了
                if metrics.status == ProviderStatus.UNHEALTHY:
                    metrics.status = ProviderStatus.DEGRADED
                    self.logger.info(f"Provider {provider_id} 健康检查通过,状态恢复")

                # 更新响应时间
                if metrics.total_requests > 0:
                    alpha = 0.1
                    metrics.response_time = (
                        alpha * response_time + (1 - alpha) * metrics.response_time
                    )

            except Exception as e:
                # 健康检查失败
                if metrics.status != ProviderStatus.UNHEALTHY:
                    metrics.status = ProviderStatus.UNHEALTHY
                    self.logger.warning(f"Provider {provider_id} 健康检查失败: {str(e)}")

    def get_provider_stats(self) -> Dict[str, Dict[str, Any]]:
        """获取所有provider的统计信息"""
        stats = {}

        for provider_id, metrics in self.provider_metrics.items():
            stats[provider_id] = {
                "status": metrics.status.value,
                "success_rate": f"{metrics.success_rate:.2%}",
                "response_time": f"{metrics.response_time:.3f}s",
                "error_count": metrics.error_count,
                "total_requests": metrics.total_requests,
                "last_error_time": metrics.last_error_time
            }

        return stats

负载均衡与故障转移:时序图与调用路径

sequenceDiagram
    participant Client as 调用方
    participant LB as LoadBalancedChatModel
    participant Sel as _select_provider
    participant Prov as Provider

    Client->>LB: _generate_with_fallback(messages)
    loop 最多 max_retries 次
        LB->>Sel: 选择健康Provider(策略)
        Sel-->>LB: provider_id
        LB->>Prov: provider._generate(messages)
        alt 成功
            LB->>LB: _update_success_metrics(pid, rt)
            LB-->>Client: 返回结果
            break
        else 失败
            LB->>LB: _update_failure_metrics(pid, err)
            LB->>LB: 尝试下一个Provider
        end
    end
    alt 全部失败
        LB-->>Client: 抛出最后一次异常
    end
  • 带兜底生成:LoadBalancedChatModel._generate_with_fallback() -> _select_provider() -> provider._generate() -> 成功:_update_success_metrics();失败:_update_failure_metrics() -> 重试
  • 选择策略:_select_provider() -> round_robin/weighted/least_connections/fastest
  • 健康检查:_health_check_loop() -> _perform_health_checks() -> provider._generate(test_message) -> 更新ProviderMetrics

3.2 可靠性控制:超时/重试/熔断/限流与配额路由

为满足“成本/可靠性”双目标,使用“配额优先 + 质量回退”路由:先用低成本配额,耗尽后切换到更高质量模型,并保留回退路径;重试/熔断与限流在统一装饰器内组合:

import time
import asyncio
from typing import Callable, Any

class RetryPolicy:
    def __init__(self, max_attempts: int = 3, base_delay: float = 0.2, jitter: float = 0.1):
        self.max_attempts = max_attempts
        self.base_delay = base_delay
        self.jitter = jitter

    async def aretry(self, fn: Callable[[], Any]):
        last = None
        for i in range(self.max_attempts):
            try:
                return await fn()
            except Exception as e:
                last = e
                await asyncio.sleep(self.base_delay * (2 ** i) + self.jitter)
        raise last

class CircuitBreaker:
    def __init__(self, failure_threshold: int = 5, cool_down: float = 30.0):
        self.failure_threshold = failure_threshold
        self.cool_down = cool_down
        self.failures = 0
        self.open_until = 0.0

    def allow(self) -> bool:
        return time.time() >= self.open_until

    def record_success(self):
        self.failures = 0

    def record_failure(self):
        self.failures += 1
        if self.failures >= self.failure_threshold:
            self.open_until = time.time() + self.cool_down

class RateLimiter:
    def __init__(self, qps: float = 10.0):
        self.interval = 1.0 / qps
        self.last = 0.0

    async def acquire(self):
        now = time.time()
        delta = self.interval - (now - self.last)
        if delta > 0:
            await asyncio.sleep(delta)
        self.last = time.time()

class QuotaRouter:
    """按模型/Provider 配额与成本做路由:先走便宜配额,耗尽再升级。"""
    def __init__(self, providers: list[dict]):
        self.providers = providers  # [{"id": "openai:gpt-3.5", "cost": 1, "remaining": 1000}, ...]

    def select(self) -> dict:
        affordable = [p for p in self.providers if p.get("remaining", 0) > 0]
        if not affordable:
            # 全部耗尽时,选择质量更高但更贵的作为兜底
            return sorted(self.providers, key=lambda x: x["cost"])[0]
        # 选择最低成本的可用配额
        return sorted(affordable, key=lambda x: x["cost"])[0]

    def consume(self, pid: str, tokens: int):
        for p in self.providers:
            if p["id"] == pid:
                p["remaining"] = max(0, p.get("remaining", 0) - tokens)
                break

# 组合使用示例
async def robust_generate(llm, messages):
    rl = RateLimiter(qps=20)
    cb = CircuitBreaker(failure_threshold=5, cool_down=15)
    retry = RetryPolicy(max_attempts=3)

    await rl.acquire()

    if not cb.allow():
        raise RuntimeError("circuit_open")

    async def call():
        return await llm._agenerate(messages)

    try:
        result = await retry.aretry(call)
        cb.record_success()
        return result
    except Exception:
        cb.record_failure()
        raise

可靠性与配额:时序图与调用路径

sequenceDiagram
    participant Client as 调用方
    participant RL as RateLimiter
    participant CB as CircuitBreaker
    participant Retry as RetryPolicy
    participant LLM as LLM Provider

    Client->>RL: acquire()
    RL-->>Client: 许可
    Client->>CB: allow?
    alt 熔断打开
        CB-->>Client: 拒绝(circuit_open)
    else 允许
        Client->>Retry: aretry(call)
        loop 最多N次
            Retry->>LLM: _agenerate(messages)
            alt 调用失败
                LLM-->>Retry: error
                Retry-->>Retry: 指数退避
            else 成功
                LLM-->>Retry: result
                Retry-->>CB: record_success
                Retry-->>Client: result
            end
        end
        Retry-->>CB: record_failure(超出重试)
    end
sequenceDiagram
    participant Client as 调用方
    participant QR as QuotaRouter
    participant LLM as LLM Provider

    Client->>QR: select()
    QR-->>Client: 选定pid(优先低成本且有剩余额度)
    Client->>LLM: 调用(pid)
    LLM-->>Client: 返回结果(含token使用)
    Client->>QR: consume(pid, tokens)
    QR-->>Client: 更新remaining
  • 可靠性封装:robust_generate() -> RateLimiter.acquire() -> CircuitBreaker.allow() -> RetryPolicy.aretry(call) -> llm._agenerate() -> CircuitBreaker.record_success()/record_failure()
  • 重试策略:RetryPolicy.aretry() -> fn() -> 异常 -> asyncio.sleep(指数退避) -> 最终抛出或返回
  • 配额路由:QuotaRouter.select() -> 选最低成本且remaining>0 -> fallback最便宜 -> QuotaRouter.consume() 更新剩余额度

4. 高性能向量存储

4.1 优化的向量存储实现

在 FAISS 路线中加入“查询缓存 + 距离→相似度归一化”的工程建议,并强调按库语义调整阈值方向(越小越近或越大越近),降低误配风险:

from typing import Any, Dict, List, Optional, Tuple, Union
from langchain_core.vectorstores import VectorStore
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
import numpy as np
import faiss
import pickle
import os
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
import time

class HighPerformanceVectorStore(VectorStore):
    """高性能向量存储实现

        https://jishuzhan.net/article/1895692926025994242
    """

    def __init__(
        self,
        embedding_function: Embeddings,
        index_factory: str = "IVF1024,Flat",
        metric_type: str = "L2",
        use_gpu: bool = False,
        cache_size: int = 10000,
        batch_size: int = 1000,
        **kwargs
    ):
        self.embedding_function = embedding_function
        self.index_factory = index_factory
        self.metric_type = metric_type
        self.use_gpu = use_gpu
        self.cache_size = cache_size
        self.batch_size = batch_size

        # 初始化FAISS索引
        self.index = None
        self.dimension = None

        # 文档存储
        self.documents = {}
        self.id_to_index = {}
        self.index_to_id = {}

        # 缓存机制
        self.query_cache = {}
        self.cache_lock = threading.RLock()

        # 性能统计
        self.stats = {
            'total_queries': 0,
            'cache_hits': 0,
            'avg_query_time': 0,
            'total_documents': 0
        }

    def add_texts(
        self,
        texts: List[str],
        metadatas: Optional[List[dict]] = None,
        ids: Optional[List[str]] = None,
        **kwargs: Any,
    ) -> List[str]:
        """批量添加文本"""

        if not texts:
            return []

        # 生成ID
        if ids is None:
            ids = [f"doc_{int(time.time() * 1000000)}_{i}" for i in range(len(texts))]

        # 处理元数据
        if metadatas is None:
            metadatas = [{}] * len(texts)

        # 批量生成嵌入
        embeddings = self._batch_embed_texts(texts)

        # 初始化索引(如果需要)
        if self.index is None:
            self.dimension = len(embeddings[0])
            self._initialize_index()

        # 添加到索引
        start_index = len(self.index_to_id)

        # 批量添加向量
        vectors = np.array(embeddings, dtype=np.float32)
        self.index.add(vectors)

        # 更新映射和文档存储
        for i, (text, metadata, doc_id) in enumerate(zip(texts, metadatas, ids)):
            index_id = start_index + i

            # 创建文档
            doc = Document(page_content=text, metadata=metadata)

            # 更新存储
            self.documents[doc_id] = doc
            self.id_to_index[doc_id] = index_id
            self.index_to_id[index_id] = doc_id

        # 更新统计
        self.stats['total_documents'] += len(texts)

        # 清空查询缓存(因为索引已更新)
        with self.cache_lock:
            self.query_cache.clear()

        return ids

    def similarity_search_with_score(
        self,
        query: str,
        k: int = 4,
        filter: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> List[Tuple[Document, float]]:
        """相似度搜索(带分数)"""

        start_time = time.time()

        # 检查缓存
        cache_key = self._generate_cache_key(query, k, filter)

        with self.cache_lock:
            if cache_key in self.query_cache:
                self.stats['cache_hits'] += 1
                self.stats['total_queries'] += 1
                return self.query_cache[cache_key]

        # 生成查询向量
        query_embedding = self.embedding_function.embed_query(query)
        query_vector = np.array([query_embedding], dtype=np.float32)

        # 执行搜索
        if self.index is None or self.index.ntotal == 0:
            return []

        # FAISS搜索
        scores, indices = self.index.search(query_vector, min(k, self.index.ntotal))

        # 处理结果
        results = []
        for score, idx in zip(scores[0], indices[0]):
            if idx == -1:  # FAISS返回-1表示无效结果
                continue

            doc_id = self.index_to_id.get(idx)
            if doc_id and doc_id in self.documents:
                doc = self.documents[doc_id]

                # 应用过滤器
                if filter and not self._apply_filter(doc, filter):
                    continue

                # 转换分数(FAISS返回的是距离,需要转换为相似度)
                similarity_score = self._distance_to_similarity(score)
                results.append((doc, similarity_score))

        # 限制结果数量
        results = results[:k]

        # 缓存结果
        with self.cache_lock:
            if len(self.query_cache) < self.cache_size:
                self.query_cache[cache_key] = results

        # 更新统计
        query_time = time.time() - start_time
        self.stats['total_queries'] += 1

        # 更新平均查询时间
        total_queries = self.stats['total_queries']
        current_avg = self.stats['avg_query_time']
        self.stats['avg_query_time'] = (
            (current_avg * (total_queries - 1) + query_time) / total_queries
        )

        return results

    def similarity_search(
        self,
        query: str,
        k: int = 4,
        filter: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> List[Document]:
        """相似度搜索"""
        results = self.similarity_search_with_score(query, k, filter, **kwargs)
        return [doc for doc, _ in results]

    def _batch_embed_texts(self, texts: List[str]) -> List[List[float]]:
        """批量生成文本嵌入"""

        if len(texts) <= self.batch_size:
            return self.embedding_function.embed_documents(texts)

        # 分批处理大量文本
        all_embeddings = []

        with ThreadPoolExecutor(max_workers=4) as executor:
            futures = []

            for i in range(0, len(texts), self.batch_size):
                batch = texts[i:i + self.batch_size]
                future = executor.submit(self.embedding_function.embed_documents, batch)
                futures.append(future)

            for future in as_completed(futures):
                batch_embeddings = future.result()
                all_embeddings.extend(batch_embeddings)

        return all_embeddings

    def _initialize_index(self):
        """初始化FAISS索引"""

        if self.metric_type == "L2":
            metric = faiss.METRIC_L2
        elif self.metric_type == "IP":
            metric = faiss.METRIC_INNER_PRODUCT
        else:
            metric = faiss.METRIC_L2

        # 创建索引
        if self.index_factory == "Flat":
            self.index = faiss.IndexFlatL2(self.dimension)
        else:
            self.index = faiss.index_factory(self.dimension, self.index_factory, metric)

        # GPU支持
        if self.use_gpu and faiss.get_num_gpus() > 0:
            gpu_resource = faiss.StandardGpuResources()
            self.index = faiss.index_cpu_to_gpu(gpu_resource, 0, self.index)

        # 训练索引(如果需要)
        if hasattr(self.index, 'is_trained') and not self.index.is_trained:
            # 对于需要训练的索引类型,这里需要训练数据
            pass

    def _generate_cache_key(
        self,
        query: str,
        k: int,
        filter: Optional[Dict[str, Any]]
    ) -> str:
        """生成缓存键"""
        import hashlib

        key_data = f"{query}:{k}:{filter}"
        return hashlib.md5(key_data.encode()).hexdigest()

    def _apply_filter(self, doc: Document, filter: Dict[str, Any]) -> bool:
        """应用元数据过滤器"""
        for key, value in filter.items():
            if key not in doc.metadata:
                return False
            if doc.metadata[key] != value:
                return False
        return True

    def _distance_to_similarity(self, distance: float) -> float:
        """将距离转换为相似度分数"""
        # 简单的转换公式,实际应用中可能需要更复杂的转换
        return 1.0 / (1.0 + distance)

    def get_performance_stats(self) -> Dict[str, Any]:
        """获取性能统计"""
        cache_hit_rate = (
            self.stats['cache_hits'] / max(self.stats['total_queries'], 1) * 100
        )

        return {
            **self.stats,
            'cache_hit_rate': f"{cache_hit_rate:.2f}%",
            'cache_size': len(self.query_cache),
            'index_size': self.index.ntotal if self.index else 0
        }

向量检索与查询缓存(FAISS):时序图与调用路径

sequenceDiagram
    participant App as 应用
    participant VS as HighPerformanceVectorStore
    participant Emb as Embeddings
    participant Index as FAISS索引

    App->>VS: similarity_search_with_score(query, k, filter)
    VS->>VS: 生成cache_key
    alt 缓存命中
        VS-->>App: 返回缓存结果
    else 未命中
        VS->>Emb: embed_query(query)
        Emb-->>VS: 向量q
        VS->>Index: search(q, k)
        Index-->>VS: (scores, indices)
        VS->>VS: 过滤/分数转换/截断
        VS->>VS: 写入缓存
        VS-->>App: 返回结果
    end
  • 写入索引:HighPerformanceVectorStore.add_texts() -> _batch_embed_texts() ->(必要时)_initialize_index() -> index.add() -> 更新id_to_index/index_to_id
  • 相似检索:similarity_search_with_score() -> 缓存查找 -> embedding_function.embed_query() -> index.search() -> _apply_filter() -> _distance_to_similarity() -> 写入缓存
  • 只取文档:similarity_search() -> similarity_search_with_score() -> 提取Document

4.2 混合检索与重排(Hybrid + Rerank)

采用“RRF 融合 + 可选 Cross-Encoder 重排”的二阶段方案,优先执行 RRF 以降低重排候选集,平衡质量与延迟:

from typing import List, Tuple

class HybridRetriever:
    """向量检索 + BM25(或关键词) 混合,并用 RRF 融合;支持可选重排模型。"""
    def __init__(self, vector_retriever, bm25_retriever, k: int = 8, alpha: float = 0.7, reranker=None):
        self.vec = vector_retriever
        self.bm25 = bm25_retriever
        self.k = k
        self.alpha = alpha
        self.reranker = reranker  # 可对融合后的候选做二次重排(如 Cross-Encoder)

    def _rrf(self, lists: List[List[Tuple[str, float]]]) -> List[Tuple[str, float]]:
        # Reciprocal Rank Fusion: score += 1/(rank + 60)
        scores = {}
        for results in lists:
            for rank, (doc_id, _) in enumerate(results):
                scores[doc_id] = scores.get(doc_id, 0.0) + 1.0 / (rank + 60.0)
        return sorted(scores.items(), key=lambda x: x[1], reverse=True)

    def _topk_pairs(self, docs) -> List[Tuple[str, float]]:
        pairs = []
        for i, d in enumerate(docs):
            # 简化:以位置当分数占位
            pairs.append((getattr(d, "id", f"doc_{i}"), 1.0 / (i + 1)))
        return pairs

    def get_relevant_documents(self, query: str):
        vec_docs = self.vec.get_relevant_documents(query)
        bm25_docs = self.bm25.get_relevant_documents(query)

        fused_ids = self._rrf([self._topk_pairs(vec_docs), self._topk_pairs(bm25_docs)])

        # 恢复文档对象并截断到 k
        id_to_doc = {}
        for d in vec_docs + bm25_docs:
            id_to_doc[getattr(d, "id", id(d))] = d

        candidates = [id_to_doc[i] for i, _ in fused_ids if i in id_to_doc][: self.k]

        if self.reranker:
            candidates = self.reranker.rerank(query, candidates)[: self.k]
        return candidates

混合检索与重排(RRF + Rerank):时序图与调用路径

sequenceDiagram
    participant Q as Query
    participant VR as 向量检索
    participant BR as BM25/关键词
    participant RRF as RRF融合
    participant RR as 重排模型(可选)

    Q->>VR: get_relevant_documents(q)
    Q->>BR: get_relevant_documents(q)
    VR-->>RRF: 列表[id, rank]
    BR-->>RRF: 列表[id, rank]
    RRF-->>RR: 候选TopK
    alt 配置重排
        RR-->>Q: 最终TopK
    else 无重排
        RRF-->>Q: 最终TopK
    end
  • 候选获取:HybridRetriever.get_relevant_documents() -> vec.get_relevant_documents() + bm25.get_relevant_documents()
  • 融合排序:_topk_pairs() -> _rrf() -> 组装候选TopK
  • 可选重排:reranker.rerank() -> 截断到k

4.3 语义缓存(Semantic Cache)

使用“近似命中可回退”策略:阈值附近可返回“近似命中”并提示生成回退路径,以提升命中率并兼顾答案质量:

from langchain_core.documents import Document

class SemanticCache:
    def __init__(self, embeddings, vectorstore, threshold: float = 0.92):
        self.emb = embeddings
        self.vs = vectorstore
        self.threshold = threshold

    def lookup(self, query: str) -> str | None:
        results = self.vs.similarity_search_with_score(query, k=1)
        if not results:
            return None
        doc, score = results[0]
        if score >= self.threshold:  # 假设 score 越大越相似(需与具体向量库一致)
            return doc.metadata.get("response")
        return None

    def update(self, query: str, response: str):
        doc = Document(page_content=query, metadata={"response": response})
        self.vs.add_texts([doc.page_content], metadatas=[doc.metadata])

提示:生产中应根据具体向量库分数语义(相似度或距离)调整阈值与比较方向,并增加 TTL 与逐出策略。

语义缓存命中与回填:时序图与调用路径

sequenceDiagram
    participant U as 用户
    participant Cache as SemanticCache
    participant VS as VectorStore
    participant LLM as LLM

    U->>Cache: lookup(query)
    alt 命中
        Cache-->>U: 命中响应
    else 未命中
        Cache->>VS: similarity_search_with_score(q, k=1)
        VS-->>Cache: 最近邻(doc, score)
        alt 分数>=阈值
            Cache-->>U: 近似命中响应
        else 仍需生成
            U->>LLM: 调用生成
            LLM-->>U: 响应
            U->>Cache: update(query, response)
        end
    end
  • 查询命中:SemanticCache.lookup() -> vs.similarity_search_with_score(k=1) -> 分数阈值判定 -> 返回缓存响应或None
  • 回填写入:SemanticCache.update() -> Document(...) -> vs.add_texts()

5. 关键函数与结构补充

本节对前述模块的关键函数补充核心代码片段(含简要注释与功能说明)、统一调用链、类结构图/继承关系与时序图索引,描述不包含价值判断。

5.1 安全与隐私(SecurityManager / SecurePromptTemplate)

关键函数(核心代码与说明):

class LangChainSecurityManager:
    def sanitize_input(self, text: str) -> str:
        """将输入中的潜在敏感字段替换为占位符。
        - 使用预置正则表达式集合匹配常见PII(邮箱/手机号/卡号等)
        - 返回与原文本形态一致但已脱敏的字符串
        """
        for pattern in self.sensitive_patterns:
            text = re.sub(pattern, '[REDACTED]', text)
        return text

    def encrypt_sensitive_data(self, data: str) -> str:
        """对敏感文本进行对称加密并做URL安全Base64编码,便于持久化传输。"""
        return base64.urlsafe_b64encode(self.cipher_suite.encrypt(data.encode())).decode()

    def decrypt_sensitive_data(self, encrypted_data: str) -> str:
        """对加密密文执行Base64解码与对称解密,恢复明文。"""
        return self.cipher_suite.decrypt(base64.urlsafe_b64decode(encrypted_data.encode())).decode()

class SecurePromptTemplate:
    def format(self, **kwargs) -> str:
        """对入参逐项脱敏后按模板格式化输出,避免PII泄露。"""
        sanitized = {k: (self.security_manager.sanitize_input(v) if isinstance(v, str) else v)
                     for k, v in kwargs.items()}
        return self.template.format(**sanitized)

统一调用链:

  • 模板渲染:SecurePromptTemplate.format() -> LangChainSecurityManager.sanitize_input() -> re.sub()
  • 加解密:encrypt_sensitive_data() -> Fernet.encrypt() / decrypt_sensitive_data() -> Fernet.decrypt()

类结构图(Mermaid):

classDiagram
    class LangChainSecurityManager {
      - master_key: str
      - cipher_suite: Fernet
      - sensitive_patterns: list
      + sanitize_input(text) str
      + encrypt_sensitive_data(data) str
      + decrypt_sensitive_data(encrypted_data) str
    }
    class SecurePromptTemplate {
      - template: str
      - security_manager: LangChainSecurityManager
      + format(**kwargs) str
    }
    LangChainSecurityManager <.. SecurePromptTemplate : uses

时序图索引:见“1.1 输入清理与加密解密”小节。

5.2 访问控制(AccessControlManager / SecureLangChainAgent)

关键函数(核心代码与说明):

class AccessControlManager:
    def create_token(self, user_id: str, role: Role, expires_in: int = 3600) -> str:
        """生成带角色与权限列表的JWT,用于下游鉴权。"""
        payload = {
            'user_id': user_id,
            'role': role.value,
            'permissions': [p.value for p in self.role_permissions[role]],
            'exp': time.time() + expires_in,
            'iat': time.time(),
        }
        return jwt.encode(payload, self.secret_key, algorithm='HS256')

    def check_permission(self, token: str, required_permission: Permission) -> bool:
        """校验令牌中是否包含目标权限,失败返回False。"""
        payload = self.verify_token(token)
        return required_permission.value in payload.get('permissions', [])

    def require_permission(self, permission: Permission):
        """方法装饰器;在进入业务函数前执行权限验证,不通过抛出PermissionError。"""
        ...

统一调用链:

  • 鉴权:SecureLangChainAgent.authenticate() -> AccessControlManager.verify_token() -> jwt.decode()
  • 访问控制:@require_permission(x) -> check_permission() -> 业务函数执行

类结构图(Mermaid):

classDiagram
    class Permission { <<enum>> }
    class Role { <<enum>> }
    class AccessControlManager {
      - secret_key: str
      - role_permissions: dict
      + create_token(user_id, role, expires_in) str
      + verify_token(token) dict
      + check_permission(token, required_permission) bool
      + require_permission(permission) decorator
    }
    class SecureLangChainAgent {
      - access_control: AccessControlManager
      - auth_token: str
      + authenticate(token)
      + read_documents(...)
      + execute_tool(...)
    }
    AccessControlManager <.. SecureLangChainAgent : uses

时序图索引:见“1.2 访问控制与权限校验”。

5.3 多模态(MultiModalChatModel)

关键函数(核心代码与说明):

class MultiModalChatModel(BaseChatModel):
    def _process_multimodal_message(self, message: BaseMessage) -> Dict[str, Any]:
        """将文本/图像等多模态内容规范化为统一API请求体片段。"""
        ...

    def _encode_image_file(self, image_path: str) -> str:
        """将本地图片文件以Base64编码内联,使请求在无外网场景仍可发送。"""
        with open(image_path, 'rb') as f:
            return base64.b64encode(f.read()).decode('utf-8')

统一调用链:

  • 生成:_generate() -> _process_multimodal_message() -> requests.post(/chat/completions)

类结构图(Mermaid):

classDiagram
    class BaseChatModel { <<abstract>> }
    class MultiModalChatModel {
      - model_name: str
      - modality_processors: dict
      + _generate(messages, stop, run_manager) ChatResult
      - _process_multimodal_message(message) dict
      - _encode_image_file(path) str
    }
    BaseChatModel <|-- MultiModalChatModel

时序图索引:见“2.1 多模态消息处理与模型调用”。

5.4 输出安全与SSE(sanitize_chunk / astream_with_safety)

关键函数(核心代码与说明):

def sanitize_chunk(text: str) -> str:
    """对流式分片进行正则过滤,移除常见敏感模式。"""
    for p in SENSITIVE_PATTERNS:
        text = p.sub('[REDACTED]', text)
    return text

async def astream_with_safety(chain, payload: dict, *, on_chunk) -> str:
    """在异步流式生成过程中逐片过滤并回调输出,最终拼接完整响应。"""
    full = []
    async for chunk in chain.astream(payload):
        safe = sanitize_chunk(str(chunk))
        on_chunk(safe)
        full.append(safe)
    return ''.join(full)

统一调用链:astream_with_safety() -> chain.astream() -> sanitize_chunk() -> on_chunk()

类结构图:本模块以函数为主,无独立类。

时序图索引:见“2.2 SSE流式与输出过滤”。

5.5 负载均衡与故障转移(LoadBalancedChatModel)

关键函数(核心代码与说明):

class LoadBalancedChatModel:
    def _select_provider(self) -> Optional[str]:
        """按策略在健康提供方中选择目标(轮询/加权/最快等)。"""
        ...

    def _generate_with_fallback(self, messages: list, **kwargs):
        """带重试与指标更新的生成流程,失败切换下一Provider。"""
        ...

统一调用链:_generate_with_fallback() -> _select_provider() -> provider._generate() -> 指标更新

类结构图(Mermaid):

classDiagram
    class ProviderStatus { <<enum>> }
    class ProviderMetrics {
      +response_time: float
      +success_rate: float
      +error_count: int
      +total_requests: int
    }
    class LoadBalancedChatModel {
      - providers: dict
      - provider_metrics: dict
      + _select_provider() str
      + _generate_with_fallback(messages) Any
      - _update_success_metrics(pid, rt)
      - _update_failure_metrics(pid, err)
    }

时序图索引:见“3.1 负载均衡与故障转移”。

5.6 可靠性控制(RetryPolicy / CircuitBreaker / RateLimiter / QuotaRouter)

关键函数(核心代码与说明):

class RetryPolicy:
    async def aretry(self, fn):
        """以指数退避重试异步函数,达到上限后抛出最后一次异常。"""
        ...

class CircuitBreaker:
    def allow(self) -> bool:
        """根据冷却时间窗口决定是否放行请求。"""
        ...

统一调用链:robust_generate() -> RateLimiter.acquire() -> CircuitBreaker.allow() -> RetryPolicy.aretry(call)

类结构图(Mermaid):

classDiagram
    class RetryPolicy { +aretry(fn) }
    class CircuitBreaker { +allow() +record_success() +record_failure() }
    class RateLimiter { +acquire() }
    class QuotaRouter { +select() +consume(pid,tokens) }

时序图索引:见“3.2 可靠性与配额”。

5.7 向量存储(HighPerformanceVectorStore)

关键函数(核心代码与说明):

class HighPerformanceVectorStore(VectorStore):
    def add_texts(self, texts: List[str], metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None) -> List[str]:
        """批量嵌入并写入FAISS索引,更新映射与统计,并清空查询缓存。"""
        ...

    def similarity_search_with_score(self, query: str, k: int = 4, filter: Optional[dict] = None):
        """执行嵌入→检索→过滤→分数转换→缓存→统计更新,返回(doc, score)。"""
        ...

统一调用链:

  • 写入:add_texts() -> _batch_embed_texts() -> _initialize_index() -> index.add()
  • 检索:similarity_search_with_score() -> embed_query() -> index.search() -> _distance_to_similarity()

类结构图(Mermaid):

classDiagram
    class VectorStore { <<abstract>> }
    class HighPerformanceVectorStore {
      - index
      - documents
      + add_texts(texts, metadatas, ids) list
      + similarity_search_with_score(query, k, filter) list
    }
    VectorStore <|-- HighPerformanceVectorStore

时序图索引:见“4.1 向量检索与查询缓存”。

5.8 混合检索与重排(HybridRetriever)

关键函数(核心代码与说明):

class HybridRetriever:
    def get_relevant_documents(self, query: str):
        """分别检索向量/BM25结果,使用RRF融合并可选重排,返回TopK。"""
        ...

统一调用链:get_relevant_documents() -> vec.get_relevant_documents() + bm25.get_relevant_documents() -> _rrf() -> reranker.rerank()(可选)

类结构图(Mermaid):

classDiagram
    class HybridRetriever {
      - vec
      - bm25
      - reranker
      + get_relevant_documents(query)
      - _rrf(lists)
    }

时序图索引:见“4.2 混合检索与重排”。

5.9 语义缓存(SemanticCache)

关键函数(核心代码与说明):

class SemanticCache:
    def lookup(self, query: str) -> str | None:
        """以k=1最近邻检索判断是否达到阈值,命中则返回缓存响应。"""
        ...
    def update(self, query: str, response: str):
        """将查询与响应以Document写入向量库,用于后续近似命中。"""
        ...

统一调用链:lookup() -> vs.similarity_search_with_score() -> 阈值判断;update() -> vs.add_texts()

类结构图(Mermaid):

classDiagram
    class SemanticCache {
      - emb
      - vs
      - threshold: float
      + lookup(query) str|None
      + update(query, response)
    }

时序图索引:见“4.3 语义缓存命中与回填”。

5.10 统一时序图索引

  • 安全:见“1.1 输入清理与加密解密”
  • 访问控制:见“1.2 访问控制与权限校验”
  • 审计:见“1.4 审计回调与落盘”
  • 多模态:见“2.1 多模态消息处理与模型调用”
  • SSE与输出过滤:见“2.2 SSE流式与输出过滤”
  • 负载均衡:见“3.1 负载均衡与故障转移”
  • 可靠性与配额:见“3.2 可靠性与配额”
  • 向量检索:见“4.1 向量检索与查询缓存”
  • 混合检索:见“4.2 混合检索与重排”
  • 语义缓存:见“4.3 语义缓存命中与回填”

5.11 内容整合与去重说明

  • 将各模块的“关键函数/调用链/类结构”以统一格式集中于第5章,避免在各分章重复阐述。
  • 时序图维持在原分章位置,统一在第5.10节建立索引,减少图形重复。
  • 对已出现的函数说明采用简述与索引方式,保持篇幅与可读性。

附录A. 可观测性与成本控制

A.1 指标与Tracing

from typing import Dict, Any
from langchain_core.callbacks import BaseCallbackHandler
import time

class MetricsCallback(BaseCallbackHandler):
    def __init__(self, emitter):
        self.emitter = emitter  # 可为 Prometheus/StatsD/OpenTelemetry 导出器
        self.llm_start_time = {}

    def on_llm_start(self, serialized: Dict[str, Any], prompts, *, run_id, **kwargs):
        self.llm_start_time[run_id] = time.time()

    def on_llm_end(self, response, *, run_id, **kwargs):
        start = self.llm_start_time.pop(run_id, None)
        if start:
            duration = time.time() - start
            self.emitter.gauge("llm.duration", duration)
        usage = (response.llm_output or {}).get("token_usage", {})
        self.emitter.gauge("llm.tokens.input", usage.get("prompt_tokens", 0))
        self.emitter.gauge("llm.tokens.output", usage.get("completion_tokens", 0))

A.2 成本仪表(按模型/租户)

PRICING = {
    "gpt-4": {"in": 0.03, "out": 0.06},
    "gpt-3.5-turbo": {"in": 0.001, "out": 0.002},
}

def estimate_cost(model: str, input_tokens: int, output_tokens: int) -> float:
    p = PRICING.get(model, {"in": 0.0, "out": 0.0})
    return (input_tokens / 1000) * p["in"] + (output_tokens / 1000) * p["out"]

附录B. 评测与 A/B

B.1 离线基准集

采用“任务-指标-预算”的最小评测协议:离线以关键词/要点命中为主,在线以成功率/均耗时/Token 成本三指标作为 A/B 看板,可按租户/地域分桶:

from dataclasses import dataclass
from typing import List

@dataclass
class EvalCase:
    query: str
    expected_keywords: List[str]

def offline_eval(chain, cases: List[EvalCase]) -> float:
    hits = 0
    for c in cases:
        out = chain.invoke({"question": c.query})
        text = str(out)
        if all(k.lower() in text.lower() for k in c.expected_keywords):
            hits += 1
    return hits / max(len(cases), 1)

B.2 在线 A/B(简化)

import random

class ABRouter:
    def __init__(self, variants: dict[str, Any], weights: dict[str, float]):
        self.variants = variants
        self.weights = weights

    def pick(self) -> str:
        names, ws = zip(*self.weights.items())
        r = random.random() * sum(ws)
        acc = 0
        for n, w in zip(names, ws):
            acc += w
            if r <= acc:
                return n
        return names[-1]

6. 总结

  1. 安全机制:数据加密、访问控制、隐私保护
  2. 多模态集成:图像、音频、视频等多种模态的处理
  3. 负载均衡:智能路由、故障转移、健康检查
  4. 性能优化:高性能向量存储、缓存机制、批处理优化

这些实践模式为开发者在生产环境中部署LangChain应用提供了重要的技术指导和最佳实践参考。