Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,15 +523,17 @@ def __init_subclass__(cls, mixin: bool = False, **kwargs):

new_backend_vars = {
name: value if not isinstance(value, Field) else value.default_value()
for name, value in list(cls.__dict__.items())
if types.is_backend_base_variable(name, cls)
for mixin_cls in [*cls._mixins(), cls]
for name, value in list(mixin_cls.__dict__.items())
if types.is_backend_base_variable(name, mixin_cls)
}
# Add annotated backend vars that may not have a default value.
new_backend_vars.update({
name: cls._get_var_default(name, annotation_value)
for name, annotation_value in cls._get_type_hints().items()
for mixin_cls in [*cls._mixins(), cls]
for name, annotation_value in mixin_cls._get_type_hints().items()
if name not in new_backend_vars
and types.is_backend_base_variable(name, cls)
and types.is_backend_base_variable(name, mixin_cls)
})

cls.backend_vars = {
Expand Down Expand Up @@ -579,9 +581,6 @@ def __init_subclass__(cls, mixin: bool = False, **kwargs):
cls.computed_vars[name] = newcv
cls.vars[name] = newcv
continue
if types.is_backend_base_variable(name, mixin_cls):
cls.backend_vars[name] = copy.deepcopy(value)
continue
if events.get(name) is not None:
continue
if not cls._item_is_event_handler(name, value):
Expand Down
20 changes: 17 additions & 3 deletions tests/units/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,14 @@ class Object(Base):
prop2: str = "hello"


class TestState(BaseState):
class TestMixin(BaseState, mixin=True):
"""A test mixin."""

mixin: rx.Field[str] = rx.field("mixin_value")
_mixin_backend: rx.Field[int] = rx.field(default_factory=lambda: 10)


class TestState(TestMixin, BaseState): # pyright: ignore[reportUnsafeMultipleInheritance]
"""A test state."""

# Set this class as not test one
Expand Down Expand Up @@ -342,6 +349,7 @@ def test_class_vars(test_state):
"fig",
"dt",
"asynctest",
"mixin",
}


Expand All @@ -367,7 +375,7 @@ def test_event_handlers(test_state):
assert all(key in cls.event_handlers for key in expected_keys)


def test_default_value(test_state):
def test_default_value(test_state: TestState):
"""Test that the default value of a var is correct.

Args:
Expand All @@ -378,6 +386,10 @@ def test_default_value(test_state):
assert test_state.key == ""
assert test_state.sum == 3.15
assert test_state.upper == ""
assert test_state._backend == 0
assert test_state.mixin == "mixin_value"
assert test_state._mixin_backend == 10
assert test_state.array == [1, 2, 3.15]


def test_computed_vars(test_state):
Expand Down Expand Up @@ -735,7 +747,7 @@ def test_set_dirty_substate(
assert grandchild_state.dirty_vars == set()


def test_reset(test_state, child_state):
def test_reset(test_state: TestState, child_state: ChildState):
"""Test resetting the state.

Args:
Expand Down Expand Up @@ -771,6 +783,8 @@ def test_reset(test_state, child_state):
"mapping",
"dt",
"_backend",
"mixin",
"_mixin_backend",
"asynctest",
}

Expand Down
1 change: 1 addition & 0 deletions tests/units/utils/test_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,7 @@ def test_format_query_params(input, output):
"key" + FIELD_MARKER: "",
"map_key" + FIELD_MARKER: "a",
"mapping" + FIELD_MARKER: {"a": [1, 2, 3], "b": [4, 5, 6]},
"mixin" + FIELD_MARKER: "mixin_value",
"num1" + FIELD_MARKER: 0,
"num2" + FIELD_MARKER: 3.15,
"obj" + FIELD_MARKER: {"prop1": 42, "prop2": "hello"},
Expand Down