Skip to content

Commit f3df526

Browse files
authored
Track all rx.memo components (#5172)
* Automatically compile any `@rx.memo` decorated function If a component is memoized anywhere in the app, include the component in the generated output. Avoid extra component tree walk, since we can know immediately what all of the custom components are. Any perf optimization gained by not compiling unused memo functions is handily saved by avoiding the tree walk. * dynamic: bundle local $/utils/components (rx.memo) module allow `@rx.memo` decorated functions to be referenced by dynamic components (which allows working around most limitations with dynamic components). move special cases for $/ prefix modules to `_normalize_library_path` and include them in `bundled_libraries` so they can be checked at runtime. * fixup memo registry * Pass dummy EventSpec to rx.memo function
1 parent fae7b3c commit f3df526

File tree

6 files changed

+61
-97
lines changed

6 files changed

+61
-97
lines changed

reflex/app.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from reflex.components.base.fragment import Fragment
4646
from reflex.components.base.strict_mode import StrictMode
4747
from reflex.components.component import (
48+
CUSTOM_COMPONENTS,
4849
Component,
4950
ComponentStyle,
5051
evaluate_style_namespaces,
@@ -1222,9 +1223,8 @@ def get_compilation_time() -> str:
12221223

12231224
progress.advance(task)
12241225

1225-
# Track imports and custom components found.
1226+
# Track imports found.
12261227
all_imports = {}
1227-
custom_components = set()
12281228

12291229
# This has to happen before compiling stateful components as that
12301230
# prevents recursive functions from reaching all components.
@@ -1235,9 +1235,6 @@ def get_compilation_time() -> str:
12351235
# Add the app wrappers from this component.
12361236
app_wrappers.update(component._get_all_app_wrap_components())
12371237

1238-
# Add the custom components from the page to the set.
1239-
custom_components |= component._get_all_custom_components()
1240-
12411238
if (toaster := self.toaster) is not None:
12421239
from reflex.components.component import memo
12431240

@@ -1255,9 +1252,6 @@ def memoized_toast_provider():
12551252
if component is not None:
12561253
app_wrappers[key] = component
12571254

1258-
for component in app_wrappers.values():
1259-
custom_components |= component._get_all_custom_components()
1260-
12611255
if self.error_boundary:
12621256
from reflex.compiler.compiler import into_component
12631257

@@ -1382,7 +1376,7 @@ def _submit_work(fn: Callable[..., tuple[str, str]], *args, **kwargs):
13821376
custom_components_output,
13831377
custom_components_result,
13841378
custom_components_imports,
1385-
) = compiler.compile_components(custom_components)
1379+
) = compiler.compile_components(set(CUSTOM_COMPONENTS.values()))
13861380
compile_results.append((custom_components_output, custom_components_result))
13871381
all_imports.update(custom_components_imports)
13881382

