Skip to content

Commit 4212a46

Browse files
[BUG] Mem leak handles in scheduler (#5590)
## Description of changes _Summarize the changes made by this PR._ - Improvements & Bug fixes - Fixes a memory leak of handles in the scheduler by adding a drop guard to handles and tracking them in a hashmap - New functionality - None ## Test plan _How are these changes tested?_ Added a test to ensure map is empty after schedule finished - [x] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Migration plan None required ## Observability plan None required ## Documentation Changes None required --------- Co-authored-by: propel-code-bot[bot] <203372662+propel-code-bot[bot]@users.noreply.github.com>
1 parent 848ccdc commit 4212a46

File tree

1 file changed

+137
-21
lines changed

1 file changed

+137
-21
lines changed

rust/system/src/scheduler.rs

Lines changed: 137 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,63 @@
11
use parking_lot::RwLock;
2-
use std::fmt::Debug;
3-
use std::sync::Arc;
2+
use std::sync::atomic::AtomicU64;
3+
use std::sync::{Arc, Weak};
44
use std::time::Duration;
5+
use std::{collections::HashMap, fmt::Debug};
56
use tokio::select;
67
use tracing::Span;
78

89
use super::{Component, ComponentContext, Handler, Message};
910

10-
#[derive(Debug)]
1111
pub(crate) struct SchedulerTaskHandle {
1212
join_handle: Option<tokio::task::JoinHandle<()>>,
1313
cancel: tokio_util::sync::CancellationToken,
1414
}
1515

16+
impl Debug for SchedulerTaskHandle {
17+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18+
f.debug_struct("SchedulerTaskHandle").finish()
19+
}
20+
}
21+
22+
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
23+
pub(crate) struct TaskId(u64);
24+
25+
pub(crate) struct HandleGuard {
26+
weak_handles: Weak<RwLock<HashMap<TaskId, SchedulerTaskHandle>>>,
27+
task_id: TaskId,
28+
}
29+
30+
impl Drop for HandleGuard {
31+
fn drop(&mut self) {
32+
if let Some(handles) = self.weak_handles.upgrade() {
33+
let mut handles = handles.write();
34+
handles.remove(&self.task_id);
35+
}
36+
}
37+
}
38+
1639
#[derive(Clone, Debug)]
1740
pub struct Scheduler {
18-
handles: Arc<RwLock<Vec<SchedulerTaskHandle>>>,
41+
handles: Arc<RwLock<HashMap<TaskId, SchedulerTaskHandle>>>,
42+
next_id: Arc<AtomicU64>,
1943
}
2044

