1
1
import asyncio
2
+ import contextlib
2
3
from typing import Any
3
4
4
5
from broadcaster import Broadcast
@@ -22,14 +23,6 @@ class BroadcastNotification(BaseModel):
22
23
data : Any = None
23
24
24
25
25
- class EventBroadcasterException (Exception ):
26
- pass
27
-
28
-
29
- class BroadcasterAlreadyStarted (EventBroadcasterException ):
30
- pass
31
-
32
-
33
26
class EventBroadcasterContextManager :
34
27
"""
35
28
Manages the context for the EventBroadcaster
@@ -56,56 +49,18 @@ def __init__(
56
49
self ._listen : bool = listen
57
50
58
51
async def __aenter__ (self ):
59
- async with self ._event_broadcaster ._context_manager_lock :
60
- if self ._listen :
61
- self ._event_broadcaster ._listen_count += 1
62
- if self ._event_broadcaster ._listen_count == 1 :
63
- # We have our first listener start the read-task for it (And all those who'd follow)
64
- logger .info (
65
- "Listening for incoming events from broadcast channel (first listener started)"
66
- )
67
- # Start task listening on incoming broadcasts
68
- await self ._event_broadcaster .start_reader_task ()
69
-
70
- if self ._share :
71
- self ._event_broadcaster ._share_count += 1
72
- if self ._event_broadcaster ._share_count == 1 :
73
- # We have our first publisher
74
- # Init the broadcast used for sharing (reading has its own)
75
- logger .debug (
76
- "Subscribing to ALL_TOPICS, and sharing messages with broadcast channel"
77
- )
78
- # Subscribe to internal events form our own event notifier and broadcast them
79
- await self ._event_broadcaster ._subscribe_to_all_topics ()
80
- else :
81
- logger .debug (
82
- f"Did not subscribe to ALL_TOPICS: share count == { self ._event_broadcaster ._share_count } "
83
- )
84
- return self
52
+ await self ._event_broadcaster .connect (self ._listen , self ._share )
85
53
86
54
async def __aexit__ (self , exc_type , exc , tb ):
87
- async with self ._event_broadcaster ._context_manager_lock :
88
- try :
89
- if self ._listen :
90
- self ._event_broadcaster ._listen_count -= 1
91
- # if this was last listener - we can stop the reading task
92
- if self ._event_broadcaster ._listen_count == 0 :
93
- # Cancel task reading broadcast subscriptions
94
- if self ._event_broadcaster ._subscription_task is not None :
95
- logger .info ("Cancelling broadcast listen task" )
96
- self ._event_broadcaster ._subscription_task .cancel ()
97
- self ._event_broadcaster ._subscription_task = None
98
-
99
- if self ._share :
100
- self ._event_broadcaster ._share_count -= 1
101
- # if this was last sharer - we can stop subscribing to internal events - we aren't sharing anymore
102
- if self ._event_broadcaster ._share_count == 0 :
103
- # Unsubscribe from internal events
104
- logger .debug ("Unsubscribing from ALL TOPICS" )
105
- await self ._event_broadcaster ._unsubscribe_from_topics ()
106
-
107
- except :
108
- logger .exception ("Failed to exit EventBroadcaster context" )
55
+ await self ._event_broadcaster .close (self ._listen , self ._share )
56
+
57
+
58
+ class EventBroadcasterException (Exception ):
59
+ pass
60
+
61
+
62
+ class BroadcasterAlreadyStarted (EventBroadcasterException ):
63
+ pass
109
64
110
65
111
66
class EventBroadcaster :
@@ -137,62 +92,46 @@ def __init__(
137
92
broadcast_type (Broadcast, optional): Broadcast class to use. None - Defaults to Broadcast.
138
93
is_publish_only (bool, optional): [For default context] Should the broadcaster only transmit events and not listen to any. Defaults to False
139
94
"""
140
- # Broadcast init params
141
95
self ._broadcast_url = broadcast_url
142
96
self ._broadcast_type = broadcast_type or Broadcast
143
- # Publish broadcast (initialized within async with statement)
144
- self ._sharing_broadcast_channel = None
145
- # channel to operate on
146
97
self ._channel = channel
147
- # Async-io task for reading broadcasts (initialized within async with statement)
148
98
self ._subscription_task = None
149
- # Uniqueue instance id (used to avoid reading own notifications sent in broadcast)
150
99
self ._id = gen_uid ()
151
- # The internal events notifier
152
100
self ._notifier = notifier
101
+ self ._broadcast_channel = None
102
+ self ._connect_lock = asyncio .Lock ()
103
+ self ._listen_refcount = 0
104
+ self ._share_refcount = 0
153
105
self ._is_publish_only = is_publish_only
154
- self ._publish_lock = None
155
- # used to track creation / removal of resources needed per type (reader task->listen, and subscription to internal events->share)
156
- self ._listen_count : int = 0
157
- self ._share_count : int = 0
158
- # If we opt to manage the context directly (i.e. call async with on the event broadcaster itself)
159
- self ._context_manager = None
160
- self ._context_manager_lock = asyncio .Lock ()
161
- self ._tasks = set ()
162
- self .listening_broadcast_channel = None
163
106
164
- async def __broadcast_notifications__ (self , subscription : Subscription , data ):
107
+ async def connect (self , listen = True , share = True ):
165
108
"""
166
- Share incoming internal notifications with the entire broadcast channel
167
-
168
- Args:
169
- subscription (Subscription): the subscription that got triggered
170
- data: the event data
109
+ This connects the listening channel
171
110
"""
172
- logger .info (
173
- "Broadcasting incoming event: {}" .format (
174
- {"topic" : subscription .topic , "notifier_id" : self ._id }
175
- )
176
- )
177
- note = BroadcastNotification (
178
- notifier_id = self ._id , topics = [subscription .topic ], data = data
179
- )
111
+ async with self ._connect_lock :
112
+ if listen :
113
+ await self ._connect_listen ()
114
+ self ._listen_refcount += 1
180
115
181
- # Publish event to broadcast
182
- async with self ._broadcast_type (
183
- self ._broadcast_url
184
- ) as sharing_broadcast_channel :
185
- await sharing_broadcast_channel .publish (
186
- self ._channel , pydantic_serialize (note )
187
- )
116
+ if share :
117
+ await self ._connect_share ()
118
+ self ._share_refcount += 1
188
119
189
- async def _subscribe_to_all_topics (self ):
190
- return await self ._notifier .subscribe (
191
- self ._id , ALL_TOPICS , self .__broadcast_notifications__
192
- )
120
+ async def close (self , listen = True , share = True ):
121
+ async with self ._connect_lock :
122
+ if listen :
123
+ await self ._close_listen ()
124
+ self ._listen_refcount -= 1
125
+
126
+ if share :
127
+ await self ._close_share ()
128
+ self ._share_refcount -= 1
129
+
130
+ async def __aenter__ (self ):
131
+ await self .connect (listen = not self ._is_publish_only )
193
132
194
- async def _unsubscribe_from_topics (self ):
195
- return await self ._notifier . unsubscribe ( self ._id )
133
+ async def __aexit__ (self , exc_type , exc , tb ):
134
+ await self .close ( listen = not self ._is_publish_only )
196
135
197
136
def get_context (self , listen = True , share = True ):
198
137
"""
@@ -213,97 +152,116 @@ def get_listening_context(self):
213
152
def get_sharing_context (self ):
214
153
return EventBroadcasterContextManager (self , listen = False , share = True )
215
154
216
- async def __aenter__ (self ):
155
+ async def __broadcast_notifications__ (self , subscription : Subscription , data ):
217
156
"""
218
- Convince caller (also backward compaltability)
157
+ Share incoming internal notifications with the entire broadcast channel
158
+
159
+ Args:
160
+ subscription (Subscription): the subscription that got triggered
161
+ data: the event data
219
162
"""
220
- if self ._context_manager is None :
221
- self ._context_manager = self .get_context (listen = not self ._is_publish_only )
222
- return await self ._context_manager .__aenter__ ()
163
+ logger .info (
164
+ "Broadcasting incoming event: {}" .format (
165
+ {"topic" : subscription .topic , "notifier_id" : self ._id }
166
+ )
167
+ )
223
168
224
- async def __aexit__ (self , exc_type , exc , tb ):
225
- await self ._context_manager .__aexit__ (exc_type , exc , tb )
169
+ note = BroadcastNotification (
170
+ notifier_id = self ._id , topics = [subscription .topic ], data = data
171
+ )
226
172
227
- async def start_reader_task (self ):
228
- """Spawn a task reading incoming broadcasts and posting them to the intreal notifier
229
- Raises:
230
- BroadcasterAlreadyStarted: if called more than once per context
231
- Returns:
232
- the spawned task
233
- """
234
- # Make sure a task wasn't started already
235
- if self ._subscription_task is not None :
236
- # we already started a task for this worker process
237
- logger .debug (
238
- "No need for listen task, already started broadcast listen task for this notifier"
173
+ # Publish event to broadcast using a new connection from connection pool
174
+ async with self ._broadcast_type (
175
+ self ._broadcast_url
176
+ ) as sharing_broadcast_channel :
177
+ await sharing_broadcast_channel .publish (
178
+ self ._channel , pydantic_serialize (note )
239
179
)
240
- return
241
180
242
- # Init new broadcast channel for reading
243
- try :
244
- if self .listening_broadcast_channel is None :
245
- self .listening_broadcast_channel = self ._broadcast_type (
246
- self ._broadcast_url
247
- )
248
- await self .listening_broadcast_channel .connect ()
249
- except Exception as e :
250
- logger .error (
251
- f"Failed to connect to broadcast channel for reading incoming events: { e } "
181
+ async def _connect_share (self ):
182
+ if self ._share_refcount == 0 :
183
+ return await self ._notifier .subscribe (
184
+ self ._id , ALL_TOPICS , self .__broadcast_notifications__
252
185
)
253
- raise e
254
186
255
- # Trigger the task
256
- logger .debug ("Spawning broadcast listen task" )
257
- self ._subscription_task = asyncio .create_task (self .__read_notifications__ ())
258
- return self ._subscription_task
187
+ async def _close_share (self ):
188
+ if self ._share_refcount == 1 :
189
+ return await self ._notifier .unsubscribe (self ._id )
190
+
191
+ async def _connect_listen (self ):
192
+ if self ._listen_refcount == 0 :
193
+ if self ._listen_refcount == 0 :
194
+ try :
195
+ self ._broadcast_channel = self ._broadcast_type (self ._broadcast_url )
196
+ await self ._broadcast_channel .connect ()
197
+ except Exception as e :
198
+ logger .error (
199
+ f"Failed to connect to broadcast channel for reading incoming events: { e } "
200
+ )
201
+ raise e
202
+ self ._subscription_task = asyncio .create_task (
203
+ self .__read_notifications__ ()
204
+ )
205
+ return await self ._notifier .subscribe (
206
+ self ._id , ALL_TOPICS , self .__broadcast_notifications__
207
+ )
208
+
209
+ async def _close_listen (self ):
210
+ if self ._listen_refcount == 1 and self ._broadcast_channel is not None :
211
+ await self ._broadcast_channel .disconnect ()
212
+ await self .wait_until_done ()
213
+ self ._broadcast_channel = None
259
214
260
215
def get_reader_task (self ):
261
216
return self ._subscription_task
262
217
218
+ async def wait_until_done (self ):
219
+ if self ._subscription_task is not None :
220
+ await self ._subscription_task
221
+ self ._subscription_task = None
222
+
263
223
async def __read_notifications__ (self ):
264
224
"""
265
225
read incoming broadcasts and posting them to the intreal notifier
266
226
"""
267
227
logger .debug ("Starting broadcaster listener" )
228
+
229
+ notify_tasks = set ()
268
230
try :
269
231
# Subscribe to our channel
270
- async with self .listening_broadcast_channel .subscribe (
232
+ async with self ._broadcast_channel .subscribe (
271
233
channel = self ._channel
272
234
) as subscriber :
273
235
async for event in subscriber :
274
- try :
275
- notification = BroadcastNotification .parse_raw (event .message )
276
- # Avoid re-publishing our own broadcasts
277
- if notification .notifier_id != self ._id :
278
- logger .debug (
279
- "Handling incoming broadcast event: {}" .format (
280
- {
281
- "topics" : notification .topics ,
282
- "src" : notification .notifier_id ,
283
- }
284
- )
236
+ notification = BroadcastNotification .parse_raw (event .message )
237
+ # Avoid re-publishing our own broadcasts
238
+ if notification .notifier_id != self ._id :
239
+ logger .debug (
240
+ "Handling incoming broadcast event: {}" .format (
241
+ {
242
+ "topics" : notification .topics ,
243
+ "src" : notification .notifier_id ,
244
+ }
285
245
)
286
- # Notify subscribers of message received from broadcast
287
- task = asyncio . create_task (
288
- self . _notifier . notify (
289
- notification . topics ,
290
- notification .data ,
291
- notifier_id = self . _id ,
292
- )
246
+ )
247
+ # Notify subscribers of message received from broadcast
248
+ task = asyncio . create_task (
249
+ self . _notifier . notify (
250
+ notification .topics ,
251
+ notification . data ,
252
+ notifier_id = self . _id ,
293
253
)
254
+ )
294
255
295
- self . _tasks .add (task )
256
+ notify_tasks .add (task )
296
257
297
- def cleanup (task ):
298
- self . _tasks . remove (task )
258
+ def cleanup (t ):
259
+ notify_tasks . remove (t )
299
260
300
- task .add_done_callback (cleanup )
301
- except :
302
- logger .exception ("Failed handling incoming broadcast" )
261
+ task .add_done_callback (cleanup )
303
262
logger .info (
304
263
"No more events to read from subscriber (underlying connection closed)"
305
264
)
306
265
finally :
307
- if self .listening_broadcast_channel is not None :
308
- await self .listening_broadcast_channel .disconnect ()
309
- self .listening_broadcast_channel = None
266
+ # TODO: return_exceptions?
267
+ await asyncio .gather (* notify_tasks , return_exceptions = True )
0 commit comments