Skip to content

Commit b481f3c

Browse files
authored
eventstream fixes (#206)
- fix `setup.py` so that eventstream modules get installed - fix crash when `eventstream.rpc.ClientConnection` was GC'd during `on_connection_shutdown` callback - change default args to always be `=None`, so that it's easy to pass along defaults from wrapper functions - centralize logic that sanitizes message args going in each direction between C <-> Python
1 parent 037fe03 commit b481f3c

File tree

4 files changed

+49
-37
lines changed

4 files changed

+49
-37
lines changed

awscrt/eventstream/rpc.py

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,35 @@ def on_protocol_message(
185185
pass
186186

187187

188+
def _to_binding_msg_args(headers, payload, message_type, flags):
189+
"""
190+
Transform args that a python send-msg function would take,
191+
into args that a native send-msg function would take.
192+
"""
193+
# python functions for sending messages
194+
if headers is None:
195+
headers = []
196+
else:
197+
headers = [i._as_binding_tuple() for i in headers]
198+
if payload is None:
199+
payload = b''
200+
if flags is None:
201+
flags = MessageFlag.NONE
202+
return (headers, payload, message_type, flags)
203+
204+
205+
def _from_binding_msg_args(headers, payload, message_type, flags):
206+
"""
207+
Transform msg-received args that came from native,
208+
into msg-received args presented to python users.
209+
"""
210+
headers = [Header._from_binding_tuple(i) for i in headers]
211+
if payload is None:
212+
payload = b''
213+
message_type = MessageType(message_type)
214+
return (headers, payload, message_type, flags)
215+
216+
188217
class ClientConnection(NativeResource):
189218
"""A client connection for the event-stream RPC protocol.
190219
@@ -323,8 +352,7 @@ def _on_protocol_message(bound_weak_handler, headers, payload, message_type, fla
323352
handler = bound_weak_handler()
324353
if handler:
325354
# transform from simple types to actual classes
326-
headers = [Header._from_binding_tuple(i) for i in headers]
327-
message_type = MessageType(message_type)
355+
headers, payload, message_type, flags = _from_binding_msg_args(headers, payload, message_type, flags)
328356
handler.on_protocol_message(
329357
headers=headers,
330358
payload=payload,
@@ -336,8 +364,7 @@ def _on_continuation_message(bound_weak_handler, headers, payload, message_type,
336364
handler = bound_weak_handler()
337365
if handler:
338366
# transform from simple types to actual classes
339-
headers = [Header._from_binding_tuple(i) for i in headers]
340-
message_type = MessageType(message_type)
367+
headers, payload, message_type, flags = _from_binding_msg_args(headers, payload, message_type, flags)
341368
handler.on_continuation_message(
342369
headers=headers,
343370
payload=payload,
@@ -358,7 +385,7 @@ def _on_flush(bound_future, bound_callback, error_code):
358385
else:
359386
bound_future.set_result(None)
360387

361-
def close(self, reason=None):
388+
def close(self):
362389
"""Close the connection.
363390
364391
Shutdown is asynchronous. This call has no effect if the connection is already
@@ -384,10 +411,10 @@ def is_open(self):
384411
def send_protocol_message(
385412
self,
386413
*,
387-
headers: Optional[Sequence[Header]] = [],
388-
payload: Optional[ByteString] = b'',
414+
headers: Optional[Sequence[Header]] = None,
415+
payload: Optional[ByteString] = None,
389416
message_type: MessageType,
390-
flags: int = MessageFlag.NONE,
417+
flags: Optional[int] = None,
391418
on_flush: Callable = None) -> Future:
392419
"""Send a protocol message.
393420
@@ -428,7 +455,7 @@ def send_protocol_message(
428455
future = Future()
429456

430457
# native code deals with simplified types
431-
headers = [i._as_binding_tuple() for i in headers]
458+
headers, payload, message_type, flags = _to_binding_msg_args(headers, payload, message_type, flags)
432459

433460
_awscrt.event_stream_rpc_client_connection_send_protocol_message(
434461
self._binding,
@@ -488,10 +515,10 @@ def activate(
488515
self,
489516
*,
490517
operation: str,
491-
headers: Sequence[Header] = [],
492-
payload: ByteString = b'',
518+
headers: Sequence[Header] = None,
519+
payload: ByteString = None,
493520
message_type: MessageType,
494-
flags: int = MessageFlag.NONE,
521+
flags: int = None,
495522
on_flush: Callable = None):
496523
"""
497524
Activate the stream by sending its first message.
@@ -537,7 +564,7 @@ def activate(
537564
flush_future = Future()
538565

539566
# native code deals with simplified types
540-
headers = [i._as_binding_tuple() for i in headers]
567+
headers, payload, message_type, flags = _to_binding_msg_args(headers, payload, message_type, flags)
541568

542569
_awscrt.event_stream_rpc_client_continuation_activate(
543570
self._binding,
@@ -553,10 +580,10 @@ def activate(
553580
def send_message(
554581
self,
555582
*,
556-
headers: Sequence[Header] = [],
557-
payload: ByteString = b'',
583+
headers: Sequence[Header] = None,
584+
payload: ByteString = None,
558585
message_type: MessageType,
559-
flags: int = MessageFlag.NONE,
586+
flags: int = None,
560587
on_flush: Callable = None) -> Future:
561588
"""
562589
Send a continuation message.
@@ -600,7 +627,7 @@ def send_message(
600627
"""
601628
future = Future()
602629
# native code deals with simplified types
603-
headers = [i._as_binding_tuple() for i in headers]
630+
headers, payload, message_type, flags = _to_binding_msg_args(headers, payload, message_type, flags)
604631

605632
_awscrt.event_stream_rpc_client_continuation_send_message(
606633
self._binding,

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,8 @@ def awscrt_ext():
304304
author_email="aws-sdk-common-runtime@amazon.com",
305305
description="A common runtime for AWS Python projects",
306306
url="https://github.com/awslabs/aws-crt-python",
307-
packages=['awscrt'],
307+
# Note: find_packages() without extra args will end up installing test/
308+
packages=setuptools.find_packages(include=['awscrt*']),
308309
classifiers=[
309310
"Programming Language :: Python :: 3",
310311
"License :: OSI Approved :: Apache Software License",

source/event_stream_rpc_client_connection.c

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@ static void s_on_connection_shutdown(
187187

188188
AWS_FATAL_ASSERT(connection->native && "Illegal for event-stream connection shutdown to fire before setup");
189189
AWS_FATAL_ASSERT(!connection->shutdown_complete && "illegal for event-stream connection shutdown to fire twice");
190-
connection->shutdown_complete = true;
191190

192191
PyGILState_STATE state;
193192
if (aws_py_gilstate_ensure(&state)) {
@@ -202,6 +201,7 @@ static void s_on_connection_shutdown(
202201
PyErr_WriteUnraisable(PyErr_Occurred());
203202
}
204203

204+
connection->shutdown_complete = true;
205205
s_connection_destroy_if_ready(connection);
206206
PyGILState_Release(state);
207207
}
@@ -257,20 +257,12 @@ static void s_on_protocol_message(
257257
return; /* Python has shut down. Nothing matters anymore, but don't crash */
258258
}
259259

260-
/* We always want to deliver bytes to python user, even if the length is 0.
261-
* But PyObject_CallFunction() with "y#" will convert a NULL ptr to None instead of 0-length bytes.
262-
* Therefore, if message_args->payload_buffer is NULL, pass some other valid ptr instead. */
263-
const char *payload_ptr = (void *)message_args->payload->buffer;
264-
if (payload_ptr == NULL) {
265-
payload_ptr = "";
266-
}
267-
268260
PyObject *result = PyObject_CallFunction(
269261
connection->on_protocol_message,
270262
"(Oy#iI)",
271263
/* NOTE: if headers_create() returns NULL, then PyObject_CallFunction() fails too, which is convenient */
272264
aws_py_event_stream_python_headers_create(message_args->headers, message_args->headers_count),
273-
payload_ptr,
265+
message_args->payload->buffer,
274266
message_args->payload->len,
275267
message_args->message_type,
276268
message_args->message_flags);

source/event_stream_rpc_client_continuation.c

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -162,20 +162,12 @@ static void s_on_continuation_message(
162162
return; /* Python has shut down. Nothing matters anymore, but don't crash */
163163
}
164164

165-
/* We always want to deliver bytes to python user, even if the length is 0.
166-
* But PyObject_CallFunction() with "y#" will convert a NULL ptr to None instead of 0-length bytes.
167-
* Therefore, if message_args->payload_buffer is NULL, pass some other valid ptr instead. */
168-
const char *payload_ptr = (void *)message_args->payload->buffer;
169-
if (payload_ptr == NULL) {
170-
payload_ptr = "";
171-
}
172-
173165
PyObject *result = PyObject_CallFunction(
174166
continuation->on_message,
175167
"(Oy#iI)",
176168
/* NOTE: if headers_create() returns NULL, then PyObject_CallFunction() fails too, which is convenient */
177169
aws_py_event_stream_python_headers_create(message_args->headers, message_args->headers_count),
178-
payload_ptr,
170+
message_args->payload->buffer,
179171
message_args->payload->len,
180172
message_args->message_type,
181173
message_args->message_flags);

0 commit comments

Comments
 (0)