Skip to content

Commit 14b50ee

Browse files
committed
refactor: Utilize processor.yield_exceptions_as_parts in call
1 parent 82cc59f commit 14b50ee

File tree

2 files changed

+103
-62
lines changed

2 files changed

+103
-62
lines changed

genai_processors_url_fetch/tests/test_url_fetch.py

Lines changed: 87 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,9 @@ async def mock_aiter_bytes() -> AsyncIterable[bytes]:
125125
r for r in results if r.metadata.get("fetch_status") == "success"
126126
]
127127

128-
assert len(status_parts) == 1
129-
assert "Fetched successfully" in status_parts[0].text
128+
assert len(status_parts) == 2 # Processing + success status
129+
assert any("Processing 1 URL(s)" in part.text for part in status_parts)
130+
assert any("Fetched successfully" in part.text for part in status_parts)
130131
assert len(content_parts) == 1
131132
assert content_parts[0].text == "Test Content"
132133
assert content_parts[0].metadata["source_url"] == "https://example.com"
@@ -159,23 +160,26 @@ async def test_failed_fetch_with_mocking(self) -> None:
159160
part = processor.ProcessorPart("Visit https://notfound.com")
160161
results = [r async for r in p.call(part)]
161162

162-
# Should have status and failure parts
163+
# Should have status parts and exception part
163164
status_parts = [
164165
r for r in results if r.substream_name == processor.STATUS_STREAM
165166
]
166-
failure_parts = [
167-
r for r in results if r.metadata.get("fetch_status") == "failure"
167+
exception_parts = [
168+
r for r in results if r.metadata.get("exception_type") == "RuntimeError"
168169
]
169170

170-
assert len(status_parts) == 1
171-
assert "Fetch failed" in status_parts[0].text
172-
assert len(failure_parts) == 1
173-
assert failure_parts[0].text == "" # Empty content for failures
174-
assert "404" in failure_parts[0].metadata["fetch_error"]
171+
assert (
172+
len(status_parts) == 3
173+
) # Processing + failure status + exception part
174+
assert any("Processing 1 URL(s)" in part.text for part in status_parts)
175+
assert any("Fetch failed" in part.text for part in status_parts)
176+
assert len(exception_parts) == 1
177+
assert "An unexpected error occurred" in exception_parts[0].text
178+
assert "404" in exception_parts[0].metadata["original_exception"]
175179

176180
@pytest.mark.anyio
177181
async def test_fail_on_error_config(self) -> None:
178-
"""Test that fail_on_error configuration raises exceptions."""
182+
"""Test that errors are converted to exception parts by the decorator."""
179183
config = FetchConfig(fail_on_error=True)
180184
p = UrlFetchProcessor(config)
181185

@@ -194,10 +198,14 @@ async def test_fail_on_error_config(self) -> None:
194198
mock_client.get.side_effect = error
195199

196200
part = processor.ProcessorPart("Visit https://error.com")
201+
results = [r async for r in p.call(part)]
197202

198-
with pytest.raises(RuntimeError):
199-
async for _ in p.call(part):
200-
pass
203+
# With decorator, exceptions are converted to exception parts
204+
exception_parts = [
205+
r for r in results if r.metadata.get("exception_type") == "RuntimeError"
206+
]
207+
assert len(exception_parts) == 1
208+
assert "Request Error" in exception_parts[0].metadata["original_exception"]
201209

202210
@pytest.mark.anyio
203211
async def test_content_processor_raw_config(self) -> None:
@@ -470,19 +478,22 @@ async def test_url_validation_integration(self) -> None:
470478
part = processor.ProcessorPart("Visit https://malicious.com")
471479
results = [r async for r in p.call(part)]
472480

473-
# Should have status and failure parts
481+
# Should have status parts and exception part
474482
status_parts = [
475483
r for r in results if r.substream_name == processor.STATUS_STREAM
476484
]
477-
failure_parts = [
478-
r for r in results if r.metadata.get("fetch_status") == "failure"
485+
exception_parts = [
486+
r for r in results if r.metadata.get("exception_type") == "RuntimeError"
479487
]
480488

481-
assert len(status_parts) == 1
482-
assert "Fetch failed" in status_parts[0].text
483-
assert len(failure_parts) == 1
484-
error_msg = failure_parts[0].metadata["fetch_error"]
485-
assert "Security validation failed" in error_msg
489+
assert len(status_parts) == 3 # Processing + failure status + exception part
490+
assert any("Processing 1 URL(s)" in part.text for part in status_parts)
491+
assert any("Fetch failed" in part.text for part in status_parts)
492+
assert len(exception_parts) == 1
493+
assert (
494+
"Domain 'malicious.com' not in allowed list"
495+
in exception_parts[0].metadata["original_exception"]
496+
)
486497

487498
@pytest.mark.anyio
488499
async def test_create_success_part_validation_errors(self) -> None:
@@ -558,13 +569,34 @@ async def test_response_size_limits(self) -> None:
558569
part = processor.ProcessorPart("Visit https://example.com")
559570
results = [r async for r in p.call(part)]
560571

561-
# Should have failure parts due to size limit
562-
failure_parts = [
563-
r for r in results if r.metadata.get("fetch_status") == "failure"
572+
# Should have status parts and exception part
573+
status_parts = [
574+
r
575+
for r in results
576+
if r.substream_name == processor.STATUS_STREAM # type: ignore
577+
]
578+
exception_parts = [
579+
r
580+
for r in results
581+
# type: ignore
582+
if r.metadata.get("exception_type") == "RuntimeError"
564583
]
565-
assert len(failure_parts) == 1
566-
error_msg = failure_parts[0].metadata["fetch_error"]
567-
assert "Response too large" in error_msg
584+
585+
assert (
586+
len(status_parts) == 3
587+
) # Processing + failure status + exception part
588+
assert any(
589+
"Processing 1 URL(s)" in part.text # type: ignore
590+
for part in status_parts
591+
)
592+
assert any(
593+
"Fetch failed" in part.text for part in status_parts # type: ignore
594+
)
595+
assert len(exception_parts) == 1
596+
original_exception = exception_parts[0].metadata[ # type: ignore
597+
"original_exception"
598+
]
599+
assert "Response too large" in original_exception
568600

569601
@pytest.mark.anyio
570602
async def test_streaming_response_size_exceeded(self) -> None:
@@ -595,10 +627,31 @@ async def mock_aiter_bytes() -> AsyncIterable[bytes]:
595627
part = processor.ProcessorPart("Visit https://example.com")
596628
results = [r async for r in p.call(part)]
597629

598-
# Should have failure parts due to size limit during streaming
599-
failure_parts = [
600-
r for r in results if r.metadata.get("fetch_status") == "failure"
630+
# Should have status parts and exception part
631+
status_parts = [
632+
r
633+
for r in results
634+
if r.substream_name == processor.STATUS_STREAM # type: ignore
635+
]
636+
exception_parts = [
637+
r
638+
for r in results
639+
# type: ignore
640+
if r.metadata.get("exception_type") == "RuntimeError"
641+
]
642+
643+
assert (
644+
len(status_parts) == 3
645+
) # Processing + failure status + exception part
646+
assert any(
647+
"Processing 1 URL(s)" in part.text # type: ignore
648+
for part in status_parts
649+
)
650+
assert any(
651+
"Fetch failed" in part.text for part in status_parts # type: ignore
652+
)
653+
assert len(exception_parts) == 1
654+
original_exception = exception_parts[0].metadata[ # type: ignore
655+
"original_exception"
601656
]
602-
assert len(failure_parts) == 1
603-
error_msg = failure_parts[0].metadata["fetch_error"]
604-
assert "Response exceeded" in error_msg
657+
assert "Response exceeded" in original_exception

genai_processors_url_fetch/url_fetch.py

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -516,28 +516,17 @@ async def _create_success_part(
516516
},
517517
)
518518

519-
async def _create_failure_part(
520-
self,
521-
result: FetchResult,
522-
original_part: processor.ProcessorPart,
523-
) -> processor.ProcessorPart:
524-
"""Create a ProcessorPart for a failed fetch."""
525-
return processor.ProcessorPart(
526-
# Empty content for failures
527-
"",
528-
metadata={
529-
**original_part.metadata,
530-
"source_url": result.url,
531-
"fetch_status": "failure",
532-
"fetch_error": result.error_message,
533-
},
534-
)
535-
519+
@processor.yield_exceptions_as_parts # type: ignore[arg-type]
536520
async def call(
537521
self,
538522
part: processor.ProcessorPart,
539523
) -> AsyncIterable[processor.ProcessorPart]:
540-
"""Fetch URLs found in the part and yield results."""
524+
"""Fetch URLs found in the part and yield results.
525+
526+
Extracts URLs from the part's text, fetches each one, and yields
527+
status updates and results. Failed fetches are automatically converted
528+
to error parts by the @processor.yield_exceptions_as_parts decorator.
529+
"""
541530
urls = list(dict.fromkeys(URL_REGEX.findall(part.text or "")))
542531

543532
if not urls:
@@ -546,14 +535,13 @@ async def call(
546535
yield part
547536
return
548537

538+
yield processor.status(f"📄 Processing {len(urls)} URL(s)")
539+
549540
headers = {"User-Agent": self.config.user_agent}
550541
async with httpx.AsyncClient(headers=headers) as client:
551-
# Create tasks to fetch all URLs concurrently
552-
tasks = [self._fetch_one(url, client) for url in urls]
553-
results = await asyncio.gather(*tasks)
554-
555-
# Process all results
556-
for result in results:
542+
# Process each URL individually
543+
for url in urls:
544+
result = await self._fetch_one(url, client)
557545

558546
if result.ok:
559547
yield processor.status(
@@ -562,10 +550,10 @@ async def call(
562550
yield await self._create_success_part(result, part)
563551
else:
564552
yield processor.status(f"❌ Fetch failed: {result.url}")
565-
if self.config.fail_on_error:
566-
raise RuntimeError(result.error_message)
567-
yield await self._create_failure_part(result, part)
553+
# Raise exception for failed fetch - decorator converts to error part
554+
error_msg = f"Failed to fetch {url}: {result.error_message}"
555+
raise RuntimeError(error_msg)
568556

569-
# Finally, yield the original part if configured to do so
557+
# Include original part if configured to do so
570558
if self.config.include_original_part:
571559
yield part

0 commit comments

Comments
 (0)