Skip to content

Commit d1578ba

Browse files
committed
✨ Patch feature
1 parent af2b9e6 commit d1578ba

File tree

3 files changed

+180
-40
lines changed

3 files changed

+180
-40
lines changed

main.py

Lines changed: 93 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import importlib
33
import importlib.util
44
import asyncio
5+
import os
56

67
from inspect import signature
78
from quart import Quart
@@ -10,23 +11,33 @@
1011
from src.logging import logger, patch
1112
from os.path import splitext, basename
1213
from types import ModuleType
13-
from typing import Any, Coroutine, Callable
14+
from typing import Any, Callable, TypedDict, TYPE_CHECKING
15+
from collections.abc import Coroutine
1416
from src.utils import validate_module, unzip_extensions
1517

18+
if TYPE_CHECKING:
19+
FunctionConfig = TypedDict("FunctionConfig", {"enabled": bool})
20+
FunctionlistType = list[tuple[Callable[..., Any], FunctionConfig]]
21+
1622

17-
# noinspection PyUnusedLocal
1823
async def start_bot(bot: discord.Bot, token: str):
19-
await bot.start(config["bot"]["token"])
24+
await bot.start(token)
2025

2126

2227
async def start_backend(app: Quart, bot: discord.Bot, token: str):
2328
from hypercorn.config import Config
2429
from hypercorn.logging import Logger as HypercornLogger
25-
from hypercorn.asyncio import serve
30+
from hypercorn.asyncio import serve # pyright: ignore [reportUnknownVariableType]
2631

2732
class CustomLogger(HypercornLogger):
28-
def __init__(self, *args, **kwargs) -> None:
29-
super().__init__(*args, **kwargs)
33+
def __init__(
34+
self,
35+
*args, # pyright: ignore [reportUnknownParameterType,reportMissingParameterType]
36+
**kwargs, # pyright: ignore [reportUnknownParameterType,reportMissingParameterType]
37+
) -> None:
38+
super().__init__(
39+
*args, **kwargs # pyright: ignore [reportUnknownArgumentType]
40+
)
3041
if self.error_logger:
3142
patch(self.error_logger)
3243
if self.access_logger:
@@ -42,7 +53,7 @@ def __init__(self, *args, **kwargs) -> None:
4253
patch("hypercorn.error")
4354

4455

45-
def setup_func(func: callable, **kwargs) -> Any:
56+
def setup_func(func: Callable[..., Any], **kwargs: Any) -> Any:
4657
parameters = signature(func).parameters
4758
func_kwargs = {}
4859
for name, parameter in parameters.items():
@@ -55,23 +66,35 @@ def setup_func(func: callable, **kwargs) -> Any:
5566
return func(**func_kwargs)
5667

5768

58-
async def main():
59-
assert (config.get("bot", {}) or {}).get(
60-
"token"
61-
), f"No bit token provided in config"
62-
unzip_extensions()
69+
async def load_and_run_patches():
70+
for patch_file in iglob("src/extensions/*/patch.py"):
71+
extension = os.path.basename(os.path.dirname(patch_file))
72+
if config["extensions"].get(extension, {}).get("enabled", False):
73+
logger.info(f"Loading patch for extension {extension}")
74+
spec = importlib.util.spec_from_file_location(
75+
f"src.extensions.{extension}.patch", patch_file
76+
)
77+
if not spec or not spec.loader:
78+
continue
79+
patch_module = importlib.util.module_from_spec(spec)
80+
spec.loader.exec_module(patch_module)
81+
if hasattr(patch_module, "patch") and callable(patch_module.patch):
82+
await asyncio.to_thread(patch_module.patch)
6383

64-
bot_functions: list[tuple[Callable, dict[Any]]] = []
65-
back_functions: list[tuple[Callable, dict[Any]]] = []
66-
startup_functions: list[tuple[Callable, dict[Any]]] = []
84+
85+
def load_extensions() -> (
86+
tuple["FunctionlistType", "FunctionlistType", "FunctionlistType"]
87+
):
88+
bot_functions: "FunctionlistType" = []
89+
back_functions: "FunctionlistType" = []
90+
startup_functions: "FunctionlistType" = []
6791

6892
for extension in iglob("src/extensions/*"):
6993
name = splitext(basename(extension))[0]
7094
its_config = config["extensions"].get(name, {})
7195
logger.info(f"Loading extension {name}")
7296
module: ModuleType = importlib.import_module(f"src.extensions.{name}")
7397
if not its_config:
74-
# use default config if not present
7598
its_config = module.default
7699
config["extensions"][name] = its_config
77100
if not its_config["enabled"]:
@@ -87,34 +110,64 @@ async def main():
87110
if hasattr(module, "on_startup") and callable(module.on_startup):
88111
startup_functions.append((module.on_startup, its_config))
89112

