Skip to content

Commit b02b329

Browse files
dulinrileymeta-codesync[bot]
authored andcommitted
Have proc_mesh.stop try to cleanly exit actors (#1717)
Summary: Pull Request resolved: #1717 Non-local processes that are being stopped with v1 `proc_mesh.stop()` go right to SIGTERM, and don't give actors a chance to stop themselves. LocalHandle processes use "destroy_and_wait" already. That didn't matter a lot before, but we'd like for Actors to have a chance to run cleanups in Drop, and in the future the `Actor::cleanup` trait function. This change asks the ProcMeshAgent to stop the actors on the proc before sending a SIGTERM, which gives a chance for more cooperative cleanup. Note that the process still gets SIGTERM, as "destroy_and_wait" (and by extension, the StopAll message) does not exit the process. We make sure to `await` the StopAll message response, so that we don't try to SIGTERM actors that are already stopping. Reviewed By: shayne-fletcher Differential Revision: D85795859 fbshipit-source-id: 6b4ec9a0f545ed51aaaa353bceef2e16ab0547dc
1 parent 228e8b3 commit b02b329

File tree

6 files changed

+105
-15
lines changed

6 files changed

+105
-15
lines changed

hyperactor/src/host.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ use crate::channel::Rx;
7272
use crate::channel::Tx;
7373
use crate::clock::Clock;
7474
use crate::clock::RealClock;
75+
use crate::context;
7576
use crate::mailbox::BoxableMailboxSender;
7677
use crate::mailbox::DialMailboxRouter;
7778
use crate::mailbox::IntoBoxedMailboxSender as _;
@@ -404,6 +405,7 @@ pub trait SingleTerminate: Send + Sync {
404405
/// Returns a tuple of (polite shutdown actors vec, forceful stop actors vec)
405406
async fn terminate_proc(
406407
&self,
408+
cx: &impl context::Actor,
407409
proc: &ProcId,
408410
timeout: std::time::Duration,
409411
) -> Result<(Vec<ActorId>, Vec<ActorId>), anyhow::Error>;
@@ -444,6 +446,7 @@ pub trait BulkTerminate: Send + Sync {
444446
/// etc.).
445447
async fn terminate_all(
446448
&self,
449+
cx: &impl context::Actor,
447450
timeout: std::time::Duration,
448451
max_in_flight: usize,
449452
) -> TerminateSummary;
@@ -467,21 +470,23 @@ impl<M: ProcManager + BulkTerminate> Host<M> {
467470
/// terminations.
468471
pub async fn terminate_children(
469472
&self,
473+
cx: &impl context::Actor,
470474
timeout: Duration,
471475
max_in_flight: usize,
472476
) -> TerminateSummary {
473-
self.manager.terminate_all(timeout, max_in_flight).await
477+
self.manager.terminate_all(cx, timeout, max_in_flight).await
474478
}
475479
}
476480

477481
#[async_trait::async_trait]
478482
impl<M: ProcManager + SingleTerminate> SingleTerminate for Host<M> {
479483
async fn terminate_proc(
480484
&self,
485+
cx: &impl context::Actor,
481486
proc: &ProcId,
482487
timeout: Duration,
483488
) -> Result<(Vec<ActorId>, Vec<ActorId>), anyhow::Error> {
484-
self.manager.terminate_proc(proc, timeout).await
489+
self.manager.terminate_proc(cx, proc, timeout).await
485490
}
486491
}
487492

@@ -566,6 +571,7 @@ pub trait ProcHandle: Clone + Send + Sync + 'static {
566571
/// termination.
567572
async fn terminate(
568573
&self,
574+
cx: &impl context::Actor,
569575
timeout: Duration,
570576
) -> Result<Self::TerminalStatus, TerminateError<Self::TerminalStatus>>;
571577

@@ -657,6 +663,7 @@ where
657663
{
658664
async fn terminate_all(
659665
&self,
666+
_cx: &impl context::Actor,
660667
timeout: std::time::Duration,
661668
max_in_flight: usize,
662669
) -> TerminateSummary {
@@ -699,6 +706,7 @@ where
699706
{
700707
async fn terminate_proc(
701708
&self,
709+
_cx: &impl context::Actor,
702710
proc: &ProcId,
703711
timeout: std::time::Duration,
704712
) -> Result<(Vec<ActorId>, Vec<ActorId>), anyhow::Error> {
@@ -783,6 +791,7 @@ impl<A: Actor + Referable> ProcHandle for LocalHandle<A> {
783791

784792
async fn terminate(
785793
&self,
794+
_cx: &impl context::Actor,
786795
timeout: Duration,
787796
) -> Result<(), TerminateError<Self::TerminalStatus>> {
788797
let mut proc = {
@@ -1010,6 +1019,7 @@ impl<A: Actor + Referable> ProcHandle for ProcessHandle<A> {
10101019

10111020
async fn terminate(
10121021
&self,
1022+
_cx: &impl context::Actor,
10131023
_deadline: Duration,
10141024
) -> Result<(), TerminateError<Self::TerminalStatus>> {
10151025
Err(TerminateError::Unsupported)
@@ -1441,6 +1451,7 @@ mod tests {
14411451
}
14421452
async fn terminate(
14431453
&self,
1454+
_cx: &impl context::Actor,
14441455
_timeout: Duration,
14451456
) -> Result<Self::TerminalStatus, TerminateError<Self::TerminalStatus>> {
14461457
Err(TerminateError::Unsupported)

hyperactor_mesh/src/bootstrap.rs

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ use hyperactor::clock::RealClock;
4343
use hyperactor::config::CONFIG;
4444
use hyperactor::config::ConfigAttr;
4545
use hyperactor::config::global as config;
46+
use hyperactor::context;
4647
use hyperactor::declare_attrs;
4748
use hyperactor::host;
4849
use hyperactor::host::Host;
@@ -64,6 +65,7 @@ use tokio::sync::watch;
6465
use crate::logging::OutputTarget;
6566
use crate::logging::StreamFwder;
6667
use crate::proc_mesh::mesh_agent::ProcMeshAgent;
68+
use crate::resource::StopAllClient;
6769
use crate::v1;
6870
use crate::v1::host_mesh::mesh_agent::HostAgentMode;
6971
use crate::v1::host_mesh::mesh_agent::HostMeshAgent;
@@ -1242,6 +1244,7 @@ impl hyperactor::host::ProcHandle for BootstrapProcHandle {
12421244
/// or the channel was lost.
12431245
async fn terminate(
12441246
&self,
1247+
cx: &impl context::Actor,
12451248
timeout: Duration,
12461249
) -> Result<ProcStatus, hyperactor::host::TerminateError<Self::TerminalStatus>> {
12471250
const HARD_WAIT_AFTER_KILL: Duration = Duration::from_secs(5);
@@ -1264,6 +1267,30 @@ impl hyperactor::host::ProcHandle for BootstrapProcHandle {
12641267
})?;
12651268

12661269
// Best-effort mark "Stopping" (ok if state races).
1270+
1271+
// Before sending SIGTERM, try to close actors normally. Only works if
1272+
// they are in the Ready state and have an Agent we can message.
1273+
let agent = self.agent_ref();
1274+
if let Some(agent) = agent {
1275+
let mailbox_result = RealClock.timeout(timeout, agent.stop_all(cx)).await;
1276+
if let Err(timeout_err) = mailbox_result {
1277+
// Agent didn't respond in time, proceed with SIGTERM.
1278+
tracing::warn!(
1279+
"ProcMeshAgent {} didn't respond in time to stop proc: {}",
1280+
agent.actor_id(),
1281+
timeout_err,
1282+
);
1283+
} else if let Ok(Err(e)) = mailbox_result {
1284+
// Other mailbox error, proceed with SIGTERM.
1285+
tracing::warn!(
1286+
"ProcMeshAgent {} did not successfully stop all actors: {}",
1287+
agent.actor_id(),
1288+
e
1289+
);
1290+
}
1291+
}
1292+
// After the stop all actors message may be successful, we still need
1293+
// to actually stop the process.
12671294
let _ = self.mark_stopping();
12681295

12691296
// Send SIGTERM (ESRCH is treated as "already gone").
@@ -1885,6 +1912,7 @@ impl hyperactor::host::SingleTerminate for BootstrapProcManager {
18851912
/// Logs a warning for each failure.
18861913
async fn terminate_proc(
18871914
&self,
1915+
cx: &impl context::Actor,
18881916
proc: &ProcId,
18891917
timeout: Duration,
18901918
) -> Result<(Vec<ActorId>, Vec<ActorId>), anyhow::Error> {
@@ -1895,7 +1923,7 @@ impl hyperactor::host::SingleTerminate for BootstrapProcManager {
18951923
};
18961924

18971925
if let Some(h) = proc_handle {
1898-
h.terminate(timeout)
1926+
h.terminate(cx, timeout)
18991927
.await
19001928
.map(|_| (Vec::new(), Vec::new()))
19011929
.map_err(|e| e.into())
@@ -1920,7 +1948,12 @@ impl hyperactor::host::BulkTerminate for BootstrapProcManager {
19201948
/// those that were already terminal), and how many failed.
19211949
///
19221950
/// Logs a warning for each failure.
1923-
async fn terminate_all(&self, timeout: Duration, max_in_flight: usize) -> TerminateSummary {
1951+
async fn terminate_all(
1952+
&self,
1953+
cx: &impl context::Actor,
1954+
timeout: Duration,
1955+
max_in_flight: usize,
1956+
) -> TerminateSummary {
19241957
// Snapshot to avoid holding the lock across awaits.
19251958
let handles: Vec<BootstrapProcHandle> = {
19261959
let guard = self.children.lock().await;
@@ -1931,7 +1964,7 @@ impl hyperactor::host::BulkTerminate for BootstrapProcManager {
19311964
let mut ok = 0usize;
19321965

19331966
let results = stream::iter(handles.into_iter().map(|h| async move {
1934-
match h.terminate(timeout).await {
1967+
match h.terminate(cx, timeout).await {
19351968
Ok(_) | Err(hyperactor::host::TerminateError::AlreadyTerminated(_)) => {
19361969
// Treat "already terminal" as success.
19371970
true
@@ -3321,7 +3354,7 @@ mod tests {
33213354

33223355
let deadline = Duration::from_secs(2);
33233356
match RealClock
3324-
.timeout(deadline * 2, handle.terminate(deadline))
3357+
.timeout(deadline * 2, handle.terminate(&instance, deadline))
33253358
.await
33263359
{
33273360
Err(_) => panic!("terminate() future hung"),

hyperactor_mesh/src/proc_mesh/mesh_agent.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ pub(crate) fn update_event_actor_id(mut event: ActorSupervisionEvent) -> ActorSu
219219
MeshAgentMessage,
220220
resource::CreateOrUpdate<ActorSpec> { cast = true },
221221
resource::Stop { cast = true },
222+
resource::StopAll { cast = true },
222223
resource::GetState<ActorState> { cast = true },
223224
resource::GetRankStatus { cast = true },
224225
]
@@ -272,6 +273,14 @@ impl ProcMeshAgent {
272273
};
273274
proc.spawn::<Self>("agent", agent).await
274275
}
276+
277+
async fn destroy_and_wait<'a>(
278+
&mut self,
279+
cx: &Context<'a, Self>,
280+
timeout: tokio::time::Duration,
281+
) -> Result<(Vec<ActorId>, Vec<ActorId>), anyhow::Error> {
282+
self.proc.destroy_and_wait::<Self>(timeout, Some(cx)).await
283+
}
275284
}
276285

277286
#[async_trait]
@@ -616,6 +625,25 @@ impl Handler<resource::Stop> for ProcMeshAgent {
616625
}
617626
}
618627

628+
#[async_trait]
629+
impl Handler<resource::StopAll> for ProcMeshAgent {
630+
async fn handle(
631+
&mut self,
632+
cx: &Context<Self>,
633+
_message: resource::StopAll,
634+
) -> anyhow::Result<()> {
635+
let timeout = hyperactor::config::global::get(hyperactor::config::STOP_ACTOR_TIMEOUT);
636+
// By passing in the self context, destroy_and_wait will stop this agent
637+
// last, after all others are stopped.
638+
let _stop_result = self.destroy_and_wait(cx, timeout).await?;
639+
for (_, actor_state) in self.actor_states.iter_mut() {
640+
// Mark all actors as stopped.
641+
actor_state.stopped = true;
642+
}
643+
Ok(())
644+
}
645+
}
646+
619647
#[async_trait]
620648
impl Handler<resource::GetRankStatus> for ProcMeshAgent {
621649
async fn handle(

hyperactor_mesh/src/resource.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,22 @@ pub struct Stop {
273273
pub reply: PortRef<StatusOverlay>,
274274
}
275275

276+
/// Stop all resources owned by the receiver of this message.
277+
/// No reply, this is meant to force a stop without waiting for acknowledgement.
278+
#[derive(
279+
Debug,
280+
Clone,
281+
Serialize,
282+
Deserialize,
283+
Named,
284+
Handler,
285+
HandleClient,
286+
RefClient,
287+
Bind,
288+
Unbind
289+
)]
290+
pub struct StopAll {}
291+
276292
/// Retrieve the current state of the resource.
277293
#[derive(Debug, Serialize, Deserialize, Named, Handler, HandleClient, RefClient)]
278294
pub struct GetState<S> {

hyperactor_mesh/src/v1/host_mesh/mesh_agent.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ use hyperactor::Proc;
2727
use hyperactor::ProcId;
2828
use hyperactor::RefClient;
2929
use hyperactor::channel::ChannelTransport;
30+
use hyperactor::context;
3031
use hyperactor::host::Host;
3132
use hyperactor::host::HostError;
3233
use hyperactor::host::LocalProcManager;
@@ -75,13 +76,14 @@ impl HostAgentMode {
7576

7677
async fn terminate_proc(
7778
&self,
79+
cx: &impl context::Actor,
7880
proc: &ProcId,
7981
timeout: Duration,
8082
) -> Result<(Vec<ActorId>, Vec<ActorId>), anyhow::Error> {
8183
#[allow(clippy::match_same_arms)]
8284
match self {
83-
HostAgentMode::Process(host) => host.terminate_proc(proc, timeout).await,
84-
HostAgentMode::Local(host) => host.terminate_proc(proc, timeout).await,
85+
HostAgentMode::Process(host) => host.terminate_proc(cx, proc, timeout).await,
86+
HostAgentMode::Local(host) => host.terminate_proc(cx, proc, timeout).await,
8587
}
8688
}
8789
}
@@ -212,7 +214,7 @@ impl Handler<resource::Stop> for HostMeshAgent {
212214
!*stopped
213215
};
214216
if should_stop {
215-
host.terminate_proc(proc_id, timeout).await?;
217+
host.terminate_proc(&cx, proc_id, timeout).await?;
216218
*stopped = true;
217219
}
218220
// use Stopped as a successful result for Stop.
@@ -328,13 +330,13 @@ impl Handler<ShutdownHost> for HostMeshAgent {
328330
match host_mode {
329331
HostAgentMode::Process(host) => {
330332
let summary = host
331-
.terminate_children(msg.timeout, msg.max_in_flight.clamp(1, 256))
333+
.terminate_children(cx, msg.timeout, msg.max_in_flight.clamp(1, 256))
332334
.await;
333335
tracing::info!(?summary, "terminated children on host");
334336
}
335337
HostAgentMode::Local(host) => {
336338
let summary = host
337-
.terminate_children(msg.timeout, msg.max_in_flight)
339+
.terminate_children(cx, msg.timeout, msg.max_in_flight)
338340
.await;
339341
tracing::info!(?summary, "terminated children on local host");
340342
}

python/tests/test_actor_error.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -749,13 +749,13 @@ async def test_slice_supervision() -> None:
749749
slice_2 = error_mesh.slice(gpus=2)
750750
slice_3 = error_mesh.slice(gpus=3)
751751

752-
# Trigger supervision error on gpus=3
753-
with pytest.raises(SupervisionError, match="did not handle supervision event"):
754-
await slice_3.fail_with_supervision_error.call()
755-
756752
match = (
757753
"Actor .* (is unhealthy with reason:|exited because of the following reason:)"
758754
)
755+
# Trigger supervision error on gpus=3
756+
with pytest.raises(SupervisionError, match=match):
757+
await slice_3.fail_with_supervision_error.call()
758+
759759
# Mesh containing all gpus is unhealthy
760760
with pytest.raises(SupervisionError, match=match):
761761
await error_mesh.check.call()

0 commit comments

Comments
 (0)