|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 |
| -from typing import Any |
| 3 | +import asyncio |
| 4 | +import logging |
| 5 | +import math |
| 6 | +from dataclasses import dataclass |
| 7 | +from datetime import timedelta |
| 8 | +from queue import Queue |
| 9 | +from time import sleep |
| 10 | +from typing import TYPE_CHECKING, Any, TypedDict |
4 | 11 |
|
5 | 12 | from apify_shared.utils import filter_out_none_values_recursively, ignore_docs, parse_date_fields
|
| 13 | +from more_itertools import constrained_batches |
6 | 14 |
|
7 | 15 | from apify_client._errors import ApifyApiError
|
8 | 16 | from apify_client._utils import catch_not_found_or_throw, pluck_data
|
9 | 17 | from apify_client.clients.base import ResourceClient, ResourceClientAsync
|
10 | 18 |
|
| 19 | +if TYPE_CHECKING: |
| 20 | + from collections.abc import Iterable |
| 21 | + |
| 22 | +logger = logging.getLogger(__name__) |
| 23 | + |
| 24 | +_RQ_MAX_REQUESTS_PER_BATCH = 25 |
| 25 | +_MAX_PAYLOAD_SIZE_BYTES = 9 * 1024 * 1024 # 9 MB |
| 26 | +_SAFETY_BUFFER_PERCENT = 0.01 / 100 # 0.01% |
| 27 | + |
| 28 | + |
| 29 | +class BatchAddRequestsResult(TypedDict): |
| 30 | + """Result of the batch add requests operation. |
| 31 | +
|
| 32 | + Args: |
| 33 | + processedRequests: List of successfully added requests. |
| 34 | + unprocessedRequests: List of requests that failed to be added. |
| 35 | + """ |
| 36 | + |
| 37 | + processedRequests: list[dict] |
| 38 | + unprocessedRequests: list[dict] |
| 39 | + |
| 40 | + |
| 41 | +@dataclass |
| 42 | +class AddRequestsBatch: |
| 43 | + """Batch of requests to add to the request queue. |
| 44 | +
|
| 45 | + Args: |
| 46 | + requests: List of requests to be added to the request queue. |
| 47 | + num_of_retries: Number of times this batch has been retried. |
| 48 | + """ |
| 49 | + |
| 50 | + requests: Iterable[dict] |
| 51 | + num_of_retries: int = 0 |
| 52 | + |
11 | 53 |
|
12 | 54 | class RequestQueueClient(ResourceClient):
|
13 | 55 | """Sub-client for manipulating a single request queue."""
|
@@ -240,28 +282,84 @@ def delete_request_lock(self: RequestQueueClient, request_id: str, *, forefront:
|
240 | 282 | )
|
241 | 283 |
|
242 | 284 | def batch_add_requests(
|
243 |
| - self: RequestQueueClient, |
| 285 | + self, |
244 | 286 | requests: list[dict],
|
245 | 287 | *,
|
246 |
| - forefront: bool | None = None, |
247 |
| - ) -> dict: |
248 |
| - """Add requests to the queue. |
| 288 | + forefront: bool = False, |
| 289 | + max_parallel: int = 1, |
| 290 | + max_unprocessed_requests_retries: int = 3, |
| 291 | + min_delay_between_unprocessed_requests_retries: timedelta = timedelta(milliseconds=500), |
| 292 | + ) -> BatchAddRequestsResult: |
| 293 | + """Add requests to the request queue in batches. |
| 294 | +
|
| 295 | + Requests are split into batches based on size and processed in parallel. |
249 | 296 |
|
250 | 297 | https://docs.apify.com/api/v2#/reference/request-queues/batch-request-operations/add-requests
|
251 | 298 |
|
252 | 299 | Args:
|
253 |
| - requests (list[dict]): list of the requests to add |
254 |
| - forefront (bool, optional): Whether to add the requests to the head or the end of the queue |
| 300 | + requests: List of requests to be added to the queue. |
| 301 | + forefront: Whether to add requests to the front of the queue. |
| 302 | + max_parallel: Specifies the maximum number of parallel tasks for API calls. This is only applicable |
| 303 | + to the async client. For the sync client, this value must be set to 1, as parallel execution |
| 304 | + is not supported. |
| 305 | + max_unprocessed_requests_retries: Number of retry attempts for unprocessed requests. |
| 306 | + min_delay_between_unprocessed_requests_retries: Minimum delay between retry attempts for unprocessed requests. |
| 307 | +
|
| 308 | + Returns: |
| 309 | + Result containing lists of processed and unprocessed requests. |
255 | 310 | """
|
| 311 | + if max_parallel != 1: |
| 312 | + raise NotImplementedError('max_parallel is only supported in async client') |
| 313 | + |
256 | 314 | request_params = self._params(clientKey=self.client_key, forefront=forefront)
|
257 | 315 |
|
258 |
| - response = self.http_client.call( |
259 |
| - url=self._url('requests/batch'), |
260 |
| - method='POST', |
261 |
| - params=request_params, |
262 |
| - json=requests, |
| 316 | + # Compute the payload size limit to ensure it doesn't exceed the maximum allowed size. |
| 317 | + payload_size_limit_bytes = _MAX_PAYLOAD_SIZE_BYTES - math.ceil(_MAX_PAYLOAD_SIZE_BYTES * _SAFETY_BUFFER_PERCENT) |
| 318 | + |
| 319 | + # Split the requests into batches, constrained by the max payload size and max requests per batch. |
| 320 | + batches = constrained_batches( |
| 321 | + requests, |
| 322 | + max_size=payload_size_limit_bytes, |
| 323 | + max_count=_RQ_MAX_REQUESTS_PER_BATCH, |
263 | 324 | )
|
264 |
| - return parse_date_fields(pluck_data(response.json())) |
| 325 | + |
| 326 | + # Put the batches into the queue for processing. |
| 327 | + queue = Queue[AddRequestsBatch]() |
| 328 | + |
| 329 | + for b in batches: |
| 330 | + queue.put(AddRequestsBatch(b)) |
| 331 | + |
| 332 | + processed_requests = list[dict]() |
| 333 | + unprocessed_requests = list[dict]() |
| 334 | + |
| 335 | + # Process all batches in the queue sequentially. |
| 336 | + while not queue.empty(): |
| 337 | + batch = queue.get() |
| 338 | + |
| 339 | + # Send the batch to the API. |
| 340 | + response = self.http_client.call( |
| 341 | + url=self._url('requests/batch'), |
| 342 | + method='POST', |
| 343 | + params=request_params, |
| 344 | + json=list(batch.requests), |
| 345 | + ) |
| 346 | + |
| 347 | + # Retry if the request failed and the retry limit has not been reached. |
| 348 | + if not response.is_success and batch.num_of_retries < max_unprocessed_requests_retries: |
| 349 | + batch.num_of_retries += 1 |
| 350 | + sleep(min_delay_between_unprocessed_requests_retries.total_seconds()) |
| 351 | + queue.put(batch) |
| 352 | + |
| 353 | + # Otherwise, add the processed/unprocessed requests to their respective lists. |
| 354 | + else: |
| 355 | + response_parsed = parse_date_fields(pluck_data(response.json())) |
| 356 | + processed_requests.extend(response_parsed.get('processedRequests', [])) |
| 357 | + unprocessed_requests.extend(response_parsed.get('unprocessedRequests', [])) |
| 358 | + |
| 359 | + return { |
| 360 | + 'processedRequests': processed_requests, |
| 361 | + 'unprocessedRequests': unprocessed_requests, |
| 362 | + } |
265 | 363 |
|
266 | 364 | def batch_delete_requests(self: RequestQueueClient, requests: list[dict]) -> dict:
|
267 | 365 | """Delete given requests from the queue.
|
@@ -540,29 +638,139 @@ async def delete_request_lock(
|
540 | 638 | params=request_params,
|
541 | 639 | )
|
542 | 640 |
|
| 641 | + async def _batch_add_requests_worker( |
| 642 | + self, |
| 643 | + queue: asyncio.Queue[AddRequestsBatch], |
| 644 | + request_params: dict, |
| 645 | + max_unprocessed_requests_retries: int, |
| 646 | + min_delay_between_unprocessed_requests_retries: timedelta, |
| 647 | + ) -> BatchAddRequestsResult: |
| 648 | + """Worker function to process a batch of requests. |
| 649 | +
|
| 650 | + This worker will process batches from the queue, retrying requests that fail until the retry limit is reached. |
| 651 | +
|
| 652 | + Returns result containing lists of processed and unprocessed requests by the worker. |
| 653 | + """ |
| 654 | + processed_requests = list[dict]() |
| 655 | + unprocessed_requests = list[dict]() |
| 656 | + |
| 657 | + while True: |
| 658 | + # Get the next batch from the queue. |
| 659 | + try: |
| 660 | + batch = await queue.get() |
| 661 | + except asyncio.CancelledError: |
| 662 | + break |
| 663 | + |
| 664 | + try: |
| 665 | + # Send the batch to the API. |
| 666 | + response = await self.http_client.call( |
| 667 | + url=self._url('requests/batch'), |
| 668 | + method='POST', |
| 669 | + params=request_params, |
| 670 | + json=list(batch.requests), |
| 671 | + ) |
| 672 | + |
| 673 | + response_parsed = parse_date_fields(pluck_data(response.json())) |
| 674 | + |
| 675 | + # Retry if the request failed and the retry limit has not been reached. |
| 676 | + if not response.is_success and batch.num_of_retries < max_unprocessed_requests_retries: |
| 677 | + batch.num_of_retries += 1 |
| 678 | + await asyncio.sleep(min_delay_between_unprocessed_requests_retries.total_seconds()) |
| 679 | + await queue.put(batch) |
| 680 | + |
| 681 | + # Otherwise, add the processed/unprocessed requests to their respective lists. |
| 682 | + else: |
| 683 | + processed_requests.extend(response_parsed.get('processedRequests', [])) |
| 684 | + unprocessed_requests.extend(response_parsed.get('unprocessedRequests', [])) |
| 685 | + |
| 686 | + except Exception as exc: |
| 687 | + logger.warning(f'Error occurred while processing a batch of requests: {exc}') |
| 688 | + |
| 689 | + finally: |
| 690 | + # Mark the batch as done whether it succeeded or failed. |
| 691 | + queue.task_done() |
| 692 | + |
| 693 | + return { |
| 694 | + 'processedRequests': processed_requests, |
| 695 | + 'unprocessedRequests': unprocessed_requests, |
| 696 | + } |
| 697 | + |
543 | 698 | async def batch_add_requests(
|
544 |
| - self: RequestQueueClientAsync, |
| 699 | + self, |
545 | 700 | requests: list[dict],
|
546 | 701 | *,
|
547 |
| - forefront: bool | None = None, |
548 |
| - ) -> dict: |
549 |
| - """Add requests to the queue. |
| 702 | + forefront: bool = False, |
| 703 | + max_parallel: int = 5, |
| 704 | + max_unprocessed_requests_retries: int = 3, |
| 705 | + min_delay_between_unprocessed_requests_retries: timedelta = timedelta(milliseconds=500), |
| 706 | + ) -> BatchAddRequestsResult: |
| 707 | + """Add requests to the request queue in batches. |
| 708 | +
|
| 709 | + Requests are split into batches based on size and processed in parallel. |
550 | 710 |
|
551 | 711 | https://docs.apify.com/api/v2#/reference/request-queues/batch-request-operations/add-requests
|
552 | 712 |
|
553 | 713 | Args:
|
554 |
| - requests (list[dict]): list of the requests to add |
555 |
| - forefront (bool, optional): Whether to add the requests to the head or the end of the queue |
| 714 | + requests: List of requests to be added to the queue. |
| 715 | + forefront: Whether to add requests to the front of the queue. |
| 716 | + max_parallel: Specifies the maximum number of parallel tasks for API calls. This is only applicable |
| 717 | + to the async client. For the sync client, this value must be set to 1, as parallel execution |
| 718 | + is not supported. |
| 719 | + max_unprocessed_requests_retries: Number of retry attempts for unprocessed requests. |
| 720 | + min_delay_between_unprocessed_requests_retries: Minimum delay between retry attempts for unprocessed requests. |
| 721 | +
|
| 722 | + Returns: |
| 723 | + Result containing lists of processed and unprocessed requests. |
556 | 724 | """
|
| 725 | + tasks = set[asyncio.Task]() |
| 726 | + queue: asyncio.Queue[AddRequestsBatch] = asyncio.Queue() |
557 | 727 | request_params = self._params(clientKey=self.client_key, forefront=forefront)
|
558 | 728 |
|
559 |
| - response = await self.http_client.call( |
560 |
| - url=self._url('requests/batch'), |
561 |
| - method='POST', |
562 |
| - params=request_params, |
563 |
| - json=requests, |
| 729 | + # Compute the payload size limit to ensure it doesn't exceed the maximum allowed size. |
| 730 | + payload_size_limit_bytes = _MAX_PAYLOAD_SIZE_BYTES - math.ceil(_MAX_PAYLOAD_SIZE_BYTES * _SAFETY_BUFFER_PERCENT) |
| 731 | + |
| 732 | + # Split the requests into batches, constrained by the max payload size and max requests per batch. |
| 733 | + batches = constrained_batches( |
| 734 | + requests, |
| 735 | + max_size=payload_size_limit_bytes, |
| 736 | + max_count=_RQ_MAX_REQUESTS_PER_BATCH, |
564 | 737 | )
|
565 |
| - return parse_date_fields(pluck_data(response.json())) |
| 738 | + |
| 739 | + for batch in batches: |
| 740 | + await queue.put(AddRequestsBatch(batch)) |
| 741 | + |
| 742 | + # Start a required number of worker tasks to process the batches. |
| 743 | + for i in range(max_parallel): |
| 744 | + coro = self._batch_add_requests_worker( |
| 745 | + queue, |
| 746 | + request_params, |
| 747 | + max_unprocessed_requests_retries, |
| 748 | + min_delay_between_unprocessed_requests_retries, |
| 749 | + ) |
| 750 | + task = asyncio.create_task(coro, name=f'batch_add_requests_worker_{i}') |
| 751 | + tasks.add(task) |
| 752 | + |
| 753 | + # Wait for all batches to be processed. |
| 754 | + await queue.join() |
| 755 | + |
| 756 | + # Send cancellation signals to all worker tasks and wait for them to finish. |
| 757 | + for task in tasks: |
| 758 | + task.cancel() |
| 759 | + |
| 760 | + results: list[BatchAddRequestsResult] = await asyncio.gather(*tasks) |
| 761 | + |
| 762 | + # Combine the results from all workers and return them. |
| 763 | + processed_requests = [] |
| 764 | + unprocessed_requests = [] |
| 765 | + |
| 766 | + for result in results: |
| 767 | + processed_requests.extend(result['processedRequests']) |
| 768 | + unprocessed_requests.extend(result['unprocessedRequests']) |
| 769 | + |
| 770 | + return { |
| 771 | + 'processedRequests': processed_requests, |
| 772 | + 'unprocessedRequests': unprocessed_requests, |
| 773 | + } |
566 | 774 |
|
567 | 775 | async def batch_delete_requests(self: RequestQueueClientAsync, requests: list[dict]) -> dict:
|
568 | 776 | """Delete given requests from the queue.
|
|
0 commit comments