Skip to content

Commit 7bcbcaa

Browse files
committed
EventBroadcaster: Add simple connect & close methods, clean the code a bit
1 parent 671c189 commit 7bcbcaa

File tree

2 files changed

+132
-169
lines changed

2 files changed

+132
-169
lines changed
Lines changed: 120 additions & 162 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import contextlib
23
from typing import Any
34

45
from broadcaster import Broadcast
@@ -22,14 +23,6 @@ class BroadcastNotification(BaseModel):
2223
data: Any = None
2324

2425

25-
class EventBroadcasterException(Exception):
26-
pass
27-
28-
29-
class BroadcasterAlreadyStarted(EventBroadcasterException):
30-
pass
31-
32-
3326
class EventBroadcasterContextManager:
3427
"""
3528
Manages the context for the EventBroadcaster
@@ -56,56 +49,18 @@ def __init__(
5649
self._listen: bool = listen
5750

5851
async def __aenter__(self):
59-
async with self._event_broadcaster._context_manager_lock:
60-
if self._listen:
61-
self._event_broadcaster._listen_count += 1
62-
if self._event_broadcaster._listen_count == 1:
63-
# We have our first listener start the read-task for it (And all those who'd follow)
64-
logger.info(
65-
"Listening for incoming events from broadcast channel (first listener started)"
66-
)
67-
# Start task listening on incoming broadcasts
68-
await self._event_broadcaster.start_reader_task()
69-
70-
if self._share:
71-
self._event_broadcaster._share_count += 1
72-
if self._event_broadcaster._share_count == 1:
73-
# We have our first publisher
74-
# Init the broadcast used for sharing (reading has its own)
75-
logger.debug(
76-
"Subscribing to ALL_TOPICS, and sharing messages with broadcast channel"
77-
)
78-
# Subscribe to internal events form our own event notifier and broadcast them
79-
await self._event_broadcaster._subscribe_to_all_topics()
80-
else:
81-
logger.debug(
82-
f"Did not subscribe to ALL_TOPICS: share count == {self._event_broadcaster._share_count}"
83-
)
84-
return self
52+
await self._event_broadcaster.connect(self._listen, self._share)
8553

8654
async def __aexit__(self, exc_type, exc, tb):
87-
async with self._event_broadcaster._context_manager_lock:
88-
try:
89-
if self._listen:
90-
self._event_broadcaster._listen_count -= 1
91-
# if this was last listener - we can stop the reading task
92-
if self._event_broadcaster._listen_count == 0:
93-
# Cancel task reading broadcast subscriptions
94-
if self._event_broadcaster._subscription_task is not None:
95-
logger.info("Cancelling broadcast listen task")
96-
self._event_broadcaster._subscription_task.cancel()
97-
self._event_broadcaster._subscription_task = None
98-
99-
if self._share:
100-
self._event_broadcaster._share_count -= 1
101-
# if this was last sharer - we can stop subscribing to internal events - we aren't sharing anymore
102-
if self._event_broadcaster._share_count == 0:
103-
# Unsubscribe from internal events
104-
logger.debug("Unsubscribing from ALL TOPICS")
105-
await self._event_broadcaster._unsubscribe_from_topics()
106-
107-
except:
108-
logger.exception("Failed to exit EventBroadcaster context")
55+
await self._event_broadcaster.close(self._listen, self._share)
56+
57+
58+
class EventBroadcasterException(Exception):
59+
pass
60+
61+
62+
class BroadcasterAlreadyStarted(EventBroadcasterException):
63+
pass
10964

11065

11166
class EventBroadcaster:
@@ -137,62 +92,46 @@ def __init__(
13792
broadcast_type (Broadcast, optional): Broadcast class to use. None - Defaults to Broadcast.
13893
is_publish_only (bool, optional): [For default context] Should the broadcaster only transmit events and not listen to any. Defaults to False
13994
"""
140-
# Broadcast init params
14195
self._broadcast_url = broadcast_url
14296
self._broadcast_type = broadcast_type or Broadcast
143-
# Publish broadcast (initialized within async with statement)
144-
self._sharing_broadcast_channel = None
145-
# channel to operate on
14697
self._channel = channel
147-
# Async-io task for reading broadcasts (initialized within async with statement)
14898
self._subscription_task = None
149-
# Uniqueue instance id (used to avoid reading own notifications sent in broadcast)
15099
self._id = gen_uid()
151-
# The internal events notifier
152100
self._notifier = notifier
101+
self._broadcast_channel = None
102+
self._connect_lock = asyncio.Lock()
103+
self._listen_refcount = 0
104+
self._share_refcount = 0
153105
self._is_publish_only = is_publish_only
154-
self._publish_lock = None
155-
# used to track creation / removal of resources needed per type (reader task->listen, and subscription to internal events->share)
156-
self._listen_count: int = 0
157-
self._share_count: int = 0
158-
# If we opt to manage the context directly (i.e. call async with on the event broadcaster itself)
159-
self._context_manager = None
160-
self._context_manager_lock = asyncio.Lock()
161-
self._tasks = set()
162-
self.listening_broadcast_channel = None
163106

