Skip to content

Commit 1b3ba8f

Browse files
authored
implement screenshot plugin (#5691)
* implement screenshot plugin * dang it darglint * woops
1 parent c106be7 commit 1b3ba8f

File tree

5 files changed

+167
-2
lines changed

5 files changed

+167
-2
lines changed

reflex/app.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,11 @@ def __call__(self) -> ASGIApp:
604604

605605
self._compile(prerender_routes=is_prod_mode())
606606

607+
config = get_config()
608+
609+
for plugin in config.plugins:
610+
plugin.post_compile(app=self)
611+
607612
# We will not be making more vars, so we can clear the global cache to free up memory.
608613
GLOBAL_CACHE.clear()
609614

reflex/istate/manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ async def set_state(self, token: str, state: BaseState):
143143
token: The token to set the state for.
144144
state: The state to set.
145145
"""
146+
token = _split_substate_key(token)[0]
147+
self.states[token] = state
146148

147149
@override
148150
@contextlib.asynccontextmanager
@@ -165,7 +167,6 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
165167
async with self._states_locks[token]:
166168
state = await self.get_state(token)
167169
yield state
168-
await self.set_state(token, state)
169170

170171

171172
def _default_token_expiration() -> int:

reflex/plugins/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Reflex Plugin System."""
22

3+
from ._screenshot import ScreenshotPlugin as _ScreenshotPlugin
34
from .base import CommonContext, Plugin, PreCompileContext
45
from .sitemap import SitemapPlugin
56
from .tailwind_v3 import TailwindV3Plugin
@@ -12,4 +13,5 @@
1213
"SitemapPlugin",
1314
"TailwindV3Plugin",
1415
"TailwindV4Plugin",
16+
"_ScreenshotPlugin",
1517
]

reflex/plugins/_screenshot.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
"""Plugin to enable screenshot functionality."""
2+
3+
from typing import TYPE_CHECKING
4+
5+
from reflex.plugins.base import Plugin as BasePlugin
6+
7+
if TYPE_CHECKING:
8+
from starlette.requests import Request
9+
from starlette.responses import Response
10+
from typing_extensions import Unpack
11+
12+
from reflex.app import App
13+
from reflex.plugins.base import PostCompileContext
14+
from reflex.state import BaseState
15+
16+
ACTIVE_CONNECTIONS = "/_active_connections"
17+
CLONE_STATE = "/_clone_state"
18+
19+
20+
def _deep_copy(state: "BaseState") -> "BaseState":
21+
"""Create a deep copy of the state.
22+
23+
Args:
24+
state: The state to copy.
25+
26+
Returns:
27+
A deep copy of the state.
28+
"""
29+
import copy
30+
31+
copy_of_state = copy.deepcopy(state)
32+
33+
def copy_substate(substate: "BaseState") -> "BaseState":
34+
substate_copy = _deep_copy(substate)
35+
36+
substate_copy.parent_state = copy_of_state
37+
38+
return substate_copy
39+
40+
copy_of_state.substates = {
41+
substate_name: copy_substate(substate)
42+
for substate_name, substate in state.substates.items()
43+
}
44+
45+
return copy_of_state
46+
47+
48+
class ScreenshotPlugin(BasePlugin):
49+
"""Plugin to handle screenshot functionality."""
50+
51+
def post_compile(self, **context: "Unpack[PostCompileContext]") -> None:
52+
"""Called after the compilation of the plugin.
53+
54+
Args:
55+
context: The context for the plugin.
56+
"""
57+
app = context["app"]
58+
self._add_active_connections_endpoint(app)
59+
self._add_clone_state_endpoint(app)
60+
61+
@staticmethod
62+
def _add_active_connections_endpoint(app: "App") -> None:
63+
"""Add an endpoint to the app that returns the active connections.
64+
65+
Args:
66+
app: The application instance to which the endpoint will be added.
67+
"""
68+
if not app._api:
69+
return
70+
71+
async def active_connections(_request: "Request") -> "Response":
72+
from starlette.responses import JSONResponse
73+
74+
if not app.event_namespace:
75+
return JSONResponse({})
76+
77+
return JSONResponse(app.event_namespace.token_to_sid)
78+
79+
app._api.add_route(
80+
ACTIVE_CONNECTIONS,
81+
active_connections,
82+
methods=["GET"],
83+
)
84+
85+
@staticmethod
86+
def _add_clone_state_endpoint(app: "App") -> None:
87+
"""Add an endpoint to the app that clones the current state.
88+
89+
Args:
90+
app: The application instance to which the endpoint will be added.
91+
"""
92+
if not app._api:
93+
return
94+
95+
async def clone_state(request: "Request") -> "Response":
96+
import uuid
97+
98+
from starlette.responses import JSONResponse
99+
100+
from reflex.state import _substate_key
101+
102+
if not app.event_namespace:
103+
return JSONResponse({})
104+
105+
token_to_clone = await request.json()
106+
107+
if not isinstance(token_to_clone, str):
108+
return JSONResponse(
109+
{"error": "Token to clone must be a string."}, status_code=400
110+
)
111+
112+
old_state = await app.state_manager.get_state(token_to_clone)
113+
114+
new_state = _deep_copy(old_state)
115+
116+
new_token = uuid.uuid4().hex
117+
118+
all_states = [new_state]
119+
120+
found_new = True
121+
122+
while found_new:
123+
found_new = False
124+
125+
for state in all_states:
126+
for substate in state.substates.values():
127+
substate._was_touched = True
128+
129+
if substate not in all_states:
130+
all_states.append(substate)
131+
132+
found_new = True
133+
134+
await app.state_manager.set_state(
135+
_substate_key(new_token, new_state), new_state
136+
)
137+
138+
return JSONResponse(new_token)
139+
140+
app._api.add_route(
141+
CLONE_STATE,
142+
clone_state,
143+
methods=["POST"],
144+
)

reflex/plugins/base.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing_extensions import Unpack
88

99
if TYPE_CHECKING:
10-
from reflex.app import UnevaluatedPage
10+
from reflex.app import App, UnevaluatedPage
1111

1212

1313
class CommonContext(TypedDict):
@@ -44,6 +44,12 @@ class PreCompileContext(CommonContext):
4444
unevaluated_pages: Sequence["UnevaluatedPage"]
4545

4646

47+
class PostCompileContext(CommonContext):
48+
"""Context for post-compile hooks."""
49+
50+
app: "App"
51+
52+
4753
class Plugin:
4854
"""Base class for all plugins."""
4955

@@ -104,6 +110,13 @@ def pre_compile(self, **context: Unpack[PreCompileContext]) -> None:
104110
context: The context for the plugin.
105111
"""
106112

113+
def post_compile(self, **context: Unpack[PostCompileContext]) -> None:
114+
"""Called after the compilation of the plugin.
115+
116+
Args:
117+
context: The context for the plugin.
118+
"""
119+
107120
def __repr__(self):
108121
"""Return a string representation of the plugin.
109122

0 commit comments

Comments
 (0)