2020from typing import Any , Callable
2121
2222import zmq
23+ import zmq_anyio
2324from anyio import sleep
2425from jupyter_client .session import extract_header
2526
@@ -48,7 +49,7 @@ class IOPubThread:
4849 whose IO is always run in a thread.
4950 """
5051
51- def __init__ (self , socket , pipe = False ):
52+ def __init__ (self , socket : zmq_anyio . Socket , pipe : bool = False ):
5253 """Create IOPub thread
5354
5455 Parameters
@@ -61,10 +62,7 @@ def __init__(self, socket, pipe=False):
6162 """
6263 # ensure all of our sockets as sync zmq.Sockets
6364 # don't create async wrappers until we are within the appropriate coroutines
64- self .socket : zmq .Socket [bytes ] | None = zmq .Socket (socket )
65- if self .socket .context is None :
66- # bug in pyzmq, shadow socket doesn't always inherit context attribute
67- self .socket .context = socket .context # type:ignore[unreachable]
65+ self .socket : zmq_anyio .Socket = socket
6866 self ._context = socket .context
6967
7068 self .background_socket = BackgroundSocket (self )
@@ -78,7 +76,7 @@ def __init__(self, socket, pipe=False):
7876 self ._event_pipe_gc_lock : threading .Lock = threading .Lock ()
7977 self ._event_pipe_gc_seconds : float = 10
8078 self ._setup_event_pipe ()
81- tasks = [self ._handle_event , self ._run_event_pipe_gc ]
79+ tasks = [self ._handle_event , self ._run_event_pipe_gc , self . socket . start ]
8280 if pipe :
8381 tasks .append (self ._handle_pipe_msgs )
8482 self .thread = BaseThread (name = "IOPub" , daemon = True )
@@ -87,7 +85,7 @@ def __init__(self, socket, pipe=False):
8785
8886 def _setup_event_pipe (self ):
8987 """Create the PULL socket listening for events that should fire in this thread."""
90- self ._pipe_in0 = self ._context .socket (zmq .PULL , socket_class = zmq . Socket )
88+ self ._pipe_in0 = self ._context .socket (zmq .PULL )
9189 self ._pipe_in0 .linger = 0
9290
9391 _uuid = b2a_hex (os .urandom (16 )).decode ("ascii" )
@@ -99,11 +97,11 @@ async def _run_event_pipe_gc(self):
9997 while True :
10098 await sleep (self ._event_pipe_gc_seconds )
10199 try :
102- await self ._event_pipe_gc ()
100+ self ._event_pipe_gc ()
103101 except Exception as e :
104102 print (f"Exception in IOPubThread._event_pipe_gc: { e } " , file = sys .__stderr__ )
105103
106- async def _event_pipe_gc (self ):
104+ def _event_pipe_gc (self ):
107105 """run a single garbage collection on event pipes"""
108106 if not self ._event_pipes :
109107 # don't acquire the lock if there's nothing to do
@@ -122,7 +120,7 @@ def _event_pipe(self):
122120 except AttributeError :
123121 # new thread, new event pipe
124122 # create sync base socket
125- event_pipe = self ._context .socket (zmq .PUSH , socket_class = zmq . Socket )
123+ event_pipe = self ._context .socket (zmq .PUSH )
126124 event_pipe .linger = 0
127125 event_pipe .connect (self ._event_interface )
128126 self ._local .event_pipe = event_pipe
@@ -141,30 +139,28 @@ async def _handle_event(self):
141139 Whenever *an* event arrives on the event stream,
142140 *all* waiting events are processed in order.
143141 """
144- # create async wrapper within coroutine
145- pipe_in = zmq . asyncio . Socket ( self . _pipe_in0 )
146- try :
147- while True :
148- await pipe_in .recv ()
149- # freeze event count so new writes don't extend the queue
150- # while we are processing
151- n_events = len (self ._events )
152- for _ in range (n_events ):
153- event_f = self ._events .popleft ()
154- event_f ()
155- except Exception :
156- if self .thread .stopped .is_set ():
157- return
158- raise
142+ pipe_in = zmq_anyio . Socket ( self . _pipe_in0 )
143+ async with pipe_in :
144+ try :
145+ while True :
146+ await pipe_in .arecv (). wait ()
147+ # freeze event count so new writes don't extend the queue
148+ # while we are processing
149+ n_events = len (self ._events )
150+ for _ in range (n_events ):
151+ event_f = self ._events .popleft ()
152+ event_f ()
153+ except Exception :
154+ if self .thread .stopped .is_set ():
155+ return
156+ raise
159157
160158 def _setup_pipe_in (self ):
161159 """setup listening pipe for IOPub from forked subprocesses"""
162- ctx = self ._context
163-
164160 # use UUID to authenticate pipe messages
165161 self ._pipe_uuid = os .urandom (16 )
166162
167- self ._pipe_in1 = ctx . socket (zmq .PULL , socket_class = zmq . Socket )
163+ self ._pipe_in1 = zmq_anyio . Socket ( self . _context . socket (zmq .PULL ) )
168164 self ._pipe_in1 .linger = 0
169165
170166 try :
@@ -181,19 +177,18 @@ def _setup_pipe_in(self):
181177
182178 async def _handle_pipe_msgs (self ):
183179 """handle pipe messages from a subprocess"""
184- # create async wrapper within coroutine
185- self ._async_pipe_in1 = zmq .asyncio .Socket (self ._pipe_in1 )
186- try :
187- while True :
188- await self ._handle_pipe_msg ()
189- except Exception :
190- if self .thread .stopped .is_set ():
191- return
192- raise
180+ async with self ._pipe_in1 :
181+ try :
182+ while True :
183+ await self ._handle_pipe_msg ()
184+ except Exception :
185+ if self .thread .stopped .is_set ():
186+ return
187+ raise
193188
194189 async def _handle_pipe_msg (self , msg = None ):
195190 """handle a pipe message from a subprocess"""
196- msg = msg or await self ._async_pipe_in1 . recv_multipart ()
191+ msg = msg or await self ._pipe_in1 . arecv_multipart (). wait ()
197192 if not self ._pipe_flag or not self ._is_main_process ():
198193 return
199194 if msg [0 ] != self ._pipe_uuid :
@@ -246,7 +241,10 @@ def close(self):
246241 """Close the IOPub thread."""
247242 if self .closed :
248243 return
249- self ._pipe_in0 .close ()
244+ try :
245+ self ._pipe_in0 .close ()
246+ except Exception :
247+ pass
250248 if self ._pipe_flag :
251249 self ._pipe_in1 .close ()
252250 if self .socket is not None :
0 commit comments