164-
async def __broadcast_notifications__(self, subscription: Subscription, data):
107+
async def connect(self, listen=True, share=True):
165108
"""
166-
Share incoming internal notifications with the entire broadcast channel
167-
168-
Args:
169-
subscription (Subscription): the subscription that got triggered
170-
data: the event data
109+
This connects the listening channel
171110
"""
172-
logger.info(
173-
"Broadcasting incoming event: {}".format(
174-
{"topic": subscription.topic, "notifier_id": self._id}
175-
)
176-
)
177-
note = BroadcastNotification(
178-
notifier_id=self._id, topics=[subscription.topic], data=data
179-
)
111+
async with self._connect_lock:
112+
if listen:
113+
await self._connect_listen()
114+
self._listen_refcount += 1
180115

181-
# Publish event to broadcast
182-
async with self._broadcast_type(
183-
self._broadcast_url
184-
) as sharing_broadcast_channel:
185-
await sharing_broadcast_channel.publish(
186-
self._channel, pydantic_serialize(note)
187-
)
116+
if share:
117+
await self._connect_share()
118+
self._share_refcount += 1
188119

189-
async def _subscribe_to_all_topics(self):
190-
return await self._notifier.subscribe(
191-
self._id, ALL_TOPICS, self.__broadcast_notifications__
192-
)
120+
async def close(self, listen=True, share=True):
121+
async with self._connect_lock:
122+
if listen:
123+
await self._close_listen()
124+
self._listen_refcount -= 1
125+
126+
if share:
127+
await self._close_share()
128+
self._share_refcount -= 1
129+
130+
async def __aenter__(self):
131+
await self.connect(listen=not self._is_publish_only)
193132

194-
async def _unsubscribe_from_topics(self):
195-
return await self._notifier.unsubscribe(self._id)
133+
async def __aexit__(self, exc_type, exc, tb):
134+
await self.close(listen=not self._is_publish_only)
196135