reflex/compiler/compiler.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def _normalize_library_name(lib: str) -> str:
5656
"""
5757
if lib == "react":
5858
return "React"
59-
return lib.replace("@", "").replace("/", "_").replace("-", "_")
59+
return lib.replace("$/", "").replace("@", "").replace("/", "_").replace("-", "_")
6060

6161

6262
def _compile_app(app_root: Component) -> str:
@@ -72,9 +72,6 @@ def _compile_app(app_root: Component) -> str:
7272

7373
window_libraries = [
7474
(_normalize_library_name(name), name) for name in bundled_libraries
75-
] + [
76-
("utils_context", f"$/{constants.Dirs.UTILS}/context"),
77-
("utils_state", f"$/{constants.Dirs.UTILS}/state"),
7875
]
7976

8077
return templates.APP_ROOT.render(

reflex/components/component.py

Lines changed: 39 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1647,32 +1647,6 @@ def _get_all_refs(self) -> set[str]:
16471647

16481648
return refs
16491649

1650-
def _get_all_custom_components(
1651-
self, seen: set[str] | None = None
1652-
) -> set[CustomComponent]:
1653-
"""Get all the custom components used by the component.
1654-
1655-
Args:
1656-
seen: The tags of the components that have already been seen.
1657-
1658-
Returns:
1659-
The set of custom components.
1660-
"""
1661-
custom_components = set()
1662-
1663-
# Store the seen components in a set to avoid infinite recursion.
1664-
if seen is None:
1665-
seen = set()
1666-
for child in self.children:
1667-
# Skip BaseComponent and StatefulComponent children.
1668-
if not isinstance(child, Component):
1669-
continue
1670-
custom_components |= child._get_all_custom_components(seen=seen)
1671-
for component in self._get_components_in_props():
1672-
if isinstance(component, Component) and component.tag is not None:
1673-
custom_components |= component._get_all_custom_components(seen=seen)
1674-
return custom_components
1675-
16761650
@property
16771651
def import_var(self):
16781652
"""The tag to import.
@@ -1857,37 +1831,6 @@ def get_props(cls) -> set[str]:
18571831
"""
18581832
return set()
18591833

1860-
def _get_all_custom_components(
1861-
self, seen: set[str] | None = None
1862-
) -> set[CustomComponent]:
1863-
"""Get all the custom components used by the component.
1864-
1865-
Args:
1866-
seen: The tags of the components that have already been seen.
1867-
1868-
Raises:
1869-
ValueError: If the tag is not set.
1870-
1871-
Returns:
1872-
The set of custom components.
1873-
"""
1874-
if self.tag is None:
1875-
raise ValueError("The tag must be set.")
1876-
1877-
# Store the seen components in a set to avoid infinite recursion.
1878-
if seen is None:
1879-
seen = set()
1880-
custom_components = {self} | super()._get_all_custom_components(seen=seen)
1881-
1882-
# Avoid adding the same component twice.
1883-
if self.tag not in seen:
1884-
seen.add(self.tag)
1885-
custom_components |= self.get_component(self)._get_all_custom_components(
1886-
seen=seen
1887-
)
1888-
1889-
return custom_components
1890-
18911834
@staticmethod
18921835
def _get_event_spec_from_args_spec(name: str, event: EventChain) -> Callable:
18931836
"""Get the event spec from the args spec.
@@ -1951,6 +1894,42 @@ def get_component(self) -> Component:
19511894
return self.component_fn(*self.get_prop_vars())
19521895

19531896

1897+
CUSTOM_COMPONENTS: dict[str, CustomComponent] = {}
1898+
1899+
1900+
def _register_custom_component(
1901+
component_fn: Callable[..., Component],
1902+
):
1903+
"""Register a custom component to be compiled.
1904+
1905+
Args:
1906+
component_fn: The function that creates the component.
1907+
1908+
Raises:
1909+
TypeError: If the tag name cannot be determined.
1910+
"""
1911+
dummy_props = {
1912+
prop: (
1913+
Var(
1914+
"",
1915+
_var_type=annotation,
1916+
)
1917+
if not types.safe_issubclass(annotation, EventHandler)
1918+
else EventSpec(handler=EventHandler(fn=lambda: []))
1919+
)
1920+
for prop, annotation in typing.get_type_hints(component_fn).items()
1921+
if prop != "return"
1922+
}
1923+
dummy_component = CustomComponent._create(
1924+
children=[],
1925+
component_fn=component_fn,
1926+
**dummy_props,
1927+
)
1928+
if dummy_component.tag is None:
1929+
raise TypeError(f"Could not determine the tag name for {component_fn!r}")
1930+
CUSTOM_COMPONENTS[dummy_component.tag] = dummy_component
1931+
1932+
19541933
def custom_component(
19551934
component_fn: Callable[..., Component],
19561935
) -> Callable[..., CustomComponent]:
@@ -1971,6 +1950,9 @@ def wrapper(*children, **props) -> CustomComponent:
19711950
children=list(children), component_fn=component_fn, **props
19721951
)
19731952

