Skip to content

Commit e6d2c5f

Browse files
[InferenceClient] Add content-type header whenever possible + refacto (#3321)
* [InferenceClient] Add content-type header whenever possible + refacto * address comments * Update src/huggingface_hub/inference/_common.py Co-authored-by: célina <hanouticelina@gmail.com> --------- Co-authored-by: célina <hanouticelina@gmail.com>
1 parent 1d971b0 commit e6d2c5f

File tree

12 files changed

+313
-167
lines changed

12 files changed

+313
-167
lines changed

src/huggingface_hub/inference/_client.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
_bytes_to_list,
5454
_get_unsupported_text_generation_kwargs,
5555
_import_numpy,
56-
_open_as_binary,
5756
_set_unsupported_text_generation_kwargs,
5857
_stream_chat_completion_response,
5958
_stream_text_generation_response,
@@ -257,21 +256,20 @@ def _inner_post(
257256
if request_parameters.task in TASKS_EXPECTING_IMAGES and "Accept" not in request_parameters.headers:
258257
request_parameters.headers["Accept"] = "image/png"
259258

260-
with _open_as_binary(request_parameters.data) as data_as_binary:
261-
try:
262-
response = get_session().post(
263-
request_parameters.url,
264-
json=request_parameters.json,
265-
data=data_as_binary,
266-
headers=request_parameters.headers,
267-
cookies=self.cookies,
268-
timeout=self.timeout,
269-
stream=stream,
270-
proxies=self.proxies,
271-
)
272-
except TimeoutError as error:
273-
# Convert any `TimeoutError` to a `InferenceTimeoutError`
274-
raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore
259+
try:
260+
response = get_session().post(
261+
request_parameters.url,
262+
json=request_parameters.json,
263+
data=request_parameters.data,
264+
headers=request_parameters.headers,
265+
cookies=self.cookies,
266+
timeout=self.timeout,
267+
stream=stream,
268+
proxies=self.proxies,
269+
)
270+
except TimeoutError as error:
271+
# Convert any `TimeoutError` to a `InferenceTimeoutError`
272+
raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore
275273

276274
try:
277275
hf_raise_for_status(response)

src/huggingface_hub/inference/_common.py

Lines changed: 77 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,14 @@
1919
import json
2020
import logging
2121
import mimetypes
22-
from contextlib import contextmanager
2322
from dataclasses import dataclass
2423
from pathlib import Path
2524
from typing import (
2625
TYPE_CHECKING,
2726
Any,
2827
AsyncIterable,
2928
BinaryIO,
30-
ContextManager,
3129
Dict,
32-
Generator,
3330
Iterable,
3431
List,
3532
Literal,
@@ -61,8 +58,7 @@
6158
# TYPES
6259
UrlT = str
6360
PathT = Union[str, Path]
64-
BinaryT = Union[bytes, BinaryIO]
65-
ContentT = Union[BinaryT, PathT, UrlT, "Image"]
61+
ContentT = Union[bytes, BinaryIO, PathT, UrlT, "Image", bytearray, memoryview]
6662

6763
# Use to set a Accept: image/png header
6864
TASKS_EXPECTING_IMAGES = {"text-to-image", "image-to-image"}
@@ -76,10 +72,35 @@ class RequestParameters:
7672
task: str
7773
model: Optional[str]
7874
json: Optional[Union[str, Dict, List]]
79-
data: Optional[ContentT]
75+
data: Optional[bytes]
8076
headers: Dict[str, Any]
8177

8278

79+
class MimeBytes(bytes):
80+
"""
81+
A bytes object with a mime type.
82+
To be returned by `_prepare_payload_open_as_mime_bytes` in subclasses.
83+
84+
Example:
85+
```python
86+
>>> b = MimeBytes(b"hello", "text/plain")
87+
>>> isinstance(b, bytes)
88+
True
89+
>>> b.mime_type
90+
'text/plain'
91+
```
92+
"""
93+
94+
mime_type: Optional[str]
95+
96+
def __new__(cls, data: bytes, mime_type: Optional[str] = None):
97+
obj = super().__new__(cls, data)
98+
obj.mime_type = mime_type
99+
if isinstance(data, MimeBytes) and mime_type is None:
100+
obj.mime_type = data.mime_type
101+
return obj
102+
103+
83104
## IMPORT UTILS
84105

85106

@@ -117,31 +138,49 @@ def _import_pil_image():
117138

118139

119140
@overload
120-
def _open_as_binary(
121-
content: ContentT,
122-
) -> ContextManager[BinaryT]: ... # means "if input is not None, output is not None"
141+
def _open_as_mime_bytes(content: ContentT) -> MimeBytes: ... # means "if input is not None, output is not None"
123142

124143

125144
@overload
126-
def _open_as_binary(
127-
content: Literal[None],
128-
) -> ContextManager[Literal[None]]: ... # means "if input is None, output is None"
145+
def _open_as_mime_bytes(content: Literal[None]) -> Literal[None]: ... # means "if input is None, output is None"
129146

130147

131-
@contextmanager # type: ignore
132-
def _open_as_binary(content: Optional[ContentT]) -> Generator[Optional[BinaryT], None, None]:
148+
def _open_as_mime_bytes(content: Optional[ContentT]) -> Optional[MimeBytes]:
133149
"""Open `content` as a binary file, either from a URL, a local path, raw bytes, or a PIL Image.
134150
135151
Do nothing if `content` is None.
136-
137-
TODO: handle base64 as input
138152
"""
153+
# If content is None, yield None
154+
if content is None:
155+
return None
156+
157+
# If content is bytes, return it
158+
if isinstance(content, bytes):
159+
return MimeBytes(content)
160+
161+
# If content is raw binary data (bytearray, memoryview)
162+
if isinstance(content, (bytearray, memoryview)):
163+
return MimeBytes(bytes(content))
164+
165+
# If content is a binary file-like object
166+
if hasattr(content, "read"): # duck-typing instead of isinstance(content, BinaryIO)
167+
logger.debug("Reading content from BinaryIO")
168+
data = content.read()
169+
mime_type = mimetypes.guess_type(content.name)[0] if hasattr(content, "name") else None
170+
if isinstance(data, str):
171+
raise TypeError("Expected binary stream (bytes), but got text stream")
172+
return MimeBytes(data, mime_type=mime_type)
173+
139174
# If content is a string => must be either a URL or a path
140175
if isinstance(content, str):
141176
if content.startswith("https://") or content.startswith("http://"):
142177
logger.debug(f"Downloading content from {content}")
143-
yield get_session().get(content).content # TODO: retrieve as stream and pipe to post request ?
144-
return
178+
response = get_session().get(content)
179+
mime_type = response.headers.get("Content-Type")
180+
if mime_type is None:
181+
mime_type = mimetypes.guess_type(content)[0]
182+
return MimeBytes(response.content, mime_type=mime_type)
183+
145184
content = Path(content)
146185
if not content.exists():
147186
raise FileNotFoundError(
@@ -152,9 +191,7 @@ def _open_as_binary(content: Optional[ContentT]) -> Generator[Optional[BinaryT],
152191
# If content is a Path => open it
153192
if isinstance(content, Path):
154193
logger.debug(f"Opening content from {content}")
155-
with content.open("rb") as f:
156-
yield f
157-
return
194+
return MimeBytes(content.read_bytes(), mime_type=mimetypes.guess_type(content)[0])
158195

159196
# If content is a PIL Image => convert to bytes
160197
if is_pillow_available():
@@ -163,38 +200,37 @@ def _open_as_binary(content: Optional[ContentT]) -> Generator[Optional[BinaryT],
163200
if isinstance(content, Image.Image):
164201
logger.debug("Converting PIL Image to bytes")
165202
buffer = io.BytesIO()
166-
content.save(buffer, format=content.format or "PNG")
167-
yield buffer.getvalue()
168-
return
203+
format = content.format or "PNG"
204+
content.save(buffer, format=format)
205+
return MimeBytes(buffer.getvalue(), mime_type=f"image/{format.lower()}")
169206

170-
# Otherwise: already a file-like object or None
171-
yield content # type: ignore
207+
# If nothing matched, raise error
208+
raise TypeError(
209+
f"Unsupported content type: {type(content)}. "
210+
"Expected one of: bytes, bytearray, BinaryIO, memoryview, Path, str (URL or file path), or PIL.Image.Image."
211+
)
172212

173213

174214
def _b64_encode(content: ContentT) -> str:
175215
"""Encode a raw file (image, audio) into base64. Can be bytes, an opened file, a path or a URL."""
176-
with _open_as_binary(content) as data:
177-
data_as_bytes = data if isinstance(data, bytes) else data.read()
178-
return base64.b64encode(data_as_bytes).decode()
216+
raw_bytes = _open_as_mime_bytes(content)
217+
return base64.b64encode(raw_bytes).decode()
179218

180219

181220
def _as_url(content: ContentT, default_mime_type: str) -> str:
182-
if isinstance(content, str) and (content.startswith("https://") or content.startswith("http://")):
221+
if isinstance(content, str) and content.startswith(("http://", "https://", "data:")):
183222
return content
184223

185-
# Handle MIME type detection for different content types
186-
mime_type = None
187-
if isinstance(content, (str, Path)):
188-
mime_type = mimetypes.guess_type(content, strict=False)[0]
189-
elif is_pillow_available():
190-
from PIL import Image
224+
# Convert content to bytes
225+
raw_bytes = _open_as_mime_bytes(content)
191226

192-
if isinstance(content, Image.Image):
193-
# Determine MIME type from PIL Image format, in sync with `_open_as_binary`
194-
mime_type = f"image/{(content.format or 'PNG').lower()}"
227+
# Get MIME type
228+
mime_type = raw_bytes.mime_type or default_mime_type
195229

196-
mime_type = mime_type or default_mime_type
197-
encoded_data = _b64_encode(content)
230+
# Encode content to base64
231+
encoded_data = base64.b64encode(raw_bytes).decode()
232+
233+
# Build data URL
198234
return f"data:{mime_type};base64,{encoded_data}"
199235

200236

@@ -239,9 +275,6 @@ def _as_dict(response: Union[bytes, Dict]) -> Dict:
239275
return json.loads(response) if isinstance(response, bytes) else response
240276

241277

242-
## PAYLOAD UTILS
243-
244-
245278
## STREAMING UTILS
246279

247280

src/huggingface_hub/inference/_generated/_async_client.py

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
_bytes_to_list,
4141
_get_unsupported_text_generation_kwargs,
4242
_import_numpy,
43-
_open_as_binary,
4443
_set_unsupported_text_generation_kwargs,
4544
raise_text_generation_error,
4645
)
@@ -255,39 +254,38 @@ async def _inner_post(
255254
if request_parameters.task in TASKS_EXPECTING_IMAGES and "Accept" not in request_parameters.headers:
256255
request_parameters.headers["Accept"] = "image/png"
257256

258-
with _open_as_binary(request_parameters.data) as data_as_binary:
259-
# Do not use context manager as we don't want to close the connection immediately when returning
260-
# a stream
261-
session = self._get_client_session(headers=request_parameters.headers)
257+
# Do not use context manager as we don't want to close the connection immediately when returning
258+
# a stream
259+
session = self._get_client_session(headers=request_parameters.headers)
262260

263-
try:
264-
response = await session.post(
265-
request_parameters.url, json=request_parameters.json, data=data_as_binary, proxy=self.proxies
266-
)
267-
response_error_payload = None
268-
if response.status != 200:
269-
try:
270-
response_error_payload = await response.json() # get payload before connection closed
271-
except Exception:
272-
pass
273-
response.raise_for_status()
274-
if stream:
275-
return _async_yield_from(session, response)
276-
else:
277-
content = await response.read()
278-
await session.close()
279-
return content
280-
except asyncio.TimeoutError as error:
281-
await session.close()
282-
# Convert any `TimeoutError` to a `InferenceTimeoutError`
283-
raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore
284-
except aiohttp.ClientResponseError as error:
285-
error.response_error_payload = response_error_payload
286-
await session.close()
287-
raise error
288-
except Exception:
261+
try:
262+
response = await session.post(
263+
request_parameters.url, json=request_parameters.json, data=request_parameters.data, proxy=self.proxies
264+
)
265+
response_error_payload = None
266+
if response.status != 200:
267+
try:
268+
response_error_payload = await response.json() # get payload before connection closed
269+
except Exception:
270+
pass
271+
response.raise_for_status()
272+
if stream:
273+
return _async_yield_from(session, response)
274+
else:
275+
content = await response.read()
289276
await session.close()
290-
raise
277+
return content
278+
except asyncio.TimeoutError as error:
279+
await session.close()
280+
# Convert any `TimeoutError` to a `InferenceTimeoutError`
281+
raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore
282+
except aiohttp.ClientResponseError as error:
283+
error.response_error_payload = response_error_payload
284+
await session.close()
285+
raise error
286+
except Exception:
287+
await session.close()
288+
raise
291289

292290
async def __aenter__(self):
293291
return self

src/huggingface_hub/inference/_providers/_common.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from huggingface_hub import constants
55
from huggingface_hub.hf_api import InferenceProviderMapping
6-
from huggingface_hub.inference._common import RequestParameters
6+
from huggingface_hub.inference._common import MimeBytes, RequestParameters
77
from huggingface_hub.inference._generated.types.chat_completion import ChatCompletionInputMessage
88
from huggingface_hub.utils import build_hf_headers, get_token, logging
99

@@ -108,8 +108,17 @@ def prepare_request(
108108
raise ValueError("Both payload and data cannot be set in the same request.")
109109
if payload is None and data is None:
110110
raise ValueError("Either payload or data must be set in the request.")
111+
112+
# normalize headers to lowercase and add content-type if not present
113+
normalized_headers = self._normalize_headers(headers, payload, data)
114+
111115
return RequestParameters(
112-
url=url, task=self.task, model=provider_mapping_info.provider_id, json=payload, data=data, headers=headers
116+
url=url,
117+
task=self.task,
118+
model=provider_mapping_info.provider_id,
119+
json=payload,
120+
data=data,
121+
headers=normalized_headers,
113122
)
114123

115124
def get_response(
@@ -172,7 +181,22 @@ def _prepare_mapping_info(self, model: Optional[str]) -> InferenceProviderMappin
172181
)
173182
return provider_mapping
174183

175-
def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
184+
def _normalize_headers(
185+
self, headers: Dict[str, Any], payload: Optional[Dict[str, Any]], data: Optional[MimeBytes]
186+
) -> Dict[str, Any]:
187+
"""Normalize the headers to use for the request.
188+
189+
Override this method in subclasses for customized headers.
190+
"""
191+
normalized_headers = {key.lower(): value for key, value in headers.items() if value is not None}
192+
if normalized_headers.get("content-type") is None:
193+
if data is not None and data.mime_type is not None:
194+
normalized_headers["content-type"] = data.mime_type
195+
elif payload is not None:
196+
normalized_headers["content-type"] = "application/json"
197+
return normalized_headers
198+
199+
def _prepare_headers(self, headers: Dict, api_key: str) -> Dict[str, Any]:
176200
"""Return the headers to use for the request.
177201
178202
Override this method in subclasses for customized headers.
@@ -222,7 +246,7 @@ def _prepare_payload_as_bytes(
222246
parameters: Dict,
223247
provider_mapping_info: InferenceProviderMapping,
224248
extra_payload: Optional[Dict],
225-
) -> Optional[bytes]:
249+
) -> Optional[MimeBytes]:
226250
"""Return the body to use for the request, as bytes.
227251
228252
Override this method in subclasses for customized body data.

src/huggingface_hub/inference/_providers/black_forest_labs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class BlackForestLabsTextToImageTask(TaskProviderHelper):
1818
def __init__(self):
1919
super().__init__(provider="black-forest-labs", base_url="https://api.us1.bfl.ai", task="text-to-image")
2020

21-
def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
21+
def _prepare_headers(self, headers: Dict, api_key: str) -> Dict[str, Any]:
2222
headers = super()._prepare_headers(headers, api_key)
2323
if not api_key.startswith("hf_"):
2424
_ = headers.pop("authorization")

0 commit comments

Comments
 (0)