Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 41 additions & 4 deletions python/monarch/_src/actor/actor_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from abc import abstractproperty

from dataclasses import dataclass

from functools import cache
from pprint import pformat
from textwrap import indent
from traceback import TracebackException
Expand Down Expand Up @@ -257,6 +259,41 @@ def _root_client_context() -> "Context": ...
"monarch.actor_mesh._context"
)


class _ActorFilter(logging.Filter):
def __init__(self) -> None:
super().__init__()

def filter(self, record: Any) -> bool:
ctx = _context.get(None)
if ctx is not None:
record.msg = f"[actor={ctx.actor_instance}] {record.msg}"
return True


@cache
def _init_context_log_handler() -> None:
af: _ActorFilter = _ActorFilter()
logger = logging.getLogger()
for handler in logger.handlers:
handler.addFilter(af)

_original_addHandler: Any = logging.Logger.addHandler

def _patched_addHandler(self: Any, handler: Any) -> None:
_original_addHandler(self, handler)
if af not in handler.filters:
handler.addFilter(af)

# typing: ignore
logging.Logger.addHandler = _patched_addHandler


def _set_context(c: Context) -> None:
_init_context_log_handler()
_context.set(c)


T = TypeVar("T")


Expand Down Expand Up @@ -305,7 +342,7 @@ def context() -> Context:
c = _context.get(None)
if c is None:
c = Context._root_client_context()
_context.set(c)
_set_context(c)

from monarch._src.actor.host_mesh import create_local_host_mesh
from monarch._src.actor.proc_mesh import _get_controller_controller
Expand Down Expand Up @@ -919,7 +956,7 @@ async def handle(
# response_port can be None. If so, then sending to port will drop the response,
# and raise any exceptions to the caller.
try:
_context.set(ctx)
_set_context(ctx)

DebugContext.set(DebugContext())

Expand Down Expand Up @@ -1053,7 +1090,7 @@ def _post_mortem_debug(self, exc_tb: Any) -> None:
def _handle_undeliverable_message(
self, cx: Context, message: UndeliverableMessageEnvelope
) -> bool:
_context.set(cx)
_set_context(cx)
handle_undeliverable = getattr(
self.instance, "_handle_undeliverable_message", None
)
Expand All @@ -1063,7 +1100,7 @@ def _handle_undeliverable_message(
return False

def __supervise__(self, cx: Context, *args: Any, **kwargs: Any) -> object:
_context.set(cx)
_set_context(cx)
instance = self.instance
if instance is None:
# This could happen because of the following reasons. Both
Expand Down
32 changes: 31 additions & 1 deletion python/tests/test_python_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import asyncio
import ctypes
import importlib.resources
import io
import logging
import operator
import os
Expand All @@ -20,6 +21,7 @@
import time
import unittest
import unittest.mock
from contextlib import contextmanager
from tempfile import TemporaryDirectory
from types import ModuleType
from typing import cast, Tuple
Expand Down Expand Up @@ -68,7 +70,6 @@
from monarch.tools.config import defaults
from typing_extensions import assert_type


needs_cuda = pytest.mark.skipif(
not torch.cuda.is_available(),
reason="CUDA not available",
Expand Down Expand Up @@ -1688,9 +1689,34 @@ def test_login_job():
j.kill()


class CaptureLogs:
def __init__(self):
log_stream = io.StringIO()
handler = logging.StreamHandler(log_stream)
handler.setFormatter(logging.Formatter("%(message)s"))

logger = logging.getLogger("capture")
logger.setLevel(logging.INFO)
logger.addHandler(handler)

self.log_stream = log_stream
self.logger = logger

@property
def contents(self) -> str:
return self.log_stream.getvalue()


class Named(Actor):
@endpoint
def report(self):
logs = CaptureLogs()
logs.logger.error("HUH")
assert (
"actor=<root>.<tests.test_python_actors.Named the_name{'f': 0/2}>"
in logs.contents
)

return context().actor_instance.creator, str(context().actor_instance)


Expand All @@ -1706,3 +1732,7 @@ def test_instance_name():
assert result == "<root>.<tests.test_python_actors.Named the_name{'f': 0/2}>"
assert cr.name == "root"
assert str(context().actor_instance) == "<root>"

logs = CaptureLogs()
logs.logger.error("HUH")
assert "actor=<root>" in logs.contents
Loading