1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
| namespace torch::distributed::pipeline {
// 流水线并行管理器(基于GPipe分析)
class PipelineParallelManager {
private:
// 流水线阶段定义
struct PipelineStage {
torch::nn::Module module; // 阶段模块
c10::DeviceIndex device; // 所在设备
int stage_id; // 阶段ID
// 激活检查点配置
bool use_checkpointing;
int checkpoint_segments;
PipelineStage(torch::nn::Module mod, c10::DeviceIndex dev, int id)
: module(std::move(mod)), device(dev), stage_id(id),
use_checkpointing(false), checkpoint_segments(1) {}
};
std::vector<PipelineStage> stages_;
// 微批次管理
struct Microbatch {
at::Tensor data;
at::Tensor target;
int64_t microbatch_id;
int current_stage;
// 激活缓存(用于反向传播)
std::vector<at::Tensor> cached_activations;
Microbatch(at::Tensor data, at::Tensor target, int64_t id)
: data(std::move(data)), target(std::move(target)),
microbatch_id(id), current_stage(0) {}
};
// 流水线调度器
class PipelineScheduler {
private:
enum class SchedulePhase {
FORWARD_WARMUP, // 前向预热
FORWARD_BACKWARD, // 前向-反向交替
BACKWARD_COOLDOWN // 反向冷却
};
SchedulePhase current_phase_;
int num_microbatches_;
int num_stages_;
public:
PipelineScheduler(int num_microbatches, int num_stages)
: current_phase_(SchedulePhase::FORWARD_WARMUP),
num_microbatches_(num_microbatches),
num_stages_(num_stages) {}
// 生成1F1B(One Forward One Backward)调度
std::vector<PipelineOperation> generate_1f1b_schedule() {
std::vector<PipelineOperation> schedule;
// 前向预热阶段:填充流水线
for (int i = 0; i < num_stages_ - 1; ++i) {
schedule.push_back({OperationType::FORWARD, i, i});
}
// 前向-反向交替阶段
for (int i = num_stages_ - 1; i < num_microbatches_; ++i) {
// 前向传播
schedule.push_back({OperationType::FORWARD, i, num_stages_ - 1});
// 反向传播(延迟启动)
if (i >= num_stages_ - 1) {
int backward_microbatch = i - num_stages_ + 1;
schedule.push_back({OperationType::BACKWARD, backward_microbatch, num_stages_ - 1});
}
}
// 反向冷却阶段:清空流水线
for (int i = num_microbatches_ - num_stages_ + 1; i < num_microbatches_; ++i) {
schedule.push_back({OperationType::BACKWARD, i, num_stages_ - 1 - (i - (num_microbatches_ - num_stages_ + 1))});
}
return schedule;
}
};
// 流水线操作类型
enum class OperationType {
FORWARD,
BACKWARD,
OPTIMIZER_STEP,
COMMUNICATION
};
struct PipelineOperation {
OperationType type;
int microbatch_id;
int stage_id;
};
public:
PipelineParallelManager(const std::vector<torch::nn::Module>& stage_modules,
const std::vector<c10::DeviceIndex>& devices)
: stages_() {
TORCH_CHECK(stage_modules.size() == devices.size(),
"Number of stages must match number of devices");
// 初始化流水线阶段
for (size_t i = 0; i < stage_modules.size(); ++i) {
stages_.emplace_back(stage_modules[i], devices[i], static_cast<int>(i));
}
}
// 执行流水线训练
void train_pipeline(const std::vector<at::Tensor>& inputs,
const std::vector<at::Tensor>& targets,
int num_microbatches) {
// 分割输入为微批次
auto microbatches = split_into_microbatches(inputs, targets, num_microbatches);
// 生成流水线调度
PipelineScheduler scheduler(num_microbatches, stages_.size());
auto schedule = scheduler.generate_1f1b_schedule();
// 激活缓存(用于反向传播)
std::vector<std::vector<at::Tensor>> forward_cache(num_microbatches);
// 执行流水线调度
for (const auto& op : schedule) {
switch (op.type) {
case OperationType::FORWARD:
execute_forward_pass(microbatches[op.microbatch_id],
op.stage_id, forward_cache[op.microbatch_id]);
break;
case OperationType::BACKWARD:
execute_backward_pass(microbatches[op.microbatch_id],
op.stage_id, forward_cache[op.microbatch_id]);
break;
case OperationType::COMMUNICATION:
execute_communication(op.microbatch_id, op.stage_id);
break;
}
}
}
private:
// 分割数据为微批次
std::vector<Microbatch> split_into_microbatches(
const std::vector<at::Tensor>& inputs,
const std::vector<at::Tensor>& targets,
int num_microbatches) {
std::vector<Microbatch> microbatches;
microbatches.reserve(num_microbatches);
int64_t batch_size = inputs[0].size(0);
int64_t microbatch_size = batch_size / num_microbatches;
for (int i = 0; i < num_microbatches; ++i) {
int64_t start_idx = i * microbatch_size;
int64_t end_idx = (i == num_microbatches - 1) ? batch_size : (i + 1) * microbatch_size;
// 切片输入和目标
auto input_slice = inputs[0].narrow(0, start_idx, end_idx - start_idx);
auto target_slice = targets[0].narrow(0, start_idx, end_idx - start_idx);
microbatches.emplace_back(input_slice, target_slice, i);
}
return microbatches;
}
// 执行前向传播
void execute_forward_pass(Microbatch& microbatch, int stage_id,
std::vector<at::Tensor>& activation_cache) {
auto& stage = stages_[stage_id];
// 设置设备上下文
c10::cuda::CUDADeviceGuard guard(stage.device);
// 将数据移动到当前设备
if (microbatch.data.device().index() != stage.device) {
microbatch.data = microbatch.data.to(stage.device);
}
// 前向传播
at::Tensor output;
if (stage.use_checkpointing) {
// 使用梯度检查点节省内存
output = torch::utils::checkpoint::checkpoint(
[&](const at::Tensor& input) {
return stage.module.forward({input})[0];
},
microbatch.data
);
} else {
// 普通前向传播
output = stage.module.forward({microbatch.data})[0];
// 缓存激活用于反向传播
activation_cache.push_back(microbatch.data.detach());
}
// 更新微批次数据
microbatch.data = output;
microbatch.current_stage = stage_id + 1;
// 如果不是最后一个阶段,发送到下一个阶段
if (stage_id < static_cast<int>(stages_.size()) - 1) {
send_to_next_stage(microbatch, stage_id + 1);
}
}
// 执行反向传播
void execute_backward_pass(Microbatch& microbatch, int stage_id,
const std::vector<at::Tensor>& activation_cache) {
auto& stage = stages_[stage_id];
c10::cuda::CUDADeviceGuard guard(stage.device);
// 计算损失梯度(只在最后一个阶段)
at::Tensor grad_output;
if (stage_id == static_cast<int>(stages_.size()) - 1) {
// 最后阶段:计算损失
auto loss = torch::nn::functional::cross_entropy(
microbatch.data, microbatch.target
);
grad_output = at::autograd::grad({loss}, {microbatch.data})[0];
} else {
// 中间阶段:接收来自后续阶段的梯度
grad_output = receive_grad_from_next_stage(stage_id + 1);
}
// 反向传播
at::Tensor grad_input;
if (stage.use_checkpointing) {
// 重新计算前向传播
auto recomputed_output = stage.module.forward({microbatch.data})[0];
grad_input = at::autograd::grad({recomputed_output}, {microbatch.data}, {grad_output})[0];
} else {
// 使用缓存的激活
grad_input = at::autograd::grad(
{microbatch.data},
{activation_cache.back()},
{grad_output}
)[0];
}
// 如果不是第一个阶段,发送梯度到前一个阶段
if (stage_id > 0) {
send_grad_to_prev_stage(grad_input, stage_id - 1);
}
}
// 阶段间通信
void send_to_next_stage(const Microbatch& microbatch, int next_stage_id) {
auto next_device = stages_[next_stage_id].device;
// P2P发送或通过NCCL
if (can_use_p2p(microbatch.data.device().index(), next_device)) {
// 直接P2P传输
auto transferred = microbatch.data.to(next_device, /*non_blocking=*/true);
} else {
// 通过ProcessGroup发送
auto work = process_group_->send({microbatch.data}, next_stage_id, /*tag=*/0);
work->wait();
}
}
at::Tensor receive_grad_from_next_stage(int next_stage_id) {
// 从下一个阶段接收梯度
auto grad_tensor = at::empty_like(stages_[next_stage_id - 1].module.parameters().front());
auto work = process_group_->recv({grad_tensor}, next_stage_id, /*tag=*/1);
work->wait();
return grad_tensor;
}
bool can_use_p2p(c10::DeviceIndex src_device, c10::DeviceIndex dst_device) {
// 检查是否可以使用P2P传输
if (src_device == dst_device) return true;
int can_access;
C10_CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access, src_device, dst_device));
return can_access == 1;
}
};
} // namespace torch::distributed::pipeline
|