tokio_util/task/join_map.rs
1use hashbrown::hash_map::RawEntryMut;
2use hashbrown::HashMap;
3use std::borrow::Borrow;
4use std::collections::hash_map::RandomState;
5use std::fmt;
6use std::future::Future;
7use std::hash::{BuildHasher, Hash, Hasher};
8use std::marker::PhantomData;
9use tokio::runtime::Handle;
10use tokio::task::{AbortHandle, Id, JoinError, JoinSet, LocalSet};
11
12/// A collection of tasks spawned on a Tokio runtime, associated with hash map
13/// keys.
14///
15/// This type is very similar to the [`JoinSet`] type in `tokio::task`, with the
16/// addition of a set of keys associated with each task. These keys allow
17/// [cancelling a task][abort] or [multiple tasks][abort_matching] in the
18/// `JoinMap` based on their keys, or [test whether a task corresponding to a
19/// given key exists][contains] in the `JoinMap`.
20///
21/// In addition, when tasks in the `JoinMap` complete, they will return the
22/// associated key along with the value returned by the task, if any.
23///
24/// A `JoinMap` can be used to await the completion of some or all of the tasks
25/// in the map. The map is not ordered, and the tasks will be returned in the
26/// order they complete.
27///
28/// All of the tasks must have the same return type `V`.
29///
30/// When the `JoinMap` is dropped, all tasks in the `JoinMap` are immediately aborted.
31///
32/// **Note**: This type depends on Tokio's [unstable API][unstable]. See [the
33/// documentation on unstable features][unstable] for details on how to enable
34/// Tokio's unstable features.
35///
36/// # Examples
37///
38/// Spawn multiple tasks and wait for them:
39///
40/// ```
41/// use tokio_util::task::JoinMap;
42///
43/// #[tokio::main]
44/// async fn main() {
45/// let mut map = JoinMap::new();
46///
47/// for i in 0..10 {
48/// // Spawn a task on the `JoinMap` with `i` as its key.
49/// map.spawn(i, async move { /* ... */ });
50/// }
51///
52/// let mut seen = [false; 10];
53///
54/// // When a task completes, `join_next` returns the task's key along
55/// // with its output.
56/// while let Some((key, res)) = map.join_next().await {
57/// seen[key] = true;
58/// assert!(res.is_ok(), "task {} completed successfully!", key);
59/// }
60///
61/// for i in 0..10 {
62/// assert!(seen[i]);
63/// }
64/// }
65/// ```
66///
67/// Cancel tasks based on their keys:
68///
69/// ```
70/// use tokio_util::task::JoinMap;
71///
72/// #[tokio::main]
73/// async fn main() {
74/// let mut map = JoinMap::new();
75///
76/// map.spawn("hello world", async move { /* ... */ });
77/// map.spawn("goodbye world", async move { /* ... */});
78///
79/// // Look up the "goodbye world" task in the map and abort it.
80/// let aborted = map.abort("goodbye world");
81///
82/// // `JoinMap::abort` returns `true` if a task existed for the
83/// // provided key.
84/// assert!(aborted);
85///
86/// while let Some((key, res)) = map.join_next().await {
87/// if key == "goodbye world" {
88/// // The aborted task should complete with a cancelled `JoinError`.
89/// assert!(res.unwrap_err().is_cancelled());
90/// } else {
91/// // Other tasks should complete normally.
92/// assert!(res.is_ok());
93/// }
94/// }
95/// }
96/// ```
97///
98/// [`JoinSet`]: tokio::task::JoinSet
99/// [unstable]: tokio#unstable-features
100/// [abort]: fn@Self::abort
101/// [abort_matching]: fn@Self::abort_matching
102/// [contains]: fn@Self::contains_key
103#[cfg_attr(docsrs, doc(cfg(all(feature = "rt", tokio_unstable))))]
104pub struct JoinMap<K, V, S = RandomState> {
105 /// A map of the [`AbortHandle`]s of the tasks spawned on this `JoinMap`,
106 /// indexed by their keys and task IDs.
107 ///
108 /// The [`Key`] type contains both the task's `K`-typed key provided when
109 /// spawning tasks, and the task's IDs. The IDs are stored here to resolve
110 /// hash collisions when looking up tasks based on their pre-computed hash
111 /// (as stored in the `hashes_by_task` map).
112 tasks_by_key: HashMap<Key<K>, AbortHandle, S>,
113
114 /// A map from task IDs to the hash of the key associated with that task.
115 ///
116 /// This map is used to perform reverse lookups of tasks in the
117 /// `tasks_by_key` map based on their task IDs. When a task terminates, the
118 /// ID is provided to us by the `JoinSet`, so we can look up the hash value
119 /// of that task's key, and then remove it from the `tasks_by_key` map using
120 /// the raw hash code, resolving collisions by comparing task IDs.
121 hashes_by_task: HashMap<Id, u64, S>,
122
123 /// The [`JoinSet`] that awaits the completion of tasks spawned on this
124 /// `JoinMap`.
125 tasks: JoinSet<V>,
126}
127
128/// A variant of [`task::Builder`] that spawns tasks on a [`JoinMap`] rather than on the current
129/// default runtime.
130///
131/// [`task::Builder`]: tokio::task::Builder
132#[cfg(feature = "tracing")]
133#[cfg_attr(
134 docsrs,
135 doc(cfg(all(feature = "rt", feature = "tracing", tokio_unstable)))
136)]
137pub struct Builder<'a, K, V, S> {
138 joinmap: &'a mut JoinMap<K, V, S>,
139 name: Option<&'a str>,
140}
141
142/// A [`JoinMap`] key.
143///
144/// This holds both a `K`-typed key (the actual key as seen by the user), _and_
145/// a task ID, so that hash collisions between `K`-typed keys can be resolved
146/// using either `K`'s `Eq` impl *or* by checking the task IDs.
147///
148/// This allows looking up a task using either an actual key (such as when the
149/// user queries the map with a key), *or* using a task ID and a hash (such as
150/// when removing completed tasks from the map).
151#[derive(Debug)]
152struct Key<K> {
153 key: K,
154 id: Id,
155}
156
157impl<K, V> JoinMap<K, V> {
158 /// Creates a new empty `JoinMap`.
159 ///
160 /// The `JoinMap` is initially created with a capacity of 0, so it will not
161 /// allocate until a task is first spawned on it.
162 ///
163 /// # Examples
164 ///
165 /// ```
166 /// use tokio_util::task::JoinMap;
167 /// let map: JoinMap<&str, i32> = JoinMap::new();
168 /// ```
169 #[inline]
170 #[must_use]
171 pub fn new() -> Self {
172 Self::with_hasher(RandomState::new())
173 }
174
175 /// Creates an empty `JoinMap` with the specified capacity.
176 ///
177 /// The `JoinMap` will be able to hold at least `capacity` tasks without
178 /// reallocating.
179 ///
180 /// # Examples
181 ///
182 /// ```
183 /// use tokio_util::task::JoinMap;
184 /// let map: JoinMap<&str, i32> = JoinMap::with_capacity(10);
185 /// ```
186 #[inline]
187 #[must_use]
188 pub fn with_capacity(capacity: usize) -> Self {
189 JoinMap::with_capacity_and_hasher(capacity, Default::default())
190 }
191}
192
193impl<K, V, S: Clone> JoinMap<K, V, S> {
194 /// Creates an empty `JoinMap` which will use the given hash builder to hash
195 /// keys.
196 ///
197 /// The created map has the default initial capacity.
198 ///
199 /// Warning: `hash_builder` is normally randomly generated, and
200 /// is designed to allow `JoinMap` to be resistant to attacks that
201 /// cause many collisions and very poor performance. Setting it
202 /// manually using this function can expose a DoS attack vector.
203 ///
204 /// The `hash_builder` passed should implement the [`BuildHasher`] trait for
205 /// the `JoinMap` to be useful, see its documentation for details.
206 #[inline]
207 #[must_use]
208 pub fn with_hasher(hash_builder: S) -> Self {
209 Self::with_capacity_and_hasher(0, hash_builder)
210 }
211
212 /// Creates an empty `JoinMap` with the specified capacity, using `hash_builder`
213 /// to hash the keys.
214 ///
215 /// The `JoinMap` will be able to hold at least `capacity` elements without
216 /// reallocating. If `capacity` is 0, the `JoinMap` will not allocate.
217 ///
218 /// Warning: `hash_builder` is normally randomly generated, and
219 /// is designed to allow HashMaps to be resistant to attacks that
220 /// cause many collisions and very poor performance. Setting it
221 /// manually using this function can expose a DoS attack vector.
222 ///
223 /// The `hash_builder` passed should implement the [`BuildHasher`] trait for
224 /// the `JoinMap`to be useful, see its documentation for details.
225 ///
226 /// # Examples
227 ///
228 /// ```
229 /// # #[tokio::main]
230 /// # async fn main() {
231 /// use tokio_util::task::JoinMap;
232 /// use std::collections::hash_map::RandomState;
233 ///
234 /// let s = RandomState::new();
235 /// let mut map = JoinMap::with_capacity_and_hasher(10, s);
236 /// map.spawn(1, async move { "hello world!" });
237 /// # }
238 /// ```
239 #[inline]
240 #[must_use]
241 pub fn with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Self {
242 Self {
243 tasks_by_key: HashMap::with_capacity_and_hasher(capacity, hash_builder.clone()),
244 hashes_by_task: HashMap::with_capacity_and_hasher(capacity, hash_builder),
245 tasks: JoinSet::new(),
246 }
247 }
248
249 /// Returns the number of tasks currently in the `JoinMap`.
250 pub fn len(&self) -> usize {
251 let len = self.tasks_by_key.len();
252 debug_assert_eq!(len, self.hashes_by_task.len());
253 len
254 }
255
256 /// Returns whether the `JoinMap` is empty.
257 pub fn is_empty(&self) -> bool {
258 let empty = self.tasks_by_key.is_empty();
259 debug_assert_eq!(empty, self.hashes_by_task.is_empty());
260 empty
261 }
262
263 /// Returns the number of tasks the map can hold without reallocating.
264 ///
265 /// This number is a lower bound; the `JoinMap` might be able to hold
266 /// more, but is guaranteed to be able to hold at least this many.
267 ///
268 /// # Examples
269 ///
270 /// ```
271 /// use tokio_util::task::JoinMap;
272 ///
273 /// let map: JoinMap<i32, i32> = JoinMap::with_capacity(100);
274 /// assert!(map.capacity() >= 100);
275 /// ```
276 #[inline]
277 pub fn capacity(&self) -> usize {
278 let capacity = self.tasks_by_key.capacity();
279 debug_assert_eq!(capacity, self.hashes_by_task.capacity());
280 capacity
281 }
282}
283
284impl<K, V, S> JoinMap<K, V, S>
285where
286 K: Hash + Eq,
287 V: 'static,
288 S: BuildHasher,
289{
290 /// Returns a [`Builder`] that can be used to configure a task prior to spawning it on this
291 /// [`JoinMap`].
292 ///
293 /// # Examples
294 ///
295 /// ```
296 /// use tokio_util::task::JoinMap;
297 ///
298 /// #[tokio::main]
299 /// async fn main() -> std::io::Result<()> {
300 /// let mut map = JoinMap::new();
301 ///
302 /// // Use the builder to configure the task's name before spawning it.
303 /// map.build_task()
304 /// .name("my_task")
305 /// .spawn(42, async { /* ... */ });
306 ///
307 /// Ok(())
308 /// }
309 /// ```
310 #[cfg(feature = "tracing")]
311 #[cfg_attr(
312 docsrs,
313 doc(cfg(all(feature = "rt", feature = "tracing", tokio_unstable)))
314 )]
315 pub fn build_task(&mut self) -> Builder<'_, K, V, S> {
316 Builder {
317 joinmap: self,
318 name: None,
319 }
320 }
321
322 /// Spawn the provided task and store it in this `JoinMap` with the provided
323 /// key.
324 ///
325 /// If a task previously existed in the `JoinMap` for this key, that task
326 /// will be cancelled and replaced with the new one. The previous task will
327 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
328 /// *not* return a cancelled [`JoinError`] for that task.
329 ///
330 /// # Panics
331 ///
332 /// This method panics if called outside of a Tokio runtime.
333 ///
334 /// [`join_next`]: Self::join_next
335 #[track_caller]
336 pub fn spawn<F>(&mut self, key: K, task: F)
337 where
338 F: Future<Output = V>,
339 F: Send + 'static,
340 V: Send,
341 {
342 let task = self.tasks.spawn(task);
343 self.insert(key, task)
344 }
345
346 /// Spawn the provided task on the provided runtime and store it in this
347 /// `JoinMap` with the provided key.
348 ///
349 /// If a task previously existed in the `JoinMap` for this key, that task
350 /// will be cancelled and replaced with the new one. The previous task will
351 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
352 /// *not* return a cancelled [`JoinError`] for that task.
353 ///
354 /// [`join_next`]: Self::join_next
355 #[track_caller]
356 pub fn spawn_on<F>(&mut self, key: K, task: F, handle: &Handle)
357 where
358 F: Future<Output = V>,
359 F: Send + 'static,
360 V: Send,
361 {
362 let task = self.tasks.spawn_on(task, handle);
363 self.insert(key, task);
364 }
365
366 /// Spawn the blocking code on the blocking threadpool and store it in this `JoinMap` with the provided
367 /// key.
368 ///
369 /// If a task previously existed in the `JoinMap` for this key, that task
370 /// will be cancelled and replaced with the new one. The previous task will
371 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
372 /// *not* return a cancelled [`JoinError`] for that task.
373 ///
374 /// Note that blocking tasks cannot be cancelled after execution starts.
375 /// Replaced blocking tasks will still run to completion if the task has begun
376 /// to execute when it is replaced. A blocking task which is replaced before
377 /// it has been scheduled on a blocking worker thread will be cancelled.
378 ///
379 /// # Panics
380 ///
381 /// This method panics if called outside of a Tokio runtime.
382 ///
383 /// [`join_next`]: Self::join_next
384 #[track_caller]
385 pub fn spawn_blocking<F>(&mut self, key: K, f: F)
386 where
387 F: FnOnce() -> V,
388 F: Send + 'static,
389 V: Send,
390 {
391 let task = self.tasks.spawn_blocking(f);
392 self.insert(key, task)
393 }
394
395 /// Spawn the blocking code on the blocking threadpool of the provided runtime and store it in this
396 /// `JoinMap` with the provided key.
397 ///
398 /// If a task previously existed in the `JoinMap` for this key, that task
399 /// will be cancelled and replaced with the new one. The previous task will
400 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
401 /// *not* return a cancelled [`JoinError`] for that task.
402 ///
403 /// Note that blocking tasks cannot be cancelled after execution starts.
404 /// Replaced blocking tasks will still run to completion if the task has begun
405 /// to execute when it is replaced. A blocking task which is replaced before
406 /// it has been scheduled on a blocking worker thread will be cancelled.
407 ///
408 /// [`join_next`]: Self::join_next
409 #[track_caller]
410 pub fn spawn_blocking_on<F>(&mut self, key: K, f: F, handle: &Handle)
411 where
412 F: FnOnce() -> V,
413 F: Send + 'static,
414 V: Send,
415 {
416 let task = self.tasks.spawn_blocking_on(f, handle);
417 self.insert(key, task);
418 }
419
420 /// Spawn the provided task on the current [`LocalSet`] and store it in this
421 /// `JoinMap` with the provided key.
422 ///
423 /// If a task previously existed in the `JoinMap` for this key, that task
424 /// will be cancelled and replaced with the new one. The previous task will
425 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
426 /// *not* return a cancelled [`JoinError`] for that task.
427 ///
428 /// # Panics
429 ///
430 /// This method panics if it is called outside of a `LocalSet`.
431 ///
432 /// [`LocalSet`]: tokio::task::LocalSet
433 /// [`join_next`]: Self::join_next
434 #[track_caller]
435 pub fn spawn_local<F>(&mut self, key: K, task: F)
436 where
437 F: Future<Output = V>,
438 F: 'static,
439 {
440 let task = self.tasks.spawn_local(task);
441 self.insert(key, task);
442 }
443
444 /// Spawn the provided task on the provided [`LocalSet`] and store it in
445 /// this `JoinMap` with the provided key.
446 ///
447 /// If a task previously existed in the `JoinMap` for this key, that task
448 /// will be cancelled and replaced with the new one. The previous task will
449 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
450 /// *not* return a cancelled [`JoinError`] for that task.
451 ///
452 /// [`LocalSet`]: tokio::task::LocalSet
453 /// [`join_next`]: Self::join_next
454 #[track_caller]
455 pub fn spawn_local_on<F>(&mut self, key: K, task: F, local_set: &LocalSet)
456 where
457 F: Future<Output = V>,
458 F: 'static,
459 {
460 let task = self.tasks.spawn_local_on(task, local_set);
461 self.insert(key, task)
462 }
463
464 fn insert(&mut self, key: K, abort: AbortHandle) {
465 let hash = self.hash(&key);
466 let id = abort.id();
467 let map_key = Key { id, key };
468
469 // Insert the new key into the map of tasks by keys.
470 let entry = self
471 .tasks_by_key
472 .raw_entry_mut()
473 .from_hash(hash, |k| k.key == map_key.key);
474 match entry {
475 RawEntryMut::Occupied(mut occ) => {
476 // There was a previous task spawned with the same key! Cancel
477 // that task, and remove its ID from the map of hashes by task IDs.
478 let Key { id: prev_id, .. } = occ.insert_key(map_key);
479 occ.insert(abort).abort();
480 let _prev_hash = self.hashes_by_task.remove(&prev_id);
481 debug_assert_eq!(Some(hash), _prev_hash);
482 }
483 RawEntryMut::Vacant(vac) => {
484 vac.insert(map_key, abort);
485 }
486 };
487
488 // Associate the key's hash with this task's ID, for looking up tasks by ID.
489 let _prev = self.hashes_by_task.insert(id, hash);
490 debug_assert!(_prev.is_none(), "no prior task should have had the same ID");
491 }
492
493 /// Waits until one of the tasks in the map completes and returns its
494 /// output, along with the key corresponding to that task.
495 ///
496 /// Returns `None` if the map is empty.
497 ///
498 /// # Cancel Safety
499 ///
500 /// This method is cancel safe. If `join_next` is used as the event in a [`tokio::select!`]
501 /// statement and some other branch completes first, it is guaranteed that no tasks were
502 /// removed from this `JoinMap`.
503 ///
504 /// # Returns
505 ///
506 /// This function returns:
507 ///
508 /// * `Some((key, Ok(value)))` if one of the tasks in this `JoinMap` has
509 /// completed. The `value` is the return value of that ask, and `key` is
510 /// the key associated with the task.
511 /// * `Some((key, Err(err))` if one of the tasks in this `JoinMap` has
512 /// panicked or been aborted. `key` is the key associated with the task
513 /// that panicked or was aborted.
514 /// * `None` if the `JoinMap` is empty.
515 ///
516 /// [`tokio::select!`]: tokio::select
517 pub async fn join_next(&mut self) -> Option<(K, Result<V, JoinError>)> {
518 let (res, id) = match self.tasks.join_next_with_id().await {
519 Some(Ok((id, output))) => (Ok(output), id),
520 Some(Err(e)) => {
521 let id = e.id();
522 (Err(e), id)
523 }
524 None => return None,
525 };
526 let key = self.remove_by_id(id)?;
527 Some((key, res))
528 }
529
530 /// Aborts all tasks and waits for them to finish shutting down.
531 ///
532 /// Calling this method is equivalent to calling [`abort_all`] and then calling [`join_next`] in
533 /// a loop until it returns `None`.
534 ///
535 /// This method ignores any panics in the tasks shutting down. When this call returns, the
536 /// `JoinMap` will be empty.
537 ///
538 /// [`abort_all`]: fn@Self::abort_all
539 /// [`join_next`]: fn@Self::join_next
540 pub async fn shutdown(&mut self) {
541 self.abort_all();
542 while self.join_next().await.is_some() {}
543 }
544
545 /// Abort the task corresponding to the provided `key`.
546 ///
547 /// If this `JoinMap` contains a task corresponding to `key`, this method
548 /// will abort that task and return `true`. Otherwise, if no task exists for
549 /// `key`, this method returns `false`.
550 ///
551 /// # Examples
552 ///
553 /// Aborting a task by key:
554 ///
555 /// ```
556 /// use tokio_util::task::JoinMap;
557 ///
558 /// # #[tokio::main]
559 /// # async fn main() {
560 /// let mut map = JoinMap::new();
561 ///
562 /// map.spawn("hello world", async move { /* ... */ });
563 /// map.spawn("goodbye world", async move { /* ... */});
564 ///
565 /// // Look up the "goodbye world" task in the map and abort it.
566 /// map.abort("goodbye world");
567 ///
568 /// while let Some((key, res)) = map.join_next().await {
569 /// if key == "goodbye world" {
570 /// // The aborted task should complete with a cancelled `JoinError`.
571 /// assert!(res.unwrap_err().is_cancelled());
572 /// } else {
573 /// // Other tasks should complete normally.
574 /// assert!(res.is_ok());
575 /// }
576 /// }
577 /// # }
578 /// ```
579 ///
580 /// `abort` returns `true` if a task was aborted:
581 /// ```
582 /// use tokio_util::task::JoinMap;
583 ///
584 /// # #[tokio::main]
585 /// # async fn main() {
586 /// let mut map = JoinMap::new();
587 ///
588 /// map.spawn("hello world", async move { /* ... */ });
589 /// map.spawn("goodbye world", async move { /* ... */});
590 ///
591 /// // A task for the key "goodbye world" should exist in the map:
592 /// assert!(map.abort("goodbye world"));
593 ///
594 /// // Aborting a key that does not exist will return `false`:
595 /// assert!(!map.abort("goodbye universe"));
596 /// # }
597 /// ```
598 pub fn abort<Q: ?Sized>(&mut self, key: &Q) -> bool
599 where
600 Q: Hash + Eq,
601 K: Borrow<Q>,
602 {
603 match self.get_by_key(key) {
604 Some((_, handle)) => {
605 handle.abort();
606 true
607 }
608 None => false,
609 }
610 }
611
612 /// Aborts all tasks with keys matching `predicate`.
613 ///
614 /// `predicate` is a function called with a reference to each key in the
615 /// map. If it returns `true` for a given key, the corresponding task will
616 /// be cancelled.
617 ///
618 /// # Examples
619 /// ```
620 /// use tokio_util::task::JoinMap;
621 ///
622 /// # // use the current thread rt so that spawned tasks don't
623 /// # // complete in the background before they can be aborted.
624 /// # #[tokio::main(flavor = "current_thread")]
625 /// # async fn main() {
626 /// let mut map = JoinMap::new();
627 ///
628 /// map.spawn("hello world", async move {
629 /// // ...
630 /// # tokio::task::yield_now().await; // don't complete immediately, get aborted!
631 /// });
632 /// map.spawn("goodbye world", async move {
633 /// // ...
634 /// # tokio::task::yield_now().await; // don't complete immediately, get aborted!
635 /// });
636 /// map.spawn("hello san francisco", async move {
637 /// // ...
638 /// # tokio::task::yield_now().await; // don't complete immediately, get aborted!
639 /// });
640 /// map.spawn("goodbye universe", async move {
641 /// // ...
642 /// # tokio::task::yield_now().await; // don't complete immediately, get aborted!
643 /// });
644 ///
645 /// // Abort all tasks whose keys begin with "goodbye"
646 /// map.abort_matching(|key| key.starts_with("goodbye"));
647 ///
648 /// let mut seen = 0;
649 /// while let Some((key, res)) = map.join_next().await {
650 /// seen += 1;
651 /// if key.starts_with("goodbye") {
652 /// // The aborted task should complete with a cancelled `JoinError`.
653 /// assert!(res.unwrap_err().is_cancelled());
654 /// } else {
655 /// // Other tasks should complete normally.
656 /// assert!(key.starts_with("hello"));
657 /// assert!(res.is_ok());
658 /// }
659 /// }
660 ///
661 /// // All spawned tasks should have completed.
662 /// assert_eq!(seen, 4);
663 /// # }
664 /// ```
665 pub fn abort_matching(&mut self, mut predicate: impl FnMut(&K) -> bool) {
666 // Note: this method iterates over the tasks and keys *without* removing
667 // any entries, so that the keys from aborted tasks can still be
668 // returned when calling `join_next` in the future.
669 for (Key { ref key, .. }, task) in &self.tasks_by_key {
670 if predicate(key) {
671 task.abort();
672 }
673 }
674 }
675
676 /// Returns an iterator visiting all keys in this `JoinMap` in arbitrary order.
677 ///
678 /// If a task has completed, but its output hasn't yet been consumed by a
679 /// call to [`join_next`], this method will still return its key.
680 ///
681 /// [`join_next`]: fn@Self::join_next
682 pub fn keys(&self) -> JoinMapKeys<'_, K, V> {
683 JoinMapKeys {
684 iter: self.tasks_by_key.keys(),
685 _value: PhantomData,
686 }
687 }
688
689 /// Returns `true` if this `JoinMap` contains a task for the provided key.
690 ///
691 /// If the task has completed, but its output hasn't yet been consumed by a
692 /// call to [`join_next`], this method will still return `true`.
693 ///
694 /// [`join_next`]: fn@Self::join_next
695 pub fn contains_key<Q: ?Sized>(&self, key: &Q) -> bool
696 where
697 Q: Hash + Eq,
698 K: Borrow<Q>,
699 {
700 self.get_by_key(key).is_some()
701 }
702
703 /// Returns `true` if this `JoinMap` contains a task with the provided
704 /// [task ID].
705 ///
706 /// If the task has completed, but its output hasn't yet been consumed by a
707 /// call to [`join_next`], this method will still return `true`.
708 ///
709 /// [`join_next`]: fn@Self::join_next
710 /// [task ID]: tokio::task::Id
711 pub fn contains_task(&self, task: &Id) -> bool {
712 self.get_by_id(task).is_some()
713 }
714
715 /// Reserves capacity for at least `additional` more tasks to be spawned
716 /// on this `JoinMap` without reallocating for the map of task keys. The
717 /// collection may reserve more space to avoid frequent reallocations.
718 ///
719 /// Note that spawning a task will still cause an allocation for the task
720 /// itself.
721 ///
722 /// # Panics
723 ///
724 /// Panics if the new allocation size overflows [`usize`].
725 ///
726 /// # Examples
727 ///
728 /// ```
729 /// use tokio_util::task::JoinMap;
730 ///
731 /// let mut map: JoinMap<&str, i32> = JoinMap::new();
732 /// map.reserve(10);
733 /// ```
734 #[inline]
735 pub fn reserve(&mut self, additional: usize) {
736 self.tasks_by_key.reserve(additional);
737 self.hashes_by_task.reserve(additional);
738 }
739
740 /// Shrinks the capacity of the `JoinMap` as much as possible. It will drop
741 /// down as much as possible while maintaining the internal rules
742 /// and possibly leaving some space in accordance with the resize policy.
743 ///
744 /// # Examples
745 ///
746 /// ```
747 /// # #[tokio::main]
748 /// # async fn main() {
749 /// use tokio_util::task::JoinMap;
750 ///
751 /// let mut map: JoinMap<i32, i32> = JoinMap::with_capacity(100);
752 /// map.spawn(1, async move { 2 });
753 /// map.spawn(3, async move { 4 });
754 /// assert!(map.capacity() >= 100);
755 /// map.shrink_to_fit();
756 /// assert!(map.capacity() >= 2);
757 /// # }
758 /// ```
759 #[inline]
760 pub fn shrink_to_fit(&mut self) {
761 self.hashes_by_task.shrink_to_fit();
762 self.tasks_by_key.shrink_to_fit();
763 }
764
765 /// Shrinks the capacity of the map with a lower limit. It will drop
766 /// down no lower than the supplied limit while maintaining the internal rules
767 /// and possibly leaving some space in accordance with the resize policy.
768 ///
769 /// If the current capacity is less than the lower limit, this is a no-op.
770 ///
771 /// # Examples
772 ///
773 /// ```
774 /// # #[tokio::main]
775 /// # async fn main() {
776 /// use tokio_util::task::JoinMap;
777 ///
778 /// let mut map: JoinMap<i32, i32> = JoinMap::with_capacity(100);
779 /// map.spawn(1, async move { 2 });
780 /// map.spawn(3, async move { 4 });
781 /// assert!(map.capacity() >= 100);
782 /// map.shrink_to(10);
783 /// assert!(map.capacity() >= 10);
784 /// map.shrink_to(0);
785 /// assert!(map.capacity() >= 2);
786 /// # }
787 /// ```
788 #[inline]
789 pub fn shrink_to(&mut self, min_capacity: usize) {
790 self.hashes_by_task.shrink_to(min_capacity);
791 self.tasks_by_key.shrink_to(min_capacity)
792 }
793
794 /// Look up a task in the map by its key, returning the key and abort handle.
795 fn get_by_key<'map, Q: ?Sized>(&'map self, key: &Q) -> Option<(&'map Key<K>, &'map AbortHandle)>
796 where
797 Q: Hash + Eq,
798 K: Borrow<Q>,
799 {
800 let hash = self.hash(key);
801 self.tasks_by_key
802 .raw_entry()
803 .from_hash(hash, |k| k.key.borrow() == key)
804 }
805
806 /// Look up a task in the map by its task ID, returning the key and abort handle.
807 fn get_by_id<'map>(&'map self, id: &Id) -> Option<(&'map Key<K>, &'map AbortHandle)> {
808 let hash = self.hashes_by_task.get(id)?;
809 self.tasks_by_key
810 .raw_entry()
811 .from_hash(*hash, |k| &k.id == id)
812 }
813
814 /// Remove a task from the map by ID, returning the key for that task.
815 fn remove_by_id(&mut self, id: Id) -> Option<K> {
816 // Get the hash for the given ID.
817 let hash = self.hashes_by_task.remove(&id)?;
818
819 // Remove the entry for that hash.
820 let entry = self
821 .tasks_by_key
822 .raw_entry_mut()
823 .from_hash(hash, |k| k.id == id);
824 let (Key { id: _key_id, key }, handle) = match entry {
825 RawEntryMut::Occupied(entry) => entry.remove_entry(),
826 _ => return None,
827 };
828 debug_assert_eq!(_key_id, id);
829 debug_assert_eq!(id, handle.id());
830 self.hashes_by_task.remove(&id);
831 Some(key)
832 }
833
834 /// Returns the hash for a given key.
835 #[inline]
836 fn hash<Q: ?Sized>(&self, key: &Q) -> u64
837 where
838 Q: Hash,
839 {
840 let mut hasher = self.tasks_by_key.hasher().build_hasher();
841 key.hash(&mut hasher);
842 hasher.finish()
843 }
844}
845
846impl<K, V, S> JoinMap<K, V, S>
847where
848 V: 'static,
849{
850 /// Aborts all tasks on this `JoinMap`.
851 ///
852 /// This does not remove the tasks from the `JoinMap`. To wait for the tasks to complete
853 /// cancellation, you should call `join_next` in a loop until the `JoinMap` is empty.
854 pub fn abort_all(&mut self) {
855 self.tasks.abort_all()
856 }
857
858 /// Removes all tasks from this `JoinMap` without aborting them.
859 ///
860 /// The tasks removed by this call will continue to run in the background even if the `JoinMap`
861 /// is dropped. They may still be aborted by key.
862 pub fn detach_all(&mut self) {
863 self.tasks.detach_all();
864 self.tasks_by_key.clear();
865 self.hashes_by_task.clear();
866 }
867}
868
869// Hand-written `fmt::Debug` implementation in order to avoid requiring `V:
870// Debug`, since no value is ever actually stored in the map.
871impl<K: fmt::Debug, V, S> fmt::Debug for JoinMap<K, V, S> {
872 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
873 // format the task keys and abort handles a little nicer by just
874 // printing the key and task ID pairs, without format the `Key` struct
875 // itself or the `AbortHandle`, which would just format the task's ID
876 // again.
877 struct KeySet<'a, K: fmt::Debug, S>(&'a HashMap<Key<K>, AbortHandle, S>);
878 impl<K: fmt::Debug, S> fmt::Debug for KeySet<'_, K, S> {
879 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
880 f.debug_map()
881 .entries(self.0.keys().map(|Key { key, id }| (key, id)))
882 .finish()
883 }
884 }
885
886 f.debug_struct("JoinMap")
887 // The `tasks_by_key` map is the only one that contains information
888 // that's really worth formatting for the user, since it contains
889 // the tasks' keys and IDs. The other fields are basically
890 // implementation details.
891 .field("tasks", &KeySet(&self.tasks_by_key))
892 .finish()
893 }
894}
895
896impl<K, V> Default for JoinMap<K, V> {
897 fn default() -> Self {
898 Self::new()
899 }
900}
901
902// === impl Builder ===
903
904#[cfg(feature = "tracing")]
905#[cfg_attr(
906 docsrs,
907 doc(cfg(all(feature = "rt", feature = "tracing", tokio_unstable)))
908)]
909impl<'a, K, V, S> Builder<'a, K, V, S>
910where
911 K: Hash + Eq,
912 V: 'static,
913 S: BuildHasher,
914{
915 /// Assigns a name to the task which will be spawned.
916 pub fn name(mut self, name: &'a str) -> Self {
917 self.name = Some(name);
918 self
919 }
920
921 /// Spawn the provided task with this builder's settings and store it in this `JoinMap` with
922 /// the provided key.
923 ///
924 /// If a task previously existed in the `JoinMap` for this key, that task
925 /// will be cancelled and replaced with the new one. The previous task will
926 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
927 /// *not* return a cancelled [`JoinError`] for that task.
928 ///
929 /// # Panics
930 ///
931 /// This method panics if called outside of a Tokio runtime.
932 ///
933 /// [`join_next`]: JoinMap::join_next
934 #[track_caller]
935 pub fn spawn<F>(self, key: K, task: F) -> std::io::Result<()>
936 where
937 F: Future<Output = V>,
938 F: Send + 'static,
939 V: Send,
940 {
941 let builder = self.joinmap.tasks.build_task();
942 let builder = if let Some(name) = self.name {
943 builder.name(name)
944 } else {
945 builder
946 };
947 let abort = builder.spawn(task)?;
948
949 Ok(self.joinmap.insert(key, abort))
950 }
951
952 /// Spawn the provided task on the provided runtime and store it in this
953 /// `JoinMap` with the provided key.
954 ///
955 /// If a task previously existed in the `JoinMap` for this key, that task
956 /// will be cancelled and replaced with the new one. The previous task will
957 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
958 /// *not* return a cancelled [`JoinError`] for that task.
959 ///
960 /// [`join_next`]: JoinMap::join_next
961 #[track_caller]
962 pub fn spawn_on<F>(&mut self, key: K, task: F, handle: &Handle) -> std::io::Result<()>
963 where
964 F: Future<Output = V>,
965 F: Send + 'static,
966 V: Send,
967 {
968 let builder = self.joinmap.tasks.build_task();
969 let builder = if let Some(name) = self.name {
970 builder.name(name)
971 } else {
972 builder
973 };
974 let abort = builder.spawn_on(task, handle)?;
975
976 Ok(self.joinmap.insert(key, abort))
977 }
978
979 /// Spawn the blocking code on the blocking threadpool and store it in this `JoinMap` with the provided
980 /// key.
981 ///
982 /// If a task previously existed in the `JoinMap` for this key, that task
983 /// will be cancelled and replaced with the new one. The previous task will
984 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
985 /// *not* return a cancelled [`JoinError`] for that task.
986 ///
987 /// Note that blocking tasks cannot be cancelled after execution starts.
988 /// Replaced blocking tasks will still run to completion if the task has begun
989 /// to execute when it is replaced. A blocking task which is replaced before
990 /// it has been scheduled on a blocking worker thread will be cancelled.
991 ///
992 /// # Panics
993 ///
994 /// This method panics if called outside of a Tokio runtime.
995 ///
996 /// [`join_next`]: JoinMap::join_next
997 #[track_caller]
998 pub fn spawn_blocking<F>(&mut self, key: K, f: F) -> std::io::Result<()>
999 where
1000 F: FnOnce() -> V,
1001 F: Send + 'static,
1002 V: Send,
1003 {
1004 let builder = self.joinmap.tasks.build_task();
1005 let builder = if let Some(name) = self.name {
1006 builder.name(name)
1007 } else {
1008 builder
1009 };
1010 let abort = builder.spawn_blocking(f)?;
1011
1012 Ok(self.joinmap.insert(key, abort))
1013 }
1014
1015 /// Spawn the blocking code on the blocking threadpool of the provided runtime and store it in this
1016 /// `JoinMap` with the provided key.
1017 ///
1018 /// If a task previously existed in the `JoinMap` for this key, that task
1019 /// will be cancelled and replaced with the new one. The previous task will
1020 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
1021 /// *not* return a cancelled [`JoinError`] for that task.
1022 ///
1023 /// Note that blocking tasks cannot be cancelled after execution starts.
1024 /// Replaced blocking tasks will still run to completion if the task has begun
1025 /// to execute when it is replaced. A blocking task which is replaced before
1026 /// it has been scheduled on a blocking worker thread will be cancelled.
1027 ///
1028 /// [`join_next`]: JoinMap::join_next
1029 #[track_caller]
1030 pub fn spawn_blocking_on<F>(&mut self, key: K, f: F, handle: &Handle) -> std::io::Result<()>
1031 where
1032 F: FnOnce() -> V,
1033 F: Send + 'static,
1034 V: Send,
1035 {
1036 let builder = self.joinmap.tasks.build_task();
1037 let builder = if let Some(name) = self.name {
1038 builder.name(name)
1039 } else {
1040 builder
1041 };
1042 let abort = builder.spawn_blocking_on(f, handle)?;
1043
1044 Ok(self.joinmap.insert(key, abort))
1045 }
1046
1047 /// Spawn the provided task on the current [`LocalSet`] and store it in this
1048 /// `JoinMap` with the provided key.
1049 ///
1050 /// If a task previously existed in the `JoinMap` for this key, that task
1051 /// will be cancelled and replaced with the new one. The previous task will
1052 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
1053 /// *not* return a cancelled [`JoinError`] for that task.
1054 ///
1055 /// # Panics
1056 ///
1057 /// This method panics if it is called outside of a `LocalSet`.
1058 ///
1059 /// [`LocalSet`]: tokio::task::LocalSet
1060 /// [`join_next`]: JoinMap::join_next
1061 #[track_caller]
1062 pub fn spawn_local<F>(&mut self, key: K, task: F) -> std::io::Result<()>
1063 where
1064 F: Future<Output = V>,
1065 F: 'static,
1066 {
1067 let builder = self.joinmap.tasks.build_task();
1068 let builder = if let Some(name) = self.name {
1069 builder.name(name)
1070 } else {
1071 builder
1072 };
1073 let abort = builder.spawn_local(task)?;
1074
1075 Ok(self.joinmap.insert(key, abort))
1076 }
1077
1078 /// Spawn the provided task on the provided [`LocalSet`] and store it in
1079 /// this `JoinMap` with the provided key.
1080 ///
1081 /// If a task previously existed in the `JoinMap` for this key, that task
1082 /// will be cancelled and replaced with the new one. The previous task will
1083 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
1084 /// *not* return a cancelled [`JoinError`] for that task.
1085 ///
1086 /// [`LocalSet`]: tokio::task::LocalSet
1087 /// [`join_next`]: JoinMap::join_next
1088 #[track_caller]
1089 pub fn spawn_local_on<F>(
1090 &mut self,
1091 key: K,
1092 task: F,
1093 local_set: &LocalSet,
1094 ) -> std::io::Result<()>
1095 where
1096 F: Future<Output = V>,
1097 F: 'static,
1098 {
1099 let builder = self.joinmap.tasks.build_task();
1100 let builder = if let Some(name) = self.name {
1101 builder.name(name)
1102 } else {
1103 builder
1104 };
1105 let abort = builder.spawn_local_on(task, local_set)?;
1106
1107 Ok(self.joinmap.insert(key, abort))
1108 }
1109}
1110
1111// Manual `Debug` impl so that `Builder` is `Debug` regardless of whether `V` and `S` are `Debug`.
1112#[cfg(feature = "tracing")]
1113#[cfg_attr(
1114 docsrs,
1115 doc(cfg(all(feature = "rt", feature = "tracing", tokio_unstable)))
1116)]
1117impl<'a, K: fmt::Debug, V, S> fmt::Debug for Builder<'a, K, V, S> {
1118 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1119 f.debug_struct("join_map::Builder")
1120 .field("joinmap", &self.joinmap)
1121 .field("name", &self.name)
1122 .finish()
1123 }
1124}
1125
1126// === impl Key ===
1127
1128impl<K: Hash> Hash for Key<K> {
1129 // Don't include the task ID in the hash.
1130 #[inline]
1131 fn hash<H: Hasher>(&self, hasher: &mut H) {
1132 self.key.hash(hasher);
1133 }
1134}
1135
1136// Because we override `Hash` for this type, we must also override the
1137// `PartialEq` impl, so that all instances with the same hash are equal.
1138impl<K: PartialEq> PartialEq for Key<K> {
1139 #[inline]
1140 fn eq(&self, other: &Self) -> bool {
1141 self.key == other.key
1142 }
1143}
1144
1145impl<K: Eq> Eq for Key<K> {}
1146
1147/// An iterator over the keys of a [`JoinMap`].
1148#[derive(Debug, Clone)]
1149pub struct JoinMapKeys<'a, K, V> {
1150 iter: hashbrown::hash_map::Keys<'a, Key<K>, AbortHandle>,
1151 /// To make it easier to change `JoinMap` in the future, keep V as a generic
1152 /// parameter.
1153 _value: PhantomData<&'a V>,
1154}
1155
1156impl<'a, K, V> Iterator for JoinMapKeys<'a, K, V> {
1157 type Item = &'a K;
1158
1159 fn next(&mut self) -> Option<&'a K> {
1160 self.iter.next().map(|key| &key.key)
1161 }
1162
1163 fn size_hint(&self) -> (usize, Option<usize>) {
1164 self.iter.size_hint()
1165 }
1166}
1167
1168impl<'a, K, V> ExactSizeIterator for JoinMapKeys<'a, K, V> {
1169 fn len(&self) -> usize {
1170 self.iter.len()
1171 }
1172}
1173
1174impl<'a, K, V> std::iter::FusedIterator for JoinMapKeys<'a, K, V> {}