Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 90 additions & 15 deletions discord/ext/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
CommandT = TypeVar('CommandT', bound='Command[Any, ..., Any]')
# CHT = TypeVar('CHT', bound='Check')
GroupT = TypeVar('GroupT', bound='Group[Any, ..., Any]')
SpecialDataT = TypeVar('SpecialDataT', discord.Attachment, discord.StickerItem)

if TYPE_CHECKING:
P = ParamSpec('P')
Expand Down Expand Up @@ -252,6 +253,31 @@ async def wrapped(*args: P.args, **kwargs: P.kwargs) -> Optional[T]:
return wrapped


async def _convert_stickers(
sticker_type: Type[Union[discord.StickerItem, discord.Sticker, discord.StandardSticker, discord.GuildSticker]],
stickers: _SpecialIterator[discord.StickerItem],
param: Parameter,
/,
) -> Union[discord.StickerItem, discord.Sticker, discord.StandardSticker, discord.GuildSticker]:
if sticker_type is discord.StickerItem:
try:
return next(stickers)
except StopIteration:
raise MissingRequiredSticker(param)

while not stickers.is_empty():
try:
sticker = next(stickers)
except StopIteration:
raise MissingRequiredSticker(param)

fetched = await sticker.fetch()
if isinstance(fetched, sticker_type):
return fetched

raise MissingRequiredSticker(param)


class _CaseInsensitiveDict(dict):
def __contains__(self, k):
return super().__contains__(k.casefold())
Expand All @@ -272,15 +298,15 @@ def __setitem__(self, k, v):
super().__setitem__(k.casefold(), v)


class _AttachmentIterator:
def __init__(self, data: List[discord.Attachment]):
self.data: List[discord.Attachment] = data
class _SpecialIterator(Generic[SpecialDataT]):
def __init__(self, data: List[SpecialDataT]):
self.data: List[SpecialDataT] = data
self.index: int = 0

def __iter__(self) -> Self:
return self

def __next__(self) -> discord.Attachment:
def __next__(self) -> SpecialDataT:
try:
value = self.data[self.index]
except IndexError:
Expand Down Expand Up @@ -649,7 +675,14 @@ async def dispatch_error(self, ctx: Context[BotT], error: CommandError, /) -> No
finally:
ctx.bot.dispatch('command_error', ctx, error)

async def transform(self, ctx: Context[BotT], param: Parameter, attachments: _AttachmentIterator, /) -> Any:
async def transform(
self,
ctx: Context[BotT],
param: Parameter,
attachments: _SpecialIterator[discord.Attachment],
stickers: _SpecialIterator[discord.StickerItem],
/,
) -> Any:
converter = param.converter
consume_rest_is_special = param.kind == param.KEYWORD_ONLY and not self.rest_is_raw
view = ctx.view
Expand All @@ -661,6 +694,15 @@ async def transform(self, ctx: Context[BotT], param: Parameter, attachments: _At
# Special case for Greedy[discord.Attachment] to consume the attachments iterator
if converter.converter is discord.Attachment:
return list(attachments)
# Special case for Greedy[discord.StickerItem] to consume the stickers iterator
elif converter.converter in (
discord.StickerItem,
discord.Sticker,
discord.StandardSticker,
discord.GuildSticker,
):
# can only send one sticker at a time
return [await _convert_stickers(converter.converter, stickers, param)]

if param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY):
return await self._transform_greedy_pos(ctx, param, param.required, converter.constructed_converter)
Expand All @@ -679,12 +721,27 @@ async def transform(self, ctx: Context[BotT], param: Parameter, attachments: _At
except StopIteration:
raise MissingRequiredAttachment(param)

if self._is_typing_optional(param.annotation) and param.annotation.__args__[0] is discord.Attachment:
if attachments.is_empty():
# I have no idea who would be doing Optional[discord.Attachment] = 1
# but for those cases then 1 should be returned instead of None
return None if param.default is param.empty else param.default
return next(attachments)
# Try to detect Optional[discord.StickerItem] or discord.StickerItem special converter
if converter in (discord.StickerItem, discord.Sticker, discord.StandardSticker, discord.GuildSticker):
return await _convert_stickers(converter, stickers, param)

if self._is_typing_optional(param.annotation):
if param.annotation.__args__[0] is discord.Attachment:
if attachments.is_empty():
# I have no idea who would be doing Optional[discord.Attachment] = 1
# but for those cases then 1 should be returned instead of None
return None if param.default is param.empty else param.default
return next(attachments)
elif param.annotation.__args__[0] in (
discord.StickerItem,
discord.Sticker,
discord.StandardSticker,
discord.GuildSticker,
):
if stickers.is_empty():
return None if param.default is param.empty else param.default

return await _convert_stickers(param.annotation.__args__[0], stickers, param)

if view.eof:
if param.kind == param.VAR_POSITIONAL:
Expand Down Expand Up @@ -834,30 +891,32 @@ async def _parse_arguments(self, ctx: Context[BotT]) -> None:
ctx.kwargs = {}
args = ctx.args
kwargs = ctx.kwargs
attachments = _AttachmentIterator(ctx.message.attachments)

attachments = _SpecialIterator(ctx.message.attachments)
stickers = _SpecialIterator(ctx.message.stickers)

view = ctx.view
iterator = iter(self.params.items())

for name, param in iterator:
ctx.current_parameter = param
if param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY):
transformed = await self.transform(ctx, param, attachments)
transformed = await self.transform(ctx, param, attachments, stickers)
args.append(transformed)
elif param.kind == param.KEYWORD_ONLY:
# kwarg only param denotes "consume rest" semantics
if self.rest_is_raw:
ctx.current_argument = argument = view.read_rest()
kwargs[name] = await run_converters(ctx, param.converter, argument, param)
else:
kwargs[name] = await self.transform(ctx, param, attachments)
kwargs[name] = await self.transform(ctx, param, attachments, stickers)
break
elif param.kind == param.VAR_POSITIONAL:
if view.eof and self.require_var_positional:
raise MissingRequiredArgument(param)
while not view.eof:
try:
transformed = await self.transform(ctx, param, attachments)
transformed = await self.transform(ctx, param, attachments, stickers)
args.append(transformed)
except RuntimeError:
break
Expand Down Expand Up @@ -1202,6 +1261,22 @@ def signature(self) -> str:
result.append(f'<{name} (upload a file)>')
continue

if annotation in (discord.StickerItem, discord.Sticker, discord.StandardSticker, discord.GuildSticker):
if annotation is discord.GuildSticker:
sticker_type = 'server sticker'
elif annotation is discord.StandardSticker:
sticker_type = 'standard sticker'
else:
sticker_type = 'sticker'

if optional:
result.append(f'[{name} (upload a {sticker_type})]')
elif greedy:
result.append(f'[{name} (upload {sticker_type}s)]...')
else:
result.append(f'<{name} (upload a {sticker_type})>')
continue

# for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the
# parameter signature is a literal list of it's values
if origin is Literal:
Expand Down
30 changes: 30 additions & 0 deletions discord/ext/commands/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
'CommandError',
'MissingRequiredArgument',
'MissingRequiredAttachment',
'MissingRequiredSticker',
'BadArgument',
'PrivateMessageOnly',
'NoPrivateMessage',
Expand Down Expand Up @@ -206,6 +207,35 @@ def __init__(self, param: Parameter) -> None:
super().__init__(f'{param.displayed_name or param.name} is a required argument that is missing an attachment.')


class MissingRequiredSticker(UserInputError):
"""Exception raised when parsing a command and a parameter
that requires a sticker is not given.

This inherits from :exc:`UserInputError`

.. versionadded:: 2.5

Attributes
-----------
param: :class:`Parameter`
The argument that is missing a sticker.
"""

def __init__(self, param: Parameter) -> None:
from ...sticker import GuildSticker, StandardSticker

self.param: Parameter = param
converter = param.converter
if converter == GuildSticker:
sticker_type = 'server sticker'
elif converter == StandardSticker:
sticker_type = 'standard sticker'
else:
sticker_type = 'sticker'

super().__init__(f'{param.displayed_name or param.name} is a required argument that is missing a {sticker_type}.')


class TooManyArguments(UserInputError):
"""Exception raised when the command was passed too many arguments and its
:attr:`.Command.ignore_extra` attribute was not set to ``True``.
Expand Down
4 changes: 4 additions & 0 deletions docs/ext/commands/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,9 @@ Exceptions
.. autoexception:: discord.ext.commands.MissingRequiredAttachment
:members:

.. autoexception:: discord.ext.commands.MissingRequiredSticker
:members:

.. autoexception:: discord.ext.commands.ArgumentParsingError
:members:

Expand Down Expand Up @@ -789,6 +792,7 @@ Exception Hierarchy
- :exc:`~.commands.UserInputError`
- :exc:`~.commands.MissingRequiredArgument`
- :exc:`~.commands.MissingRequiredAttachment`
- :exc:`~.commands.MissingRequiredSticker`
- :exc:`~.commands.TooManyArguments`
- :exc:`~.commands.BadArgument`
- :exc:`~.commands.MessageNotFound`
Expand Down
41 changes: 41 additions & 0 deletions docs/ext/commands/commands.rst
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,47 @@ Note that using a :class:`discord.Attachment` converter after a :class:`~ext.com

If an attachment is expected but not given, then :exc:`~ext.commands.MissingRequiredAttachment` is raised to the error handlers.


Stickers
^^^^^^^^^^^^^^^^^^

.. versionadded:: 2.5

Annotating a parameter with any of the following sticker types will automatically get the uploaded sticker on a message and return the corresponding object:

- :class:`~discord.StickerItem`
- :class:`~discord.Sticker`
- :class:`~discord.StandardSticker`
- :class:`~discord.GuildSticker`

Consider the following example:

.. code-block:: python3

import discord

@bot.command()
async def sticker(ctx, sticker: discord.Sticker):
await ctx.send(f'You have uploaded {sticker.name} with format: {sticker.format}!')

When this command is invoked, the user must directly upload a sticker for the command body to be executed. When combined with the :data:`typing.Optional` converter, the user does not have to provide a sticker.

.. code-block:: python3

import typing
import discord

@bot.command()
async def upload(ctx, attachment: typing.Optional[discord.GuildSticker]):
if attachment is None:
await ctx.send('You did not upload anything!')
else:
await ctx.send(f'You have uploaded {sticker.name} with format: {sticker.format} from server: {sticker.guild}!')

If a sticker is expected but not given, then :exc:`~ext.commands.MissingRequiredSticker` is raised to the error handlers.

:class:`~ext.commands.Greedy` is supported too but at the moment, users can only upload one sticker at a time.

.. _ext_commands_flag_converter:

FlagConverter
Expand Down
Loading