Last active
December 3, 2024 11:16
-
-
Save eladn/65a9bff58bf00c2793a7f661bb10239c to your computer and use it in GitHub Desktop.
Python ray pool for custom ray actors
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| __author__ = "Elad Nachmias" | |
| __email__ = "[email protected]" | |
| __date__ = "2023-10-19" | |
| import sys | |
| import time | |
| import uuid | |
| import logging | |
| import threading | |
| import traceback | |
| import dataclasses | |
| from abc import ABC, abstractmethod | |
| from typing import List, Dict, Optional, Set, Any, TypeVar, Tuple | |
| import ray | |
| TaskReturnType = TypeVar('TaskReturnType') | |
| """ | |
| TODO: PoolStatus enum. | |
| TODO: invoke internal ray actor remote call to terminate the actors gracefully before agressively killing. | |
| TODO: fix terminology of close() and gracefully_close() to correspond with pythonic convension. | |
| TODO: make interface for clarity | |
| TODO: refactor exception/errors handling (currently there are several methods for this) and add tests | |
| """ | |
| @dataclasses.dataclass(unsafe_hash=True) | |
| class TaskInfo: | |
| submission_id: str | |
| future: ray.ObjectRef | |
| worker: ray.actor.ActorHandle = dataclasses.field(hash=False, compare=False) | |
| submit_time: float = dataclasses.field(hash=False, compare=False) | |
| @dataclasses.dataclass(unsafe_hash=True) | |
| class WorkerInfo: | |
| actor_handle: ray.actor.ActorHandle | |
| total_nr_submitted_tasks: int = dataclasses.field(default=0, hash=False, compare=False) | |
| tasks: Dict[ray.ObjectRef, TaskInfo] = dataclasses.field(default_factory=dict, hash=False, compare=False) | |
| last_finished_task_time: Optional[float] = dataclasses.field(default=None, hash=False, compare=False) | |
| @property | |
| def latest_task_submit_time(self) -> Optional[float]: | |
| return max((task_info.submit_time for task_info in self.tasks.values()), default=None) | |
| @property | |
| def latest_worker_event(self) -> Optional[float]: | |
| latest_worker_event = self.latest_task_submit_time | |
| if self.last_finished_task_time is not None and ( | |
| latest_worker_event is None or self.last_finished_task_time > latest_worker_event): | |
| # Note that we submit multiple tasks per worker at each given point in time. These submitted tasks are being added to a | |
| # queue of pending tasks at the workers' underlying ray implementation. At each point in time the actor is serving a only | |
| # single task. Thus, whenever we submit a task, the worker might be working on a previous one, and it might take a while | |
| # until it begins to serve that just submitted task. Therefore, it's not sufficient to regard the tasks' submit time as the | |
| # effective tasks' start time in practice. To alleviate this, we also consider the latest event occurred by this worker. | |
| # That is, we take the latest finished task into account while calculating the latest worker event. | |
| # Note that it is also not sufficient to only consider the time that has been passed since the last finished task by this | |
| # worker, as the pool might have been paused for a while since that time, and thus this worker might have been idle for | |
| # a while. Thus, we take the latest between the last finished task time and the last task submitting time. | |
| latest_worker_event = self.last_finished_task_time | |
| return latest_worker_event | |
| @property | |
| def is_busy(self) -> bool: | |
| return len(self.tasks) > 0 | |
| class RayPool(ABC): | |
| """ | |
| Manages distributed execution of remote actor-tasks. Adapted for a workflow in which: | |
| (i) User's custom ray actors; | |
| (ii) Actors can be dynamically instantiated/terminated: | |
| (a) Wrt varied utilization (temporarily pausing tasks dispatching), and/or | |
| (b) Due to actors' failures that requires killing and re-creating an actor; | |
| (iii) Continuous tasks dispatching: | |
| (a) Set of tasks is not entirely dispatched in advanced. They are rather continuously dispatched along the course of the execution | |
| whenever there exist available workers and the pool execution is not paused. | |
| (b) Uniformly distribute tasks among workers. The tasks submission policy to workers is defined to be on-the-fly (as described | |
| above) to avoid cases where some workers are idle while others are at full capacity. Ideally, this policy aims to be agnostic | |
| to variation in the workers' throughputs and tasks runtime. | |
| (c) Allows systematically dispatching more than a single task per worker (configurable parameter). To have an available task call | |
| ready to be executed at each worker's pending tasks queue. So that whenever the worker finishes the execution of the current | |
| task it will immediately start executing the next pending task, without being idle until the pool's main run() loop receives | |
| the previous task result and dispatching the next task to this worker. | |
| (iv) Centralized tasks dispatching: | |
| Whenever the pool calls the dispatching method to invoke a new remote task call, the dispatcher can decide which actual task to | |
| dispatch and which params to pass. For example, the dispatcher can (potentially) maintain its own data-structures to track which | |
| tasks should be the next to invoke. The pool keeps track on completed/canceled dispatched tasks, so the dispatcher can use this | |
| information (by additionally overriding methods like `_remove_worker_and_submitted_tasks()` and `_clear_all_workers_and_tasks()`). | |
| This is a base abstract class that implements the core logics and functionalities and keeps abstract methods for tasks dispatching, | |
| return values handling, and actors instantiating non-implemented. The inheritor is expected to implement these ad-hoc to its needs. | |
| The main logics happens inside the long-lived call to `run()`. Basically, it loops as long as the pool is not stopped. In each | |
| iteration we wait for complete tasks, submit new tasks, and terminate/instantiate workers as needed (based on utilization and/or | |
| workers' health). This work is mostly IO-bound - meaning most of the time the process/thread that executes `run()` is sleeping on a | |
| blocked call (e.g: ray.wait()). | |
| The methods are thread-safe. That is, the user can call control operations (e.g: pause(), resume()) concurrently while run() is running. | |
| The operations can be called from different concurrent threads. If the pool itself is executed from a ray actor, then a threaded-actor | |
| can be used. | |
| """ | |
| def __init__( | |
| self, | |
| num_actors: int, | |
| logger: Optional[logging.Logger] = None, | |
| min_tasks_to_process: int = 5, | |
| min_tasks_to_process_timeout: float = 900., | |
| waiting_for_tasks_wait_time: float = 1., | |
| nr_desired_simultaneous_tasks_per_worker: int = 2, | |
| task_timeout: Optional[float] = None, | |
| handle_worker_exceptions_by_default: bool = True, # Used if `allowed_exceptions` is None | |
| allowed_exceptions: Optional[List[str]] = None, # Pass `None` for default handling policy | |
| wait_for_tasks_end_on_reset: bool = False): | |
| """ | |
| :param num_actors: Maximum number of actors to instantiate when fully active (not paused) | |
| :param logger: Optional. All status updates during the execution are being printed to this logger | |
| :param min_tasks_to_process: Minimum number of tasks to ray.wait() for their completion before handling and proceeding the main loop | |
| :param min_tasks_to_process_timeout: Timeout for each ray.wait() call on the active submitted tasks - clear all tasks if reached | |
| :param waiting_for_tasks_wait_time: Wait time between consequent pollings while paused and no active tasks to wait for | |
| :param nr_desired_simultaneous_tasks_per_worker: Number of co-existing simultaneously remote task invocations to wait for per worker. Ideally, | |
| it should be big enough for small tasks to ensure the worker will always have a pending task | |
| to begin working on just after finishing the current one and before the main loop detects it and | |
| submits the next task | |
| :param task_timeout: Optional. If given, terminate worker that haven't finished any task for at least this period of time | |
| :param handle_worker_exceptions_by_default: When `allowed_exceptions` is not given, how to handle workers' exceptions | |
| :param allowed_exceptions: Optional. If given, don't terminate workers that raise a corresponding exception | |
| :param wait_for_tasks_end_on_reset: Whether to actively wait for tasks completion on reset | |
| """ | |
| self._num_actors = num_actors | |
| self._logger = _get_logger() if logger is None else logger | |
| self._min_tasks_to_process = min_tasks_to_process | |
| self._min_tasks_to_process_timeout = min_tasks_to_process_timeout | |
| self._waiting_for_tasks_wait_time = waiting_for_tasks_wait_time | |
| self._control_lock = threading.RLock() # protects the status flags | |
| # TODO: Consider using a dedicated `PoolStatus` enum instead of the following flags. | |
| self._should_submit_new_tasks = False | |
| self._should_stop = False | |
| self._should_terminate = False # graceful termination - wait for completion of active tasks before exiting run()'s main loop | |
| self._is_idle = True | |
| self._is_main_loop_active = False # avoid having two active main loops concurrently | |
| self._tasks_info: Dict[ray.ObjectRef, TaskInfo] = {} | |
| self._workers_info: Dict[ray.actor.ActorHandle, WorkerInfo] = {} | |
| self._nr_desired_simultaneous_tasks_per_worker = nr_desired_simultaneous_tasks_per_worker | |
| self._task_timeout = task_timeout | |
| self._handle_worker_exceptions_by_default = handle_worker_exceptions_by_default | |
| self._allowed_exceptions = allowed_exceptions | |
| self._wait_for_tasks_end_on_reset = wait_for_tasks_end_on_reset | |
| def run(self): | |
| """ | |
| Here is where the main logics happen. It's a long-lived call. Dynamically instantiate/terminate workers and submit tasks | |
| corresponding to the pool's status (whether paused/stopped). Wait for dispatched tasks to complete, and handle compled ttasks' | |
| return values. Monitor and track the health of the workers. Exit after `stop()` is called is raised. Each loop iteration contains | |
| either waiting for pending tasks or only polling the resuming flag `self._should_submit_new_tasks` (if paused). | |
| """ | |
| try: | |
| with self._control_lock: | |
| self._should_submit_new_tasks = True | |
| self._should_stop = False | |
| self._should_terminate = False | |
| self._is_idle = False | |
| if self._is_main_loop_active: | |
| # This verification is necessary when support having a concurrent actor with multiple threads (2) for handling | |
| # user's calls for other methods (e.g: stop(), update_*_network() etc) while running this infinite task. | |
| raise RuntimeError(f'{self.run.__name__}() should be called only once concurrently!') | |
| self._is_main_loop_active = True | |
| # Initialize all workers | |
| self._workers_info = {} | |
| self._instantiate_and_register_workers(self._num_actors) | |
| self._logger.info("Finished initializing") | |
| # Initialize empty tasks list. Tasks will be submitted to workers within the loop below. | |
| self._tasks_info = {} | |
| self.on_before_main_loop() | |
| last_waiting_time_for_tasks_logging_time = None | |
| while not self._should_stop: | |
| self.on_main_loop_iteration_start() | |
| if len(self._tasks_info) == 0 and not self._should_submit_new_tasks: | |
| self._clear_all_workers_and_tasks() | |
| if self._should_terminate: | |
| break | |
| else: | |
| # No tasks exist, wait for 1 second and try again | |
| if last_waiting_time_for_tasks_logging_time is None or \ | |
| time.monotonic() - last_waiting_time_for_tasks_logging_time >= 60.: | |
| self._logger.info("No tasks exist, waiting") | |
| last_waiting_time_for_tasks_logging_time = time.monotonic() | |
| time.sleep(self._waiting_for_tasks_wait_time) | |
| continue | |
| self._enforce_tasks_timeout() | |
| self._fill_workers_tasks_queues() | |
| just_finished_tasks_futures = self._wait_for_pending_tasks() | |
| while len(just_finished_tasks_futures) > 0: | |
| finished_task_future = just_finished_tasks_futures.pop() | |
| self._handle_finished_task_handle(finished_task_future) | |
| self.on_main_loop_iteration_end() | |
| except Exception as ex: | |
| self._logger.exception(f"Major Exception, {ex.args}, {type(ex)}, {ex}") | |
| self._logger.exception(traceback.format_exc()) | |
| finally: | |
| # Softly terminate workers. | |
| self._clear_all_workers_and_tasks() | |
| self._is_idle = True | |
| self._should_submit_new_tasks = False | |
| self._is_main_loop_active = False | |
| self._logger.info("All workers and task softly cleared.") | |
| self._logger.info("Exiting.") | |
| self._flush_logger() # Ensure logs are flushed to output. Helpful to make potential errors being visible after the exit. | |
| def stop(self): | |
| """ | |
| Exit the run() main loop (if active) as soon as possible, without waiting to tasks to finish. | |
| """ | |
| self._logger.info("Stopping") | |
| with self._control_lock: | |
| self._should_stop = True | |
| self._should_submit_new_tasks = False | |
| def gracefully_terminate(self): | |
| """ | |
| Stop submission of new tasks. Exit run()'s main loop when all active tasks are completed (or failed). | |
| """ | |
| self._logger.info("Marked to terminate gracefully (stop tasks submission and wait for active tasks to complete)..") | |
| with self._control_lock: | |
| self._should_terminate = True | |
| self._should_submit_new_tasks = False | |
| def pause(self): | |
| """ | |
| Cease submitting new tasks to workers (and terminate idle workers). Keep the main run() loop active and collect tasks' results or | |
| wait for resumption. | |
| """ | |
| self._logger.info("Pausing") | |
| with self._control_lock: | |
| self._should_submit_new_tasks = False | |
| def resume(self): | |
| """ | |
| Resume submitting new tasks to workers if paused before (and instantiate new workers if needed). | |
| """ | |
| self._logger.info("Resuming") | |
| with self._control_lock: | |
| self._should_submit_new_tasks = True | |
| def reset(self): | |
| """ | |
| Pause (wait for all active tasks to finish without submitting new tasks), call a callback to potentially clear internal buffers of | |
| the inheritor if needed, then resume. | |
| """ | |
| self._logger.info("Resetting ..") | |
| self.pause() | |
| if self._wait_for_tasks_end_on_reset: | |
| while not self._is_idle: | |
| self._logger.info("Waiting for tasks for finish after reset, sleeping for 30 seconds.") | |
| time.sleep(30) | |
| self._logger.info("All tasks finished.") | |
| # Support potentially clearing internal buffers of the inheritor if needed. | |
| self.on_reset_before_resumption() | |
| self.resume() | |
| def _flush_logger(self): | |
| """ | |
| Ensure logs are flushed to output. Helpful to make potential errors being visible to the user. | |
| """ | |
| for handler in self._logger.handlers: | |
| handler.flush() | |
| sys.stdout.flush() | |
| sys.stderr.flush() | |
| # Yield the active thread/process for a short period to give a chance for other kind of outputs to get flushed. | |
| # https://www.geeksforgeeks.org/python-sys-stdout-flush/ | |
| time.sleep(1) | |
| def _wait_for_pending_tasks(self) -> List[ray.ObjectRef]: | |
| """ | |
| This is one of the main methods being called directly from run()'s main loop. In each iteration we wait for some of the remotely | |
| submitted active tasks to be completed. | |
| """ | |
| # Encapsulating this allows elegantly avoid having an extra redundant reference to tasks to avoid deferring its release when we | |
| # actually lose the last reference to it. | |
| pending_tasks_list = list(self._tasks_info.keys()) | |
| assert len(pending_tasks_list) >= 0 | |
| just_finished_tasks, non_finished_tasks = ray.wait( | |
| pending_tasks_list, | |
| num_returns=min(self._min_tasks_to_process, len(pending_tasks_list)), | |
| timeout=self._min_tasks_to_process_timeout) | |
| # Ray wait timeout exceeded without a finished task. | |
| if len(just_finished_tasks) == 0 and len(non_finished_tasks) > 0: | |
| # failed to finish any task - release resources | |
| self._logger.info("Cluster stuck. Releasing all resources !!!") | |
| self._clear_all_workers_and_tasks() | |
| return just_finished_tasks | |
| def _fill_worker_tasks_queue(self, worker: ray.actor.ActorHandle): | |
| """ | |
| Ensure that enough remote tasks are being invoked for the given worker. | |
| """ | |
| if not self._should_submit_new_tasks: | |
| return | |
| # Requesting multiple tasks simultaneously so there's always a task in place in the worker's queue if main loop doesn't keep up. | |
| # As the `ray.wait()` waits until the finished task's data is fetched locally, which might take time. | |
| while len(self._workers_info[worker].tasks) < self._nr_desired_simultaneous_tasks_per_worker: | |
| self._submit_tracked_task_to_worker(worker=worker) | |
| def _submit_tracked_task_to_worker(self, worker: ray.actor.ActorHandle): | |
| """ | |
| Create a new remote task call and store the future handle in the internal data-structures for tracking and waiting for it. For | |
| keeping the actual performed task as general as possible, the relevant remote call triggering is being encapsulated by | |
| `_trigger_worker_remote_task_call()`. | |
| """ | |
| submission_id = self._generate_unique_submission_id() | |
| task_future = self.invoke_worker_remote_task_call(worker=worker, submission_id=submission_id) | |
| if task_future is None: | |
| return | |
| with self._control_lock: | |
| self._tasks_info[task_future] = task_info = TaskInfo( | |
| submission_id=submission_id, future=task_future, worker=worker, submit_time=time.monotonic()) | |
| worker_info = self._workers_info[worker] | |
| assert task_future not in worker_info.tasks | |
| worker_info.tasks[task_future] = task_info | |
| worker_info.total_nr_submitted_tasks += 1 | |
| self._is_idle = False | |
| def _generate_unique_submission_id(self) -> str: | |
| """ | |
| Used for assigning a unique identifier for each submitted task. This approach allows the user (that inherits from RayPool) to track | |
| tasks. Especially tasks that are failed/discarded without returning a value. Here is the default implementation. It can be | |
| overridden by the user to have custom identifiers. | |
| """ | |
| return uuid.uuid4().hex | |
| def _remove_worker_and_submitted_tasks(self, worker: ray.actor.ActorHandle, kill: bool = False): | |
| """ | |
| Removing a worker is useful both whenever it becomes idle or when a failure occurred (that requires discarding the worker). Ensure | |
| that the tracked remote tasks submitted for this worker are also discarded. Optionally, forcely kill the worker. If not, only lose | |
| all the references to it and make ray release it. | |
| """ | |
| discarded_task_submission_ids = set() | |
| with self._control_lock: | |
| worker_info = self._workers_info.pop(worker, None) | |
| if worker_info is not None: | |
| for task_info in worker_info.tasks.values(): | |
| self._tasks_info.pop(task_info.future) | |
| discarded_task_submission_ids.add(task_info.submission_id) | |
| if kill: | |
| try: | |
| ray.kill(worker) | |
| except: | |
| pass | |
| for submission_id in discarded_task_submission_ids: | |
| self.task_discarded(submission_id=submission_id) | |
| def _clear_all_tasks(self): | |
| """ | |
| Lose all tracked references to all submitted remote tasks. Keep the workers if exist. | |
| """ | |
| discarded_task_submission_ids = set() | |
| with self._control_lock: | |
| while len(self._tasks_info) > 0: | |
| _, task_info = self._tasks_info.popitem() | |
| discarded_task_submission_ids.add(task_info.submission_id) | |
| for worker_info in self._workers_info.values(): | |
| assert all(task_info.submission_id in discarded_task_submission_ids for task_info in worker_info.tasks.values()) | |
| worker_info.tasks.clear() | |
| if not self._should_submit_new_tasks: | |
| self._is_idle = True | |
| for submission_id in discarded_task_submission_ids: | |
| self.task_discarded(submission_id=submission_id) | |
| def _clear_all_workers_and_tasks(self): | |
| """ | |
| Lose all tracked references to all submitted remote tasks and all workers. Used upon when exiting or when a global error occurred. | |
| """ | |
| with self._control_lock: | |
| self._clear_all_tasks() | |
| self._workers_info.clear() | |
| def _pop_finished_task(self, task_future: ray.ObjectRef) -> Optional[TaskInfo]: | |
| """ | |
| Remove all references to the given task handle from the internal data-structures. | |
| :return: handle to the corresponding worker that the given task is assigned to. | |
| """ | |
| if task_future not in self._tasks_info: | |
| # Can happen during the loop of handling the finished tasks, if several finished tasks came from the same worker that caused an | |
| # exception in ray get. | |
| return None | |
| task_info = self._tasks_info.pop(task_future) | |
| worker_info = self._workers_info[task_info.worker] | |
| worker_info.tasks.pop(task_future) | |
| worker_info.last_finished_task_time = time.monotonic() | |
| return task_info | |
| def _handle_finished_task_handle(self, task_handle: ray.ObjectRef): | |
| """ | |
| Called when a remote task has been finished and fetched locally. Should clear all relevant tracked references to this task, obtain | |
| its return value, and pass it to the user. | |
| """ | |
| task_info = self._pop_finished_task(task_future=task_handle) | |
| if task_info is None: | |
| # Continue in case task is no longer in task worker map keys. This can happen if several | |
| # finished tasks came from the same worker that caused an exception in ray get. | |
| return | |
| try: | |
| task_returned_value = ray.get(task_handle) | |
| except Exception as ex: | |
| self._handle_worker_failure(worker=task_info.worker, failed_submission_id=task_info.submission_id, exception=ex) | |
| return | |
| is_worker_healthy = self.is_task_results_healthy( | |
| task_returned_value=task_returned_value, submission_id=task_info.submission_id) | |
| if not is_worker_healthy: | |
| self._logger.error(f'Task return value is not healthy. Killing worker.') | |
| self._terminate_failed_worker(worker=task_info.worker, related_failed_submission_id=task_info.submission_id, force_kill=True) | |
| return | |
| # If we want to continue performing tasks, add a new task for this worker | |
| self._fill_worker_tasks_queue(worker=task_info.worker) | |
| self.handle_task_results(task_returned_value=task_returned_value, submission_id=task_info.submission_id) | |
| def _handle_worker_failure(self, worker: ray.actor.ActorHandle, failed_submission_id: str, exception: Exception): | |
| """ | |
| Called whenever a worker failure has occurred (raised by `ray.get()`). This implementation terminates the worker and gracefully | |
| proceed the main pool's run() execution depending on the actual error occurred. If given explicit `allowed_exceptions` list, then | |
| for non-allowed exceptions the exception would propagate, so that the entire run() execution would be crashed. Encapsulating these | |
| logics within a standalone method Can be overridden by the inheritor to customly extend/modify the failure handling functionality. | |
| """ | |
| self._logger.exception("Exception occurred in worker") | |
| self._logger.exception(traceback.format_exc()) | |
| try: | |
| # Handle exception based on the allowed exceptions list. If the exception is not allowed, it will be raised from within the | |
| # handling method. | |
| if self._allowed_exceptions is None: | |
| if not self._handle_worker_exceptions_by_default: | |
| raise exception | |
| else: | |
| _handle_allowed_exceptions(exception=exception, allowed_exception_types_partial_names=self._allowed_exceptions) | |
| except: | |
| self._terminate() | |
| raise # propagate the error upward (will terminate the `run()` call) | |
| # This finer handling case is reachable if the error is not propagated. | |
| # Lose reference to the failed worker and to its related tasks, as it might be dead. | |
| # The main loop will add a new worker later if still generating. | |
| # In addition to that, remove all tasks that are scheduled on the problematic worker. | |
| self._terminate_failed_worker(worker=worker, related_failed_submission_id=failed_submission_id, related_exception=exception) | |
| def _terminate_failed_worker( | |
| self, | |
| worker: ray.actor.ActorHandle, | |
| related_failed_submission_id: Optional[str] = None, | |
| related_exception: Optional[Exception] = None, | |
| force_kill: bool = False): | |
| """ | |
| Called whenever a worker should be terminated due to a failure. | |
| Can be overridden by the inheritor to track workers' failures. | |
| """ | |
| self._remove_worker_and_submitted_tasks(worker=worker, kill=force_kill) | |
| def _terminate(self): | |
| """ | |
| Encapsulating and unifying the necessary steps for terminating run()'s main loop from auxiliary methods executed within the loop. | |
| For example, whenever a non-recoverable error has occurred (or an error that is marked to not be handled). | |
| Note: Expected to be called from within the main thread that executes `run()` and not from control methods that might be called | |
| from other concurrent threads, as it modifies un-protected structures. | |
| """ | |
| with self._control_lock: | |
| assert self._is_main_loop_active | |
| self._should_submit_new_tasks = False | |
| self._should_stop = True | |
| self._clear_all_tasks() | |
| def _enforce_tasks_timeout(self): | |
| """ | |
| Called by the main loop. Kill tasks that exceeded the timeout if set. | |
| """ | |
| if self._task_timeout is None: | |
| return # limitation is unset | |
| now = time.monotonic() | |
| workers_to_remove = [] | |
| for worker_handle, worker_info in self._workers_info.items(): | |
| if worker_info.latest_worker_event is not None and now - worker_info.latest_worker_event > self._task_timeout: | |
| # Nothing happened with this worker for too long. Mark it to be killed. | |
| workers_to_remove.append(worker_handle) | |
| if workers_to_remove: | |
| self._logger.warning(f'Killing {len(workers_to_remove)} workers with exceeded tasks timeout.') | |
| # We perform the actual workers removal only after iterating through the workers to avoid changing list during iterating it. | |
| for worker_handle in workers_to_remove: | |
| self._remove_worker_and_submitted_tasks(worker_handle, kill=True) | |
| def _fill_workers_tasks_queues(self): | |
| """ | |
| Invoke enough tasks for each worker to meet the desired number of simultaneous submitted tasks. | |
| """ | |
| # Generate tasks for all non-busy workers | |
| if self._should_submit_new_tasks: | |
| # Add workers according to required amount | |
| if self._num_actors > len(self._workers_info): | |
| self._instantiate_and_register_workers(self._num_actors - len(self._workers_info)) | |
| for worker in self._workers_info.keys(): | |
| self._fill_worker_tasks_queue(worker=worker) | |
| else: | |
| self._remove_idle_workers() | |
| def _get_busy_workers(self) -> Set[ray.actor.ActorHandle]: | |
| """ | |
| :return: a set of handles to workers that have an active/pending task assigned. | |
| """ | |
| return set(worker_info.actor_handle for worker_info in self._workers_info.values() if worker_info.is_busy) | |
| def _get_idle_workers(self) -> Set[ray.actor.ActorHandle]: | |
| """ | |
| :return: a set of handles to workers that don't have an active/pending task assigned. | |
| """ | |
| return set(self._workers_info.keys()) - self._get_busy_workers() | |
| def _remove_idle_workers(self): | |
| """ | |
| Called by the main loop to erase the workers without any assigned active/pending task. | |
| """ | |
| # Remove non-busy workers from list so ray will free the cpus | |
| idle_workers = self._get_idle_workers() | |
| if len(idle_workers) == 0: | |
| return | |
| self._logger.info(f"Removing {len(idle_workers)} idle workers because shouldn't generate now.") | |
| # Note: It's ok we're not removing verify why we're not removing tasks here | |
| for worker_handle in idle_workers: | |
| self._remove_worker_and_submitted_tasks(worker=worker_handle) | |
| def _instantiate_and_register_workers(self, num_actors: int): | |
| for worker_handle in self.instantiate_workers(num_actors=num_actors): | |
| assert worker_handle not in self._workers_info | |
| self._workers_info[worker_handle] = WorkerInfo(actor_handle=worker_handle) | |
| @abstractmethod | |
| def instantiate_workers(self, num_actors: int) -> List[ray.actor.ActorHandle]: | |
| """ | |
| Should be implemented by the inheritor. Creates the actual actor of the desired type. It is called by the main loop within run(). | |
| Workers can be dynamically instantiated and terminated multiple times during the entire lifetime of run()'s main loop. | |
| """ | |
| raise NotImplementedError | |
| @abstractmethod | |
| def handle_task_results(self, task_returned_value: Any, submission_id: str): | |
| """ | |
| Being called by the `run()`s main loop whenever a task has been completed and fetched locally. The purpose of this method is to | |
| process these results locally however needed, and store them under a local structure or pass them to the relevant consumer. | |
| :param task_returned_value: the value produced and returned from the completed task | |
| :param submission_id: the unique identifier assigned to this task upon its submission | |
| """ | |
| raise NotImplementedError | |
| def task_discarded(self, submission_id: str): | |
| """ | |
| Being called by the `run()`s main loop whenever a task has been discarded. Note that the task execution status is unknown in this | |
| stage; that is, it's unknown whether it's started and interrupted / started and finished / canceled before even starting. It is | |
| promised however that `handle_task_results()` won't and haven't been called on this task. | |
| :param submission_id: the unique identifier assigned to this task upon its submission | |
| """ | |
| pass | |
| @abstractmethod | |
| def invoke_worker_remote_task_call(self, worker: ray.actor.ActorHandle, submission_id: str) -> Optional[ray.ObjectRef]: | |
| """ | |
| Separate method for only creating the remote call to worker. Only returns the just-created task, without tracking / registering | |
| it in the internal pool's data-structures (performed by the caller). Should be overridden by inheritor. Can support, for example, | |
| multiple kinds of tasks and/or centralizing here the distributing and passing of the parameters to the submitted tasks. | |
| :param worker: Handle to the ray actor to submit this task to. RayPool is responsible for balance the submitted jobs per worker; | |
| that is, the user (the inheritor) implements this function without caring neither about actors instantiation / | |
| destruction nor about when submitting remote tasks to which workers. | |
| :param submission_id: The unique identifier for the new task. This allows the user to track the tasks. For example, the user | |
| can pass this id as a parameter to the remote call. If the remote call acquires resources, it can log this | |
| id together with the acquisition. Then, when a task is failed or discarded, the same identifier to that | |
| task is known here and the user can safely release the aqcuired resources that are assigned with this task id. | |
| :return: The future to the just-submitted remote task. | |
| """ | |
| raise NotImplementedError | |
| def is_task_results_healthy(self, task_returned_value: Any, submission_id: str) -> bool: | |
| """ | |
| A means to detect faulty runs and support terminating actors upon such cases. Can be overridden by the inheritor. Returning true | |
| keeps the same behavior. Returning false would disregard this result and also kill the corresponding worker. | |
| """ | |
| return True # default implementation; allows overriding, but not requires explicit implementation by the inheritor | |
| def on_main_loop_iteration_start(self): | |
| """ | |
| Being called at the very beginning of each iteration of the main long-lived loop at run(). | |
| Default implementation does nothing. Overridable by the inheritor for extending the pool. | |
| Use case example: periodically sending an update to all active the workers. | |
| """ | |
| pass | |
| def on_main_loop_iteration_end(self): | |
| """ | |
| Being called at the very ending of each iteration of the main long-lived loop at run(). Note that, in contrast to | |
| `on_main_loop_iteration_start()`, this callback is called only after waiting to active/pending tasks is completed, and thus it is | |
| NOT being called while the pool is paused and no assigned tasks are available to wait for. | |
| Default implementation does nothing. Overridable by the inheritor for extending the pool. | |
| Use case examples: check for custom exit conditions; periodically log the pool's progress. | |
| """ | |
| pass | |
| def on_before_main_loop(self): | |
| """ | |
| Being called at just before entering the main long-lived loop at run(). | |
| Default implementation does nothing. Overridable by the inheritor for extending the pool. | |
| Use case examples: resetting custom members that the inheritor for extended functionality (like periodic log timing tracking). | |
| """ | |
| pass | |
| def on_reset_before_resumption(self): | |
| """ | |
| Being called upon reset(), after the waiting for completed task has been completed and just before resuming. | |
| Default implementation does nothing. Overridable by the inheritor for extending the pool. | |
| Use case examples: resetting custom members that the inheritor for extended functionality. | |
| """ | |
| pass | |
| def _handle_allowed_exceptions(exception: Exception, allowed_exception_types_partial_names: Optional[List[str]]): | |
| if allowed_exception_types_partial_names is None: | |
| return | |
| exc_str = traceback.format_exc() | |
| if not any(allowed_exc_str in exc_str for allowed_exc_str in allowed_exception_types_partial_names): | |
| raise exception | |
| def _get_logger() -> logging.Logger: | |
| return logging.getLogger('RayPool') | |
| @ray.remote(num_cpus=1, memory=6_000_000_000) | |
| class MockActor: | |
| def do_work(self, submission_id: str) -> Tuple[str, float]: | |
| time.sleep(1.) | |
| value = float(np.random.random()) | |
| return submission_id, value | |
| class MockRayPool(RayPool): | |
| def __init__(self, *args, max_nr_tasks: int = 100, **kwargs): | |
| self._nr_invoked_tasks = 0 | |
| self._nr_completed_tasks = 0 | |
| self._nr_discarded_tasks = 0 | |
| self._max_nr_tasks = max_nr_tasks | |
| RayPool.__init__(self, *args, **kwargs) | |
| @property | |
| def nr_completed_tasks(self) -> int: | |
| return self._nr_completed_tasks | |
| @property | |
| def nr_discarded_tasks(self) -> int: | |
| return self._nr_discarded_tasks | |
| @property | |
| def nr_invoked_tasks(self) -> int: | |
| return self._nr_invoked_tasks | |
| def instantiate_workers(self, num_actors: int) -> List[ray.actor.ActorHandle]: | |
| return [MockActor.remote() for _ in range(num_actors)] | |
| def handle_task_results(self, task_returned_value: Any, submission_id: str): | |
| self._nr_completed_tasks += 1 | |
| assert isinstance(task_returned_value, tuple) | |
| assert task_returned_value[0] == submission_id | |
| assert isinstance(task_returned_value[1], float) | |
| def update_metrics_on_task_end(self, task_returned_value: Any): | |
| pass | |
| def invoke_worker_remote_task_call(self, worker: ray.actor.ActorHandle, submission_id: str) -> Optional[ray.ObjectRef]: | |
| if self._nr_invoked_tasks >= self._max_nr_tasks: | |
| return None | |
| self._nr_invoked_tasks += 1 | |
| task = worker.do_work.remote(submission_id=submission_id) | |
| return task | |
| def on_main_loop_iteration_end(self): | |
| if self._nr_invoked_tasks >= self._max_nr_tasks: | |
| self.pause() | |
| def on_main_loop_iteration_start(self): | |
| if self._nr_completed_tasks + self._nr_discarded_tasks >= self._max_nr_tasks: | |
| self.stop() | |
| def task_discarded(self, submission_id: str): | |
| self._nr_discarded_tasks += 1 | |
| class ThreadedMockRayPool(MockRayPool, threading.Thread): | |
| def __init__(self, *args, **kwargs): | |
| MockRayPool.__init__(self, *args, **kwargs) | |
| threading.Thread.__init__(self) | |
| def test_ray_pool_basic(): | |
| logger = _get_logger() | |
| nr_tasks = 100 | |
| pool = MockRayPool(logger=logger, num_actors=10, max_nr_tasks=nr_tasks) | |
| pool.run() | |
| assert pool.nr_completed_tasks == nr_tasks | |
| assert pool.nr_discarded_tasks == 0 | |
| def test_ray_pool_concurrent_with_terminate(): | |
| logger = _get_logger() | |
| nr_tasks = 100 | |
| threaded_pool = ThreadedMockRayPool(logger=logger, num_actors=10, max_nr_tasks=nr_tasks) | |
| threaded_pool.start() | |
| threaded_pool.gracefully_terminate() | |
| nr_invoked_tasks_after_terminate_command = threaded_pool.nr_invoked_tasks | |
| threaded_pool.join() | |
| assert threaded_pool.nr_completed_tasks == nr_invoked_tasks_after_terminate_command | |
| assert threaded_pool.nr_discarded_tasks == 0 | |
| def test_ray_pool_task_discarded(): | |
| logger = _get_logger() | |
| nr_tasks = 100 | |
| threaded_pool = ThreadedMockRayPool(logger=logger, num_actors=10, max_nr_tasks=nr_tasks) | |
| threaded_pool.start() | |
| while threaded_pool.nr_completed_tasks == 0: | |
| time.sleep(.1) | |
| # We stop the generation just after the first results are ready. Then we can expect some results to be discarded before being ready. | |
| # Apart from checking that the `task_discarded()` method is being called properly, we also check that for each task exactly one | |
| # method is called - either the task_discarded() or handle_task_results(). This also characterize a desired flow where the results are | |
| # being reported just when they're ready and that not all of the results are passed together when all are ready (assuming | |
| # #workers >> #tasks). | |
| threaded_pool.stop() | |
| nr_invoked_tasks_after_terminate_command = threaded_pool.nr_invoked_tasks | |
| threaded_pool.join() | |
| assert threaded_pool.nr_discarded_tasks > 0 | |
| assert threaded_pool.nr_completed_tasks > 0 | |
| assert threaded_pool.nr_completed_tasks + threaded_pool.nr_discarded_tasks == nr_invoked_tasks_after_terminate_command | |
| if __name__ == '__main__': | |
| import numpy as np | |
| test_ray_pool_basic() | |
| test_ray_pool_concurrent_with_terminate() | |
| test_ray_pool_task_discarded() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment