Last active
May 29, 2022 19:46
-
-
Save TerryGeng/77fd2c524a5be69131d8b4176accc3dc to your computer and use it in GitHub Desktop.
Multiprocessing in Jupyter Notebook. With progress bar and ETA display.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import time | |
| from datetime import timedelta | |
| from multiprocess import Process, Pipe | |
| import ipywidgets as widgets | |
| from IPython.display import display | |
| def run_parallel_tasks(func, tasks, processes=4): | |
| progbar = widgets.IntProgress( | |
| value=0, | |
| min=0, | |
| max=len(tasks), | |
| step=1, | |
| description='Loading:', | |
| bar_style='', # 'success', 'info', 'warning', 'danger' or '' | |
| orientation='horizontal' | |
| ) | |
| eta_label = widgets.Label(value="ETA: inf") | |
| display(progbar, eta_label) | |
| tasks_sub = [] | |
| for i in range(processes): | |
| sub_cnt = int(len(tasks) / processes) + 1 | |
| begin = i * sub_cnt | |
| end = (i + 1) * sub_cnt | |
| tasks_sub.append(tasks[begin:end]) | |
| def task_func(ind, func, sub_tasks, status_pipe_in): | |
| for i, task in enumerate(sub_tasks): | |
| status_pipe_in.send([ind, 'DATA', (i, func(*task))]) | |
| status_pipe_in.send([ind, 'RET', None]) | |
| status_pipe_in, status_pipe_out = Pipe() | |
| progs = [0 for i in range(processes)] | |
| results = [[] for i in range(processes)] | |
| finished_cnt = 0 | |
| procs = [Process(target=task_func, args=(i, func, tasks_sub[i], status_pipe_in)) for i in range(processes)] | |
| [proc.start() for proc in procs] | |
| start_time = time.time() | |
| proc_rate = 0 | |
| last_report_time = 0 | |
| while finished_cnt < processes: | |
| i, kind, val = status_pipe_out.recv() | |
| if kind == 'DATA': | |
| ind, data = val | |
| progs[i] = ind + 1 | |
| results[i].append(data) | |
| current_prog = sum(progs) | |
| if time.time() - last_report_time > 0.5: # do not flood jupyter's message system | |
| progbar.value = current_prog | |
| proc_rate = current_prog / (time.time() - start_time) | |
| eta = timedelta(seconds=int((len(tasks) - current_prog) / proc_rate)) | |
| eta_label.value = f"ETA: {eta}" | |
| last_report_time = time.time() | |
| elif kind == 'RET': | |
| print(f'Task {i} done.') | |
| finished_cnt += 1 | |
| [proc.join() for proc in procs] | |
| progbar.bar_style = 'success' | |
| ret = [] | |
| for r in results: | |
| ret += r | |
| return ret | |
| # ==== EXAMPLE ==== | |
| def square(x): | |
| return x**2 | |
| tasks = range(20) | |
| run_parallel_tasks(square, tasks) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment