diff --git a/tests/test_http.py b/tests/test_http.py index 52263cb..e0ece39 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -35,6 +35,10 @@ def test_headers_raw_dict_none(self): assert headers_raw_to_dict(None) is None assert headers_dict_to_raw(None) is None + def test_headers_raw_dict_empty(self): + assert headers_raw_to_dict(b"") == {} + assert headers_dict_to_raw({}) == b"" + def test_headers_raw_to_dict(self): raw = b"Content-type: text/html\n\rAccept: gzip\n\r\ Cache-Control: no-cache\n\rCache-Control: no-store\n\n" diff --git a/w3lib/http.py b/w3lib/http.py index 1791c98..fe47748 100644 --- a/w3lib/http.py +++ b/w3lib/http.py @@ -2,6 +2,7 @@ from base64 import b64encode from collections.abc import Mapping, MutableMapping, Sequence +from io import BytesIO from typing import Any, Union, overload from w3lib.util import to_bytes, to_unicode @@ -44,21 +45,23 @@ def headers_raw_to_dict(headers_raw: bytes | None) -> HeadersDictOutput | None: if headers_raw is None: return None - headers = headers_raw.splitlines() - headers_tuples = [header.split(b":", 1) for header in headers] + + if not headers_raw: + return {} result_dict: HeadersDictOutput = {} - for header_item in headers_tuples: - if len(header_item) != 2: + + for header in BytesIO(headers_raw): + key, sep, value = header.partition(b":") + if not sep: continue - item_key = header_item[0].strip() - item_value = header_item[1].strip() + key, value = key.strip(), value.strip() - if item_key in result_dict: - result_dict[item_key].append(item_value) + if key in result_dict: + result_dict[key].append(value) else: - result_dict[item_key] = [item_value] + result_dict[key] = [value] return result_dict @@ -93,13 +96,25 @@ def headers_dict_to_raw(headers_dict: HeadersDictInput | None) -> bytes | None: if headers_dict is None: return None - raw_lines = [] + + if not headers_dict: + return b"" + + parts = bytearray() + for key, value in headers_dict.items(): if isinstance(value, bytes): - raw_lines.append(b": ".join([key, value])) + if parts: + parts.extend(b"\r\n") + parts.extend(key + b": " + value) + elif isinstance(value, (list, tuple)): - raw_lines.extend(b": ".join([key, v]) for v in value) - return b"\r\n".join(raw_lines) + for v in value: + if parts: + parts.extend(b"\r\n") + parts.extend(key + b": " + v) + + return bytes(parts) def basic_auth_header(