2145
impl Scheduler {
2246
pub(crate) fn new() -> Scheduler {
2347
Scheduler {
24-
handles: Arc::new(RwLock::new(Vec::new())),
48+
handles: Arc::new(RwLock::new(HashMap::new())),
49+
next_id: Arc::new(AtomicU64::new(1)),
2550
}
2651
}
2752

53+
/// Allocate the next task ID.
54+
fn allocate_id(&self) -> TaskId {
55+
let id = self
56+
.next_id
57+
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
58+
TaskId(id)
59+
}
60+
2861
/// Schedule a message to be sent to the component after the specified duration.
2962
///
3063
/// `span_factory` is called immediately before sending the scheduled message to the component.
@@ -40,9 +73,17 @@ impl Scheduler {
4073
M: Message,
4174
S: (Fn() -> Option<Span>) + Send + Sync + 'static,
4275
{
76+
let id = self.allocate_id();
77+
let handles_weak = Arc::downgrade(&self.handles);
78+
4379
let cancel = ctx.cancellation_token.clone();
4480
let sender = ctx.receiver().clone();
4581
let handle = tokio::spawn(async move {
82+
let _guard = HandleGuard {
83+
weak_handles: handles_weak,
84+
task_id: id,
85+
};
86+
4687
select! {
4788
_ = cancel.cancelled() => {}
4889
_ = tokio::time::sleep(duration) => {
@@ -61,7 +102,7 @@ impl Scheduler {
61102
join_handle: Some(handle),
62103
cancel: ctx.cancellation_token.clone(),
63104
};
64-
self.handles.write().push(handle);
105+
self.handles.write().insert(id, handle);
65106
}
66107

67108
/// Schedule a message to be sent to the component at a regular interval.
@@ -80,11 +121,16 @@ impl Scheduler {
80121
M: Message + Clone,
81122
S: (Fn() -> Option<Span>) + Send + Sync + 'static,
82123
{
124+
let id = self.allocate_id();
125+
let handles_weak = Arc::downgrade(&self.handles);
83126
let cancel = ctx.cancellation_token.clone();
84-
85127
let sender = ctx.receiver().clone();
86128

87129
let handle = tokio::spawn(async move {
130+
let _guard = HandleGuard {
131+
weak_handles: handles_weak,
132+
task_id: id,
133+
};
88134
let mut counter = 0;
89135
while Self::should_continue(num_times, counter) {
90136
select! {
@@ -109,7 +155,7 @@ impl Scheduler {
109155
join_handle: Some(handle),
110156
cancel: ctx.cancellation_token.clone(),
111157
};
112-
self.handles.write().push(handle);
158+
self.handles.write().insert(id, handle);
113159
}
114160

115161
#[cfg(test)]
@@ -132,7 +178,7 @@ impl Scheduler {
132178
let mut handles = self.handles.write();
133179
handles
134180
.iter_mut()
135-
.flat_map(|h| h.join_handle.take())
181+
.flat_map(|(_, h)| h.join_handle.take())
136182
.collect::<Vec<_>>()
137183
};
138184
for join_handle in handles.iter_mut() {
@@ -148,7 +194,7 @@ impl Scheduler {
148194
pub(crate) fn stop(&self) {
149195
let handles = self.handles.read();
150196
for handle in handles.iter() {
151-
handle.cancel.cancel();
197+
handle.1.cancel.cancel();
152198
}
153199
}
154200
}
@@ -157,45 +203,43 @@ impl Scheduler {
157203
mod tests {
158204
use super::*;
159205
use crate::system::System;
160-
161206
use async_trait::async_trait;
207+
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
162208
use std::sync::Arc;
163209
use std::time::Duration;
164210

165-
use std::sync::atomic::{AtomicUsize, Ordering};
166-
167211
#[derive(Debug)]
168-
struct TestComponent {
212+
struct SimpleScheduleIntervalComponent {
169213
queue_size: usize,
170214
counter: Arc<AtomicUsize>,
171215
}
172216

173217
#[derive(Clone, Debug)]
174218
struct ScheduleMessage {}
175219

176-
impl TestComponent {
220+
impl SimpleScheduleIntervalComponent {
177221
fn new(queue_size: usize, counter: Arc<AtomicUsize>) -> Self {
178-
TestComponent {
222+
SimpleScheduleIntervalComponent {
179223
queue_size,
180224
counter,
181225
}
182226
}
183227
}
184228
#[async_trait]
185-
impl Handler<ScheduleMessage> for TestComponent {
229+
impl Handler<ScheduleMessage> for SimpleScheduleIntervalComponent {
186230
type Result = ();
187231

188232
async fn handle(
189233
&mut self,
190234
_message: ScheduleMessage,
191-
_ctx: &ComponentContext<TestComponent>,
235+
_ctx: &ComponentContext<SimpleScheduleIntervalComponent>,
192236
) {
193237
self.counter.fetch_add(1, Ordering::SeqCst);
194238
}
195239
}
196240

197241
#[async_trait]
198-
impl Component for TestComponent {
242+
impl Component for SimpleScheduleIntervalComponent {
199243
fn get_name() -> &'static str {
200244
"Test component"
201245
}
@@ -204,7 +248,10 @@ mod tests {
204248
self.queue_size
205249
}
206250

207-
async fn on_start(&mut self, ctx: &ComponentContext<TestComponent>) -> () {
251+
async fn on_start(
252+
&mut self,
253+
ctx: &ComponentContext<SimpleScheduleIntervalComponent>,
254+
) -> () {
208255
let duration = Duration::from_millis(100);
209256
ctx.scheduler
210257
.schedule(ScheduleMessage {}, duration, ctx, || None);
@@ -224,12 +271,81 @@ mod tests {
224271
async fn test_schedule() {
225272
let system = System::new();
226273
let counter = Arc::new(AtomicUsize::new(0));
227-
let component = TestComponent::new(10, counter.clone());
274+
let component = SimpleScheduleIntervalComponent::new(10, counter.clone());
228275
let _handle = system.start_component(component);
229276
// yield to allow the component to process the messages
230277
tokio::task::yield_now().await;
231278
// We should have scheduled the message once
232279
system.join().await;
233280
assert_eq!(counter.load(Ordering::SeqCst), 5);
234281
}
282+
283+
#[derive(Debug)]
284+
struct OneMessageComponent {
285+
queue_size: usize,
286+
counter: Arc<AtomicUsize>,
287+
handles_empty_after: Arc<AtomicBool>,
288+
}
289+
290+
impl OneMessageComponent {
291+
fn new(
292+
queue_size: usize,
293+
counter: Arc<AtomicUsize>,
294+
handles_empty_after: Arc<AtomicBool>,
295+
) -> Self {
296+
OneMessageComponent {
297+
queue_size,
298+
counter,
299+
handles_empty_after,
300+
}
301+
}
302+
}
303+
304+
#[async_trait]
305+
impl Component for OneMessageComponent {
306+
fn get_name() -> &'static str {
307+
"OneMessageComponent"
308+
}
309+
310+
fn queue_size(&self) -> usize {
311+
self.queue_size
312+
}
313+
314+
async fn on_start(&mut self, ctx: &ComponentContext<OneMessageComponent>) -> () {
315+
let duration = Duration::from_millis(100);
316+
ctx.scheduler
317+
.schedule(ScheduleMessage {}, duration, ctx, || None);
318+
}
319+
}
320+
321+
#[async_trait]
322+
impl Handler<ScheduleMessage> for OneMessageComponent {
323+
type Result = ();
324+
325+
async fn handle(
326+
&mut self,
327+
_message: ScheduleMessage,
328+
ctx: &ComponentContext<OneMessageComponent>,
329+
) {
330+
self.counter.fetch_add(1, Ordering::SeqCst);
331+
self.handles_empty_after
332+
.store(ctx.scheduler.handles.read().is_empty(), Ordering::SeqCst);
333+
}
334+
}
335+
336+
#[tokio::test]
337+
async fn test_handle_cleaned_up() {
338+
let system = System::new();
339+
let counter = Arc::new(AtomicUsize::new(0));
340+
let handles_empty_after = Arc::new(AtomicBool::new(false));
341+
let component = OneMessageComponent::new(10, counter.clone(), handles_empty_after.clone());
342+
let _handle = system.start_component(component);
343+
// Wait for the 100ms schedule to trigger
344+
tokio::time::sleep(Duration::from_millis(500)).await;
345+
// yield to allow the component to process the messages
346+
tokio::task::yield_now().await;
347+
assert!(handles_empty_after.load(Ordering::SeqCst));
348+
// We should have scheduled the message once
349+
system.join().await;
350+
}
235351
}

0 commit comments

Comments
 (0)