From 0095c7fe2924a2e77dbfafff8be4a2ccb1d2572d Mon Sep 17 00:00:00 2001 From: noah Date: Mon, 3 Feb 2025 21:41:42 -0600 Subject: [PATCH 1/9] rt: overhaul task hooks This change overhauls the entire task hooks system so that users can propagate arbitrary information between task hook invocations and pass context data between the hook "harnesses" for parent and child tasks at time of spawn. This is intended to be significantly more extensible and long-term maintainable than the current task hooks system, and should ultimately be much easier to stabilize. --- tokio/src/lib.rs | 3 - tokio/src/runtime/blocking/pool.rs | 5 + tokio/src/runtime/blocking/schedule.rs | 24 +- tokio/src/runtime/builder.rs | 231 +-------- tokio/src/runtime/config.rs | 19 +- tokio/src/runtime/context.rs | 68 ++- tokio/src/runtime/handle.rs | 23 +- tokio/src/runtime/local_runtime/runtime.rs | 4 +- tokio/src/runtime/mod.rs | 7 +- tokio/src/runtime/runtime.rs | 9 + .../runtime/scheduler/current_thread/mod.rs | 156 ++++-- tokio/src/runtime/scheduler/mod.rs | 35 +- .../runtime/scheduler/multi_thread/handle.rs | 82 +++- .../runtime/scheduler/multi_thread/worker.rs | 43 +- tokio/src/runtime/task/core.rs | 28 +- tokio/src/runtime/task/harness.rs | 39 +- tokio/src/runtime/task/list.rs | 25 +- tokio/src/runtime/task/mod.rs | 36 +- tokio/src/runtime/task/raw.rs | 49 +- tokio/src/runtime/task_hooks.rs | 81 ---- tokio/src/runtime/task_hooks/mod.rs | 99 ++++ tokio/src/runtime/tests/mod.rs | 22 +- tokio/src/runtime/tests/queue.rs | 1 - tokio/src/runtime/tests/task.rs | 19 +- tokio/src/task/builder.rs | 29 +- tokio/src/task/join_set.rs | 10 +- tokio/src/task/local.rs | 41 +- tokio/src/task/mod.rs | 4 + tokio/src/task/spawn.rs | 36 +- tokio/tests/rt_poll_callbacks.rs | 128 ----- tokio/tests/task_builder.rs | 30 +- tokio/tests/task_hooks.rs | 450 ++++++++++++++++-- tokio/tests/tracing_task.rs | 8 +- 33 files changed, 1159 insertions(+), 685 deletions(-) delete mode 100644 tokio/src/runtime/task_hooks.rs create mode 100644 tokio/src/runtime/task_hooks/mod.rs delete mode 100644 tokio/tests/rt_poll_callbacks.rs diff --git a/tokio/src/lib.rs b/tokio/src/lib.rs index 6b0f48bd105..cf287b0dac8 100644 --- a/tokio/src/lib.rs +++ b/tokio/src/lib.rs @@ -351,10 +351,7 @@ //! - [`task::Builder`] //! - Some methods on [`task::JoinSet`] //! - [`runtime::RuntimeMetrics`] -//! - [`runtime::Builder::on_task_spawn`] -//! - [`runtime::Builder::on_task_terminate`] //! - [`runtime::Builder::unhandled_panic`] -//! - [`runtime::TaskMeta`] //! //! This flag enables **unstable** features. The public API of these features //! may break in 1.x releases. To enable these features, the `--cfg diff --git a/tokio/src/runtime/blocking/pool.rs b/tokio/src/runtime/blocking/pool.rs index 23180dc5245..990a1fd4a7b 100644 --- a/tokio/src/runtime/blocking/pool.rs +++ b/tokio/src/runtime/blocking/pool.rs @@ -375,10 +375,15 @@ impl Spawner { F: FnOnce() -> R + Send + 'static, R: Send + 'static, { + // let parent = with_c let id = task::Id::next(); let fut = blocking_task::>(BlockingTask::new(func), spawn_meta, id.as_u64()); + #[cfg(tokio_unstable)] + let (task, handle) = task::unowned(fut, BlockingSchedule::new(rt), id, None); + + #[cfg(not(tokio_unstable))] let (task, handle) = task::unowned(fut, BlockingSchedule::new(rt), id); let spawned = self.spawn_task(Task::new(task, is_mandatory), rt); diff --git a/tokio/src/runtime/blocking/schedule.rs b/tokio/src/runtime/blocking/schedule.rs index 0e97c5aeaf4..aad4da4f3f4 100644 --- a/tokio/src/runtime/blocking/schedule.rs +++ b/tokio/src/runtime/blocking/schedule.rs @@ -1,7 +1,9 @@ #[cfg(feature = "test-util")] use crate::runtime::scheduler; -use crate::runtime::task::{self, Task, TaskHarnessScheduleHooks}; +use crate::runtime::task::{self, Task}; use crate::runtime::Handle; +#[cfg(tokio_unstable)] +use crate::runtime::{OptionalTaskHooksFactory, OptionalTaskHooksFactoryRef}; /// `task::Schedule` implementation that does nothing (except some bookkeeping /// in test-util builds). This is unique to the blocking scheduler as tasks @@ -12,7 +14,8 @@ use crate::runtime::Handle; pub(crate) struct BlockingSchedule { #[cfg(feature = "test-util")] handle: Handle, - hooks: TaskHarnessScheduleHooks, + #[cfg(tokio_unstable)] + hooks_factory: OptionalTaskHooksFactory, } impl BlockingSchedule { @@ -31,9 +34,8 @@ impl BlockingSchedule { BlockingSchedule { #[cfg(feature = "test-util")] handle: handle.clone(), - hooks: TaskHarnessScheduleHooks { - task_terminate_callback: handle.inner.hooks().task_terminate_callback.clone(), - }, + #[cfg(tokio_unstable)] + hooks_factory: handle.inner.hooks_factory(), } } } @@ -58,9 +60,13 @@ impl task::Schedule for BlockingSchedule { unreachable!(); } - fn hooks(&self) -> TaskHarnessScheduleHooks { - TaskHarnessScheduleHooks { - task_terminate_callback: self.hooks.task_terminate_callback.clone(), - } + #[cfg(tokio_unstable)] + fn hooks_factory(&self) -> OptionalTaskHooksFactory { + self.hooks_factory.clone() + } + + #[cfg(tokio_unstable)] + fn hooks_factory_ref(&self) -> OptionalTaskHooksFactoryRef<'_> { + self.hooks_factory.as_ref().map(AsRef::as_ref) } } diff --git a/tokio/src/runtime/builder.rs b/tokio/src/runtime/builder.rs index 994fcfa5c73..4fbe5565dca 100644 --- a/tokio/src/runtime/builder.rs +++ b/tokio/src/runtime/builder.rs @@ -1,15 +1,19 @@ #![cfg_attr(loom, allow(unused_imports))] +use crate::runtime::blocking::BlockingPool; use crate::runtime::handle::Handle; -use crate::runtime::{blocking, driver, Callback, HistogramBuilder, Runtime, TaskCallback}; +use crate::runtime::scheduler::CurrentThread; +use crate::runtime::{blocking, driver, Callback, HistogramBuilder, Runtime}; #[cfg(tokio_unstable)] -use crate::runtime::{metrics::HistogramConfiguration, LocalOptions, LocalRuntime, TaskMeta}; +use crate::runtime::{ + metrics::HistogramConfiguration, LocalOptions, LocalRuntime, OptionalTaskHooksFactory, + TaskHookHarnessFactory, +}; use crate::util::rand::{RngSeed, RngSeedGenerator}; - -use crate::runtime::blocking::BlockingPool; -use crate::runtime::scheduler::CurrentThread; use std::fmt; use std::io; +#[cfg(tokio_unstable)] +use std::sync::Arc; use std::thread::ThreadId; use std::time::Duration; @@ -85,19 +89,8 @@ pub struct Builder { /// To run after each thread is unparked. pub(super) after_unpark: Option, - /// To run before each task is spawned. - pub(super) before_spawn: Option, - - /// To run before each poll #[cfg(tokio_unstable)] - pub(super) before_poll: Option, - - /// To run after each poll - #[cfg(tokio_unstable)] - pub(super) after_poll: Option, - - /// To run after each task is terminated. - pub(super) after_termination: Option, + pub(super) task_hook_harness_factory: OptionalTaskHooksFactory, /// Customizable keep alive timeout for `BlockingPool` pub(super) keep_alive: Option, @@ -287,13 +280,8 @@ impl Builder { before_park: None, after_unpark: None, - before_spawn: None, - after_termination: None, - #[cfg(tokio_unstable)] - before_poll: None, - #[cfg(tokio_unstable)] - after_poll: None, + task_hook_harness_factory: None, keep_alive: None, @@ -685,188 +673,19 @@ impl Builder { self } - /// Executes function `f` just before a task is spawned. - /// - /// `f` is called within the Tokio context, so functions like - /// [`tokio::spawn`](crate::spawn) can be called, and may result in this callback being - /// invoked immediately. - /// - /// This can be used for bookkeeping or monitoring purposes. - /// - /// Note: There can only be one spawn callback for a runtime; calling this function more - /// than once replaces the last callback defined, rather than adding to it. - /// - /// This *does not* support [`LocalSet`](crate::task::LocalSet) at this time. - /// - /// **Note**: This is an [unstable API][unstable]. The public API of this type - /// may break in 1.x releases. See [the documentation on unstable - /// features][unstable] for details. - /// - /// [unstable]: crate#unstable-features - /// - /// # Examples - /// - /// ``` - /// # use tokio::runtime; - /// # pub fn main() { - /// let runtime = runtime::Builder::new_current_thread() - /// .on_task_spawn(|_| { - /// println!("spawning task"); - /// }) - /// .build() - /// .unwrap(); - /// - /// runtime.block_on(async { - /// tokio::task::spawn(std::future::ready(())); - /// - /// for _ in 0..64 { - /// tokio::task::yield_now().await; - /// } - /// }) - /// # } - /// ``` - #[cfg(all(not(loom), tokio_unstable))] - #[cfg_attr(docsrs, doc(cfg(tokio_unstable)))] - pub fn on_task_spawn(&mut self, f: F) -> &mut Self - where - F: Fn(&TaskMeta<'_>) + Send + Sync + 'static, - { - self.before_spawn = Some(std::sync::Arc::new(f)); - self - } - - /// Executes function `f` just before a task is polled - /// - /// `f` is called within the Tokio context, so functions like - /// [`tokio::spawn`](crate::spawn) can be called, and may result in this callback being - /// invoked immediately. - /// - /// **Note**: This is an [unstable API][unstable]. The public API of this type - /// may break in 1.x releases. See [the documentation on unstable - /// features][unstable] for details. - /// - /// [unstable]: crate#unstable-features - /// - /// # Examples - /// - /// ``` - /// # use std::sync::{atomic::AtomicUsize, Arc}; - /// # use tokio::task::yield_now; - /// # pub fn main() { - /// let poll_start_counter = Arc::new(AtomicUsize::new(0)); - /// let poll_start = poll_start_counter.clone(); - /// let rt = tokio::runtime::Builder::new_multi_thread() - /// .enable_all() - /// .on_before_task_poll(move |meta| { - /// println!("task {} is about to be polled", meta.id()) - /// }) - /// .build() - /// .unwrap(); - /// let task = rt.spawn(async { - /// yield_now().await; - /// }); - /// let _ = rt.block_on(task); - /// - /// # } - /// ``` - #[cfg(tokio_unstable)] - pub fn on_before_task_poll(&mut self, f: F) -> &mut Self - where - F: Fn(&TaskMeta<'_>) + Send + Sync + 'static, - { - self.before_poll = Some(std::sync::Arc::new(f)); - self - } - - /// Executes function `f` just after a task is polled - /// - /// `f` is called within the Tokio context, so functions like - /// [`tokio::spawn`](crate::spawn) can be called, and may result in this callback being - /// invoked immediately. - /// - /// **Note**: This is an [unstable API][unstable]. The public API of this type - /// may break in 1.x releases. See [the documentation on unstable - /// features][unstable] for details. - /// - /// [unstable]: crate#unstable-features - /// - /// # Examples - /// - /// ``` - /// # use std::sync::{atomic::AtomicUsize, Arc}; - /// # use tokio::task::yield_now; - /// # pub fn main() { - /// let poll_stop_counter = Arc::new(AtomicUsize::new(0)); - /// let poll_stop = poll_stop_counter.clone(); - /// let rt = tokio::runtime::Builder::new_multi_thread() - /// .enable_all() - /// .on_after_task_poll(move |meta| { - /// println!("task {} completed polling", meta.id()); - /// }) - /// .build() - /// .unwrap(); - /// let task = rt.spawn(async { - /// yield_now().await; - /// }); - /// let _ = rt.block_on(task); - /// - /// # } - /// ``` - #[cfg(tokio_unstable)] - pub fn on_after_task_poll(&mut self, f: F) -> &mut Self - where - F: Fn(&TaskMeta<'_>) + Send + Sync + 'static, - { - self.after_poll = Some(std::sync::Arc::new(f)); - self - } - - /// Executes function `f` just after a task is terminated. - /// - /// `f` is called within the Tokio context, so functions like - /// [`tokio::spawn`](crate::spawn) can be called. - /// - /// This can be used for bookkeeping or monitoring purposes. - /// - /// Note: There can only be one task termination callback for a runtime; calling this - /// function more than once replaces the last callback defined, rather than adding to it. + /// Factory method for producing "fallback" task hook harnesses. /// - /// This *does not* support [`LocalSet`](crate::task::LocalSet) at this time. - /// - /// **Note**: This is an [unstable API][unstable]. The public API of this type - /// may break in 1.x releases. See [the documentation on unstable - /// features][unstable] for details. - /// - /// [unstable]: crate#unstable-features - /// - /// # Examples - /// - /// ``` - /// # use tokio::runtime; - /// # pub fn main() { - /// let runtime = runtime::Builder::new_current_thread() - /// .on_task_terminate(|_| { - /// println!("killing task"); - /// }) - /// .build() - /// .unwrap(); - /// - /// runtime.block_on(async { - /// tokio::task::spawn(std::future::ready(())); - /// - /// for _ in 0..64 { - /// tokio::task::yield_now().await; - /// } - /// }) - /// # } - /// ``` + /// The order of operations for assigning the hook harness for a task are as follows: + /// 1. [`crate::task::spawn_with_hooks`], if used. + /// 2. [`crate::runtime::task_hooks::TaskHookHarnessFactory`], if it returns something other than [Option::None]. + /// 3. This function. #[cfg(all(not(loom), tokio_unstable))] #[cfg_attr(docsrs, doc(cfg(tokio_unstable)))] - pub fn on_task_terminate(&mut self, f: F) -> &mut Self + pub fn hook_harness_factory(&mut self, hooks: T) -> &mut Self where - F: Fn(&TaskMeta<'_>) + Send + Sync + 'static, + T: TaskHookHarnessFactory + Send + Sync + 'static, { - self.after_termination = Some(std::sync::Arc::new(f)); + self.task_hook_harness_factory = Some(Arc::new(hooks)); self } @@ -1475,12 +1294,8 @@ impl Builder { Config { before_park: self.before_park.clone(), after_unpark: self.after_unpark.clone(), - before_spawn: self.before_spawn.clone(), - #[cfg(tokio_unstable)] - before_poll: self.before_poll.clone(), #[cfg(tokio_unstable)] - after_poll: self.after_poll.clone(), - after_termination: self.after_termination.clone(), + task_hook_factory: self.task_hook_harness_factory.clone(), global_queue_interval: self.global_queue_interval, event_interval: self.event_interval, #[cfg(tokio_unstable)] @@ -1628,12 +1443,8 @@ cfg_rt_multi_thread! { Config { before_park: self.before_park.clone(), after_unpark: self.after_unpark.clone(), - before_spawn: self.before_spawn.clone(), - #[cfg(tokio_unstable)] - before_poll: self.before_poll.clone(), #[cfg(tokio_unstable)] - after_poll: self.after_poll.clone(), - after_termination: self.after_termination.clone(), + task_hook_factory: self.task_hook_harness_factory.clone(), global_queue_interval: self.global_queue_interval, event_interval: self.event_interval, #[cfg(tokio_unstable)] diff --git a/tokio/src/runtime/config.rs b/tokio/src/runtime/config.rs index b79df96e1e2..8537adc5dcd 100644 --- a/tokio/src/runtime/config.rs +++ b/tokio/src/runtime/config.rs @@ -2,7 +2,10 @@ any(not(all(tokio_unstable, feature = "full")), target_family = "wasm"), allow(dead_code) )] -use crate::runtime::{Callback, TaskCallback}; + +use crate::runtime::Callback; +#[cfg(tokio_unstable)] +use crate::runtime::OptionalTaskHooksFactory; use crate::util::RngSeedGenerator; pub(crate) struct Config { @@ -18,19 +21,9 @@ pub(crate) struct Config { /// Callback for a worker unparking itself pub(crate) after_unpark: Option, - /// To run before each task is spawned. - pub(crate) before_spawn: Option, - - /// To run after each task is terminated. - pub(crate) after_termination: Option, - - /// To run before each poll - #[cfg(tokio_unstable)] - pub(crate) before_poll: Option, - - /// To run after each poll + /// Called on task spawn to generate the attached task hook harness. #[cfg(tokio_unstable)] - pub(crate) after_poll: Option, + pub(crate) task_hook_factory: OptionalTaskHooksFactory, /// The multi-threaded scheduler includes a per-worker LIFO slot used to /// store the last scheduled task. This can improve certain usage patterns, diff --git a/tokio/src/runtime/context.rs b/tokio/src/runtime/context.rs index e8f17bb374a..c0fcc64aa15 100644 --- a/tokio/src/runtime/context.rs +++ b/tokio/src/runtime/context.rs @@ -1,10 +1,14 @@ +#[cfg(all(feature = "rt", tokio_unstable))] +use crate::loom::cell::UnsafeCell; use crate::loom::thread::AccessError; +#[cfg(all(feature = "rt", tokio_unstable))] +use crate::runtime::{OptionalTaskHooksMut, OptionalTaskHooksWeak, TaskHookHarness}; use crate::task::coop; - -use std::cell::Cell; - #[cfg(any(feature = "rt", feature = "macros", feature = "time"))] use crate::util::rand::FastRand; +use std::cell::Cell; +#[cfg(all(feature = "rt", tokio_unstable))] +use std::ptr::NonNull; cfg_rt! { mod blocking; @@ -49,6 +53,10 @@ struct Context { #[cfg(feature = "rt")] current_task_id: Cell>, + /// Tracks the current set of task hooks, + #[cfg(all(feature = "rt", tokio_unstable))] + current_task_hooks: OptionalTaskHooksWeak, + /// Tracks if the current thread is currently driving a runtime. /// Note, that if this is set to "entered", the current scheduler /// handle may not reference the runtime currently executing. This @@ -92,6 +100,9 @@ tokio_thread_local! { #[cfg(feature = "rt")] current_task_id: Cell::new(None), + #[cfg(all(feature = "rt", tokio_unstable))] + current_task_hooks: UnsafeCell::new(None), + // Tracks if the current thread is currently driving a runtime. // Note, that if this is set to "entered", the current scheduler // handle may not reference the runtime currently executing. This @@ -139,6 +150,16 @@ pub(crate) fn budget(f: impl FnOnce(&Cell) -> R) -> Result>) -> Result { + CONTEXT.try_with(|ctx| { + ctx.current_task_hooks.with_mut(|x| { + unsafe { + *x = hooks; + } + }) + })?; + + Ok(SetTaskHooksGuard) + } + + #[track_caller] + #[cfg(tokio_unstable)] + pub(super) fn clear_task_hooks() -> Result<(), AccessError> { + CONTEXT.try_with(|ctx| { + ctx.current_task_hooks.with_mut(|x| { + unsafe { + *x = None; + } + }) + })?; + + Ok(()) + } + + #[track_caller] + #[cfg(tokio_unstable)] + pub(super) fn with_task_hooks(f: impl FnOnce(OptionalTaskHooksMut<'_>) -> R) -> Result { + CONTEXT.try_with(|ctx| { + ctx.current_task_hooks.with_mut(|ptr| { + let hooks = unsafe { &mut *ptr }; + unsafe { + f(hooks.as_mut().map(|x| x.as_mut())) + } + }) + }) + } + #[track_caller] pub(crate) fn defer(waker: &Waker) { with_scheduler(|maybe_scheduler| { diff --git a/tokio/src/runtime/handle.rs b/tokio/src/runtime/handle.rs index 7aaba2ff243..7ca76578dd9 100644 --- a/tokio/src/runtime/handle.rs +++ b/tokio/src/runtime/handle.rs @@ -1,5 +1,5 @@ #[cfg(tokio_unstable)] -use crate::runtime; +use crate::runtime::{self, OptionalTaskHooks}; use crate::runtime::{context, scheduler, RuntimeFlavor, RuntimeMetrics}; /// Handle to the runtime. @@ -191,6 +191,13 @@ impl Handle { F::Output: Send + 'static, { let fut_size = mem::size_of::(); + #[cfg(tokio_unstable)] + return if fut_size > BOX_FUTURE_THRESHOLD { + self.spawn_named(Box::pin(future), SpawnMeta::new_unnamed(fut_size), None) + } else { + self.spawn_named(future, SpawnMeta::new_unnamed(fut_size), None) + }; + #[cfg(not(tokio_unstable))] if fut_size > BOX_FUTURE_THRESHOLD { self.spawn_named(Box::pin(future), SpawnMeta::new_unnamed(fut_size)) } else { @@ -329,7 +336,12 @@ impl Handle { } #[track_caller] - pub(crate) fn spawn_named(&self, future: F, _meta: SpawnMeta<'_>) -> JoinHandle + pub(crate) fn spawn_named( + &self, + future: F, + _meta: SpawnMeta<'_>, + #[cfg(tokio_unstable)] parent: OptionalTaskHooks, + ) -> JoinHandle where F: Future + Send + 'static, F::Output: Send + 'static, @@ -345,6 +357,9 @@ impl Handle { let future = super::task::trace::Trace::root(future); #[cfg(all(tokio_unstable, feature = "tracing"))] let future = crate::util::trace::task(future, "task", _meta, id.as_u64()); + #[cfg(tokio_unstable)] + return self.inner.spawn(future, id, parent); + #[cfg(not(tokio_unstable))] self.inner.spawn(future, id) } @@ -354,6 +369,7 @@ impl Handle { &self, future: F, _meta: SpawnMeta<'_>, + #[cfg(tokio_unstable)] hooks_override: OptionalTaskHooks, ) -> JoinHandle where F: Future + 'static, @@ -370,6 +386,9 @@ impl Handle { let future = super::task::trace::Trace::root(future); #[cfg(all(tokio_unstable, feature = "tracing"))] let future = crate::util::trace::task(future, "task", _meta, id.as_u64()); + #[cfg(tokio_unstable)] + return self.inner.spawn_local(future, id, hooks_override); + #[cfg(not(tokio_unstable))] self.inner.spawn_local(future, id) } diff --git a/tokio/src/runtime/local_runtime/runtime.rs b/tokio/src/runtime/local_runtime/runtime.rs index 358a771956b..11fdc097b17 100644 --- a/tokio/src/runtime/local_runtime/runtime.rs +++ b/tokio/src/runtime/local_runtime/runtime.rs @@ -155,9 +155,9 @@ impl LocalRuntime { // safety: spawn_local can only be called from `LocalRuntime`, which this is unsafe { if std::mem::size_of::() > BOX_FUTURE_THRESHOLD { - self.handle.spawn_local_named(Box::pin(future), meta) + self.handle.spawn_local_named(Box::pin(future), meta, None) } else { - self.handle.spawn_local_named(future, meta) + self.handle.spawn_local_named(future, meta, None) } } } diff --git a/tokio/src/runtime/mod.rs b/tokio/src/runtime/mod.rs index 78a0114f48e..026bdd7ef68 100644 --- a/tokio/src/runtime/mod.rs +++ b/tokio/src/runtime/mod.rs @@ -380,13 +380,10 @@ cfg_rt! { pub use dump::Dump; } - mod task_hooks; - pub(crate) use task_hooks::{TaskHooks, TaskCallback}; cfg_unstable! { - pub use task_hooks::TaskMeta; + mod task_hooks; + pub use task_hooks::*; } - #[cfg(not(tokio_unstable))] - pub(crate) use task_hooks::TaskMeta; mod handle; pub use handle::{EnterGuard, Handle, TryCurrentError}; diff --git a/tokio/src/runtime/runtime.rs b/tokio/src/runtime/runtime.rs index 2f2b07d322c..9af21d31249 100644 --- a/tokio/src/runtime/runtime.rs +++ b/tokio/src/runtime/runtime.rs @@ -233,6 +233,15 @@ impl Runtime { F::Output: Send + 'static, { let fut_size = mem::size_of::(); + #[cfg(tokio_unstable)] + return if fut_size > BOX_FUTURE_THRESHOLD { + self.handle + .spawn_named(Box::pin(future), SpawnMeta::new_unnamed(fut_size), None) + } else { + self.handle + .spawn_named(future, SpawnMeta::new_unnamed(fut_size), None) + }; + #[cfg(not(tokio_unstable))] if fut_size > BOX_FUTURE_THRESHOLD { self.handle .spawn_named(Box::pin(future), SpawnMeta::new_unnamed(fut_size)) diff --git a/tokio/src/runtime/scheduler/current_thread/mod.rs b/tokio/src/runtime/scheduler/current_thread/mod.rs index 72e12fae895..a42528ba7d3 100644 --- a/tokio/src/runtime/scheduler/current_thread/mod.rs +++ b/tokio/src/runtime/scheduler/current_thread/mod.rs @@ -1,17 +1,19 @@ use crate::loom::sync::atomic::AtomicBool; use crate::loom::sync::Arc; +#[cfg(tokio_unstable)] +use crate::runtime::context::with_task_hooks; use crate::runtime::driver::{self, Driver}; use crate::runtime::scheduler::{self, Defer, Inject}; -use crate::runtime::task::{ - self, JoinHandle, OwnedTasks, Schedule, Task, TaskHarnessScheduleHooks, -}; +use crate::runtime::task::{self, JoinHandle, OwnedTasks, Schedule, Task}; +use crate::runtime::{blocking, context, Config, MetricsBatch, SchedulerMetrics, WorkerMetrics}; +#[cfg(tokio_unstable)] use crate::runtime::{ - blocking, context, Config, MetricsBatch, SchedulerMetrics, TaskHooks, TaskMeta, WorkerMetrics, + OnChildTaskSpawnContext, OnTopLevelTaskSpawnContext, OptionalTaskHooks, + OptionalTaskHooksFactory, OptionalTaskHooksFactoryRef, }; use crate::sync::notify::Notify; use crate::util::atomic_cell::AtomicCell; use crate::util::{waker_ref, RngSeedGenerator, Wake, WakerRef}; - use std::cell::RefCell; use std::collections::VecDeque; use std::future::{poll_fn, Future}; @@ -20,7 +22,7 @@ use std::task::Poll::{Pending, Ready}; use std::task::Waker; use std::thread::ThreadId; use std::time::Duration; -use std::{fmt, thread}; +use std::{fmt, panic, thread}; /// Executes tasks on the current thread pub(crate) struct CurrentThread { @@ -47,7 +49,8 @@ pub(crate) struct Handle { pub(crate) seed_generator: RngSeedGenerator, /// User-supplied hooks to invoke for things - pub(crate) task_hooks: TaskHooks, + #[cfg(tokio_unstable)] + pub(crate) task_hooks: OptionalTaskHooksFactory, /// If this is a `LocalRuntime`, flags the owning thread ID. pub(crate) local_tid: Option, @@ -142,14 +145,8 @@ impl CurrentThread { .unwrap_or(DEFAULT_GLOBAL_QUEUE_INTERVAL); let handle = Arc::new(Handle { - task_hooks: TaskHooks { - task_spawn_callback: config.before_spawn.clone(), - task_terminate_callback: config.after_termination.clone(), - #[cfg(tokio_unstable)] - before_poll_callback: config.before_poll.clone(), - #[cfg(tokio_unstable)] - after_poll_callback: config.after_poll.clone(), - }, + #[cfg(tokio_unstable)] + task_hooks: config.task_hook_factory.clone(), shared: Shared { inject: Inject::new(), owned: OwnedTasks::new(1), @@ -448,19 +445,61 @@ impl Handle { pub(crate) fn spawn( me: &Arc, future: F, - id: crate::runtime::task::Id, + id: task::Id, + #[cfg(tokio_unstable)] hooks_override: OptionalTaskHooks, ) -> JoinHandle where F: crate::future::Future + Send + 'static, F::Output: Send + 'static, { - let (handle, notified) = me.shared.owned.bind(future, me.clone(), id); - - me.task_hooks.spawn(&TaskMeta { - id, - _phantom: Default::default(), + // preference order for hook selection: + // 1. "hook override" - comes from builder + // 2. parent task's hook + // 3. runtime hook factory + #[cfg(tokio_unstable)] + let hooks = hooks_override.or_else(|| { + with_task_hooks(|parent| { + parent + .map(|parent| { + if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| { + parent.on_child_spawn(&mut OnChildTaskSpawnContext { + id, + _phantom: Default::default(), + }) + })) { + r + } else { + None + } + }) + .flatten() + }) + .ok() + .flatten() + .or_else(|| { + if let Some(hooks) = me.hooks_factory_ref() { + if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| { + hooks.on_top_level_spawn(&mut OnTopLevelTaskSpawnContext { + id, + _phantom: Default::default(), + }) + })) { + r + } else { + None + } + } else { + None + } + }) }); + #[cfg(tokio_unstable)] + let (handle, notified) = me.shared.owned.bind(future, me.clone(), id, hooks); + + #[cfg(not(tokio_unstable))] + let (handle, notified) = me.shared.owned.bind(future, me.clone(), id); + if let Some(notified) = notified { me.schedule(notified); } @@ -477,18 +516,62 @@ impl Handle { pub(crate) unsafe fn spawn_local( me: &Arc, future: F, - id: crate::runtime::task::Id, + id: task::Id, + #[cfg(tokio_unstable)] hooks_override: OptionalTaskHooks, ) -> JoinHandle where F: crate::future::Future + 'static, F::Output: 'static, { - let (handle, notified) = me.shared.owned.bind_local(future, me.clone(), id); + // preference order for hook selection: + // 1. "hook override" - comes from builder + // 2. parent task's hook + // 3. runtime hook factory + #[cfg(tokio_unstable)] + let hooks = hooks_override.or_else(|| { + with_task_hooks(|parent| { + parent + .map(|parent| { + if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| { + parent.on_child_spawn(&mut OnChildTaskSpawnContext { + id, + _phantom: Default::default(), + }) + })) { + r + } else { + None + } + }) + .flatten() + }) + .ok() + .flatten() + .or_else(|| { + if let Some(hooks) = me.hooks_factory_ref() { + if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| { + hooks.on_top_level_spawn(&mut OnTopLevelTaskSpawnContext { + id, + _phantom: Default::default(), + }) + })) { + r + } else { + None + } + } else { + None + } + }) + }); - me.task_hooks.spawn(&TaskMeta { + let (handle, notified) = me.shared.owned.bind_local( + future, + me.clone(), id, - _phantom: Default::default(), - }); + #[cfg(tokio_unstable)] + hooks, + ); if let Some(notified) = notified { me.schedule(notified); @@ -654,10 +737,14 @@ impl Schedule for Arc { }); } - fn hooks(&self) -> TaskHarnessScheduleHooks { - TaskHarnessScheduleHooks { - task_terminate_callback: self.task_hooks.task_terminate_callback.clone(), - } + #[cfg(tokio_unstable)] + fn hooks_factory(&self) -> OptionalTaskHooksFactory { + self.task_hooks.clone() + } + + #[cfg(tokio_unstable)] + fn hooks_factory_ref(&self) -> OptionalTaskHooksFactoryRef<'_> { + self.task_hooks.as_ref().map(AsRef::as_ref) } cfg_unstable! { @@ -770,17 +857,8 @@ impl CoreGuard<'_> { let task = context.handle.shared.owned.assert_owner(task); - #[cfg(tokio_unstable)] - let task_id = task.task_id(); - let (c, ()) = context.run_task(core, || { - #[cfg(tokio_unstable)] - context.handle.task_hooks.poll_start_callback(task_id); - task.run(); - - #[cfg(tokio_unstable)] - context.handle.task_hooks.poll_stop_callback(task_id); }); core = c; diff --git a/tokio/src/runtime/scheduler/mod.rs b/tokio/src/runtime/scheduler/mod.rs index 8241b57c1de..6b2b2cf645a 100644 --- a/tokio/src/runtime/scheduler/mod.rs +++ b/tokio/src/runtime/scheduler/mod.rs @@ -8,8 +8,6 @@ cfg_rt! { pub(crate) mod inject; pub(crate) use inject::Inject; - use crate::runtime::TaskHooks; - use crate::runtime::WorkerMetrics; } @@ -25,6 +23,10 @@ cfg_rt_multi_thread! { } use crate::runtime::driver; +#[cfg(all(feature = "rt", tokio_unstable))] +use crate::runtime::task::Schedule; +#[cfg(all(feature = "rt", tokio_unstable))] +use crate::runtime::{OptionalTaskHooks, OptionalTaskHooksFactory}; #[derive(Debug, Clone)] pub(crate) enum Handle { @@ -117,11 +119,24 @@ cfg_rt! { } } - pub(crate) fn spawn(&self, future: F, id: Id) -> JoinHandle + pub(crate) fn spawn(&self, + future: F, + id: Id, + #[cfg(tokio_unstable)] + hooks_override: OptionalTaskHooks + ) -> JoinHandle where F: Future + Send + 'static, F::Output: Send + 'static, { + #[cfg(tokio_unstable)] + return match self { + Handle::CurrentThread(h) => current_thread::Handle::spawn(h, future, id, hooks_override), + + #[cfg(feature = "rt-multi-thread")] + Handle::MultiThread(h) => multi_thread::Handle::spawn(h, future, id, hooks_override), + }; + #[cfg(not(tokio_unstable))] match self { Handle::CurrentThread(h) => current_thread::Handle::spawn(h, future, id), @@ -136,12 +151,15 @@ cfg_rt! { /// This should only be called in `LocalRuntime` if the runtime has been verified to be owned /// by the current thread. #[allow(irrefutable_let_patterns)] - pub(crate) unsafe fn spawn_local(&self, future: F, id: Id) -> JoinHandle + pub(crate) unsafe fn spawn_local(&self, future: F, id: Id, #[cfg(tokio_unstable)] hooks_override: OptionalTaskHooks) -> JoinHandle where F: Future + 'static, F::Output: 'static, { if let Handle::CurrentThread(h) = self { + #[cfg(tokio_unstable)] + return current_thread::Handle::spawn_local(h, future, id, hooks_override); + #[cfg(not(tokio_unstable))] current_thread::Handle::spawn_local(h, future, id) } else { panic!("Only current_thread and LocalSet have spawn_local internals implemented") @@ -169,12 +187,9 @@ cfg_rt! { } } - pub(crate) fn hooks(&self) -> &TaskHooks { - match self { - Handle::CurrentThread(h) => &h.task_hooks, - #[cfg(feature = "rt-multi-thread")] - Handle::MultiThread(h) => &h.task_hooks, - } + #[cfg(tokio_unstable)] + pub(crate) fn hooks_factory(&self) -> OptionalTaskHooksFactory { + match_flavor!(self, Handle(h) => h.hooks_factory()) } } diff --git a/tokio/src/runtime/scheduler/multi_thread/handle.rs b/tokio/src/runtime/scheduler/multi_thread/handle.rs index 4075713c979..fa8973a8fab 100644 --- a/tokio/src/runtime/scheduler/multi_thread/handle.rs +++ b/tokio/src/runtime/scheduler/multi_thread/handle.rs @@ -1,14 +1,20 @@ use crate::future::Future; use crate::loom::sync::Arc; +#[cfg(tokio_unstable)] +use crate::runtime::context::with_task_hooks; use crate::runtime::scheduler::multi_thread::worker; +#[cfg(tokio_unstable)] +use crate::runtime::task::Schedule; use crate::runtime::{ blocking, driver, task::{self, JoinHandle}, - TaskHooks, TaskMeta, }; +#[cfg(tokio_unstable)] +use crate::runtime::{OnChildTaskSpawnContext, OnTopLevelTaskSpawnContext, OptionalTaskHooks}; use crate::util::RngSeedGenerator; - use std::fmt; +#[cfg(tokio_unstable)] +use std::panic; mod metrics; @@ -29,18 +35,24 @@ pub(crate) struct Handle { /// Current random number generator seed pub(crate) seed_generator: RngSeedGenerator, - - /// User-supplied hooks to invoke for things - pub(crate) task_hooks: TaskHooks, } impl Handle { /// Spawns a future onto the thread pool - pub(crate) fn spawn(me: &Arc, future: F, id: task::Id) -> JoinHandle + pub(crate) fn spawn( + me: &Arc, + future: F, + id: task::Id, + #[cfg(tokio_unstable)] hooks_override: OptionalTaskHooks, + ) -> JoinHandle where F: crate::future::Future + Send + 'static, F::Output: Send + 'static, { + #[cfg(tokio_unstable)] + return Self::bind_new_task(me, future, id, hooks_override); + + #[cfg(not(tokio_unstable))] Self::bind_new_task(me, future, id) } @@ -48,17 +60,65 @@ impl Handle { self.close(); } - pub(super) fn bind_new_task(me: &Arc, future: T, id: task::Id) -> JoinHandle + pub(super) fn bind_new_task( + me: &Arc, + future: T, + id: task::Id, + #[cfg(tokio_unstable)] hooks_override: OptionalTaskHooks, + ) -> JoinHandle where T: Future + Send + 'static, T::Output: Send + 'static, { - let (handle, notified) = me.shared.owned.bind(future, me.clone(), id); + // preference order for hook selection: + // 1. "hook override" - comes from builder + // 2. parent task's hook + // 3. runtime hook factory + #[cfg(tokio_unstable)] + let hooks = hooks_override.or_else(|| { + with_task_hooks(|parent| { + parent + .map(|parent| { + if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| { + parent.on_child_spawn(&mut OnChildTaskSpawnContext { + id, + _phantom: Default::default(), + }) + })) { + r + } else { + None + } + }) + .flatten() + }) + .ok() + .flatten() + .or_else(|| { + if let Some(hooks) = me.hooks_factory_ref() { + if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| { + hooks.on_top_level_spawn(&mut OnTopLevelTaskSpawnContext { + id, + _phantom: Default::default(), + }) + })) { + r + } else { + None + } + } else { + None + } + }) + }); - me.task_hooks.spawn(&TaskMeta { + let (handle, notified) = me.shared.owned.bind( + future, + me.clone(), id, - _phantom: Default::default(), - }); + #[cfg(tokio_unstable)] + hooks, + ); me.schedule_option_task_without_yield(notified); diff --git a/tokio/src/runtime/scheduler/multi_thread/worker.rs b/tokio/src/runtime/scheduler/multi_thread/worker.rs index e33b9baea2c..2bd2438f378 100644 --- a/tokio/src/runtime/scheduler/multi_thread/worker.rs +++ b/tokio/src/runtime/scheduler/multi_thread/worker.rs @@ -58,13 +58,15 @@ use crate::loom::sync::{Arc, Mutex}; use crate::runtime; +use crate::runtime::context; use crate::runtime::scheduler::multi_thread::{ idle, queue, Counters, Handle, Idle, Overflow, Parker, Stats, TraceStatus, Unparker, }; use crate::runtime::scheduler::{inject, Defer, Lock}; -use crate::runtime::task::{OwnedTasks, TaskHarnessScheduleHooks}; +use crate::runtime::task::OwnedTasks; use crate::runtime::{blocking, driver, scheduler, task, Config, SchedulerMetrics, WorkerMetrics}; -use crate::runtime::{context, TaskHooks}; +#[cfg(tokio_unstable)] +use crate::runtime::{OptionalTaskHooksFactory, OptionalTaskHooksFactoryRef}; use crate::task::coop; use crate::util::atomic_cell::AtomicCell; use crate::util::rand::{FastRand, RngSeedGenerator}; @@ -281,7 +283,6 @@ pub(super) fn create( let remotes_len = remotes.len(); let handle = Arc::new(Handle { - task_hooks: TaskHooks::from_config(&config), shared: Shared { remotes: remotes.into_boxed_slice(), inject, @@ -570,9 +571,6 @@ impl Context { } fn run_task(&self, task: Notified, mut core: Box) -> RunResult { - #[cfg(tokio_unstable)] - let task_id = task.task_id(); - let task = self.worker.handle.shared.owned.assert_owner(task); // Make sure the worker is not in the **searching** state. This enables @@ -592,16 +590,8 @@ impl Context { // Run the task coop::budget(|| { - // Unlike the poll time above, poll start callback is attached to the task id, - // so it is tightly associated with the actual poll invocation. - #[cfg(tokio_unstable)] - self.worker.handle.task_hooks.poll_start_callback(task_id); - task.run(); - #[cfg(tokio_unstable)] - self.worker.handle.task_hooks.poll_stop_callback(task_id); - let mut lifo_polls = 0; // As long as there is budget remaining and a task exists in the @@ -665,16 +655,7 @@ impl Context { *self.core.borrow_mut() = Some(core); let task = self.worker.handle.shared.owned.assert_owner(task); - #[cfg(tokio_unstable)] - let task_id = task.task_id(); - - #[cfg(tokio_unstable)] - self.worker.handle.task_hooks.poll_start_callback(task_id); - task.run(); - - #[cfg(tokio_unstable)] - self.worker.handle.task_hooks.poll_stop_callback(task_id); } }) } @@ -1063,10 +1044,18 @@ impl task::Schedule for Arc { self.schedule_task(task, false); } - fn hooks(&self) -> TaskHarnessScheduleHooks { - TaskHarnessScheduleHooks { - task_terminate_callback: self.task_hooks.task_terminate_callback.clone(), - } + #[cfg(tokio_unstable)] + fn hooks_factory(&self) -> OptionalTaskHooksFactory { + self.shared.config.task_hook_factory.clone() + } + + #[cfg(tokio_unstable)] + fn hooks_factory_ref(&self) -> OptionalTaskHooksFactoryRef<'_> { + self.shared + .config + .task_hook_factory + .as_ref() + .map(AsRef::as_ref) } fn yield_now(&self, task: Notified) { diff --git a/tokio/src/runtime/task/core.rs b/tokio/src/runtime/task/core.rs index 5d3ca0e00c9..7f122581c29 100644 --- a/tokio/src/runtime/task/core.rs +++ b/tokio/src/runtime/task/core.rs @@ -14,7 +14,9 @@ use crate::loom::cell::UnsafeCell; use crate::runtime::context; use crate::runtime::task::raw::{self, Vtable}; use crate::runtime::task::state::State; -use crate::runtime::task::{Id, Schedule, TaskHarnessScheduleHooks}; +use crate::runtime::task::{Id, Schedule}; +#[cfg(tokio_unstable)] +use crate::runtime::OptionalTaskHooks; use crate::util::linked_list; use std::num::NonZeroU64; @@ -186,7 +188,8 @@ pub(super) struct Trailer { /// Consumer task waiting on completion of this task. pub(super) waker: UnsafeCell>, /// Optional hooks needed in the harness. - pub(super) hooks: TaskHarnessScheduleHooks, + #[cfg(tokio_unstable)] + pub(super) hooks: UnsafeCell, } generate_addr_of_methods! { @@ -208,7 +211,13 @@ pub(super) enum Stage { impl Cell { /// Allocates a new task cell, containing the header, trailer, and core /// structures. - pub(super) fn new(future: T, scheduler: S, state: State, task_id: Id) -> Box> { + pub(super) fn new( + future: T, + scheduler: S, + state: State, + task_id: Id, + #[cfg(tokio_unstable)] hooks: OptionalTaskHooks, + ) -> Box> { // Separated into a non-generic function to reduce LLVM codegen fn new_header( state: State, @@ -229,7 +238,13 @@ impl Cell { let tracing_id = future.id(); let vtable = raw::vtable::(); let result = Box::new(Cell { - trailer: Trailer::new(scheduler.hooks()), + #[cfg(tokio_unstable)] + trailer: Trailer::new( + #[cfg(tokio_unstable)] + hooks, + ), + #[cfg(not(tokio_unstable))] + trailer: Trailer::new(), header: new_header( state, vtable, @@ -462,11 +477,12 @@ impl Header { } impl Trailer { - fn new(hooks: TaskHarnessScheduleHooks) -> Self { + fn new(#[cfg(tokio_unstable)] hooks: OptionalTaskHooks) -> Self { Trailer { waker: UnsafeCell::new(None), owned: linked_list::Pointers::new(), - hooks, + #[cfg(tokio_unstable)] + hooks: UnsafeCell::new(hooks), } } diff --git a/tokio/src/runtime/task/harness.rs b/tokio/src/runtime/task/harness.rs index 9bf73b74fbf..f039731a51c 100644 --- a/tokio/src/runtime/task/harness.rs +++ b/tokio/src/runtime/task/harness.rs @@ -1,10 +1,12 @@ use crate::future::Future; +#[cfg(tokio_unstable)] +use crate::runtime::context::with_task_hooks; use crate::runtime::task::core::{Cell, Core, Header, Trailer}; use crate::runtime::task::state::{Snapshot, State}; use crate::runtime::task::waker::waker_ref; use crate::runtime::task::{Id, JoinError, Notified, RawTask, Schedule, Task}; - -use crate::runtime::TaskMeta; +#[cfg(tokio_unstable)] +use crate::runtime::{AfterTaskPollContext, OnTaskTerminateContext}; use std::any::Any; use std::mem; use std::mem::ManuallyDrop; @@ -150,8 +152,21 @@ where /// All necessary state checks and transitions are performed. /// Panics raised while polling the future are handled. pub(super) fn poll(self) { + let res = self.poll_inner(); + + #[cfg(tokio_unstable)] + let _ = with_task_hooks(|t| { + if let Some(hooks) = t { + let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| { + hooks.after_poll(&mut AfterTaskPollContext { + _phantom: Default::default(), + }) + })); + } + }); + // We pass our ref-count to `poll_inner`. - match self.poll_inner() { + match res { PollFuture::Notified => { // The `poll_inner` call has given us two ref-counts back. // We give one of them to a new task and call `yield_now`. @@ -367,14 +382,16 @@ where // // We call this in a separate block so that it runs after the task appears to have // completed and will still run if the destructor panics. - if let Some(f) = self.trailer().hooks.task_terminate_callback.as_ref() { - let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| { - f(&TaskMeta { - id: self.core().task_id, - _phantom: Default::default(), - }) - })); - } + #[cfg(tokio_unstable)] + let _ = with_task_hooks(|t| { + if let Some(hooks) = t { + let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| { + hooks.on_task_terminate(&mut OnTaskTerminateContext { + _phantom: Default::default(), + }) + })); + } + }); // The task has completed execution and will no longer be scheduled. let num_release = self.release(); diff --git a/tokio/src/runtime/task/list.rs b/tokio/src/runtime/task/list.rs index 54bfc01aafb..91e89dd4ffa 100644 --- a/tokio/src/runtime/task/list.rs +++ b/tokio/src/runtime/task/list.rs @@ -13,9 +13,10 @@ use crate::util::linked_list::{Link, LinkedList}; use crate::util::sharded_list; use crate::loom::sync::atomic::{AtomicBool, Ordering}; +#[cfg(tokio_unstable)] +use crate::runtime::OptionalTaskHooks; use std::marker::PhantomData; use std::num::NonZeroU64; - // The id from the module below is used to verify whether a given task is stored // in this OwnedTasks, or some other task. The counter starts at one so we can // use `None` for tasks not owned by any list. @@ -91,13 +92,20 @@ impl OwnedTasks { task: T, scheduler: S, id: super::Id, + #[cfg(tokio_unstable)] hooks: OptionalTaskHooks, ) -> (JoinHandle, Option>) where S: Schedule, T: Future + Send + 'static, T::Output: Send + 'static, { - let (task, notified, join) = super::new_task(task, scheduler, id); + let (task, notified, join) = super::new_task( + task, + scheduler, + id, + #[cfg(tokio_unstable)] + hooks, + ); let notified = unsafe { self.bind_inner(task, notified) }; (join, notified) } @@ -111,13 +119,20 @@ impl OwnedTasks { task: T, scheduler: S, id: super::Id, + #[cfg(tokio_unstable)] parent: OptionalTaskHooks, ) -> (JoinHandle, Option>) where S: Schedule, T: Future + 'static, T::Output: 'static, { - let (task, notified, join) = super::new_task(task, scheduler, id); + let (task, notified, join) = super::new_task( + task, + scheduler, + id, + #[cfg(tokio_unstable)] + parent, + ); let notified = unsafe { self.bind_inner(task, notified) }; (join, notified) } @@ -258,12 +273,16 @@ impl LocalOwnedTasks { task: T, scheduler: S, id: super::Id, + #[cfg(tokio_unstable)] parent: OptionalTaskHooks, ) -> (JoinHandle, Option>) where S: Schedule, T: Future + 'static, T::Output: 'static, { + #[cfg(tokio_unstable)] + let (task, notified, join) = super::new_task(task, scheduler, id, parent); + #[cfg(not(tokio_unstable))] let (task, notified, join) = super::new_task(task, scheduler, id); unsafe { diff --git a/tokio/src/runtime/task/mod.rs b/tokio/src/runtime/task/mod.rs index 7d314c3b176..7cf2f1e98f7 100644 --- a/tokio/src/runtime/task/mod.rs +++ b/tokio/src/runtime/task/mod.rs @@ -221,10 +221,10 @@ cfg_taskdump! { } use crate::future::Future; +#[cfg(tokio_unstable)] +use crate::runtime::{OptionalTaskHooks, OptionalTaskHooksFactory, OptionalTaskHooksFactoryRef}; use crate::util::linked_list; use crate::util::sharded_list; - -use crate::runtime::TaskCallback; use std::marker::PhantomData; use std::ptr::NonNull; use std::{fmt, mem}; @@ -256,13 +256,6 @@ pub(crate) struct LocalNotified { _not_send: PhantomData<*const ()>, } -impl LocalNotified { - #[cfg(tokio_unstable)] - pub(crate) fn task_id(&self) -> Id { - self.task.id() - } -} - /// A task that is not owned by any `OwnedTasks`. Used for blocking tasks. /// This type holds two ref-counts. pub(crate) struct UnownedTask { @@ -277,12 +270,6 @@ unsafe impl Sync for UnownedTask {} /// Task result sent back. pub(crate) type Result = std::result::Result; -/// Hooks for scheduling tasks which are needed in the task harness. -#[derive(Clone)] -pub(crate) struct TaskHarnessScheduleHooks { - pub(crate) task_terminate_callback: Option, -} - pub(crate) trait Schedule: Sync + Sized + 'static { /// The task has completed work and is ready to be released. The scheduler /// should release it immediately and return it. The task module will batch @@ -294,7 +281,11 @@ pub(crate) trait Schedule: Sync + Sized + 'static { /// Schedule the task fn schedule(&self, task: Notified); - fn hooks(&self) -> TaskHarnessScheduleHooks; + #[cfg(tokio_unstable)] + fn hooks_factory(&self) -> OptionalTaskHooksFactory; + + #[cfg(tokio_unstable)] + fn hooks_factory_ref(&self) -> OptionalTaskHooksFactoryRef<'_>; /// Schedule the task to run in the near future, yielding the thread to /// other tasks. @@ -317,13 +308,19 @@ cfg_rt! { task: T, scheduler: S, id: Id, + #[cfg(tokio_unstable)] + hooks: OptionalTaskHooks ) -> (Task, Notified, JoinHandle) where S: Schedule, T: Future + 'static, T::Output: 'static, { + #[cfg(tokio_unstable)] + let raw = RawTask::new::(task, scheduler, id, hooks); + #[cfg(not(tokio_unstable))] let raw = RawTask::new::(task, scheduler, id); + let task = Task { raw, _p: PhantomData, @@ -341,12 +338,16 @@ cfg_rt! { /// only when the task is not going to be stored in an `OwnedTasks` list. /// /// Currently only blocking tasks use this method. - pub(crate) fn unowned(task: T, scheduler: S, id: Id) -> (UnownedTask, JoinHandle) + pub(crate) fn unowned(task: T, scheduler: S, id: Id, #[cfg(tokio_unstable)] hooks: OptionalTaskHooks) -> (UnownedTask, JoinHandle) where S: Schedule, T: Send + Future + 'static, T::Output: Send + 'static, { + #[cfg(tokio_unstable)] + let (task, notified, join) = new_task(task, scheduler, id, hooks); + + #[cfg(not(tokio_unstable))] let (task, notified, join) = new_task(task, scheduler, id); // This transfers the ref-count of task and notified into an UnownedTask. @@ -459,6 +460,7 @@ impl LocalNotified { /// Runs the task. pub(crate) fn run(self) { let raw = self.task.raw; + mem::forget(self); raw.poll(); } diff --git a/tokio/src/runtime/task/raw.rs b/tokio/src/runtime/task/raw.rs index 6699551f3ec..ad2f30677dc 100644 --- a/tokio/src/runtime/task/raw.rs +++ b/tokio/src/runtime/task/raw.rs @@ -1,7 +1,12 @@ use crate::future::Future; +#[cfg(tokio_unstable)] +use crate::runtime::context::set_task_hooks; use crate::runtime::task::core::{Core, Trailer}; use crate::runtime::task::{Cell, Harness, Header, Id, Schedule, State}; - +#[cfg(tokio_unstable)] +use crate::runtime::{BeforeTaskPollContext, OptionalTaskHooks, TaskHookHarness}; +#[cfg(tokio_unstable)] +use std::panic; use std::ptr::NonNull; use std::task::{Poll, Waker}; @@ -157,12 +162,24 @@ const fn get_id_offset( } impl RawTask { - pub(super) fn new(task: T, scheduler: S, id: Id) -> RawTask + pub(super) fn new( + task: T, + scheduler: S, + id: Id, + #[cfg(tokio_unstable)] hooks: OptionalTaskHooks, + ) -> RawTask where T: Future, S: Schedule, { - let ptr = Box::into_raw(Cell::<_, S>::new(task, scheduler, State::new(), id)); + let ptr = Box::into_raw(Cell::<_, S>::new( + task, + scheduler, + State::new(), + id, + #[cfg(tokio_unstable)] + hooks, + )); let ptr = unsafe { NonNull::new_unchecked(ptr.cast()) }; RawTask { ptr } @@ -197,8 +214,30 @@ impl RawTask { /// Safety: mutual exclusion is required to call this function. pub(crate) fn poll(self) { - let vtable = self.header().vtable; - unsafe { (vtable.poll)(self.ptr) } + #[cfg(tokio_unstable)] + self.trailer().hooks.with_mut(|ptr| unsafe { + let _guard = ptr.as_mut().and_then(|x| { + x.as_mut().map(|x| { + let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| { + x.before_poll(&mut BeforeTaskPollContext { + _phantom: Default::default(), + }) + })); + + set_task_hooks(NonNull::new( + (&mut **x) as *mut (dyn TaskHookHarness + Send + Sync + 'static), + )) + }) + }); + + let vtable = self.header().vtable; + (vtable.poll)(self.ptr); + }); + #[cfg(not(tokio_unstable))] + unsafe { + let vtable = self.header().vtable; + (vtable.poll)(self.ptr); + } } pub(super) fn schedule(self) { diff --git a/tokio/src/runtime/task_hooks.rs b/tokio/src/runtime/task_hooks.rs deleted file mode 100644 index 13865ed515d..00000000000 --- a/tokio/src/runtime/task_hooks.rs +++ /dev/null @@ -1,81 +0,0 @@ -use std::marker::PhantomData; - -use super::Config; - -impl TaskHooks { - pub(crate) fn spawn(&self, meta: &TaskMeta<'_>) { - if let Some(f) = self.task_spawn_callback.as_ref() { - f(meta) - } - } - - #[allow(dead_code)] - pub(crate) fn from_config(config: &Config) -> Self { - Self { - task_spawn_callback: config.before_spawn.clone(), - task_terminate_callback: config.after_termination.clone(), - #[cfg(tokio_unstable)] - before_poll_callback: config.before_poll.clone(), - #[cfg(tokio_unstable)] - after_poll_callback: config.after_poll.clone(), - } - } - - #[cfg(tokio_unstable)] - #[inline] - pub(crate) fn poll_start_callback(&self, id: super::task::Id) { - if let Some(poll_start) = &self.before_poll_callback { - (poll_start)(&TaskMeta { - id, - _phantom: std::marker::PhantomData, - }) - } - } - - #[cfg(tokio_unstable)] - #[inline] - pub(crate) fn poll_stop_callback(&self, id: super::task::Id) { - if let Some(poll_stop) = &self.after_poll_callback { - (poll_stop)(&TaskMeta { - id, - _phantom: std::marker::PhantomData, - }) - } - } -} - -#[derive(Clone)] -pub(crate) struct TaskHooks { - pub(crate) task_spawn_callback: Option, - pub(crate) task_terminate_callback: Option, - #[cfg(tokio_unstable)] - pub(crate) before_poll_callback: Option, - #[cfg(tokio_unstable)] - pub(crate) after_poll_callback: Option, -} - -/// Task metadata supplied to user-provided hooks for task events. -/// -/// **Note**: This is an [unstable API][unstable]. The public API of this type -/// may break in 1.x releases. See [the documentation on unstable -/// features][unstable] for details. -/// -/// [unstable]: crate#unstable-features -#[allow(missing_debug_implementations)] -#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] -pub struct TaskMeta<'a> { - /// The opaque ID of the task. - pub(crate) id: super::task::Id, - pub(crate) _phantom: PhantomData<&'a ()>, -} - -impl<'a> TaskMeta<'a> { - /// Return the opaque ID of the task. - #[cfg_attr(not(tokio_unstable), allow(unreachable_pub, dead_code))] - pub fn id(&self) -> super::task::Id { - self.id - } -} - -/// Runs on specific task-related events -pub(crate) type TaskCallback = std::sync::Arc) + Send + Sync>; diff --git a/tokio/src/runtime/task_hooks/mod.rs b/tokio/src/runtime/task_hooks/mod.rs new file mode 100644 index 00000000000..c54c2787991 --- /dev/null +++ b/tokio/src/runtime/task_hooks/mod.rs @@ -0,0 +1,99 @@ +use super::task; +use crate::loom::cell::UnsafeCell; +use std::marker::PhantomData; +use std::ptr::NonNull; +use std::sync::Arc; + +/// A factory which produces new [`TaskHookHarness`] objects for tasks which either have been +/// spawned in "detached mode" via the builder, or which were spawned from outside the runtime or +/// from another context where no [`TaskHookHarness`] was present. +pub trait TaskHookHarnessFactory { + /// Create a new [`TaskHookHarness`] object which the runtime will attach to a given task. + fn on_top_level_spawn( + &self, + ctx: &mut OnTopLevelTaskSpawnContext<'_>, + ) -> Option>; +} + +/// Trait for user-provided "harness" objects which are attached to tasks and provide hook +/// implementations. +#[allow(unused_variables)] +pub trait TaskHookHarness { + /// Pre-poll task hook which runs arbitrary user logic. + fn before_poll(&mut self, ctx: &mut BeforeTaskPollContext<'_>) {} + + /// Post-poll task hook which runs arbitrary user logic. + fn after_poll(&mut self, ctx: &mut AfterTaskPollContext<'_>) {} + + /// Task hook which runs when this task spawns a child, unless that child is explicitly spawned + /// detached from the parent. + /// + /// This hook creates a harness for the child, or detaches the child from any instrumentation. + fn on_child_spawn( + &mut self, + ctx: &mut OnChildTaskSpawnContext<'_>, + ) -> Option> { + None + } + + /// Task hook which runs on task termination. + fn on_task_terminate(&mut self, ctx: &mut OnTaskTerminateContext<'_>) {} +} + +pub(crate) type OptionalTaskHooksFactory = + Option>; +pub(crate) type OptionalTaskHooks = Option>; + +pub(crate) type OptionalTaskHooksWeak = + UnsafeCell>>; + +pub(crate) type OptionalTaskHooksMut<'a> = + Option<&'a mut (dyn TaskHookHarness + Send + Sync + 'static)>; +pub(crate) type OptionalTaskHooksFactoryRef<'a> = + Option<&'a (dyn TaskHookHarnessFactory + Send + Sync + 'static)>; + +#[allow(missing_debug_implementations, missing_docs)] +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +pub struct OnTopLevelTaskSpawnContext<'a> { + pub(crate) id: task::Id, + pub(crate) _phantom: PhantomData<&'a ()>, +} + +impl<'a> OnTopLevelTaskSpawnContext<'a> { + /// Returns the ID of the task. + pub fn id(&self) -> task::Id { + self.id + } +} + +#[allow(missing_debug_implementations, missing_docs)] +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +pub struct OnChildTaskSpawnContext<'a> { + pub(crate) id: task::Id, + pub(crate) _phantom: PhantomData<&'a ()>, +} + +impl<'a> OnChildTaskSpawnContext<'a> { + /// Returns the ID of the task. + pub fn id(&self) -> task::Id { + self.id + } +} + +#[allow(missing_debug_implementations, missing_docs)] +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +pub struct OnTaskTerminateContext<'a> { + pub(crate) _phantom: PhantomData<&'a ()>, +} + +#[allow(missing_debug_implementations, missing_docs)] +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +pub struct BeforeTaskPollContext<'a> { + pub(crate) _phantom: PhantomData<&'a ()>, +} + +#[allow(missing_debug_implementations, missing_docs)] +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +pub struct AfterTaskPollContext<'a> { + pub(crate) _phantom: PhantomData<&'a ()>, +} diff --git a/tokio/src/runtime/tests/mod.rs b/tokio/src/runtime/tests/mod.rs index 6fcf8a2ec09..6115c6c429b 100644 --- a/tokio/src/runtime/tests/mod.rs +++ b/tokio/src/runtime/tests/mod.rs @@ -6,7 +6,9 @@ use self::noop_scheduler::NoopSchedule; use self::unowned_wrapper::unowned; mod noop_scheduler { - use crate::runtime::task::{self, Task, TaskHarnessScheduleHooks}; + use crate::runtime::task::{self, Task}; + #[cfg(tokio_unstable)] + use crate::runtime::{OptionalTaskHooksFactory, OptionalTaskHooksFactoryRef}; /// `task::Schedule` implementation that does nothing, for testing. pub(crate) struct NoopSchedule; @@ -20,10 +22,14 @@ mod noop_scheduler { unreachable!(); } - fn hooks(&self) -> TaskHarnessScheduleHooks { - TaskHarnessScheduleHooks { - task_terminate_callback: None, - } + #[cfg(tokio_unstable)] + fn hooks_factory(&self) -> OptionalTaskHooksFactory { + None + } + + #[cfg(tokio_unstable)] + fn hooks_factory_ref(&self) -> OptionalTaskHooksFactoryRef<'_> { + None } } } @@ -41,6 +47,9 @@ mod unowned_wrapper { use tracing::Instrument; let span = tracing::trace_span!("test_span"); let task = task.instrument(span); + #[cfg(tokio_unstable)] + let (task, handle) = crate::runtime::task::unowned(task, NoopSchedule, Id::next(), None); + #[cfg(not(tokio_unstable))] let (task, handle) = crate::runtime::task::unowned(task, NoopSchedule, Id::next()); (task.into_notified(), handle) } @@ -51,6 +60,9 @@ mod unowned_wrapper { T: std::future::Future + Send + 'static, T::Output: Send + 'static, { + #[cfg(tokio_unstable)] + let (task, handle) = crate::runtime::task::unowned(task, NoopSchedule, Id::next(), None); + #[cfg(not(tokio_unstable))] let (task, handle) = crate::runtime::task::unowned(task, NoopSchedule, Id::next()); (task.into_notified(), handle) } diff --git a/tokio/src/runtime/tests/queue.rs b/tokio/src/runtime/tests/queue.rs index 9047f4ad7af..b44e8992fd9 100644 --- a/tokio/src/runtime/tests/queue.rs +++ b/tokio/src/runtime/tests/queue.rs @@ -1,5 +1,4 @@ use crate::runtime::scheduler::multi_thread::{queue, Stats}; - use std::cell::RefCell; use std::thread; use std::time::Duration; diff --git a/tokio/src/runtime/tests/task.rs b/tokio/src/runtime/tests/task.rs index ea48b8e5199..4cf0de69cf0 100644 --- a/tokio/src/runtime/tests/task.rs +++ b/tokio/src/runtime/tests/task.rs @@ -1,8 +1,7 @@ -use crate::runtime::task::{ - self, unowned, Id, JoinHandle, OwnedTasks, Schedule, Task, TaskHarnessScheduleHooks, -}; +use crate::runtime::task::{self, unowned, Id, JoinHandle, OwnedTasks, Schedule, Task}; use crate::runtime::tests::NoopSchedule; - +#[cfg(tokio_unstable)] +use crate::runtime::{OptionalTaskHooksFactory, OptionalTaskHooksFactoryRef}; use std::collections::VecDeque; use std::future::Future; use std::sync::atomic::{AtomicBool, Ordering}; @@ -447,9 +446,13 @@ impl Schedule for Runtime { self.0.core.try_lock().unwrap().queue.push_back(task); } - fn hooks(&self) -> TaskHarnessScheduleHooks { - TaskHarnessScheduleHooks { - task_terminate_callback: None, - } + #[cfg(tokio_unstable)] + fn hooks_factory(&self) -> OptionalTaskHooksFactory { + None + } + + #[cfg(tokio_unstable)] + fn hooks_factory_ref(&self) -> OptionalTaskHooksFactoryRef<'_> { + None } } diff --git a/tokio/src/task/builder.rs b/tokio/src/task/builder.rs index 6053352a01c..c34a2ca462e 100644 --- a/tokio/src/task/builder.rs +++ b/tokio/src/task/builder.rs @@ -44,8 +44,12 @@ use std::{future::Future, io, mem}; /// loop { /// let (socket, _) = listener.accept().await?; /// -/// tokio::task::Builder::new() -/// .name("tcp connection handler") +/// let mut builder = tokio::task::Builder::new(); +/// +/// builder +/// .name("tcp connection handler"); +/// +/// builder /// .spawn(async move { /// // Process each socket concurrently. /// process(socket).await @@ -71,8 +75,9 @@ impl<'a> Builder<'a> { } /// Assigns a name to the task which will be spawned. - pub fn name(&self, name: &'a str) -> Self { - Self { name: Some(name) } + pub fn name(&mut self, name: &'a str) -> &mut Self { + self.name = Some(name); + self } /// Spawns a task with this builder's settings on the current runtime. @@ -91,9 +96,9 @@ impl<'a> Builder<'a> { { let fut_size = mem::size_of::(); Ok(if fut_size > BOX_FUTURE_THRESHOLD { - super::spawn::spawn_inner(Box::pin(future), SpawnMeta::new(self.name, fut_size)) + super::spawn::spawn_inner(Box::pin(future), SpawnMeta::new(self.name, fut_size), None) } else { - super::spawn::spawn_inner(future, SpawnMeta::new(self.name, fut_size)) + super::spawn::spawn_inner(future, SpawnMeta::new(self.name, fut_size), None) }) } @@ -112,9 +117,9 @@ impl<'a> Builder<'a> { { let fut_size = mem::size_of::(); Ok(if fut_size > BOX_FUTURE_THRESHOLD { - handle.spawn_named(Box::pin(future), SpawnMeta::new(self.name, fut_size)) + handle.spawn_named(Box::pin(future), SpawnMeta::new(self.name, fut_size), None) } else { - handle.spawn_named(future, SpawnMeta::new(self.name, fut_size)) + handle.spawn_named(future, SpawnMeta::new(self.name, fut_size), None) }) } @@ -140,9 +145,13 @@ impl<'a> Builder<'a> { { let fut_size = mem::size_of::(); Ok(if fut_size > BOX_FUTURE_THRESHOLD { - super::local::spawn_local_inner(Box::pin(future), SpawnMeta::new(self.name, fut_size)) + super::local::spawn_local_inner( + Box::pin(future), + SpawnMeta::new(self.name, fut_size), + None, + ) } else { - super::local::spawn_local_inner(future, SpawnMeta::new(self.name, fut_size)) + super::local::spawn_local_inner(future, SpawnMeta::new(self.name, fut_size), None) }) } diff --git a/tokio/src/task/join_set.rs b/tokio/src/task/join_set.rs index a156719a067..2c43b9c989d 100644 --- a/tokio/src/task/join_set.rs +++ b/tokio/src/task/join_set.rs @@ -641,9 +641,13 @@ where #[cfg_attr(docsrs, doc(cfg(all(tokio_unstable, feature = "tracing"))))] impl<'a, T: 'static> Builder<'a, T> { /// Assigns a name to the task which will be spawned. - pub fn name(self, name: &'a str) -> Self { - let builder = self.builder.name(name); - Self { builder, ..self } + pub fn name(mut self, name: &'a str) -> Self { + self.builder.name(name); + + Self { + builder: self.builder, + ..self + } } /// Spawn the provided task with this builder's settings and store it in the diff --git a/tokio/src/task/local.rs b/tokio/src/task/local.rs index 95bd6404bec..f100e6dc2ca 100644 --- a/tokio/src/task/local.rs +++ b/tokio/src/task/local.rs @@ -1,9 +1,11 @@ //! Runs `!Send` futures on the current thread. use crate::loom::cell::UnsafeCell; use crate::loom::sync::{Arc, Mutex}; +use crate::runtime::task::{self, JoinHandle, LocalOwnedTasks, Task}; #[cfg(tokio_unstable)] -use crate::runtime; -use crate::runtime::task::{self, JoinHandle, LocalOwnedTasks, Task, TaskHarnessScheduleHooks}; +use crate::runtime::{ + self, OptionalTaskHooks, OptionalTaskHooksFactory, OptionalTaskHooksFactoryRef, +}; use crate::runtime::{context, ThreadId, BOX_FUTURE_THRESHOLD}; use crate::sync::AtomicWaker; use crate::util::trace::SpawnMeta; @@ -371,6 +373,13 @@ cfg_rt! { F::Output: 'static, { let fut_size = std::mem::size_of::(); + #[cfg(tokio_unstable)] + if fut_size > BOX_FUTURE_THRESHOLD { + spawn_local_inner(Box::pin(future), SpawnMeta::new_unnamed(fut_size), None) + } else { + spawn_local_inner(future, SpawnMeta::new_unnamed(fut_size), None) + } + #[cfg(not(tokio_unstable))] if fut_size > BOX_FUTURE_THRESHOLD { spawn_local_inner(Box::pin(future), SpawnMeta::new_unnamed(fut_size)) } else { @@ -380,7 +389,7 @@ cfg_rt! { #[track_caller] - pub(super) fn spawn_local_inner(future: F, meta: SpawnMeta<'_>) -> JoinHandle + pub(super) fn spawn_local_inner(future: F, meta: SpawnMeta<'_>, #[cfg(tokio_unstable)] hooks_override: OptionalTaskHooks) -> JoinHandle where F: Future + 'static, F::Output: 'static { @@ -412,6 +421,9 @@ cfg_rt! { let task = crate::util::trace::task(future, "task", meta, id.as_u64()); // safety: we have verified that this is a `LocalRuntime` owned by the current thread + #[cfg(tokio_unstable)] + unsafe { handle.spawn_local(task, id, hooks_override) } + #[cfg(not(tokio_unstable))] unsafe { handle.spawn_local(task, id) } } else { match CURRENT.with(|LocalData { ctx, .. }| ctx.get()) { @@ -1004,6 +1016,15 @@ impl Context { let future = crate::util::trace::task(future, "local", meta, id.as_u64()); // Safety: called from the thread that owns the `LocalSet` + #[cfg(tokio_unstable)] + let (handle, notified) = { + self.shared.local_state.assert_called_from_owner_thread(); + self.shared + .local_state + .owned + .bind(future, self.shared.clone(), id, None) + }; + #[cfg(not(tokio_unstable))] let (handle, notified) = { self.shared.local_state.assert_called_from_owner_thread(); self.shared @@ -1117,11 +1138,15 @@ impl task::Schedule for Arc { Shared::schedule(self, task); } - // localset does not currently support task hooks - fn hooks(&self) -> TaskHarnessScheduleHooks { - TaskHarnessScheduleHooks { - task_terminate_callback: None, - } + #[cfg(tokio_unstable)] + fn hooks_factory(&self) -> OptionalTaskHooksFactory { + None + } + + // localset does not support task hooks + #[cfg(tokio_unstable)] + fn hooks_factory_ref(&self) -> OptionalTaskHooksFactoryRef<'_> { + None } cfg_unstable! { diff --git a/tokio/src/task/mod.rs b/tokio/src/task/mod.rs index f0c6f71c15a..601793f2f99 100644 --- a/tokio/src/task/mod.rs +++ b/tokio/src/task/mod.rs @@ -311,6 +311,10 @@ cfg_rt! { pub use crate::runtime::task::{Id, id, try_id}; + cfg_unstable! { + pub use spawn::spawn_with_hooks; + } + cfg_trace! { mod builder; pub use builder::Builder; diff --git a/tokio/src/task/spawn.rs b/tokio/src/task/spawn.rs index 7c748226121..ad1fc6eb19c 100644 --- a/tokio/src/task/spawn.rs +++ b/tokio/src/task/spawn.rs @@ -1,4 +1,6 @@ use crate::runtime::BOX_FUTURE_THRESHOLD; +#[cfg(tokio_unstable)] +use crate::runtime::{OptionalTaskHooks, TaskHookHarness}; use crate::task::JoinHandle; use crate::util::trace::SpawnMeta; @@ -169,6 +171,13 @@ cfg_rt! { F::Output: Send + 'static, { let fut_size = std::mem::size_of::(); + #[cfg(tokio_unstable)] + if fut_size > BOX_FUTURE_THRESHOLD { + spawn_inner(Box::pin(future), SpawnMeta::new_unnamed(fut_size), None) + } else { + spawn_inner(future, SpawnMeta::new_unnamed(fut_size), None) + } + #[cfg(not(tokio_unstable))] if fut_size > BOX_FUTURE_THRESHOLD { spawn_inner(Box::pin(future), SpawnMeta::new_unnamed(fut_size)) } else { @@ -176,8 +185,26 @@ cfg_rt! { } } + /// Spawn a future with a custom set of task hooks + #[track_caller] + #[cfg(tokio_unstable)] + pub fn spawn_with_hooks(future: F, hooks: T) -> JoinHandle + where + F: Future + Send + 'static, + F::Output: Send + 'static, + T: TaskHookHarness + Send + Sync + 'static, + { + let fut_size = std::mem::size_of::(); + + if fut_size > BOX_FUTURE_THRESHOLD { + spawn_inner(Box::pin(future), SpawnMeta::new_unnamed(fut_size), Some(Box::new(hooks))) + } else { + spawn_inner(future, SpawnMeta::new_unnamed(fut_size), Some(Box::new(hooks))) + } + } + #[track_caller] - pub(super) fn spawn_inner(future: T, meta: SpawnMeta<'_>) -> JoinHandle + pub(super) fn spawn_inner(future: T, meta: SpawnMeta<'_>, #[cfg(tokio_unstable)] hooks_override: OptionalTaskHooks) -> JoinHandle where T: Future + Send + 'static, T::Output: Send + 'static, @@ -199,6 +226,13 @@ cfg_rt! { let id = task::Id::next(); let task = crate::util::trace::task(future, "task", meta, id.as_u64()); + #[cfg(tokio_unstable)] + return match context::with_current(|handle| handle.spawn(task, id, hooks_override)) { + Ok(join_handle) => join_handle, + Err(e) => panic!("{}", e), + }; + + #[cfg(not(tokio_unstable))] match context::with_current(|handle| handle.spawn(task, id)) { Ok(join_handle) => join_handle, Err(e) => panic!("{}", e), diff --git a/tokio/tests/rt_poll_callbacks.rs b/tokio/tests/rt_poll_callbacks.rs deleted file mode 100644 index 8ccff385772..00000000000 --- a/tokio/tests/rt_poll_callbacks.rs +++ /dev/null @@ -1,128 +0,0 @@ -#![allow(unknown_lints, unexpected_cfgs)] -#![cfg(tokio_unstable)] - -use std::sync::{atomic::AtomicUsize, Arc, Mutex}; - -use tokio::task::yield_now; - -#[cfg(not(target_os = "wasi"))] -#[test] -fn callbacks_fire_multi_thread() { - let poll_start_counter = Arc::new(AtomicUsize::new(0)); - let poll_stop_counter = Arc::new(AtomicUsize::new(0)); - let poll_start = poll_start_counter.clone(); - let poll_stop = poll_stop_counter.clone(); - - let before_task_poll_callback_task_id: Arc>> = - Arc::new(Mutex::new(None)); - let after_task_poll_callback_task_id: Arc>> = - Arc::new(Mutex::new(None)); - - let before_task_poll_callback_task_id_ref = Arc::clone(&before_task_poll_callback_task_id); - let after_task_poll_callback_task_id_ref = Arc::clone(&after_task_poll_callback_task_id); - let rt = tokio::runtime::Builder::new_multi_thread() - .enable_all() - .on_before_task_poll(move |task_meta| { - before_task_poll_callback_task_id_ref - .lock() - .unwrap() - .replace(task_meta.id()); - poll_start_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - }) - .on_after_task_poll(move |task_meta| { - after_task_poll_callback_task_id_ref - .lock() - .unwrap() - .replace(task_meta.id()); - poll_stop_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - }) - .build() - .unwrap(); - let task = rt.spawn(async { - yield_now().await; - yield_now().await; - yield_now().await; - }); - - let spawned_task_id = task.id(); - - rt.block_on(task).expect("task should succeed"); - // We need to drop the runtime to guarantee the workers have exited (and thus called the callback) - drop(rt); - - assert_eq!( - before_task_poll_callback_task_id.lock().unwrap().unwrap(), - spawned_task_id - ); - assert_eq!( - after_task_poll_callback_task_id.lock().unwrap().unwrap(), - spawned_task_id - ); - let actual_count = 4; - assert_eq!( - poll_start.load(std::sync::atomic::Ordering::Relaxed), - actual_count, - "unexpected number of poll starts" - ); - assert_eq!( - poll_stop.load(std::sync::atomic::Ordering::Relaxed), - actual_count, - "unexpected number of poll stops" - ); -} - -#[test] -fn callbacks_fire_current_thread() { - let poll_start_counter = Arc::new(AtomicUsize::new(0)); - let poll_stop_counter = Arc::new(AtomicUsize::new(0)); - let poll_start = poll_start_counter.clone(); - let poll_stop = poll_stop_counter.clone(); - - let before_task_poll_callback_task_id: Arc>> = - Arc::new(Mutex::new(None)); - let after_task_poll_callback_task_id: Arc>> = - Arc::new(Mutex::new(None)); - - let before_task_poll_callback_task_id_ref = Arc::clone(&before_task_poll_callback_task_id); - let after_task_poll_callback_task_id_ref = Arc::clone(&after_task_poll_callback_task_id); - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .on_before_task_poll(move |task_meta| { - before_task_poll_callback_task_id_ref - .lock() - .unwrap() - .replace(task_meta.id()); - poll_start_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - }) - .on_after_task_poll(move |task_meta| { - after_task_poll_callback_task_id_ref - .lock() - .unwrap() - .replace(task_meta.id()); - poll_stop_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - }) - .build() - .unwrap(); - - let task = rt.spawn(async { - yield_now().await; - yield_now().await; - yield_now().await; - }); - - let spawned_task_id = task.id(); - - let _ = rt.block_on(task); - drop(rt); - - assert_eq!( - before_task_poll_callback_task_id.lock().unwrap().unwrap(), - spawned_task_id - ); - assert_eq!( - after_task_poll_callback_task_id.lock().unwrap().unwrap(), - spawned_task_id - ); - assert_eq!(poll_start.load(std::sync::atomic::Ordering::Relaxed), 4); - assert_eq!(poll_stop.load(std::sync::atomic::Ordering::Relaxed), 4); -} diff --git a/tokio/tests/task_builder.rs b/tokio/tests/task_builder.rs index c700f229f9f..63cd9d925f1 100644 --- a/tokio/tests/task_builder.rs +++ b/tokio/tests/task_builder.rs @@ -8,22 +8,22 @@ use tokio::{ #[test] async fn spawn_with_name() { - let result = Builder::new() - .name("name") - .spawn(async { "task executed" }) - .unwrap() - .await; + let mut b = Builder::new(); + + b.name("name"); + + let result = b.spawn(async { "task executed" }).unwrap().await; assert_eq!(result.unwrap(), "task executed"); } #[test] async fn spawn_blocking_with_name() { - let result = Builder::new() - .name("name") - .spawn_blocking(|| "task executed") - .unwrap() - .await; + let mut b = Builder::new(); + + b.name("name"); + + let result = b.spawn_blocking(|| "task executed").unwrap().await; assert_eq!(result.unwrap(), "task executed"); } @@ -33,11 +33,11 @@ async fn spawn_local_with_name() { let unsend_data = Rc::new("task executed"); let result = LocalSet::new() .run_until(async move { - Builder::new() - .name("name") - .spawn_local(async move { unsend_data }) - .unwrap() - .await + let mut b = Builder::new(); + + b.name("name"); + + b.spawn_local(async move { unsend_data }).unwrap().await }) .await; diff --git a/tokio/tests/task_hooks.rs b/tokio/tests/task_hooks.rs index 185b9126cca..2af3cc594ee 100644 --- a/tokio/tests/task_hooks.rs +++ b/tokio/tests/task_hooks.rs @@ -1,75 +1,433 @@ -#![warn(rust_2018_idioms)] -#![cfg(all(feature = "full", tokio_unstable, target_has_atomic = "64"))] +#![cfg(all( + feature = "full", + tokio_unstable, + target_has_atomic = "64", + not(target_arch = "wasm32") +))] -use std::collections::HashSet; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; +use tokio::runtime; +use tokio::runtime::{ + AfterTaskPollContext, BeforeTaskPollContext, OnChildTaskSpawnContext, OnTaskTerminateContext, + OnTopLevelTaskSpawnContext, TaskHookHarness, TaskHookHarnessFactory, +}; -use tokio::runtime::Builder; +#[test] +fn runtime_default_factory() { + let ct = runtime::Builder::new_current_thread(); + let mt = runtime::Builder::new_multi_thread(); + + run_runtime_default_factory(ct); + run_runtime_default_factory(mt); +} + +#[test] +fn parent_child_chaining() { + let ct = runtime::Builder::new_current_thread(); + let mt = runtime::Builder::new_multi_thread(); + + run_parent_child_chaining(ct); + run_parent_child_chaining(mt); +} -const TASKS: usize = 8; -const ITERATIONS: usize = 64; -/// Assert that the spawn task hook always fires when set. #[test] -fn spawn_task_hook_fires() { - let count = Arc::new(AtomicUsize::new(0)); - let count2 = Arc::clone(&count); +fn before_poll() { + let ct = runtime::Builder::new_current_thread(); + let mt = runtime::Builder::new_multi_thread(); - let ids = Arc::new(Mutex::new(HashSet::new())); - let ids2 = Arc::clone(&ids); + run_before_poll(ct); + run_before_poll(mt); +} + +#[test] +fn after_poll() { + let ct = runtime::Builder::new_current_thread(); + let mt = runtime::Builder::new_multi_thread(); - let runtime = Builder::new_current_thread() - .on_task_spawn(move |data| { - ids2.lock().unwrap().insert(data.id()); + run_after_poll(ct); + run_after_poll(mt); +} + +#[test] +fn terminate() { + let ct = runtime::Builder::new_current_thread(); + + run_terminate(ct); +} + +#[test] +fn hook_switching() { + let ct = runtime::Builder::new_current_thread(); + let mt = runtime::Builder::new_multi_thread(); + + run_hook_switching(ct); + run_hook_switching(mt); +} - count2.fetch_add(1, Ordering::SeqCst); +#[test] +fn override_hooks() { + let ct = runtime::Builder::new_current_thread(); + let mt = runtime::Builder::new_multi_thread(); + + run_override(ct); + run_override(mt); +} + +fn run_runtime_default_factory(mut builder: runtime::Builder) { + struct TestFactory { + counter: Arc, + } + + impl TaskHookHarnessFactory for TestFactory { + fn on_top_level_spawn( + &self, + _ctx: &mut OnTopLevelTaskSpawnContext<'_>, + ) -> Option> { + self.counter.fetch_add(1, Ordering::SeqCst); + None + } + } + + let counter = Arc::new(AtomicUsize::new(0)); + + let rt = builder + .hook_harness_factory(TestFactory { + counter: counter.clone(), }) .build() .unwrap(); - for _ in 0..TASKS { - runtime.spawn(std::future::pending::<()>()); + rt.spawn(async {}); + + assert_eq!(counter.load(Ordering::SeqCst), 1); + + let handle = rt.handle(); + + handle.spawn(async {}); + + assert_eq!(counter.load(Ordering::SeqCst), 2); + + rt.block_on(async {}); + + assert_eq!(counter.load(Ordering::SeqCst), 2); + + rt.block_on(async { tokio::spawn(async {}) }); + + assert_eq!(counter.load(Ordering::SeqCst), 3); + + // block on a future which spawns a future and waits for it, which in turn spawns another future + // + // this checks that stuff works from on-worker within a multithreaded runtime + let _ = rt.block_on(async { tokio::spawn(async { tokio::spawn(async {}) }).await }); + + assert_eq!(counter.load(Ordering::SeqCst), 5); +} + +fn run_parent_child_chaining(mut builder: runtime::Builder) { + struct TestFactory { + parent_spawns: Arc, + child_spawns: Arc, + } + + struct TestHooks { + spawns: Arc, + } + + impl TaskHookHarnessFactory for TestFactory { + fn on_top_level_spawn( + &self, + _ctx: &mut OnTopLevelTaskSpawnContext<'_>, + ) -> Option> { + self.parent_spawns.fetch_add(1, Ordering::SeqCst); + + Some(Box::new(TestHooks { + spawns: self.child_spawns.clone(), + })) + } } - let count_realized = count.load(Ordering::SeqCst); - assert_eq!( - TASKS, count_realized, - "Total number of spawned task hook invocations was incorrect, expected {TASKS}, got {}", - count_realized - ); + impl TaskHookHarness for TestHooks { + fn on_child_spawn( + &mut self, + _ctx: &mut OnChildTaskSpawnContext<'_>, + ) -> Option> { + self.spawns.fetch_add(1, Ordering::SeqCst); + + Some(Box::new(Self { + spawns: self.spawns.clone(), + })) + } + } + + let parent_spawns = Arc::new(AtomicUsize::new(0)); + let child_spawns = Arc::new(AtomicUsize::new(0)); + + let rt = builder + .hook_harness_factory(TestFactory { + parent_spawns: parent_spawns.clone(), + child_spawns: child_spawns.clone(), + }) + .build() + .unwrap(); + + rt.spawn(async {}); - let count_ids_realized = ids.lock().unwrap().len(); + assert_eq!(parent_spawns.load(Ordering::SeqCst), 1); + assert_eq!(child_spawns.load(Ordering::SeqCst), 0); - assert_eq!( - TASKS, count_ids_realized, - "Total number of spawned task hook invocations was incorrect, expected {TASKS}, got {}", - count_realized - ); + let _ = rt.block_on(async { tokio::spawn(async { tokio::spawn(async {}) }).await }); + + assert_eq!(parent_spawns.load(Ordering::SeqCst), 2); + assert_eq!(child_spawns.load(Ordering::SeqCst), 1); } -/// Assert that the terminate task hook always fires when set. -#[test] -fn terminate_task_hook_fires() { - let count = Arc::new(AtomicUsize::new(0)); - let count2 = Arc::clone(&count); +fn run_before_poll(mut builder: runtime::Builder) { + struct TestFactory { + polls: Arc, + } + + struct TestHooks { + polls: Arc, + } + + impl TaskHookHarnessFactory for TestFactory { + fn on_top_level_spawn( + &self, + _ctx: &mut OnTopLevelTaskSpawnContext<'_>, + ) -> Option> { + Some(Box::new(TestHooks { + polls: self.polls.clone(), + })) + } + } + + impl TaskHookHarness for TestHooks { + fn before_poll(&mut self, _ctx: &mut BeforeTaskPollContext<'_>) { + self.polls.fetch_add(1, Ordering::SeqCst); + } + } + + let polls = Arc::new(AtomicUsize::new(0)); - let runtime = Builder::new_current_thread() - .on_task_terminate(move |_data| { - count2.fetch_add(1, Ordering::SeqCst); + let rt = builder + .hook_harness_factory(TestFactory { + polls: polls.clone(), }) .build() .unwrap(); - for _ in 0..TASKS { - runtime.spawn(std::future::ready(())); + rt.block_on(async {}); + assert_eq!(polls.load(Ordering::SeqCst), 0); + + let _ = rt.block_on(async { tokio::spawn(async {}).await }); + assert_eq!(polls.load(Ordering::SeqCst), 1); + + let _ = rt.block_on(async { tokio::spawn(async { tokio::spawn(async {}).await }).await }); + assert_eq!(polls.load(Ordering::SeqCst), 4); +} + +fn run_after_poll(mut builder: runtime::Builder) { + struct TestFactory { + polls: Arc, + } + + struct TestHooks { + polls: Arc, } - runtime.block_on(async { - // tick the runtime a bunch to close out tasks - for _ in 0..ITERATIONS { - tokio::task::yield_now().await; + impl TaskHookHarnessFactory for TestFactory { + fn on_top_level_spawn( + &self, + _ctx: &mut OnTopLevelTaskSpawnContext<'_>, + ) -> Option> { + Some(Box::new(TestHooks { + polls: self.polls.clone(), + })) } + } + + impl TaskHookHarness for TestHooks { + fn after_poll(&mut self, _ctx: &mut AfterTaskPollContext<'_>) { + self.polls.fetch_add(1, Ordering::SeqCst); + } + } + + let polls = Arc::new(AtomicUsize::new(0)); + + let rt = builder + .hook_harness_factory(TestFactory { + polls: polls.clone(), + }) + .build() + .unwrap(); + + rt.block_on(async {}); + assert_eq!(polls.load(Ordering::SeqCst), 0); + + let _ = rt.block_on(async { tokio::spawn(async {}).await }); + assert_eq!(polls.load(Ordering::SeqCst), 1); + + let _ = rt.block_on(async { tokio::spawn(async { tokio::spawn(async {}).await }).await }); + assert_eq!(polls.load(Ordering::SeqCst), 4); +} + +fn run_terminate(mut builder: runtime::Builder) { + struct TestFactory { + terminations: Arc, + } + + struct TestHooks { + terminations: Arc, + } + + impl TaskHookHarnessFactory for TestFactory { + fn on_top_level_spawn( + &self, + _ctx: &mut OnTopLevelTaskSpawnContext<'_>, + ) -> Option> { + Some(Box::new(TestHooks { + terminations: self.terminations.clone(), + })) + } + } + + impl TaskHookHarness for TestHooks { + fn on_task_terminate(&mut self, _ctx: &mut OnTaskTerminateContext<'_>) { + self.terminations.fetch_add(1, Ordering::SeqCst); + } + } + + let terminations = Arc::new(AtomicUsize::new(0)); + + let rt = builder + .hook_harness_factory(TestFactory { + terminations: terminations.clone(), + }) + .build() + .unwrap(); + + let _ = rt.block_on(async { tokio::spawn(async { tokio::spawn(async {}).await }).await }); + + assert_eq!(terminations.load(Ordering::SeqCst), 2); +} + +fn run_hook_switching(mut builder: runtime::Builder) { + struct TestFactory { + next_id: Arc, + flag: Arc, + } + + struct TestHooks { + id: usize, + flag: Arc, + } + + impl TaskHookHarnessFactory for TestFactory { + fn on_top_level_spawn( + &self, + _ctx: &mut OnTopLevelTaskSpawnContext<'_>, + ) -> Option> { + Some(Box::new(TestHooks { + id: self.next_id.fetch_add(1, Ordering::SeqCst), + flag: self.flag.clone(), + })) + } + } + + impl TaskHookHarness for TestHooks { + fn before_poll(&mut self, _ctx: &mut BeforeTaskPollContext<'_>) { + self.flag.store(self.id, Ordering::SeqCst); + } + } + + let polls = Arc::new(AtomicUsize::new(0)); + + let rt = builder + .hook_harness_factory(TestFactory { + next_id: Arc::new(Default::default()), + flag: polls.clone(), + }) + .build() + .unwrap(); + + let _ = rt.block_on(async { tokio::spawn(async {}).await }); + assert_eq!(polls.load(Ordering::SeqCst), 0); + + let _ = rt.block_on(async { tokio::spawn(async { tokio::spawn(async {}).await }).await }); + assert_eq!(polls.load(Ordering::SeqCst), 1); + + let _ = rt.block_on(async { tokio::spawn(async {}).await }); + assert_eq!(polls.load(Ordering::SeqCst), 3); +} + +fn run_override(mut builder: runtime::Builder) { + struct TestFactory { + counter: Arc, + } + + struct TestHooks { + counter: Arc, + } + + impl TaskHookHarness for TestHooks { + fn before_poll(&mut self, _ctx: &mut BeforeTaskPollContext<'_>) { + self.counter.fetch_add(1, Ordering::SeqCst); + } + + fn on_child_spawn( + &mut self, + _ctx: &mut OnChildTaskSpawnContext<'_>, + ) -> Option> { + Some(Box::new(Self { + counter: self.counter.clone(), + })) + } + } + + impl TaskHookHarnessFactory for TestFactory { + fn on_top_level_spawn( + &self, + _ctx: &mut OnTopLevelTaskSpawnContext<'_>, + ) -> Option> { + self.counter.fetch_add(1, Ordering::SeqCst); + None + } + } + + let factory_counter = Arc::new(AtomicUsize::new(0)); + let builder_counter = Arc::new(AtomicUsize::new(0)); + + let rt = builder + .hook_harness_factory(TestFactory { + counter: factory_counter.clone(), + }) + .build() + .unwrap(); + + rt.spawn(async {}); + + assert_eq!(factory_counter.load(Ordering::SeqCst), 1); + + let _ = rt.block_on(async { + tokio::task::spawn_with_hooks( + async {}, + TestHooks { + counter: builder_counter.clone(), + }, + ) + .await + }); + + assert_eq!(factory_counter.load(Ordering::SeqCst), 1); + assert_eq!(builder_counter.load(Ordering::SeqCst), 1); + + let _ = rt.block_on(async { + let counter = builder_counter.clone(); + tokio::spawn(async { tokio::task::spawn_with_hooks(async {}, TestHooks { counter }).await }) + .await }); - assert_eq!(TASKS, count.load(Ordering::SeqCst)); + assert_eq!(factory_counter.load(Ordering::SeqCst), 2); + assert_eq!(builder_counter.load(Ordering::SeqCst), 2); } diff --git a/tokio/tests/tracing_task.rs b/tokio/tests/tracing_task.rs index a9317bf5b12..f2adf573a9d 100644 --- a/tokio/tests/tracing_task.rs +++ b/tokio/tests/tracing_task.rs @@ -64,9 +64,11 @@ async fn task_builder_name_recorded() { { let _guard = tracing::subscriber::set_default(subscriber); - task::Builder::new() - .name("test-task") - .spawn(futures::future::ready(())) + let mut b = task::Builder::new(); + + b.name("test-task"); + + b.spawn(futures::future::ready(())) .unwrap() .await .expect("failed to await join handle"); From ef9fcf999d2a1a5e30e8688b45d0e80bb2e5e3aa Mon Sep 17 00:00:00 2001 From: Noah Kennedy Date: Fri, 2 May 2025 15:03:51 -0500 Subject: [PATCH 2/9] fix some loom issues --- .github/workflows/ci.yml | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9816c46a947..4428d1a0239 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,8 +1,8 @@ on: push: - branches: ["master", "tokio-*.x"] + branches: [ "master", "tokio-*.x" ] pull_request: - branches: ["master", "tokio-*.x"] + branches: [ "master", "tokio-*.x" ] name: CI @@ -107,7 +107,7 @@ jobs: - name: Install Rust ${{ env.rust_stable }} uses: dtolnay/rust-toolchain@stable with: - toolchain: ${{ env.rust_stable }} + toolchain: ${{ env.rust_stable }} - name: Install cargo-nextest uses: taiki-e/install-action@v2 with: @@ -139,7 +139,7 @@ jobs: - name: Install Rust ${{ env.rust_stable }} uses: dtolnay/rust-toolchain@stable with: - toolchain: ${{ env.rust_stable }} + toolchain: ${{ env.rust_stable }} - name: Install cargo-nextest uses: taiki-e/install-action@v2 with: @@ -169,7 +169,7 @@ jobs: - name: Install Rust ${{ env.rust_nightly }} uses: dtolnay/rust-toolchain@stable with: - toolchain: ${{ env.rust_nightly }} + toolchain: ${{ env.rust_nightly }} - name: Install cargo-nextest uses: taiki-e/install-action@v2 with: @@ -197,7 +197,7 @@ jobs: - name: Install Rust ${{ env.rust_stable }} uses: dtolnay/rust-toolchain@stable with: - toolchain: ${{ env.rust_stable }} + toolchain: ${{ env.rust_stable }} - name: Install cargo-hack uses: taiki-e/install-action@v2 with: @@ -237,7 +237,7 @@ jobs: - name: Install Rust ${{ env.rust_stable }} uses: dtolnay/rust-toolchain@stable with: - toolchain: ${{ env.rust_stable }} + toolchain: ${{ env.rust_stable }} - name: Enable parking_lot send_guard feature # Inserts the line "plsend = ["parking_lot/send_guard"]" right after [features] @@ -256,7 +256,7 @@ jobs: - name: Install Rust ${{ env.rust_stable }} uses: dtolnay/rust-toolchain@stable with: - toolchain: 1.82 + toolchain: 1.82 - name: Install Valgrind uses: taiki-e/install-action@valgrind @@ -295,7 +295,7 @@ jobs: - name: Install Rust ${{ env.rust_stable }} uses: dtolnay/rust-toolchain@stable with: - toolchain: ${{ env.rust_stable }} + toolchain: ${{ env.rust_stable }} - name: Install cargo-nextest uses: taiki-e/install-action@v2 @@ -329,7 +329,7 @@ jobs: - name: Install Rust ${{ env.rust_stable }} uses: dtolnay/rust-toolchain@stable with: - toolchain: ${{ env.rust_stable }} + toolchain: ${{ env.rust_stable }} - name: Install cargo-nextest uses: taiki-e/install-action@v2 @@ -363,7 +363,7 @@ jobs: - name: Install Rust ${{ env.rust_stable }} uses: dtolnay/rust-toolchain@stable with: - toolchain: ${{ env.rust_stable }} + toolchain: ${{ env.rust_stable }} - name: Install cargo-nextest uses: taiki-e/install-action@v2 with: @@ -842,7 +842,7 @@ jobs: toolchain: ${{ env.rust_stable }} - uses: Swatinem/rust-cache@v2 - name: build --cfg loom - run: cargo test --no-run --lib --features full + run: cargo test --no-run --lib --release --features full working-directory: tokio env: RUSTFLAGS: --cfg loom --cfg tokio_unstable -Dwarnings From 855be676da3c0652b29fe986ffced0ed192bcf54 Mon Sep 17 00:00:00 2001 From: Noah Kennedy Date: Fri, 2 May 2025 15:34:41 -0500 Subject: [PATCH 3/9] try lloyd's fix --- tokio/src/runtime/task/raw.rs | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/tokio/src/runtime/task/raw.rs b/tokio/src/runtime/task/raw.rs index ad2f30677dc..90b3477ecc8 100644 --- a/tokio/src/runtime/task/raw.rs +++ b/tokio/src/runtime/task/raw.rs @@ -215,8 +215,8 @@ impl RawTask { /// Safety: mutual exclusion is required to call this function. pub(crate) fn poll(self) { #[cfg(tokio_unstable)] - self.trailer().hooks.with_mut(|ptr| unsafe { - let _guard = ptr.as_mut().and_then(|x| { + let _guard = self.trailer().hooks.with_mut(|ptr| unsafe { + ptr.as_mut().and_then(|x| { x.as_mut().map(|x| { let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| { x.before_poll(&mut BeforeTaskPollContext { @@ -228,12 +228,9 @@ impl RawTask { (&mut **x) as *mut (dyn TaskHookHarness + Send + Sync + 'static), )) }) - }); - - let vtable = self.header().vtable; - (vtable.poll)(self.ptr); + }) }); - #[cfg(not(tokio_unstable))] + unsafe { let vtable = self.header().vtable; (vtable.poll)(self.ptr); From eb54da46f50a4457bff77a2612cc8142d76f6f4b Mon Sep 17 00:00:00 2001 From: Noah Kennedy Date: Fri, 2 May 2025 15:40:06 -0500 Subject: [PATCH 4/9] dont add release --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4428d1a0239..380dd5c5814 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -842,7 +842,7 @@ jobs: toolchain: ${{ env.rust_stable }} - uses: Swatinem/rust-cache@v2 - name: build --cfg loom - run: cargo test --no-run --lib --release --features full + run: cargo test --no-run --lib --features full working-directory: tokio env: RUSTFLAGS: --cfg loom --cfg tokio_unstable -Dwarnings From dd6c1b72c072dccc7144654cd36624249f21d71d Mon Sep 17 00:00:00 2001 From: Noah Kennedy Date: Fri, 2 May 2025 16:24:43 -0500 Subject: [PATCH 5/9] loom release --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 380dd5c5814..ec73e28dd05 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -842,10 +842,10 @@ jobs: toolchain: ${{ env.rust_stable }} - uses: Swatinem/rust-cache@v2 - name: build --cfg loom - run: cargo test --no-run --lib --features full + run: cargo test --no-run --lib --release --features full working-directory: tokio env: - RUSTFLAGS: --cfg loom --cfg tokio_unstable -Dwarnings + RUSTFLAGS: --cfg loom --cfg tokio_unstable -Dwarnings -Cdebug check-readme: name: Check README From 2769520afe33fcaf4a5fb219b0638315e1da6405 Mon Sep 17 00:00:00 2001 From: Noah Kennedy Date: Fri, 2 May 2025 16:30:42 -0500 Subject: [PATCH 6/9] fix docs --- tokio/src/runtime/task_hooks/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tokio/src/runtime/task_hooks/mod.rs b/tokio/src/runtime/task_hooks/mod.rs index c54c2787991..e2f94323052 100644 --- a/tokio/src/runtime/task_hooks/mod.rs +++ b/tokio/src/runtime/task_hooks/mod.rs @@ -5,7 +5,7 @@ use std::ptr::NonNull; use std::sync::Arc; /// A factory which produces new [`TaskHookHarness`] objects for tasks which either have been -/// spawned in "detached mode" via the builder, or which were spawned from outside the runtime or +/// spawned in "detached mode" via [`tokio::task::spawn_with_hooks`], or which were spawned from outside the runtime or /// from another context where no [`TaskHookHarness`] was present. pub trait TaskHookHarnessFactory { /// Create a new [`TaskHookHarness`] object which the runtime will attach to a given task. From a70726cbc7fe53ad1bb2929ff0bb72b6d7b0b0d7 Mon Sep 17 00:00:00 2001 From: Noah Kennedy Date: Fri, 2 May 2025 16:31:24 -0500 Subject: [PATCH 7/9] debug assertions properly --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ec73e28dd05..4370783d2d4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -845,7 +845,7 @@ jobs: run: cargo test --no-run --lib --release --features full working-directory: tokio env: - RUSTFLAGS: --cfg loom --cfg tokio_unstable -Dwarnings -Cdebug + RUSTFLAGS: --cfg loom --cfg tokio_unstable -Dwarnings -Cdebug-assertions check-readme: name: Check README From 77eb27a9fb5aab29f5e9a0e743f56714caa2ac14 Mon Sep 17 00:00:00 2001 From: Noah Kennedy Date: Fri, 2 May 2025 17:21:31 -0500 Subject: [PATCH 8/9] lol --- tokio/src/runtime/task_hooks/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tokio/src/runtime/task_hooks/mod.rs b/tokio/src/runtime/task_hooks/mod.rs index e2f94323052..c28a0b9dd05 100644 --- a/tokio/src/runtime/task_hooks/mod.rs +++ b/tokio/src/runtime/task_hooks/mod.rs @@ -5,7 +5,7 @@ use std::ptr::NonNull; use std::sync::Arc; /// A factory which produces new [`TaskHookHarness`] objects for tasks which either have been -/// spawned in "detached mode" via [`tokio::task::spawn_with_hooks`], or which were spawned from outside the runtime or +/// spawned in "detached mode" via [`crate::task::spawn_with_hooks`], or which were spawned from outside the runtime or /// from another context where no [`TaskHookHarness`] was present. pub trait TaskHookHarnessFactory { /// Create a new [`TaskHookHarness`] object which the runtime will attach to a given task. From 200cd93a9b2df51ec996bddbe0cdde2cd745cb72 Mon Sep 17 00:00:00 2001 From: noah Date: Sun, 4 May 2025 10:55:03 -0500 Subject: [PATCH 9/9] add actions --- .../runtime/scheduler/current_thread/mod.rs | 40 +++--- .../runtime/scheduler/multi_thread/handle.rs | 20 +-- tokio/src/runtime/task_hooks/mod.rs | 83 +++++++++++-- tokio/tests/task_hooks.rs | 114 ++++++++++++------ 4 files changed, 182 insertions(+), 75 deletions(-) diff --git a/tokio/src/runtime/scheduler/current_thread/mod.rs b/tokio/src/runtime/scheduler/current_thread/mod.rs index a42528ba7d3..0f018ada1c4 100644 --- a/tokio/src/runtime/scheduler/current_thread/mod.rs +++ b/tokio/src/runtime/scheduler/current_thread/mod.rs @@ -462,10 +462,12 @@ impl Handle { parent .map(|parent| { if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| { - parent.on_child_spawn(&mut OnChildTaskSpawnContext { - id, - _phantom: Default::default(), - }) + parent + .on_child_spawn(&mut OnChildTaskSpawnContext { + id, + _phantom: Default::default(), + }) + .hooks })) { r } else { @@ -479,10 +481,12 @@ impl Handle { .or_else(|| { if let Some(hooks) = me.hooks_factory_ref() { if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| { - hooks.on_top_level_spawn(&mut OnTopLevelTaskSpawnContext { - id, - _phantom: Default::default(), - }) + hooks + .on_top_level_spawn(&mut OnTopLevelTaskSpawnContext { + id, + _phantom: Default::default(), + }) + .hooks })) { r } else { @@ -533,10 +537,12 @@ impl Handle { parent .map(|parent| { if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| { - parent.on_child_spawn(&mut OnChildTaskSpawnContext { - id, - _phantom: Default::default(), - }) + parent + .on_child_spawn(&mut OnChildTaskSpawnContext { + id, + _phantom: Default::default(), + }) + .hooks })) { r } else { @@ -550,10 +556,12 @@ impl Handle { .or_else(|| { if let Some(hooks) = me.hooks_factory_ref() { if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| { - hooks.on_top_level_spawn(&mut OnTopLevelTaskSpawnContext { - id, - _phantom: Default::default(), - }) + hooks + .on_top_level_spawn(&mut OnTopLevelTaskSpawnContext { + id, + _phantom: Default::default(), + }) + .hooks })) { r } else { diff --git a/tokio/src/runtime/scheduler/multi_thread/handle.rs b/tokio/src/runtime/scheduler/multi_thread/handle.rs index fa8973a8fab..030910d30f3 100644 --- a/tokio/src/runtime/scheduler/multi_thread/handle.rs +++ b/tokio/src/runtime/scheduler/multi_thread/handle.rs @@ -80,10 +80,12 @@ impl Handle { parent .map(|parent| { if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| { - parent.on_child_spawn(&mut OnChildTaskSpawnContext { - id, - _phantom: Default::default(), - }) + parent + .on_child_spawn(&mut OnChildTaskSpawnContext { + id, + _phantom: Default::default(), + }) + .hooks })) { r } else { @@ -97,10 +99,12 @@ impl Handle { .or_else(|| { if let Some(hooks) = me.hooks_factory_ref() { if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| { - hooks.on_top_level_spawn(&mut OnTopLevelTaskSpawnContext { - id, - _phantom: Default::default(), - }) + hooks + .on_top_level_spawn(&mut OnTopLevelTaskSpawnContext { + id, + _phantom: Default::default(), + }) + .hooks })) { r } else { diff --git a/tokio/src/runtime/task_hooks/mod.rs b/tokio/src/runtime/task_hooks/mod.rs index c28a0b9dd05..51b70cf5a91 100644 --- a/tokio/src/runtime/task_hooks/mod.rs +++ b/tokio/src/runtime/task_hooks/mod.rs @@ -8,11 +8,9 @@ use std::sync::Arc; /// spawned in "detached mode" via [`crate::task::spawn_with_hooks`], or which were spawned from outside the runtime or /// from another context where no [`TaskHookHarness`] was present. pub trait TaskHookHarnessFactory { - /// Create a new [`TaskHookHarness`] object which the runtime will attach to a given task. - fn on_top_level_spawn( - &self, - ctx: &mut OnTopLevelTaskSpawnContext<'_>, - ) -> Option>; + /// Runs a hook which may produce a new [`TaskHookHarness`] object which the runtime will attach to a given task. + fn on_top_level_spawn(&self, ctx: &mut OnTopLevelTaskSpawnContext<'_>) + -> OnTopLevelSpawnAction; } /// Trait for user-provided "harness" objects which are attached to tasks and provide hook @@ -20,24 +18,27 @@ pub trait TaskHookHarnessFactory { #[allow(unused_variables)] pub trait TaskHookHarness { /// Pre-poll task hook which runs arbitrary user logic. - fn before_poll(&mut self, ctx: &mut BeforeTaskPollContext<'_>) {} + fn before_poll(&mut self, ctx: &mut BeforeTaskPollContext<'_>) -> BeforeTaskPollAction { + BeforeTaskPollAction::default() + } /// Post-poll task hook which runs arbitrary user logic. - fn after_poll(&mut self, ctx: &mut AfterTaskPollContext<'_>) {} + fn after_poll(&mut self, ctx: &mut AfterTaskPollContext<'_>) -> AfterTaskPollAction { + AfterTaskPollAction::default() + } /// Task hook which runs when this task spawns a child, unless that child is explicitly spawned /// detached from the parent. /// /// This hook creates a harness for the child, or detaches the child from any instrumentation. - fn on_child_spawn( - &mut self, - ctx: &mut OnChildTaskSpawnContext<'_>, - ) -> Option> { - None + fn on_child_spawn(&mut self, ctx: &mut OnChildTaskSpawnContext<'_>) -> OnChildSpawnAction { + OnChildSpawnAction::default() } /// Task hook which runs on task termination. - fn on_task_terminate(&mut self, ctx: &mut OnTaskTerminateContext<'_>) {} + fn on_task_terminate(&mut self, ctx: &mut OnTaskTerminateContext<'_>) -> OnTaskTerminateAction { + OnTaskTerminateAction::default() + } } pub(crate) type OptionalTaskHooksFactory = @@ -97,3 +98,59 @@ pub struct BeforeTaskPollContext<'a> { pub struct AfterTaskPollContext<'a> { pub(crate) _phantom: PhantomData<&'a ()>, } + +#[derive(Default)] +#[allow(missing_debug_implementations, missing_docs)] +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +#[non_exhaustive] +pub struct OnTopLevelSpawnAction { + pub(crate) hooks: Option>, +} + +impl OnTopLevelSpawnAction { + /// Pass in a set of task hooks for the task. + pub fn set_hooks(&mut self, hooks: T) -> &mut Self + where + T: TaskHookHarness + Send + Sync + 'static, + { + self.hooks = Some(Box::new(hooks)); + self + } +} + +#[derive(Default)] +#[allow(missing_debug_implementations, missing_docs)] +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +#[non_exhaustive] +pub struct OnChildSpawnAction { + pub(crate) hooks: Option>, +} + +impl OnChildSpawnAction { + /// Pass in a set of task hooks for the child task. + pub fn set_hooks(&mut self, hooks: T) -> &mut Self + where + T: TaskHookHarness + Send + Sync + 'static, + { + self.hooks = Some(Box::new(hooks)); + self + } +} + +#[derive(Default)] +#[allow(missing_debug_implementations, missing_docs)] +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +#[non_exhaustive] +pub struct OnTaskTerminateAction {} + +#[derive(Default)] +#[allow(missing_debug_implementations, missing_docs)] +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +#[non_exhaustive] +pub struct BeforeTaskPollAction {} + +#[derive(Default)] +#[allow(missing_debug_implementations, missing_docs)] +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +#[non_exhaustive] +pub struct AfterTaskPollAction {} diff --git a/tokio/tests/task_hooks.rs b/tokio/tests/task_hooks.rs index 2af3cc594ee..e2127388e58 100644 --- a/tokio/tests/task_hooks.rs +++ b/tokio/tests/task_hooks.rs @@ -9,8 +9,9 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use tokio::runtime; use tokio::runtime::{ - AfterTaskPollContext, BeforeTaskPollContext, OnChildTaskSpawnContext, OnTaskTerminateContext, - OnTopLevelTaskSpawnContext, TaskHookHarness, TaskHookHarnessFactory, + AfterTaskPollAction, AfterTaskPollContext, BeforeTaskPollAction, BeforeTaskPollContext, + OnChildSpawnAction, OnChildTaskSpawnContext, OnTaskTerminateAction, OnTaskTerminateContext, + OnTopLevelSpawnAction, OnTopLevelTaskSpawnContext, TaskHookHarness, TaskHookHarnessFactory, }; #[test] @@ -83,9 +84,10 @@ fn run_runtime_default_factory(mut builder: runtime::Builder) { fn on_top_level_spawn( &self, _ctx: &mut OnTopLevelTaskSpawnContext<'_>, - ) -> Option> { + ) -> OnTopLevelSpawnAction { self.counter.fetch_add(1, Ordering::SeqCst); - None + + Default::default() } } @@ -138,25 +140,30 @@ fn run_parent_child_chaining(mut builder: runtime::Builder) { fn on_top_level_spawn( &self, _ctx: &mut OnTopLevelTaskSpawnContext<'_>, - ) -> Option> { + ) -> OnTopLevelSpawnAction { self.parent_spawns.fetch_add(1, Ordering::SeqCst); - Some(Box::new(TestHooks { + let mut a = OnTopLevelSpawnAction::default(); + + a.set_hooks(TestHooks { spawns: self.child_spawns.clone(), - })) + }); + + a } } impl TaskHookHarness for TestHooks { - fn on_child_spawn( - &mut self, - _ctx: &mut OnChildTaskSpawnContext<'_>, - ) -> Option> { + fn on_child_spawn(&mut self, _ctx: &mut OnChildTaskSpawnContext<'_>) -> OnChildSpawnAction { self.spawns.fetch_add(1, Ordering::SeqCst); - Some(Box::new(Self { + let mut a = OnChildSpawnAction::default(); + + a.set_hooks(Self { spawns: self.spawns.clone(), - })) + }); + + a } } @@ -195,16 +202,22 @@ fn run_before_poll(mut builder: runtime::Builder) { fn on_top_level_spawn( &self, _ctx: &mut OnTopLevelTaskSpawnContext<'_>, - ) -> Option> { - Some(Box::new(TestHooks { + ) -> OnTopLevelSpawnAction { + let mut a = OnTopLevelSpawnAction::default(); + + a.set_hooks(TestHooks { polls: self.polls.clone(), - })) + }); + + a } } impl TaskHookHarness for TestHooks { - fn before_poll(&mut self, _ctx: &mut BeforeTaskPollContext<'_>) { + fn before_poll(&mut self, _ctx: &mut BeforeTaskPollContext<'_>) -> BeforeTaskPollAction { self.polls.fetch_add(1, Ordering::SeqCst); + + Default::default() } } @@ -240,16 +253,22 @@ fn run_after_poll(mut builder: runtime::Builder) { fn on_top_level_spawn( &self, _ctx: &mut OnTopLevelTaskSpawnContext<'_>, - ) -> Option> { - Some(Box::new(TestHooks { + ) -> OnTopLevelSpawnAction { + let mut a = OnTopLevelSpawnAction::default(); + + a.set_hooks(TestHooks { polls: self.polls.clone(), - })) + }); + + a } } impl TaskHookHarness for TestHooks { - fn after_poll(&mut self, _ctx: &mut AfterTaskPollContext<'_>) { + fn after_poll(&mut self, _ctx: &mut AfterTaskPollContext<'_>) -> AfterTaskPollAction { self.polls.fetch_add(1, Ordering::SeqCst); + + Default::default() } } @@ -285,16 +304,25 @@ fn run_terminate(mut builder: runtime::Builder) { fn on_top_level_spawn( &self, _ctx: &mut OnTopLevelTaskSpawnContext<'_>, - ) -> Option> { - Some(Box::new(TestHooks { + ) -> OnTopLevelSpawnAction { + let mut a = OnTopLevelSpawnAction::default(); + + a.set_hooks(TestHooks { terminations: self.terminations.clone(), - })) + }); + + a } } impl TaskHookHarness for TestHooks { - fn on_task_terminate(&mut self, _ctx: &mut OnTaskTerminateContext<'_>) { + fn on_task_terminate( + &mut self, + _ctx: &mut OnTaskTerminateContext<'_>, + ) -> OnTaskTerminateAction { self.terminations.fetch_add(1, Ordering::SeqCst); + + Default::default() } } @@ -327,17 +355,23 @@ fn run_hook_switching(mut builder: runtime::Builder) { fn on_top_level_spawn( &self, _ctx: &mut OnTopLevelTaskSpawnContext<'_>, - ) -> Option> { - Some(Box::new(TestHooks { + ) -> OnTopLevelSpawnAction { + let mut a = OnTopLevelSpawnAction::default(); + + a.set_hooks(TestHooks { id: self.next_id.fetch_add(1, Ordering::SeqCst), flag: self.flag.clone(), - })) + }); + + a } } impl TaskHookHarness for TestHooks { - fn before_poll(&mut self, _ctx: &mut BeforeTaskPollContext<'_>) { + fn before_poll(&mut self, _ctx: &mut BeforeTaskPollContext<'_>) -> BeforeTaskPollAction { self.flag.store(self.id, Ordering::SeqCst); + + Default::default() } } @@ -371,17 +405,20 @@ fn run_override(mut builder: runtime::Builder) { } impl TaskHookHarness for TestHooks { - fn before_poll(&mut self, _ctx: &mut BeforeTaskPollContext<'_>) { + fn before_poll(&mut self, _ctx: &mut BeforeTaskPollContext<'_>) -> BeforeTaskPollAction { self.counter.fetch_add(1, Ordering::SeqCst); + + Default::default() } - fn on_child_spawn( - &mut self, - _ctx: &mut OnChildTaskSpawnContext<'_>, - ) -> Option> { - Some(Box::new(Self { + fn on_child_spawn(&mut self, _ctx: &mut OnChildTaskSpawnContext<'_>) -> OnChildSpawnAction { + let mut a = OnChildSpawnAction::default(); + + a.set_hooks(Self { counter: self.counter.clone(), - })) + }); + + a } } @@ -389,9 +426,10 @@ fn run_override(mut builder: runtime::Builder) { fn on_top_level_spawn( &self, _ctx: &mut OnTopLevelTaskSpawnContext<'_>, - ) -> Option> { + ) -> OnTopLevelSpawnAction { self.counter.fetch_add(1, Ordering::SeqCst); - None + + Default::default() } }