Skip to content

Commit 7fa1e53

Browse files
committed
[monarch] Make sure context is propagated through PythonTask.spawn_blocking
Similar to `PythonTask.spawn`, we also need to make sure that the monarch context is propagated through `PythonTask.spawn_blocking`. Differential Revision: [D85981930](https://our.internmc.facebook.com/intern/diff/D85981930/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D85981930/)! ghstack-source-id: 320165614 Pull Request resolved: #1728
1 parent b249a03 commit 7fa1e53

File tree

2 files changed

+50
-2
lines changed

2 files changed

+50
-2
lines changed

monarch_hyperactor/src/pytokio.rs

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -413,14 +413,30 @@ impl PyPythonTask {
413413
}
414414

415415
#[staticmethod]
416-
fn spawn_blocking(f: PyObject) -> PyResult<PyShared> {
416+
fn spawn_blocking(py: Python<'_>, f: PyObject) -> PyResult<PyShared> {
417417
let (tx, rx) = watch::channel(None);
418418
let traceback = current_traceback()?;
419419
let traceback1 = traceback
420420
.as_ref()
421421
.map_or_else(|| None, |t| Python::with_gil(|py| Some(t.clone_ref(py))));
422+
let monarch_context = py
423+
.import("monarch._src.actor.actor_mesh")?
424+
.call_method0("context")?
425+
.unbind();
426+
// The `_context` contextvar needs to be propagated through to the thread that
427+
// runs the blocking tokio task. Upon completion, the original value of `_context`
428+
// is restored.
422429
let handle = get_tokio_runtime().spawn_blocking(move || {
423-
let result = Python::with_gil(|py| f.call0(py));
430+
let result = Python::with_gil(|py| {
431+
let _context = py
432+
.import("monarch._src.actor.actor_mesh")?
433+
.getattr("_context")?;
434+
let old_context = _context.call_method1("get", (PyNone::get(py),))?;
435+
_context.call_method1("set", (monarch_context.clone_ref(py),))?;
436+
let result = f.call0(py);
437+
_context.call_method1("set", (old_context,))?;
438+
result
439+
});
424440
send_result(tx, result, traceback1);
425441
});
426442
Ok(PyShared {

python/tests/test_python_actors.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1686,3 +1686,35 @@ def test_login_job():
16861686
assert v == "hello!"
16871687

16881688
j.kill()
1689+
1690+
1691+
class TestPytokioActor(Actor):
1692+
@endpoint
1693+
def context_propagated_through_spawn(self) -> None:
1694+
cx = context()
1695+
1696+
async def task():
1697+
assert cx is context()
1698+
1699+
PythonTask.from_coroutine(coro=task()).spawn().block_on()
1700+
1701+
@endpoint
1702+
def context_propagated_through_spawn_blocking(self) -> None:
1703+
cx = context()
1704+
1705+
def task():
1706+
assert cx is context()
1707+
1708+
PythonTask.spawn_blocking(task).block_on()
1709+
1710+
1711+
def test_context_propagated_through_python_task_spawn():
1712+
p = this_host().spawn_procs()
1713+
a = p.spawn("test_pytokio_actor", TestPytokioActor)
1714+
a.context_propagated_through_spawn.call().get()
1715+
1716+
1717+
def test_context_propagated_through_python_task_spawn_blocking():
1718+
p = this_host().spawn_procs()
1719+
a = p.spawn("test_pytokio_actor", TestPytokioActor)
1720+
a.context_propagated_through_spawn_blocking.call().get()

0 commit comments

Comments
 (0)