Skip to content

Commit 53d9cff

Browse files
authored
Accept duck typed python streams. (#185)
1 parent e67a5e2 commit 53d9cff

File tree

4 files changed

+121
-35
lines changed

4 files changed

+121
-35
lines changed

awscrt/http.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def port(self):
179179
def request(self, request, on_response=None, on_body=None):
180180
"""Create :class:`HttpClientStream` to carry out the request/response exchange.
181181
182-
NOTE: The stream sends no data until :meth:`HttpClientStream.activate()`
182+
NOTE: The HTTP stream sends no data until :meth:`HttpClientStream.activate()`
183183
is called. Call activate() when you're ready for callbacks and events to fire.
184184
185185
Args:
@@ -188,7 +188,7 @@ def request(self, request, on_response=None, on_body=None):
188188
on_response: Optional callback invoked once main response headers are received.
189189
The function should take the following arguments and return nothing:
190190
191-
* `http_stream` (:class:`HttpClientStream`): Stream carrying
191+
* `http_stream` (:class:`HttpClientStream`): HTTP stream carrying
192192
out this request/response exchange.
193193
194194
* `status_code` (int): Response status code.
@@ -198,21 +198,21 @@ def request(self, request, on_response=None, on_body=None):
198198
199199
* `**kwargs` (dict): Forward compatibility kwargs.
200200
201-
An exception raise by this function will cause the stream to end in error.
201+
An exception raise by this function will cause the HTTP stream to end in error.
202202
This callback is always invoked on the connection's event-loop thread.
203203
204204
on_body: Optional callback invoked 0+ times as response body data is received.
205205
The function should take the following arguments and return nothing:
206206
207-
* `http_stream` (:class:`HttpClientStream`): Stream carrying
207+
* `http_stream` (:class:`HttpClientStream`): HTTP stream carrying
208208
out this request/response exchange.
209209
210210
* `chunk` (buffer): Response body data (not necessarily
211211
a whole "chunk" of chunked encoding).
212212
213213
* `**kwargs` (dict): Forward-compatibility kwargs.
214214
215-
An exception raise by this function will cause the stream to end in error.
215+
An exception raise by this function will cause the HTTP stream to end in error.
216216
This callback is always invoked on the connection's event-loop thread.
217217
218218
Returns:
@@ -245,11 +245,11 @@ def _on_body(self, chunk):
245245

246246

247247
class HttpClientStream(HttpStreamBase):
248-
"""Stream that sends a request and receives a response.
248+
"""HTTP stream that sends a request and receives a response.
249249
250250
Create an HttpClientStream with :meth:`HttpClientConnection.request()`.
251251
252-
NOTE: The stream sends no data until :meth:`HttpClientStream.activate()`
252+
NOTE: The HTTP stream sends no data until :meth:`HttpClientStream.activate()`
253253
is called. Call activate() when you're ready for callbacks and events to fire.
254254
255255
Attributes:
@@ -288,7 +288,7 @@ def response_status_code(self):
288288
def activate(self):
289289
"""Begin sending the request.
290290
291-
The stream does nothing until this is called. Call activate() when you
291+
The HTTP stream does nothing until this is called. Call activate() when you
292292
are ready for its callbacks and events to fire.
293293
"""
294294
_awscrt.http_client_stream_activate(self)
@@ -332,7 +332,7 @@ def headers(self):
332332

333333
@property
334334
def body_stream(self):
335-
"""InputStream: Stream of outgoing body."""
335+
"""InputStream: Binary stream of outgoing body."""
336336
return _awscrt.http_message_get_body_stream(self._binding)
337337

338338
@body_stream.setter
@@ -352,7 +352,7 @@ class HttpRequest(HttpMessageBase):
352352
path (str): HTTP path-and-query value. Default value is "/".
353353
headers (Optional[HttpHeaders]): Optional headers. If None specified,
354354
an empty :class:`HttpHeaders` is created.
355-
body_string(Optional[Union[InputStream, io.IOBase]]): Optional body as stream.
355+
body_stream(Optional[Union[InputStream, io.IOBase]]): Optional body as binary stream.
356356
"""
357357

358358
__slots__ = ()

awscrt/io.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import _awscrt
1313
from awscrt import NativeResource
1414
from enum import IntEnum
15-
import io
1615
import threading
1716

1817

@@ -490,28 +489,55 @@ def is_alpn_available():
490489

491490

492491
class InputStream(NativeResource):
493-
"""InputStream allows `awscrt` native code to read from Python I/O classes.
492+
"""InputStream allows `awscrt` native code to read from Python binary I/O classes.
494493
495494
Args:
496-
stream (io.IOBase): Python I/O stream to wrap.
495+
stream (io.IOBase): Python binary I/O stream to wrap.
497496
"""
498-
__slots__ = ()
497+
__slots__ = ('_stream')
499498
# TODO: Implement IOBase interface so Python can read from this class as well.
500499

501500
def __init__(self, stream):
502-
assert isinstance(stream, io.IOBase)
501+
# duck-type instead of checking inheritance from IOBase.
502+
# At the least, stream must have read()
503+
if not callable(getattr(stream, 'read', None)):
504+
raise TypeError('I/O stream type expected')
503505
assert not isinstance(stream, InputStream)
504506

505507
super().__init__()
506-
self._binding = _awscrt.input_stream_new(stream)
508+
self._stream = stream
509+
self._binding = _awscrt.input_stream_new(self)
510+
511+
def _read_into_memoryview(self, m):
512+
# Read into memoryview m.
513+
# Return number of bytes read, or None if no data available.
514+
try:
515+
# prefer the most efficient read methods,
516+
if hasattr(self._stream, 'readinto1'):
517+
return self._stream.readinto1(m)
518+
if hasattr(self._stream, 'readinto'):
519+
return self._stream.readinto(m)
520+
521+
if hasattr(self._stream, 'read1'):
522+
data = self._stream.read1(len(m))
523+
else:
524+
data = self._stream.read(len(m))
525+
n = len(data)
526+
m[:n] = data
527+
return n
528+
except BlockingIOError:
529+
return None
530+
531+
def _seek(self, offset, whence):
532+
return self._stream.seek(offset, whence)
507533

508534
@classmethod
509535
def wrap(cls, stream, allow_none=False):
510536
"""
511537
Given some stream type, returns an :class:`InputStream`.
512538
513539
Args:
514-
stream (Union[io.IOBase, InputStream, None]): I/O stream to wrap.
540+
stream (Union[io.IOBase, InputStream, None]): Binary I/O stream to wrap.
515541
allow_none (bool): Whether to allow `stream` to be None.
516542
If False (default), and `stream` is None, an exception is raised.
517543
@@ -520,10 +546,8 @@ def wrap(cls, stream, allow_none=False):
520546
Otherwise, an :class:`InputStream` which wraps the `stream` is returned.
521547
If `allow_none` is True, and `stream` is None, then None is returned.
522548
"""
523-
if isinstance(stream, InputStream):
524-
return stream
525-
if isinstance(stream, io.IOBase):
526-
return cls(stream)
527549
if stream is None and allow_none:
528550
return None
529-
raise TypeError('I/O stream type expected')
551+
if isinstance(stream, InputStream):
552+
return stream
553+
return cls(stream)

source/io.c

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -627,13 +627,13 @@ struct aws_input_stream_py_impl {
627627

628628
bool is_end_of_stream;
629629

630-
/* Dependencies that must outlive this */
631-
PyObject *io;
630+
/* Weak reference proxy to python self. */
631+
PyObject *self_proxy;
632632
};
633633

634634
static void s_aws_input_stream_py_destroy(struct aws_input_stream *stream) {
635635
struct aws_input_stream_py_impl *impl = stream->impl;
636-
Py_DECREF(impl->io);
636+
Py_XDECREF(impl->self_proxy);
637637
aws_mem_release(stream->allocator, stream);
638638
}
639639

@@ -653,7 +653,7 @@ static int s_aws_input_stream_py_seek(
653653
return AWS_OP_ERR; /* Python has shut down. Nothing matters anymore, but don't crash */
654654
}
655655

656-
method_result = PyObject_CallMethod(impl->io, "seek", "(li)", offset, basis);
656+
method_result = PyObject_CallMethod(impl->self_proxy, "_seek", "(li)", offset, basis);
657657
if (!method_result) {
658658
aws_result = aws_py_raise_error();
659659
goto done;
@@ -689,7 +689,7 @@ int s_aws_input_stream_py_read(struct aws_input_stream *stream, struct aws_byte_
689689
goto done;
690690
}
691691

692-
method_result = PyObject_CallMethod(impl->io, "readinto", "(O)", memory_view);
692+
method_result = PyObject_CallMethod(impl->self_proxy, "_read_into_memoryview", "(O)", memory_view);
693693
if (!method_result) {
694694
aws_result = aws_py_raise_error();
695695
goto done;
@@ -745,9 +745,9 @@ static struct aws_input_stream_vtable s_aws_input_stream_py_vtable = {
745745
.destroy = s_aws_input_stream_py_destroy,
746746
};
747747

748-
static struct aws_input_stream *aws_input_stream_new_from_py(PyObject *io) {
748+
static struct aws_input_stream *aws_input_stream_new_from_py(PyObject *py_self) {
749749

750-
if (!io || (io == Py_None)) {
750+
if (!py_self || (py_self == Py_None)) {
751751
aws_raise_error(AWS_ERROR_INVALID_ARGUMENT);
752752
return NULL;
753753
}
@@ -761,10 +761,15 @@ static struct aws_input_stream *aws_input_stream_new_from_py(PyObject *io) {
761761
impl->base.allocator = alloc;
762762
impl->base.vtable = &s_aws_input_stream_py_vtable;
763763
impl->base.impl = impl;
764-
impl->io = io;
765-
Py_INCREF(impl->io);
764+
impl->self_proxy = PyWeakref_NewProxy(py_self, NULL);
765+
if (!impl->self_proxy) {
766+
goto error;
767+
}
766768

767769
return &impl->base;
770+
error:
771+
aws_input_stream_destroy(&impl->base);
772+
return NULL;
768773
}
769774

770775
/**
@@ -783,12 +788,12 @@ static void s_input_stream_capsule_destructor(PyObject *py_capsule) {
783788
PyObject *aws_py_input_stream_new(PyObject *self, PyObject *args) {
784789
(void)self;
785790

786-
PyObject *py_io;
787-
if (!PyArg_ParseTuple(args, "O", &py_io)) {
791+
PyObject *py_self;
792+
if (!PyArg_ParseTuple(args, "O", &py_self)) {
788793
return NULL;
789794
}
790795

791-
struct aws_input_stream *stream = aws_input_stream_new_from_py(py_io);
796+
struct aws_input_stream *stream = aws_input_stream_new_from_py(py_self);
792797
if (!stream) {
793798
return PyErr_AwsLastError();
794799
}

test/test_io.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
# SPDX-License-Identifier: Apache-2.0.
33

44
from __future__ import absolute_import
5-
from awscrt.io import ClientBootstrap, ClientTlsContext, DefaultHostResolver, EventLoopGroup, TlsConnectionOptions, TlsContextOptions
5+
from awscrt.io import ClientBootstrap, ClientTlsContext, DefaultHostResolver, EventLoopGroup, InputStream, TlsConnectionOptions, TlsContextOptions
66
from test import NativeResourceTest, TIMEOUT
7+
import io
78
import unittest
89

910

@@ -84,5 +85,61 @@ def test_server_name(self):
8485
conn_opt.set_server_name('localhost')
8586

8687

88+
class MockPythonStream:
89+
"""For testing duck-typed stream classes.
90+
Doesn't inherit from io.IOBase. Doesn't implement readinto()"""
91+
92+
def __init__(self, src_data):
93+
self.data = bytes(src_data)
94+
self.len = len(src_data)
95+
self.pos = 0
96+
97+
def seek(self, where):
98+
self.pos = where
99+
100+
def tell(self):
101+
return self.pos
102+
103+
def read(self, amount=None):
104+
if amount is None:
105+
amount = self.len - self.pos
106+
else:
107+
amount = min(amount, self.len - self.pos)
108+
prev_pos = self.pos
109+
self.pos += amount
110+
return self.data[prev_pos: self.pos]
111+
112+
113+
class InputStreamTest(NativeResourceTest):
114+
def _test(self, python_stream, expected):
115+
input_stream = InputStream(python_stream)
116+
result = bytearray()
117+
fixed_mv_len = 4
118+
fixed_mv = memoryview(bytearray(fixed_mv_len))
119+
while True:
120+
read_len = input_stream._read_into_memoryview(fixed_mv)
121+
if read_len is None:
122+
continue
123+
if read_len == 0:
124+
break
125+
if read_len > 0:
126+
self.assertLessEqual(read_len, fixed_mv_len)
127+
result += fixed_mv[0:read_len]
128+
129+
self.assertEqual(expected, result)
130+
131+
def test_read_official_io(self):
132+
# Read from a class defined in the io module
133+
src_data = b'a long string here'
134+
python_stream = io.BytesIO(src_data)
135+
self._test(python_stream, src_data)
136+
137+
def test_read_duck_typed_io(self):
138+
# Read from a class defined in the io module
139+
src_data = b'a man a can a planal canada'
140+
python_stream = MockPythonStream(src_data)
141+
self._test(python_stream, src_data)
142+
143+
87144
if __name__ == '__main__':
88145
unittest.main()

0 commit comments

Comments
 (0)