Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions msgspec/_core.c
Original file line number Diff line number Diff line change
Expand Up @@ -4867,7 +4867,7 @@ typenode_collect_type(TypeNodeCollectState *state, PyObject *obj) {
) {
out = typenode_collect_struct(state, t);
}
else if (Py_TYPE(t) == state->mod->EnumMetaType) {
else if (PyType_IsSubtype(Py_TYPE(t), state->mod->EnumMetaType)) {
out = typenode_collect_enum(state, t);
}
else if (origin == (PyObject*)(&PyDict_Type)) {
Expand Down Expand Up @@ -13101,7 +13101,7 @@ mpack_encode_uncommon(EncoderState *self, PyTypeObject *type, PyObject *obj)
else if (type == &Raw_Type) {
return mpack_encode_raw(self, obj);
}
else if (Py_TYPE(type) == self->mod->EnumMetaType) {
else if (PyType_IsSubtype(Py_TYPE(type), self->mod->EnumMetaType)) {
return mpack_encode_enum(self, obj);
}
else if (type == (PyTypeObject *)(self->mod->DecimalType)) {
Expand Down Expand Up @@ -13774,7 +13774,7 @@ json_encode_dict_key_noinline(EncoderState *self, PyObject *obj) {
else if (type == &PyFloat_Type) {
return json_encode_float_as_str(self, obj);
}
else if (Py_TYPE(type) == self->mod->EnumMetaType) {
else if (PyType_IsSubtype(Py_TYPE(type), self->mod->EnumMetaType)) {
return json_encode_enum(self, obj, true);
}
else if (type == PyDateTimeAPI->DateTimeType) {
Expand Down Expand Up @@ -14178,7 +14178,7 @@ json_encode_uncommon(EncoderState *self, PyTypeObject *type, PyObject *obj) {
else if (type == &Raw_Type) {
return json_encode_raw(self, obj);
}
else if (Py_TYPE(type) == self->mod->EnumMetaType) {
else if (PyType_IsSubtype(Py_TYPE(type), self->mod->EnumMetaType)) {
return json_encode_enum(self, obj, false);
}
else if (PyType_IsSubtype(type, (PyTypeObject *)(self->mod->UUIDType))) {
Expand Down Expand Up @@ -19856,7 +19856,7 @@ to_builtins(ToBuiltinsState *self, PyObject *obj, bool is_key) {
else if (Py_TYPE(type) == &StructMetaType) {
return to_builtins_struct(self, obj, is_key);
}
else if (Py_TYPE(type) == self->mod->EnumMetaType) {
else if (PyType_IsSubtype(Py_TYPE(type), self->mod->EnumMetaType)) {
return to_builtins_enum(self, obj);
}
else if (is_key & PyUnicode_Check(obj)) {
Expand Down Expand Up @@ -21620,7 +21620,7 @@ convert(
else if (pytype == (PyTypeObject *)self->mod->DecimalType) {
return convert_decimal(self, obj, type, path);
}
else if (Py_TYPE(pytype) == self->mod->EnumMetaType) {
else if (PyType_IsSubtype(Py_TYPE(pytype), self->mod->EnumMetaType)) {
return convert_enum(self, obj, type, path);
}
else if (pytype == &Ext_Type) {
Expand Down
14 changes: 14 additions & 0 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,20 @@ class Test(base_cls):

assert proto.encode(Test.A) == proto.encode("apple")

@pytest.mark.parametrize(
"base_metacls", [enum.EnumMeta] + ([enum.EnumType] if PY311 else [])
)
def test_enum_with_custom_enum_metaclass(self, proto, base_metacls):
class ChoicesMeta(base_metacls):
"""My custom enum metaclass"""

class Test(enum.Enum, metaclass=ChoicesMeta):
A = "apple"
B = "banana"

assert proto.encode(Test.A) == proto.encode("apple")
assert proto.decode(proto.encode("apple"), type=Test) is Test.A

@pytest.mark.parametrize("base_cls", [StrEnum, enum.Enum])
def test_decode(self, proto, base_cls):
class Test(base_cls):
Expand Down
19 changes: 19 additions & 0 deletions tests/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,25 @@ def _missing_(cls, val):
with pytest.raises(ValidationError, match="Invalid enum value 6"):
msgspec.convert(6, Ex)

@pytest.mark.parametrize(
"base_metacls", [enum.EnumMeta] + ([enum.EnumType] if PY311 else [])
)
def test_enum_with_custom_metaclass(self, base_metacls):
class ChoicesMeta(base_metacls):
"""My custom enum metaclass"""

class Test(enum.Enum, metaclass=ChoicesMeta):
A = "apple"
B = "banana"

assert convert(Test.A, Test) is Test.A
assert convert("apple", Test) is Test.A
assert convert("banana", Test) is Test.B
with pytest.raises(ValidationError, match="Invalid enum value 'ceeee'"):
convert("ceeee", Test)
with pytest.raises(ValidationError, match="Expected `str`, got `int`"):
convert(1, Test)


class TestLiteral:
def test_str_literal(self):
Expand Down