1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
| import multiprocessing as mp
import queue
import threading
import time
from typing import Dict, List, Optional, Any
class MultiprocessingWorkerManager:
"""多进程工作管理器(基于并行处理优化分析)"""
def __init__(self, dataset, num_workers, collate_fn, worker_init_fn=None,
multiprocessing_context=None):
self.dataset = dataset
self.num_workers = num_workers
self.collate_fn = collate_fn
self.worker_init_fn = worker_init_fn
# 多进程上下文
if multiprocessing_context is None:
# 根据平台选择最优的多进程方法
if hasattr(mp, 'get_context'):
if sys.platform == 'win32':
self.mp_context = mp.get_context('spawn')
else:
self.mp_context = mp.get_context('fork')
else:
self.mp_context = mp
else:
self.mp_context = multiprocessing_context
# 进程间通信队列
self.index_queues = [] # 主进程 -> 工作进程的索引队列
self.data_queue = None # 工作进程 -> 主进程的数据队列
self.worker_processes = [] # 工作进程列表
self.workers_done_event = None # 工作完成事件
# 工作进程状态监控
self.worker_status = {}
self.worker_pids = {}
self.worker_queue_sizes = [0] * num_workers
# 性能监控
self.task_completion_times = {}
self.worker_utilization = [0.0] * num_workers
def start_workers(self):
"""启动所有工作进程"""
# 创建进程间通信队列
for i in range(self.num_workers):
# 每个工作进程一个索引队列
self.index_queues.append(self.mp_context.Queue())
# 共享的数据队列
self.data_queue = self.mp_context.Queue()
# 工作完成事件
self.workers_done_event = self.mp_context.Event()
# 启动工作进程
for worker_id in range(self.num_workers):
# 创建工作进程
worker_process = self.mp_context.Process(
target=self._worker_loop,
args=(
worker_id,
self.dataset,
self.index_queues[worker_id],
self.data_queue,
self.workers_done_event,
self.collate_fn,
self.worker_init_fn
)
)
worker_process.daemon = True
worker_process.start()
self.worker_processes.append(worker_process)
self.worker_pids[worker_id] = worker_process.pid
self.worker_status[worker_id] = 'running'
# 启动状态监控线程
self._start_monitoring_thread()
def send_task(self, worker_id: int, task: 'DataLoaderTask'):
"""向指定工作进程发送任务"""
try:
self.index_queues[worker_id].put(task, timeout=1.0)
self.worker_queue_sizes[worker_id] += 1
except queue.Full:
raise RuntimeError(f"Worker {worker_id} queue is full")
def get_least_busy_worker(self) -> int:
"""获取最空闲的工作进程"""
return min(range(self.num_workers),
key=lambda i: self.worker_queue_sizes[i])
def shutdown_workers(self):
"""关闭所有工作进程"""
if self.workers_done_event:
self.workers_done_event.set()
# 向所有工作进程发送终止信号
for i in range(self.num_workers):
try:
self.index_queues[i].put(None, timeout=1.0)
except queue.Full:
pass
# 等待工作进程结束
for worker_process in self.worker_processes:
worker_process.join(timeout=5.0)
if worker_process.is_alive():
worker_process.terminate()
worker_process.join()
@staticmethod
def _worker_loop(worker_id, dataset, index_queue, data_queue, done_event,
collate_fn, worker_init_fn):
"""工作进程主循环"""
try:
# 设置随机种子(确保不同工作进程的随机性)
torch.manual_seed(torch.initial_seed() + worker_id)
np.random.seed(torch.initial_seed() + worker_id)
# 执行工作进程初始化函数
if worker_init_fn is not None:
worker_init_fn(worker_id)
# 优化工作进程设置
optimize_worker_process(worker_id)
# 主工作循环
while not done_event.is_set():
try:
# 从索引队列获取任务
task = index_queue.get(timeout=1.0)
if task is None: # 终止信号
break
# 执行数据加载
batch_data = load_batch_data(dataset, task.indices, collate_fn)
# 发送结果到主进程
result = DataLoaderResult(
task_id=task.task_id,
worker_id=worker_id,
data=batch_data,
completion_time=time.time()
)
data_queue.put(result)
except queue.Empty:
continue
except Exception as e:
# 发送异常信息到主进程
error_result = DataLoaderResult(
task_id=getattr(task, 'task_id', -1),
worker_id=worker_id,
error=e,
completion_time=time.time()
)
data_queue.put(error_result)
except Exception as e:
print(f"Worker {worker_id} crashed with error: {e}")
def _start_monitoring_thread(self):
"""启动工作进程监控线程"""
def monitor_workers():
while not self.workers_done_event.is_set():
time.sleep(1.0) # 每秒检查一次
# 检查工作进程健康状态
for worker_id, process in enumerate(self.worker_processes):
if not process.is_alive():
self.worker_status[worker_id] = 'dead'
# 可以实现工作进程重启逻辑
self._restart_worker(worker_id)
# 更新队列大小统计
try:
queue_size = self.index_queues[worker_id].qsize()
self.worker_queue_sizes[worker_id] = queue_size
except NotImplementedError:
# 某些平台不支持qsize()
pass
# 计算工作进程利用率
self._update_worker_utilization()
monitor_thread = threading.Thread(target=monitor_workers, daemon=True)
monitor_thread.start()
def _restart_worker(self, worker_id):
"""重启故障的工作进程"""
if self.worker_status[worker_id] == 'dead':
# 清理旧进程
old_process = self.worker_processes[worker_id]
if old_process.is_alive():
old_process.terminate()
old_process.join()
# 创建新的工作进程
new_process = self.mp_context.Process(
target=self._worker_loop,
args=(
worker_id,
self.dataset,
self.index_queues[worker_id],
self.data_queue,
self.workers_done_event,
self.collate_fn,
self.worker_init_fn
)
)
new_process.daemon = True
new_process.start()
self.worker_processes[worker_id] = new_process
self.worker_pids[worker_id] = new_process.pid
self.worker_status[worker_id] = 'running'
def optimize_worker_process(worker_id):
"""优化工作进程设置(基于系统优化分析)"""
try:
import os
import psutil
# 设置CPU亲和性(减少缓存miss)
cpu_count = os.cpu_count()
if cpu_count > 1:
# 将工作进程绑定到特定CPU核心
available_cpus = list(range(cpu_count))
assigned_cpu = available_cpus[worker_id % len(available_cpus)]
process = psutil.Process()
process.cpu_affinity([assigned_cpu])
# 设置进程优先级
if hasattr(os, 'nice'):
os.nice(-5) # 提高优先级(需要权限)
# 设置内存策略
if hasattr(os, 'sched_setaffinity'):
# 在NUMA系统上优化内存访问
numa_node = worker_id % 2 # 简化:假设2个NUMA节点
# 实际需要更复杂的NUMA拓扑检测
except ImportError:
# psutil不可用时跳过优化
pass
except PermissionError:
# 权限不足时跳过优化
pass
def load_batch_data(dataset, indices, collate_fn):
"""加载批次数据(优化版本)"""
# 批量预取优化
if hasattr(dataset, 'batch_load'):
# 数据集支持批量加载
return dataset.batch_load(indices, collate_fn)
# 传统的逐个加载
samples = []
for idx in indices:
try:
sample = dataset[idx]
samples.append(sample)
except Exception as e:
# 处理数据加载异常
print(f"Error loading sample {idx}: {e}")
# 可以选择跳过或使用默认样本
continue
# 应用collate函数
if samples:
return collate_fn(samples)
else:
return None
# 数据加载任务和结果的定义
class DataLoaderTask:
"""数据加载任务"""
def __init__(self, task_id, indices, dataset_kind):
self.task_id = task_id
self.indices = indices
self.dataset_kind = dataset_kind
self.creation_time = time.time()
class DataLoaderResult:
"""数据加载结果"""
def __init__(self, task_id, worker_id, data=None, error=None, completion_time=None):
self.task_id = task_id
self.worker_id = worker_id
self.data = data
self.error = error
self.completion_time = completion_time or time.time()
|