tokio/runtime/scheduler/multi_thread_alt/
queue.rs

1//! Run-queue structures to support a work-stealing scheduler
2
3use crate::loom::cell::UnsafeCell;
4use crate::loom::sync::Arc;
5use crate::runtime::scheduler::multi_thread_alt::{Overflow, Stats};
6use crate::runtime::task;
7
8use std::mem::{self, MaybeUninit};
9use std::ptr;
10use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release};
11
12// Use wider integers when possible to increase ABA resilience.
13//
14// See issue #5041: <https://github.com/tokio-rs/tokio/issues/5041>.
15cfg_has_atomic_u64! {
16    type UnsignedShort = u32;
17    type UnsignedLong = u64;
18    type AtomicUnsignedShort = crate::loom::sync::atomic::AtomicU32;
19    type AtomicUnsignedLong = crate::loom::sync::atomic::AtomicU64;
20}
21cfg_not_has_atomic_u64! {
22    type UnsignedShort = u16;
23    type UnsignedLong = u32;
24    type AtomicUnsignedShort = crate::loom::sync::atomic::AtomicU16;
25    type AtomicUnsignedLong = crate::loom::sync::atomic::AtomicU32;
26}
27
28/// Producer handle. May only be used from a single thread.
29pub(crate) struct Local<T: 'static> {
30    inner: Arc<Inner<T>>,
31}
32
33/// Consumer handle. May be used from many threads.
34pub(crate) struct Steal<T: 'static>(Arc<Inner<T>>);
35
36#[repr(align(128))]
37pub(crate) struct Inner<T: 'static> {
38    /// Concurrently updated by many threads.
39    ///
40    /// Contains two `UnsignedShort` values. The `LSB` byte is the "real" head of
41    /// the queue. The `UnsignedShort` in the `MSB` is set by a stealer in process
42    /// of stealing values. It represents the first value being stolen in the
43    /// batch. The `UnsignedShort` indices are intentionally wider than strictly
44    /// required for buffer indexing in order to provide ABA mitigation and make
45    /// it possible to distinguish between full and empty buffers.
46    ///
47    /// When both `UnsignedShort` values are the same, there is no active
48    /// stealer.
49    ///
50    /// Tracking an in-progress stealer prevents a wrapping scenario.
51    head: AtomicUnsignedLong,
52
53    /// Only updated by producer thread but read by many threads.
54    tail: AtomicUnsignedShort,
55
56    /// Elements
57    buffer: Box<[UnsafeCell<MaybeUninit<task::Notified<T>>>]>,
58
59    mask: usize,
60}
61
62unsafe impl<T> Send for Inner<T> {}
63unsafe impl<T> Sync for Inner<T> {}
64
65/// Create a new local run-queue
66pub(crate) fn local<T: 'static>(capacity: usize) -> (Steal<T>, Local<T>) {
67    assert!(capacity <= 4096);
68    assert!(capacity >= 1);
69
70    let mut buffer = Vec::with_capacity(capacity);
71
72    for _ in 0..capacity {
73        buffer.push(UnsafeCell::new(MaybeUninit::uninit()));
74    }
75
76    let inner = Arc::new(Inner {
77        head: AtomicUnsignedLong::new(0),
78        tail: AtomicUnsignedShort::new(0),
79        buffer: buffer.into_boxed_slice(),
80        mask: capacity - 1,
81    });
82
83    let local = Local {
84        inner: inner.clone(),
85    };
86
87    let remote = Steal(inner);
88
89    (remote, local)
90}
91
92impl<T> Local<T> {
93    /// How many tasks can be pushed into the queue
94    pub(crate) fn remaining_slots(&self) -> usize {
95        self.inner.remaining_slots()
96    }
97
98    pub(crate) fn max_capacity(&self) -> usize {
99        self.inner.buffer.len()
100    }
101
102    /// Returns `true` if there are no entries in the queue
103    pub(crate) fn is_empty(&self) -> bool {
104        self.inner.is_empty()
105    }
106
107    pub(crate) fn can_steal(&self) -> bool {
108        self.remaining_slots() >= self.max_capacity() - self.max_capacity() / 2
109    }
110
111    /// Pushes a batch of tasks to the back of the queue. All tasks must fit in
112    /// the local queue.
113    ///
114    /// # Panics
115    ///
116    /// The method panics if there is not enough capacity to fit in the queue.
117    pub(crate) fn push_back(&mut self, tasks: impl ExactSizeIterator<Item = task::Notified<T>>) {
118        let len = tasks.len();
119        assert!(len <= self.inner.buffer.len());
120
121        if len == 0 {
122            // Nothing to do
123            return;
124        }
125
126        let head = self.inner.head.load(Acquire);
127        let (steal, real) = unpack(head);
128
129        // safety: this is the **only** thread that updates this cell.
130        let mut tail = unsafe { self.inner.tail.unsync_load() };
131
132        if tail.wrapping_sub(steal) <= (self.inner.buffer.len() - len) as UnsignedShort {
133            // Yes, this if condition is structured a bit weird (first block
134            // does nothing, second returns an error). It is this way to match
135            // `push_back_or_overflow`.
136        } else {
137            panic!(
138                "not enough capacity; len={}; tail={}; steal={}; real={}",
139                len, tail, steal, real
140            );
141        }
142
143        for task in tasks {
144            let idx = tail as usize & self.inner.mask;
145
146            self.inner.buffer[idx].with_mut(|ptr| {
147                // Write the task to the slot
148                //
149                // Safety: There is only one producer and the above `if`
150                // condition ensures we don't touch a cell if there is a
151                // value, thus no consumer.
152                unsafe {
153                    ptr::write((*ptr).as_mut_ptr(), task);
154                }
155            });
156
157            tail = tail.wrapping_add(1);
158        }
159
160        self.inner.tail.store(tail, Release);
161    }
162
163    /// Pushes a task to the back of the local queue, if there is not enough
164    /// capacity in the queue, this triggers the overflow operation.
165    ///
166    /// When the queue overflows, half of the current contents of the queue is
167    /// moved to the given Injection queue. This frees up capacity for more
168    /// tasks to be pushed into the local queue.
169    pub(crate) fn push_back_or_overflow<O: Overflow<T>>(
170        &mut self,
171        mut task: task::Notified<T>,
172        overflow: &O,
173        stats: &mut Stats,
174    ) {
175        let tail = loop {
176            let head = self.inner.head.load(Acquire);
177            let (steal, real) = unpack(head);
178
179            // safety: this is the **only** thread that updates this cell.
180            let tail = unsafe { self.inner.tail.unsync_load() };
181
182            if tail.wrapping_sub(steal) < self.inner.buffer.len() as UnsignedShort {
183                // There is capacity for the task
184                break tail;
185            } else if steal != real {
186                super::counters::inc_num_overflows();
187                // Concurrently stealing, this will free up capacity, so only
188                // push the task onto the inject queue
189                overflow.push(task);
190                return;
191            } else {
192                super::counters::inc_num_overflows();
193                // Push the current task and half of the queue into the
194                // inject queue.
195                match self.push_overflow(task, real, tail, overflow, stats) {
196                    Ok(_) => return,
197                    // Lost the race, try again
198                    Err(v) => {
199                        task = v;
200                    }
201                }
202            }
203        };
204
205        self.push_back_finish(task, tail);
206    }
207
208    // Second half of `push_back`
209    fn push_back_finish(&self, task: task::Notified<T>, tail: UnsignedShort) {
210        // Map the position to a slot index.
211        let idx = tail as usize & self.inner.mask;
212
213        self.inner.buffer[idx].with_mut(|ptr| {
214            // Write the task to the slot
215            //
216            // Safety: There is only one producer and the above `if`
217            // condition ensures we don't touch a cell if there is a
218            // value, thus no consumer.
219            unsafe {
220                ptr::write((*ptr).as_mut_ptr(), task);
221            }
222        });
223
224        // Make the task available. Synchronizes with a load in
225        // `steal_into2`.
226        self.inner.tail.store(tail.wrapping_add(1), Release);
227    }
228
229    /// Moves a batch of tasks into the inject queue.
230    ///
231    /// This will temporarily make some of the tasks unavailable to stealers.
232    /// Once `push_overflow` is done, a notification is sent out, so if other
233    /// workers "missed" some of the tasks during a steal, they will get
234    /// another opportunity.
235    #[inline(never)]
236    fn push_overflow<O: Overflow<T>>(
237        &mut self,
238        task: task::Notified<T>,
239        head: UnsignedShort,
240        tail: UnsignedShort,
241        overflow: &O,
242        stats: &mut Stats,
243    ) -> Result<(), task::Notified<T>> {
244        // How many elements are we taking from the local queue.
245        //
246        // This is one less than the number of tasks pushed to the inject
247        // queue as we are also inserting the `task` argument.
248        let num_tasks_taken: UnsignedShort = (self.inner.buffer.len() / 2) as UnsignedShort;
249
250        assert_eq!(
251            tail.wrapping_sub(head) as usize,
252            self.inner.buffer.len(),
253            "queue is not full; tail = {}; head = {}",
254            tail,
255            head
256        );
257
258        let prev = pack(head, head);
259
260        // Claim a bunch of tasks
261        //
262        // We are claiming the tasks **before** reading them out of the buffer.
263        // This is safe because only the **current** thread is able to push new
264        // tasks.
265        //
266        // There isn't really any need for memory ordering... Relaxed would
267        // work. This is because all tasks are pushed into the queue from the
268        // current thread (or memory has been acquired if the local queue handle
269        // moved).
270        if self
271            .inner
272            .head
273            .compare_exchange(
274                prev,
275                pack(
276                    head.wrapping_add(num_tasks_taken),
277                    head.wrapping_add(num_tasks_taken),
278                ),
279                Release,
280                Relaxed,
281            )
282            .is_err()
283        {
284            // We failed to claim the tasks, losing the race. Return out of
285            // this function and try the full `push` routine again. The queue
286            // may not be full anymore.
287            return Err(task);
288        }
289
290        /// An iterator that takes elements out of the run queue.
291        struct BatchTaskIter<'a, T: 'static> {
292            buffer: &'a [UnsafeCell<MaybeUninit<task::Notified<T>>>],
293            mask: usize,
294            head: UnsignedLong,
295            i: UnsignedLong,
296            num: UnsignedShort,
297        }
298        impl<'a, T: 'static> Iterator for BatchTaskIter<'a, T> {
299            type Item = task::Notified<T>;
300
301            #[inline]
302            fn next(&mut self) -> Option<task::Notified<T>> {
303                if self.i == UnsignedLong::from(self.num) {
304                    None
305                } else {
306                    let i_idx = self.i.wrapping_add(self.head) as usize & self.mask;
307                    let slot = &self.buffer[i_idx];
308
309                    // safety: Our CAS from before has assumed exclusive ownership
310                    // of the task pointers in this range.
311                    let task = slot.with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
312
313                    self.i += 1;
314                    Some(task)
315                }
316            }
317        }
318
319        // safety: The CAS above ensures that no consumer will look at these
320        // values again, and we are the only producer.
321        let batch_iter = BatchTaskIter {
322            buffer: &self.inner.buffer,
323            mask: self.inner.mask,
324            head: head as UnsignedLong,
325            i: 0,
326            num: num_tasks_taken,
327        };
328        overflow.push_batch(batch_iter.chain(std::iter::once(task)));
329
330        // Add 1 to factor in the task currently being scheduled.
331        stats.incr_overflow_count();
332
333        Ok(())
334    }
335
336    /// Pops a task from the local queue.
337    pub(crate) fn pop(&mut self) -> Option<task::Notified<T>> {
338        let mut head = self.inner.head.load(Acquire);
339
340        let idx = loop {
341            let (steal, real) = unpack(head);
342
343            // safety: this is the **only** thread that updates this cell.
344            let tail = unsafe { self.inner.tail.unsync_load() };
345
346            if real == tail {
347                // queue is empty
348                return None;
349            }
350
351            let next_real = real.wrapping_add(1);
352
353            // If `steal == real` there are no concurrent stealers. Both `steal`
354            // and `real` are updated.
355            let next = if steal == real {
356                pack(next_real, next_real)
357            } else {
358                assert_ne!(steal, next_real);
359                pack(steal, next_real)
360            };
361
362            // Attempt to claim a task.
363            let res = self
364                .inner
365                .head
366                .compare_exchange(head, next, AcqRel, Acquire);
367
368            match res {
369                Ok(_) => break real as usize & self.inner.mask,
370                Err(actual) => head = actual,
371            }
372        };
373
374        Some(self.inner.buffer[idx].with(|ptr| unsafe { ptr::read(ptr).assume_init() }))
375    }
376}
377
378impl<T> Steal<T> {
379    pub(crate) fn is_empty(&self) -> bool {
380        self.0.is_empty()
381    }
382
383    /// Steals half the tasks from self and place them into `dst`.
384    pub(crate) fn steal_into(
385        &self,
386        dst: &mut Local<T>,
387        dst_stats: &mut Stats,
388    ) -> Option<task::Notified<T>> {
389        // Safety: the caller is the only thread that mutates `dst.tail` and
390        // holds a mutable reference.
391        let dst_tail = unsafe { dst.inner.tail.unsync_load() };
392
393        // To the caller, `dst` may **look** empty but still have values
394        // contained in the buffer. If another thread is concurrently stealing
395        // from `dst` there may not be enough capacity to steal.
396        let (steal, _) = unpack(dst.inner.head.load(Acquire));
397
398        if dst_tail.wrapping_sub(steal) > self.0.buffer.len() as UnsignedShort / 2 {
399            // we *could* try to steal less here, but for simplicity, we're just
400            // going to abort.
401            return None;
402        }
403
404        // Steal the tasks into `dst`'s buffer. This does not yet expose the
405        // tasks in `dst`.
406        let mut n = self.steal_into2(dst, dst_tail);
407
408        if n == 0 {
409            // No tasks were stolen
410            return None;
411        }
412
413        super::counters::inc_num_steals();
414
415        dst_stats.incr_steal_count(n as u16);
416        dst_stats.incr_steal_operations();
417
418        // We are returning a task here
419        n -= 1;
420
421        let ret_pos = dst_tail.wrapping_add(n);
422        let ret_idx = ret_pos as usize & dst.inner.mask;
423
424        // safety: the value was written as part of `steal_into2` and not
425        // exposed to stealers, so no other thread can access it.
426        let ret = dst.inner.buffer[ret_idx].with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
427
428        if n == 0 {
429            // The `dst` queue is empty, but a single task was stolen
430            return Some(ret);
431        }
432
433        // Make the stolen items available to consumers
434        dst.inner.tail.store(dst_tail.wrapping_add(n), Release);
435
436        Some(ret)
437    }
438
439    // Steal tasks from `self`, placing them into `dst`. Returns the number of
440    // tasks that were stolen.
441    fn steal_into2(&self, dst: &mut Local<T>, dst_tail: UnsignedShort) -> UnsignedShort {
442        let mut prev_packed = self.0.head.load(Acquire);
443        let mut next_packed;
444
445        let n = loop {
446            let (src_head_steal, src_head_real) = unpack(prev_packed);
447            let src_tail = self.0.tail.load(Acquire);
448
449            // If these two do not match, another thread is concurrently
450            // stealing from the queue.
451            if src_head_steal != src_head_real {
452                return 0;
453            }
454
455            // Number of available tasks to steal
456            let n = src_tail.wrapping_sub(src_head_real);
457            let n = n - n / 2;
458
459            if n == 0 {
460                // No tasks available to steal
461                return 0;
462            }
463
464            // Update the real head index to acquire the tasks.
465            let steal_to = src_head_real.wrapping_add(n);
466            assert_ne!(src_head_steal, steal_to);
467            next_packed = pack(src_head_steal, steal_to);
468
469            // Claim all those tasks. This is done by incrementing the "real"
470            // head but not the steal. By doing this, no other thread is able to
471            // steal from this queue until the current thread completes.
472            let res = self
473                .0
474                .head
475                .compare_exchange(prev_packed, next_packed, AcqRel, Acquire);
476
477            match res {
478                Ok(_) => break n,
479                Err(actual) => prev_packed = actual,
480            }
481        };
482
483        debug_assert!(
484            n <= (self.0.buffer.len() - self.0.buffer.len() / 2) as UnsignedShort,
485            "actual = {}",
486            n
487        );
488
489        let (first, _) = unpack(next_packed);
490
491        // Take all the tasks
492        for i in 0..n {
493            // Compute the positions
494            let src_pos = first.wrapping_add(i);
495            let dst_pos = dst_tail.wrapping_add(i);
496
497            // Map to slots
498            let src_idx = src_pos as usize & self.0.mask;
499            let dst_idx = dst_pos as usize & self.0.mask;
500
501            // Read the task
502            //
503            // safety: We acquired the task with the atomic exchange above.
504            let task = self.0.buffer[src_idx].with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
505
506            // Write the task to the new slot
507            //
508            // safety: `dst` queue is empty and we are the only producer to
509            // this queue.
510            dst.inner.buffer[dst_idx]
511                .with_mut(|ptr| unsafe { ptr::write((*ptr).as_mut_ptr(), task) });
512        }
513
514        let mut prev_packed = next_packed;
515
516        // Update `src_head_steal` to match `src_head_real` signalling that the
517        // stealing routine is complete.
518        loop {
519            let head = unpack(prev_packed).1;
520            next_packed = pack(head, head);
521
522            let res = self
523                .0
524                .head
525                .compare_exchange(prev_packed, next_packed, AcqRel, Acquire);
526
527            match res {
528                Ok(_) => return n,
529                Err(actual) => {
530                    let (actual_steal, actual_real) = unpack(actual);
531
532                    assert_ne!(actual_steal, actual_real);
533
534                    prev_packed = actual;
535                }
536            }
537        }
538    }
539}
540
541cfg_unstable_metrics! {
542    impl<T> Steal<T> {
543        pub(crate) fn len(&self) -> usize {
544            self.0.len() as _
545        }
546    }
547}
548
549impl<T> Clone for Steal<T> {
550    fn clone(&self) -> Steal<T> {
551        Steal(self.0.clone())
552    }
553}
554
555impl<T> Drop for Local<T> {
556    fn drop(&mut self) {
557        if !std::thread::panicking() {
558            assert!(self.pop().is_none(), "queue not empty");
559        }
560    }
561}
562
563impl<T> Inner<T> {
564    fn remaining_slots(&self) -> usize {
565        let (steal, _) = unpack(self.head.load(Acquire));
566        let tail = self.tail.load(Acquire);
567
568        self.buffer.len() - (tail.wrapping_sub(steal) as usize)
569    }
570
571    fn len(&self) -> UnsignedShort {
572        let (_, head) = unpack(self.head.load(Acquire));
573        let tail = self.tail.load(Acquire);
574
575        tail.wrapping_sub(head)
576    }
577
578    fn is_empty(&self) -> bool {
579        self.len() == 0
580    }
581}
582
583/// Split the head value into the real head and the index a stealer is working
584/// on.
585fn unpack(n: UnsignedLong) -> (UnsignedShort, UnsignedShort) {
586    let real = n & UnsignedShort::MAX as UnsignedLong;
587    let steal = n >> (mem::size_of::<UnsignedShort>() * 8);
588
589    (steal as UnsignedShort, real as UnsignedShort)
590}
591
592/// Join the two head values
593fn pack(steal: UnsignedShort, real: UnsignedShort) -> UnsignedLong {
594    (real as UnsignedLong) | ((steal as UnsignedLong) << (mem::size_of::<UnsignedShort>() * 8))
595}