Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 61 additions & 11 deletions fastapi_websocket_pubsub/pub_sub_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ async def _primary_on_connect(self, channel: RpcChannel):

def subscribe(self, topic: Topic, callback: Coroutine):
"""
Subscribe for events (prior to starting the client)
Subscribe for events (before and after starting the client)
@see fastapi_websocket_pubsub/rpc_event_methods.py :: RpcEventServerMethods.subscribe

Args:
Expand All @@ -260,18 +260,68 @@ def subscribe(self, topic: Topic, callback: Coroutine):
'hello' or a complex path 'a/b/c/d' .
Note: You can use ALL_TOPICS (event_notifier.ALL_TOPICS) to subscribe to all topics
callback (Coroutine): the function to call upon relevant event publishing

Returns:
Coroutine: awaitable task to subscribe to topic if connected.
"""
# TODO: add support for post connection subscriptions
if not self.is_ready():
self._topics.add(topic)
# init to empty list if no entry
callbacks = self._callbacks[topic] = self._callbacks.get(topic, [])
# add callback to callbacks list of the topic
callbacks.append(callback)
topic_is_new = topic not in self._topics
self._topics.add(topic)
# init to empty list if no entry
callbacks = self._callbacks[topic] = self._callbacks.get(topic, [])
# add callback to callbacks list of the topic
callbacks.append(callback)
if topic_is_new and self.is_ready():
return self._rpc_channel.other.subscribe(topics=[topic])
else:
raise PubSubClientInvalidStateException(
"Client already connected and subscribed"
)
# If we can't return an RPC call future then we need
# to supply something else to not fail when the
# calling code awaits the result of this function.
future = asyncio.Future()
future.set_result(None)
return future

def unsubscribe(self, topic: Topic):
"""
Unsubscribe for events

Args:
topic (Topic): the identifier of the event topic to be unsubscribed.
Note: You can use ALL_TOPICS (event_notifier.ALL_TOPICS) to unsubscribe all topics

Returns:
Coroutine: awaitable task to subscribe to topic if connected.
"""
# Create none-future which can be safely awaited
# but which also will not give warnings
# if it isn't awaited. This is returned
# on code paths which do not make RPC calls.
none_future = asyncio.Future()
none_future.set_result(None)

# Topics to potentially make RPC calls about
topics = list(self._topics) if topic is ALL_TOPICS else [topic]

# Handle ALL_TOPICS or specific topics
if topic is ALL_TOPICS and not self._topics:
logger.warning(f"Cannot unsubscribe 'ALL_TOPICS'. No topics are subscribed.")
return none_future
elif topic is not ALL_TOPICS and topic not in self._topics:
logger.warning(f"Cannot unsubscribe topic '{topic}' which is not subscribed.")
return none_future
elif topic is ALL_TOPICS and self._topics:
logger.debug(f"Unsubscribing all topics: {self._topics}")
# remove all topics and callbacks
self._topics.clear()
self._callbacks.clear()
elif topic is not ALL_TOPICS and topic in self._topics:
logger.debug(f"Unsubscribing topic '{topic}'")
self._topics.remove(topic)
self._callbacks.pop(topic, None)

if self.is_ready():
return self._rpc_channel.other.unsubscribe(topics=topics)
else:
return none_future

async def publish(
self, topics: TopicList, data=None, sync=True, notifier_id=None
Expand Down
4 changes: 4 additions & 0 deletions fastapi_websocket_pubsub/pub_sub_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ async def subscribe(
) -> List[Subscription]:
return await self.notifier.subscribe(self._subscriber_id, topics, callback)

async def unsubscribe(
self, topics: Union[TopicList, ALL_TOPICS]) -> List[Subscription]:
return await self.notifier.unsubscribe(self._subscriber_id, topics)

async def publish(self, topics: Union[TopicList, Topic], data=None):
"""
Publish events to subscribres of given topics currently connected to the endpoint
Expand Down
13 changes: 13 additions & 0 deletions fastapi_websocket_pubsub/rpc_event_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,19 @@ async def callback(subscription: Subscription, data):
"Failed to subscribe to RPC events notifier", topics)
return False

async def unsubscribe(self, topics: TopicList = []) -> bool:
"""
provided by the server so that the client can unsubscribe topics.
"""
for topic in topics.copy():
if topic not in self.event_notifier._topics:
self.logger.warning(f"Cannot unsubscribe topic '{topic}' which is not subscribed.")
topics.remove(topic)
# We'll use the remote channel id as our subscriber id
sub_id = await self._get_channel_id_()
await self.event_notifier.unsubscribe(sub_id, topics)
return True

async def publish(self, topics: TopicList = [], data=None, sync=True, notifier_id=None) -> bool:
"""
Publish an event through the server to subscribers
Expand Down
44 changes: 44 additions & 0 deletions tests/basic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,47 @@ async def on_event(data, topic):
assert published.result
# wait for finish trigger
await asyncio.wait_for(finish.wait(), 5)


@pytest.mark.asyncio
async def test_pub_sub_unsub(server):
"""
Check client can unsubscribe topic and subscribe again.
"""
# finish trigger
finish = asyncio.Event()
async with PubSubClient() as client:

async def on_event(data, topic):
assert data == DATA
finish.set()

# subscribe for the event
client.subscribe(EVENT_TOPIC, on_event)
# start listentining
client.start_client(uri)
# wait for the client to be ready to receive events
await client.wait_until_ready()
# trigger the server via an HTTP route
requests.get(trigger_url)
# wait for finish trigger
await asyncio.wait_for(finish.wait(), 5)
assert finish.is_set()

# unsubscribe and see that we don't get a message
finish.clear()
await client.unsubscribe(EVENT_TOPIC)
requests.get(trigger_url)
# wait for finish trigger which isn't coming
with pytest.raises(asyncio.TimeoutError) as excinfo:
await asyncio.wait_for(finish.wait(), 5)
assert not finish.is_set()

# subscribe again and observe that we get the trigger
finish.clear()
await client.subscribe(EVENT_TOPIC, on_event)
# trigger the server via an HTTP route
requests.get(trigger_url)
# wait for finish trigger
await asyncio.wait_for(finish.wait(), 5)
assert finish.is_set()