Skip to content

Commit d783f32

Browse files
authored
Merge pull request #2722 from bagerard/clone_list_enum_fields
[Clone] Fixing loading of EnumField inside ListField
2 parents 30c2485 + bae0275 commit d783f32

File tree

7 files changed

+162
-15
lines changed

7 files changed

+162
-15
lines changed

docs/changelog.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Development
1818
- Addition of Decimal128Field: :class:`~mongoengine.fields.Decimal128Field` for accurate representation of Decimals (much better than the legacy field DecimalField).
1919
Although it could work to switch an existing DecimalField to Decimal128Field without applying a migration script,
2020
it is not recommended to do so (DecimalField uses float/str to store the value, Decimal128Field uses Decimal128).
21+
- BREAKING CHANGE: When using ListField(EnumField) or DictField(EnumField), the values weren't always cast into the Enum (#2531)
2122

2223
Changes in 0.25.0
2324
=================

mongoengine/base/fields.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,18 @@ def _lazy_load_refs(instance, name, ref_values, *, max_depth):
282282
)
283283
return documents
284284

285+
def __set__(self, instance, value):
286+
# Some fields e.g EnumField are converted upon __set__
287+
# So it is fair to mimic the same behavior when using e.g ListField(EnumField)
288+
EnumField = _import_class("EnumField")
289+
if self.field and isinstance(self.field, EnumField):
290+
if isinstance(value, (list, tuple)):
291+
value = [self.field.to_python(sub_val) for sub_val in value]
292+
elif isinstance(value, dict):
293+
value = {key: self.field.to_python(sub) for key, sub in value.items()}
294+
295+
return super().__set__(instance, value)
296+
285297
def __get__(self, instance, owner):
286298
"""Descriptor to automatically dereference references."""
287299
if instance is None:
@@ -434,12 +446,12 @@ def to_mongo(self, value, use_db_field=True, fields=None):
434446
" have been saved to the database"
435447
)
436448

437-
# If its a document that is not inheritable it won't have
449+
# If it's a document that is not inheritable it won't have
438450
# any _cls data so make it a generic reference allows
439451
# us to dereference
440452
meta = getattr(v, "_meta", {})
441453
allow_inheritance = meta.get("allow_inheritance")
442-
if not allow_inheritance and not self.field:
454+
if not allow_inheritance:
443455
value_dict[k] = GenericReferenceField().to_mongo(v)
444456
else:
445457
collection = v._get_collection_name()