90-
startup_coros: list[Coroutine] = []
91-
coros: list[Coroutine] = []
92-
93-
bot = None
94-
back_bot = None
95-
app = None
96-
97-
if bot_functions:
98-
bot = discord.Bot(intents=discord.Intents.default())
99-
for function, its_config in bot_functions:
100-
setup_func(function, bot=bot, config=its_config)
101-
coros.append(start_bot(bot, config["bot"]["token"]))
102-
103-
if back_functions:
104-
back_bot = discord.Bot(intents=discord.Intents.default())
105-
app = Quart("backend")
106-
for function, its_config in back_functions:
107-
setup_func(function, app=app, bot=back_bot, config=its_config)
108-
coros.append(start_backend(app, back_bot, config["bot"]["token"]))
113+
return bot_functions, back_functions, startup_functions
114+
115+
116+
async def setup_and_start_bot(
117+
bot_functions: "FunctionlistType",
118+
):
119+
bot = discord.Bot(intents=discord.Intents.default())
120+
for function, its_config in bot_functions:
121+
setup_func(function, bot=bot, config=its_config)
122+
await start_bot(bot, config["bot"]["token"])
123+
124+
125+
async def setup_and_start_backend(
126+
back_functions: "FunctionlistType",
127+
):
128+
back_bot = discord.Bot(intents=discord.Intents.default())
129+
app = Quart("backend")
130+
for function, its_config in back_functions:
131+
setup_func(function, app=app, bot=back_bot, config=its_config)
132+
await start_backend(app, back_bot, config["bot"]["token"])
133+
134+
135+
async def run_startup_functions(
136+
startup_functions: "FunctionlistType",
137+
app: Quart | None,
138+
back_bot: discord.Bot | None,
139+
):
140+
startup_coros = [
141+
setup_func(function, app=app, bot=back_bot, config=its_config)
142+
for function, its_config in startup_functions
143+
]
144+
await asyncio.gather(*startup_coros)
145+
146+
147+
async def main(run_bot: bool = True, run_backend: bool = True):
148+
assert config.get("bot", {}).get("token"), "No bot token provided in config"
149+
unzip_extensions()
150+
151+
await load_and_run_patches()
152+
153+
bot_functions, back_functions, startup_functions = load_extensions()
154+
155+
coros: list[Coroutine[Any, Any, Any]] = []
156+
if bot_functions and run_bot:
157+
coros.append(setup_and_start_bot(bot_functions))
158+
if back_functions and run_backend:
159+
coros.append(setup_and_start_backend(back_functions))
109160
assert coros, "No extensions to run"
110161

111162
if startup_functions:
112-
for function, its_config in startup_functions:
113-
startup_coros.append(
114-
setup_func(function, app=app, bot=back_bot, config=its_config)
115-
)
163+
app = Quart("backend") if (back_functions and run_backend) else None
164+
back_bot = (
165+
discord.Bot(intents=discord.Intents.default())
166+
if (back_functions and run_backend)
167+
else None
168+
)
169+
await run_startup_functions(startup_functions, app, back_bot)
116170

117-
await asyncio.gather(*startup_coros)
118171
await asyncio.gather(*coros)
119172

120173
store_config()

readme.md

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,64 @@ schema = {
118118
```
119119
We really encourage you to follow these instructions, even if you’re coding privately, as it will make your code more readable and maintainable in the long run.
120120

121+
## Using Patch Files
122+
123+
Botkit supports the use of patch files to modify or extend the functionality of the bot or its dependencies before the main extension code runs. This is particularly useful for applying global changes or monkey-patching existing classes.
124+
125+
### How It Works
126+
127+
1. Create a file named `patch.py` in your extension's directory.
128+
2. Define a `patch()` function in this file. This function will be called before the extension is loaded.
129+
3. The `patch()` function can modify global state, patch classes, or perform any other setup needed.
130+
131+
### Example: Error Handling Patch
132+
133+
Here's an example from the `nice-errors` extension that demonstrates how to use a patch file to enhance error handling:
134+
135+
```python
136+
# nice-errors/patch.py
137+
138+
import discord
139+
from discord import Interaction
140+
from discord.ui import Item
141+
from typing_extensions import override
142+
143+
def patch():
144+
class PatchedView(discord.ui.View):
145+
@override
146+
async def on_error(
147+
self,
148+
error: Exception,
149+
item: Item,
150+
interaction: Interaction,
151+
) -> None:
152+
if not isinstance(error, discord.Forbidden):
153+
await interaction.respond(
154+
"Whoops! An error occurred while executing this command",
155+
ephemeral=True,
156+
)
157+
raise error
158+
await interaction.respond(
159+
f"Whoops! I don't have permission to do that\n`{error.args[0].split(':')[-1].strip()}`",
160+
ephemeral=True,
161+
)
162+
163+
discord.ui.View = PatchedView
164+
165+
```
166+
167+
This patch modifies the `discord.ui.View` class to provide more user-friendly error messages. It catches exceptions and responds to the user with an appropriate message, enhancing the overall user experience.
168+
169+
### When to Use Patch Files
170+
171+
Patch files are powerful but should be used judiciously. They are best suited for:
172+
173+
1. Applying global changes that affect multiple parts of your bot.
174+
2. Modifying third-party libraries when you can't or don't want to fork them.
175+
3. Implementing cross-cutting concerns like logging or error handling.
176+
177+
Remember that patches are applied early in the bot's lifecycle, so they can affect all subsequent code. Use them carefully and document their effects clearly.
178+
121179
## Using scripts
122180

123181
### `check-listings`

src/extensions/nice-errors/patch.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import discord
2+
from discord import Interaction
3+
from discord.ui import Item
4+
from typing_extensions import override
5+
6+
7+
def patch():
8+
9+
class PatchedView(discord.ui.View):
10+
11+
@override
12+
async def on_error(
13+
self,
14+
error: Exception,
15+
item: Item, # pyright: ignore[reportMissingTypeArgument,reportUnknownParameterType]
16+
interaction: Interaction,
17+
) -> None:
18+
if not isinstance(error, discord.Forbidden):
19+
await interaction.respond( # pyright: ignore[reportUnknownMemberType]
20+
"Whoops! An error occurred while executing this command",
21+
ephemeral=True,
22+
)
23+
raise error
24+
await interaction.respond( # pyright: ignore[reportUnknownMemberType]
25+
f"Whoops! I don't have permission to do that\n`{error.args[0].split(':')[-1].strip()}`",
26+
ephemeral=True,
27+
)
28+
29+
discord.ui.View = PatchedView

0 commit comments

Comments
 (0)