Skip to content

Commit 22fe769

Browse files
committed
[monarch] Make sure context is propagated through PythonTask.spawn_blocking
Similar to `PythonTask.spawn`, we also need to make sure that the monarch context is propagated through `PythonTask.spawn_blocking`. Differential Revision: [D85981930](https://our.internmc.facebook.com/intern/diff/D85981930/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D85981930/)! ghstack-source-id: 320167898 Pull Request resolved: #1730
1 parent b249a03 commit 22fe769

File tree

2 files changed

+60
-4
lines changed

2 files changed

+60
-4
lines changed

monarch_hyperactor/src/pytokio.rs

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,9 @@ impl PyPythonTask {
357357
.import("monarch._src.actor.actor_mesh")?
358358
.getattr("_context")?;
359359
let old_context = _context.call_method1("get", (PyNone::get(py),))?;
360-
_context.call_method1("set", (monarch_context.clone_ref(py),))?;
360+
_context
361+
.call_method1("set", (monarch_context.clone_ref(py),))
362+
.expect("failed to set _context");
361363

362364
let result = match last {
363365
Ok(value) => coroutine_iterator.bind(py).call_method1("send", (value,)),
@@ -367,7 +369,9 @@ impl PyPythonTask {
367369
};
368370

369371
// Reset context() so that when this tokio thread yields, it has its original state.
370-
_context.call_method1("set", (old_context,))?;
372+
_context
373+
.call_method1("set", (old_context,))
374+
.expect("failed to restore _context");
371375
match result {
372376
Ok(task) => Ok(Action::Wait(
373377
task.extract::<Py<PyPythonTask>>()
@@ -413,14 +417,34 @@ impl PyPythonTask {
413417
}
414418

415419
#[staticmethod]
416-
fn spawn_blocking(f: PyObject) -> PyResult<PyShared> {
420+
fn spawn_blocking(py: Python<'_>, f: PyObject) -> PyResult<PyShared> {
417421
let (tx, rx) = watch::channel(None);
418422
let traceback = current_traceback()?;
419423
let traceback1 = traceback
420424
.as_ref()
421425
.map_or_else(|| None, |t| Python::with_gil(|py| Some(t.clone_ref(py))));
426+
let monarch_context = py
427+
.import("monarch._src.actor.actor_mesh")?
428+
.call_method0("context")?
429+
.unbind();
430+
// The `_context` contextvar needs to be propagated through to the thread that
431+
// runs the blocking tokio task. Upon completion, the original value of `_context`
432+
// is restored.
422433
let handle = get_tokio_runtime().spawn_blocking(move || {
423-
let result = Python::with_gil(|py| f.call0(py));
434+
let result = Python::with_gil(|py| {
435+
let _context = py
436+
.import("monarch._src.actor.actor_mesh")?
437+
.getattr("_context")?;
438+
let old_context = _context.call_method1("get", (PyNone::get(py),))?;
439+
_context
440+
.call_method1("set", (monarch_context.clone_ref(py),))
441+
.expect("failed to set _context");
442+
let result = f.call0(py);
443+
_context
444+
.call_method1("set", (old_context,))
445+
.expect("failed to restore _context");
446+
result
447+
});
424448
send_result(tx, result, traceback1);
425449
});
426450
Ok(PyShared {

python/tests/test_python_actors.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1686,3 +1686,35 @@ def test_login_job():
16861686
assert v == "hello!"
16871687

16881688
j.kill()
1689+
1690+
1691+
class TestPytokioActor(Actor):
1692+
@endpoint
1693+
def context_propagated_through_spawn(self) -> None:
1694+
cx = context()
1695+
1696+
async def task():
1697+
assert cx is context()
1698+
1699+
PythonTask.from_coroutine(coro=task()).spawn().block_on()
1700+
1701+
@endpoint
1702+
def context_propagated_through_spawn_blocking(self) -> None:
1703+
cx = context()
1704+
1705+
def task():
1706+
assert cx is context()
1707+
1708+
PythonTask.spawn_blocking(task).block_on()
1709+
1710+
1711+
def test_context_propagated_through_python_task_spawn():
1712+
p = this_host().spawn_procs()
1713+
a = p.spawn("test_pytokio_actor", TestPytokioActor)
1714+
a.context_propagated_through_spawn.call().get()
1715+
1716+
1717+
def test_context_propagated_through_python_task_spawn_blocking():
1718+
p = this_host().spawn_procs()
1719+
a = p.spawn("test_pytokio_actor", TestPytokioActor)
1720+
a.context_propagated_through_spawn_blocking.call().get()

0 commit comments

Comments
 (0)