44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7- # pyre-unsafe
7+ # pyre-strict
88
99import abc
1010import collections
8989from monarch ._src .actor .sync_state import fake_sync_state
9090from monarch ._src .actor .telemetry import METER
9191from monarch ._src .actor .tensor_engine_shim import actor_rref , actor_send
92+ from opentelemetry .metrics import Counter
93+ from opentelemetry .trace import Tracer
9294from typing_extensions import Self
9395
9496if TYPE_CHECKING :
99101 from monarch ._src .actor .proc_mesh import _ControllerController , ProcMesh
100102from monarch ._src .actor .telemetry import get_monarch_tracer
101103
102- CallMethod = PythonMessageKind .CallMethod
103-
104104logger : logging .Logger = logging .getLogger (__name__ )
105105
106- TRACER = get_monarch_tracer ()
106+ TRACER : Tracer = get_monarch_tracer ()
107107
108108Allocator = ProcessAllocator | LocalAllocator
109109
@@ -164,9 +164,9 @@ def proc(self) -> "ProcMesh":
164164
165165 # this property is used to hold the handles to actors and processes launched by this actor
166166 # in order to keep them alive until this actor exits.
167- _children : "Optional[List[ActorMesh | ProcMesh]]"
167+ _children : "Optional[List[ActorMesh[Any] | ProcMesh]]"
168168
169- def _add_child (self , child : "ActorMesh | ProcMesh" ) -> None :
169+ def _add_child (self , child : "ActorMesh[Any] | ProcMesh" ) -> None :
170170 if self ._children is None :
171171 self ._children = [child ]
172172 else :
@@ -377,7 +377,7 @@ def stop(self, instance: HyInstance) -> "PythonTask[None]":
377377 raise NotImplementedError ("stop()" )
378378
379379 def initialized (self ) -> "PythonTask[None]" :
380- async def empty ():
380+ async def empty () -> None :
381381 pass
382382
383383 return PythonTask .from_coroutine (empty ())
@@ -402,10 +402,10 @@ def __init__(
402402 self ._signature : inspect .Signature = inspect .signature (impl )
403403 self ._explicit_response_port = explicit_response_port
404404
405- def _call_name (self ) -> Any :
405+ def _call_name (self ) -> MethodSpecifier :
406406 return self ._name
407407
408- def _check_arguments (self , args , kwargs ) :
408+ def _check_arguments (self , args : Tuple [ Any , ...], kwargs : Dict [ str , Any ]) -> None :
409409 if self ._explicit_response_port :
410410 self ._signature .bind (None , None , * args , ** kwargs )
411411 else :
@@ -415,7 +415,7 @@ def _send(
415415 self ,
416416 args : Tuple [Any , ...],
417417 kwargs : Dict [str , Any ],
418- port : "Optional[Port]" = None ,
418+ port : "Optional[Port[R] ]" = None ,
419419 selection : Selection = "all" ,
420420 ) -> Extent :
421421 """
@@ -449,7 +449,7 @@ def _port(self, once: bool = False) -> "Tuple[Port[R], PortReceiver[R]]":
449449 r ._set_monitor (monitor )
450450 return (p , r )
451451
452- def _rref (self , args , kwargs ) :
452+ def _rref (self , args : Tuple [ Any , ...], kwargs : Dict [ str , Any ]) -> R :
453453 self ._check_arguments (args , kwargs )
454454 refs , buffer = flatten ((args , kwargs ), _is_ref_or_mailbox )
455455
@@ -479,7 +479,7 @@ def as_endpoint(
479479 * ,
480480 propagate : Propagator = None ,
481481 explicit_response_port : bool = False ,
482- ):
482+ ) -> Any :
483483 if not isinstance (not_an_endpoint , NotAnEndpoint ):
484484 raise ValueError ("expected an method of a spawned actor" )
485485 kind = (
@@ -557,7 +557,7 @@ def _new_with_shape(self, shape: Shape) -> "ValueMesh[R]":
557557 remapped = [self ._hy .get (pos [g ]) for g in shape .ranks ()]
558558 return ValueMesh (shape , remapped )
559559
560- def item (self , ** kwargs ) -> R :
560+ def item (self , ** kwargs : int ) -> R :
561561 """
562562 Get the value at the given coordinates.
563563
@@ -621,10 +621,10 @@ def _ndslice(self) -> NDSlice:
621621 def _labels (self ) -> Iterable [str ]:
622622 return self ._shape .labels
623623
624- def __getstate__ (self ):
624+ def __getstate__ (self ) -> Dict [ str , Any ] :
625625 return {"shape" : self ._shape , "values" : self ._hy .values ()}
626626
627- def __setstate__ (self , state ) :
627+ def __setstate__ (self , state : Dict [ str , Any ]) -> None :
628628 self ._shape = state ["shape" ]
629629 vals = state ["values" ]
630630 self ._hy = HyValueMesh (self ._shape , vals )
@@ -634,7 +634,7 @@ def send(
634634 endpoint : Endpoint [P , R ],
635635 args : Tuple [Any , ...],
636636 kwargs : Dict [str , Any ],
637- port : "Optional[Port]" = None ,
637+ port : "Optional[Port[R] ]" = None ,
638638 selection : Selection = "all" ,
639639) -> None :
640640 """
@@ -690,15 +690,17 @@ def exception(self, obj: Exception) -> None:
690690 PythonMessage (PythonMessageKind .Exception (self ._rank ), _pickle (obj )),
691691 )
692692
693- def __reduce__ (self ):
693+ def __reduce__ (self ) -> Tuple [ Any , Tuple [ Any , ...]] :
694694 """
695695 When Port is sent over the wire, we do not want to send the actor instance
696696 from the current context. Instead, we want to reconstruct the Port with
697697 the receiver's context, since that is where the message will be sent
698698 from through this port.
699699 """
700700
701- def _reconstruct_port (port_ref , rank ):
701+ def _reconstruct_port (
702+ port_ref : PortRef | OncePortRef , rank : Optional [int ]
703+ ) -> "Port[R]" :
702704 instance = context ().actor_instance ._as_rust ()
703705 return Port (port_ref , instance , rank )
704706
@@ -714,7 +716,7 @@ class DroppingPort:
714716 Makes sure any exception sent to it causes the actor to report an exception.
715717 """
716718
717- def __init__ (self ):
719+ def __init__ (self ) -> None :
718720 pass
719721
720722 def send (self , obj : Any ) -> None :
@@ -810,7 +812,7 @@ def recv(self) -> "Future[R]":
810812 def ranked (self ) -> "RankedPortReceiver[R]" :
811813 return RankedPortReceiver [R ](self ._mailbox , self ._receiver , self ._monitor )
812814
813- def _set_monitor (self , monitor : "Optional[Shared[Exception]]" ):
815+ def _set_monitor (self , monitor : "Optional[Shared[Exception]]" ) -> None :
814816 self ._monitor = monitor
815817
816818
@@ -834,7 +836,7 @@ def _process(self, msg: PythonMessage) -> Tuple[int, R]:
834836# we need to signal to the consumer of the PythonTask object that the thread really isn't in an async context.
835837# We do this by blanking out the running event loop during the call to the synchronous actor function.
836838
837- MESSAGES_HANDLED = METER .create_counter ("py_mesages_handled" )
839+ MESSAGES_HANDLED : Counter = METER .create_counter ("py_mesages_handled" )
838840
839841
840842class _Actor :
@@ -968,15 +970,15 @@ async def handle(
968970 pass
969971 raise
970972
971- def _maybe_exit_debugger (self , do_continue = True ) -> None :
973+ def _maybe_exit_debugger (self , do_continue : bool = True ) -> None :
972974 if (pdb_wrapper := DebugContext .get ().pdb_wrapper ) is not None :
973975 if do_continue :
974976 pdb_wrapper .clear_all_breaks ()
975977 pdb_wrapper .do_continue ("" )
976978 pdb_wrapper .end_debug_session ()
977979 DebugContext .set (DebugContext ())
978980
979- def _post_mortem_debug (self , exc_tb ) -> None :
981+ def _post_mortem_debug (self , exc_tb : Any ) -> None :
980982 from monarch ._src .actor .debugger .debug_controller import debug_controller
981983
982984 if (pdb_wrapper := DebugContext .get ().pdb_wrapper ) is not None :
@@ -1005,7 +1007,7 @@ def _handle_undeliverable_message(
10051007 else :
10061008 return False
10071009
1008- def __supervise__ (self , cx : Context , * args , ** kwargs ) -> object :
1010+ def __supervise__ (self , cx : Context , * args : Any , ** kwargs : Any ) -> object :
10091011 _context .set (cx )
10101012 instance = self .instance
10111013 if instance is None :
@@ -1083,7 +1085,7 @@ def _new_with_shape(self, shape: Shape) -> Self:
10831085 )
10841086
10851087 @property
1086- def initialized (self ):
1088+ def initialized (self ) -> Any :
10871089 raise NotImplementedError (
10881090 "actor implementations are not meshes, but we can't convince the typechecker of it..."
10891091 )
@@ -1164,9 +1166,9 @@ def _endpoint(
11641166 self ,
11651167 name : MethodSpecifier ,
11661168 impl : Callable [Concatenate [Any , P ], Awaitable [R ]],
1167- propagator : Any ,
1169+ propagator : Propagator ,
11681170 explicit_response_port : bool ,
1169- ):
1171+ ) -> Any :
11701172 return ActorEndpoint (
11711173 self ._inner ,
11721174 self ._shape ,
@@ -1215,7 +1217,9 @@ def from_actor_id(
12151217 ) -> "ActorMesh[T]" :
12161218 return cls (Class , _SingletonActorAdapator (actor_id ), singleton_shape , None )
12171219
1218- def __reduce_ex__ (self , protocol : ...) -> "Tuple[Type[ActorMesh], Tuple[Any, ...]]" :
1220+ def __reduce_ex__ (
1221+ self , protocol : Any
1222+ ) -> "Tuple[Type[ActorMesh[T]], Tuple[Any, ...]]" :
12191223 return ActorMesh , (self ._class , self ._inner , self ._shape , self ._proc_mesh )
12201224
12211225 @property
@@ -1269,7 +1273,7 @@ def __init__(
12691273 )
12701274 for s in actor_mesh_ref_tb
12711275 )
1272- self .exception_formatted = "" .join (actor_mesh_ref_tb )
1276+ self .exception_formatted : str = "" .join (actor_mesh_ref_tb )
12731277 self .message = message
12741278
12751279 def __str__ (self ) -> str :
0 commit comments