1953+
# Register this component so it can be compiled.
1954+
_register_custom_component(component_fn)
1955+
19741956
return wrapper
19751957

19761958

reflex/components/dynamic.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,15 @@ def get_cdn_url(lib: str) -> str:
2626
return f"https://cdn.jsdelivr.net/npm/{lib}" + "/+esm"
2727

2828

29-
bundled_libraries = {"react", "@radix-ui/themes", "@emotion/react", "next/link"}
29+
bundled_libraries = {
30+
"react",
31+
"@radix-ui/themes",
32+
"@emotion/react",
33+
"next/link",
34+
f"$/{constants.Dirs.UTILS}/context",
35+
f"$/{constants.Dirs.UTILS}/state",
36+
f"$/{constants.Dirs.UTILS}/components",
37+
}
3038

3139

3240
def bundle_library(component: Union["Component", str]):

reflex/components/markdown/markdown.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -192,27 +192,6 @@ def create(cls, *children, **props) -> Component:
192192
**props,
193193
)
194194

195-
def _get_all_custom_components(
196-
self, seen: set[str] | None = None
197-
) -> set[CustomComponent]:
198-
"""Get all the custom components used by the component.
199-
200-
Args:
201-
seen: The tags of the components that have already been seen.
202-
203-
Returns:
204-
The set of custom components.
205-
"""
206-
custom_components = super()._get_all_custom_components(seen=seen)
207-
208-
# Get the custom components for each tag.
209-
for component in self.component_map.values():
210-
custom_components |= component(_MOCK_ARG)._get_all_custom_components(
211-
seen=seen
212-
)
213-
214-
return custom_components
215-
216195
def add_imports(self) -> ImportDict | list[ImportDict]:
217196
"""Add imports for the markdown component.
218197

tests/units/components/test_component.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55

66
import reflex as rx
77
from reflex.base import Base
8-
from reflex.compiler.compiler import compile_components
8+
from reflex.compiler.utils import compile_custom_component
99
from reflex.components.base.bare import Bare
1010
from reflex.components.base.fragment import Fragment
1111
from reflex.components.component import (
12+
CUSTOM_COMPONENTS,
1213
Component,
1314
CustomComponent,
1415
StatefulComponent,
@@ -877,7 +878,7 @@ def test_create_custom_component(my_component):
877878
component = rx.memo(my_component)(prop1="test", prop2=1)
878879
assert component.tag == "MyComponent"
879880
assert component.get_props() == {"prop1", "prop2"}
880-
assert component._get_all_custom_components() == {component}
881+
assert component.tag in CUSTOM_COMPONENTS
881882

882883

883884
def test_custom_component_hash(my_component):
@@ -1801,10 +1802,13 @@ def outer(c: Component):
18011802

18021803
# Inner is not imported directly, but it is imported by the custom component.
18031804
assert "inner" not in custom_comp._get_all_imports()
1805+
assert "outer" not in custom_comp._get_all_imports()
18041806

18051807
# The imports are only resolved during compilation.
1806-
_, _, imports_inner = compile_components(custom_comp._get_all_custom_components())
1808+
custom_comp.get_component(custom_comp)
1809+
_, imports_inner = compile_custom_component(custom_comp)
18071810
assert "inner" in imports_inner
1811+
assert "outer" not in imports_inner
18081812

18091813
outer_comp = outer(c=wrapper())
18101814

@@ -1813,8 +1817,8 @@ def outer(c: Component):
18131817
assert "other" not in outer_comp._get_all_imports()
18141818

18151819
# The imports are only resolved during compilation.
1816-
_, _, imports_outer = compile_components(outer_comp._get_all_custom_components())
1817-
assert "inner" in imports_outer
1820+
_, imports_outer = compile_custom_component(outer_comp)
1821+
assert "inner" not in imports_outer
18181822
assert "other" in imports_outer
18191823

18201824

0 commit comments

Comments
 (0)