197136
def get_context(self, listen=True, share=True):
198137
"""
@@ -213,97 +152,116 @@ def get_listening_context(self):
213152
def get_sharing_context(self):
214153
return EventBroadcasterContextManager(self, listen=False, share=True)
215154

216-
async def __aenter__(self):
155+
async def __broadcast_notifications__(self, subscription: Subscription, data):
217156
"""
218-
Convince caller (also backward compaltability)
157+
Share incoming internal notifications with the entire broadcast channel
158+
159+
Args:
160+
subscription (Subscription): the subscription that got triggered
161+
data: the event data
219162
"""
220-
if self._context_manager is None:
221-
self._context_manager = self.get_context(listen=not self._is_publish_only)
222-
return await self._context_manager.__aenter__()
163+
logger.info(
164+
"Broadcasting incoming event: {}".format(
165+
{"topic": subscription.topic, "notifier_id": self._id}
166+
)
167+
)
223168

224-
async def __aexit__(self, exc_type, exc, tb):
225-
await self._context_manager.__aexit__(exc_type, exc, tb)
169+
note = BroadcastNotification(
170+
notifier_id=self._id, topics=[subscription.topic], data=data
171+
)
226172

227-
async def start_reader_task(self):
228-
"""Spawn a task reading incoming broadcasts and posting them to the intreal notifier
229-
Raises:
230-
BroadcasterAlreadyStarted: if called more than once per context
231-
Returns:
232-
the spawned task
233-
"""
234-
# Make sure a task wasn't started already
235-
if self._subscription_task is not None:
236-
# we already started a task for this worker process
237-
logger.debug(
238-
"No need for listen task, already started broadcast listen task for this notifier"
173+
# Publish event to broadcast using a new connection from connection pool
174+
async with self._broadcast_type(
175+
self._broadcast_url
176+
) as sharing_broadcast_channel:
177+
await sharing_broadcast_channel.publish(
178+
self._channel, pydantic_serialize(note)
239179
)
240-
return
241180

242-
# Init new broadcast channel for reading
243-
try:
244-
if self.listening_broadcast_channel is None:
245-
self.listening_broadcast_channel = self._broadcast_type(
246-
self._broadcast_url
247-
)
248-
await self.listening_broadcast_channel.connect()
249-
except Exception as e:
250-
logger.error(
251-
f"Failed to connect to broadcast channel for reading incoming events: {e}"
181+
async def _connect_share(self):
182+
if self._share_refcount == 0:
183+
return await self._notifier.subscribe(
184+
self._id, ALL_TOPICS, self.__broadcast_notifications__
252185
)
253-
raise e
254186

255-
# Trigger the task
256-
logger.debug("Spawning broadcast listen task")
257-
self._subscription_task = asyncio.create_task(self.__read_notifications__())
258-
return self._subscription_task
187+
async def _close_share(self):
188+
if self._share_refcount == 1:
189+
return await self._notifier.unsubscribe(self._id)
190+
191+
async def _connect_listen(self):
192+
if self._listen_refcount == 0:
193+
if self._listen_refcount == 0:
194+
try:
195+
self._broadcast_channel = self._broadcast_type(self._broadcast_url)
196+
await self._broadcast_channel.connect()
197+
except Exception as e:
198+
logger.error(
199+
f"Failed to connect to broadcast channel for reading incoming events: {e}"
200+
)
201+
raise e
202+
self._subscription_task = asyncio.create_task(
203+
self.__read_notifications__()
204+
)
205+
return await self._notifier.subscribe(
206+
self._id, ALL_TOPICS, self.__broadcast_notifications__
207+
)
208+
209+
async def _close_listen(self):
210+
if self._listen_refcount == 1 and self._broadcast_channel is not None:
211+
await self._broadcast_channel.disconnect()
212+
await self.wait_until_done()
213+
self._broadcast_channel = None
259214

260215
def get_reader_task(self):
261216
return self._subscription_task
262217

218+
async def wait_until_done(self):
219+
if self._subscription_task is not None:
220+
await self._subscription_task
221+
self._subscription_task = None
222+
263223
async def __read_notifications__(self):
264224
"""
265225
read incoming broadcasts and posting them to the intreal notifier
266226
"""
267227
logger.debug("Starting broadcaster listener")
228+
229+
notify_tasks = set()
268230
try:
269231
# Subscribe to our channel
270-
async with self.listening_broadcast_channel.subscribe(
232+
async with self._broadcast_channel.subscribe(
271233
channel=self._channel
272234
) as subscriber:
273235
async for event in subscriber:
274-
try:
275-
notification = BroadcastNotification.parse_raw(event.message)
276-
# Avoid re-publishing our own broadcasts
277-
if notification.notifier_id != self._id:
278-
logger.debug(
279-
"Handling incoming broadcast event: {}".format(
280-
{
281-
"topics": notification.topics,
282-
"src": notification.notifier_id,
283-
}
284-
)
236+
notification = BroadcastNotification.parse_raw(event.message)
237+
# Avoid re-publishing our own broadcasts
238+
if notification.notifier_id != self._id:
239+
logger.debug(
240+
"Handling incoming broadcast event: {}".format(
241+
{
242+
"topics": notification.topics,
243+
"src": notification.notifier_id,
244+
}
285245
)
286-
# Notify subscribers of message received from broadcast
287-
task = asyncio.create_task(
288-
self._notifier.notify(
289-
notification.topics,
290-
notification.data,
291-
notifier_id=self._id,
292-
)
246+
)
247+
# Notify subscribers of message received from broadcast
248+
task = asyncio.create_task(
249+
self._notifier.notify(
250+
notification.topics,
251+
notification.data,
252+
notifier_id=self._id,
293253
)
254+
)
294255

295-
self._tasks.add(task)
256+
notify_tasks.add(task)
296257

297-
def cleanup(task):
298-
self._tasks.remove(task)
258+
def cleanup(t):
259+
notify_tasks.remove(t)
299260

300-
task.add_done_callback(cleanup)
301-
except:
302-
logger.exception("Failed handling incoming broadcast")
261+
task.add_done_callback(cleanup)
303262
logger.info(
304263
"No more events to read from subscriber (underlying connection closed)"
305264
)
306265
finally:
307-
if self.listening_broadcast_channel is not None:
308-
await self.listening_broadcast_channel.disconnect()
309-
self.listening_broadcast_channel = None
266+
# TODO: return_exceptions?
267+
await asyncio.gather(*notify_tasks, return_exceptions=True)

0 commit comments

Comments
 (0)