tokio_test/
task.rs

1//! Futures task based helpers to easily test futures and manually written futures.
2//!
3//! The [`Spawn`] type is used as a mock task harness that allows you to poll futures
4//! without needing to setup pinning or context. Any future can be polled but if the
5//! future requires the tokio async context you will need to ensure that you poll the
6//! [`Spawn`] within a tokio context, this means that as long as you are inside the
7//! runtime it will work and you can poll it via [`Spawn`].
8//!
9//! [`Spawn`] also supports [`Stream`] to call `poll_next` without pinning
10//! or context.
11//!
12//! In addition to circumventing the need for pinning and context, [`Spawn`] also tracks
13//! the amount of times the future/task was woken. This can be useful to track if some
14//! leaf future notified the root task correctly.
15//!
16//! # Example
17//!
18//! ```
19//! use tokio_test::task;
20//!
21//! let fut = async {};
22//!
23//! let mut task = task::spawn(fut);
24//!
25//! assert!(task.poll().is_ready(), "Task was not ready!");
26//! ```
27
28use std::future::Future;
29use std::mem;
30use std::ops;
31use std::pin::Pin;
32use std::sync::{Arc, Condvar, Mutex};
33use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
34
35use tokio_stream::Stream;
36
37/// Spawn a future into a [`Spawn`] which wraps the future in a mocked executor.
38///
39/// This can be used to spawn a [`Future`] or a [`Stream`].
40///
41/// For more information, check the module docs.
42pub fn spawn<T>(task: T) -> Spawn<T> {
43    Spawn {
44        task: MockTask::new(),
45        future: Box::pin(task),
46    }
47}
48
49/// Future spawned on a mock task that can be used to poll the future or stream
50/// without needing pinning or context types.
51#[derive(Debug)]
52#[must_use = "futures do nothing unless you `.await` or poll them"]
53pub struct Spawn<T> {
54    task: MockTask,
55    future: Pin<Box<T>>,
56}
57
58#[derive(Debug, Clone)]
59struct MockTask {
60    waker: Arc<ThreadWaker>,
61}
62
63#[derive(Debug)]
64struct ThreadWaker {
65    state: Mutex<usize>,
66    condvar: Condvar,
67}
68
69const IDLE: usize = 0;
70const WAKE: usize = 1;
71const SLEEP: usize = 2;
72
73impl<T> Spawn<T> {
74    /// Consumes `self` returning the inner value
75    pub fn into_inner(self) -> T
76    where
77        T: Unpin,
78    {
79        *Pin::into_inner(self.future)
80    }
81
82    /// Returns `true` if the inner future has received a wake notification
83    /// since the last call to `enter`.
84    pub fn is_woken(&self) -> bool {
85        self.task.is_woken()
86    }
87
88    /// Returns the number of references to the task waker
89    ///
90    /// The task itself holds a reference. The return value will never be zero.
91    pub fn waker_ref_count(&self) -> usize {
92        self.task.waker_ref_count()
93    }
94
95    /// Enter the task context
96    pub fn enter<F, R>(&mut self, f: F) -> R
97    where
98        F: FnOnce(&mut Context<'_>, Pin<&mut T>) -> R,
99    {
100        let fut = self.future.as_mut();
101        self.task.enter(|cx| f(cx, fut))
102    }
103}
104
105impl<T: Unpin> ops::Deref for Spawn<T> {
106    type Target = T;
107
108    fn deref(&self) -> &T {
109        &self.future
110    }
111}
112
113impl<T: Unpin> ops::DerefMut for Spawn<T> {
114    fn deref_mut(&mut self) -> &mut T {
115        &mut self.future
116    }
117}
118
119impl<T: Future> Spawn<T> {
120    /// If `T` is a [`Future`] then poll it. This will handle pinning and the context
121    /// type for the future.
122    pub fn poll(&mut self) -> Poll<T::Output> {
123        let fut = self.future.as_mut();
124        self.task.enter(|cx| fut.poll(cx))
125    }
126}
127
128impl<T: Stream> Spawn<T> {
129    /// If `T` is a [`Stream`] then `poll_next` it. This will handle pinning and the context
130    /// type for the stream.
131    pub fn poll_next(&mut self) -> Poll<Option<T::Item>> {
132        let stream = self.future.as_mut();
133        self.task.enter(|cx| stream.poll_next(cx))
134    }
135}
136
137impl<T: Future> Future for Spawn<T> {
138    type Output = T::Output;
139
140    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
141        self.future.as_mut().poll(cx)
142    }
143}
144
145impl<T: Stream> Stream for Spawn<T> {
146    type Item = T::Item;
147
148    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
149        self.future.as_mut().poll_next(cx)
150    }
151
152    fn size_hint(&self) -> (usize, Option<usize>) {
153        self.future.size_hint()
154    }
155}
156
157impl MockTask {
158    /// Creates new mock task
159    fn new() -> Self {
160        MockTask {
161            waker: Arc::new(ThreadWaker::new()),
162        }
163    }
164
165    /// Runs a closure from the context of the task.
166    ///
167    /// Any wake notifications resulting from the execution of the closure are
168    /// tracked.
169    fn enter<F, R>(&mut self, f: F) -> R
170    where
171        F: FnOnce(&mut Context<'_>) -> R,
172    {
173        self.waker.clear();
174        let waker = self.waker();
175        let mut cx = Context::from_waker(&waker);
176
177        f(&mut cx)
178    }
179
180    /// Returns `true` if the inner future has received a wake notification
181    /// since the last call to `enter`.
182    fn is_woken(&self) -> bool {
183        self.waker.is_woken()
184    }
185
186    /// Returns the number of references to the task waker
187    ///
188    /// The task itself holds a reference. The return value will never be zero.
189    fn waker_ref_count(&self) -> usize {
190        Arc::strong_count(&self.waker)
191    }
192
193    fn waker(&self) -> Waker {
194        unsafe {
195            let raw = to_raw(self.waker.clone());
196            Waker::from_raw(raw)
197        }
198    }
199}
200
201impl Default for MockTask {
202    fn default() -> Self {
203        Self::new()
204    }
205}
206
207impl ThreadWaker {
208    fn new() -> Self {
209        ThreadWaker {
210            state: Mutex::new(IDLE),
211            condvar: Condvar::new(),
212        }
213    }
214
215    /// Clears any previously received wakes, avoiding potential spurious
216    /// wake notifications. This should only be called immediately before running the
217    /// task.
218    fn clear(&self) {
219        *self.state.lock().unwrap() = IDLE;
220    }
221
222    fn is_woken(&self) -> bool {
223        match *self.state.lock().unwrap() {
224            IDLE => false,
225            WAKE => true,
226            _ => unreachable!(),
227        }
228    }
229
230    fn wake(&self) {
231        // First, try transitioning from IDLE -> NOTIFY, this does not require a lock.
232        let mut state = self.state.lock().unwrap();
233        let prev = *state;
234
235        if prev == WAKE {
236            return;
237        }
238
239        *state = WAKE;
240
241        if prev == IDLE {
242            return;
243        }
244
245        // The other half is sleeping, so we wake it up.
246        assert_eq!(prev, SLEEP);
247        self.condvar.notify_one();
248    }
249}
250
251static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker);
252
253unsafe fn to_raw(waker: Arc<ThreadWaker>) -> RawWaker {
254    RawWaker::new(Arc::into_raw(waker) as *const (), &VTABLE)
255}
256
257unsafe fn from_raw(raw: *const ()) -> Arc<ThreadWaker> {
258    Arc::from_raw(raw as *const ThreadWaker)
259}
260
261unsafe fn clone(raw: *const ()) -> RawWaker {
262    let waker = from_raw(raw);
263
264    // Increment the ref count
265    mem::forget(waker.clone());
266
267    to_raw(waker)
268}
269
270unsafe fn wake(raw: *const ()) {
271    let waker = from_raw(raw);
272    waker.wake();
273}
274
275unsafe fn wake_by_ref(raw: *const ()) {
276    let waker = from_raw(raw);
277    waker.wake();
278
279    // We don't actually own a reference to the unparker
280    mem::forget(waker);
281}
282
283unsafe fn drop_waker(raw: *const ()) {
284    let _ = from_raw(raw);
285}