12
12
class SerialTransport (asyncio .Transport ):
13
13
"""An asyncio serial transport."""
14
14
15
+ force_poll : bool = False
16
+
15
17
def __init__ (self , loop , protocol , * args , ** kwargs ):
16
18
"""Initialize."""
17
19
super ().__init__ ()
18
20
self .async_loop = loop
19
21
self ._protocol : asyncio .BaseProtocol = protocol
20
22
self .sync_serial = serial .serial_for_url (* args , ** kwargs )
21
23
self ._write_buffer = []
22
- self ._has_reader = False
23
- self ._has_writer = False
24
+ self .poll_task = None
24
25
self ._poll_wait_time = 0.0005
25
26
self .sync_serial .timeout = 0
26
27
self .sync_serial .write_timeout = 0
27
28
28
29
def setup (self ):
29
30
"""Prepare to read/write"""
30
- self .async_loop .call_soon (self ._protocol .connection_made , self )
31
- if os .name == "nt" :
32
- self ._has_reader = self .async_loop .call_later (
33
- self ._poll_wait_time , self ._poll_read
34
- )
31
+ if os .name == "nt" or self .force_poll :
32
+ self .poll_task = asyncio .create_task (self ._polling_task ())
35
33
else :
36
34
self .async_loop .add_reader (self .sync_serial .fileno (), self ._read_ready )
37
- self ._has_reader = True
35
+ self .async_loop . call_soon ( self . _protocol . connection_made , self )
38
36
39
37
def close (self , exc = None ):
40
38
"""Close the transport gracefully."""
@@ -43,13 +41,13 @@ def close(self, exc=None):
43
41
with contextlib .suppress (Exception ):
44
42
self .sync_serial .flush ()
45
43
46
- if self ._has_reader :
47
- if os .name == "nt" :
48
- self ._has_reader .cancel ()
49
- else :
50
- self .async_loop .remove_reader (self .sync_serial .fileno ())
51
- self ._has_reader = False
52
44
self .flush ()
45
+ if self .poll_task :
46
+ self .poll_task .cancel ()
47
+ _ = asyncio .ensure_future (self .poll_task )
48
+ self .poll_task = None
49
+ else :
50
+ self .async_loop .remove_reader (self .sync_serial .fileno ())
53
51
self .sync_serial .close ()
54
52
self .sync_serial = None
55
53
with contextlib .suppress (Exception ):
@@ -58,21 +56,13 @@ def close(self, exc=None):
58
56
def write (self , data ):
59
57
"""Write some data to the transport."""
60
58
self ._write_buffer .append (data )
61
- if not self ._has_writer :
62
- if os .name == "nt" :
63
- self ._has_writer = self .async_loop .call_soon (self ._poll_write )
64
- else :
65
- self .async_loop .add_writer (self .sync_serial .fileno (), self ._write_ready )
66
- self ._has_writer = True
59
+ if not self .poll_task :
60
+ self .async_loop .add_writer (self .sync_serial .fileno (), self ._write_ready )
67
61
68
62
def flush (self ):
69
63
"""Clear output buffer and stops any more data being written"""
70
- if self ._has_writer :
71
- if os .name == "nt" :
72
- self ._has_writer .cancel ()
73
- else :
74
- self .async_loop .remove_writer (self .sync_serial .fileno ())
75
- self ._has_writer = False
64
+ if not self .poll_task :
65
+ self .async_loop .remove_writer (self .sync_serial .fileno ())
76
66
self ._write_buffer .clear ()
77
67
78
68
# ------------------------------------------------
@@ -141,34 +131,32 @@ def _write_ready(self):
141
131
"""Asynchronously write buffered data."""
142
132
data = b"" .join (self ._write_buffer )
143
133
try :
144
- if nlen := self .sync_serial .write (data ) < len (data ):
145
- self ._write_buffer = data [nlen :]
146
- return True
134
+ if (nlen := self .sync_serial .write (data )) < len (data ):
135
+ self ._write_buffer = [data [nlen :]]
136
+ if not self .poll_task :
137
+ self .async_loop .add_writer (
138
+ self .sync_serial .fileno (), self ._write_ready
139
+ )
140
+ return
147
141
self .flush ()
148
142
except (BlockingIOError , InterruptedError ):
149
- return True
143
+ return
150
144
except serial .SerialException as exc :
151
145
self .close (exc = exc )
152
- return False
153
146
154
- def _poll_read (self ):
155
- if self ._has_reader :
156
- try :
157
- self ._has_reader = self .async_loop .call_later (
158
- self ._poll_wait_time , self ._poll_read
159
- )
147
+ async def _polling_task (self ):
148
+ """Poll and try to read/write."""
149
+ try :
150
+ while True :
151
+ await asyncio .sleep (self ._poll_wait_time )
152
+ while self ._write_buffer :
153
+ self ._write_ready ()
160
154
if self .sync_serial .in_waiting :
161
155
self ._read_ready ()
162
- except serial .SerialException as exc :
163
- self .close (exc = exc )
164
-
165
- def _poll_write (self ):
166
- if not self ._has_writer :
167
- return
168
- if self ._write_ready ():
169
- self ._has_writer = self .async_loop .call_later (
170
- self ._poll_wait_time , self ._poll_write
171
- )
156
+ except serial .SerialException as exc :
157
+ self .close (exc = exc )
158
+ except asyncio .CancelledError :
159
+ pass
172
160
173
161
174
162
async def create_serial_connection (loop , protocol_factory , * args , ** kwargs ):
0 commit comments