Skip to content

Commit 1d1750f

Browse files
authored
Simplify transport_serial (modbus use) (#1808)
1 parent 39177d7 commit 1d1750f

File tree

3 files changed

+96
-48
lines changed

3 files changed

+96
-48
lines changed

pymodbus/transport/transport_serial.py

Lines changed: 35 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -12,29 +12,27 @@
1212
class SerialTransport(asyncio.Transport):
1313
"""An asyncio serial transport."""
1414

15+
force_poll: bool = False
16+
1517
def __init__(self, loop, protocol, *args, **kwargs):
1618
"""Initialize."""
1719
super().__init__()
1820
self.async_loop = loop
1921
self._protocol: asyncio.BaseProtocol = protocol
2022
self.sync_serial = serial.serial_for_url(*args, **kwargs)
2123
self._write_buffer = []
22-
self._has_reader = False
23-
self._has_writer = False
24+
self.poll_task = None
2425
self._poll_wait_time = 0.0005
2526
self.sync_serial.timeout = 0
2627
self.sync_serial.write_timeout = 0
2728

2829
def setup(self):
2930
"""Prepare to read/write"""
30-
self.async_loop.call_soon(self._protocol.connection_made, self)
31-
if os.name == "nt":
32-
self._has_reader = self.async_loop.call_later(
33-
self._poll_wait_time, self._poll_read
34-
)
31+
if os.name == "nt" or self.force_poll:
32+
self.poll_task = asyncio.create_task(self._polling_task())
3533
else:
3634
self.async_loop.add_reader(self.sync_serial.fileno(), self._read_ready)
37-
self._has_reader = True
35+
self.async_loop.call_soon(self._protocol.connection_made, self)
3836

3937
def close(self, exc=None):
4038
"""Close the transport gracefully."""
@@ -43,13 +41,13 @@ def close(self, exc=None):
4341
with contextlib.suppress(Exception):
4442
self.sync_serial.flush()
4543

46-
if self._has_reader:
47-
if os.name == "nt":
48-
self._has_reader.cancel()
49-
else:
50-
self.async_loop.remove_reader(self.sync_serial.fileno())
51-
self._has_reader = False
5244
self.flush()
45+
if self.poll_task:
46+
self.poll_task.cancel()
47+
_ = asyncio.ensure_future(self.poll_task)
48+
self.poll_task = None
49+
else:
50+
self.async_loop.remove_reader(self.sync_serial.fileno())
5351
self.sync_serial.close()
5452
self.sync_serial = None
5553
with contextlib.suppress(Exception):
@@ -58,21 +56,13 @@ def close(self, exc=None):
5856
def write(self, data):
5957
"""Write some data to the transport."""
6058
self._write_buffer.append(data)
61-
if not self._has_writer:
62-
if os.name == "nt":
63-
self._has_writer = self.async_loop.call_soon(self._poll_write)
64-
else:
65-
self.async_loop.add_writer(self.sync_serial.fileno(), self._write_ready)
66-
self._has_writer = True
59+
if not self.poll_task:
60+
self.async_loop.add_writer(self.sync_serial.fileno(), self._write_ready)
6761

6862
def flush(self):
6963
"""Clear output buffer and stops any more data being written"""
70-
if self._has_writer:
71-
if os.name == "nt":
72-
self._has_writer.cancel()
73-
else:
74-
self.async_loop.remove_writer(self.sync_serial.fileno())
75-
self._has_writer = False
64+
if not self.poll_task:
65+
self.async_loop.remove_writer(self.sync_serial.fileno())
7666
self._write_buffer.clear()
7767

7868
# ------------------------------------------------
@@ -141,34 +131,32 @@ def _write_ready(self):
141131
"""Asynchronously write buffered data."""
142132
data = b"".join(self._write_buffer)
143133
try:
144-
if nlen := self.sync_serial.write(data) < len(data):
145-
self._write_buffer = data[nlen:]
146-
return True
134+
if (nlen := self.sync_serial.write(data)) < len(data):
135+
self._write_buffer = [data[nlen:]]
136+
if not self.poll_task:
137+
self.async_loop.add_writer(
138+
self.sync_serial.fileno(), self._write_ready
139+
)
140+
return
147141
self.flush()
148142
except (BlockingIOError, InterruptedError):
149-
return True
143+
return
150144
except serial.SerialException as exc:
151145
self.close(exc=exc)
152-
return False
153146

154-
def _poll_read(self):
155-
if self._has_reader:
156-
try:
157-
self._has_reader = self.async_loop.call_later(
158-
self._poll_wait_time, self._poll_read
159-
)
147+
async def _polling_task(self):
148+
"""Poll and try to read/write."""
149+
try:
150+
while True:
151+
await asyncio.sleep(self._poll_wait_time)
152+
while self._write_buffer:
153+
self._write_ready()
160154
if self.sync_serial.in_waiting:
161155
self._read_ready()
162-
except serial.SerialException as exc:
163-
self.close(exc=exc)
164-
165-
def _poll_write(self):
166-
if not self._has_writer:
167-
return
168-
if self._write_ready():
169-
self._has_writer = self.async_loop.call_later(
170-
self._poll_wait_time, self._poll_write
171-
)
156+
except serial.SerialException as exc:
157+
self.close(exc=exc)
158+
except asyncio.CancelledError:
159+
pass
172160

173161

174162
async def create_serial_connection(loop, protocol_factory, *args, **kwargs):

test/sub_transport/test_basic.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,10 @@ async def test_external_methods(self):
325325
comm.close()
326326
comm = SerialTransport(mock.MagicMock(), mock.Mock(), "dummy")
327327
comm.abort()
328-
assert await create_serial_connection(
328+
transport, protocol = await create_serial_connection(
329329
asyncio.get_running_loop(), mock.Mock, url="dummy"
330330
)
331+
await asyncio.sleep(0.1)
332+
assert transport
333+
assert protocol
334+
transport.close()

test/sub_transport/test_comm.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
CommType,
1010
ModbusProtocol,
1111
)
12+
from pymodbus.transport.transport_serial import SerialTransport
1213

1314

1415
FACTOR = 1.2 if not pytest.IS_WINDOWS else 4.2
@@ -125,6 +126,61 @@ async def test_connected(self, client, server, use_comm_type):
125126
assert not server.active_connections
126127
server.transport_close()
127128

129+
def wrapped_write(self, data):
130+
"""Wrap serial write, to split parameters."""
131+
return self.serial_write(data[:2])
132+
133+
@pytest.mark.parametrize(
134+
("use_comm_type", "use_host"),
135+
[
136+
(CommType.SERIAL, "socket://localhost:5020"),
137+
],
138+
)
139+
async def test_split_serial_packet(self, client, server):
140+
"""Test connection and data exchange."""
141+
assert await server.transport_listen()
142+
assert await client.transport_connect()
143+
await asyncio.sleep(0.5)
144+
assert len(server.active_connections) == 1
145+
server_connected = list(server.active_connections.values())[0]
146+
test_data = b"abcd"
147+
148+
self.serial_write = ( # pylint: disable=attribute-defined-outside-init
149+
client.transport.sync_serial.write
150+
)
151+
with mock.patch.object(
152+
client.transport.sync_serial, "write", wraps=self.wrapped_write
153+
):
154+
client.transport_send(test_data)
155+
await asyncio.sleep(0.5)
156+
assert server_connected.recv_buffer == test_data
157+
assert not client.recv_buffer
158+
client.transport_close()
159+
server.transport_close()
160+
161+
@pytest.mark.parametrize(
162+
("use_comm_type", "use_host"),
163+
[
164+
(CommType.SERIAL, "socket://localhost:5020"),
165+
],
166+
)
167+
async def test_serial_poll(self, client, server):
168+
"""Test connection and data exchange."""
169+
assert await server.transport_listen()
170+
SerialTransport.force_poll = True
171+
assert await client.transport_connect()
172+
await asyncio.sleep(0.5)
173+
SerialTransport.force_poll = False
174+
assert len(server.active_connections) == 1
175+
server_connected = list(server.active_connections.values())[0]
176+
test_data = b"abcd" * 1000
177+
client.transport_send(test_data)
178+
await asyncio.sleep(0.5)
179+
assert server_connected.recv_buffer == test_data
180+
assert not client.recv_buffer
181+
client.transport_close()
182+
server.transport_close()
183+
128184
@pytest.mark.parametrize(
129185
("use_comm_type", "use_host"),
130186
[

0 commit comments

Comments
 (0)