From 14b50ee6e4e2bb46ca9a5cf8d8e184bb5dd2a773 Mon Sep 17 00:00:00 2001 From: Mark Beacom Date: Fri, 18 Jul 2025 13:26:51 -0400 Subject: [PATCH] refactor: Utilize processor.yield_exceptions_as_parts in call --- .../tests/test_url_fetch.py | 121 +++++++++++++----- genai_processors_url_fetch/url_fetch.py | 44 +++---- 2 files changed, 103 insertions(+), 62 deletions(-) diff --git a/genai_processors_url_fetch/tests/test_url_fetch.py b/genai_processors_url_fetch/tests/test_url_fetch.py index 621930c..09424d8 100644 --- a/genai_processors_url_fetch/tests/test_url_fetch.py +++ b/genai_processors_url_fetch/tests/test_url_fetch.py @@ -125,8 +125,9 @@ async def mock_aiter_bytes() -> AsyncIterable[bytes]: r for r in results if r.metadata.get("fetch_status") == "success" ] - assert len(status_parts) == 1 - assert "Fetched successfully" in status_parts[0].text + assert len(status_parts) == 2 # Processing + success status + assert any("Processing 1 URL(s)" in part.text for part in status_parts) + assert any("Fetched successfully" in part.text for part in status_parts) assert len(content_parts) == 1 assert content_parts[0].text == "Test Content" assert content_parts[0].metadata["source_url"] == "https://example.com" @@ -159,23 +160,26 @@ async def test_failed_fetch_with_mocking(self) -> None: part = processor.ProcessorPart("Visit https://notfound.com") results = [r async for r in p.call(part)] - # Should have status and failure parts + # Should have status parts and exception part status_parts = [ r for r in results if r.substream_name == processor.STATUS_STREAM ] - failure_parts = [ - r for r in results if r.metadata.get("fetch_status") == "failure" + exception_parts = [ + r for r in results if r.metadata.get("exception_type") == "RuntimeError" ] - assert len(status_parts) == 1 - assert "Fetch failed" in status_parts[0].text - assert len(failure_parts) == 1 - assert failure_parts[0].text == "" # Empty content for failures - assert "404" in failure_parts[0].metadata["fetch_error"] + assert ( + len(status_parts) == 3 + ) # Processing + failure status + exception part + assert any("Processing 1 URL(s)" in part.text for part in status_parts) + assert any("Fetch failed" in part.text for part in status_parts) + assert len(exception_parts) == 1 + assert "An unexpected error occurred" in exception_parts[0].text + assert "404" in exception_parts[0].metadata["original_exception"] @pytest.mark.anyio async def test_fail_on_error_config(self) -> None: - """Test that fail_on_error configuration raises exceptions.""" + """Test that errors are converted to exception parts by the decorator.""" config = FetchConfig(fail_on_error=True) p = UrlFetchProcessor(config) @@ -194,10 +198,14 @@ async def test_fail_on_error_config(self) -> None: mock_client.get.side_effect = error part = processor.ProcessorPart("Visit https://error.com") + results = [r async for r in p.call(part)] - with pytest.raises(RuntimeError): - async for _ in p.call(part): - pass + # With decorator, exceptions are converted to exception parts + exception_parts = [ + r for r in results if r.metadata.get("exception_type") == "RuntimeError" + ] + assert len(exception_parts) == 1 + assert "Request Error" in exception_parts[0].metadata["original_exception"] @pytest.mark.anyio async def test_content_processor_raw_config(self) -> None: @@ -470,19 +478,22 @@ async def test_url_validation_integration(self) -> None: part = processor.ProcessorPart("Visit https://malicious.com") results = [r async for r in p.call(part)] - # Should have status and failure parts + # Should have status parts and exception part status_parts = [ r for r in results if r.substream_name == processor.STATUS_STREAM ] - failure_parts = [ - r for r in results if r.metadata.get("fetch_status") == "failure" + exception_parts = [ + r for r in results if r.metadata.get("exception_type") == "RuntimeError" ] - assert len(status_parts) == 1 - assert "Fetch failed" in status_parts[0].text - assert len(failure_parts) == 1 - error_msg = failure_parts[0].metadata["fetch_error"] - assert "Security validation failed" in error_msg + assert len(status_parts) == 3 # Processing + failure status + exception part + assert any("Processing 1 URL(s)" in part.text for part in status_parts) + assert any("Fetch failed" in part.text for part in status_parts) + assert len(exception_parts) == 1 + assert ( + "Domain 'malicious.com' not in allowed list" + in exception_parts[0].metadata["original_exception"] + ) @pytest.mark.anyio async def test_create_success_part_validation_errors(self) -> None: @@ -558,13 +569,34 @@ async def test_response_size_limits(self) -> None: part = processor.ProcessorPart("Visit https://example.com") results = [r async for r in p.call(part)] - # Should have failure parts due to size limit - failure_parts = [ - r for r in results if r.metadata.get("fetch_status") == "failure" + # Should have status parts and exception part + status_parts = [ + r + for r in results + if r.substream_name == processor.STATUS_STREAM # type: ignore + ] + exception_parts = [ + r + for r in results + # type: ignore + if r.metadata.get("exception_type") == "RuntimeError" ] - assert len(failure_parts) == 1 - error_msg = failure_parts[0].metadata["fetch_error"] - assert "Response too large" in error_msg + + assert ( + len(status_parts) == 3 + ) # Processing + failure status + exception part + assert any( + "Processing 1 URL(s)" in part.text # type: ignore + for part in status_parts + ) + assert any( + "Fetch failed" in part.text for part in status_parts # type: ignore + ) + assert len(exception_parts) == 1 + original_exception = exception_parts[0].metadata[ # type: ignore + "original_exception" + ] + assert "Response too large" in original_exception @pytest.mark.anyio async def test_streaming_response_size_exceeded(self) -> None: @@ -595,10 +627,31 @@ async def mock_aiter_bytes() -> AsyncIterable[bytes]: part = processor.ProcessorPart("Visit https://example.com") results = [r async for r in p.call(part)] - # Should have failure parts due to size limit during streaming - failure_parts = [ - r for r in results if r.metadata.get("fetch_status") == "failure" + # Should have status parts and exception part + status_parts = [ + r + for r in results + if r.substream_name == processor.STATUS_STREAM # type: ignore + ] + exception_parts = [ + r + for r in results + # type: ignore + if r.metadata.get("exception_type") == "RuntimeError" + ] + + assert ( + len(status_parts) == 3 + ) # Processing + failure status + exception part + assert any( + "Processing 1 URL(s)" in part.text # type: ignore + for part in status_parts + ) + assert any( + "Fetch failed" in part.text for part in status_parts # type: ignore + ) + assert len(exception_parts) == 1 + original_exception = exception_parts[0].metadata[ # type: ignore + "original_exception" ] - assert len(failure_parts) == 1 - error_msg = failure_parts[0].metadata["fetch_error"] - assert "Response exceeded" in error_msg + assert "Response exceeded" in original_exception diff --git a/genai_processors_url_fetch/url_fetch.py b/genai_processors_url_fetch/url_fetch.py index ce480f6..86efaae 100644 --- a/genai_processors_url_fetch/url_fetch.py +++ b/genai_processors_url_fetch/url_fetch.py @@ -516,28 +516,17 @@ async def _create_success_part( }, ) - async def _create_failure_part( - self, - result: FetchResult, - original_part: processor.ProcessorPart, - ) -> processor.ProcessorPart: - """Create a ProcessorPart for a failed fetch.""" - return processor.ProcessorPart( - # Empty content for failures - "", - metadata={ - **original_part.metadata, - "source_url": result.url, - "fetch_status": "failure", - "fetch_error": result.error_message, - }, - ) - + @processor.yield_exceptions_as_parts # type: ignore[arg-type] async def call( self, part: processor.ProcessorPart, ) -> AsyncIterable[processor.ProcessorPart]: - """Fetch URLs found in the part and yield results.""" + """Fetch URLs found in the part and yield results. + + Extracts URLs from the part's text, fetches each one, and yields + status updates and results. Failed fetches are automatically converted + to error parts by the @processor.yield_exceptions_as_parts decorator. + """ urls = list(dict.fromkeys(URL_REGEX.findall(part.text or ""))) if not urls: @@ -546,14 +535,13 @@ async def call( yield part return + yield processor.status(f"📄 Processing {len(urls)} URL(s)") + headers = {"User-Agent": self.config.user_agent} async with httpx.AsyncClient(headers=headers) as client: - # Create tasks to fetch all URLs concurrently - tasks = [self._fetch_one(url, client) for url in urls] - results = await asyncio.gather(*tasks) - - # Process all results - for result in results: + # Process each URL individually + for url in urls: + result = await self._fetch_one(url, client) if result.ok: yield processor.status( @@ -562,10 +550,10 @@ async def call( yield await self._create_success_part(result, part) else: yield processor.status(f"❌ Fetch failed: {result.url}") - if self.config.fail_on_error: - raise RuntimeError(result.error_message) - yield await self._create_failure_part(result, part) + # Raise exception for failed fetch - decorator converts to error part + error_msg = f"Failed to fetch {url}: {result.error_message}" + raise RuntimeError(error_msg) - # Finally, yield the original part if configured to do so + # Include original part if configured to do so if self.config.include_original_part: yield part