diff --git a/aioitertools/asyncio.py b/aioitertools/asyncio.py index 5ace5c3..31d0e93 100644 --- a/aioitertools/asyncio.py +++ b/aioitertools/asyncio.py @@ -8,6 +8,7 @@ """ import asyncio +from contextlib import suppress import time from typing import ( Any, @@ -119,56 +120,60 @@ async def generator(x): ... # intermixed values yielded from gen1 and gen2 """ - exc_queue: asyncio.Queue[Exception] = asyncio.Queue() - queue: asyncio.Queue[T] = asyncio.Queue() + queue: asyncio.Queue[dict] = asyncio.Queue() + + tailer_count: int = 0 + + async def tailer(iterable: AsyncIterable[T]) -> None: + nonlocal tailer_count - async def tailer(iter: AsyncIterable[T]) -> None: try: - async for item in iter: - await queue.put(item) + async for item in iterable: + await queue.put({"value": item}) except asyncio.CancelledError: - if isinstance(iter, AsyncGenerator): # pragma:nocover - await iter.aclose() + if isinstance(iterable, AsyncGenerator): # pragma:nocover + with suppress(Exception): + await iterable.aclose() raise - except Exception as e: - await exc_queue.put(e) + except Exception as exc: + await queue.put({"exception": exc}) + finally: + tailer_count -= 1 + + if tailer_count == 0: + await queue.put({"done": True}) tasks = [asyncio.ensure_future(tailer(iter)) for iter in iterables] - pending = set(tasks) + + if not tasks: + # Nothing to do + return + + tailer_count = len(tasks) try: - while pending: - try: - exc = exc_queue.get_nowait() + while True: + i = await queue.get() + + if "value" in i: + yield i["value"] + elif "exception" in i: if return_exceptions: - yield exc # type: ignore + yield i["exception"] else: - raise exc - except asyncio.QueueEmpty: - pass - - try: - value = queue.get_nowait() - yield value - except asyncio.QueueEmpty: - for task in list(pending): - if task.done(): - pending.remove(task) - await asyncio.sleep(0.001) - + raise i["exception"] + elif "done" in i: + break except (asyncio.CancelledError, GeneratorExit): pass - finally: for task in tasks: if not task.done(): task.cancel() for task in tasks: - try: + with suppress(asyncio.CancelledError): await task - except asyncio.CancelledError: - pass @deprecated_wait_param diff --git a/aioitertools/tests/asyncio.py b/aioitertools/tests/asyncio.py index c352a94..352b571 100644 --- a/aioitertools/tests/asyncio.py +++ b/aioitertools/tests/asyncio.py @@ -65,6 +65,31 @@ async def gen(): self.assertEqual(30, len(results)) self.assertListEqual(sorted(expected), sorted(results)) + @async_test + async def test_as_generated_no_iterables(self): + gens = [] + expected = [] + results = [] + async for value in aio.as_generated(gens): + results.append(value) + self.assertEqual(0, len(results)) + self.assertListEqual(sorted(expected), sorted(results)) + + @async_test + async def test_as_generated_empty_iterables(self): + async def gen(stop): + for i in range(stop): + yield i + await asyncio.sleep(0) + + gens = [gen(0), gen(1), gen(2)] + expected = [0, 0, 1] + results = [] + async for value in aio.as_generated(gens): + results.append(value) + self.assertEqual(3, len(results)) + self.assertListEqual(sorted(expected), sorted(results)) + @async_test async def test_as_generated_exception(self): async def gen1():