|
22 | 22 | from huggingface_hub.inference._providers.fal_ai import ( |
23 | 23 | _POLLING_INTERVAL, |
24 | 24 | FalAIAutomaticSpeechRecognitionTask, |
| 25 | + FalAIImageSegmentationTask, |
25 | 26 | FalAIImageToImageTask, |
26 | 27 | FalAIImageToVideoTask, |
27 | 28 | FalAITextToImageTask, |
@@ -629,6 +630,124 @@ def test_image_to_video_response(self, mocker): |
629 | 630 | mock_sleep.assert_called_once_with(_POLLING_INTERVAL) |
630 | 631 | assert response == b"video_content" |
631 | 632 |
|
| 633 | + def test_image_segmentation_payload(self): |
| 634 | + helper = FalAIImageSegmentationTask() |
| 635 | + mapping_info = InferenceProviderMapping( |
| 636 | + provider="fal-ai", |
| 637 | + hf_model_id="briaai/RMBG-2.0", |
| 638 | + providerId="fal-ai/rmbg-2.0", |
| 639 | + task="image-segmentation", |
| 640 | + status="live", |
| 641 | + ) |
| 642 | + payload = helper._prepare_payload_as_dict("https://example.com/image.png", {"threshold": 0.5}, mapping_info) |
| 643 | + assert payload == {"image_url": "https://example.com/image.png", "threshold": 0.5, "sync_mode": True} |
| 644 | + |
| 645 | + payload = helper._prepare_payload_as_dict(b"dummy_image_data", {"mask_threshold": 0.8}, mapping_info) |
| 646 | + assert payload == { |
| 647 | + "image_url": f"data:image/png;base64,{base64.b64encode(b'dummy_image_data').decode()}", |
| 648 | + "mask_threshold": 0.8, |
| 649 | + "sync_mode": True, |
| 650 | + } |
| 651 | + |
| 652 | + def test_image_segmentation_response_with_data_url(self, mocker): |
| 653 | + """Test image segmentation response when image URL is a data URL.""" |
| 654 | + helper = FalAIImageSegmentationTask() |
| 655 | + mock_session = mocker.patch("huggingface_hub.inference._providers.fal_ai.get_session") |
| 656 | + mock_sleep = mocker.patch("huggingface_hub.inference._providers.fal_ai.time.sleep") |
| 657 | + dummy_mask_base64 = base64.b64encode(b"mask_content").decode() |
| 658 | + data_url = f"data:image/png;base64,{dummy_mask_base64}" |
| 659 | + mock_session.return_value.get.side_effect = [ |
| 660 | + # First call: status |
| 661 | + mocker.Mock(json=lambda: {"status": "COMPLETED"}, headers={"Content-Type": "application/json"}), |
| 662 | + # Second call: get result |
| 663 | + mocker.Mock(json=lambda: {"image": {"url": data_url}}, headers={"Content-Type": "application/json"}), |
| 664 | + ] |
| 665 | + api_key = helper._prepare_api_key("hf_token") |
| 666 | + headers = helper._prepare_headers({}, api_key) |
| 667 | + url = helper._prepare_url(api_key, "username/repo_name") |
| 668 | + |
| 669 | + request_params = RequestParameters( |
| 670 | + url=url, |
| 671 | + headers=headers, |
| 672 | + task="image-segmentation", |
| 673 | + model="username/repo_name", |
| 674 | + data=None, |
| 675 | + json=None, |
| 676 | + ) |
| 677 | + response = helper.get_response( |
| 678 | + b'{"request_id": "test_request_id", "status": "PROCESSING", "response_url": "https://queue.fal.run/username_provider/repo_name_provider/requests/test_request_id", "status_url": "https://queue.fal.run/username_provider/repo_name_provider/requests/test_request_id/status"}', |
| 679 | + request_params, |
| 680 | + ) |
| 681 | + |
| 682 | + # Verify the correct URLs were called (only status and result, no fetch needed for data URL) |
| 683 | + assert mock_session.return_value.get.call_count == 2 |
| 684 | + mock_session.return_value.get.assert_has_calls( |
| 685 | + [ |
| 686 | + mocker.call( |
| 687 | + "https://router.huggingface.co/fal-ai/username_provider/repo_name_provider/requests/test_request_id/status?_subdomain=queue", |
| 688 | + headers=request_params.headers, |
| 689 | + ), |
| 690 | + mocker.call( |
| 691 | + "https://router.huggingface.co/fal-ai/username_provider/repo_name_provider/requests/test_request_id?_subdomain=queue", |
| 692 | + headers=request_params.headers, |
| 693 | + ), |
| 694 | + ] |
| 695 | + ) |
| 696 | + mock_sleep.assert_called_once_with(_POLLING_INTERVAL) |
| 697 | + assert response == [{"label": "mask", "mask": dummy_mask_base64}] |
| 698 | + |
| 699 | + def test_image_segmentation_response_with_regular_url(self, mocker): |
| 700 | + """Test image segmentation response when image URL is a regular HTTP URL.""" |
| 701 | + helper = FalAIImageSegmentationTask() |
| 702 | + mock_session = mocker.patch("huggingface_hub.inference._providers.fal_ai.get_session") |
| 703 | + mock_sleep = mocker.patch("huggingface_hub.inference._providers.fal_ai.time.sleep") |
| 704 | + dummy_mask_base64 = base64.b64encode(b"mask_content").decode() |
| 705 | + mock_session.return_value.get.side_effect = [ |
| 706 | + # First call: status |
| 707 | + mocker.Mock(json=lambda: {"status": "COMPLETED"}, headers={"Content-Type": "application/json"}), |
| 708 | + # Second call: get result |
| 709 | + mocker.Mock( |
| 710 | + json=lambda: {"image": {"url": "https://example.com/mask.png"}}, |
| 711 | + headers={"Content-Type": "application/json"}, |
| 712 | + ), |
| 713 | + # Third call: get mask content |
| 714 | + mocker.Mock(content=b"mask_content", raise_for_status=lambda: None), |
| 715 | + ] |
| 716 | + api_key = helper._prepare_api_key("hf_token") |
| 717 | + headers = helper._prepare_headers({}, api_key) |
| 718 | + url = helper._prepare_url(api_key, "username/repo_name") |
| 719 | + |
| 720 | + request_params = RequestParameters( |
| 721 | + url=url, |
| 722 | + headers=headers, |
| 723 | + task="image-segmentation", |
| 724 | + model="username/repo_name", |
| 725 | + data=None, |
| 726 | + json=None, |
| 727 | + ) |
| 728 | + response = helper.get_response( |
| 729 | + b'{"request_id": "test_request_id", "status": "PROCESSING", "response_url": "https://queue.fal.run/username_provider/repo_name_provider/requests/test_request_id", "status_url": "https://queue.fal.run/username_provider/repo_name_provider/requests/test_request_id/status"}', |
| 730 | + request_params, |
| 731 | + ) |
| 732 | + |
| 733 | + # Verify the correct URLs were called (status, result, and mask fetch) |
| 734 | + assert mock_session.return_value.get.call_count == 3 |
| 735 | + mock_session.return_value.get.assert_has_calls( |
| 736 | + [ |
| 737 | + mocker.call( |
| 738 | + "https://router.huggingface.co/fal-ai/username_provider/repo_name_provider/requests/test_request_id/status?_subdomain=queue", |
| 739 | + headers=request_params.headers, |
| 740 | + ), |
| 741 | + mocker.call( |
| 742 | + "https://router.huggingface.co/fal-ai/username_provider/repo_name_provider/requests/test_request_id?_subdomain=queue", |
| 743 | + headers=request_params.headers, |
| 744 | + ), |
| 745 | + mocker.call("https://example.com/mask.png"), |
| 746 | + ] |
| 747 | + ) |
| 748 | + mock_sleep.assert_called_once_with(_POLLING_INTERVAL) |
| 749 | + assert response == [{"label": "mask", "mask": dummy_mask_base64}] |
| 750 | + |
632 | 751 |
|
633 | 752 | class TestFeatherlessAIProvider: |
634 | 753 | def test_prepare_route_chat_completionurl(self): |
|
0 commit comments