Skip to content

Commit aa6a232

Browse files
[Inference Providers] implement image-segmentation for fal (#3521)
* implement image-segmentation for fal.ai * Update src/huggingface_hub/inference/_providers/fal_ai.py Co-authored-by: Lucain <lucain@huggingface.co> * update tests --------- Co-authored-by: Lucain <lucain@huggingface.co>
1 parent d116de0 commit aa6a232

File tree

6 files changed

+175
-1
lines changed

6 files changed

+175
-1
lines changed

docs/source/en/guides/inference.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ For more details, refer to the [Inference Providers pricing documentation](https
202202
| [`~InferenceClient.feature_extraction`] |||||||||||||||||||||
203203
| [`~InferenceClient.fill_mask`] |||||||||||||||||||||
204204
| [`~InferenceClient.image_classification`] |||||||||||||||||||||
205-
| [`~InferenceClient.image_segmentation`] ||||| ||||||||||||||||
205+
| [`~InferenceClient.image_segmentation`] ||||| ||||||||||||||||
206206
| [`~InferenceClient.image_to_image`] |||||||||||||||||||||
207207
| [`~InferenceClient.image_to_video`] |||||||||||||||||||||
208208
| [`~InferenceClient.image_to_text`] |||||||||||||||||||||

src/huggingface_hub/inference/_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1262,6 +1262,7 @@ def image_segmentation(
12621262
api_key=self.token,
12631263
)
12641264
response = self._inner_post(request_parameters)
1265+
response = provider_helper.get_response(response, request_parameters)
12651266
output = ImageSegmentationOutputElement.parse_obj_as_list(response)
12661267
for item in output:
12671268
item.mask = _b64_to_image(item.mask) # type: ignore [assignment]

src/huggingface_hub/inference/_generated/_async_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1293,6 +1293,7 @@ async def image_segmentation(
12931293
api_key=self.token,
12941294
)
12951295
response = await self._inner_post(request_parameters)
1296+
response = provider_helper.get_response(response, request_parameters)
12961297
output = ImageSegmentationOutputElement.parse_obj_as_list(response)
12971298
for item in output:
12981299
item.mask = _b64_to_image(item.mask) # type: ignore [assignment]

src/huggingface_hub/inference/_providers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .cohere import CohereConversationalTask
1414
from .fal_ai import (
1515
FalAIAutomaticSpeechRecognitionTask,
16+
FalAIImageSegmentationTask,
1617
FalAIImageToImageTask,
1718
FalAIImageToVideoTask,
1819
FalAITextToImageTask,
@@ -102,6 +103,7 @@
102103
"text-to-video": FalAITextToVideoTask(),
103104
"image-to-video": FalAIImageToVideoTask(),
104105
"image-to-image": FalAIImageToImageTask(),
106+
"image-segmentation": FalAIImageSegmentationTask(),
105107
},
106108
"featherless-ai": {
107109
"conversational": FeatherlessConversationalTask(),

src/huggingface_hub/inference/_providers/fal_ai.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,3 +246,54 @@ def get_response(
246246
output = super().get_response(response, request_params)
247247
url = _as_dict(output)["video"]["url"]
248248
return get_session().get(url).content
249+
250+
251+
class FalAIImageSegmentationTask(FalAIQueueTask):
252+
def __init__(self):
253+
super().__init__("image-segmentation")
254+
255+
def _prepare_payload_as_dict(
256+
self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
257+
) -> Optional[dict]:
258+
image_url = _as_url(inputs, default_mime_type="image/png")
259+
payload: dict[str, Any] = {
260+
"image_url": image_url,
261+
**filter_none(parameters),
262+
"sync_mode": True,
263+
}
264+
return payload
265+
266+
def get_response(
267+
self,
268+
response: Union[bytes, dict],
269+
request_params: Optional[RequestParameters] = None,
270+
) -> Any:
271+
result = super().get_response(response, request_params)
272+
result_dict = _as_dict(result)
273+
274+
if "image" not in result_dict:
275+
raise ValueError(f"Response from fal ai image-segmentation API does not contain an image: {result_dict}")
276+
277+
image_data = result_dict["image"]
278+
if "url" not in image_data:
279+
raise ValueError(f"Image data from fal ai image-segmentation API does not contain a URL: {image_data}")
280+
281+
image_url = image_data["url"]
282+
283+
if isinstance(image_url, str) and image_url.startswith("data:"):
284+
if "," in image_url:
285+
mask_base64 = image_url.split(",", 1)[1]
286+
else:
287+
raise ValueError(f"Invalid data URL format: {image_url}")
288+
else:
289+
# or it's a regular URL, fetch it
290+
mask_response = get_session().get(image_url)
291+
hf_raise_for_status(mask_response)
292+
mask_base64 = base64.b64encode(mask_response.content).decode()
293+
294+
return [
295+
{
296+
"label": "mask",
297+
"mask": mask_base64,
298+
}
299+
]

tests/test_inference_providers.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from huggingface_hub.inference._providers.fal_ai import (
2323
_POLLING_INTERVAL,
2424
FalAIAutomaticSpeechRecognitionTask,
25+
FalAIImageSegmentationTask,
2526
FalAIImageToImageTask,
2627
FalAIImageToVideoTask,
2728
FalAITextToImageTask,
@@ -629,6 +630,124 @@ def test_image_to_video_response(self, mocker):
629630
mock_sleep.assert_called_once_with(_POLLING_INTERVAL)
630631
assert response == b"video_content"
631632

633+
def test_image_segmentation_payload(self):
634+
helper = FalAIImageSegmentationTask()
635+
mapping_info = InferenceProviderMapping(
636+
provider="fal-ai",
637+
hf_model_id="briaai/RMBG-2.0",
638+
providerId="fal-ai/rmbg-2.0",
639+
task="image-segmentation",
640+
status="live",
641+
)
642+
payload = helper._prepare_payload_as_dict("https://example.com/image.png", {"threshold": 0.5}, mapping_info)
643+
assert payload == {"image_url": "https://example.com/image.png", "threshold": 0.5, "sync_mode": True}
644+
645+
payload = helper._prepare_payload_as_dict(b"dummy_image_data", {"mask_threshold": 0.8}, mapping_info)
646+
assert payload == {
647+
"image_url": f"data:image/png;base64,{base64.b64encode(b'dummy_image_data').decode()}",
648+
"mask_threshold": 0.8,
649+
"sync_mode": True,
650+
}
651+
652+
def test_image_segmentation_response_with_data_url(self, mocker):
653+
"""Test image segmentation response when image URL is a data URL."""
654+
helper = FalAIImageSegmentationTask()
655+
mock_session = mocker.patch("huggingface_hub.inference._providers.fal_ai.get_session")
656+
mock_sleep = mocker.patch("huggingface_hub.inference._providers.fal_ai.time.sleep")
657+
dummy_mask_base64 = base64.b64encode(b"mask_content").decode()
658+
data_url = f"data:image/png;base64,{dummy_mask_base64}"
659+
mock_session.return_value.get.side_effect = [
660+
# First call: status
661+
mocker.Mock(json=lambda: {"status": "COMPLETED"}, headers={"Content-Type": "application/json"}),
662+
# Second call: get result
663+
mocker.Mock(json=lambda: {"image": {"url": data_url}}, headers={"Content-Type": "application/json"}),
664+
]
665+
api_key = helper._prepare_api_key("hf_token")
666+
headers = helper._prepare_headers({}, api_key)
667+
url = helper._prepare_url(api_key, "username/repo_name")
668+
669+
request_params = RequestParameters(
670+
url=url,
671+
headers=headers,
672+
task="image-segmentation",
673+
model="username/repo_name",
674+
data=None,
675+
json=None,
676+
)
677+
response = helper.get_response(
678+
b'{"request_id": "test_request_id", "status": "PROCESSING", "response_url": "https://queue.fal.run/username_provider/repo_name_provider/requests/test_request_id", "status_url": "https://queue.fal.run/username_provider/repo_name_provider/requests/test_request_id/status"}',
679+
request_params,
680+
)
681+
682+
# Verify the correct URLs were called (only status and result, no fetch needed for data URL)
683+
assert mock_session.return_value.get.call_count == 2
684+
mock_session.return_value.get.assert_has_calls(
685+
[
686+
mocker.call(
687+
"https://router.huggingface.co/fal-ai/username_provider/repo_name_provider/requests/test_request_id/status?_subdomain=queue",
688+
headers=request_params.headers,
689+
),
690+
mocker.call(
691+
"https://router.huggingface.co/fal-ai/username_provider/repo_name_provider/requests/test_request_id?_subdomain=queue",
692+
headers=request_params.headers,
693+
),
694+
]
695+
)
696+
mock_sleep.assert_called_once_with(_POLLING_INTERVAL)
697+
assert response == [{"label": "mask", "mask": dummy_mask_base64}]
698+
699+
def test_image_segmentation_response_with_regular_url(self, mocker):
700+
"""Test image segmentation response when image URL is a regular HTTP URL."""
701+
helper = FalAIImageSegmentationTask()
702+
mock_session = mocker.patch("huggingface_hub.inference._providers.fal_ai.get_session")
703+
mock_sleep = mocker.patch("huggingface_hub.inference._providers.fal_ai.time.sleep")
704+
dummy_mask_base64 = base64.b64encode(b"mask_content").decode()
705+
mock_session.return_value.get.side_effect = [
706+
# First call: status
707+
mocker.Mock(json=lambda: {"status": "COMPLETED"}, headers={"Content-Type": "application/json"}),
708+
# Second call: get result
709+
mocker.Mock(
710+
json=lambda: {"image": {"url": "https://example.com/mask.png"}},
711+
headers={"Content-Type": "application/json"},
712+
),
713+
# Third call: get mask content
714+
mocker.Mock(content=b"mask_content", raise_for_status=lambda: None),
715+
]
716+
api_key = helper._prepare_api_key("hf_token")
717+
headers = helper._prepare_headers({}, api_key)
718+
url = helper._prepare_url(api_key, "username/repo_name")
719+
720+
request_params = RequestParameters(
721+
url=url,
722+
headers=headers,
723+
task="image-segmentation",
724+
model="username/repo_name",
725+
data=None,
726+
json=None,
727+
)
728+
response = helper.get_response(
729+
b'{"request_id": "test_request_id", "status": "PROCESSING", "response_url": "https://queue.fal.run/username_provider/repo_name_provider/requests/test_request_id", "status_url": "https://queue.fal.run/username_provider/repo_name_provider/requests/test_request_id/status"}',
730+
request_params,
731+
)
732+
733+
# Verify the correct URLs were called (status, result, and mask fetch)
734+
assert mock_session.return_value.get.call_count == 3
735+
mock_session.return_value.get.assert_has_calls(
736+
[
737+
mocker.call(
738+
"https://router.huggingface.co/fal-ai/username_provider/repo_name_provider/requests/test_request_id/status?_subdomain=queue",
739+
headers=request_params.headers,
740+
),
741+
mocker.call(
742+
"https://router.huggingface.co/fal-ai/username_provider/repo_name_provider/requests/test_request_id?_subdomain=queue",
743+
headers=request_params.headers,
744+
),
745+
mocker.call("https://example.com/mask.png"),
746+
]
747+
)
748+
mock_sleep.assert_called_once_with(_POLLING_INTERVAL)
749+
assert response == [{"label": "mask", "mask": dummy_mask_base64}]
750+
632751

633752
class TestFeatherlessAIProvider:
634753
def test_prepare_route_chat_completionurl(self):

0 commit comments

Comments
 (0)