From 02ccf8153565e0f4234f4f6cf0daf769c68094a0 Mon Sep 17 00:00:00 2001 From: zdevito Date: Fri, 31 Oct 2025 15:29:58 -0700 Subject: [PATCH] [Logging] put the actor name in the python logger Differential Revision: [D85994688](https://our.internmc.facebook.com/intern/diff/D85994688/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D85994688/)! [ghstack-poisoned] --- python/monarch/_src/actor/actor_mesh.py | 45 ++++++++++++++++++++++--- python/tests/test_python_actors.py | 32 +++++++++++++++++- 2 files changed, 72 insertions(+), 5 deletions(-) diff --git a/python/monarch/_src/actor/actor_mesh.py b/python/monarch/_src/actor/actor_mesh.py index 2b15d48a8..5ee5532bf 100644 --- a/python/monarch/_src/actor/actor_mesh.py +++ b/python/monarch/_src/actor/actor_mesh.py @@ -17,6 +17,8 @@ from abc import abstractproperty from dataclasses import dataclass + +from functools import cache from pprint import pformat from textwrap import indent from traceback import TracebackException @@ -257,6 +259,41 @@ def _root_client_context() -> "Context": ... "monarch.actor_mesh._context" ) + +class _ActorFilter(logging.Filter): + def __init__(self) -> None: + super().__init__() + + def filter(self, record: Any) -> bool: + ctx = _context.get(None) + if ctx is not None: + record.msg = f"[actor={ctx.actor_instance}] {record.msg}" + return True + + +@cache +def _init_context_log_handler() -> None: + af: _ActorFilter = _ActorFilter() + logger = logging.getLogger() + for handler in logger.handlers: + handler.addFilter(af) + + _original_addHandler: Any = logging.Logger.addHandler + + def _patched_addHandler(self: Any, handler: Any) -> None: + _original_addHandler(self, handler) + if af not in handler.filters: + handler.addFilter(af) + + # typing: ignore + logging.Logger.addHandler = _patched_addHandler + + +def _set_context(c: Context) -> None: + _init_context_log_handler() + _context.set(c) + + T = TypeVar("T") @@ -305,7 +342,7 @@ def context() -> Context: c = _context.get(None) if c is None: c = Context._root_client_context() - _context.set(c) + _set_context(c) from monarch._src.actor.host_mesh import create_local_host_mesh from monarch._src.actor.proc_mesh import _get_controller_controller @@ -919,7 +956,7 @@ async def handle( # response_port can be None. If so, then sending to port will drop the response, # and raise any exceptions to the caller. try: - _context.set(ctx) + _set_context(ctx) DebugContext.set(DebugContext()) @@ -1053,7 +1090,7 @@ def _post_mortem_debug(self, exc_tb: Any) -> None: def _handle_undeliverable_message( self, cx: Context, message: UndeliverableMessageEnvelope ) -> bool: - _context.set(cx) + _set_context(cx) handle_undeliverable = getattr( self.instance, "_handle_undeliverable_message", None ) @@ -1063,7 +1100,7 @@ def _handle_undeliverable_message( return False def __supervise__(self, cx: Context, *args: Any, **kwargs: Any) -> object: - _context.set(cx) + _set_context(cx) instance = self.instance if instance is None: # This could happen because of the following reasons. Both diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index 287dcd8af..1b1b3b8f2 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -9,6 +9,7 @@ import asyncio import ctypes import importlib.resources +import io import logging import operator import os @@ -20,6 +21,7 @@ import time import unittest import unittest.mock +from contextlib import contextmanager from tempfile import TemporaryDirectory from types import ModuleType from typing import cast, Tuple @@ -68,7 +70,6 @@ from monarch.tools.config import defaults from typing_extensions import assert_type - needs_cuda = pytest.mark.skipif( not torch.cuda.is_available(), reason="CUDA not available", @@ -1688,9 +1689,34 @@ def test_login_job(): j.kill() +class CaptureLogs: + def __init__(self): + log_stream = io.StringIO() + handler = logging.StreamHandler(log_stream) + handler.setFormatter(logging.Formatter("%(message)s")) + + logger = logging.getLogger("capture") + logger.setLevel(logging.INFO) + logger.addHandler(handler) + + self.log_stream = log_stream + self.logger = logger + + @property + def contents(self) -> str: + return self.log_stream.getvalue() + + class Named(Actor): @endpoint def report(self): + logs = CaptureLogs() + logs.logger.error("HUH") + assert ( + "actor=." + in logs.contents + ) + return context().actor_instance.creator, str(context().actor_instance) @@ -1706,3 +1732,7 @@ def test_instance_name(): assert result == "." assert cr.name == "root" assert str(context().actor_instance) == "" + + logs = CaptureLogs() + logs.logger.error("HUH") + assert "actor=" in logs.contents