1. RLlib概述
RLlib是Ray生态系统中的分布式强化学习库,提供可扩展、高性能的强化学习算法实现。它支持多种算法、环境和分布式训练模式,是工业级强化学习应用的首选框架。
1.1 核心特性
graph LR
subgraph "RLlib核心特性"
A[算法丰富<br/>Rich Algorithms]
B[分布式训练<br/>Distributed Training]
C[多环境支持<br/>Multi-Env Support]
D[可扩展架构<br/>Scalable Architecture]
E[生产就绪<br/>Production Ready]
end
style A fill:#e3f2fd
style B fill:#e8f5e8
style C fill:#fff3e0
style D fill:#fce4ec
style E fill:#f3e5f5
1.2 强化学习训练流程
stateDiagram-v2
[*] --> 环境初始化: gym.make()
环境初始化 --> 算法配置: AlgorithmConfig()
算法配置 --> 算法构建: config.build()
算法构建 --> 训练循环: algorithm.train()
state 训练循环 {
[*] --> 数据收集: EnvRunners
数据收集 --> 经验回放: ReplayBuffer
经验回放 --> 策略更新: Learner.update()
策略更新 --> 模型同步: sync weights
模型同步 --> 数据收集: next iteration
策略更新 --> 评估阶段: periodic evaluation
评估阶段 --> 数据收集: continue training
}
训练循环 --> 训练完成: stop criteria met
训练完成 --> [*]
2. 核心架构设计
2.1 RLlib整体架构
graph TB
subgraph "算法层"
A1[PPO]
A2[DQN]
A3[IMPALA]
A4[SAC]
end
subgraph "核心组件层"
B1[RLModule<br/>神经网络模型]
B2[Learner<br/>训练引擎]
B3[EnvRunner<br/>环境交互]
B4[ReplayBuffer<br/>经验存储]
end
subgraph "执行层"
C1[Algorithm<br/>算法协调器]
C2[EnvRunnerGroup<br/>分布式采样]
C3[LearnerGroup<br/>分布式学习]
end
subgraph "基础设施层"
D1[Ray Core<br/>分布式运行时]
D2[Gymnasium<br/>环境接口]
D3[PyTorch/TF<br/>深度学习框架]
end
A1 --> B1
A2 --> B1
A3 --> B2
A4 --> B2
B1 --> C1
B2 --> C2
B3 --> C3
B4 --> C1
C1 --> D1
C2 --> D2
C3 --> D3
style A1 fill:#e3f2fd
style B1 fill:#e8f5e8
style C1 fill:#fff3e0
style D1 fill:#fce4ec
2.2 RLModule核心实现
# 位置: rllib/core/rl_module/rl_module.py:256-402
@PublicAPI(stability="beta")
class RLModule(Checkpointable, abc.ABC):
"""
RLlib模型的基础类
核心职责:
1. 定义神经网络模型结构
2. 实现前向传播逻辑
3. 管理模型状态和检查点
4. 支持多智能体场景
前向传播模式:
- forward_inference: 推理模式(确定性策略)
- forward_exploration: 探索模式(随机策略)
- forward_train: 训练模式(计算损失)
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
inference_only: bool = False,
learner_only: bool = False,
model_config: Optional[Dict] = None,
**kwargs
):
"""
初始化RLModule
参数说明:
observation_space: 观测空间定义
action_space: 动作空间定义
inference_only: 是否仅用于推理
learner_only: 是否仅用于学习(不在EnvRunner上构建)
model_config: 模型配置字典
"""
self.observation_space = observation_space
self.action_space = action_space
self.inference_only = inference_only
self.learner_only = learner_only
self.model_config = model_config or {}
# 动作分布类(用于采样动作)
self.action_dist_cls = None
# 模型组件初始化标志
self._is_setup = False
@abc.abstractmethod
def _forward_inference(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType]:
"""
推理前向传播(抽象方法)
用途:在环境中执行策略,生成确定性动作
参数:
batch: 包含观测数据的字典
- "obs": 观测数据张量
返回:
包含动作分布参数的字典
- "action_dist_inputs": 动作分布的输入参数
"""
pass
@abc.abstractmethod
def _forward_exploration(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType]:
"""
探索前向传播(抽象方法)
用途:在训练期间进行探索,生成随机策略动作
与forward_inference的区别:
- 可能包含噪声或随机性
- 用于数据收集阶段
- 可能使用不同的网络分支
"""
pass
@abc.abstractmethod
def _forward_train(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType]:
"""
训练前向传播(抽象方法)
用途:计算损失函数所需的输出
参数:
batch: 训练批次数据
- "obs": 观测数据
- "actions": 历史动作(用于计算动作概率)
- "rewards": 奖励信号
- "next_obs": 下一状态观测
返回:
训练所需的输出
- "action_dist_inputs": 动作分布参数
- "vf_preds": 价值函数预测(Actor-Critic算法)
"""
pass
def get_inference_action_dist_cls(self):
"""获取推理时使用的动作分布类"""
return self.action_dist_cls
def get_exploration_action_dist_cls(self):
"""获取探索时使用的动作分布类"""
return self.action_dist_cls
def get_train_action_dist_cls(self):
"""获取训练时使用的动作分布类"""
return self.action_dist_cls
# PPO算法的RLModule实现示例
class PPOTorchRLModule(RLModule):
"""
PPO算法的PyTorch实现
模型结构:
- 共享编码器网络
- 策略头(输出动作分布参数)
- 价值头(输出状态价值估计)
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
# 构建网络结构
self._build_networks()
def _build_networks(self):
"""构建PPO网络结构"""
import torch.nn as nn
# 共享编码器
obs_dim = self.observation_space.shape[0]
hidden_sizes = self.model_config.get("fcnet_hiddens", [256, 256])
encoder_layers = []
prev_size = obs_dim
for hidden_size in hidden_sizes:
encoder_layers.extend([
nn.Linear(prev_size, hidden_size),
nn.ReLU(),
])
prev_size = hidden_size
self.encoder = nn.Sequential(*encoder_layers)
# 策略头
if isinstance(self.action_space, gym.spaces.Discrete):
# 离散动作空间
self.policy_head = nn.Linear(prev_size, self.action_space.n)
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
self.action_dist_cls = TorchCategorical
else:
# 连续动作空间
action_dim = self.action_space.shape[0]
self.policy_mean = nn.Linear(prev_size, action_dim)
self.policy_logstd = nn.Parameter(torch.zeros(action_dim))
from ray.rllib.models.torch.torch_action_dist import TorchDiagGaussian
self.action_dist_cls = TorchDiagGaussian
# 价值头
self.value_head = nn.Linear(prev_size, 1)
def _forward_inference(self, batch):
"""PPO推理前向传播"""
obs = batch["obs"]
# 编码器前向传播
encoded = self.encoder(obs)
# 策略输出
if isinstance(self.action_space, gym.spaces.Discrete):
action_logits = self.policy_head(encoded)
return {"action_dist_inputs": action_logits}
else:
action_mean = self.policy_mean(encoded)
action_logstd = self.policy_logstd.expand_as(action_mean)
return {"action_dist_inputs": torch.cat([action_mean, action_logstd], dim=-1)}
def _forward_exploration(self, batch):
"""PPO探索前向传播(与推理相同)"""
return self._forward_inference(batch)
def _forward_train(self, batch):
"""PPO训练前向传播"""
obs = batch["obs"]
# 编码器前向传播
encoded = self.encoder(obs)
# 策略输出
action_dist_inputs = self._compute_action_dist_inputs(encoded)
# 价值函数输出
value_predictions = self.value_head(encoded).squeeze(-1)
return {
"action_dist_inputs": action_dist_inputs,
"vf_preds": value_predictions,
}
def _compute_action_dist_inputs(self, encoded):
"""计算动作分布输入"""
if isinstance(self.action_space, gym.spaces.Discrete):
return self.policy_head(encoded)
else:
action_mean = self.policy_mean(encoded)
action_logstd = self.policy_logstd.expand_as(action_mean)
return torch.cat([action_mean, action_logstd], dim=-1)
3. 算法体系详解
3.1 Algorithm基类实现
# 位置: rllib/algorithms/algorithm.py:207-241
@PublicAPI
class Algorithm(Checkpointable, Trainable):
"""
RLlib算法的基础类
核心功能:
1. 算法生命周期管理
2. 训练循环协调
3. 评估和检查点管理
4. 分布式训练支持
重要属性:
config: AlgorithmConfig实例
env_runner_group: 环境交互的运行组
eval_env_runner_group: 评估专用运行组
learner_group: 学习者组(用于分布式学习)
"""
def __init__(self, config: AlgorithmConfig, env: Optional[EnvType] = None, **kwargs):
"""
初始化算法实例
初始化流程:
1. 验证和设置配置
2. 创建环境交互组件
3. 构建模型和学习器
4. 设置分布式训练
"""
super().__init__()
self.config = config
self._env_id = config.env
# 创建度量记录器
self.metrics = MetricsLogger()
# 初始化训练组件
self._setup_training_components(env)
# 设置评估
if config.evaluation_interval > 0:
self._setup_evaluation_components()
def _setup_training_components(self, env):
"""设置训练组件"""
# 1. 创建环境运行组
self.env_runner_group = EnvRunnerGroup(
env_creator=lambda: env or gym.make(self._env_id),
default_policy_class=self.get_default_policy_class(),
config=self.config,
num_workers=self.config.num_env_runners,
local_worker=True,
)
# 2. 创建学习器组
if self.config.enable_rl_module_and_learner:
self.learner_group = LearnerGroup(
config=self.config,
module_spec=self._get_rl_module_spec(),
num_learners=self.config.num_learners,
)
# 3. 创建经验回放缓冲区(如果需要)
if self.config.replay_buffer_config:
self.replay_buffer = ReplayBuffer(
capacity=self.config.replay_buffer_config["capacity"],
storage_unit=self.config.replay_buffer_config.get("storage_unit", "timesteps")
)
def train(self) -> ResultDict:
"""
执行一次训练迭代
标准训练流程:
1. 数据收集(环境交互)
2. 策略更新(梯度计算和应用)
3. 模型同步(分布式场景)
4. 指标收集和报告
"""
with self.metrics.timer("train_time_ms"):
# 1. 执行算法特定的训练步骤
train_results = self.training_step()
# 2. 更新全局步数
self._iteration += 1
# 3. 同步模型权重(如果使用分布式)
if self.config.enable_rl_module_and_learner:
self._sync_weights_to_env_runners()
# 4. 执行定期评估
if (self.config.evaluation_interval > 0 and
self._iteration % self.config.evaluation_interval == 0):
eval_results = self.evaluate()
train_results.update(eval_results)
# 5. 保存检查点(如果需要)
if (self.config.checkpoint_frequency > 0 and
self._iteration % self.config.checkpoint_frequency == 0):
self.save_checkpoint()
return train_results
@abc.abstractmethod
def training_step(self) -> ResultDict:
"""
算法特定的训练步骤(抽象方法)
子类必须实现此方法,定义:
1. 如何收集训练数据
2. 如何更新策略参数
3. 如何计算和报告指标
"""
pass
def evaluate(self) -> ResultDict:
"""
执行策略评估
评估流程:
1. 使用当前策略在评估环境中运行
2. 收集性能指标
3. 计算平均奖励、成功率等
"""
if not hasattr(self, 'eval_env_runner_group'):
return {}
# 同步最新权重到评估运行器
self._sync_weights_to_eval_env_runners()
# 执行评估回合
eval_results = self.eval_env_runner_group.sample()
# 处理评估指标
eval_metrics = {
"evaluation": {
"episode_reward_mean": np.mean([ep.get_return() for ep in eval_results]),
"episode_len_mean": np.mean([ep.length for ep in eval_results]),
"episodes_this_iter": len(eval_results),
}
}
return eval_metrics
# PPO算法的具体实现示例
class PPO(Algorithm):
"""
PPO算法实现
PPO特点:
1. On-policy策略梯度算法
2. 使用clipped surrogate objective
3. Actor-Critic架构
4. 适合并行数据收集
"""
def training_step(self) -> ResultDict:
"""
PPO训练步骤实现
PPO训练流程:
1. 使用当前策略收集轨迹数据
2. 计算GAE优势估计
3. 多次更新策略和价值网络
4. 应用PPO损失函数(clipped objective)
"""
train_results = {}
# 1. 数据收集阶段
with self.metrics.timer("sample_time_ms"):
# 从所有EnvRunner收集数据
sample_batches = self.env_runner_group.sample()
# 合并批次数据
combined_batch = SampleBatch.concat_samples(sample_batches)
# 计算GAE优势
combined_batch = compute_gae_for_sample_batch(
batch=combined_batch,
last_r=0.0,
gamma=self.config.gamma,
lambda_=self.config.lambda_,
use_gae=self.config.use_gae,
)
# 2. 策略更新阶段
with self.metrics.timer("learn_time_ms"):
if self.config.enable_rl_module_and_learner:
# 使用新的Learner API
learn_results = self.learner_group.update_from_batch(
batch=combined_batch,
minibatch_size=self.config.train_batch_size,
num_iters=self.config.num_sgd_iter,
)
train_results.update(learn_results)
else:
# 使用传统Policy API
for _ in range(self.config.num_sgd_iter):
# 随机打乱数据
shuffled_batch = combined_batch.shuffle()
# 分批更新
for minibatch in shuffled_batch.timeslices(self.config.train_batch_size):
policy_loss, value_loss, entropy = self._update_policy(minibatch)
train_results.update({
"policy_loss": policy_loss,
"value_loss": value_loss,
"entropy": entropy,
})
# 3. 收集训练指标
train_results.update({
"episodes_this_iter": len(sample_batches),
"num_env_steps_sampled": combined_batch.count(),
"num_env_steps_trained": combined_batch.count(),
"learner": self.learner_group.get_stats() if hasattr(self, 'learner_group') else {},
})
return train_results
def _update_policy(self, batch: SampleBatch) -> Tuple[float, float, float]:
"""
单次策略更新
PPO损失函数:
1. Clipped surrogate objective
2. 价值函数损失
3. 熵正则化项
"""
# 计算策略比率
curr_action_logp = self.get_policy().compute_log_likelihoods(
actions=batch["actions"],
obs_batch=batch["obs"]
)
old_action_logp = batch["action_logp"]
ratio = torch.exp(curr_action_logp - old_action_logp)
# PPO clipped objective
advantages = batch["advantages"]
unclipped_loss = ratio * advantages
clipped_loss = torch.clamp(
ratio,
1 - self.config.clip_param,
1 + self.config.clip_param
) * advantages
policy_loss = -torch.mean(torch.min(unclipped_loss, clipped_loss))
# 价值函数损失
value_predictions = self.get_policy().model.value_function()
value_targets = batch["value_targets"]
value_loss = torch.mean((value_predictions - value_targets) ** 2)
# 熵损失
action_dist = self.get_policy().dist_class(curr_action_logp)
entropy = torch.mean(action_dist.entropy())
# 总损失
total_loss = (
policy_loss +
self.config.vf_loss_coeff * value_loss -
self.config.entropy_coeff * entropy
)
# 反向传播和参数更新
self.get_policy().optimizer.zero_grad()
total_loss.backward()
# 梯度裁剪
if self.config.grad_clip:
torch.nn.utils.clip_grad_norm_(
self.get_policy().model.parameters(),
self.config.grad_clip
)
self.get_policy().optimizer.step()
return policy_loss.item(), value_loss.item(), entropy.item()
4. 分布式训练架构
4.1 分布式组件交互
graph TB
subgraph "Driver进程"
A[Algorithm实例]
B[训练协调器]
C[结果聚合器]
end
subgraph "EnvRunner节点"
D[EnvRunner 1<br/>环境交互]
E[EnvRunner 2<br/>环境交互]
F[EnvRunner N<br/>环境交互]
end
subgraph "Learner节点"
G[Learner 1<br/>策略更新]
H[Learner 2<br/>策略更新]
I[Learner M<br/>策略更新]
end
subgraph "共享存储"
J[参数服务器<br/>Parameter Server]
K[经验回放<br/>Replay Buffer]
L[检查点存储<br/>Checkpoint Store]
end
A --> B
B --> D
B --> E
B --> F
D --> K
E --> K
F --> K
B --> G
B --> H
B --> I
G --> J
H --> J
I --> J
J --> D
J --> E
J --> F
G --> L
C --> L
style A fill:#e3f2fd
style D fill:#e8f5e8
style G fill:#fff3e0
style J fill:#fce4ec
4.2 EnvRunner实现
"""
EnvRunner - 环境交互组件
负责与环境交互,收集训练数据
"""
class EnvRunner:
"""
环境运行器 - 负责策略与环境的交互
核心功能:
1. 环境管理和重置
2. 策略执行和动作采样
3. 经验数据收集
4. 轨迹处理和格式化
支持的模式:
- 单环境模式
- 向量化环境模式
- 异步环境模式
"""
def __init__(
self,
env_creator: Callable,
policy_config: Dict,
rollout_fragment_length: int,
batch_mode: str = "truncate_episodes",
):
"""
初始化环境运行器
参数:
env_creator: 环境创建函数
policy_config: 策略配置
rollout_fragment_length: 单次rollout的长度
batch_mode: 批次模式(truncate_episodes/complete_episodes)
"""
# 创建环境
self.env = env_creator()
# 初始化策略
self.policy = self._create_policy(policy_config)
# 配置参数
self.rollout_fragment_length = rollout_fragment_length
self.batch_mode = batch_mode
# 状态管理
self._episode_id = 0
self._agent_id = 0
self._current_obs = None
self._episode_reward = 0
self._episode_length = 0
def sample(self) -> SampleBatch:
"""
采样一个训练批次
采样流程:
1. 环境重置(如果需要)
2. 策略执行和动作采样
3. 环境步进和奖励收集
4. 经验存储和批次构建
"""
batch_builder = SampleBatchBuilder()
steps_collected = 0
while steps_collected < self.rollout_fragment_length:
# 环境重置(如果episode结束)
if self._current_obs is None:
self._current_obs = self.env.reset()
self._episode_reward = 0
self._episode_length = 0
self._episode_id += 1
# 策略决策
action, action_logp, state_out = self.policy.compute_action(
obs=self._current_obs,
state=self.policy.get_initial_state(),
explore=True # 训练时启用探索
)
# 环境步进
next_obs, reward, terminated, truncated, info = self.env.step(action)
done = terminated or truncated
# 存储经验
batch_builder.add_values(
obs=self._current_obs,
actions=action,
rewards=reward,
dones=done,
action_logp=action_logp,
next_obs=next_obs,
infos=info,
)
# 更新状态
self._current_obs = next_obs if not done else None
self._episode_reward += reward
self._episode_length += 1
steps_collected += 1
# Episode结束处理
if done:
self._log_episode_end()
# 构建并返回SampleBatch
return batch_builder.build_and_reset()
def _create_policy(self, config: Dict):
"""创建策略实例"""
policy_class = config["policy_class"]
return policy_class(
observation_space=self.env.observation_space,
action_space=self.env.action_space,
config=config
)
def _log_episode_end(self):
"""记录episode结束信息"""
self.metrics.log_value(
"episode_reward", self._episode_reward
)
self.metrics.log_value(
"episode_length", self._episode_length
)
class EnvRunnerGroup:
"""
环境运行器组 - 管理多个EnvRunner的协调器
功能:
1. 分布式环境交互
2. 并行数据收集
3. 故障恢复和弹性扩展
4. 负载均衡
"""
def __init__(
self,
env_creator: Callable,
policy_class: Type,
config: AlgorithmConfig,
num_workers: int = 0,
local_worker: bool = True,
):
"""
初始化运行器组
参数:
env_creator: 环境创建函数
policy_class: 策略类
config: 算法配置
num_workers: 远程worker数量
local_worker: 是否包含本地worker
"""
self.config = config
self.num_workers = num_workers
# 创建本地worker
self.local_worker = None
if local_worker:
self.local_worker = EnvRunner(
env_creator=env_creator,
policy_config=config.to_dict(),
rollout_fragment_length=config.rollout_fragment_length,
)
# 创建远程workers
self.remote_workers = []
if num_workers > 0:
RemoteEnvRunner = ray.remote(EnvRunner)
for i in range(num_workers):
worker = RemoteEnvRunner.remote(
env_creator=env_creator,
policy_config=config.to_dict(),
rollout_fragment_length=config.rollout_fragment_length,
)
self.remote_workers.append(worker)
def sample(self) -> List[SampleBatch]:
"""
并行采样训练数据
并行策略:
1. 所有worker同时采样
2. 异步收集结果
3. 处理worker故障
"""
# 启动远程采样
remote_futures = []
for worker in self.remote_workers:
future = worker.sample.remote()
remote_futures.append(future)
# 本地采样
local_batch = None
if self.local_worker:
local_batch = self.local_worker.sample()
# 收集远程结果
remote_batches = ray.get(remote_futures)
# 合并所有批次
all_batches = []
if local_batch:
all_batches.append(local_batch)
all_batches.extend(remote_batches)
return all_batches
def sync_weights(self, weights: Dict):
"""同步模型权重到所有worker"""
# 同步到本地worker
if self.local_worker:
self.local_worker.set_weights(weights)
# 同步到远程workers
if self.remote_workers:
ray.get([
worker.set_weights.remote(weights)
for worker in self.remote_workers
])
5. 环境交互系统
5.1 多环境支持架构
"""
RLlib多环境支持系统
支持的环境类型:
1. Gymnasium环境
2. Unity ML-Agents
3. 自定义环境
4. 多智能体环境
5. 向量化环境
"""
import gymnasium as gym
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
# 1. 自定义环境示例
class CustomTradingEnv(gym.Env):
"""
自定义交易环境示例
功能:
- 模拟股票交易环境
- 支持连续动作空间
- 提供复杂奖励函数
"""
def __init__(self, config: Dict):
"""
初始化交易环境
参数:
config: 环境配置字典
"""
super().__init__()
# 配置参数
self.initial_balance = config.get("initial_balance", 10000)
self.transaction_cost = config.get("transaction_cost", 0.001)
self.max_steps = config.get("max_steps", 1000)
# 动作空间:[-1, 1]表示卖出到买入的比例
self.action_space = gym.spaces.Box(
low=-1, high=1, shape=(1,), dtype=np.float32
)
# 观测空间:价格历史、技术指标、持仓信息
self.observation_space = gym.spaces.Box(
low=-np.inf, high=np.inf, shape=(50,), dtype=np.float32
)
# 环境状态
self.current_step = 0
self.balance = self.initial_balance
self.position = 0 # 当前持仓
self.price_history = []
# 加载历史价格数据
self.price_data = self._load_price_data()
def reset(self, seed=None, options=None):
"""重置环境到初始状态"""
super().reset(seed=seed)
self.current_step = 0
self.balance = self.initial_balance
self.position = 0
self.price_history = []
# 随机选择起始位置
self.start_idx = np.random.randint(50, len(self.price_data) - self.max_steps)
return self._get_observation(), {}
def step(self, action):
"""执行动作,返回下一状态"""
action = action[0] # 提取标量动作
# 获取当前价格
current_price = self.price_data[self.start_idx + self.current_step]
# 执行交易
if action > 0:
# 买入操作
buy_amount = min(action * self.balance / current_price,
self.balance / current_price)
cost = buy_amount * current_price * (1 + self.transaction_cost)
if cost <= self.balance:
self.balance -= cost
self.position += buy_amount
elif action < 0:
# 卖出操作
sell_amount = min(-action * self.position, self.position)
revenue = sell_amount * current_price * (1 - self.transaction_cost)
self.balance += revenue
self.position -= sell_amount
# 更新状态
self.current_step += 1
self.price_history.append(current_price)
# 计算奖励
portfolio_value = self.balance + self.position * current_price
reward = self._calculate_reward(portfolio_value, action)
# 检查终止条件
terminated = (
self.current_step >= self.max_steps or
portfolio_value <= 0
)
return self._get_observation(), reward, terminated, False, {}
def _get_observation(self):
"""构建观测向量"""
# 价格特征
recent_prices = self.price_history[-20:] if len(self.price_history) >= 20 else \
[self.price_data[self.start_idx]] * 20
# 技术指标
sma_5 = np.mean(recent_prices[-5:])
sma_20 = np.mean(recent_prices)
rsi = self._calculate_rsi(recent_prices)
# 账户状态
current_price = self.price_data[self.start_idx + self.current_step] \
if self.current_step < len(self.price_data) - self.start_idx else 0
portfolio_value = self.balance + self.position * current_price
# 构建观测向量
obs = np.array([
# 归一化的价格历史(20维)
*(np.array(recent_prices) / recent_prices[-1] - 1),
# 技术指标(3维)
sma_5 / current_price - 1,
sma_20 / current_price - 1,
rsi / 100 - 0.5,
# 账户状态(7维)
self.balance / self.initial_balance - 1,
self.position * current_price / self.initial_balance,
portfolio_value / self.initial_balance - 1,
self.current_step / self.max_steps,
# 填充到50维
*np.zeros(20),
], dtype=np.float32)
return obs[:50] # 确保维度正确
# 2. 多智能体环境
class MultiAgentTradingEnv(MultiAgentEnv):
"""
多智能体交易环境
场景:多个智能体在同一市场中竞争交易
"""
def __init__(self, config: Dict):
super().__init__()
self.num_agents = config.get("num_agents", 3)
self.market_impact = config.get("market_impact", True)
# 每个智能体的独立环境
self.agents = {f"agent_{i}": CustomTradingEnv(config) for i in range(self.num_agents)}
# 共享市场状态
self.market_state = MarketState()
def reset(self, *, seed=None, options=None):
"""重置所有智能体"""
observations = {}
for agent_id, agent_env in self.agents.items():
obs, info = agent_env.reset(seed=seed)
observations[agent_id] = obs
return observations, {}
def step(self, action_dict):
"""执行多智能体动作"""
observations = {}
rewards = {}
terminateds = {}
truncateds = {}
infos = {}
# 市场影响建模
total_market_action = sum(action_dict.values())
market_pressure = self._calculate_market_pressure(total_market_action)
# 各智能体并行执行
for agent_id, action in action_dict.items():
# 考虑市场影响的修正动作
adjusted_action = self._adjust_action_for_market_impact(action, market_pressure)
obs, reward, terminated, truncated, info = self.agents[agent_id].step([adjusted_action])
observations[agent_id] = obs
rewards[agent_id] = reward
terminateds[agent_id] = terminated
truncateds[agent_id] = truncated
infos[agent_id] = info
# 检查全局终止条件
terminateds["__all__"] = all(terminateds.values())
truncateds["__all__"] = all(truncateds.values())
return observations, rewards, terminateds, truncateds, infos
def _calculate_market_pressure(self, total_action):
"""计算市场压力"""
return np.tanh(total_action * 0.1) # 简化的市场影响模型
def _adjust_action_for_market_impact(self, action, market_pressure):
"""根据市场影响调整动作"""
if market_pressure > 0: # 买方压力
return action * (1 - 0.1 * market_pressure) # 买入成本增加
else: # 卖方压力
return action * (1 + 0.1 * abs(market_pressure)) # 卖出收益减少
6. 使用示例与最佳实践
6.1 PPO算法完整训练示例
"""
使用RLlib训练PPO算法的完整示例
"""
import ray
from ray import tune
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.env.wrappers.atari_wrappers import wrap_atari_for_new_api_stack
# 1. 环境配置和包装
def create_atari_env(config):
"""创建Atari游戏环境"""
import gymnasium as gym
env = gym.make("ALE/Breakout-v5")
# 应用Atari特定的预处理
env = wrap_atari_for_new_api_stack(env, framestack=4)
return env
# 2. 算法配置
def configure_ppo_algorithm():
"""配置PPO算法"""
config = (
PPOConfig()
.environment(
env="ALE/Breakout-v5",
clip_rewards=True, # Atari奖励裁剪
)
.framework("torch")
.training(
# PPO特定参数
lr=0.0001,
num_sgd_iter=10, # SGD迭代次数
sgd_minibatch_size=500, # 小批次大小
train_batch_size=5000, # 训练批次大小
clip_param=0.2, # PPO裁剪参数
entropy_coeff=0.01, # 熵系数
vf_loss_coeff=0.5, # 价值函数损失系数
gamma=0.99, # 折扣因子
lambda_=0.95, # GAE参数
use_gae=True, # 使用GAE
# 模型架构
model={
"conv_filters": [
[32, [8, 8], 4],
[64, [4, 4], 2],
[64, [3, 3], 1],
],
"fcnet_hiddens": [512],
"fcnet_activation": "relu",
}
)
.resources(
num_gpus=1, # 学习器GPU
num_cpus_per_worker=1, # 每个worker的CPU
num_gpus_per_worker=0, # 每个worker的GPU
)
.env_runners(
num_env_runners=8, # 环境运行器数量
num_envs_per_env_runner=1, # 每个运行器的环境数
rollout_fragment_length=200, # rollout长度
)
.learners(
num_learners=1, # 学习器数量
num_cpus_per_learner=4, # 每个学习器的CPU
num_gpus_per_learner=1, # 每个学习器的GPU
)
.evaluation(
evaluation_interval=20, # 评估间隔
evaluation_duration=10, # 评估episode数
evaluation_parallel_to_training=True, # 并行评估
)
.reporting(
min_sample_timesteps_per_iteration=5000, # 最小采样步数
)
)
return config
# 3. 分布式训练
def run_distributed_ppo():
"""运行分布式PPO训练"""
# 初始化Ray
ray.init(address="ray://head-node:10001") # 连接到Ray集群
# 配置算法
config = configure_ppo_algorithm()
# 构建算法实例
algorithm = config.build()
try:
# 训练循环
for iteration in range(1000):
# 执行训练步骤
train_results = algorithm.train()
# 打印训练进度
if iteration % 10 == 0:
print(f"Iteration {iteration}:")
print(f" Episode reward mean: {train_results['env_runners']['episode_reward_mean']:.2f}")
print(f" Policy loss: {train_results.get('learner', {}).get('policy_loss', 0):.4f}")
print(f" Value loss: {train_results.get('learner', {}).get('vf_loss', 0):.4f}")
# 保存检查点
if iteration % 100 == 0:
checkpoint_path = algorithm.save_to_path(f"./checkpoints/ppo_iteration_{iteration}")
print(f"Checkpoint saved: {checkpoint_path}")
# 检查停止条件
if train_results['env_runners']['episode_reward_mean'] > 400:
print("Training completed - target reward reached!")
break
finally:
algorithm.stop()
ray.shutdown()
# 4. 超参数调优集成
def hyperparameter_tuning_with_tune():
"""使用Ray Tune进行RLlib超参数调优"""
def train_function(config):
"""训练函数,接收超参数配置"""
# 基础PPO配置
ppo_config = (
PPOConfig()
.environment("CartPole-v1")
.framework("torch")
.training(
lr=config["lr"],
num_sgd_iter=config["num_sgd_iter"],
sgd_minibatch_size=config["sgd_minibatch_size"],
clip_param=config["clip_param"],
entropy_coeff=config["entropy_coeff"],
vf_loss_coeff=config["vf_loss_coeff"],
)
.env_runners(
num_env_runners=4,
rollout_fragment_length=config["rollout_fragment_length"],
)
)
# 构建算法
algorithm = ppo_config.build()
try:
# 训练循环
for i in range(100):
result = algorithm.train()
# 报告给Ray Tune
tune.report(
episode_reward_mean=result['env_runners']['episode_reward_mean'],
policy_loss=result.get('learner', {}).get('policy_loss', 0),
training_iteration=i,
)
finally:
algorithm.stop()
# 定义超参数搜索空间
search_space = {
"lr": tune.loguniform(1e-5, 1e-3),
"num_sgd_iter": tune.choice([5, 10, 20]),
"sgd_minibatch_size": tune.choice([128, 256, 512]),
"clip_param": tune.uniform(0.1, 0.3),
"entropy_coeff": tune.loguniform(1e-5, 1e-2),
"vf_loss_coeff": tune.uniform(0.1, 1.0),
"rollout_fragment_length": tune.choice([200, 400, 800]),
}
# 运行超参数优化
tuner = tune.Tuner(
train_function,
param_space=search_space,
tune_config=tune.TuneConfig(
metric="episode_reward_mean",
mode="max",
num_samples=50,
),
run_config=tune.RunConfig(
name="ppo_hyperparameter_tuning",
stop={"training_iteration": 100},
)
)
results = tuner.fit()
best_result = results.get_best_result()
print(f"Best configuration: {best_result.config}")
print(f"Best reward: {best_result.metrics['episode_reward_mean']}")
return best_result
# 5. 自定义算法实现
class CustomPPOAlgorithm(PPO):
"""
自定义PPO算法实现
扩展功能:
- 自定义损失函数
- 特殊的奖励处理
- 额外的正则化项
"""
def training_step(self) -> ResultDict:
"""覆盖训练步骤,添加自定义逻辑"""
# 1. 执行标准PPO训练步骤
train_results = super().training_step()
# 2. 添加自定义指标计算
custom_metrics = self._compute_custom_metrics()
train_results.update(custom_metrics)
# 3. 执行自定义的后处理
self._custom_post_processing(train_results)
return train_results
def _compute_custom_metrics(self) -> Dict:
"""计算自定义指标"""
# 获取最新的策略参数
policy_weights = self.get_policy().get_weights()
# 计算参数统计
param_stats = {}
for name, param in policy_weights.items():
if isinstance(param, np.ndarray):
param_stats[f"{name}_mean"] = np.mean(param)
param_stats[f"{name}_std"] = np.std(param)
param_stats[f"{name}_max"] = np.max(param)
param_stats[f"{name}_min"] = np.min(param)
return {"custom_metrics": param_stats}
def _custom_post_processing(self, train_results: Dict):
"""自定义后处理逻辑"""
# 例如:动态调整学习率
current_performance = train_results.get('env_runners', {}).get('episode_reward_mean', 0)
if hasattr(self, '_last_performance'):
improvement = current_performance - self._last_performance
if improvement < 0.1: # 性能改善缓慢
# 降低学习率
current_lr = self.config.lr
new_lr = current_lr * 0.95
self.config.lr = max(1e-6, new_lr)
print(f"Performance plateau detected, reducing LR to {new_lr}")
self._last_performance = current_performance
# 使用示例
def main():
"""主函数"""
ray.init()
try:
# 1. 基础PPO训练
print("Starting basic PPO training...")
run_distributed_ppo()
# 2. 超参数调优
print("Starting hyperparameter tuning...")
best_config = hyperparameter_tuning_with_tune()
# 3. 使用最佳配置进行最终训练
print("Training with best configuration...")
final_config = configure_ppo_algorithm()
final_config.update_from_dict(best_config.config)
final_algorithm = final_config.build()
# 训练更多轮次
for i in range(500):
result = final_algorithm.train()
if i % 50 == 0:
print(f"Final training iteration {i}: "
f"reward = {result['env_runners']['episode_reward_mean']:.2f}")
# 保存最终模型
final_checkpoint = final_algorithm.save_to_path("./final_model")
print(f"Final model saved: {final_checkpoint}")
finally:
ray.shutdown()
if __name__ == "__main__":
main()
总结
RLlib作为Ray生态系统中的强化学习引擎,提供了完整的分布式强化学习解决方案:
核心架构优势
- 模块化设计 - RLModule、Learner、EnvRunner的清晰分离
- 分布式优先 - 原生支持大规模分布式训练
- 算法丰富 - 涵盖主流强化学习算法
- 高度可扩展 - 易于添加新算法和自定义组件
关键特性
- 多环境支持: Gymnasium、Unity、自定义环境
- 分布式训练: 并行环境交互和策略更新
- 容错机制: 自动故障恢复和重启
- 超参数优化: 与Ray Tune的无缝集成
- 生产部署: 检查点管理和模型服务支持
RLlib为强化学习从研究到生产部署提供了完整的工具链,显著降低了分布式强化学习的实现复杂度。