mongoengine/fields.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,7 +1122,7 @@ class MapField(DictField):
11221122
"""
11231123

11241124
def __init__(self, field=None, *args, **kwargs):
1125-
# XXX ValidationError raised outside of the "validate" method.
1125+
# XXX ValidationError raised outside the "validate" method.
11261126
if not isinstance(field, BaseField):
11271127
self.error("Argument to MapField constructor must be a valid field")
11281128
super().__init__(field=field, *args, **kwargs)
@@ -1665,14 +1665,25 @@ def __init__(self, enum, **kwargs):
16651665
kwargs["choices"] = list(self._enum_cls) # Implicit validator
16661666
super().__init__(**kwargs)
16671667

1668-
def __set__(self, instance, value):
1669-
is_legal_value = value is None or isinstance(value, self._enum_cls)
1670-
if not is_legal_value:
1668+
def validate(self, value):
1669+
if isinstance(value, self._enum_cls):
1670+
return super().validate(value)
1671+
try:
1672+
self._enum_cls(value)
1673+
except ValueError:
1674+
self.error(f"{value} is not a valid {self._enum_cls}")
1675+
1676+
def to_python(self, value):
1677+
value = super().to_python(value)
1678+
if not isinstance(value, self._enum_cls):
16711679
try:
1672-
value = self._enum_cls(value)
1673-
except Exception:
1674-
pass
1675-
return super().__set__(instance, value)
1680+
return self._enum_cls(value)
1681+
except ValueError:
1682+
return value
1683+
return value
1684+
1685+
def __set__(self, instance, value):
1686+
return super().__set__(instance, self.to_python(value))
16761687

16771688
def to_mongo(self, value):
16781689
if isinstance(value, self._enum_cls):

tests/fields/test_binary_field.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@ class Attachment(Document):
3131
assert MIME_TYPE == attachment_1.content_type
3232
assert BLOB == bytes(attachment_1.blob)
3333

34+
def test_bytearray_conversion_to_bytes(self):
35+
class Dummy(Document):
36+
blob = BinaryField()
37+
38+
byte_arr = bytearray(b"\x00\x00\x00\x00\x00")
39+
dummy = Dummy(blob=byte_arr)
40+
assert isinstance(dummy.blob, bytes)
41+
3442
def test_validation_succeeds(self):
3543
"""Ensure that valid values can be assigned to binary fields."""
3644

tests/fields/test_enum_field.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33
import pytest
44
from bson import InvalidDocument
55

6-
from mongoengine import Document, EnumField, ValidationError
6+
from mongoengine import (
7+
DictField,
8+
Document,
9+
EnumField,
10+
ListField,
11+
ValidationError,
12+
)
713
from tests.utils import MongoDBTestCase, get_as_pymongo
814

915

@@ -21,6 +27,12 @@ class ModelWithEnum(Document):
2127
status = EnumField(Status)
2228

2329

30+
class ModelComplexEnum(Document):
31+
status = EnumField(Status)
32+
statuses = ListField(EnumField(Status))
33+
color_mapping = DictField(EnumField(Color))
34+
35+
2436
class TestStringEnumField(MongoDBTestCase):
2537
def test_storage(self):
2638
model = ModelWithEnum(status=Status.NEW).save()
@@ -101,6 +113,42 @@ def test_wrong_choices(self):
101113
with pytest.raises(ValueError, match="Invalid choices"):
102114
EnumField(Status, choices=[Status.DONE, Color.RED])
103115

116+
def test_embedding_in_complex_field(self):
117+
ModelComplexEnum.drop_collection()
118+
model = ModelComplexEnum(
119+
status="new", statuses=["new"], color_mapping={"red": 1}
120+
).save()
121+
assert model.status == Status.NEW
122+
assert model.statuses == [Status.NEW]
123+
assert model.color_mapping == {"red": Color.RED}
124+
125+
model.reload()
126+
assert model.status == Status.NEW
127+
assert model.statuses == [Status.NEW]
128+
assert model.color_mapping == {"red": Color.RED}
129+
130+
model.status = "done"
131+
model.color_mapping = {"blue": 2}
132+
model.statuses = ["new", "done"]
133+
model.save()
134+
assert model.status == Status.DONE
135+
assert model.statuses == [Status.NEW, Status.DONE]
136+
assert model.color_mapping == {"blue": Color.BLUE}
137+
138+
model.reload()
139+
assert model.status == Status.DONE
140+
assert model.color_mapping == {"blue": Color.BLUE}
141+
assert model.statuses == [Status.NEW, Status.DONE]
142+
143+
with pytest.raises(ValidationError, match="must be one of ..Status"):
144+
model.statuses = [1]
145+
model.save()
146+
147+
model.statuses = ["done"]
148+
model.color_mapping = {"blue": "done"}
149+
with pytest.raises(ValidationError, match="must be one of ..Color"):
150+
model.save()
151+
104152

105153
class ModelWithColor(Document):
106154
color = EnumField(Color, default=Color.RED)
@@ -124,10 +172,7 @@ def test_storage_enum_with_int(self):
124172
assert get_as_pymongo(model) == {"_id": model.id, "color": 2}
125173

126174
def test_validate_model(self):
127-
with pytest.raises(ValidationError, match="Value must be one of"):
128-
ModelWithColor(color=3).validate()
129-
130-
with pytest.raises(ValidationError, match="Value must be one of"):
175+
with pytest.raises(ValidationError, match="must be one of ..Color"):
131176
ModelWithColor(color="wrong_type").validate()
132177

133178

tests/fields/test_fields.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,46 @@
4444

4545

4646
class TestField(MongoDBTestCase):
47+
def test_constructor_set_historical_behavior_is_kept(self):
48+
class MyDoc(Document):
49+
oid = ObjectIdField()
50+
51+
doc = MyDoc()
52+
doc.oid = str(ObjectId())
53+
assert isinstance(doc.oid, str)
54+
55+
# not modified on save (historical behavior)
56+
doc.save()
57+
assert isinstance(doc.oid, str)
58+
59+
# reloading goes through constructor so it is expected to go through to_python
60+
doc.reload()
61+
assert isinstance(doc.oid, ObjectId)
62+
63+
def test_constructor_set_list_field_historical_behavior_is_kept(self):
64+
# Although the behavior is not consistent between regular field and a ListField
65+
# This is the historical behavior so we must make sure we don't modify it (unless consciously done of course)
66+
67+
class MyOIDSDoc(Document):
68+
oids = ListField(ObjectIdField())
69+
70+
# constructor goes through to_python so casting occurs
71+
doc = MyOIDSDoc(oids=[str(ObjectId())])
72+
assert isinstance(doc.oids[0], ObjectId)
73+
74+
# constructor goes through to_python so casting occurs
75+
doc = MyOIDSDoc()
76+
doc.oids = [str(ObjectId())]
77+
assert isinstance(doc.oids[0], str)
78+
79+
doc.save()
80+
assert isinstance(doc.oids[0], str)
81+
82+
# reloading goes through constructor so it is expected to go through to_python
83+
# and cast
84+
doc.reload()
85+
assert isinstance(doc.oids[0], ObjectId)
86+
4787
def test_default_values_nothing_set(self):
4888
"""Ensure that default field values are used when creating
4989
a document.

tests/fields/test_object_id_field.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import pytest
2+
from bson import ObjectId
3+
4+
from mongoengine import Document, ObjectIdField, ValidationError
5+
from tests.utils import MongoDBTestCase, get_as_pymongo
6+
7+
8+
class TestObjectIdField(MongoDBTestCase):
9+
def test_storage(self):
10+
class MyDoc(Document):
11+
oid = ObjectIdField()
12+
13+
doc = MyDoc(oid=ObjectId())
14+
doc.save()
15+
assert get_as_pymongo(doc) == {"_id": doc.id, "oid": doc.oid}
16+
17+
def test_constructor_converts_str_to_ObjectId(self):
18+
class MyDoc(Document):
19+
oid = ObjectIdField()
20+
21+
doc = MyDoc(oid=str(ObjectId()))
22+
assert isinstance(doc.oid, ObjectId)
23+
24+
def test_validation_works(self):
25+
class MyDoc(Document):
26+
oid = ObjectIdField()
27+
28+
doc = MyDoc(oid="not-an-oid!")
29+
with pytest.raises(ValidationError, match="Invalid ObjectID"):
30+
doc.save()

0 commit comments

Comments
 (0)