AutoGPT平台最佳实践与实战经验
概述
本文档汇总了AutoGPT平台开发、部署和使用过程中的最佳实践和实战经验。通过具体案例分析、性能优化技巧、安全建议和故障排除方法,帮助开发者和用户更好地使用AutoGPT平台。
开发最佳实践
1. Block开发最佳实践
1.1 Block设计原则
单一职责原则
# ✅ 好的实践:单一职责的Block
class EmailSenderBlock(Block):
"""专门负责发送邮件的Block"""
class Input(BaseModel):
to: str = Field(..., description="收件人邮箱")
subject: str = Field(..., description="邮件主题")
content: str = Field(..., description="邮件内容")
class Output(BaseModel):
success: bool = Field(..., description="发送是否成功")
message_id: str = Field(..., description="邮件ID")
async def run(self, input_data: Input, **kwargs) -> AsyncGenerator[Output, None]:
# 专注于邮件发送逻辑
result = await self.send_email(input_data)
yield "result", self.Output(
success=result.success,
message_id=result.message_id
)
# ❌ 避免:职责过多的Block
class EmailAndSMSBlock(Block):
"""不好的设计:同时处理邮件和短信"""
# 违反单一职责原则,应该拆分为两个Block
输入验证和错误处理
class DataProcessorBlock(Block):
"""数据处理Block的最佳实践"""
class Input(BaseModel):
data: List[dict] = Field(..., description="待处理数据")
operation: str = Field(..., description="操作类型")
@field_validator('data')
@classmethod
def validate_data(cls, v):
if not v:
raise ValueError("数据不能为空")
if len(v) > 1000:
raise ValueError("数据量过大,请分批处理")
return v
@field_validator('operation')
@classmethod
def validate_operation(cls, v):
allowed_ops = ['filter', 'transform', 'aggregate']
if v not in allowed_ops:
raise ValueError(f"不支持的操作类型: {v}")
return v
async def run(self, input_data: Input, **kwargs) -> AsyncGenerator[tuple[str, Any], None]:
try:
# 详细的执行日志
logger.info(f"开始处理 {len(input_data.data)} 条数据")
# 分批处理大数据集
batch_size = 100
results = []
for i in range(0, len(input_data.data), batch_size):
batch = input_data.data[i:i + batch_size]
batch_result = await self.process_batch(batch, input_data.operation)
results.extend(batch_result)
# 进度反馈
progress = (i + len(batch)) / len(input_data.data) * 100
yield "progress", {"percentage": progress, "processed": i + len(batch)}
yield "result", {"data": results, "total": len(results)}
except Exception as e:
logger.error(f"数据处理失败: {e}")
yield "error", f"处理失败: {str(e)}"
async def process_batch(self, batch: List[dict], operation: str) -> List[dict]:
"""批处理逻辑"""
# 实现具体的批处理逻辑
pass
1.2 异步和并发处理
正确使用异步操作
class APICallBlock(Block):
"""API调用Block的异步最佳实践"""
async def run(self, input_data: Input, **kwargs) -> AsyncGenerator[tuple[str, Any], None]:
# ✅ 使用异步HTTP客户端
async with aiohttp.ClientSession() as session:
try:
async with session.post(
input_data.url,
json=input_data.payload,
timeout=aiohttp.ClientTimeout(total=30)
) as response:
if response.status == 200:
result = await response.json()
yield "success", result
else:
yield "error", f"API调用失败: {response.status}"
except asyncio.TimeoutError:
yield "error", "API调用超时"
except Exception as e:
yield "error", f"API调用异常: {str(e)}"
class ConcurrentProcessorBlock(Block):
"""并发处理Block的最佳实践"""
async def run(self, input_data: Input, **kwargs) -> AsyncGenerator[tuple[str, Any], None]:
# ✅ 控制并发数量,避免资源耗尽
semaphore = asyncio.Semaphore(10) # 最多10个并发
async def process_item(item):
async with semaphore:
return await self.process_single_item(item)
# 批量并发处理
tasks = [process_item(item) for item in input_data.items]
results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理结果和异常
successful_results = []
errors = []
for i, result in enumerate(results):
if isinstance(result, Exception):
errors.append(f"Item {i}: {str(result)}")
else:
successful_results.append(result)
yield "results", successful_results
if errors:
yield "errors", errors
1.3 资源管理和清理
正确的资源管理
class FileProcessorBlock(Block):
"""文件处理Block的资源管理最佳实践"""
async def run(self, input_data: Input, **kwargs) -> AsyncGenerator[tuple[str, Any], None]:
temp_files = []
try:
# 创建临时文件
temp_file = await self.create_temp_file()
temp_files.append(temp_file)
# 处理文件
result = await self.process_file(temp_file, input_data)
yield "result", result
except Exception as e:
yield "error", str(e)
finally:
# ✅ 确保资源清理
for temp_file in temp_files:
try:
if os.path.exists(temp_file):
os.remove(temp_file)
except Exception as cleanup_error:
logger.warning(f"清理临时文件失败: {cleanup_error}")
class DatabaseBlock(Block):
"""数据库操作Block的连接管理"""
async def run(self, input_data: Input, **kwargs) -> AsyncGenerator[tuple[str, Any], None]:
# ✅ 使用连接池和上下文管理器
async with self.get_db_connection() as conn:
try:
async with conn.transaction():
result = await self.execute_query(conn, input_data.query)
yield "result", result
except Exception as e:
# 事务会自动回滚
yield "error", f"数据库操作失败: {str(e)}"
2. 图设计最佳实践
2.1 图结构设计
合理的图结构
# ✅ 好的图设计:清晰的数据流
"""
输入节点 → 数据验证 → 数据处理 → 结果输出
↓
错误处理 → 日志记录
"""
# 图设计原则:
# 1. 单向数据流,避免循环依赖
# 2. 错误处理节点处理异常情况
# 3. 日志节点记录执行过程
# 4. 输出节点统一结果格式
class GraphDesignHelper:
"""图设计辅助工具"""
@staticmethod
def validate_graph_structure(graph: Graph) -> List[str]:
"""验证图结构的最佳实践"""
issues = []
# 检查是否有入口节点
entry_nodes = graph.get_entry_nodes()
if not entry_nodes:
issues.append("图缺少入口节点")
# 检查是否有出口节点
exit_nodes = graph.get_exit_nodes()
if not exit_nodes:
issues.append("图缺少出口节点")
# 检查循环依赖
if GraphDesignHelper.has_cycles(graph):
issues.append("图存在循环依赖")
# 检查孤立节点
isolated_nodes = GraphDesignHelper.find_isolated_nodes(graph)
if isolated_nodes:
issues.append(f"发现孤立节点: {isolated_nodes}")
# 检查错误处理
has_error_handling = any(
'error' in node.block_id.lower()
for node in graph.nodes
)
if not has_error_handling:
issues.append("建议添加错误处理节点")
return issues
@staticmethod
def suggest_optimizations(graph: Graph) -> List[str]:
"""图优化建议"""
suggestions = []
# 检查并行处理机会
parallel_opportunities = GraphDesignHelper.find_parallel_opportunities(graph)
if parallel_opportunities:
suggestions.append("可以并行处理的节点组合")
# 检查缓存机会
cacheable_nodes = GraphDesignHelper.find_cacheable_nodes(graph)
if cacheable_nodes:
suggestions.append("建议为重复计算节点添加缓存")
# 检查资源使用
resource_intensive_nodes = GraphDesignHelper.find_resource_intensive_nodes(graph)
if resource_intensive_nodes:
suggestions.append("资源密集型节点建议添加资源限制")
return suggestions
2.2 错误处理策略
完善的错误处理
class ErrorHandlingStrategy:
"""错误处理策略最佳实践"""
@staticmethod
def create_error_handling_subgraph() -> List[Node]:
"""创建标准的错误处理子图"""
return [
Node(
id="error_detector",
block_id="ErrorDetectorBlock",
label="错误检测",
input_default={"error_types": ["validation", "execution", "timeout"]}
),
Node(
id="error_logger",
block_id="LoggerBlock",
label="错误日志",
input_default={"level": "ERROR"}
),
Node(
id="error_notifier",
block_id="NotificationBlock",
label="错误通知",
input_default={"channels": ["email", "slack"]}
),
Node(
id="fallback_handler",
block_id="FallbackBlock",
label="降级处理",
input_default={"strategy": "default_response"}
)
]
@staticmethod
def create_retry_mechanism() -> Node:
"""创建重试机制节点"""
return Node(
id="retry_handler",
block_id="RetryBlock",
label="重试处理",
input_default={
"max_retries": 3,
"backoff_strategy": "exponential",
"retry_conditions": ["timeout", "rate_limit", "server_error"]
}
)
3. 代码质量最佳实践
3.1 代码规范
代码风格和文档
class ExampleBlock(Block):
"""
示例Block的文档最佳实践
这个Block演示了如何编写高质量的Block代码,包括:
- 完整的类型注解
- 详细的文档字符串
- 合理的错误处理
- 清晰的日志记录
使用示例:
block = ExampleBlock()
async for output_name, output_data in block.run(input_data):
print(f"{output_name}: {output_data}")
注意事项:
- 输入数据必须包含required_field字段
- 处理大数据时会自动分批
- 失败时会返回详细的错误信息
"""
# 类型注解
id: str = "example_block"
name: str = "示例Block"
description: str = "用于演示最佳实践的示例Block"
class Input(BaseModel):
"""
输入数据模型
字段说明:
required_field: 必需字段,不能为空
optional_field: 可选字段,有默认值
config: 配置参数,影响处理行为
"""
required_field: str = Field(
...,
description="必需的输入字段",
min_length=1,
max_length=1000
)
optional_field: Optional[str] = Field(
None,
description="可选的输入字段"
)
config: dict = Field(
default_factory=dict,
description="配置参数"
)
class Output(BaseModel):
"""
输出数据模型
字段说明:
result: 处理结果
metadata: 处理元数据
stats: 处理统计信息
"""
result: Any = Field(..., description="处理结果")
metadata: dict = Field(default_factory=dict, description="元数据")
stats: dict = Field(default_factory=dict, description="统计信息")
async def run(
self,
input_data: Input,
**kwargs
) -> AsyncGenerator[tuple[str, Any], None]:
"""
执行Block逻辑
参数:
input_data: 验证过的输入数据
**kwargs: 额外的执行参数,可能包含:
- user_id: 用户ID
- execution_id: 执行ID
- credentials: 凭据信息
生成:
tuple[str, Any]: (输出名称, 输出数据) 元组
异常:
ValidationError: 输入数据验证失败
ProcessingError: 处理过程中的错误
"""
# 获取执行上下文
user_id = kwargs.get('user_id')
execution_id = kwargs.get('execution_id')
# 记录开始日志
logger.info(
f"开始执行Block {self.id}",
extra={
"user_id": user_id,
"execution_id": execution_id,
"input_size": len(str(input_data.required_field))
}
)
try:
# 执行处理逻辑
start_time = time.time()
result = await self._process_data(input_data)
processing_time = time.time() - start_time
# 构建输出
output = self.Output(
result=result,
metadata={
"processing_time": processing_time,
"input_length": len(input_data.required_field),
"timestamp": datetime.utcnow().isoformat()
},
stats={
"success": True,
"error_count": 0
}
)
# 记录成功日志
logger.info(
f"Block {self.id} 执行成功",
extra={
"processing_time": processing_time,
"output_size": len(str(result))
}
)
yield "result", output
except Exception as e:
# 记录错误日志
logger.error(
f"Block {self.id} 执行失败: {str(e)}",
extra={
"user_id": user_id,
"execution_id": execution_id,
"error_type": type(e).__name__
},
exc_info=True
)
# 返回错误信息
yield "error", {
"error_type": type(e).__name__,
"error_message": str(e),
"timestamp": datetime.utcnow().isoformat()
}
async def _process_data(self, input_data: Input) -> Any:
"""
私有方法:处理数据的核心逻辑
参数:
input_data: 输入数据
返回:
Any: 处理结果
"""
# 实现具体的处理逻辑
pass
3.2 测试最佳实践
完整的测试覆盖
import pytest
import asyncio
from unittest.mock import Mock, patch
from your_block import ExampleBlock
class TestExampleBlock:
"""ExampleBlock的测试用例"""
@pytest.fixture
def block(self):
"""测试用的Block实例"""
return ExampleBlock()
@pytest.fixture
def valid_input(self):
"""有效的输入数据"""
return ExampleBlock.Input(
required_field="test data",
optional_field="optional",
config={"setting": "value"}
)
@pytest.mark.asyncio
async def test_successful_execution(self, block, valid_input):
"""测试成功执行的情况"""
results = []
async for output_name, output_data in block.run(valid_input):
results.append((output_name, output_data))
# 验证输出
assert len(results) == 1
output_name, output_data = results[0]
assert output_name == "result"
assert isinstance(output_data, ExampleBlock.Output)
assert output_data.stats["success"] is True
@pytest.mark.asyncio
async def test_input_validation(self, block):
"""测试输入验证"""
# 测试必需字段缺失
with pytest.raises(ValidationError):
invalid_input = ExampleBlock.Input(
required_field="", # 空字符串应该失败
optional_field="test"
)
@pytest.mark.asyncio
async def test_error_handling(self, block, valid_input):
"""测试错误处理"""
# 模拟处理过程中的异常
with patch.object(block, '_process_data', side_effect=Exception("Test error")):
results = []
async for output_name, output_data in block.run(valid_input):
results.append((output_name, output_data))
# 验证错误输出
assert len(results) == 1
output_name, output_data = results[0]
assert output_name == "error"
assert "Test error" in output_data["error_message"]
@pytest.mark.asyncio
async def test_performance(self, block, valid_input):
"""测试性能要求"""
import time
start_time = time.time()
results = []
async for output_name, output_data in block.run(valid_input):
results.append((output_name, output_data))
end_time = time.time()
# 验证执行时间在合理范围内
execution_time = end_time - start_time
assert execution_time < 5.0 # 应该在5秒内完成
@pytest.mark.asyncio
async def test_concurrent_execution(self, block, valid_input):
"""测试并发执行"""
# 创建多个并发任务
tasks = []
for i in range(10):
task_input = ExampleBlock.Input(
required_field=f"test data {i}",
config={"task_id": i}
)
task = asyncio.create_task(self._run_block(block, task_input))
tasks.append(task)
# 等待所有任务完成
results = await asyncio.gather(*tasks)
# 验证所有任务都成功完成
assert len(results) == 10
for result in results:
assert result is not None
async def _run_block(self, block, input_data):
"""辅助方法:运行Block并返回结果"""
async for output_name, output_data in block.run(input_data):
if output_name == "result":
return output_data
return None
# 集成测试
class TestBlockIntegration:
"""Block集成测试"""
@pytest.mark.asyncio
async def test_block_in_graph(self):
"""测试Block在图中的执行"""
# 创建包含ExampleBlock的测试图
graph = create_test_graph_with_example_block()
# 执行图
execution_result = await execute_graph(graph, test_inputs)
# 验证执行结果
assert execution_result.status == "completed"
assert "result" in execution_result.outputs
@pytest.mark.asyncio
async def test_block_with_credentials(self):
"""测试需要凭据的Block"""
# 模拟凭据
mock_credentials = {
"api_provider": Mock(spec=["get_token", "is_valid"])
}
mock_credentials["api_provider"].is_valid.return_value = True
mock_credentials["api_provider"].get_token.return_value = "test_token"
# 使用凭据执行Block
block = ExampleBlock()
input_data = ExampleBlock.Input(required_field="test")
results = []
async for output_name, output_data in block.run(
input_data,
credentials=mock_credentials
):
results.append((output_name, output_data))
# 验证凭据被正确使用
mock_credentials["api_provider"].is_valid.assert_called_once()
架构设计原则
1. 微服务架构原则
1.1 服务拆分策略
按业务能力拆分
# ✅ 好的服务拆分
"""
用户服务 (UserService)
- 用户注册、登录、信息管理
- 权限和角色管理
图服务 (GraphService)
- 图的创建、编辑、版本管理
- 图结构验证和优化
执行服务 (ExecutionService)
- 图执行调度和管理
- 执行状态跟踪
Block服务 (BlockService)
- Block注册和管理
- Block元数据和schema
通知服务 (NotificationService)
- 邮件、短信、推送通知
- 通知模板和规则管理
"""
class ServiceBoundaryPrinciples:
"""服务边界设计原则"""
@staticmethod
def evaluate_service_design(service_name: str, responsibilities: List[str]) -> dict:
"""评估服务设计的合理性"""
evaluation = {
"service_name": service_name,
"score": 0,
"issues": [],
"suggestions": []
}
# 检查单一职责
if len(responsibilities) > 5:
evaluation["issues"].append("职责过多,建议拆分")
evaluation["score"] -= 20
# 检查职责相关性
if not ServiceBoundaryPrinciples._are_responsibilities_related(responsibilities):
evaluation["issues"].append("职责关联性不强")
evaluation["score"] -= 15
# 检查数据一致性
if ServiceBoundaryPrinciples._requires_distributed_transaction(responsibilities):
evaluation["issues"].append("可能需要分布式事务")
evaluation["suggestions"].append("考虑使用Saga模式")
return evaluation
1.2 服务间通信
异步消息通信
class EventDrivenCommunication:
"""事件驱动的服务间通信"""
def __init__(self, message_broker):
self.broker = message_broker
self.event_handlers = {}
async def publish_event(self, event_type: str, event_data: dict):
"""发布事件"""
event = {
"event_type": event_type,
"event_data": event_data,
"timestamp": datetime.utcnow().isoformat(),
"event_id": str(uuid.uuid4())
}
await self.broker.publish(
exchange="events",
routing_key=event_type,
message=json.dumps(event)
)
logger.info(f"Published event: {event_type}")
async def subscribe_to_event(self, event_type: str, handler: callable):
"""订阅事件"""
self.event_handlers[event_type] = handler
await self.broker.subscribe(
queue=f"{event_type}_queue",
callback=self._handle_event
)
async def _handle_event(self, message):
"""处理接收到的事件"""
try:
event = json.loads(message.body)
event_type = event["event_type"]
if event_type in self.event_handlers:
handler = self.event_handlers[event_type]
await handler(event["event_data"])
await message.ack()
except Exception as e:
logger.error(f"Event handling failed: {e}")
await message.nack(requeue=True)
# 使用示例
class GraphService:
"""图服务的事件处理示例"""
def __init__(self):
self.event_bus = EventDrivenCommunication(message_broker)
self._setup_event_handlers()
async def _setup_event_handlers(self):
"""设置事件处理器"""
await self.event_bus.subscribe_to_event(
"user.deleted",
self._handle_user_deleted
)
await self.event_bus.subscribe_to_event(
"execution.completed",
self._handle_execution_completed
)
async def _handle_user_deleted(self, event_data):
"""处理用户删除事件"""
user_id = event_data["user_id"]
# 删除用户的所有图
await self.delete_user_graphs(user_id)
# 发布图删除事件
await self.event_bus.publish_event(
"graphs.user_deleted",
{"user_id": user_id, "deleted_count": deleted_count}
)
async def create_graph(self, graph_data):
"""创建图并发布事件"""
graph = await self._create_graph_in_db(graph_data)
# 发布图创建事件
await self.event_bus.publish_event(
"graph.created",
{
"graph_id": graph.id,
"user_id": graph.user_id,
"graph_name": graph.name
}
)
return graph
2. 数据一致性策略
2.1 分布式事务处理
Saga模式实现
class SagaOrchestrator:
"""Saga事务协调器"""
def __init__(self):
self.steps = []
self.compensations = []
def add_step(self, action: callable, compensation: callable):
"""添加事务步骤"""
self.steps.append(action)
self.compensations.append(compensation)
async def execute(self) -> bool:
"""执行Saga事务"""
executed_steps = []
try:
# 顺序执行所有步骤
for i, step in enumerate(self.steps):
result = await step()
executed_steps.append(i)
if not result:
raise Exception(f"Step {i} failed")
return True
except Exception as e:
logger.error(f"Saga execution failed: {e}")
# 执行补偿操作
await self._compensate(executed_steps)
return False
async def _compensate(self, executed_steps: List[int]):
"""执行补偿操作"""
# 按相反顺序执行补偿
for step_index in reversed(executed_steps):
try:
compensation = self.compensations[step_index]
await compensation()
except Exception as e:
logger.error(f"Compensation {step_index} failed: {e}")
# 使用示例:用户注册的分布式事务
class UserRegistrationSaga:
"""用户注册的Saga事务"""
async def register_user(self, user_data: dict) -> bool:
saga = SagaOrchestrator()
# 步骤1:创建用户账户
saga.add_step(
action=lambda: self._create_user_account(user_data),
compensation=lambda: self._delete_user_account(user_data["user_id"])
)
# 步骤2:初始化用户积分
saga.add_step(
action=lambda: self._initialize_user_credits(user_data["user_id"]),
compensation=lambda: self._delete_user_credits(user_data["user_id"])
)
# 步骤3:发送欢迎邮件
saga.add_step(
action=lambda: self._send_welcome_email(user_data["email"]),
compensation=lambda: self._send_cancellation_email(user_data["email"])
)
# 步骤4:创建默认工作空间
saga.add_step(
action=lambda: self._create_default_workspace(user_data["user_id"]),
compensation=lambda: self._delete_user_workspace(user_data["user_id"])
)
return await saga.execute()
2.2 最终一致性处理
事件溯源和CQRS
class EventStore:
"""事件存储"""
def __init__(self, storage_backend):
self.storage = storage_backend
async def append_event(self, stream_id: str, event: dict):
"""追加事件到流"""
event_with_metadata = {
**event,
"event_id": str(uuid.uuid4()),
"timestamp": datetime.utcnow().isoformat(),
"stream_id": stream_id
}
await self.storage.append(stream_id, event_with_metadata)
# 发布事件到消息总线
await self._publish_event(event_with_metadata)
async def get_events(self, stream_id: str, from_version: int = 0) -> List[dict]:
"""获取流中的事件"""
return await self.storage.get_events(stream_id, from_version)
async def _publish_event(self, event: dict):
"""发布事件到消息总线"""
# 实现事件发布逻辑
pass
class GraphAggregate:
"""图聚合根"""
def __init__(self, graph_id: str):
self.graph_id = graph_id
self.version = 0
self.uncommitted_events = []
self.state = GraphState()
def create_graph(self, name: str, user_id: str):
"""创建图"""
event = {
"event_type": "GraphCreated",
"data": {
"graph_id": self.graph_id,
"name": name,
"user_id": user_id
}
}
self._apply_event(event)
self.uncommitted_events.append(event)
def add_node(self, node_data: dict):
"""添加节点"""
event = {
"event_type": "NodeAdded",
"data": {
"graph_id": self.graph_id,
"node_id": node_data["id"],
"node_data": node_data
}
}
self._apply_event(event)
self.uncommitted_events.append(event)
def _apply_event(self, event: dict):
"""应用事件到状态"""
event_type = event["event_type"]
data = event["data"]
if event_type == "GraphCreated":
self.state.name = data["name"]
self.state.user_id = data["user_id"]
self.state.created_at = datetime.utcnow()
elif event_type == "NodeAdded":
self.state.nodes[data["node_id"]] = data["node_data"]
self.version += 1
async def save(self, event_store: EventStore):
"""保存未提交的事件"""
for event in self.uncommitted_events:
await event_store.append_event(self.graph_id, event)
self.uncommitted_events.clear()
class GraphProjection:
"""图投影(读模型)"""
def __init__(self, read_db):
self.read_db = read_db
async def handle_graph_created(self, event_data: dict):
"""处理图创建事件"""
graph_record = {
"id": event_data["graph_id"],
"name": event_data["name"],
"user_id": event_data["user_id"],
"created_at": datetime.utcnow(),
"node_count": 0
}
await self.read_db.insert("graphs", graph_record)
async def handle_node_added(self, event_data: dict):
"""处理节点添加事件"""
await self.read_db.increment(
"graphs",
{"id": event_data["graph_id"]},
{"node_count": 1}
)
# 更新节点索引
node_record = {
"id": event_data["node_id"],
"graph_id": event_data["graph_id"],
**event_data["node_data"]
}
await self.read_db.insert("nodes", node_record)
性能优化实战
1. 数据库优化
1.1 查询优化
索引策略
-- ✅ 高效的索引设计
-- 1. 复合索引:按查询频率和选择性排序
CREATE INDEX idx_graph_user_active ON graphs(user_id, is_active, created_at DESC);
-- 2. 部分索引:只为活跃数据创建索引
CREATE INDEX idx_active_executions ON graph_executions(status, created_at)
WHERE status IN ('queued', 'running');
-- 3. 表达式索引:为计算字段创建索引
CREATE INDEX idx_execution_duration ON graph_executions(
EXTRACT(EPOCH FROM (ended_at - started_at))
) WHERE ended_at IS NOT NULL;
-- 4. 覆盖索引:包含查询所需的所有字段
CREATE INDEX idx_graph_list_covering ON graphs(user_id, is_active)
INCLUDE (id, name, created_at, updated_at);
查询优化技巧
class OptimizedGraphRepository:
"""优化的图数据访问层"""
async def get_user_graphs_paginated(
self,
user_id: str,
page: int = 1,
page_size: int = 20,
include_inactive: bool = False
) -> dict:
"""分页获取用户图列表(优化版本)"""
# ✅ 使用覆盖索引避免回表查询
base_query = """
SELECT id, name, description, is_active, created_at, updated_at,
(SELECT COUNT(*) FROM nodes WHERE graph_id = g.id) as node_count
FROM graphs g
WHERE user_id = $1
"""
conditions = []
params = [user_id]
if not include_inactive:
conditions.append("AND is_active = $2")
params.append(True)
# ✅ 使用游标分页而不是OFFSET
if page > 1:
# 假设前端传递了last_created_at
conditions.append("AND created_at < $" + str(len(params) + 1))
params.append(last_created_at)
query = base_query + " ".join(conditions) + """
ORDER BY created_at DESC
LIMIT $""" + str(len(params) + 1)
params.append(page_size)
# 执行查询
graphs = await self.db.fetch(query, *params)
# ✅ 批量获取相关数据,避免N+1查询
graph_ids = [g['id'] for g in graphs]
if graph_ids:
# 批量获取执行统计
exec_stats = await self._get_execution_stats_batch(graph_ids)
# 合并数据
for graph in graphs:
graph['execution_stats'] = exec_stats.get(graph['id'], {})
return {
'graphs': graphs,
'has_more': len(graphs) == page_size
}
async def _get_execution_stats_batch(self, graph_ids: List[str]) -> dict:
"""批量获取执行统计"""
query = """
SELECT
graph_id,
COUNT(*) as total_executions,
COUNT(*) FILTER (WHERE status = 'completed') as successful_executions,
AVG(EXTRACT(EPOCH FROM (ended_at - started_at))) FILTER (WHERE ended_at IS NOT NULL) as avg_duration
FROM graph_executions
WHERE graph_id = ANY($1)
GROUP BY graph_id
"""
results = await self.db.fetch(query, graph_ids)
return {r['graph_id']: dict(r) for r in results}
1.2 连接池优化
连接池配置
class DatabaseConnectionManager:
"""数据库连接管理器"""
def __init__(self):
self.pools = {}
async def create_pool(self, database_url: str, **kwargs):
"""创建优化的连接池"""
pool_config = {
# ✅ 连接池大小配置
'min_size': 5, # 最小连接数
'max_size': 20, # 最大连接数
# ✅ 连接超时配置
'command_timeout': 30, # 命令超时
'server_settings': {
'application_name': 'autogpt_platform',
'tcp_keepalives_idle': '600',
'tcp_keepalives_interval': '30',
'tcp_keepalives_count': '3',
},
# ✅ 连接验证
'init': self._init_connection,
**kwargs
}
pool = await asyncpg.create_pool(database_url, **pool_config)
self.pools['default'] = pool
return pool
async def _init_connection(self, conn):
"""初始化连接"""
# 设置连接级别的配置
await conn.execute("SET timezone = 'UTC'")
await conn.execute("SET statement_timeout = '30s'")
# 预编译常用查询
await conn.prepare("SELECT * FROM graphs WHERE id = $1")
await conn.prepare("SELECT * FROM users WHERE id = $1")
async def get_connection(self, pool_name: str = 'default'):
"""获取连接"""
pool = self.pools.get(pool_name)
if not pool:
raise ValueError(f"Pool {pool_name} not found")
return pool.acquire()
async def execute_with_retry(self, query: str, *args, max_retries: int = 3):
"""带重试的查询执行"""
for attempt in range(max_retries):
try:
async with self.get_connection() as conn:
return await conn.fetch(query, *args)
except (asyncpg.ConnectionDoesNotExistError,
asyncpg.InterfaceError) as e:
if attempt == max_retries - 1:
raise
logger.warning(f"Database connection error, retrying: {e}")
await asyncio.sleep(2 ** attempt) # 指数退避
2. 缓存策略
2.1 多级缓存
缓存层次设计
class MultiLevelCache:
"""多级缓存实现"""
def __init__(self):
# L1: 内存缓存(最快)
self.l1_cache = {}
self.l1_max_size = 1000
self.l1_ttl = 300 # 5分钟
# L2: Redis缓存(中等速度)
self.l2_cache = None # Redis客户端
self.l2_ttl = 3600 # 1小时
# L3: 数据库(最慢但最可靠)
self.l3_cache = None # 数据库连接
async def get(self, key: str, fetch_func: callable = None):
"""多级缓存获取"""
# L1缓存检查
l1_value = self._get_from_l1(key)
if l1_value is not None:
return l1_value
# L2缓存检查
l2_value = await self._get_from_l2(key)
if l2_value is not None:
# 回填L1缓存
self._set_to_l1(key, l2_value)
return l2_value
# L3缓存或原始数据源
if fetch_func:
l3_value = await fetch_func()
if l3_value is not None:
# 回填所有缓存层
await self._set_to_l2(key, l3_value)
self._set_to_l1(key, l3_value)
return l3_value
return None
async def set(self, key: str, value: any, ttl: int = None):
"""设置多级缓存"""
# 设置所有缓存层
self._set_to_l1(key, value)
await self._set_to_l2(key, value, ttl or self.l2_ttl)
async def invalidate(self, key: str):
"""失效所有缓存层"""
self._remove_from_l1(key)
await self._remove_from_l2(key)
def _get_from_l1(self, key: str):
"""从L1缓存获取"""
entry = self.l1_cache.get(key)
if entry and time.time() - entry['timestamp'] < self.l1_ttl:
return entry['value']
elif entry:
# 过期删除
del self.l1_cache[key]
return None
def _set_to_l1(self, key: str, value: any):
"""设置L1缓存"""
# LRU淘汰
if len(self.l1_cache) >= self.l1_max_size:
oldest_key = min(
self.l1_cache.keys(),
key=lambda k: self.l1_cache[k]['timestamp']
)
del self.l1_cache[oldest_key]
self.l1_cache[key] = {
'value': value,
'timestamp': time.time()
}
# 使用示例
class CachedGraphService:
"""带缓存的图服务"""
def __init__(self):
self.cache = MultiLevelCache()
self.db = DatabaseConnectionManager()
async def get_graph(self, graph_id: str) -> dict:
"""获取图(带缓存)"""
cache_key = f"graph:{graph_id}"
return await self.cache.get(
cache_key,
fetch_func=lambda: self._fetch_graph_from_db(graph_id)
)
async def update_graph(self, graph_id: str, updates: dict) -> dict:
"""更新图并失效缓存"""
# 更新数据库
updated_graph = await self._update_graph_in_db(graph_id, updates)
# 失效相关缓存
await self.cache.invalidate(f"graph:{graph_id}")
await self.cache.invalidate(f"user_graphs:{updated_graph['user_id']}")
# 预热缓存
await self.cache.set(f"graph:{graph_id}", updated_graph)
return updated_graph
2.2 缓存预热和失效策略
智能缓存管理
class CacheWarmupManager:
"""缓存预热管理器"""
def __init__(self, cache_service, analytics_service):
self.cache = cache_service
self.analytics = analytics_service
async def warmup_popular_content(self):
"""预热热门内容"""
# 获取热门图列表
popular_graphs = await self.analytics.get_popular_graphs(limit=100)
# 并发预热
tasks = []
for graph_id in popular_graphs:
task = asyncio.create_task(self._warmup_graph(graph_id))
tasks.append(task)
await asyncio.gather(*tasks, return_exceptions=True)
async def _warmup_graph(self, graph_id: str):
"""预热单个图的缓存"""
try:
# 预热图基本信息
await self.cache.get(f"graph:{graph_id}")
# 预热图的节点信息
await self.cache.get(f"graph_nodes:{graph_id}")
# 预热最近的执行记录
await self.cache.get(f"graph_executions:{graph_id}")
except Exception as e:
logger.warning(f"Failed to warmup graph {graph_id}: {e}")
async def schedule_cache_cleanup(self):
"""定期缓存清理"""
while True:
try:
# 清理过期的缓存项
await self._cleanup_expired_cache()
# 清理低命中率的缓存项
await self._cleanup_low_hit_rate_cache()
# 等待下次清理
await asyncio.sleep(3600) # 每小时清理一次
except Exception as e:
logger.error(f"Cache cleanup failed: {e}")
await asyncio.sleep(300) # 出错时5分钟后重试
class CacheInvalidationStrategy:
"""缓存失效策略"""
def __init__(self, cache_service, event_bus):
self.cache = cache_service
self.event_bus = event_bus
self._setup_event_handlers()
def _setup_event_handlers(self):
"""设置事件处理器"""
self.event_bus.subscribe("graph.updated", self._handle_graph_updated)
self.event_bus.subscribe("user.updated", self._handle_user_updated)
self.event_bus.subscribe("execution.completed", self._handle_execution_completed)
async def _handle_graph_updated(self, event_data: dict):
"""处理图更新事件"""
graph_id = event_data["graph_id"]
user_id = event_data["user_id"]
# 失效相关缓存
invalidation_keys = [
f"graph:{graph_id}",
f"graph_nodes:{graph_id}",
f"user_graphs:{user_id}",
f"graph_stats:{graph_id}"
]
for key in invalidation_keys:
await self.cache.invalidate(key)
async def _handle_execution_completed(self, event_data: dict):
"""处理执行完成事件"""
graph_id = event_data["graph_id"]
user_id = event_data["user_id"]
# 失效统计相关的缓存
await self.cache.invalidate(f"graph_stats:{graph_id}")
await self.cache.invalidate(f"user_stats:{user_id}")
await self.cache.invalidate(f"execution_history:{graph_id}")
3. 异步处理优化
3.1 任务队列优化
高效的任务调度
class OptimizedTaskQueue:
"""优化的任务队列"""
def __init__(self, redis_client):
self.redis = redis_client
self.worker_pools = {}
self.metrics = TaskQueueMetrics()
async def enqueue_task(
self,
task_type: str,
task_data: dict,
priority: int = 0,
delay: int = 0
):
"""入队任务"""
task = {
"id": str(uuid.uuid4()),
"type": task_type,
"data": task_data,
"priority": priority,
"created_at": time.time(),
"execute_at": time.time() + delay
}
# 根据优先级选择队列
queue_name = f"tasks:{task_type}:priority_{priority}"
if delay > 0:
# 延迟任务放入延迟队列
await self.redis.zadd(
"delayed_tasks",
{json.dumps(task): task["execute_at"]}
)
else:
# 立即执行的任务
await self.redis.lpush(queue_name, json.dumps(task))
self.metrics.task_enqueued(task_type, priority)
return task["id"]
async def start_worker_pool(self, task_type: str, worker_count: int = 5):
"""启动工作进程池"""
workers = []
for i in range(worker_count):
worker = TaskWorker(
worker_id=f"{task_type}_worker_{i}",
task_type=task_type,
redis_client=self.redis,
metrics=self.metrics
)
workers.append(worker)
asyncio.create_task(worker.start())
self.worker_pools[task_type] = workers
async def start_delayed_task_processor(self):
"""启动延迟任务处理器"""
while True:
try:
current_time = time.time()
# 获取到期的延迟任务
ready_tasks = await self.redis.zrangebyscore(
"delayed_tasks",
0,
current_time,
withscores=True
)
for task_data, score in ready_tasks:
task = json.loads(task_data)
# 移动到对应的执行队列
queue_name = f"tasks:{task['type']}:priority_{task['priority']}"
await self.redis.lpush(queue_name, task_data)
# 从延迟队列中移除
await self.redis.zrem("delayed_tasks", task_data)
await asyncio.sleep(1) # 每秒检查一次
except Exception as e:
logger.error(f"Delayed task processor error: {e}")
await asyncio.sleep(5)
class TaskWorker:
"""任务工作进程"""
def __init__(self, worker_id: str, task_type: str, redis_client, metrics):
self.worker_id = worker_id
self.task_type = task_type
self.redis = redis_client
self.metrics = metrics
self.is_running = False
async def start(self):
"""启动工作进程"""
self.is_running = True
logger.info(f"Worker {self.worker_id} started")
while self.is_running:
try:
# 按优先级获取任务
task_data = await self._get_next_task()
if task_data:
await self._process_task(task_data)
else:
# 没有任务时短暂休眠
await asyncio.sleep(0.1)
except Exception as e:
logger.error(f"Worker {self.worker_id} error: {e}")
await asyncio.sleep(1)
async def _get_next_task(self) -> dict:
"""获取下一个任务(按优先级)"""
# 按优先级顺序检查队列
for priority in [9, 8, 7, 6, 5, 4, 3, 2, 1, 0]:
queue_name = f"tasks:{self.task_type}:priority_{priority}"
task_data = await self.redis.brpop(queue_name, timeout=1)
if task_data:
return json.loads(task_data[1])
return None
async def _process_task(self, task_data: dict):
"""处理任务"""
task_id = task_data["id"]
start_time = time.time()
try:
# 记录任务开始
self.metrics.task_started(self.task_type, task_id)
# 执行任务
result = await self._execute_task(task_data)
# 记录成功
duration = time.time() - start_time
self.metrics.task_completed(self.task_type, task_id, duration)
logger.info(f"Task {task_id} completed in {duration:.2f}s")
except Exception as e:
# 记录失败
duration = time.time() - start_time
self.metrics.task_failed(self.task_type, task_id, duration, str(e))
# 重试逻辑
await self._handle_task_failure(task_data, e)
async def _execute_task(self, task_data: dict):
"""执行具体任务"""
task_type = task_data["type"]
task_payload = task_data["data"]
# 根据任务类型分发到不同的处理器
if task_type == "graph_execution":
return await self._execute_graph_task(task_payload)
elif task_type == "notification":
return await self._execute_notification_task(task_payload)
elif task_type == "data_processing":
return await self._execute_data_processing_task(task_payload)
else:
raise ValueError(f"Unknown task type: {task_type}")
安全防护策略
1. 认证和授权
1.1 多层认证体系
JWT + API Key双重认证
class SecurityManager:
"""安全管理器"""
def __init__(self):
self.jwt_secret = os.getenv("JWT_SECRET")
self.api_key_store = APIKeyStore()
self.rate_limiter = RateLimiter()
async def authenticate_request(self, request: Request) -> AuthContext:
"""请求认证"""
auth_context = AuthContext()
# 1. JWT认证
jwt_token = self._extract_jwt_token(request)
if jwt_token:
jwt_payload = await self._validate_jwt_token(jwt_token)
if jwt_payload:
auth_context.user_id = jwt_payload["user_id"]
auth_context.user_role = jwt_payload["role"]
auth_context.auth_method = "jwt"
# 2. API Key认证(备选或补充)
api_key = self._extract_api_key(request)
if api_key and not auth_context.user_id:
api_key_info = await self.api_key_store.validate_key(api_key)
if api_key_info:
auth_context.user_id = api_key_info["user_id"]
auth_context.user_role = api_key_info["role"]
auth_context.auth_method = "api_key"
auth_context.api_key_id = api_key_info["id"]
# 3. 限流检查
if auth_context.user_id:
rate_limit_key = f"user:{auth_context.user_id}"
if not await self.rate_limiter.check_limit(rate_limit_key):
raise RateLimitExceededError("Rate limit exceeded")
return auth_context
async def authorize_action(
self,
auth_context: AuthContext,
resource: str,
action: str
) -> bool:
"""动作授权"""
# 管理员拥有所有权限
if auth_context.user_role == "admin":
return True
# 检查资源所有权
if resource.startswith("graph:"):
graph_id = resource.split(":")[1]
graph_owner = await self._get_graph_owner(graph_id)
if graph_owner == auth_context.user_id:
return True
# 检查共享权限
if action == "read":
return await self._check_graph_sharing(graph_id, auth_context.user_id)
# 检查基于角色的权限
return await self._check_rbac_permission(
auth_context.user_role,
resource,
action
)
class SecureAPIKeyManager:
"""安全的API Key管理"""
def __init__(self):
self.encryption_key = Fernet.generate_key()
self.fernet = Fernet(self.encryption_key)
async def generate_api_key(self, user_id: str, permissions: List[str]) -> dict:
"""生成API Key"""
# 生成随机密钥
raw_key = secrets.token_urlsafe(32)
key_hash = hashlib.sha256(raw_key.encode()).hexdigest()
# 创建密钥记录
api_key_record = {
"id": str(uuid.uuid4()),
"user_id": user_id,
"key_hash": key_hash,
"permissions": permissions,
"created_at": datetime.utcnow(),
"last_used_at": None,
"is_active": True,
"usage_count": 0
}
# 存储到数据库(只存储哈希)
await self._store_api_key_record(api_key_record)
# 返回完整密钥(仅此一次)
return {
"api_key": f"agpt_{raw_key}",
"key_id": api_key_record["id"],
"permissions": permissions
}
async def validate_api_key(self, api_key: str) -> dict:
"""验证API Key"""
if not api_key.startswith("agpt_"):
return None
raw_key = api_key[5:] # 移除前缀
key_hash = hashlib.sha256(raw_key.encode()).hexdigest()
# 查找密钥记录
key_record = await self._get_api_key_by_hash(key_hash)
if not key_record or not key_record["is_active"]:
return None
# 更新使用统计
await self._update_key_usage(key_record["id"])
return {
"user_id": key_record["user_id"],
"permissions": key_record["permissions"],
"key_id": key_record["id"]
}
1.2 权限控制模型
基于角色和资源的访问控制(RBAC + ABAC)
class PermissionManager:
"""权限管理器"""
def __init__(self):
self.role_permissions = self._load_role_permissions()
self.resource_policies = self._load_resource_policies()
def _load_role_permissions(self) -> dict:
"""加载角色权限配置"""
return {
"admin": ["*"], # 管理员拥有所有权限
"user": [
"graph:create",
"graph:read:own",
"graph:update:own",
"graph:delete:own",
"execution:create:own",
"execution:read:own"
],
"viewer": [
"graph:read:shared",
"execution:read:shared"
]
}
def _load_resource_policies(self) -> dict:
"""加载资源策略"""
return {
"graph": {
"ownership_field": "user_id",
"sharing_enabled": True,
"public_read": False
},
"execution": {
"ownership_field": "user_id",
"sharing_enabled": True,
"public_read": False
},
"user": {
"ownership_field": "id",
"sharing_enabled": False,
"public_read": False
}
}
async def check_permission(
self,
user_id: str,
user_role: str,
resource_type: str,
resource_id: str,
action: str
) -> bool:
"""检查权限"""
# 1. 检查角色权限
role_perms = self.role_permissions.get(user_role, [])
# 管理员拥有所有权限
if "*" in role_perms:
return True
# 2. 构建权限字符串
permission_patterns = [
f"{resource_type}:{action}",
f"{resource_type}:{action}:own",
f"{resource_type}:{action}:shared"
]
# 3. 检查基本权限
for pattern in permission_patterns:
if pattern in role_perms:
# 如果是own权限,需要检查所有权
if pattern.endswith(":own"):
return await self._check_ownership(
user_id, resource_type, resource_id
)
# 如果是shared权限,需要检查共享
elif pattern.endswith(":shared"):
return await self._check_sharing(
user_id, resource_type, resource_id
)
# 基本权限直接通过
else:
return True
return False
async def _check_ownership(
self,
user_id: str,
resource_type: str,
resource_id: str
) -> bool:
"""检查资源所有权"""
policy = self.resource_policies.get(resource_type)
if not policy:
return False
ownership_field = policy["ownership_field"]
# 查询资源所有者
resource = await self._get_resource(resource_type, resource_id)
if not resource:
return False
return resource.get(ownership_field) == user_id
async def _check_sharing(
self,
user_id: str,
resource_type: str,
resource_id: str
) -> bool:
"""检查资源共享权限"""
policy = self.resource_policies.get(resource_type)
if not policy or not policy["sharing_enabled"]:
return False
# 检查是否在共享列表中
return await self._is_resource_shared_with_user(
resource_type, resource_id, user_id
)
# 权限装饰器
def require_permission(resource_type: str, action: str):
"""权限检查装饰器"""
def decorator(func):
@functools.wraps(func)
async def wrapper(*args, **kwargs):
# 从请求中获取认证信息
request = kwargs.get('request') or args[0]
auth_context = getattr(request.state, 'auth_context', None)
if not auth_context or not auth_context.user_id:
raise HTTPException(status_code=401, detail="Authentication required")
# 获取资源ID
resource_id = kwargs.get('resource_id') or kwargs.get('graph_id') or kwargs.get('id')
# 检查权限
permission_manager = PermissionManager()
has_permission = await permission_manager.check_permission(
user_id=auth_context.user_id,
user_role=auth_context.user_role,
resource_type=resource_type,
resource_id=resource_id,
action=action
)
if not has_permission:
raise HTTPException(status_code=403, detail="Permission denied")
return await func(*args, **kwargs)
return wrapper
return decorator
# 使用示例
@require_permission("graph", "update")
async def update_graph(graph_id: str, updates: dict, request: Request):
"""更新图(需要更新权限)"""
# 实现更新逻辑
pass
2. 数据安全
2.1 敏感数据加密
端到端加密方案
class DataEncryptionManager:
"""数据加密管理器"""
def __init__(self):
# 主密钥(从环境变量或密钥管理服务获取)
self.master_key = self._get_master_key()
self.fernet = Fernet(self.master_key)
# 字段级加密配置
self.encrypted_fields = {
"users": ["email", "phone"],
"credentials": ["access_token", "refresh_token", "api_key"],
"graphs": ["sensitive_data"],
"executions": ["input_data", "output_data"]
}
def _get_master_key(self) -> bytes:
"""获取主加密密钥"""
key_str = os.getenv("ENCRYPTION_MASTER_KEY")
if not key_str:
# 在生产环境中,应该从密钥管理服务获取
raise ValueError("Master encryption key not configured")
return base64.urlsafe_b64decode(key_str.encode())
def encrypt_field(self, table: str, field: str, value: str) -> str:
"""加密字段值"""
if not self._should_encrypt_field(table, field):
return value
if not value:
return value
# 加密数据
encrypted_bytes = self.fernet.encrypt(value.encode('utf-8'))
return base64.urlsafe_b64encode(encrypted_bytes).decode('utf-8')
def decrypt_field(self, table: str, field: str, encrypted_value: str) -> str:
"""解密字段值"""
if not self._should_encrypt_field(table, field):
return encrypted_value
if not encrypted_value:
return encrypted_value
try:
# 解密数据
encrypted_bytes = base64.urlsafe_b64decode(encrypted_value.encode('utf-8'))
decrypted_bytes = self.fernet.decrypt(encrypted_bytes)
return decrypted_bytes.decode('utf-8')
except Exception as e:
logger.error(f"Failed to decrypt field {table}.{field}: {e}")
return encrypted_value
def _should_encrypt_field(self, table: str, field: str) -> bool:
"""检查字段是否需要加密"""
return field in self.encrypted_fields.get(table, [])
def encrypt_record(self, table: str, record: dict) -> dict:
"""加密记录中的敏感字段"""
encrypted_record = record.copy()
for field, value in record.items():
if isinstance(value, str):
encrypted_record[field] = self.encrypt_field(table, field, value)
return encrypted_record
def decrypt_record(self, table: str, record: dict) -> dict:
"""解密记录中的敏感字段"""
decrypted_record = record.copy()
for field, value in record.items():
if isinstance(value, str):
decrypted_record[field] = self.decrypt_field(table, field, value)
return decrypted_record
class SecureCredentialsStore:
"""安全凭据存储"""
def __init__(self):
self.encryption_manager = DataEncryptionManager()
self.db = DatabaseManager()
async def store_credentials(
self,
user_id: str,
provider: str,
credentials: dict
) -> str:
"""存储加密的凭据"""
# 生成凭据ID
credential_id = str(uuid.uuid4())
# 加密敏感字段
encrypted_credentials = {}
for key, value in credentials.items():
if key in ["access_token", "refresh_token", "api_key", "client_secret"]:
encrypted_credentials[key] = self.encryption_manager.encrypt_field(
"credentials", key, str(value)
)
else:
encrypted_credentials[key] = value
# 存储到数据库
record = {
"id": credential_id,
"user_id": user_id,
"provider": provider,
"credentials": json.dumps(encrypted_credentials),
"created_at": datetime.utcnow(),
"updated_at": datetime.utcnow()
}
await self.db.insert("user_credentials", record)
# 记录审计日志
await self._log_credential_access(
user_id, credential_id, "store", provider
)
return credential_id
async def retrieve_credentials(
self,
user_id: str,
credential_id: str
) -> dict:
"""检索并解密凭据"""
# 从数据库获取
record = await self.db.get_one(
"user_credentials",
{"id": credential_id, "user_id": user_id}
)
if not record:
return None
# 解密凭据
encrypted_credentials = json.loads(record["credentials"])
decrypted_credentials = {}
for key, value in encrypted_credentials.items():
if key in ["access_token", "refresh_token", "api_key", "client_secret"]:
decrypted_credentials[key] = self.encryption_manager.decrypt_field(
"credentials", key, value
)
else:
decrypted_credentials[key] = value
# 记录审计日志
await self._log_credential_access(
user_id, credential_id, "retrieve", record["provider"]
)
return decrypted_credentials
async def _log_credential_access(
self,
user_id: str,
credential_id: str,
action: str,
provider: str
):
"""记录凭据访问审计日志"""
audit_log = {
"id": str(uuid.uuid4()),
"user_id": user_id,
"credential_id": credential_id,
"action": action,
"provider": provider,
"timestamp": datetime.utcnow(),
"ip_address": self._get_current_ip(),
"user_agent": self._get_current_user_agent()
}
await self.db.insert("credential_audit_log", audit_log)
2.2 输入验证和防护
全面的输入验证
class InputValidator:
"""输入验证器"""
def __init__(self):
self.sql_injection_patterns = [
r"(\b(SELECT|INSERT|UPDATE|DELETE|DROP|CREATE|ALTER|EXEC|UNION)\b)",
r"(--|#|/\*|\*/)",
r"(\b(OR|AND)\s+\d+\s*=\s*\d+)",
]
self.xss_patterns = [
r"<script[^>]*>.*?</script>",
r"javascript:",
r"on\w+\s*=",
r"<iframe[^>]*>.*?</iframe>",
]
self.command_injection_patterns = [
r"[;&|`$(){}[\]\\]",
r"\b(cat|ls|pwd|whoami|id|uname|ps|netstat|ifconfig)\b",
]
def validate_and_sanitize(self, data: dict, schema: dict) -> dict:
"""验证和清理输入数据"""
validated_data = {}
for field, value in data.items():
field_schema = schema.get(field, {})
# 类型验证
expected_type = field_schema.get("type")
if expected_type and not self._validate_type(value, expected_type):
raise ValidationError(f"Field {field} has invalid type")
# 长度验证
if isinstance(value, str):
min_length = field_schema.get("min_length", 0)
max_length = field_schema.get("max_length", 10000)
if len(value) < min_length:
raise ValidationError(f"Field {field} is too short")
if len(value) > max_length:
raise ValidationError(f"Field {field} is too long")
# 安全验证
if isinstance(value, str):
value = self._sanitize_string(value, field_schema)
validated_data[field] = value
return validated_data
def _sanitize_string(self, value: str, field_schema: dict) -> str:
"""清理字符串输入"""
# HTML转义
if field_schema.get("escape_html", True):
value = html.escape(value)
# SQL注入检测
if self._detect_sql_injection(value):
raise SecurityError("Potential SQL injection detected")
# XSS检测
if self._detect_xss(value):
raise SecurityError("Potential XSS attack detected")
# 命令注入检测
if field_schema.get("check_command_injection", False):
if self._detect_command_injection(value):
raise SecurityError("Potential command injection detected")
return value
def _detect_sql_injection(self, value: str) -> bool:
"""检测SQL注入"""
for pattern in self.sql_injection_patterns:
if re.search(pattern, value, re.IGNORECASE):
return True
return False
def _detect_xss(self, value: str) -> bool:
"""检测XSS攻击"""
for pattern in self.xss_patterns:
if re.search(pattern, value, re.IGNORECASE):
return True
return False
def _detect_command_injection(self, value: str) -> bool:
"""检测命令注入"""
for pattern in self.command_injection_patterns:
if re.search(pattern, value, re.IGNORECASE):
return True
return False
# 安全中间件
class SecurityMiddleware:
"""安全中间件"""
def __init__(self):
self.validator = InputValidator()
self.rate_limiter = RateLimiter()
self.ip_whitelist = self._load_ip_whitelist()
self.blocked_ips = self._load_blocked_ips()
async def __call__(self, request: Request, call_next):
"""中间件处理"""
# 1. IP检查
client_ip = self._get_client_ip(request)
if client_ip in self.blocked_ips:
raise HTTPException(status_code=403, detail="IP blocked")
# 2. 限流检查
if not await self.rate_limiter.check_limit(f"ip:{client_ip}"):
raise HTTPException(status_code=429, detail="Rate limit exceeded")
# 3. 请求大小检查
content_length = request.headers.get("content-length")
if content_length and int(content_length) > 10 * 1024 * 1024: # 10MB
raise HTTPException(status_code=413, detail="Request too large")
# 4. 安全头检查
self._validate_security_headers(request)
# 执行请求
response = await call_next(request)
# 5. 添加安全响应头
self._add_security_headers(response)
return response
def _add_security_headers(self, response: Response):
"""添加安全响应头"""
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
response.headers["Content-Security-Policy"] = "default-src 'self'"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
总结
通过以上最佳实践和实战经验,AutoGPT平台能够实现:
开发效率提升
- 标准化开发流程:统一的Block开发规范和测试框架
- 代码质量保障:完整的代码审查和质量检查机制
- 快速问题定位:详细的日志记录和错误处理策略
系统性能优化
- 数据库优化:合理的索引设计和查询优化
- 缓存策略:多级缓存和智能失效机制
- 异步处理:高效的任务队列和并发控制
安全防护加强
- 多层认证:JWT + API Key双重认证体系
- 权限控制:细粒度的RBAC + ABAC权限模型
- 数据保护:端到端加密和输入验证防护
运维监控完善
- 实时监控:全面的系统指标和业务指标监控
- 故障处理:自动化的故障检测和恢复机制
- 性能分析:详细的性能分析和优化建议
这些最佳实践为AutoGPT平台的稳定运行和持续发展提供了坚实的基础。