diff --git a/yaqd-core/CHANGELOG.md b/yaqd-core/CHANGELOG.md index 5f8d640..0c16604 100644 --- a/yaqd-core/CHANGELOG.md +++ b/yaqd-core/CHANGELOG.md @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/). ## [Unreleased] ### Fixed +- handle cleanup when a client connection is broken - type hints for IsSensor attributes are appropriate for _n_-dimensional data ## [2023.11.0] diff --git a/yaqd-core/yaqd_core/_protocol.py b/yaqd-core/yaqd_core/_protocol.py index 6644a08..2ca4a16 100644 --- a/yaqd-core/yaqd_core/_protocol.py +++ b/yaqd-core/yaqd_core/_protocol.py @@ -28,7 +28,9 @@ def connection_made(self, transport): self.transport = transport self.unpacker = avrorpc.Unpacker(self._avro_protocol) self._daemon._connection_made(peername) - self.task = asyncio.get_event_loop().create_task(self.process_requests()) + self.task = self._daemon._loop.create_task(self.process_requests()) + self._daemon._tasks.append(self.task) + self.task.add_done_callback(self._daemon._tasks.remove) def data_received(self, data): """Process an incomming request.""" @@ -38,6 +40,15 @@ def data_received(self, data): self.unpacker.feed(data) async def process_requests(self): + try: + await self._process_requests() + except asyncio.CancelledError as e: + self.logger.debug("cancelling process_requests") + await self.unpacker.__aexit__(None, None, None) + self.transport.close() + raise e + + async def _process_requests(self): async for hs, meta, name, params in self.unpacker: if hs is not None: out = bytes(hs) @@ -46,13 +57,13 @@ async def process_requests(self): if hs.match == "NONE": name = "" - out_meta = io.BytesIO() + meta_out = io.BytesIO() fastavro.schemaless_writer( - out_meta, {"type": "map", "values": "bytes"}, meta + meta_out, {"type": "map", "values": "bytes"}, meta ) - length = out_meta.tell() - self.transport.write(struct.pack(">L", length) + out_meta.getvalue()) - self.logger.debug(f"Wrote meta, {meta}, {out_meta.getvalue()}") + length = meta_out.tell() + self.transport.write(struct.pack(">L", length) + meta_out.getvalue()) + self.logger.debug(f"Wrote meta, {meta}, {meta_out.getvalue()}") try: response_out = io.BytesIO() response = None @@ -82,6 +93,7 @@ async def process_requests(self): fastavro.schemaless_writer(error_out, ["string"], repr(e)) length = error_out.tell() self.transport.write(struct.pack(">L", length) + error_out.getvalue()) + error_out.close() else: self.transport.write(struct.pack(">L", 1) + b"\0") self.logger.debug(f"Wrote non-error flag") @@ -92,7 +104,11 @@ async def process_requests(self): self.logger.debug( f"Wrote response {response}, {response_out.getvalue()}" ) + finally: + response_out.close() + meta_out.close() self.transport.write(struct.pack(">L", 0)) + self.unpacker._file = io.BytesIO() if name == "shutdown": self.logger.debug("Closing transport") self.transport.close() diff --git a/yaqd-core/yaqd_core/avrorpc/unpacker.py b/yaqd-core/yaqd_core/avrorpc/unpacker.py index 1d507f7..60825fb 100644 --- a/yaqd-core/yaqd_core/avrorpc/unpacker.py +++ b/yaqd-core/yaqd_core/avrorpc/unpacker.py @@ -67,13 +67,19 @@ async def __anext__(self): except (ValueError, struct.error): await self.new_data.wait() + async def __aexit__(self, exc_type, exc_val, exc_tb): + await asyncio.sleep(0) + self._file.close() + self.buf.close() + def feed(self, data: bytes): - # Must support random access, if it does not, must be fed externally (e.g. TCP) - pos = self._file.tell() - self._file.seek(0, 2) - self._file.write(data) - self._file.seek(pos) - self.new_data.set() + if not self._file.closed: + # Must support random access, if it does not, must be fed externally (e.g. TCP) + pos = self._file.tell() + self._file.seek(0, 2) + self._file.write(data) + self._file.seek(pos) + self.new_data.set() async def _read_object(self, schema): schema = fastavro.parse_schema(