@@ -248,3 +248,62 @@ async def test_subscribe(self, mock_connect):
248
248
call ({"data" : {"messageAdded" : "two" }}),
249
249
]
250
250
)
251
+
252
+ @patch ("logging.info" )
253
+ @patch ("websockets.connect" )
254
+ async def test_does_not_crash_with_keep_alive (self , mock_connect , mock_info ):
255
+ """Subsribe a GraphQL subscription."""
256
+ mock_websocket = mock_connect .return_value .__aenter__ .return_value
257
+ mock_websocket .send = AsyncMock ()
258
+ mock_websocket .__aiter__ .return_value = [
259
+ '{"type": "ka"}' ,
260
+ ]
261
+
262
+ client = GraphqlClient (endpoint = "ws://www.test-api.com/graphql" )
263
+ query = """
264
+ subscription onMessageAdded {
265
+ messageAdded
266
+ }
267
+ """
268
+
269
+ await client .subscribe (query = query , handle = MagicMock ())
270
+
271
+ mock_info .assert_has_calls ([call ("the server sent a keep alive message" )])
272
+
273
+ @patch ("websockets.connect" )
274
+ async def test_headers_passed_to_websocket_connect (self , mock_connect ):
275
+ """Subsribe a GraphQL subscription."""
276
+ mock_websocket = mock_connect .return_value .__aenter__ .return_value
277
+ mock_websocket .send = AsyncMock ()
278
+ mock_websocket .__aiter__ .return_value = [
279
+ '{"type": "data", "id": "1", "payload": {"data": {"messageAdded": "one"}}}' ,
280
+ ]
281
+
282
+ expected_endpoint = "ws://www.test-api.com/graphql"
283
+ client = GraphqlClient (endpoint = expected_endpoint )
284
+
285
+ query = """
286
+ subscription onMessageAdded {
287
+ messageAdded
288
+ }
289
+ """
290
+
291
+ mock_handle = MagicMock ()
292
+
293
+ expected_headers = {"some" : "header" }
294
+
295
+ await client .subscribe (
296
+ query = query , handle = mock_handle , headers = expected_headers
297
+ )
298
+
299
+ mock_connect .assert_called_with (
300
+ expected_endpoint ,
301
+ subprotocols = ["graphql-ws" ],
302
+ extra_headers = expected_headers ,
303
+ )
304
+
305
+ mock_handle .assert_has_calls (
306
+ [
307
+ call ({"data" : {"messageAdded" : "one" }}),
308
+ ]
309
+ )
0 commit comments