1use std::future::Future;
29use std::mem;
30use std::ops;
31use std::pin::Pin;
32use std::sync::{Arc, Condvar, Mutex};
33use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
34
35use tokio_stream::Stream;
36
37pub fn spawn<T>(task: T) -> Spawn<T> {
43 Spawn {
44 task: MockTask::new(),
45 future: Box::pin(task),
46 }
47}
48
49#[derive(Debug)]
52#[must_use = "futures do nothing unless you `.await` or poll them"]
53pub struct Spawn<T> {
54 task: MockTask,
55 future: Pin<Box<T>>,
56}
57
58#[derive(Debug, Clone)]
59struct MockTask {
60 waker: Arc<ThreadWaker>,
61}
62
63#[derive(Debug)]
64struct ThreadWaker {
65 state: Mutex<usize>,
66 condvar: Condvar,
67}
68
69const IDLE: usize = 0;
70const WAKE: usize = 1;
71const SLEEP: usize = 2;
72
73impl<T> Spawn<T> {
74 pub fn into_inner(self) -> T
76 where
77 T: Unpin,
78 {
79 *Pin::into_inner(self.future)
80 }
81
82 pub fn is_woken(&self) -> bool {
85 self.task.is_woken()
86 }
87
88 pub fn waker_ref_count(&self) -> usize {
92 self.task.waker_ref_count()
93 }
94
95 pub fn enter<F, R>(&mut self, f: F) -> R
97 where
98 F: FnOnce(&mut Context<'_>, Pin<&mut T>) -> R,
99 {
100 let fut = self.future.as_mut();
101 self.task.enter(|cx| f(cx, fut))
102 }
103}
104
105impl<T: Unpin> ops::Deref for Spawn<T> {
106 type Target = T;
107
108 fn deref(&self) -> &T {
109 &self.future
110 }
111}
112
113impl<T: Unpin> ops::DerefMut for Spawn<T> {
114 fn deref_mut(&mut self) -> &mut T {
115 &mut self.future
116 }
117}
118
119impl<T: Future> Spawn<T> {
120 pub fn poll(&mut self) -> Poll<T::Output> {
123 let fut = self.future.as_mut();
124 self.task.enter(|cx| fut.poll(cx))
125 }
126}
127
128impl<T: Stream> Spawn<T> {
129 pub fn poll_next(&mut self) -> Poll<Option<T::Item>> {
132 let stream = self.future.as_mut();
133 self.task.enter(|cx| stream.poll_next(cx))
134 }
135}
136
137impl<T: Future> Future for Spawn<T> {
138 type Output = T::Output;
139
140 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
141 self.future.as_mut().poll(cx)
142 }
143}
144
145impl<T: Stream> Stream for Spawn<T> {
146 type Item = T::Item;
147
148 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
149 self.future.as_mut().poll_next(cx)
150 }
151
152 fn size_hint(&self) -> (usize, Option<usize>) {
153 self.future.size_hint()
154 }
155}
156
157impl MockTask {
158 fn new() -> Self {
160 MockTask {
161 waker: Arc::new(ThreadWaker::new()),
162 }
163 }
164
165 fn enter<F, R>(&mut self, f: F) -> R
170 where
171 F: FnOnce(&mut Context<'_>) -> R,
172 {
173 self.waker.clear();
174 let waker = self.waker();
175 let mut cx = Context::from_waker(&waker);
176
177 f(&mut cx)
178 }
179
180 fn is_woken(&self) -> bool {
183 self.waker.is_woken()
184 }
185
186 fn waker_ref_count(&self) -> usize {
190 Arc::strong_count(&self.waker)
191 }
192
193 fn waker(&self) -> Waker {
194 unsafe {
195 let raw = to_raw(self.waker.clone());
196 Waker::from_raw(raw)
197 }
198 }
199}
200
201impl Default for MockTask {
202 fn default() -> Self {
203 Self::new()
204 }
205}
206
207impl ThreadWaker {
208 fn new() -> Self {
209 ThreadWaker {
210 state: Mutex::new(IDLE),
211 condvar: Condvar::new(),
212 }
213 }
214
215 fn clear(&self) {
219 *self.state.lock().unwrap() = IDLE;
220 }
221
222 fn is_woken(&self) -> bool {
223 match *self.state.lock().unwrap() {
224 IDLE => false,
225 WAKE => true,
226 _ => unreachable!(),
227 }
228 }
229
230 fn wake(&self) {
231 let mut state = self.state.lock().unwrap();
233 let prev = *state;
234
235 if prev == WAKE {
236 return;
237 }
238
239 *state = WAKE;
240
241 if prev == IDLE {
242 return;
243 }
244
245 assert_eq!(prev, SLEEP);
247 self.condvar.notify_one();
248 }
249}
250
251static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker);
252
253unsafe fn to_raw(waker: Arc<ThreadWaker>) -> RawWaker {
254 RawWaker::new(Arc::into_raw(waker) as *const (), &VTABLE)
255}
256
257unsafe fn from_raw(raw: *const ()) -> Arc<ThreadWaker> {
258 Arc::from_raw(raw as *const ThreadWaker)
259}
260
261unsafe fn clone(raw: *const ()) -> RawWaker {
262 let waker = from_raw(raw);
263
264 mem::forget(waker.clone());
266
267 to_raw(waker)
268}
269
270unsafe fn wake(raw: *const ()) {
271 let waker = from_raw(raw);
272 waker.wake();
273}
274
275unsafe fn wake_by_ref(raw: *const ()) {
276 let waker = from_raw(raw);
277 waker.wake();
278
279 mem::forget(waker);
281}
282
283unsafe fn drop_waker(raw: *const ()) {
284 let _ = from_raw(raw);
285}