Last active
October 10, 2024 03:05
-
-
Save vmolsa/51811f0c12218821a453322a55ea002f to your computer and use it in GitHub Desktop.
Stack allocated futures as Batch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| use std::{future::Future, pin::Pin, task::{Context, Poll}}; | |
| /// A structure representing a batch of futures that can be polled concurrently. | |
| /// | |
| /// # Type Parameters | |
| /// | |
| /// * `'a` - The lifetime of the futures. | |
| /// * `T` - The output type of the futures on successful completion. | |
| /// * `E` - The error type of the futures on failure. | |
| /// * `F` - The type of the futures being batched. | |
| /// * `N` - The number of futures in the batch. | |
| pub struct Batch<'a, T, E, F: Future<Output = Result<T, E>>, const N: usize> { | |
| stack: [Option<Pin<&'a mut F>>; N], | |
| result: [Option<T>; N], | |
| } | |
| /// Implementation of the From trait to convert an array of pinned futures into a Batch struct. | |
| /// | |
| /// # Arguments | |
| /// | |
| /// * `value` - An array of pinned futures. | |
| /// | |
| /// # Returns | |
| /// | |
| /// A Batch instance containing the provided futures. | |
| impl<'a, T: Copy, E, F: Future<Output = Result<T, E>>, const N: usize> From<[Pin<&'a mut F>; N]> for Batch<'a, T, E, F, N> { | |
| fn from(value: [Pin<&'a mut F>; N]) -> Self { | |
| Self { | |
| stack: value.map(Some), | |
| result: [None; N], | |
| } | |
| } | |
| } | |
| /// Implementation of the Future trait for the Batch struct. | |
| /// | |
| /// This implementation allows the Batch to be polled, driving the underlying futures to completion. | |
| /// The output is a Result containing an array of results or an error if any of the futures fail. | |
| impl<'a, T: Unpin + Copy, E, F: Future<Output = Result<T, E>>, const N: usize> Future for Batch<'a, T, E, F, N> { | |
| type Output = Result<[T; N], E>; | |
| fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { | |
| let this = self.get_mut(); | |
| let mut all_done = true; | |
| for i in 0..N { | |
| if let Some(fut) = &mut this.stack[i] { | |
| match fut.as_mut().poll(cx) { | |
| Poll::Ready(Ok(val)) => { | |
| this.result[i] = Some(val); | |
| this.stack[i] = None; | |
| } | |
| Poll::Ready(Err(err)) => { | |
| return Poll::Ready(Err(err)); | |
| } | |
| Poll::Pending => { | |
| all_done = false; | |
| } | |
| } | |
| } | |
| } | |
| if all_done { | |
| // Unwrapping here is safe because all elements are ensured to be `Some` | |
| Poll::Ready(Ok(this.result.map(|opt| opt.unwrap()))) | |
| } else { | |
| Poll::Pending | |
| } | |
| } | |
| } | |
| /// A structure representing a batch of futures that can be polled concurrently. | |
| /// | |
| /// # Type Parameters | |
| /// | |
| /// * `T` - The output type of the futures on successful completion. | |
| /// * `E` - The error type of the futures on failure. | |
| /// * `F` - The type of the futures being batched. | |
| pub struct BatchBox<T, E, F: Future<Output = Result<T, E>>> { | |
| stack: Vec<Option<Pin<Box<F>>>>, | |
| result: Vec<Option<T>>, | |
| } | |
| /// Implementation of the FromIterator trait to convert an iterator of pinned futures into a BatchBox struct. | |
| /// | |
| /// This allows a BatchBox to be created from any iterator that yields pinned futures. | |
| /// | |
| /// # Arguments | |
| /// | |
| /// * `iter` - An iterator of pinned futures. | |
| /// | |
| /// # Returns | |
| /// | |
| /// A BatchBox instance containing the provided futures. | |
| impl<T: Copy, E, F: Future<Output = Result<T, E>>> FromIterator<Pin<Box<F>>> for BatchBox<T, E, F> { | |
| fn from_iter<I: IntoIterator<Item = Pin<Box<F>>>>(iter: I) -> Self { | |
| let value: Vec<Pin<Box<F>>> = iter.into_iter().collect(); | |
| Self { | |
| result: vec![None; value.len()], | |
| stack: value.into_iter().map(Some).collect(), | |
| } | |
| } | |
| } | |
| /// Implementation of the Future trait for the BatchBox struct. | |
| /// | |
| /// This implementation allows the BatchBox to be polled, driving the underlying futures to completion. | |
| /// The output is a Result containing a vector of results or an error if any of the futures fail. | |
| impl<T: Unpin + Copy, E, F: Future<Output = Result<T, E>>> Future for BatchBox<T, E, F> { | |
| type Output = Result<Vec<T>, E>; | |
| fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { | |
| let this = self.get_mut(); | |
| let mut all_done = true; | |
| for i in 0..this.stack.len() { | |
| if let Some(fut) = &mut this.stack[i] { | |
| match fut.as_mut().poll(cx) { | |
| Poll::Ready(Ok(val)) => { | |
| this.result[i] = Some(val); | |
| this.stack[i] = None; | |
| } | |
| Poll::Ready(Err(err)) => { | |
| return Poll::Ready(Err(err)); | |
| } | |
| Poll::Pending => { | |
| all_done = false; | |
| } | |
| } | |
| } | |
| } | |
| if all_done { | |
| // Unwrapping here is safe because all elements are ensured to be `Some` | |
| Poll::Ready(Ok(this.result.iter().map(|opt| opt.unwrap()).collect())) | |
| } else { | |
| Poll::Pending | |
| } | |
| } | |
| } | |
| #[tokio::test] | |
| async fn test_batch_with_tokio_sleep() { | |
| use std::time::Duration; | |
| use std::pin::pin; | |
| async fn sleep_and_return_result(duration: Duration) -> Result<(), ()> { | |
| tokio::time::sleep(duration).await; | |
| Ok(()) | |
| } | |
| let fut1 = pin!(sleep_and_return_result(Duration::from_millis(100))); | |
| let fut2 = pin!(sleep_and_return_result(Duration::from_millis(200))); | |
| let fut3 = pin!(sleep_and_return_result(Duration::from_millis(300))); | |
| let fut4 = pin!(sleep_and_return_result(Duration::from_millis(400))); | |
| let fut5 = pin!(sleep_and_return_result(Duration::from_millis(500))); | |
| let fut6 = pin!(sleep_and_return_result(Duration::from_millis(600))); | |
| let fut7 = pin!(sleep_and_return_result(Duration::from_millis(700))); | |
| let fut8 = pin!(sleep_and_return_result(Duration::from_millis(800))); | |
| let fut9 = pin!(sleep_and_return_result(Duration::from_millis(900))); | |
| let batch = Batch::from([ fut1, fut2, fut3, fut4, fut5, fut6, fut7, fut8, fut9 ]); | |
| let _ = batch.await; | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment