Skip to content

Instantly share code, notes, and snippets.

@eladn
Last active December 3, 2024 11:16
Show Gist options
  • Select an option

  • Save eladn/65a9bff58bf00c2793a7f661bb10239c to your computer and use it in GitHub Desktop.

Select an option

Save eladn/65a9bff58bf00c2793a7f661bb10239c to your computer and use it in GitHub Desktop.
Python ray pool for custom ray actors
__author__ = "Elad Nachmias"
__email__ = "[email protected]"
__date__ = "2023-10-19"
import sys
import time
import uuid
import logging
import threading
import traceback
import dataclasses
from abc import ABC, abstractmethod
from typing import List, Dict, Optional, Set, Any, TypeVar, Tuple
import ray
TaskReturnType = TypeVar('TaskReturnType')
"""
TODO: PoolStatus enum.
TODO: invoke internal ray actor remote call to terminate the actors gracefully before agressively killing.
TODO: fix terminology of close() and gracefully_close() to correspond with pythonic convension.
TODO: make interface for clarity
TODO: refactor exception/errors handling (currently there are several methods for this) and add tests
"""
@dataclasses.dataclass(unsafe_hash=True)
class TaskInfo:
submission_id: str
future: ray.ObjectRef
worker: ray.actor.ActorHandle = dataclasses.field(hash=False, compare=False)
submit_time: float = dataclasses.field(hash=False, compare=False)
@dataclasses.dataclass(unsafe_hash=True)
class WorkerInfo:
actor_handle: ray.actor.ActorHandle
total_nr_submitted_tasks: int = dataclasses.field(default=0, hash=False, compare=False)
tasks: Dict[ray.ObjectRef, TaskInfo] = dataclasses.field(default_factory=dict, hash=False, compare=False)
last_finished_task_time: Optional[float] = dataclasses.field(default=None, hash=False, compare=False)
@property
def latest_task_submit_time(self) -> Optional[float]:
return max((task_info.submit_time for task_info in self.tasks.values()), default=None)
@property
def latest_worker_event(self) -> Optional[float]:
latest_worker_event = self.latest_task_submit_time
if self.last_finished_task_time is not None and (
latest_worker_event is None or self.last_finished_task_time > latest_worker_event):
# Note that we submit multiple tasks per worker at each given point in time. These submitted tasks are being added to a
# queue of pending tasks at the workers' underlying ray implementation. At each point in time the actor is serving a only
# single task. Thus, whenever we submit a task, the worker might be working on a previous one, and it might take a while
# until it begins to serve that just submitted task. Therefore, it's not sufficient to regard the tasks' submit time as the
# effective tasks' start time in practice. To alleviate this, we also consider the latest event occurred by this worker.
# That is, we take the latest finished task into account while calculating the latest worker event.
# Note that it is also not sufficient to only consider the time that has been passed since the last finished task by this
# worker, as the pool might have been paused for a while since that time, and thus this worker might have been idle for
# a while. Thus, we take the latest between the last finished task time and the last task submitting time.
latest_worker_event = self.last_finished_task_time
return latest_worker_event
@property
def is_busy(self) -> bool:
return len(self.tasks) > 0
class RayPool(ABC):
"""
Manages distributed execution of remote actor-tasks. Adapted for a workflow in which:
(i) User's custom ray actors;
(ii) Actors can be dynamically instantiated/terminated:
(a) Wrt varied utilization (temporarily pausing tasks dispatching), and/or
(b) Due to actors' failures that requires killing and re-creating an actor;
(iii) Continuous tasks dispatching:
(a) Set of tasks is not entirely dispatched in advanced. They are rather continuously dispatched along the course of the execution
whenever there exist available workers and the pool execution is not paused.
(b) Uniformly distribute tasks among workers. The tasks submission policy to workers is defined to be on-the-fly (as described
above) to avoid cases where some workers are idle while others are at full capacity. Ideally, this policy aims to be agnostic
to variation in the workers' throughputs and tasks runtime.
(c) Allows systematically dispatching more than a single task per worker (configurable parameter). To have an available task call
ready to be executed at each worker's pending tasks queue. So that whenever the worker finishes the execution of the current
task it will immediately start executing the next pending task, without being idle until the pool's main run() loop receives
the previous task result and dispatching the next task to this worker.
(iv) Centralized tasks dispatching:
Whenever the pool calls the dispatching method to invoke a new remote task call, the dispatcher can decide which actual task to
dispatch and which params to pass. For example, the dispatcher can (potentially) maintain its own data-structures to track which
tasks should be the next to invoke. The pool keeps track on completed/canceled dispatched tasks, so the dispatcher can use this
information (by additionally overriding methods like `_remove_worker_and_submitted_tasks()` and `_clear_all_workers_and_tasks()`).
This is a base abstract class that implements the core logics and functionalities and keeps abstract methods for tasks dispatching,
return values handling, and actors instantiating non-implemented. The inheritor is expected to implement these ad-hoc to its needs.
The main logics happens inside the long-lived call to `run()`. Basically, it loops as long as the pool is not stopped. In each
iteration we wait for complete tasks, submit new tasks, and terminate/instantiate workers as needed (based on utilization and/or
workers' health). This work is mostly IO-bound - meaning most of the time the process/thread that executes `run()` is sleeping on a
blocked call (e.g: ray.wait()).
The methods are thread-safe. That is, the user can call control operations (e.g: pause(), resume()) concurrently while run() is running.
The operations can be called from different concurrent threads. If the pool itself is executed from a ray actor, then a threaded-actor
can be used.
"""
def __init__(
self,
num_actors: int,
logger: Optional[logging.Logger] = None,
min_tasks_to_process: int = 5,
min_tasks_to_process_timeout: float = 900.,
waiting_for_tasks_wait_time: float = 1.,
nr_desired_simultaneous_tasks_per_worker: int = 2,
task_timeout: Optional[float] = None,
handle_worker_exceptions_by_default: bool = True, # Used if `allowed_exceptions` is None
allowed_exceptions: Optional[List[str]] = None, # Pass `None` for default handling policy
wait_for_tasks_end_on_reset: bool = False):
"""
:param num_actors: Maximum number of actors to instantiate when fully active (not paused)
:param logger: Optional. All status updates during the execution are being printed to this logger
:param min_tasks_to_process: Minimum number of tasks to ray.wait() for their completion before handling and proceeding the main loop
:param min_tasks_to_process_timeout: Timeout for each ray.wait() call on the active submitted tasks - clear all tasks if reached
:param waiting_for_tasks_wait_time: Wait time between consequent pollings while paused and no active tasks to wait for
:param nr_desired_simultaneous_tasks_per_worker: Number of co-existing simultaneously remote task invocations to wait for per worker. Ideally,
it should be big enough for small tasks to ensure the worker will always have a pending task
to begin working on just after finishing the current one and before the main loop detects it and
submits the next task
:param task_timeout: Optional. If given, terminate worker that haven't finished any task for at least this period of time
:param handle_worker_exceptions_by_default: When `allowed_exceptions` is not given, how to handle workers' exceptions
:param allowed_exceptions: Optional. If given, don't terminate workers that raise a corresponding exception
:param wait_for_tasks_end_on_reset: Whether to actively wait for tasks completion on reset
"""
self._num_actors = num_actors
self._logger = _get_logger() if logger is None else logger
self._min_tasks_to_process = min_tasks_to_process
self._min_tasks_to_process_timeout = min_tasks_to_process_timeout
self._waiting_for_tasks_wait_time = waiting_for_tasks_wait_time
self._control_lock = threading.RLock() # protects the status flags
# TODO: Consider using a dedicated `PoolStatus` enum instead of the following flags.
self._should_submit_new_tasks = False
self._should_stop = False
self._should_terminate = False # graceful termination - wait for completion of active tasks before exiting run()'s main loop
self._is_idle = True
self._is_main_loop_active = False # avoid having two active main loops concurrently
self._tasks_info: Dict[ray.ObjectRef, TaskInfo] = {}
self._workers_info: Dict[ray.actor.ActorHandle, WorkerInfo] = {}
self._nr_desired_simultaneous_tasks_per_worker = nr_desired_simultaneous_tasks_per_worker
self._task_timeout = task_timeout
self._handle_worker_exceptions_by_default = handle_worker_exceptions_by_default
self._allowed_exceptions = allowed_exceptions
self._wait_for_tasks_end_on_reset = wait_for_tasks_end_on_reset
def run(self):
"""
Here is where the main logics happen. It's a long-lived call. Dynamically instantiate/terminate workers and submit tasks
corresponding to the pool's status (whether paused/stopped). Wait for dispatched tasks to complete, and handle compled ttasks'
return values. Monitor and track the health of the workers. Exit after `stop()` is called is raised. Each loop iteration contains
either waiting for pending tasks or only polling the resuming flag `self._should_submit_new_tasks` (if paused).
"""
try:
with self._control_lock:
self._should_submit_new_tasks = True
self._should_stop = False
self._should_terminate = False
self._is_idle = False
if self._is_main_loop_active:
# This verification is necessary when support having a concurrent actor with multiple threads (2) for handling
# user's calls for other methods (e.g: stop(), update_*_network() etc) while running this infinite task.
raise RuntimeError(f'{self.run.__name__}() should be called only once concurrently!')
self._is_main_loop_active = True
# Initialize all workers
self._workers_info = {}
self._instantiate_and_register_workers(self._num_actors)
self._logger.info("Finished initializing")
# Initialize empty tasks list. Tasks will be submitted to workers within the loop below.
self._tasks_info = {}
self.on_before_main_loop()
last_waiting_time_for_tasks_logging_time = None
while not self._should_stop:
self.on_main_loop_iteration_start()
if len(self._tasks_info) == 0 and not self._should_submit_new_tasks:
self._clear_all_workers_and_tasks()
if self._should_terminate:
break
else:
# No tasks exist, wait for 1 second and try again
if last_waiting_time_for_tasks_logging_time is None or \
time.monotonic() - last_waiting_time_for_tasks_logging_time >= 60.:
self._logger.info("No tasks exist, waiting")
last_waiting_time_for_tasks_logging_time = time.monotonic()
time.sleep(self._waiting_for_tasks_wait_time)
continue
self._enforce_tasks_timeout()
self._fill_workers_tasks_queues()
just_finished_tasks_futures = self._wait_for_pending_tasks()
while len(just_finished_tasks_futures) > 0:
finished_task_future = just_finished_tasks_futures.pop()
self._handle_finished_task_handle(finished_task_future)
self.on_main_loop_iteration_end()
except Exception as ex:
self._logger.exception(f"Major Exception, {ex.args}, {type(ex)}, {ex}")
self._logger.exception(traceback.format_exc())
finally:
# Softly terminate workers.
self._clear_all_workers_and_tasks()
self._is_idle = True
self._should_submit_new_tasks = False
self._is_main_loop_active = False
self._logger.info("All workers and task softly cleared.")
self._logger.info("Exiting.")
self._flush_logger() # Ensure logs are flushed to output. Helpful to make potential errors being visible after the exit.
def stop(self):
"""
Exit the run() main loop (if active) as soon as possible, without waiting to tasks to finish.
"""
self._logger.info("Stopping")
with self._control_lock:
self._should_stop = True
self._should_submit_new_tasks = False
def gracefully_terminate(self):
"""
Stop submission of new tasks. Exit run()'s main loop when all active tasks are completed (or failed).
"""
self._logger.info("Marked to terminate gracefully (stop tasks submission and wait for active tasks to complete)..")
with self._control_lock:
self._should_terminate = True
self._should_submit_new_tasks = False
def pause(self):
"""
Cease submitting new tasks to workers (and terminate idle workers). Keep the main run() loop active and collect tasks' results or
wait for resumption.
"""
self._logger.info("Pausing")
with self._control_lock:
self._should_submit_new_tasks = False
def resume(self):
"""
Resume submitting new tasks to workers if paused before (and instantiate new workers if needed).
"""
self._logger.info("Resuming")
with self._control_lock:
self._should_submit_new_tasks = True
def reset(self):
"""
Pause (wait for all active tasks to finish without submitting new tasks), call a callback to potentially clear internal buffers of
the inheritor if needed, then resume.
"""
self._logger.info("Resetting ..")
self.pause()
if self._wait_for_tasks_end_on_reset:
while not self._is_idle:
self._logger.info("Waiting for tasks for finish after reset, sleeping for 30 seconds.")
time.sleep(30)
self._logger.info("All tasks finished.")
# Support potentially clearing internal buffers of the inheritor if needed.
self.on_reset_before_resumption()
self.resume()
def _flush_logger(self):
"""
Ensure logs are flushed to output. Helpful to make potential errors being visible to the user.
"""
for handler in self._logger.handlers:
handler.flush()
sys.stdout.flush()
sys.stderr.flush()
# Yield the active thread/process for a short period to give a chance for other kind of outputs to get flushed.
# https://www.geeksforgeeks.org/python-sys-stdout-flush/
time.sleep(1)
def _wait_for_pending_tasks(self) -> List[ray.ObjectRef]:
"""
This is one of the main methods being called directly from run()'s main loop. In each iteration we wait for some of the remotely
submitted active tasks to be completed.
"""
# Encapsulating this allows elegantly avoid having an extra redundant reference to tasks to avoid deferring its release when we
# actually lose the last reference to it.
pending_tasks_list = list(self._tasks_info.keys())
assert len(pending_tasks_list) >= 0
just_finished_tasks, non_finished_tasks = ray.wait(
pending_tasks_list,
num_returns=min(self._min_tasks_to_process, len(pending_tasks_list)),
timeout=self._min_tasks_to_process_timeout)
# Ray wait timeout exceeded without a finished task.
if len(just_finished_tasks) == 0 and len(non_finished_tasks) > 0:
# failed to finish any task - release resources
self._logger.info("Cluster stuck. Releasing all resources !!!")
self._clear_all_workers_and_tasks()
return just_finished_tasks
def _fill_worker_tasks_queue(self, worker: ray.actor.ActorHandle):
"""
Ensure that enough remote tasks are being invoked for the given worker.
"""
if not self._should_submit_new_tasks:
return
# Requesting multiple tasks simultaneously so there's always a task in place in the worker's queue if main loop doesn't keep up.
# As the `ray.wait()` waits until the finished task's data is fetched locally, which might take time.
while len(self._workers_info[worker].tasks) < self._nr_desired_simultaneous_tasks_per_worker:
self._submit_tracked_task_to_worker(worker=worker)
def _submit_tracked_task_to_worker(self, worker: ray.actor.ActorHandle):
"""
Create a new remote task call and store the future handle in the internal data-structures for tracking and waiting for it. For
keeping the actual performed task as general as possible, the relevant remote call triggering is being encapsulated by
`_trigger_worker_remote_task_call()`.
"""
submission_id = self._generate_unique_submission_id()
task_future = self.invoke_worker_remote_task_call(worker=worker, submission_id=submission_id)
if task_future is None:
return
with self._control_lock:
self._tasks_info[task_future] = task_info = TaskInfo(
submission_id=submission_id, future=task_future, worker=worker, submit_time=time.monotonic())
worker_info = self._workers_info[worker]
assert task_future not in worker_info.tasks
worker_info.tasks[task_future] = task_info
worker_info.total_nr_submitted_tasks += 1
self._is_idle = False
def _generate_unique_submission_id(self) -> str:
"""
Used for assigning a unique identifier for each submitted task. This approach allows the user (that inherits from RayPool) to track
tasks. Especially tasks that are failed/discarded without returning a value. Here is the default implementation. It can be
overridden by the user to have custom identifiers.
"""
return uuid.uuid4().hex
def _remove_worker_and_submitted_tasks(self, worker: ray.actor.ActorHandle, kill: bool = False):
"""
Removing a worker is useful both whenever it becomes idle or when a failure occurred (that requires discarding the worker). Ensure
that the tracked remote tasks submitted for this worker are also discarded. Optionally, forcely kill the worker. If not, only lose
all the references to it and make ray release it.
"""
discarded_task_submission_ids = set()
with self._control_lock:
worker_info = self._workers_info.pop(worker, None)
if worker_info is not None:
for task_info in worker_info.tasks.values():
self._tasks_info.pop(task_info.future)
discarded_task_submission_ids.add(task_info.submission_id)
if kill:
try:
ray.kill(worker)
except:
pass
for submission_id in discarded_task_submission_ids:
self.task_discarded(submission_id=submission_id)
def _clear_all_tasks(self):
"""
Lose all tracked references to all submitted remote tasks. Keep the workers if exist.
"""
discarded_task_submission_ids = set()
with self._control_lock:
while len(self._tasks_info) > 0:
_, task_info = self._tasks_info.popitem()
discarded_task_submission_ids.add(task_info.submission_id)
for worker_info in self._workers_info.values():
assert all(task_info.submission_id in discarded_task_submission_ids for task_info in worker_info.tasks.values())
worker_info.tasks.clear()
if not self._should_submit_new_tasks:
self._is_idle = True
for submission_id in discarded_task_submission_ids:
self.task_discarded(submission_id=submission_id)
def _clear_all_workers_and_tasks(self):
"""
Lose all tracked references to all submitted remote tasks and all workers. Used upon when exiting or when a global error occurred.
"""
with self._control_lock:
self._clear_all_tasks()
self._workers_info.clear()
def _pop_finished_task(self, task_future: ray.ObjectRef) -> Optional[TaskInfo]:
"""
Remove all references to the given task handle from the internal data-structures.
:return: handle to the corresponding worker that the given task is assigned to.
"""
if task_future not in self._tasks_info:
# Can happen during the loop of handling the finished tasks, if several finished tasks came from the same worker that caused an
# exception in ray get.
return None
task_info = self._tasks_info.pop(task_future)
worker_info = self._workers_info[task_info.worker]
worker_info.tasks.pop(task_future)
worker_info.last_finished_task_time = time.monotonic()
return task_info
def _handle_finished_task_handle(self, task_handle: ray.ObjectRef):
"""
Called when a remote task has been finished and fetched locally. Should clear all relevant tracked references to this task, obtain
its return value, and pass it to the user.
"""
task_info = self._pop_finished_task(task_future=task_handle)
if task_info is None:
# Continue in case task is no longer in task worker map keys. This can happen if several
# finished tasks came from the same worker that caused an exception in ray get.
return
try:
task_returned_value = ray.get(task_handle)
except Exception as ex:
self._handle_worker_failure(worker=task_info.worker, failed_submission_id=task_info.submission_id, exception=ex)
return
is_worker_healthy = self.is_task_results_healthy(
task_returned_value=task_returned_value, submission_id=task_info.submission_id)
if not is_worker_healthy:
self._logger.error(f'Task return value is not healthy. Killing worker.')
self._terminate_failed_worker(worker=task_info.worker, related_failed_submission_id=task_info.submission_id, force_kill=True)
return
# If we want to continue performing tasks, add a new task for this worker
self._fill_worker_tasks_queue(worker=task_info.worker)
self.handle_task_results(task_returned_value=task_returned_value, submission_id=task_info.submission_id)
def _handle_worker_failure(self, worker: ray.actor.ActorHandle, failed_submission_id: str, exception: Exception):
"""
Called whenever a worker failure has occurred (raised by `ray.get()`). This implementation terminates the worker and gracefully
proceed the main pool's run() execution depending on the actual error occurred. If given explicit `allowed_exceptions` list, then
for non-allowed exceptions the exception would propagate, so that the entire run() execution would be crashed. Encapsulating these
logics within a standalone method Can be overridden by the inheritor to customly extend/modify the failure handling functionality.
"""
self._logger.exception("Exception occurred in worker")
self._logger.exception(traceback.format_exc())
try:
# Handle exception based on the allowed exceptions list. If the exception is not allowed, it will be raised from within the
# handling method.
if self._allowed_exceptions is None:
if not self._handle_worker_exceptions_by_default:
raise exception
else:
_handle_allowed_exceptions(exception=exception, allowed_exception_types_partial_names=self._allowed_exceptions)
except:
self._terminate()
raise # propagate the error upward (will terminate the `run()` call)
# This finer handling case is reachable if the error is not propagated.
# Lose reference to the failed worker and to its related tasks, as it might be dead.
# The main loop will add a new worker later if still generating.
# In addition to that, remove all tasks that are scheduled on the problematic worker.
self._terminate_failed_worker(worker=worker, related_failed_submission_id=failed_submission_id, related_exception=exception)
def _terminate_failed_worker(
self,
worker: ray.actor.ActorHandle,
related_failed_submission_id: Optional[str] = None,
related_exception: Optional[Exception] = None,
force_kill: bool = False):
"""
Called whenever a worker should be terminated due to a failure.
Can be overridden by the inheritor to track workers' failures.
"""
self._remove_worker_and_submitted_tasks(worker=worker, kill=force_kill)
def _terminate(self):
"""
Encapsulating and unifying the necessary steps for terminating run()'s main loop from auxiliary methods executed within the loop.
For example, whenever a non-recoverable error has occurred (or an error that is marked to not be handled).
Note: Expected to be called from within the main thread that executes `run()` and not from control methods that might be called
from other concurrent threads, as it modifies un-protected structures.
"""
with self._control_lock:
assert self._is_main_loop_active
self._should_submit_new_tasks = False
self._should_stop = True
self._clear_all_tasks()
def _enforce_tasks_timeout(self):
"""
Called by the main loop. Kill tasks that exceeded the timeout if set.
"""
if self._task_timeout is None:
return # limitation is unset
now = time.monotonic()
workers_to_remove = []
for worker_handle, worker_info in self._workers_info.items():
if worker_info.latest_worker_event is not None and now - worker_info.latest_worker_event > self._task_timeout:
# Nothing happened with this worker for too long. Mark it to be killed.
workers_to_remove.append(worker_handle)
if workers_to_remove:
self._logger.warning(f'Killing {len(workers_to_remove)} workers with exceeded tasks timeout.')
# We perform the actual workers removal only after iterating through the workers to avoid changing list during iterating it.
for worker_handle in workers_to_remove:
self._remove_worker_and_submitted_tasks(worker_handle, kill=True)
def _fill_workers_tasks_queues(self):
"""
Invoke enough tasks for each worker to meet the desired number of simultaneous submitted tasks.
"""
# Generate tasks for all non-busy workers
if self._should_submit_new_tasks:
# Add workers according to required amount
if self._num_actors > len(self._workers_info):
self._instantiate_and_register_workers(self._num_actors - len(self._workers_info))
for worker in self._workers_info.keys():
self._fill_worker_tasks_queue(worker=worker)
else:
self._remove_idle_workers()
def _get_busy_workers(self) -> Set[ray.actor.ActorHandle]:
"""
:return: a set of handles to workers that have an active/pending task assigned.
"""
return set(worker_info.actor_handle for worker_info in self._workers_info.values() if worker_info.is_busy)
def _get_idle_workers(self) -> Set[ray.actor.ActorHandle]:
"""
:return: a set of handles to workers that don't have an active/pending task assigned.
"""
return set(self._workers_info.keys()) - self._get_busy_workers()
def _remove_idle_workers(self):
"""
Called by the main loop to erase the workers without any assigned active/pending task.
"""
# Remove non-busy workers from list so ray will free the cpus
idle_workers = self._get_idle_workers()
if len(idle_workers) == 0:
return
self._logger.info(f"Removing {len(idle_workers)} idle workers because shouldn't generate now.")
# Note: It's ok we're not removing verify why we're not removing tasks here
for worker_handle in idle_workers:
self._remove_worker_and_submitted_tasks(worker=worker_handle)
def _instantiate_and_register_workers(self, num_actors: int):
for worker_handle in self.instantiate_workers(num_actors=num_actors):
assert worker_handle not in self._workers_info
self._workers_info[worker_handle] = WorkerInfo(actor_handle=worker_handle)
@abstractmethod
def instantiate_workers(self, num_actors: int) -> List[ray.actor.ActorHandle]:
"""
Should be implemented by the inheritor. Creates the actual actor of the desired type. It is called by the main loop within run().
Workers can be dynamically instantiated and terminated multiple times during the entire lifetime of run()'s main loop.
"""
raise NotImplementedError
@abstractmethod
def handle_task_results(self, task_returned_value: Any, submission_id: str):
"""
Being called by the `run()`s main loop whenever a task has been completed and fetched locally. The purpose of this method is to
process these results locally however needed, and store them under a local structure or pass them to the relevant consumer.
:param task_returned_value: the value produced and returned from the completed task
:param submission_id: the unique identifier assigned to this task upon its submission
"""
raise NotImplementedError
def task_discarded(self, submission_id: str):
"""
Being called by the `run()`s main loop whenever a task has been discarded. Note that the task execution status is unknown in this
stage; that is, it's unknown whether it's started and interrupted / started and finished / canceled before even starting. It is
promised however that `handle_task_results()` won't and haven't been called on this task.
:param submission_id: the unique identifier assigned to this task upon its submission
"""
pass
@abstractmethod
def invoke_worker_remote_task_call(self, worker: ray.actor.ActorHandle, submission_id: str) -> Optional[ray.ObjectRef]:
"""
Separate method for only creating the remote call to worker. Only returns the just-created task, without tracking / registering
it in the internal pool's data-structures (performed by the caller). Should be overridden by inheritor. Can support, for example,
multiple kinds of tasks and/or centralizing here the distributing and passing of the parameters to the submitted tasks.
:param worker: Handle to the ray actor to submit this task to. RayPool is responsible for balance the submitted jobs per worker;
that is, the user (the inheritor) implements this function without caring neither about actors instantiation /
destruction nor about when submitting remote tasks to which workers.
:param submission_id: The unique identifier for the new task. This allows the user to track the tasks. For example, the user
can pass this id as a parameter to the remote call. If the remote call acquires resources, it can log this
id together with the acquisition. Then, when a task is failed or discarded, the same identifier to that
task is known here and the user can safely release the aqcuired resources that are assigned with this task id.
:return: The future to the just-submitted remote task.
"""
raise NotImplementedError
def is_task_results_healthy(self, task_returned_value: Any, submission_id: str) -> bool:
"""
A means to detect faulty runs and support terminating actors upon such cases. Can be overridden by the inheritor. Returning true
keeps the same behavior. Returning false would disregard this result and also kill the corresponding worker.
"""
return True # default implementation; allows overriding, but not requires explicit implementation by the inheritor
def on_main_loop_iteration_start(self):
"""
Being called at the very beginning of each iteration of the main long-lived loop at run().
Default implementation does nothing. Overridable by the inheritor for extending the pool.
Use case example: periodically sending an update to all active the workers.
"""
pass
def on_main_loop_iteration_end(self):
"""
Being called at the very ending of each iteration of the main long-lived loop at run(). Note that, in contrast to
`on_main_loop_iteration_start()`, this callback is called only after waiting to active/pending tasks is completed, and thus it is
NOT being called while the pool is paused and no assigned tasks are available to wait for.
Default implementation does nothing. Overridable by the inheritor for extending the pool.
Use case examples: check for custom exit conditions; periodically log the pool's progress.
"""
pass
def on_before_main_loop(self):
"""
Being called at just before entering the main long-lived loop at run().
Default implementation does nothing. Overridable by the inheritor for extending the pool.
Use case examples: resetting custom members that the inheritor for extended functionality (like periodic log timing tracking).
"""
pass
def on_reset_before_resumption(self):
"""
Being called upon reset(), after the waiting for completed task has been completed and just before resuming.
Default implementation does nothing. Overridable by the inheritor for extending the pool.
Use case examples: resetting custom members that the inheritor for extended functionality.
"""
pass
def _handle_allowed_exceptions(exception: Exception, allowed_exception_types_partial_names: Optional[List[str]]):
if allowed_exception_types_partial_names is None:
return
exc_str = traceback.format_exc()
if not any(allowed_exc_str in exc_str for allowed_exc_str in allowed_exception_types_partial_names):
raise exception
def _get_logger() -> logging.Logger:
return logging.getLogger('RayPool')
@ray.remote(num_cpus=1, memory=6_000_000_000)
class MockActor:
def do_work(self, submission_id: str) -> Tuple[str, float]:
time.sleep(1.)
value = float(np.random.random())
return submission_id, value
class MockRayPool(RayPool):
def __init__(self, *args, max_nr_tasks: int = 100, **kwargs):
self._nr_invoked_tasks = 0
self._nr_completed_tasks = 0
self._nr_discarded_tasks = 0
self._max_nr_tasks = max_nr_tasks
RayPool.__init__(self, *args, **kwargs)
@property
def nr_completed_tasks(self) -> int:
return self._nr_completed_tasks
@property
def nr_discarded_tasks(self) -> int:
return self._nr_discarded_tasks
@property
def nr_invoked_tasks(self) -> int:
return self._nr_invoked_tasks
def instantiate_workers(self, num_actors: int) -> List[ray.actor.ActorHandle]:
return [MockActor.remote() for _ in range(num_actors)]
def handle_task_results(self, task_returned_value: Any, submission_id: str):
self._nr_completed_tasks += 1
assert isinstance(task_returned_value, tuple)
assert task_returned_value[0] == submission_id
assert isinstance(task_returned_value[1], float)
def update_metrics_on_task_end(self, task_returned_value: Any):
pass
def invoke_worker_remote_task_call(self, worker: ray.actor.ActorHandle, submission_id: str) -> Optional[ray.ObjectRef]:
if self._nr_invoked_tasks >= self._max_nr_tasks:
return None
self._nr_invoked_tasks += 1
task = worker.do_work.remote(submission_id=submission_id)
return task
def on_main_loop_iteration_end(self):
if self._nr_invoked_tasks >= self._max_nr_tasks:
self.pause()
def on_main_loop_iteration_start(self):
if self._nr_completed_tasks + self._nr_discarded_tasks >= self._max_nr_tasks:
self.stop()
def task_discarded(self, submission_id: str):
self._nr_discarded_tasks += 1
class ThreadedMockRayPool(MockRayPool, threading.Thread):
def __init__(self, *args, **kwargs):
MockRayPool.__init__(self, *args, **kwargs)
threading.Thread.__init__(self)
def test_ray_pool_basic():
logger = _get_logger()
nr_tasks = 100
pool = MockRayPool(logger=logger, num_actors=10, max_nr_tasks=nr_tasks)
pool.run()
assert pool.nr_completed_tasks == nr_tasks
assert pool.nr_discarded_tasks == 0
def test_ray_pool_concurrent_with_terminate():
logger = _get_logger()
nr_tasks = 100
threaded_pool = ThreadedMockRayPool(logger=logger, num_actors=10, max_nr_tasks=nr_tasks)
threaded_pool.start()
threaded_pool.gracefully_terminate()
nr_invoked_tasks_after_terminate_command = threaded_pool.nr_invoked_tasks
threaded_pool.join()
assert threaded_pool.nr_completed_tasks == nr_invoked_tasks_after_terminate_command
assert threaded_pool.nr_discarded_tasks == 0
def test_ray_pool_task_discarded():
logger = _get_logger()
nr_tasks = 100
threaded_pool = ThreadedMockRayPool(logger=logger, num_actors=10, max_nr_tasks=nr_tasks)
threaded_pool.start()
while threaded_pool.nr_completed_tasks == 0:
time.sleep(.1)
# We stop the generation just after the first results are ready. Then we can expect some results to be discarded before being ready.
# Apart from checking that the `task_discarded()` method is being called properly, we also check that for each task exactly one
# method is called - either the task_discarded() or handle_task_results(). This also characterize a desired flow where the results are
# being reported just when they're ready and that not all of the results are passed together when all are ready (assuming
# #workers >> #tasks).
threaded_pool.stop()
nr_invoked_tasks_after_terminate_command = threaded_pool.nr_invoked_tasks
threaded_pool.join()
assert threaded_pool.nr_discarded_tasks > 0
assert threaded_pool.nr_completed_tasks > 0
assert threaded_pool.nr_completed_tasks + threaded_pool.nr_discarded_tasks == nr_invoked_tasks_after_terminate_command
if __name__ == '__main__':
import numpy as np
test_ray_pool_basic()
test_ray_pool_concurrent_with_terminate()
test_ray_pool_task_discarded()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment