Skip to content

Commit fd294b4

Browse files
[type validation] skip unresolved forward ref (#3376)
* skip if unresolved forward ref * make style * Update src/huggingface_hub/dataclasses.py Co-authored-by: Lucain <lucainp@gmail.com> * add forward ref in test cases * update the test * fix * maybe like this? * update tests --------- Co-authored-by: Lucain <lucainp@gmail.com> Co-authored-by: Lucain <lucain@huggingface.co>
1 parent ae08790 commit fd294b4

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

src/huggingface_hub/dataclasses.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
Any,
66
Callable,
77
Dict,
8+
ForwardRef,
89
List,
910
Literal,
1011
Optional,
@@ -325,6 +326,8 @@ def type_validator(name: str, value: Any, expected_type: Any) -> None:
325326
validator(name, value, args)
326327
elif isinstance(expected_type, type): # simple types
327328
_validate_simple_type(name, value, expected_type)
329+
elif isinstance(expected_type, ForwardRef) or isinstance(expected_type, str):
330+
return
328331
else:
329332
raise TypeError(f"Unsupported type for field '{name}': {expected_type}")
330333

tests/test_utils_strict_dataclass.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,30 @@ def strictly_positive(value: int):
2929
raise ValueError(f"Value must be strictly positive, got {value}")
3030

3131

32+
def dtype_validation(value: "ForwardDtype"):
33+
if not isinstance(value, str):
34+
raise ValueError(f"Value must be string, got {value}")
35+
36+
if isinstance(value, str) and value not in ["float32", "bfloat16", "float16"]:
37+
raise ValueError(f"Value must be one of `[float32, bfloat16, float16]` but got {value}")
38+
39+
40+
@strict
41+
@dataclass
42+
class ConfigForwardRef:
43+
"""Test forward reference handling.
44+
45+
In practice, forward reference types are not validated so a custom validator is highly recommended.
46+
"""
47+
48+
forward_ref_validated: "ForwardDtype" = validated_field(validator=dtype_validation)
49+
forward_ref: "ForwardDtype" = "float32" # type is not validated by default
50+
51+
52+
class ForwardDtype(str):
53+
"""Dummy class to simulate a forward reference (e.g. `torch.dtype`)."""
54+
55+
3256
@strict
3357
@dataclass
3458
class Config:
@@ -62,6 +86,26 @@ def test_default_values():
6286
assert config.hidden_size == 1024
6387

6488

89+
def test_forward_ref_validation_is_skipped():
90+
config = ConfigForwardRef(forward_ref="float32", forward_ref_validated="float32")
91+
assert config.forward_ref == "float32"
92+
assert config.forward_ref_validated == "float32"
93+
94+
# The `forward_ref_validated` has proper validation added in field-metadata and will be validated
95+
with pytest.raises(StrictDataclassFieldValidationError):
96+
ConfigForwardRef(forward_ref_validated="float64")
97+
98+
with pytest.raises(StrictDataclassFieldValidationError):
99+
ConfigForwardRef(forward_ref_validated=-1)
100+
101+
with pytest.raises(StrictDataclassFieldValidationError):
102+
ConfigForwardRef(forward_ref_validated="not_dtype")
103+
104+
# The `forward_ref` type is not validated => user can input anything
105+
ConfigForwardRef(forward_ref=-1, forward_ref_validated="float32")
106+
ConfigForwardRef(forward_ref=["float32"], forward_ref_validated="float32")
107+
108+
65109
def test_invalid_type_initialization():
66110
with pytest.raises(StrictDataclassFieldValidationError):
67111
Config(model_type={"type": "bert"}, vocab_size=30000, hidden_size=768)

0 commit comments

Comments
 (0)