Skip to content

Commit a6d1eb1

Browse files
Robert Ruschmeta-codesync[bot]
authored andcommitted
Migrator monarch/python/monarch/_src/actor to pyre-strict (#1710)
Summary: Pull Request resolved: #1710 pyre had a couple of gaps for how we handled endpoints that interfered with full turning typing on that are mostly addressed with pyrefly. This is part 1 of a series of diffs to turn on pyre-strict, leaning someone generously on Any's for now, which then will be followed up by enabling pyrefly. Unlike last time, I think this should also generally preserve the existing behavior on how endpoints resolve, but happy to modify that, since how they pass through args can be funky. Reviewed By: shayne-fletcher Differential Revision: D85356082 fbshipit-source-id: fcd2f5d2c49cf1166080e97de19a83999ac927be
1 parent 958b033 commit a6d1eb1

25 files changed

+318
-216
lines changed

python/monarch/_src/actor/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
# pyre-unsafe
7+
# pyre-strict
88

99
"""
1010
Monarch Actor API

python/monarch/_src/actor/actor_mesh.py

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
# pyre-unsafe
7+
# pyre-strict
88

99
import abc
1010
import collections
@@ -89,6 +89,8 @@
8989
from monarch._src.actor.sync_state import fake_sync_state
9090
from monarch._src.actor.telemetry import METER
9191
from monarch._src.actor.tensor_engine_shim import actor_rref, actor_send
92+
from opentelemetry.metrics import Counter
93+
from opentelemetry.trace import Tracer
9294
from typing_extensions import Self
9395

9496
if TYPE_CHECKING:
@@ -99,11 +101,9 @@
99101
from monarch._src.actor.proc_mesh import _ControllerController, ProcMesh
100102
from monarch._src.actor.telemetry import get_monarch_tracer
101103

102-
CallMethod = PythonMessageKind.CallMethod
103-
104104
logger: logging.Logger = logging.getLogger(__name__)
105105

106-
TRACER = get_monarch_tracer()
106+
TRACER: Tracer = get_monarch_tracer()
107107

108108
Allocator = ProcessAllocator | LocalAllocator
109109

@@ -164,9 +164,9 @@ def proc(self) -> "ProcMesh":
164164

165165
# this property is used to hold the handles to actors and processes launched by this actor
166166
# in order to keep them alive until this actor exits.
167-
_children: "Optional[List[ActorMesh | ProcMesh]]"
167+
_children: "Optional[List[ActorMesh[Any] | ProcMesh]]"
168168

169-
def _add_child(self, child: "ActorMesh | ProcMesh") -> None:
169+
def _add_child(self, child: "ActorMesh[Any] | ProcMesh") -> None:
170170
if self._children is None:
171171
self._children = [child]
172172
else:
@@ -377,7 +377,7 @@ def stop(self, instance: HyInstance) -> "PythonTask[None]":
377377
raise NotImplementedError("stop()")
378378

379379
def initialized(self) -> "PythonTask[None]":
380-
async def empty():
380+
async def empty() -> None:
381381
pass
382382

383383
return PythonTask.from_coroutine(empty())
@@ -402,10 +402,10 @@ def __init__(
402402
self._signature: inspect.Signature = inspect.signature(impl)
403403
self._explicit_response_port = explicit_response_port
404404

405-
def _call_name(self) -> Any:
405+
def _call_name(self) -> MethodSpecifier:
406406
return self._name
407407

408-
def _check_arguments(self, args, kwargs):
408+
def _check_arguments(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
409409
if self._explicit_response_port:
410410
self._signature.bind(None, None, *args, **kwargs)
411411
else:
@@ -415,7 +415,7 @@ def _send(
415415
self,
416416
args: Tuple[Any, ...],
417417
kwargs: Dict[str, Any],
418-
port: "Optional[Port]" = None,
418+
port: "Optional[Port[R]]" = None,
419419
selection: Selection = "all",
420420
) -> Extent:
421421
"""
@@ -449,7 +449,7 @@ def _port(self, once: bool = False) -> "Tuple[Port[R], PortReceiver[R]]":
449449
r._set_monitor(monitor)
450450
return (p, r)
451451

452-
def _rref(self, args, kwargs):
452+
def _rref(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> R:
453453
self._check_arguments(args, kwargs)
454454
refs, buffer = flatten((args, kwargs), _is_ref_or_mailbox)
455455

@@ -479,7 +479,7 @@ def as_endpoint(
479479
*,
480480
propagate: Propagator = None,
481481
explicit_response_port: bool = False,
482-
):
482+
) -> Any:
483483
if not isinstance(not_an_endpoint, NotAnEndpoint):
484484
raise ValueError("expected an method of a spawned actor")
485485
kind = (
@@ -557,7 +557,7 @@ def _new_with_shape(self, shape: Shape) -> "ValueMesh[R]":
557557
remapped = [self._hy.get(pos[g]) for g in shape.ranks()]
558558
return ValueMesh(shape, remapped)
559559

560-
def item(self, **kwargs) -> R:
560+
def item(self, **kwargs: int) -> R:
561561
"""
562562
Get the value at the given coordinates.
563563
@@ -621,10 +621,10 @@ def _ndslice(self) -> NDSlice:
621621
def _labels(self) -> Iterable[str]:
622622
return self._shape.labels
623623

624-
def __getstate__(self):
624+
def __getstate__(self) -> Dict[str, Any]:
625625
return {"shape": self._shape, "values": self._hy.values()}
626626

627-
def __setstate__(self, state):
627+
def __setstate__(self, state: Dict[str, Any]) -> None:
628628
self._shape = state["shape"]
629629
vals = state["values"]
630630
self._hy = HyValueMesh(self._shape, vals)
@@ -634,7 +634,7 @@ def send(
634634
endpoint: Endpoint[P, R],
635635
args: Tuple[Any, ...],
636636
kwargs: Dict[str, Any],
637-
port: "Optional[Port]" = None,
637+
port: "Optional[Port[R]]" = None,
638638
selection: Selection = "all",
639639
) -> None:
640640
"""
@@ -690,15 +690,17 @@ def exception(self, obj: Exception) -> None:
690690
PythonMessage(PythonMessageKind.Exception(self._rank), _pickle(obj)),
691691
)
692692

693-
def __reduce__(self):
693+
def __reduce__(self) -> Tuple[Any, Tuple[Any, ...]]:
694694
"""
695695
When Port is sent over the wire, we do not want to send the actor instance
696696
from the current context. Instead, we want to reconstruct the Port with
697697
the receiver's context, since that is where the message will be sent
698698
from through this port.
699699
"""
700700

701-
def _reconstruct_port(port_ref, rank):
701+
def _reconstruct_port(
702+
port_ref: PortRef | OncePortRef, rank: Optional[int]
703+
) -> "Port[R]":
702704
instance = context().actor_instance._as_rust()
703705
return Port(port_ref, instance, rank)
704706

@@ -714,7 +716,7 @@ class DroppingPort:
714716
Makes sure any exception sent to it causes the actor to report an exception.
715717
"""
716718

717-
def __init__(self):
719+
def __init__(self) -> None:
718720
pass
719721

720722
def send(self, obj: Any) -> None:
@@ -810,7 +812,7 @@ def recv(self) -> "Future[R]":
810812
def ranked(self) -> "RankedPortReceiver[R]":
811813
return RankedPortReceiver[R](self._mailbox, self._receiver, self._monitor)
812814

813-
def _set_monitor(self, monitor: "Optional[Shared[Exception]]"):
815+
def _set_monitor(self, monitor: "Optional[Shared[Exception]]") -> None:
814816
self._monitor = monitor
815817

816818

@@ -834,7 +836,7 @@ def _process(self, msg: PythonMessage) -> Tuple[int, R]:
834836
# we need to signal to the consumer of the PythonTask object that the thread really isn't in an async context.
835837
# We do this by blanking out the running event loop during the call to the synchronous actor function.
836838

837-
MESSAGES_HANDLED = METER.create_counter("py_mesages_handled")
839+
MESSAGES_HANDLED: Counter = METER.create_counter("py_mesages_handled")
838840

839841

840842
class _Actor:
@@ -968,15 +970,15 @@ async def handle(
968970
pass
969971
raise
970972

971-
def _maybe_exit_debugger(self, do_continue=True) -> None:
973+
def _maybe_exit_debugger(self, do_continue: bool = True) -> None:
972974
if (pdb_wrapper := DebugContext.get().pdb_wrapper) is not None:
973975
if do_continue:
974976
pdb_wrapper.clear_all_breaks()
975977
pdb_wrapper.do_continue("")
976978
pdb_wrapper.end_debug_session()
977979
DebugContext.set(DebugContext())
978980

979-
def _post_mortem_debug(self, exc_tb) -> None:
981+
def _post_mortem_debug(self, exc_tb: Any) -> None:
980982
from monarch._src.actor.debugger.debug_controller import debug_controller
981983

982984
if (pdb_wrapper := DebugContext.get().pdb_wrapper) is not None:
@@ -1005,7 +1007,7 @@ def _handle_undeliverable_message(
10051007
else:
10061008
return False
10071009

1008-
def __supervise__(self, cx: Context, *args, **kwargs) -> object:
1010+
def __supervise__(self, cx: Context, *args: Any, **kwargs: Any) -> object:
10091011
_context.set(cx)
10101012
instance = self.instance
10111013
if instance is None:
@@ -1083,7 +1085,7 @@ def _new_with_shape(self, shape: Shape) -> Self:
10831085
)
10841086

10851087
@property
1086-
def initialized(self):
1088+
def initialized(self) -> Any:
10871089
raise NotImplementedError(
10881090
"actor implementations are not meshes, but we can't convince the typechecker of it..."
10891091
)
@@ -1164,9 +1166,9 @@ def _endpoint(
11641166
self,
11651167
name: MethodSpecifier,
11661168
impl: Callable[Concatenate[Any, P], Awaitable[R]],
1167-
propagator: Any,
1169+
propagator: Propagator,
11681170
explicit_response_port: bool,
1169-
):
1171+
) -> Any:
11701172
return ActorEndpoint(
11711173
self._inner,
11721174
self._shape,
@@ -1215,7 +1217,9 @@ def from_actor_id(
12151217
) -> "ActorMesh[T]":
12161218
return cls(Class, _SingletonActorAdapator(actor_id), singleton_shape, None)
12171219

1218-
def __reduce_ex__(self, protocol: ...) -> "Tuple[Type[ActorMesh], Tuple[Any, ...]]":
1220+
def __reduce_ex__(
1221+
self, protocol: Any
1222+
) -> "Tuple[Type[ActorMesh[T]], Tuple[Any, ...]]":
12191223
return ActorMesh, (self._class, self._inner, self._shape, self._proc_mesh)
12201224

12211225
@property
@@ -1269,7 +1273,7 @@ def __init__(
12691273
)
12701274
for s in actor_mesh_ref_tb
12711275
)
1272-
self.exception_formatted = "".join(actor_mesh_ref_tb)
1276+
self.exception_formatted: str = "".join(actor_mesh_ref_tb)
12731277
self.message = message
12741278

12751279
def __str__(self) -> str:

python/monarch/_src/actor/bootstrap.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
# pyre-unsafe
7+
# pyre-strict
88

99

1010
from pathlib import Path
@@ -28,11 +28,12 @@
2828
CA = Union[bytes, Path, Literal["trust_all_connections"]]
2929

3030

31-
def _as_python_task(s: str | Future[str]) -> PythonTask:
31+
def _as_python_task(s: str | Future[str]) -> "PythonTask[str]":
3232
if isinstance(s, str):
33+
s_str: str = s
3334

34-
async def just():
35-
return s
35+
async def just() -> str:
36+
return s_str
3637

3738
return PythonTask.from_coroutine(just())
3839
else:

python/monarch/_src/actor/bootstrap_main.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
# pyre-unsafe
7+
# pyre-strict
88

99
"""
1010
This is the main function for the boostrapping a new process using a ProcessAllocator.
@@ -29,13 +29,14 @@
2929
import monarch._rust_bindings # @manual # noqa: F401
3030

3131

32-
async def main():
32+
async def main() -> None:
3333
from monarch._rust_bindings.monarch_hyperactor.bootstrap import bootstrap_main
3434

35+
# pyre-ignore[12]: bootstrap_main is async but imported from Rust bindings
3536
await bootstrap_main()
3637

3738

38-
def invoke_main():
39+
def invoke_main() -> None:
3940
# if this is invoked with the stdout piped somewhere, then print
4041
# changes its buffering behavior. So we default to the standard
4142
# behavior of std out as if it were a terminal.

python/monarch/_src/actor/code_sync/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
# pyre-unsafe
7+
# pyre-strict
88

99
from monarch._rust_bindings.monarch_extension.code_sync import ( # noqa: F401
1010
CodeSyncMeshClient,

0 commit comments

Comments
 (0)