From 78035bfa7f9bb81a8b8943f9c12db1a95831e797 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Carvalho?= Date: Thu, 3 Mar 2022 10:25:10 +0000 Subject: [PATCH 1/3] Don't poll in `as_generated`. This makes use of a single queue and value wrappers to signal to the queue consumer what to do. It also prevents polling, which could make this more efficient. --- aioitertools/asyncio.py | 44 +++++++++++++++++++++-------------------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/aioitertools/asyncio.py b/aioitertools/asyncio.py index 5ace5c3..8a0427c 100644 --- a/aioitertools/asyncio.py +++ b/aioitertools/asyncio.py @@ -119,42 +119,44 @@ async def generator(x): ... # intermixed values yielded from gen1 and gen2 """ - exc_queue: asyncio.Queue[Exception] = asyncio.Queue() - queue: asyncio.Queue[T] = asyncio.Queue() + queue: asyncio.Queue[dict] = asyncio.Queue() + + tailer_count: int = 0 async def tailer(iter: AsyncIterable[T]) -> None: + nonlocal tailer_count + try: async for item in iter: - await queue.put(item) + await queue.put({"value": item}) except asyncio.CancelledError: if isinstance(iter, AsyncGenerator): # pragma:nocover await iter.aclose() raise except Exception as e: - await exc_queue.put(e) + await queue.put({"exception": e}) + finally: + tailer_count -= 1 + + if tailer_count == 0: + await queue.put({"done": True}) tasks = [asyncio.ensure_future(tailer(iter)) for iter in iterables] - pending = set(tasks) + tailer_count = len(tasks) try: - while pending: - try: - exc = exc_queue.get_nowait() + while True: + i = await queue.get() + + if "value" in i: + yield i["value"] + elif "exception" in i: if return_exceptions: - yield exc # type: ignore + yield i["exception"] else: - raise exc - except asyncio.QueueEmpty: - pass - - try: - value = queue.get_nowait() - yield value - except asyncio.QueueEmpty: - for task in list(pending): - if task.done(): - pending.remove(task) - await asyncio.sleep(0.001) + raise i["exception"] + elif "done" in i: + break except (asyncio.CancelledError, GeneratorExit): pass From f86552753e626cb71a3a305b9ec890f97d771e6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Carvalho?= Date: Thu, 3 Mar 2022 10:35:42 +0000 Subject: [PATCH 2/3] Take into account empty iterables --- aioitertools/asyncio.py | 5 +++++ aioitertools/tests/asyncio.py | 25 +++++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/aioitertools/asyncio.py b/aioitertools/asyncio.py index 8a0427c..8f81c57 100644 --- a/aioitertools/asyncio.py +++ b/aioitertools/asyncio.py @@ -142,6 +142,11 @@ async def tailer(iter: AsyncIterable[T]) -> None: await queue.put({"done": True}) tasks = [asyncio.ensure_future(tailer(iter)) for iter in iterables] + + if not tasks: + # Nothing to do + return + tailer_count = len(tasks) try: diff --git a/aioitertools/tests/asyncio.py b/aioitertools/tests/asyncio.py index c352a94..352b571 100644 --- a/aioitertools/tests/asyncio.py +++ b/aioitertools/tests/asyncio.py @@ -65,6 +65,31 @@ async def gen(): self.assertEqual(30, len(results)) self.assertListEqual(sorted(expected), sorted(results)) + @async_test + async def test_as_generated_no_iterables(self): + gens = [] + expected = [] + results = [] + async for value in aio.as_generated(gens): + results.append(value) + self.assertEqual(0, len(results)) + self.assertListEqual(sorted(expected), sorted(results)) + + @async_test + async def test_as_generated_empty_iterables(self): + async def gen(stop): + for i in range(stop): + yield i + await asyncio.sleep(0) + + gens = [gen(0), gen(1), gen(2)] + expected = [0, 0, 1] + results = [] + async for value in aio.as_generated(gens): + results.append(value) + self.assertEqual(3, len(results)) + self.assertListEqual(sorted(expected), sorted(results)) + @async_test async def test_as_generated_exception(self): async def gen1(): From fda95957fae5e261ae43c430ee36dc2b97308e93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Carvalho?= <1012794+RedRoserade@users.noreply.github.com> Date: Fri, 18 Mar 2022 10:19:02 +0000 Subject: [PATCH 3/3] Update asyncio.py --- aioitertools/asyncio.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/aioitertools/asyncio.py b/aioitertools/asyncio.py index 8f81c57..31d0e93 100644 --- a/aioitertools/asyncio.py +++ b/aioitertools/asyncio.py @@ -8,6 +8,7 @@ """ import asyncio +from contextlib import suppress import time from typing import ( Any, @@ -123,18 +124,19 @@ async def generator(x): tailer_count: int = 0 - async def tailer(iter: AsyncIterable[T]) -> None: + async def tailer(iterable: AsyncIterable[T]) -> None: nonlocal tailer_count try: - async for item in iter: + async for item in iterable: await queue.put({"value": item}) except asyncio.CancelledError: - if isinstance(iter, AsyncGenerator): # pragma:nocover - await iter.aclose() + if isinstance(iterable, AsyncGenerator): # pragma:nocover + with suppress(Exception): + await iterable.aclose() raise - except Exception as e: - await queue.put({"exception": e}) + except Exception as exc: + await queue.put({"exception": exc}) finally: tailer_count -= 1 @@ -162,20 +164,16 @@ async def tailer(iter: AsyncIterable[T]) -> None: raise i["exception"] elif "done" in i: break - except (asyncio.CancelledError, GeneratorExit): pass - finally: for task in tasks: if not task.done(): task.cancel() for task in tasks: - try: + with suppress(asyncio.CancelledError): await task - except asyncio.CancelledError: - pass @deprecated_wait_param