11use parking_lot:: RwLock ;
2- use std:: fmt :: Debug ;
3- use std:: sync:: Arc ;
2+ use std:: sync :: atomic :: AtomicU64 ;
3+ use std:: sync:: { Arc , Weak } ;
44use std:: time:: Duration ;
5+ use std:: { collections:: HashMap , fmt:: Debug } ;
56use tokio:: select;
67use tracing:: Span ;
78
89use super :: { Component , ComponentContext , Handler , Message } ;
910
10- #[ derive( Debug ) ]
1111pub ( crate ) struct SchedulerTaskHandle {
1212 join_handle : Option < tokio:: task:: JoinHandle < ( ) > > ,
1313 cancel : tokio_util:: sync:: CancellationToken ,
1414}
1515
16+ impl Debug for SchedulerTaskHandle {
17+ fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
18+ f. debug_struct ( "SchedulerTaskHandle" ) . finish ( )
19+ }
20+ }
21+
22+ #[ derive( Copy , Clone , Eq , PartialEq , Hash , Debug ) ]
23+ pub ( crate ) struct TaskId ( u64 ) ;
24+
25+ pub ( crate ) struct HandleGuard {
26+ weak_handles : Weak < RwLock < HashMap < TaskId , SchedulerTaskHandle > > > ,
27+ task_id : TaskId ,
28+ }
29+
30+ impl Drop for HandleGuard {
31+ fn drop ( & mut self ) {
32+ if let Some ( handles) = self . weak_handles . upgrade ( ) {
33+ let mut handles = handles. write ( ) ;
34+ handles. remove ( & self . task_id ) ;
35+ }
36+ }
37+ }
38+
1639#[ derive( Clone , Debug ) ]
1740pub struct Scheduler {
18- handles : Arc < RwLock < Vec < SchedulerTaskHandle > > > ,
41+ handles : Arc < RwLock < HashMap < TaskId , SchedulerTaskHandle > > > ,
42+ next_id : Arc < AtomicU64 > ,
1943}
2044
2145impl Scheduler {
2246 pub ( crate ) fn new ( ) -> Scheduler {
2347 Scheduler {
24- handles : Arc :: new ( RwLock :: new ( Vec :: new ( ) ) ) ,
48+ handles : Arc :: new ( RwLock :: new ( HashMap :: new ( ) ) ) ,
49+ next_id : Arc :: new ( AtomicU64 :: new ( 1 ) ) ,
2550 }
2651 }
2752
53+ /// Allocate the next task ID.
54+ fn allocate_id ( & self ) -> TaskId {
55+ let id = self
56+ . next_id
57+ . fetch_add ( 1 , std:: sync:: atomic:: Ordering :: Relaxed ) ;
58+ TaskId ( id)
59+ }
60+
2861 /// Schedule a message to be sent to the component after the specified duration.
2962 ///
3063 /// `span_factory` is called immediately before sending the scheduled message to the component.
@@ -40,9 +73,17 @@ impl Scheduler {
4073 M : Message ,
4174 S : ( Fn ( ) -> Option < Span > ) + Send + Sync + ' static ,
4275 {
76+ let id = self . allocate_id ( ) ;
77+ let handles_weak = Arc :: downgrade ( & self . handles ) ;
78+
4379 let cancel = ctx. cancellation_token . clone ( ) ;
4480 let sender = ctx. receiver ( ) . clone ( ) ;
4581 let handle = tokio:: spawn ( async move {
82+ let _guard = HandleGuard {
83+ weak_handles : handles_weak,
84+ task_id : id,
85+ } ;
86+
4687 select ! {
4788 _ = cancel. cancelled( ) => { }
4889 _ = tokio:: time:: sleep( duration) => {
@@ -61,7 +102,7 @@ impl Scheduler {
61102 join_handle : Some ( handle) ,
62103 cancel : ctx. cancellation_token . clone ( ) ,
63104 } ;
64- self . handles . write ( ) . push ( handle) ;
105+ self . handles . write ( ) . insert ( id , handle) ;
65106 }
66107
67108 /// Schedule a message to be sent to the component at a regular interval.
@@ -80,11 +121,16 @@ impl Scheduler {
80121 M : Message + Clone ,
81122 S : ( Fn ( ) -> Option < Span > ) + Send + Sync + ' static ,
82123 {
124+ let id = self . allocate_id ( ) ;
125+ let handles_weak = Arc :: downgrade ( & self . handles ) ;
83126 let cancel = ctx. cancellation_token . clone ( ) ;
84-
85127 let sender = ctx. receiver ( ) . clone ( ) ;
86128
87129 let handle = tokio:: spawn ( async move {
130+ let _guard = HandleGuard {
131+ weak_handles : handles_weak,
132+ task_id : id,
133+ } ;
88134 let mut counter = 0 ;
89135 while Self :: should_continue ( num_times, counter) {
90136 select ! {
@@ -109,7 +155,7 @@ impl Scheduler {
109155 join_handle : Some ( handle) ,
110156 cancel : ctx. cancellation_token . clone ( ) ,
111157 } ;
112- self . handles . write ( ) . push ( handle) ;
158+ self . handles . write ( ) . insert ( id , handle) ;
113159 }
114160
115161 #[ cfg( test) ]
@@ -132,7 +178,7 @@ impl Scheduler {
132178 let mut handles = self . handles . write ( ) ;
133179 handles
134180 . iter_mut ( )
135- . flat_map ( |h | h. join_handle . take ( ) )
181+ . flat_map ( |( _ , h ) | h. join_handle . take ( ) )
136182 . collect :: < Vec < _ > > ( )
137183 } ;
138184 for join_handle in handles. iter_mut ( ) {
@@ -148,7 +194,7 @@ impl Scheduler {
148194 pub ( crate ) fn stop ( & self ) {
149195 let handles = self . handles . read ( ) ;
150196 for handle in handles. iter ( ) {
151- handle. cancel . cancel ( ) ;
197+ handle. 1 . cancel . cancel ( ) ;
152198 }
153199 }
154200}
@@ -157,45 +203,43 @@ impl Scheduler {
157203mod tests {
158204 use super :: * ;
159205 use crate :: system:: System ;
160-
161206 use async_trait:: async_trait;
207+ use std:: sync:: atomic:: { AtomicBool , AtomicUsize , Ordering } ;
162208 use std:: sync:: Arc ;
163209 use std:: time:: Duration ;
164210
165- use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
166-
167211 #[ derive( Debug ) ]
168- struct TestComponent {
212+ struct SimpleScheduleIntervalComponent {
169213 queue_size : usize ,
170214 counter : Arc < AtomicUsize > ,
171215 }
172216
173217 #[ derive( Clone , Debug ) ]
174218 struct ScheduleMessage { }
175219
176- impl TestComponent {
220+ impl SimpleScheduleIntervalComponent {
177221 fn new ( queue_size : usize , counter : Arc < AtomicUsize > ) -> Self {
178- TestComponent {
222+ SimpleScheduleIntervalComponent {
179223 queue_size,
180224 counter,
181225 }
182226 }
183227 }
184228 #[ async_trait]
185- impl Handler < ScheduleMessage > for TestComponent {
229+ impl Handler < ScheduleMessage > for SimpleScheduleIntervalComponent {
186230 type Result = ( ) ;
187231
188232 async fn handle (
189233 & mut self ,
190234 _message : ScheduleMessage ,
191- _ctx : & ComponentContext < TestComponent > ,
235+ _ctx : & ComponentContext < SimpleScheduleIntervalComponent > ,
192236 ) {
193237 self . counter . fetch_add ( 1 , Ordering :: SeqCst ) ;
194238 }
195239 }
196240
197241 #[ async_trait]
198- impl Component for TestComponent {
242+ impl Component for SimpleScheduleIntervalComponent {
199243 fn get_name ( ) -> & ' static str {
200244 "Test component"
201245 }
@@ -204,7 +248,10 @@ mod tests {
204248 self . queue_size
205249 }
206250
207- async fn on_start ( & mut self , ctx : & ComponentContext < TestComponent > ) -> ( ) {
251+ async fn on_start (
252+ & mut self ,
253+ ctx : & ComponentContext < SimpleScheduleIntervalComponent > ,
254+ ) -> ( ) {
208255 let duration = Duration :: from_millis ( 100 ) ;
209256 ctx. scheduler
210257 . schedule ( ScheduleMessage { } , duration, ctx, || None ) ;
@@ -224,12 +271,81 @@ mod tests {
224271 async fn test_schedule ( ) {
225272 let system = System :: new ( ) ;
226273 let counter = Arc :: new ( AtomicUsize :: new ( 0 ) ) ;
227- let component = TestComponent :: new ( 10 , counter. clone ( ) ) ;
274+ let component = SimpleScheduleIntervalComponent :: new ( 10 , counter. clone ( ) ) ;
228275 let _handle = system. start_component ( component) ;
229276 // yield to allow the component to process the messages
230277 tokio:: task:: yield_now ( ) . await ;
231278 // We should have scheduled the message once
232279 system. join ( ) . await ;
233280 assert_eq ! ( counter. load( Ordering :: SeqCst ) , 5 ) ;
234281 }
282+
283+ #[ derive( Debug ) ]
284+ struct OneMessageComponent {
285+ queue_size : usize ,
286+ counter : Arc < AtomicUsize > ,
287+ handles_empty_after : Arc < AtomicBool > ,
288+ }
289+
290+ impl OneMessageComponent {
291+ fn new (
292+ queue_size : usize ,
293+ counter : Arc < AtomicUsize > ,
294+ handles_empty_after : Arc < AtomicBool > ,
295+ ) -> Self {
296+ OneMessageComponent {
297+ queue_size,
298+ counter,
299+ handles_empty_after,
300+ }
301+ }
302+ }
303+
304+ #[ async_trait]
305+ impl Component for OneMessageComponent {
306+ fn get_name ( ) -> & ' static str {
307+ "OneMessageComponent"
308+ }
309+
310+ fn queue_size ( & self ) -> usize {
311+ self . queue_size
312+ }
313+
314+ async fn on_start ( & mut self , ctx : & ComponentContext < OneMessageComponent > ) -> ( ) {
315+ let duration = Duration :: from_millis ( 100 ) ;
316+ ctx. scheduler
317+ . schedule ( ScheduleMessage { } , duration, ctx, || None ) ;
318+ }
319+ }
320+
321+ #[ async_trait]
322+ impl Handler < ScheduleMessage > for OneMessageComponent {
323+ type Result = ( ) ;
324+
325+ async fn handle (
326+ & mut self ,
327+ _message : ScheduleMessage ,
328+ ctx : & ComponentContext < OneMessageComponent > ,
329+ ) {
330+ self . counter . fetch_add ( 1 , Ordering :: SeqCst ) ;
331+ self . handles_empty_after
332+ . store ( ctx. scheduler . handles . read ( ) . is_empty ( ) , Ordering :: SeqCst ) ;
333+ }
334+ }
335+
336+ #[ tokio:: test]
337+ async fn test_handle_cleaned_up ( ) {
338+ let system = System :: new ( ) ;
339+ let counter = Arc :: new ( AtomicUsize :: new ( 0 ) ) ;
340+ let handles_empty_after = Arc :: new ( AtomicBool :: new ( false ) ) ;
341+ let component = OneMessageComponent :: new ( 10 , counter. clone ( ) , handles_empty_after. clone ( ) ) ;
342+ let _handle = system. start_component ( component) ;
343+ // Wait for the 100ms schedule to trigger
344+ tokio:: time:: sleep ( Duration :: from_millis ( 500 ) ) . await ;
345+ // yield to allow the component to process the messages
346+ tokio:: task:: yield_now ( ) . await ;
347+ assert ! ( handles_empty_after. load( Ordering :: SeqCst ) ) ;
348+ // We should have scheduled the message once
349+ system. join ( ) . await ;
350+ }
235351}
0 commit comments