diff --git a/fastapi_websocket_pubsub/pub_sub_client.py b/fastapi_websocket_pubsub/pub_sub_client.py index 4ad917e..893d715 100644 --- a/fastapi_websocket_pubsub/pub_sub_client.py +++ b/fastapi_websocket_pubsub/pub_sub_client.py @@ -251,7 +251,7 @@ async def _primary_on_connect(self, channel: RpcChannel): def subscribe(self, topic: Topic, callback: Coroutine): """ - Subscribe for events (prior to starting the client) + Subscribe for events (before and after starting the client) @see fastapi_websocket_pubsub/rpc_event_methods.py :: RpcEventServerMethods.subscribe Args: @@ -260,18 +260,68 @@ def subscribe(self, topic: Topic, callback: Coroutine): 'hello' or a complex path 'a/b/c/d' . Note: You can use ALL_TOPICS (event_notifier.ALL_TOPICS) to subscribe to all topics callback (Coroutine): the function to call upon relevant event publishing + + Returns: + Coroutine: awaitable task to subscribe to topic if connected. """ - # TODO: add support for post connection subscriptions - if not self.is_ready(): - self._topics.add(topic) - # init to empty list if no entry - callbacks = self._callbacks[topic] = self._callbacks.get(topic, []) - # add callback to callbacks list of the topic - callbacks.append(callback) + topic_is_new = topic not in self._topics + self._topics.add(topic) + # init to empty list if no entry + callbacks = self._callbacks[topic] = self._callbacks.get(topic, []) + # add callback to callbacks list of the topic + callbacks.append(callback) + if topic_is_new and self.is_ready(): + return self._rpc_channel.other.subscribe(topics=[topic]) else: - raise PubSubClientInvalidStateException( - "Client already connected and subscribed" - ) + # If we can't return an RPC call future then we need + # to supply something else to not fail when the + # calling code awaits the result of this function. + future = asyncio.Future() + future.set_result(None) + return future + + def unsubscribe(self, topic: Topic): + """ + Unsubscribe for events + + Args: + topic (Topic): the identifier of the event topic to be unsubscribed. + Note: You can use ALL_TOPICS (event_notifier.ALL_TOPICS) to unsubscribe all topics + + Returns: + Coroutine: awaitable task to subscribe to topic if connected. + """ + # Create none-future which can be safely awaited + # but which also will not give warnings + # if it isn't awaited. This is returned + # on code paths which do not make RPC calls. + none_future = asyncio.Future() + none_future.set_result(None) + + # Topics to potentially make RPC calls about + topics = list(self._topics) if topic is ALL_TOPICS else [topic] + + # Handle ALL_TOPICS or specific topics + if topic is ALL_TOPICS and not self._topics: + logger.warning(f"Cannot unsubscribe 'ALL_TOPICS'. No topics are subscribed.") + return none_future + elif topic is not ALL_TOPICS and topic not in self._topics: + logger.warning(f"Cannot unsubscribe topic '{topic}' which is not subscribed.") + return none_future + elif topic is ALL_TOPICS and self._topics: + logger.debug(f"Unsubscribing all topics: {self._topics}") + # remove all topics and callbacks + self._topics.clear() + self._callbacks.clear() + elif topic is not ALL_TOPICS and topic in self._topics: + logger.debug(f"Unsubscribing topic '{topic}'") + self._topics.remove(topic) + self._callbacks.pop(topic, None) + + if self.is_ready(): + return self._rpc_channel.other.unsubscribe(topics=topics) + else: + return none_future async def publish( self, topics: TopicList, data=None, sync=True, notifier_id=None diff --git a/fastapi_websocket_pubsub/pub_sub_server.py b/fastapi_websocket_pubsub/pub_sub_server.py index 83f3913..755268a 100644 --- a/fastapi_websocket_pubsub/pub_sub_server.py +++ b/fastapi_websocket_pubsub/pub_sub_server.py @@ -91,6 +91,10 @@ async def subscribe( ) -> List[Subscription]: return await self.notifier.subscribe(self._subscriber_id, topics, callback) + async def unsubscribe( + self, topics: Union[TopicList, ALL_TOPICS]) -> List[Subscription]: + return await self.notifier.unsubscribe(self._subscriber_id, topics) + async def publish(self, topics: Union[TopicList, Topic], data=None): """ Publish events to subscribres of given topics currently connected to the endpoint diff --git a/fastapi_websocket_pubsub/rpc_event_methods.py b/fastapi_websocket_pubsub/rpc_event_methods.py index ea797c6..73b4163 100644 --- a/fastapi_websocket_pubsub/rpc_event_methods.py +++ b/fastapi_websocket_pubsub/rpc_event_methods.py @@ -49,6 +49,19 @@ async def callback(subscription: Subscription, data): "Failed to subscribe to RPC events notifier", topics) return False + async def unsubscribe(self, topics: TopicList = []) -> bool: + """ + provided by the server so that the client can unsubscribe topics. + """ + for topic in topics.copy(): + if topic not in self.event_notifier._topics: + self.logger.warning(f"Cannot unsubscribe topic '{topic}' which is not subscribed.") + topics.remove(topic) + # We'll use the remote channel id as our subscriber id + sub_id = await self._get_channel_id_() + await self.event_notifier.unsubscribe(sub_id, topics) + return True + async def publish(self, topics: TopicList = [], data=None, sync=True, notifier_id=None) -> bool: """ Publish an event through the server to subscribers diff --git a/tests/basic_test.py b/tests/basic_test.py index 4b5c1aa..58416a4 100644 --- a/tests/basic_test.py +++ b/tests/basic_test.py @@ -133,3 +133,47 @@ async def on_event(data, topic): assert published.result # wait for finish trigger await asyncio.wait_for(finish.wait(), 5) + + +@pytest.mark.asyncio +async def test_pub_sub_unsub(server): + """ + Check client can unsubscribe topic and subscribe again. + """ + # finish trigger + finish = asyncio.Event() + async with PubSubClient() as client: + + async def on_event(data, topic): + assert data == DATA + finish.set() + + # subscribe for the event + client.subscribe(EVENT_TOPIC, on_event) + # start listentining + client.start_client(uri) + # wait for the client to be ready to receive events + await client.wait_until_ready() + # trigger the server via an HTTP route + requests.get(trigger_url) + # wait for finish trigger + await asyncio.wait_for(finish.wait(), 5) + assert finish.is_set() + + # unsubscribe and see that we don't get a message + finish.clear() + await client.unsubscribe(EVENT_TOPIC) + requests.get(trigger_url) + # wait for finish trigger which isn't coming + with pytest.raises(asyncio.TimeoutError) as excinfo: + await asyncio.wait_for(finish.wait(), 5) + assert not finish.is_set() + + # subscribe again and observe that we get the trigger + finish.clear() + await client.subscribe(EVENT_TOPIC, on_event) + # trigger the server via an HTTP route + requests.get(trigger_url) + # wait for finish trigger + await asyncio.wait_for(finish.wait(), 5) + assert finish.is_set() \ No newline at end of file