tokio_util/task/
join_map.rs

1use hashbrown::hash_map::RawEntryMut;
2use hashbrown::HashMap;
3use std::borrow::Borrow;
4use std::collections::hash_map::RandomState;
5use std::fmt;
6use std::future::Future;
7use std::hash::{BuildHasher, Hash, Hasher};
8use std::marker::PhantomData;
9use tokio::runtime::Handle;
10use tokio::task::{AbortHandle, Id, JoinError, JoinSet, LocalSet};
11
12/// A collection of tasks spawned on a Tokio runtime, associated with hash map
13/// keys.
14///
15/// This type is very similar to the [`JoinSet`] type in `tokio::task`, with the
16/// addition of a  set of keys associated with each task. These keys allow
17/// [cancelling a task][abort] or [multiple tasks][abort_matching] in the
18/// `JoinMap` based on   their keys, or [test whether a task corresponding to a
19/// given key exists][contains] in the `JoinMap`.
20///
21/// In addition, when tasks in the `JoinMap` complete, they will return the
22/// associated key along with the value returned by the task, if any.
23///
24/// A `JoinMap` can be used to await the completion of some or all of the tasks
25/// in the map. The map is not ordered, and the tasks will be returned in the
26/// order they complete.
27///
28/// All of the tasks must have the same return type `V`.
29///
30/// When the `JoinMap` is dropped, all tasks in the `JoinMap` are immediately aborted.
31///
32/// **Note**: This type depends on Tokio's [unstable API][unstable]. See [the
33/// documentation on unstable features][unstable] for details on how to enable
34/// Tokio's unstable features.
35///
36/// # Examples
37///
38/// Spawn multiple tasks and wait for them:
39///
40/// ```
41/// use tokio_util::task::JoinMap;
42///
43/// #[tokio::main]
44/// async fn main() {
45///     let mut map = JoinMap::new();
46///
47///     for i in 0..10 {
48///         // Spawn a task on the `JoinMap` with `i` as its key.
49///         map.spawn(i, async move { /* ... */ });
50///     }
51///
52///     let mut seen = [false; 10];
53///
54///     // When a task completes, `join_next` returns the task's key along
55///     // with its output.
56///     while let Some((key, res)) = map.join_next().await {
57///         seen[key] = true;
58///         assert!(res.is_ok(), "task {} completed successfully!", key);
59///     }
60///
61///     for i in 0..10 {
62///         assert!(seen[i]);
63///     }
64/// }
65/// ```
66///
67/// Cancel tasks based on their keys:
68///
69/// ```
70/// use tokio_util::task::JoinMap;
71///
72/// #[tokio::main]
73/// async fn main() {
74///     let mut map = JoinMap::new();
75///
76///     map.spawn("hello world", async move { /* ... */ });
77///     map.spawn("goodbye world", async move { /* ... */});
78///
79///     // Look up the "goodbye world" task in the map and abort it.
80///     let aborted = map.abort("goodbye world");
81///
82///     // `JoinMap::abort` returns `true` if a task existed for the
83///     // provided key.
84///     assert!(aborted);
85///
86///     while let Some((key, res)) = map.join_next().await {
87///         if key == "goodbye world" {
88///             // The aborted task should complete with a cancelled `JoinError`.
89///             assert!(res.unwrap_err().is_cancelled());
90///         } else {
91///             // Other tasks should complete normally.
92///             assert!(res.is_ok());
93///         }
94///     }
95/// }
96/// ```
97///
98/// [`JoinSet`]: tokio::task::JoinSet
99/// [unstable]: tokio#unstable-features
100/// [abort]: fn@Self::abort
101/// [abort_matching]: fn@Self::abort_matching
102/// [contains]: fn@Self::contains_key
103#[cfg_attr(docsrs, doc(cfg(all(feature = "rt", tokio_unstable))))]
104pub struct JoinMap<K, V, S = RandomState> {
105    /// A map of the [`AbortHandle`]s of the tasks spawned on this `JoinMap`,
106    /// indexed by their keys and task IDs.
107    ///
108    /// The [`Key`] type contains both the task's `K`-typed key provided when
109    /// spawning tasks, and the task's IDs. The IDs are stored here to resolve
110    /// hash collisions when looking up tasks based on their pre-computed hash
111    /// (as stored in the `hashes_by_task` map).
112    tasks_by_key: HashMap<Key<K>, AbortHandle, S>,
113
114    /// A map from task IDs to the hash of the key associated with that task.
115    ///
116    /// This map is used to perform reverse lookups of tasks in the
117    /// `tasks_by_key` map based on their task IDs. When a task terminates, the
118    /// ID is provided to us by the `JoinSet`, so we can look up the hash value
119    /// of that task's key, and then remove it from the `tasks_by_key` map using
120    /// the raw hash code, resolving collisions by comparing task IDs.
121    hashes_by_task: HashMap<Id, u64, S>,
122
123    /// The [`JoinSet`] that awaits the completion of tasks spawned on this
124    /// `JoinMap`.
125    tasks: JoinSet<V>,
126}
127
128/// A variant of [`task::Builder`] that spawns tasks on a [`JoinMap`] rather than on the current
129/// default runtime.
130///
131/// [`task::Builder`]: tokio::task::Builder
132#[cfg(feature = "tracing")]
133#[cfg_attr(
134    docsrs,
135    doc(cfg(all(feature = "rt", feature = "tracing", tokio_unstable)))
136)]
137pub struct Builder<'a, K, V, S> {
138    joinmap: &'a mut JoinMap<K, V, S>,
139    name: Option<&'a str>,
140}
141
142/// A [`JoinMap`] key.
143///
144/// This holds both a `K`-typed key (the actual key as seen by the user), _and_
145/// a task ID, so that hash collisions between `K`-typed keys can be resolved
146/// using either `K`'s `Eq` impl *or* by checking the task IDs.
147///
148/// This allows looking up a task using either an actual key (such as when the
149/// user queries the map with a key), *or* using a task ID and a hash (such as
150/// when removing completed tasks from the map).
151#[derive(Debug)]
152struct Key<K> {
153    key: K,
154    id: Id,
155}
156
157impl<K, V> JoinMap<K, V> {
158    /// Creates a new empty `JoinMap`.
159    ///
160    /// The `JoinMap` is initially created with a capacity of 0, so it will not
161    /// allocate until a task is first spawned on it.
162    ///
163    /// # Examples
164    ///
165    /// ```
166    /// use tokio_util::task::JoinMap;
167    /// let map: JoinMap<&str, i32> = JoinMap::new();
168    /// ```
169    #[inline]
170    #[must_use]
171    pub fn new() -> Self {
172        Self::with_hasher(RandomState::new())
173    }
174
175    /// Creates an empty `JoinMap` with the specified capacity.
176    ///
177    /// The `JoinMap` will be able to hold at least `capacity` tasks without
178    /// reallocating.
179    ///
180    /// # Examples
181    ///
182    /// ```
183    /// use tokio_util::task::JoinMap;
184    /// let map: JoinMap<&str, i32> = JoinMap::with_capacity(10);
185    /// ```
186    #[inline]
187    #[must_use]
188    pub fn with_capacity(capacity: usize) -> Self {
189        JoinMap::with_capacity_and_hasher(capacity, Default::default())
190    }
191}
192
193impl<K, V, S: Clone> JoinMap<K, V, S> {
194    /// Creates an empty `JoinMap` which will use the given hash builder to hash
195    /// keys.
196    ///
197    /// The created map has the default initial capacity.
198    ///
199    /// Warning: `hash_builder` is normally randomly generated, and
200    /// is designed to allow `JoinMap` to be resistant to attacks that
201    /// cause many collisions and very poor performance. Setting it
202    /// manually using this function can expose a DoS attack vector.
203    ///
204    /// The `hash_builder` passed should implement the [`BuildHasher`] trait for
205    /// the `JoinMap` to be useful, see its documentation for details.
206    #[inline]
207    #[must_use]
208    pub fn with_hasher(hash_builder: S) -> Self {
209        Self::with_capacity_and_hasher(0, hash_builder)
210    }
211
212    /// Creates an empty `JoinMap` with the specified capacity, using `hash_builder`
213    /// to hash the keys.
214    ///
215    /// The `JoinMap` will be able to hold at least `capacity` elements without
216    /// reallocating. If `capacity` is 0, the `JoinMap` will not allocate.
217    ///
218    /// Warning: `hash_builder` is normally randomly generated, and
219    /// is designed to allow HashMaps to be resistant to attacks that
220    /// cause many collisions and very poor performance. Setting it
221    /// manually using this function can expose a DoS attack vector.
222    ///
223    /// The `hash_builder` passed should implement the [`BuildHasher`] trait for
224    /// the `JoinMap`to be useful, see its documentation for details.
225    ///
226    /// # Examples
227    ///
228    /// ```
229    /// # #[tokio::main]
230    /// # async fn main() {
231    /// use tokio_util::task::JoinMap;
232    /// use std::collections::hash_map::RandomState;
233    ///
234    /// let s = RandomState::new();
235    /// let mut map = JoinMap::with_capacity_and_hasher(10, s);
236    /// map.spawn(1, async move { "hello world!" });
237    /// # }
238    /// ```
239    #[inline]
240    #[must_use]
241    pub fn with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Self {
242        Self {
243            tasks_by_key: HashMap::with_capacity_and_hasher(capacity, hash_builder.clone()),
244            hashes_by_task: HashMap::with_capacity_and_hasher(capacity, hash_builder),
245            tasks: JoinSet::new(),
246        }
247    }
248
249    /// Returns the number of tasks currently in the `JoinMap`.
250    pub fn len(&self) -> usize {
251        let len = self.tasks_by_key.len();
252        debug_assert_eq!(len, self.hashes_by_task.len());
253        len
254    }
255
256    /// Returns whether the `JoinMap` is empty.
257    pub fn is_empty(&self) -> bool {
258        let empty = self.tasks_by_key.is_empty();
259        debug_assert_eq!(empty, self.hashes_by_task.is_empty());
260        empty
261    }
262
263    /// Returns the number of tasks the map can hold without reallocating.
264    ///
265    /// This number is a lower bound; the `JoinMap` might be able to hold
266    /// more, but is guaranteed to be able to hold at least this many.
267    ///
268    /// # Examples
269    ///
270    /// ```
271    /// use tokio_util::task::JoinMap;
272    ///
273    /// let map: JoinMap<i32, i32> = JoinMap::with_capacity(100);
274    /// assert!(map.capacity() >= 100);
275    /// ```
276    #[inline]
277    pub fn capacity(&self) -> usize {
278        let capacity = self.tasks_by_key.capacity();
279        debug_assert_eq!(capacity, self.hashes_by_task.capacity());
280        capacity
281    }
282}
283
284impl<K, V, S> JoinMap<K, V, S>
285where
286    K: Hash + Eq,
287    V: 'static,
288    S: BuildHasher,
289{
290    /// Returns a [`Builder`] that can be used to configure a task prior to spawning it on this
291    /// [`JoinMap`].
292    ///
293    /// # Examples
294    ///
295    /// ```
296    /// use tokio_util::task::JoinMap;
297    ///
298    /// #[tokio::main]
299    /// async fn main() -> std::io::Result<()> {
300    ///     let mut map = JoinMap::new();
301    ///
302    ///     // Use the builder to configure the task's name before spawning it.
303    ///     map.build_task()
304    ///         .name("my_task")
305    ///         .spawn(42, async { /* ... */ });
306    ///
307    ///     Ok(())
308    /// }
309    /// ```
310    #[cfg(feature = "tracing")]
311    #[cfg_attr(
312        docsrs,
313        doc(cfg(all(feature = "rt", feature = "tracing", tokio_unstable)))
314    )]
315    pub fn build_task(&mut self) -> Builder<'_, K, V, S> {
316        Builder {
317            joinmap: self,
318            name: None,
319        }
320    }
321
322    /// Spawn the provided task and store it in this `JoinMap` with the provided
323    /// key.
324    ///
325    /// If a task previously existed in the `JoinMap` for this key, that task
326    /// will be cancelled and replaced with the new one. The previous task will
327    /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
328    /// *not* return a cancelled [`JoinError`] for that task.
329    ///
330    /// # Panics
331    ///
332    /// This method panics if called outside of a Tokio runtime.
333    ///
334    /// [`join_next`]: Self::join_next
335    #[track_caller]
336    pub fn spawn<F>(&mut self, key: K, task: F)
337    where
338        F: Future<Output = V>,
339        F: Send + 'static,
340        V: Send,
341    {
342        let task = self.tasks.spawn(task);
343        self.insert(key, task)
344    }
345
346    /// Spawn the provided task on the provided runtime and store it in this
347    /// `JoinMap` with the provided key.
348    ///
349    /// If a task previously existed in the `JoinMap` for this key, that task
350    /// will be cancelled and replaced with the new one. The previous task will
351    /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
352    /// *not* return a cancelled [`JoinError`] for that task.
353    ///
354    /// [`join_next`]: Self::join_next
355    #[track_caller]
356    pub fn spawn_on<F>(&mut self, key: K, task: F, handle: &Handle)
357    where
358        F: Future<Output = V>,
359        F: Send + 'static,
360        V: Send,
361    {
362        let task = self.tasks.spawn_on(task, handle);
363        self.insert(key, task);
364    }
365
366    /// Spawn the blocking code on the blocking threadpool and store it in this `JoinMap` with the provided
367    /// key.
368    ///
369    /// If a task previously existed in the `JoinMap` for this key, that task
370    /// will be cancelled and replaced with the new one. The previous task will
371    /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
372    /// *not* return a cancelled [`JoinError`] for that task.
373    ///
374    /// Note that blocking tasks cannot be cancelled after execution starts.
375    /// Replaced blocking tasks will still run to completion if the task has begun
376    /// to execute when it is replaced. A blocking task which is replaced before
377    /// it has been scheduled on a blocking worker thread will be cancelled.
378    ///
379    /// # Panics
380    ///
381    /// This method panics if called outside of a Tokio runtime.
382    ///
383    /// [`join_next`]: Self::join_next
384    #[track_caller]
385    pub fn spawn_blocking<F>(&mut self, key: K, f: F)
386    where
387        F: FnOnce() -> V,
388        F: Send + 'static,
389        V: Send,
390    {
391        let task = self.tasks.spawn_blocking(f);
392        self.insert(key, task)
393    }
394
395    /// Spawn the blocking code on the blocking threadpool of the provided runtime and store it in this
396    /// `JoinMap` with the provided key.
397    ///
398    /// If a task previously existed in the `JoinMap` for this key, that task
399    /// will be cancelled and replaced with the new one. The previous task will
400    /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
401    /// *not* return a cancelled [`JoinError`] for that task.
402    ///
403    /// Note that blocking tasks cannot be cancelled after execution starts.
404    /// Replaced blocking tasks will still run to completion if the task has begun
405    /// to execute when it is replaced. A blocking task which is replaced before
406    /// it has been scheduled on a blocking worker thread will be cancelled.
407    ///
408    /// [`join_next`]: Self::join_next
409    #[track_caller]
410    pub fn spawn_blocking_on<F>(&mut self, key: K, f: F, handle: &Handle)
411    where
412        F: FnOnce() -> V,
413        F: Send + 'static,
414        V: Send,
415    {
416        let task = self.tasks.spawn_blocking_on(f, handle);
417        self.insert(key, task);
418    }
419
420    /// Spawn the provided task on the current [`LocalSet`] and store it in this
421    /// `JoinMap` with the provided key.
422    ///
423    /// If a task previously existed in the `JoinMap` for this key, that task
424    /// will be cancelled and replaced with the new one. The previous task will
425    /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
426    /// *not* return a cancelled [`JoinError`] for that task.
427    ///
428    /// # Panics
429    ///
430    /// This method panics if it is called outside of a `LocalSet`.
431    ///
432    /// [`LocalSet`]: tokio::task::LocalSet
433    /// [`join_next`]: Self::join_next
434    #[track_caller]
435    pub fn spawn_local<F>(&mut self, key: K, task: F)
436    where
437        F: Future<Output = V>,
438        F: 'static,
439    {
440        let task = self.tasks.spawn_local(task);
441        self.insert(key, task);
442    }
443
444    /// Spawn the provided task on the provided [`LocalSet`] and store it in
445    /// this `JoinMap` with the provided key.
446    ///
447    /// If a task previously existed in the `JoinMap` for this key, that task
448    /// will be cancelled and replaced with the new one. The previous task will
449    /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
450    /// *not* return a cancelled [`JoinError`] for that task.
451    ///
452    /// [`LocalSet`]: tokio::task::LocalSet
453    /// [`join_next`]: Self::join_next
454    #[track_caller]
455    pub fn spawn_local_on<F>(&mut self, key: K, task: F, local_set: &LocalSet)
456    where
457        F: Future<Output = V>,
458        F: 'static,
459    {
460        let task = self.tasks.spawn_local_on(task, local_set);
461        self.insert(key, task)
462    }
463
464    fn insert(&mut self, key: K, abort: AbortHandle) {
465        let hash = self.hash(&key);
466        let id = abort.id();
467        let map_key = Key { id, key };
468
469        // Insert the new key into the map of tasks by keys.
470        let entry = self
471            .tasks_by_key
472            .raw_entry_mut()
473            .from_hash(hash, |k| k.key == map_key.key);
474        match entry {
475            RawEntryMut::Occupied(mut occ) => {
476                // There was a previous task spawned with the same key! Cancel
477                // that task, and remove its ID from the map of hashes by task IDs.
478                let Key { id: prev_id, .. } = occ.insert_key(map_key);
479                occ.insert(abort).abort();
480                let _prev_hash = self.hashes_by_task.remove(&prev_id);
481                debug_assert_eq!(Some(hash), _prev_hash);
482            }
483            RawEntryMut::Vacant(vac) => {
484                vac.insert(map_key, abort);
485            }
486        };
487
488        // Associate the key's hash with this task's ID, for looking up tasks by ID.
489        let _prev = self.hashes_by_task.insert(id, hash);
490        debug_assert!(_prev.is_none(), "no prior task should have had the same ID");
491    }
492
493    /// Waits until one of the tasks in the map completes and returns its
494    /// output, along with the key corresponding to that task.
495    ///
496    /// Returns `None` if the map is empty.
497    ///
498    /// # Cancel Safety
499    ///
500    /// This method is cancel safe. If `join_next` is used as the event in a [`tokio::select!`]
501    /// statement and some other branch completes first, it is guaranteed that no tasks were
502    /// removed from this `JoinMap`.
503    ///
504    /// # Returns
505    ///
506    /// This function returns:
507    ///
508    ///  * `Some((key, Ok(value)))` if one of the tasks in this `JoinMap` has
509    ///    completed. The `value` is the return value of that ask, and `key` is
510    ///    the key associated with the task.
511    ///  * `Some((key, Err(err))` if one of the tasks in this `JoinMap` has
512    ///    panicked or been aborted. `key` is the key associated  with the task
513    ///    that panicked or was aborted.
514    ///  * `None` if the `JoinMap` is empty.
515    ///
516    /// [`tokio::select!`]: tokio::select
517    pub async fn join_next(&mut self) -> Option<(K, Result<V, JoinError>)> {
518        let (res, id) = match self.tasks.join_next_with_id().await {
519            Some(Ok((id, output))) => (Ok(output), id),
520            Some(Err(e)) => {
521                let id = e.id();
522                (Err(e), id)
523            }
524            None => return None,
525        };
526        let key = self.remove_by_id(id)?;
527        Some((key, res))
528    }
529
530    /// Aborts all tasks and waits for them to finish shutting down.
531    ///
532    /// Calling this method is equivalent to calling [`abort_all`] and then calling [`join_next`] in
533    /// a loop until it returns `None`.
534    ///
535    /// This method ignores any panics in the tasks shutting down. When this call returns, the
536    /// `JoinMap` will be empty.
537    ///
538    /// [`abort_all`]: fn@Self::abort_all
539    /// [`join_next`]: fn@Self::join_next
540    pub async fn shutdown(&mut self) {
541        self.abort_all();
542        while self.join_next().await.is_some() {}
543    }
544
545    /// Abort the task corresponding to the provided `key`.
546    ///
547    /// If this `JoinMap` contains a task corresponding to `key`, this method
548    /// will abort that task and return `true`. Otherwise, if no task exists for
549    /// `key`, this method returns `false`.
550    ///
551    /// # Examples
552    ///
553    /// Aborting a task by key:
554    ///
555    /// ```
556    /// use tokio_util::task::JoinMap;
557    ///
558    /// # #[tokio::main]
559    /// # async fn main() {
560    /// let mut map = JoinMap::new();
561    ///
562    /// map.spawn("hello world", async move { /* ... */ });
563    /// map.spawn("goodbye world", async move { /* ... */});
564    ///
565    /// // Look up the "goodbye world" task in the map and abort it.
566    /// map.abort("goodbye world");
567    ///
568    /// while let Some((key, res)) = map.join_next().await {
569    ///     if key == "goodbye world" {
570    ///         // The aborted task should complete with a cancelled `JoinError`.
571    ///         assert!(res.unwrap_err().is_cancelled());
572    ///     } else {
573    ///         // Other tasks should complete normally.
574    ///         assert!(res.is_ok());
575    ///     }
576    /// }
577    /// # }
578    /// ```
579    ///
580    /// `abort` returns `true` if a task was aborted:
581    /// ```
582    /// use tokio_util::task::JoinMap;
583    ///
584    /// # #[tokio::main]
585    /// # async fn main() {
586    /// let mut map = JoinMap::new();
587    ///
588    /// map.spawn("hello world", async move { /* ... */ });
589    /// map.spawn("goodbye world", async move { /* ... */});
590    ///
591    /// // A task for the key "goodbye world" should exist in the map:
592    /// assert!(map.abort("goodbye world"));
593    ///
594    /// // Aborting a key that does not exist will return `false`:
595    /// assert!(!map.abort("goodbye universe"));
596    /// # }
597    /// ```
598    pub fn abort<Q: ?Sized>(&mut self, key: &Q) -> bool
599    where
600        Q: Hash + Eq,
601        K: Borrow<Q>,
602    {
603        match self.get_by_key(key) {
604            Some((_, handle)) => {
605                handle.abort();
606                true
607            }
608            None => false,
609        }
610    }
611
612    /// Aborts all tasks with keys matching `predicate`.
613    ///
614    /// `predicate` is a function called with a reference to each key in the
615    /// map. If it returns `true` for a given key, the corresponding task will
616    /// be cancelled.
617    ///
618    /// # Examples
619    /// ```
620    /// use tokio_util::task::JoinMap;
621    ///
622    /// # // use the current thread rt so that spawned tasks don't
623    /// # // complete in the background before they can be aborted.
624    /// # #[tokio::main(flavor = "current_thread")]
625    /// # async fn main() {
626    /// let mut map = JoinMap::new();
627    ///
628    /// map.spawn("hello world", async move {
629    ///     // ...
630    ///     # tokio::task::yield_now().await; // don't complete immediately, get aborted!
631    /// });
632    /// map.spawn("goodbye world", async move {
633    ///     // ...
634    ///     # tokio::task::yield_now().await; // don't complete immediately, get aborted!
635    /// });
636    /// map.spawn("hello san francisco", async move {
637    ///     // ...
638    ///     # tokio::task::yield_now().await; // don't complete immediately, get aborted!
639    /// });
640    /// map.spawn("goodbye universe", async move {
641    ///     // ...
642    ///     # tokio::task::yield_now().await; // don't complete immediately, get aborted!
643    /// });
644    ///
645    /// // Abort all tasks whose keys begin with "goodbye"
646    /// map.abort_matching(|key| key.starts_with("goodbye"));
647    ///
648    /// let mut seen = 0;
649    /// while let Some((key, res)) = map.join_next().await {
650    ///     seen += 1;
651    ///     if key.starts_with("goodbye") {
652    ///         // The aborted task should complete with a cancelled `JoinError`.
653    ///         assert!(res.unwrap_err().is_cancelled());
654    ///     } else {
655    ///         // Other tasks should complete normally.
656    ///         assert!(key.starts_with("hello"));
657    ///         assert!(res.is_ok());
658    ///     }
659    /// }
660    ///
661    /// // All spawned tasks should have completed.
662    /// assert_eq!(seen, 4);
663    /// # }
664    /// ```
665    pub fn abort_matching(&mut self, mut predicate: impl FnMut(&K) -> bool) {
666        // Note: this method iterates over the tasks and keys *without* removing
667        // any entries, so that the keys from aborted tasks can still be
668        // returned when calling `join_next` in the future.
669        for (Key { ref key, .. }, task) in &self.tasks_by_key {
670            if predicate(key) {
671                task.abort();
672            }
673        }
674    }
675
676    /// Returns an iterator visiting all keys in this `JoinMap` in arbitrary order.
677    ///
678    /// If a task has completed, but its output hasn't yet been consumed by a
679    /// call to [`join_next`], this method will still return its key.
680    ///
681    /// [`join_next`]: fn@Self::join_next
682    pub fn keys(&self) -> JoinMapKeys<'_, K, V> {
683        JoinMapKeys {
684            iter: self.tasks_by_key.keys(),
685            _value: PhantomData,
686        }
687    }
688
689    /// Returns `true` if this `JoinMap` contains a task for the provided key.
690    ///
691    /// If the task has completed, but its output hasn't yet been consumed by a
692    /// call to [`join_next`], this method will still return `true`.
693    ///
694    /// [`join_next`]: fn@Self::join_next
695    pub fn contains_key<Q: ?Sized>(&self, key: &Q) -> bool
696    where
697        Q: Hash + Eq,
698        K: Borrow<Q>,
699    {
700        self.get_by_key(key).is_some()
701    }
702
703    /// Returns `true` if this `JoinMap` contains a task with the provided
704    /// [task ID].
705    ///
706    /// If the task has completed, but its output hasn't yet been consumed by a
707    /// call to [`join_next`], this method will still return `true`.
708    ///
709    /// [`join_next`]: fn@Self::join_next
710    /// [task ID]: tokio::task::Id
711    pub fn contains_task(&self, task: &Id) -> bool {
712        self.get_by_id(task).is_some()
713    }
714
715    /// Reserves capacity for at least `additional` more tasks to be spawned
716    /// on this `JoinMap` without reallocating for the map of task keys. The
717    /// collection may reserve more space to avoid frequent reallocations.
718    ///
719    /// Note that spawning a task will still cause an allocation for the task
720    /// itself.
721    ///
722    /// # Panics
723    ///
724    /// Panics if the new allocation size overflows [`usize`].
725    ///
726    /// # Examples
727    ///
728    /// ```
729    /// use tokio_util::task::JoinMap;
730    ///
731    /// let mut map: JoinMap<&str, i32> = JoinMap::new();
732    /// map.reserve(10);
733    /// ```
734    #[inline]
735    pub fn reserve(&mut self, additional: usize) {
736        self.tasks_by_key.reserve(additional);
737        self.hashes_by_task.reserve(additional);
738    }
739
740    /// Shrinks the capacity of the `JoinMap` as much as possible. It will drop
741    /// down as much as possible while maintaining the internal rules
742    /// and possibly leaving some space in accordance with the resize policy.
743    ///
744    /// # Examples
745    ///
746    /// ```
747    /// # #[tokio::main]
748    /// # async fn main() {
749    /// use tokio_util::task::JoinMap;
750    ///
751    /// let mut map: JoinMap<i32, i32> = JoinMap::with_capacity(100);
752    /// map.spawn(1, async move { 2 });
753    /// map.spawn(3, async move { 4 });
754    /// assert!(map.capacity() >= 100);
755    /// map.shrink_to_fit();
756    /// assert!(map.capacity() >= 2);
757    /// # }
758    /// ```
759    #[inline]
760    pub fn shrink_to_fit(&mut self) {
761        self.hashes_by_task.shrink_to_fit();
762        self.tasks_by_key.shrink_to_fit();
763    }
764
765    /// Shrinks the capacity of the map with a lower limit. It will drop
766    /// down no lower than the supplied limit while maintaining the internal rules
767    /// and possibly leaving some space in accordance with the resize policy.
768    ///
769    /// If the current capacity is less than the lower limit, this is a no-op.
770    ///
771    /// # Examples
772    ///
773    /// ```
774    /// # #[tokio::main]
775    /// # async fn main() {
776    /// use tokio_util::task::JoinMap;
777    ///
778    /// let mut map: JoinMap<i32, i32> = JoinMap::with_capacity(100);
779    /// map.spawn(1, async move { 2 });
780    /// map.spawn(3, async move { 4 });
781    /// assert!(map.capacity() >= 100);
782    /// map.shrink_to(10);
783    /// assert!(map.capacity() >= 10);
784    /// map.shrink_to(0);
785    /// assert!(map.capacity() >= 2);
786    /// # }
787    /// ```
788    #[inline]
789    pub fn shrink_to(&mut self, min_capacity: usize) {
790        self.hashes_by_task.shrink_to(min_capacity);
791        self.tasks_by_key.shrink_to(min_capacity)
792    }
793
794    /// Look up a task in the map by its key, returning the key and abort handle.
795    fn get_by_key<'map, Q: ?Sized>(&'map self, key: &Q) -> Option<(&'map Key<K>, &'map AbortHandle)>
796    where
797        Q: Hash + Eq,
798        K: Borrow<Q>,
799    {
800        let hash = self.hash(key);
801        self.tasks_by_key
802            .raw_entry()
803            .from_hash(hash, |k| k.key.borrow() == key)
804    }
805
806    /// Look up a task in the map by its task ID, returning the key and abort handle.
807    fn get_by_id<'map>(&'map self, id: &Id) -> Option<(&'map Key<K>, &'map AbortHandle)> {
808        let hash = self.hashes_by_task.get(id)?;
809        self.tasks_by_key
810            .raw_entry()
811            .from_hash(*hash, |k| &k.id == id)
812    }
813
814    /// Remove a task from the map by ID, returning the key for that task.
815    fn remove_by_id(&mut self, id: Id) -> Option<K> {
816        // Get the hash for the given ID.
817        let hash = self.hashes_by_task.remove(&id)?;
818
819        // Remove the entry for that hash.
820        let entry = self
821            .tasks_by_key
822            .raw_entry_mut()
823            .from_hash(hash, |k| k.id == id);
824        let (Key { id: _key_id, key }, handle) = match entry {
825            RawEntryMut::Occupied(entry) => entry.remove_entry(),
826            _ => return None,
827        };
828        debug_assert_eq!(_key_id, id);
829        debug_assert_eq!(id, handle.id());
830        self.hashes_by_task.remove(&id);
831        Some(key)
832    }
833
834    /// Returns the hash for a given key.
835    #[inline]
836    fn hash<Q: ?Sized>(&self, key: &Q) -> u64
837    where
838        Q: Hash,
839    {
840        let mut hasher = self.tasks_by_key.hasher().build_hasher();
841        key.hash(&mut hasher);
842        hasher.finish()
843    }
844}
845
846impl<K, V, S> JoinMap<K, V, S>
847where
848    V: 'static,
849{
850    /// Aborts all tasks on this `JoinMap`.
851    ///
852    /// This does not remove the tasks from the `JoinMap`. To wait for the tasks to complete
853    /// cancellation, you should call `join_next` in a loop until the `JoinMap` is empty.
854    pub fn abort_all(&mut self) {
855        self.tasks.abort_all()
856    }
857
858    /// Removes all tasks from this `JoinMap` without aborting them.
859    ///
860    /// The tasks removed by this call will continue to run in the background even if the `JoinMap`
861    /// is dropped. They may still be aborted by key.
862    pub fn detach_all(&mut self) {
863        self.tasks.detach_all();
864        self.tasks_by_key.clear();
865        self.hashes_by_task.clear();
866    }
867}
868
869// Hand-written `fmt::Debug` implementation in order to avoid requiring `V:
870// Debug`, since no value is ever actually stored in the map.
871impl<K: fmt::Debug, V, S> fmt::Debug for JoinMap<K, V, S> {
872    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
873        // format the task keys and abort handles a little nicer by just
874        // printing the key and task ID pairs, without format the `Key` struct
875        // itself or the `AbortHandle`, which would just format the task's ID
876        // again.
877        struct KeySet<'a, K: fmt::Debug, S>(&'a HashMap<Key<K>, AbortHandle, S>);
878        impl<K: fmt::Debug, S> fmt::Debug for KeySet<'_, K, S> {
879            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
880                f.debug_map()
881                    .entries(self.0.keys().map(|Key { key, id }| (key, id)))
882                    .finish()
883            }
884        }
885
886        f.debug_struct("JoinMap")
887            // The `tasks_by_key` map is the only one that contains information
888            // that's really worth formatting for the user, since it contains
889            // the tasks' keys and IDs. The other fields are basically
890            // implementation details.
891            .field("tasks", &KeySet(&self.tasks_by_key))
892            .finish()
893    }
894}
895
896impl<K, V> Default for JoinMap<K, V> {
897    fn default() -> Self {
898        Self::new()
899    }
900}
901
902// === impl Builder ===
903
904#[cfg(feature = "tracing")]
905#[cfg_attr(
906    docsrs,
907    doc(cfg(all(feature = "rt", feature = "tracing", tokio_unstable)))
908)]
909impl<'a, K, V, S> Builder<'a, K, V, S>
910where
911    K: Hash + Eq,
912    V: 'static,
913    S: BuildHasher,
914{
915    /// Assigns a name to the task which will be spawned.
916    pub fn name(mut self, name: &'a str) -> Self {
917        self.name = Some(name);
918        self
919    }
920
921    /// Spawn the provided task with this builder's settings and store it in this `JoinMap` with
922    /// the provided key.
923    ///
924    /// If a task previously existed in the `JoinMap` for this key, that task
925    /// will be cancelled and replaced with the new one. The previous task will
926    /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
927    /// *not* return a cancelled [`JoinError`] for that task.
928    ///
929    /// # Panics
930    ///
931    /// This method panics if called outside of a Tokio runtime.
932    ///
933    /// [`join_next`]: JoinMap::join_next
934    #[track_caller]
935    pub fn spawn<F>(self, key: K, task: F) -> std::io::Result<()>
936    where
937        F: Future<Output = V>,
938        F: Send + 'static,
939        V: Send,
940    {
941        let builder = self.joinmap.tasks.build_task();
942        let builder = if let Some(name) = self.name {
943            builder.name(name)
944        } else {
945            builder
946        };
947        let abort = builder.spawn(task)?;
948
949        Ok(self.joinmap.insert(key, abort))
950    }
951
952    /// Spawn the provided task on the provided runtime and store it in this
953    /// `JoinMap` with the provided key.
954    ///
955    /// If a task previously existed in the `JoinMap` for this key, that task
956    /// will be cancelled and replaced with the new one. The previous task will
957    /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
958    /// *not* return a cancelled [`JoinError`] for that task.
959    ///
960    /// [`join_next`]: JoinMap::join_next
961    #[track_caller]
962    pub fn spawn_on<F>(&mut self, key: K, task: F, handle: &Handle) -> std::io::Result<()>
963    where
964        F: Future<Output = V>,
965        F: Send + 'static,
966        V: Send,
967    {
968        let builder = self.joinmap.tasks.build_task();
969        let builder = if let Some(name) = self.name {
970            builder.name(name)
971        } else {
972            builder
973        };
974        let abort = builder.spawn_on(task, handle)?;
975
976        Ok(self.joinmap.insert(key, abort))
977    }
978
979    /// Spawn the blocking code on the blocking threadpool and store it in this `JoinMap` with the provided
980    /// key.
981    ///
982    /// If a task previously existed in the `JoinMap` for this key, that task
983    /// will be cancelled and replaced with the new one. The previous task will
984    /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
985    /// *not* return a cancelled [`JoinError`] for that task.
986    ///
987    /// Note that blocking tasks cannot be cancelled after execution starts.
988    /// Replaced blocking tasks will still run to completion if the task has begun
989    /// to execute when it is replaced. A blocking task which is replaced before
990    /// it has been scheduled on a blocking worker thread will be cancelled.
991    ///
992    /// # Panics
993    ///
994    /// This method panics if called outside of a Tokio runtime.
995    ///
996    /// [`join_next`]: JoinMap::join_next
997    #[track_caller]
998    pub fn spawn_blocking<F>(&mut self, key: K, f: F) -> std::io::Result<()>
999    where
1000        F: FnOnce() -> V,
1001        F: Send + 'static,
1002        V: Send,
1003    {
1004        let builder = self.joinmap.tasks.build_task();
1005        let builder = if let Some(name) = self.name {
1006            builder.name(name)
1007        } else {
1008            builder
1009        };
1010        let abort = builder.spawn_blocking(f)?;
1011
1012        Ok(self.joinmap.insert(key, abort))
1013    }
1014
1015    /// Spawn the blocking code on the blocking threadpool of the provided runtime and store it in this
1016    /// `JoinMap` with the provided key.
1017    ///
1018    /// If a task previously existed in the `JoinMap` for this key, that task
1019    /// will be cancelled and replaced with the new one. The previous task will
1020    /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
1021    /// *not* return a cancelled [`JoinError`] for that task.
1022    ///
1023    /// Note that blocking tasks cannot be cancelled after execution starts.
1024    /// Replaced blocking tasks will still run to completion if the task has begun
1025    /// to execute when it is replaced. A blocking task which is replaced before
1026    /// it has been scheduled on a blocking worker thread will be cancelled.
1027    ///
1028    /// [`join_next`]: JoinMap::join_next
1029    #[track_caller]
1030    pub fn spawn_blocking_on<F>(&mut self, key: K, f: F, handle: &Handle) -> std::io::Result<()>
1031    where
1032        F: FnOnce() -> V,
1033        F: Send + 'static,
1034        V: Send,
1035    {
1036        let builder = self.joinmap.tasks.build_task();
1037        let builder = if let Some(name) = self.name {
1038            builder.name(name)
1039        } else {
1040            builder
1041        };
1042        let abort = builder.spawn_blocking_on(f, handle)?;
1043
1044        Ok(self.joinmap.insert(key, abort))
1045    }
1046
1047    /// Spawn the provided task on the current [`LocalSet`] and store it in this
1048    /// `JoinMap` with the provided key.
1049    ///
1050    /// If a task previously existed in the `JoinMap` for this key, that task
1051    /// will be cancelled and replaced with the new one. The previous task will
1052    /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
1053    /// *not* return a cancelled [`JoinError`] for that task.
1054    ///
1055    /// # Panics
1056    ///
1057    /// This method panics if it is called outside of a `LocalSet`.
1058    ///
1059    /// [`LocalSet`]: tokio::task::LocalSet
1060    /// [`join_next`]: JoinMap::join_next
1061    #[track_caller]
1062    pub fn spawn_local<F>(&mut self, key: K, task: F) -> std::io::Result<()>
1063    where
1064        F: Future<Output = V>,
1065        F: 'static,
1066    {
1067        let builder = self.joinmap.tasks.build_task();
1068        let builder = if let Some(name) = self.name {
1069            builder.name(name)
1070        } else {
1071            builder
1072        };
1073        let abort = builder.spawn_local(task)?;
1074
1075        Ok(self.joinmap.insert(key, abort))
1076    }
1077
1078    /// Spawn the provided task on the provided [`LocalSet`] and store it in
1079    /// this `JoinMap` with the provided key.
1080    ///
1081    /// If a task previously existed in the `JoinMap` for this key, that task
1082    /// will be cancelled and replaced with the new one. The previous task will
1083    /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
1084    /// *not* return a cancelled [`JoinError`] for that task.
1085    ///
1086    /// [`LocalSet`]: tokio::task::LocalSet
1087    /// [`join_next`]: JoinMap::join_next
1088    #[track_caller]
1089    pub fn spawn_local_on<F>(
1090        &mut self,
1091        key: K,
1092        task: F,
1093        local_set: &LocalSet,
1094    ) -> std::io::Result<()>
1095    where
1096        F: Future<Output = V>,
1097        F: 'static,
1098    {
1099        let builder = self.joinmap.tasks.build_task();
1100        let builder = if let Some(name) = self.name {
1101            builder.name(name)
1102        } else {
1103            builder
1104        };
1105        let abort = builder.spawn_local_on(task, local_set)?;
1106
1107        Ok(self.joinmap.insert(key, abort))
1108    }
1109}
1110
1111// Manual `Debug` impl so that `Builder` is `Debug` regardless of whether `V` and `S` are `Debug`.
1112#[cfg(feature = "tracing")]
1113#[cfg_attr(
1114    docsrs,
1115    doc(cfg(all(feature = "rt", feature = "tracing", tokio_unstable)))
1116)]
1117impl<'a, K: fmt::Debug, V, S> fmt::Debug for Builder<'a, K, V, S> {
1118    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1119        f.debug_struct("join_map::Builder")
1120            .field("joinmap", &self.joinmap)
1121            .field("name", &self.name)
1122            .finish()
1123    }
1124}
1125
1126// === impl Key ===
1127
1128impl<K: Hash> Hash for Key<K> {
1129    // Don't include the task ID in the hash.
1130    #[inline]
1131    fn hash<H: Hasher>(&self, hasher: &mut H) {
1132        self.key.hash(hasher);
1133    }
1134}
1135
1136// Because we override `Hash` for this type, we must also override the
1137// `PartialEq` impl, so that all instances with the same hash are equal.
1138impl<K: PartialEq> PartialEq for Key<K> {
1139    #[inline]
1140    fn eq(&self, other: &Self) -> bool {
1141        self.key == other.key
1142    }
1143}
1144
1145impl<K: Eq> Eq for Key<K> {}
1146
1147/// An iterator over the keys of a [`JoinMap`].
1148#[derive(Debug, Clone)]
1149pub struct JoinMapKeys<'a, K, V> {
1150    iter: hashbrown::hash_map::Keys<'a, Key<K>, AbortHandle>,
1151    /// To make it easier to change `JoinMap` in the future, keep V as a generic
1152    /// parameter.
1153    _value: PhantomData<&'a V>,
1154}
1155
1156impl<'a, K, V> Iterator for JoinMapKeys<'a, K, V> {
1157    type Item = &'a K;
1158
1159    fn next(&mut self) -> Option<&'a K> {
1160        self.iter.next().map(|key| &key.key)
1161    }
1162
1163    fn size_hint(&self) -> (usize, Option<usize>) {
1164        self.iter.size_hint()
1165    }
1166}
1167
1168impl<'a, K, V> ExactSizeIterator for JoinMapKeys<'a, K, V> {
1169    fn len(&self) -> usize {
1170        self.iter.len()
1171    }
1172}
1173
1174impl<'a, K, V> std::iter::FusedIterator for JoinMapKeys<'a, K, V> {}