Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 52 additions & 4 deletions python/monarch/_src/actor/actor_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@
import collections
import contextvars
import functools
import importlib
import inspect
import itertools
import logging
import threading
from abc import abstractproperty

from dataclasses import dataclass

from functools import cache
from pprint import pformat
from textwrap import indent
from traceback import TracebackException
Expand Down Expand Up @@ -257,6 +260,51 @@ def _root_client_context() -> "Context": ...
"monarch.actor_mesh._context"
)


@cache
def _monarch_actor() -> Any:
return importlib.import_module("monarch.actor")


class _ActorFilter(logging.Filter):
def __init__(self) -> None:
super().__init__()

def filter(self, record: Any) -> bool:
fn = _monarch_actor().per_actor_logging_prefix
ctx = _context.get(None)
if ctx is not None and fn is not None:
record.msg = fn(ctx.actor_instance) + record.msg
return True


def per_actor_logging_prefix(instance: Instance | CreatorInstance) -> str:
return f"[actor={instance}] "


@cache
def _init_context_log_handler() -> None:
af: _ActorFilter = _ActorFilter()
logger = logging.getLogger()
for handler in logger.handlers:
handler.addFilter(af)

_original_addHandler: Any = logging.Logger.addHandler

def _patched_addHandler(self: logging.Logger, hdlr: logging.Handler) -> None:
_original_addHandler(self, hdlr)
if af not in hdlr.filters:
hdlr.addFilter(af)

# pyre-ignore[8]: Intentionally monkey-patching Logger.addHandler
logging.Logger.addHandler = _patched_addHandler


def _set_context(c: Context) -> None:
_init_context_log_handler()
_context.set(c)


T = TypeVar("T")


Expand Down Expand Up @@ -305,7 +353,7 @@ def context() -> Context:
c = _context.get(None)
if c is None:
c = Context._root_client_context()
_context.set(c)
_set_context(c)

from monarch._src.actor.host_mesh import create_local_host_mesh
from monarch._src.actor.proc_mesh import _get_controller_controller
Expand Down Expand Up @@ -919,7 +967,7 @@ async def handle(
# response_port can be None. If so, then sending to port will drop the response,
# and raise any exceptions to the caller.
try:
_context.set(ctx)
_set_context(ctx)

DebugContext.set(DebugContext())

Expand Down Expand Up @@ -1053,7 +1101,7 @@ def _post_mortem_debug(self, exc_tb: Any) -> None:
def _handle_undeliverable_message(
self, cx: Context, message: UndeliverableMessageEnvelope
) -> bool:
_context.set(cx)
_set_context(cx)
handle_undeliverable = getattr(
self.instance, "_handle_undeliverable_message", None
)
Expand All @@ -1063,7 +1111,7 @@ def _handle_undeliverable_message(
return False

def __supervise__(self, cx: Context, *args: Any, **kwargs: Any) -> object:
_context.set(cx)
_set_context(cx)
instance = self.instance
if instance is None:
# This could happen because of the following reasons. Both
Expand Down
2 changes: 2 additions & 0 deletions python/monarch/actor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
current_size,
enable_transport,
Endpoint,
per_actor_logging_prefix,
Point,
Port,
PortReceiver,
Expand Down Expand Up @@ -107,5 +108,6 @@
"Context",
"ChannelTransport",
"unhandled_fault_hook",
"per_actor_logging_prefix",
"MeshFailure",
]
41 changes: 40 additions & 1 deletion python/tests/test_python_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import asyncio
import ctypes
import importlib.resources
import io
import logging
import operator
import os
Expand All @@ -20,6 +21,7 @@
import time
import unittest
import unittest.mock
from contextlib import contextmanager
from tempfile import TemporaryDirectory
from types import ModuleType
from typing import Any, cast, Tuple
Expand Down Expand Up @@ -68,7 +70,6 @@
from monarch.tools.config import defaults
from typing_extensions import assert_type


needs_cuda = pytest.mark.skipif(
not torch.cuda.is_available(),
reason="CUDA not available",
Expand Down Expand Up @@ -1733,9 +1734,31 @@ def test_setup_async() -> None:
time.sleep(10)


class CaptureLogs:
def __init__(self):
log_stream = io.StringIO()
handler = logging.StreamHandler(log_stream)
handler.setFormatter(logging.Formatter("%(message)s"))

logger = logging.getLogger("capture")
logger.setLevel(logging.INFO)
logger.addHandler(handler)

self.log_stream = log_stream
self.logger = logger

@property
def contents(self) -> str:
return self.log_stream.getvalue()


class Named(Actor):
@endpoint
def report(self) -> Any:
logs = CaptureLogs()
logs.logger.error("HUH")
assert "test_python_actors.Named the_name{'f': 0/2}>" in logs.contents

return context().actor_instance.creator, str(context().actor_instance)


Expand All @@ -1751,3 +1774,19 @@ def test_instance_name():
assert "test_python_actors.Named the_name{'f': 0/2}>" in result
assert cr.name == "root"
assert str(context().actor_instance) == "<root>"

logs = CaptureLogs()
logs.logger.error("HUH")
assert "actor=<root>" in logs.contents
default = monarch.actor.per_actor_logging_prefix
try:
monarch.actor.per_actor_logging_prefix = lambda inst: "<test>"
logs = CaptureLogs()
logs.logger.error("HUH")
assert "<test>" in logs.contents
monarch.actor.per_actor_logging_prefix = None
# make sure we can set _per_actor_logging_prefix to none.
logs = CaptureLogs()
logs.logger.error("HUH")
finally:
monarch.actor.per_actor_logging_prefix = default
Loading