@@ -29,6 +29,30 @@ def strictly_positive(value: int):
29
29
raise ValueError (f"Value must be strictly positive, got { value } " )
30
30
31
31
32
+ def dtype_validation (value : "ForwardDtype" ):
33
+ if not isinstance (value , str ):
34
+ raise ValueError (f"Value must be string, got { value } " )
35
+
36
+ if isinstance (value , str ) and value not in ["float32" , "bfloat16" , "float16" ]:
37
+ raise ValueError (f"Value must be one of `[float32, bfloat16, float16]` but got { value } " )
38
+
39
+
40
+ @strict
41
+ @dataclass
42
+ class ConfigForwardRef :
43
+ """Test forward reference handling.
44
+
45
+ In practice, forward reference types are not validated so a custom validator is highly recommended.
46
+ """
47
+
48
+ forward_ref_validated : "ForwardDtype" = validated_field (validator = dtype_validation )
49
+ forward_ref : "ForwardDtype" = "float32" # type is not validated by default
50
+
51
+
52
+ class ForwardDtype (str ):
53
+ """Dummy class to simulate a forward reference (e.g. `torch.dtype`)."""
54
+
55
+
32
56
@strict
33
57
@dataclass
34
58
class Config :
@@ -62,6 +86,26 @@ def test_default_values():
62
86
assert config .hidden_size == 1024
63
87
64
88
89
+ def test_forward_ref_validation_is_skipped ():
90
+ config = ConfigForwardRef (forward_ref = "float32" , forward_ref_validated = "float32" )
91
+ assert config .forward_ref == "float32"
92
+ assert config .forward_ref_validated == "float32"
93
+
94
+ # The `forward_ref_validated` has proper validation added in field-metadata and will be validated
95
+ with pytest .raises (StrictDataclassFieldValidationError ):
96
+ ConfigForwardRef (forward_ref_validated = "float64" )
97
+
98
+ with pytest .raises (StrictDataclassFieldValidationError ):
99
+ ConfigForwardRef (forward_ref_validated = - 1 )
100
+
101
+ with pytest .raises (StrictDataclassFieldValidationError ):
102
+ ConfigForwardRef (forward_ref_validated = "not_dtype" )
103
+
104
+ # The `forward_ref` type is not validated => user can input anything
105
+ ConfigForwardRef (forward_ref = - 1 , forward_ref_validated = "float32" )
106
+ ConfigForwardRef (forward_ref = ["float32" ], forward_ref_validated = "float32" )
107
+
108
+
65
109
def test_invalid_type_initialization ():
66
110
with pytest .raises (StrictDataclassFieldValidationError ):
67
111
Config (model_type = {"type" : "bert" }, vocab_size = 30000 , hidden_size = 768 )
0 commit comments