diff --git a/src/msgspec/_core.c b/src/msgspec/_core.c index 6b6ad1c4..51c9b88a 100644 --- a/src/msgspec/_core.c +++ b/src/msgspec/_core.c @@ -4948,7 +4948,7 @@ typenode_collect_type(TypeNodeCollectState *state, PyObject *obj) { ) { out = typenode_collect_struct(state, t); } - else if (Py_TYPE(t) == state->mod->EnumMetaType) { + else if (PyType_IsSubtype(Py_TYPE(t), state->mod->EnumMetaType)) { out = typenode_collect_enum(state, t); } else if (origin == (PyObject*)(&PyDict_Type)) { @@ -13310,7 +13310,7 @@ mpack_encode_uncommon(EncoderState *self, PyTypeObject *type, PyObject *obj) else if (type == &Raw_Type) { return mpack_encode_raw(self, obj); } - else if (Py_TYPE(type) == self->mod->EnumMetaType) { + else if (PyType_IsSubtype(Py_TYPE(type), self->mod->EnumMetaType)) { return mpack_encode_enum(self, obj); } else if (type == (PyTypeObject *)(self->mod->DecimalType)) { @@ -13987,7 +13987,7 @@ json_encode_dict_key_noinline(EncoderState *self, PyObject *obj) { else if (type == &PyFloat_Type) { return json_encode_float_as_str(self, obj); } - else if (Py_TYPE(type) == self->mod->EnumMetaType) { + else if (PyType_IsSubtype(Py_TYPE(type), self->mod->EnumMetaType)) { return json_encode_enum(self, obj, true); } else if (type == PyDateTimeAPI->DateTimeType) { @@ -14414,7 +14414,7 @@ json_encode_uncommon(EncoderState *self, PyTypeObject *type, PyObject *obj) { else if (type == &Raw_Type) { return json_encode_raw(self, obj); } - else if (Py_TYPE(type) == self->mod->EnumMetaType) { + else if (PyType_IsSubtype(Py_TYPE(type), self->mod->EnumMetaType)) { return json_encode_enum(self, obj, false); } else if (PyType_IsSubtype(type, (PyTypeObject *)(self->mod->UUIDType))) { @@ -20119,7 +20119,7 @@ to_builtins(ToBuiltinsState *self, PyObject *obj, bool is_key) { else if (PyType_IsSubtype(Py_TYPE(type), &StructMetaType)) { return to_builtins_struct(self, obj, is_key); } - else if (Py_TYPE(type) == self->mod->EnumMetaType) { + else if (PyType_IsSubtype(Py_TYPE(type), self->mod->EnumMetaType)) { return to_builtins_enum(self, obj); } else if (is_key & PyUnicode_Check(obj)) { @@ -21888,7 +21888,7 @@ convert( else if (pytype == (PyTypeObject *)self->mod->DecimalType) { return convert_decimal(self, obj, type, path); } - else if (Py_TYPE(pytype) == self->mod->EnumMetaType) { + else if (PyType_IsSubtype(Py_TYPE(pytype), self->mod->EnumMetaType)) { return convert_enum(self, obj, type, path); } else if (pytype == &Ext_Type) { diff --git a/tests/test_common.py b/tests/test_common.py index a1bc5093..2b9f5a5f 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -647,6 +647,20 @@ class Test(base_cls): assert proto.encode(Test.A) == proto.encode("apple") + @pytest.mark.parametrize( + "base_metacls", [enum.EnumMeta] + ([enum.EnumType] if PY311 else []) + ) + def test_enum_with_custom_enum_metaclass(self, proto, base_metacls): + class ChoicesMeta(base_metacls): + """My custom enum metaclass""" + + class Test(enum.Enum, metaclass=ChoicesMeta): + A = "apple" + B = "banana" + + assert proto.encode(Test.A) == proto.encode("apple") + assert proto.decode(proto.encode("apple"), type=Test) is Test.A + @pytest.mark.parametrize("base_cls", [StrEnum, enum.Enum]) def test_decode(self, proto, base_cls): class Test(base_cls): diff --git a/tests/test_convert.py b/tests/test_convert.py index afb668ee..ad1fe30e 100644 --- a/tests/test_convert.py +++ b/tests/test_convert.py @@ -879,6 +879,25 @@ def _missing_(cls, val): with pytest.raises(ValidationError, match="Invalid enum value 6"): msgspec.convert(6, Ex) + @pytest.mark.parametrize( + "base_metacls", [enum.EnumMeta] + ([enum.EnumType] if PY311 else []) + ) + def test_enum_with_custom_metaclass(self, base_metacls): + class ChoicesMeta(base_metacls): + """My custom enum metaclass""" + + class Test(enum.Enum, metaclass=ChoicesMeta): + A = "apple" + B = "banana" + + assert convert(Test.A, Test) is Test.A + assert convert("apple", Test) is Test.A + assert convert("banana", Test) is Test.B + with pytest.raises(ValidationError, match="Invalid enum value 'ceeee'"): + convert("ceeee", Test) + with pytest.raises(ValidationError, match="Expected `str`, got `int`"): + convert(1, Test) + class TestLiteral: def test_str_literal(self):