Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 87 additions & 34 deletions genai_processors_url_fetch/tests/test_url_fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,9 @@ async def mock_aiter_bytes() -> AsyncIterable[bytes]:
r for r in results if r.metadata.get("fetch_status") == "success"
]

assert len(status_parts) == 1
assert "Fetched successfully" in status_parts[0].text
assert len(status_parts) == 2 # Processing + success status
assert any("Processing 1 URL(s)" in part.text for part in status_parts)
assert any("Fetched successfully" in part.text for part in status_parts)
assert len(content_parts) == 1
assert content_parts[0].text == "Test Content"
assert content_parts[0].metadata["source_url"] == "https://example.com"
Expand Down Expand Up @@ -159,23 +160,26 @@ async def test_failed_fetch_with_mocking(self) -> None:
part = processor.ProcessorPart("Visit https://notfound.com")
results = [r async for r in p.call(part)]

# Should have status and failure parts
# Should have status parts and exception part
status_parts = [
r for r in results if r.substream_name == processor.STATUS_STREAM
]
failure_parts = [
r for r in results if r.metadata.get("fetch_status") == "failure"
exception_parts = [
r for r in results if r.metadata.get("exception_type") == "RuntimeError"
]

assert len(status_parts) == 1
assert "Fetch failed" in status_parts[0].text
assert len(failure_parts) == 1
assert failure_parts[0].text == "" # Empty content for failures
assert "404" in failure_parts[0].metadata["fetch_error"]
assert (
len(status_parts) == 3
) # Processing + failure status + exception part
assert any("Processing 1 URL(s)" in part.text for part in status_parts)
assert any("Fetch failed" in part.text for part in status_parts)
assert len(exception_parts) == 1
assert "An unexpected error occurred" in exception_parts[0].text
assert "404" in exception_parts[0].metadata["original_exception"]

@pytest.mark.anyio
async def test_fail_on_error_config(self) -> None:
"""Test that fail_on_error configuration raises exceptions."""
"""Test that errors are converted to exception parts by the decorator."""
config = FetchConfig(fail_on_error=True)
p = UrlFetchProcessor(config)

Expand All @@ -194,10 +198,14 @@ async def test_fail_on_error_config(self) -> None:
mock_client.get.side_effect = error

part = processor.ProcessorPart("Visit https://error.com")
results = [r async for r in p.call(part)]

with pytest.raises(RuntimeError):
async for _ in p.call(part):
pass
# With decorator, exceptions are converted to exception parts
exception_parts = [
r for r in results if r.metadata.get("exception_type") == "RuntimeError"
]
assert len(exception_parts) == 1
assert "Request Error" in exception_parts[0].metadata["original_exception"]

@pytest.mark.anyio
async def test_content_processor_raw_config(self) -> None:
Expand Down Expand Up @@ -470,19 +478,22 @@ async def test_url_validation_integration(self) -> None:
part = processor.ProcessorPart("Visit https://malicious.com")
results = [r async for r in p.call(part)]

# Should have status and failure parts
# Should have status parts and exception part
status_parts = [
r for r in results if r.substream_name == processor.STATUS_STREAM
]
failure_parts = [
r for r in results if r.metadata.get("fetch_status") == "failure"
exception_parts = [
r for r in results if r.metadata.get("exception_type") == "RuntimeError"
]

assert len(status_parts) == 1
assert "Fetch failed" in status_parts[0].text
assert len(failure_parts) == 1
error_msg = failure_parts[0].metadata["fetch_error"]
assert "Security validation failed" in error_msg
assert len(status_parts) == 3 # Processing + failure status + exception part
assert any("Processing 1 URL(s)" in part.text for part in status_parts)
assert any("Fetch failed" in part.text for part in status_parts)
assert len(exception_parts) == 1
assert (
"Domain 'malicious.com' not in allowed list"
in exception_parts[0].metadata["original_exception"]
)

@pytest.mark.anyio
async def test_create_success_part_validation_errors(self) -> None:
Expand Down Expand Up @@ -558,13 +569,34 @@ async def test_response_size_limits(self) -> None:
part = processor.ProcessorPart("Visit https://example.com")
results = [r async for r in p.call(part)]

# Should have failure parts due to size limit
failure_parts = [
r for r in results if r.metadata.get("fetch_status") == "failure"
# Should have status parts and exception part
status_parts = [
r
for r in results
if r.substream_name == processor.STATUS_STREAM # type: ignore
]
exception_parts = [
r
for r in results
# type: ignore
if r.metadata.get("exception_type") == "RuntimeError"
]
assert len(failure_parts) == 1
error_msg = failure_parts[0].metadata["fetch_error"]
assert "Response too large" in error_msg

assert (
len(status_parts) == 3
) # Processing + failure status + exception part
assert any(
"Processing 1 URL(s)" in part.text # type: ignore
for part in status_parts
)
assert any(
"Fetch failed" in part.text for part in status_parts # type: ignore
)
assert len(exception_parts) == 1
original_exception = exception_parts[0].metadata[ # type: ignore
"original_exception"
]
assert "Response too large" in original_exception

@pytest.mark.anyio
async def test_streaming_response_size_exceeded(self) -> None:
Expand Down Expand Up @@ -595,10 +627,31 @@ async def mock_aiter_bytes() -> AsyncIterable[bytes]:
part = processor.ProcessorPart("Visit https://example.com")
results = [r async for r in p.call(part)]

# Should have failure parts due to size limit during streaming
failure_parts = [
r for r in results if r.metadata.get("fetch_status") == "failure"
# Should have status parts and exception part
status_parts = [
r
for r in results
if r.substream_name == processor.STATUS_STREAM # type: ignore
]
exception_parts = [
r
for r in results
# type: ignore
if r.metadata.get("exception_type") == "RuntimeError"
]

assert (
len(status_parts) == 3
) # Processing + failure status + exception part
assert any(
"Processing 1 URL(s)" in part.text # type: ignore
for part in status_parts
)
assert any(
"Fetch failed" in part.text for part in status_parts # type: ignore
)
assert len(exception_parts) == 1
original_exception = exception_parts[0].metadata[ # type: ignore
"original_exception"
]
assert len(failure_parts) == 1
error_msg = failure_parts[0].metadata["fetch_error"]
assert "Response exceeded" in error_msg
assert "Response exceeded" in original_exception
44 changes: 16 additions & 28 deletions genai_processors_url_fetch/url_fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,28 +516,17 @@ async def _create_success_part(
},
)

async def _create_failure_part(
self,
result: FetchResult,
original_part: processor.ProcessorPart,
) -> processor.ProcessorPart:
"""Create a ProcessorPart for a failed fetch."""
return processor.ProcessorPart(
# Empty content for failures
"",
metadata={
**original_part.metadata,
"source_url": result.url,
"fetch_status": "failure",
"fetch_error": result.error_message,
},
)

@processor.yield_exceptions_as_parts # type: ignore[arg-type]
async def call(
self,
part: processor.ProcessorPart,
) -> AsyncIterable[processor.ProcessorPart]:
"""Fetch URLs found in the part and yield results."""
"""Fetch URLs found in the part and yield results.

Extracts URLs from the part's text, fetches each one, and yields
status updates and results. Failed fetches are automatically converted
to error parts by the @processor.yield_exceptions_as_parts decorator.
"""
urls = list(dict.fromkeys(URL_REGEX.findall(part.text or "")))

if not urls:
Expand All @@ -546,14 +535,13 @@ async def call(
yield part
return

yield processor.status(f"📄 Processing {len(urls)} URL(s)")

headers = {"User-Agent": self.config.user_agent}
async with httpx.AsyncClient(headers=headers) as client:
# Create tasks to fetch all URLs concurrently
tasks = [self._fetch_one(url, client) for url in urls]
results = await asyncio.gather(*tasks)

# Process all results
for result in results:
# Process each URL individually
for url in urls:
result = await self._fetch_one(url, client)

if result.ok:
yield processor.status(
Expand All @@ -562,10 +550,10 @@ async def call(
yield await self._create_success_part(result, part)
else:
yield processor.status(f"❌ Fetch failed: {result.url}")
if self.config.fail_on_error:
raise RuntimeError(result.error_message)
yield await self._create_failure_part(result, part)
# Raise exception for failed fetch - decorator converts to error part
error_msg = f"Failed to fetch {url}: {result.error_message}"
raise RuntimeError(error_msg)

# Finally, yield the original part if configured to do so
# Include original part if configured to do so
if self.config.include_original_part:
yield part