Source code for hikari.impl.interaction_server

# -*- coding: utf-8 -*-
# cython: language_level=3
# Copyright (c) 2020 Nekokatt
# Copyright (c) 2021-present davfsa
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""Standard implementation of a REST based interactions server."""
from __future__ import annotations

__all__: typing.Sequence[str] = ("InteractionServer",)

import asyncio
import inspect
import logging
import typing

import aiohttp
import aiohttp.web
import aiohttp.web_runner

from hikari import applications
from hikari import errors
from hikari.api import interaction_server
from hikari.api import special_endpoints
from hikari.interactions import base_interactions
from hikari.internal import data_binding

if typing.TYPE_CHECKING:
    import concurrent.futures
    import socket as socket_
    import ssl

    import aiohttp.abc
    import aiohttp.typedefs

    # This is kept inline as pynacl is an optional dependency.
    from nacl import signing

    from hikari import files as files_
    from hikari.api import entity_factory as entity_factory_api
    from hikari.api import rest as rest_api
    from hikari.interactions import command_interactions
    from hikari.interactions import component_interactions
    from hikari.interactions import modal_interactions

    _InteractionT_co = typing.TypeVar("_InteractionT_co", bound=base_interactions.PartialInteraction, covariant=True)
    _MessageResponseBuilderT = typing.Union[
        special_endpoints.InteractionDeferredBuilder,
        special_endpoints.InteractionMessageBuilder,
    ]
    _ModalOrMessageResponseBuilderT = typing.Union[
        _MessageResponseBuilderT,
        special_endpoints.InteractionModalBuilder,
    ]

_LOGGER: typing.Final[logging.Logger] = logging.getLogger("hikari.interaction_server")

# Internal interaction and interaction response types.
_PING_INTERACTION_TYPE: typing.Final[int] = 1
_PONG_RESPONSE_TYPE: typing.Final[int] = 1

# HTTP status codes.
_OK_STATUS: typing.Final[int] = 200
_BAD_REQUEST_STATUS: typing.Final[int] = 400
_PAYLOAD_TOO_LARGE_STATUS: typing.Final[int] = 413
_UNSUPPORTED_MEDIA_TYPE_STATUS: typing.Final[int] = 415
_INTERNAL_SERVER_ERROR_STATUS: typing.Final[int] = 500
_NOT_IMPLEMENTED: typing.Final[int] = 501

_UTF_8_CHARSET: typing.Final[str] = "UTF-8"

# Header keys and values
_X_SIGNATURE_ED25519_HEADER: typing.Final[str] = "X-Signature-Ed25519"
_X_SIGNATURE_TIMESTAMP_HEADER: typing.Final[str] = "X-Signature-Timestamp"
_CONTENT_TYPE_KEY: typing.Final[str] = "Content-Type"
_USER_AGENT_KEY: typing.Final[str] = "User-Agent"
_APPLICATION_OCTET_STREAM: typing.Final[str] = "application/octet-stream"
_JSON_CONTENT_TYPE: typing.Final[str] = "application/json"
_TEXT_CONTENT_TYPE: typing.Final[str] = "text/plain"


class _Response:
    __slots__: typing.Sequence[str] = ("_content_type", "_files", "_payload", "_status_code")

    def __init__(
        self,
        status_code: int,
        payload: typing.Optional[bytes] = None,
        *,
        content_type: typing.Optional[str] = None,
        files: typing.Sequence[files_.Resource[files_.AsyncReader]] = (),
    ) -> None:
        if payload and not content_type:
            content_type = _TEXT_CONTENT_TYPE

        self._content_type = content_type
        self._files = files
        self._payload = payload
        self._status_code = status_code

    @property
    def content_type(self) -> typing.Optional[str]:
        return self._content_type

    @property
    def charset(self) -> typing.Optional[str]:
        # No cases of charset not being UTF-8
        return _UTF_8_CHARSET if self._payload else None

    @property
    def files(self) -> typing.Sequence[files_.Resource[files_.AsyncReader]]:
        return self._files

    @property
    def headers(self) -> typing.Optional[typing.MutableMapping[str, str]]:
        return None

    @property
    def payload(self) -> typing.Optional[bytes]:
        return self._payload

    @property
    def status_code(self) -> int:
        return self._status_code


# Constant response
_PONG_RESPONSE: typing.Final[_Response] = _Response(
    _OK_STATUS, data_binding.default_json_dumps({"type": _PONG_RESPONSE_TYPE}), content_type=_JSON_CONTENT_TYPE
)


class _FilePayload(aiohttp.Payload):
    _value: files_.Resource[files_.AsyncReader]

    def __init__(
        self,
        value: files_.Resource[files_.AsyncReader],
        content_type: str,
        /,
        *,
        executor: typing.Optional[concurrent.futures.Executor] = None,
        headers: typing.Optional[typing.Dict[str, str]] = None,
    ) -> None:
        super().__init__(value=value, headers=headers, content_type=content_type)
        self._executor = executor

    async def write(self, writer: aiohttp.abc.AbstractStreamWriter) -> None:
        async with self._value.stream(executor=self._executor) as data:
            async for chunk in data:
                await writer.write(chunk)


async def _consume_generator_listener(generator: typing.AsyncGenerator[typing.Any, None]) -> None:
    try:
        await generator.__anext__()

        # We expect only one yield!
        await generator.athrow(RuntimeError("Generator listener yielded more than once, expected only one yield"))

    except StopAsyncIteration:
        pass

    except Exception as exc:
        asyncio.get_running_loop().call_exception_handler(
            {"message": "Exception occurred during interaction post dispatch", "exception": exc}
        )


[docs]class InteractionServer(interaction_server.InteractionServer): """Standard implementation of `hikari.api.interaction_server.InteractionServer`. Parameters ---------- entity_factory : hikari.api.entity_factory.EntityFactory The entity factory instance this server should use. Other Parameters ---------------- dumps : hikari.internal.data_binding.JSONEncoder The JSON encoder this server should use. Defaults to `hikari.internal.data_binding.default_json_dumps`. loads : hikari.internal.data_binding.JSONDecoder The JSON decoder this server should use. Defaults to `hikari.internal.data_binding.default_json_loads`. public_key : bytes The public key this server should use for verifying request payloads from Discord. If left as `None` then the client will try to work this out using `rest_client`. rest_client : hikari.api.rest.RESTClient The client this should use for making REST requests. """ __slots__: typing.Sequence[str] = ( "_application_fetch_lock", "_close_event", "_dumps", "_entity_factory", "_executor", "_is_closing", "_listeners", "_loads", "_nacl", "_public_key", "_rest_client", "_server", "_running_generator_listeners", ) def __init__( self, *, dumps: data_binding.JSONEncoder = data_binding.default_json_dumps, entity_factory: entity_factory_api.EntityFactory, executor: typing.Optional[concurrent.futures.Executor] = None, loads: data_binding.JSONDecoder = data_binding.default_json_loads, rest_client: rest_api.RESTClient, public_key: typing.Optional[bytes] = None, ) -> None: # This is kept inline as pynacl is an optional dependency. try: import nacl.exceptions import nacl.signing except ModuleNotFoundError as exc: raise RuntimeError( "You must install the optional `hikari[server]` dependencies to use the default interaction server." ) from exc # Building asyncio.Lock when there isn't a running loop may lead to runtime errors. self._application_fetch_lock: typing.Optional[asyncio.Lock] = None # Building asyncio.Event when there isn't a running loop may lead to runtime errors. self._close_event: typing.Optional[asyncio.Event] = None self._dumps = dumps self._entity_factory = entity_factory self._executor = executor self._is_closing = False self._listeners: typing.Dict[typing.Type[base_interactions.PartialInteraction], typing.Any] = {} self._loads = loads self._nacl = nacl self._rest_client = rest_client self._server: typing.Optional[aiohttp.web_runner.AppRunner] = None self._public_key = nacl.signing.VerifyKey(public_key) if public_key is not None else None self._running_generator_listeners: typing.List[asyncio.Task[None]] = [] @property
[docs] def is_alive(self) -> bool: """Whether this interaction server is active.""" return self._server is not None
async def _fetch_public_key(self) -> signing.VerifyKey: if self._application_fetch_lock is None: self._application_fetch_lock = asyncio.Lock() application: typing.Union[applications.Application, applications.AuthorizationApplication] async with self._application_fetch_lock: if self._public_key: return self._public_key if self._rest_client.token_type == applications.TokenType.BOT: application = await self._rest_client.fetch_application() else: application = (await self._rest_client.fetch_authorization()).application self._public_key = self._nacl.signing.VerifyKey(application.public_key) return self._public_key
[docs] async def aiohttp_hook(self, request: aiohttp.web.Request) -> aiohttp.web.Response: """Handle an AIOHTTP interaction request. This method handles aiohttp specific detail before calling `InteractionServer.on_interaction` with the data extracted from the request if it can and handles building an aiohttp response. Parameters ---------- request : aiohttp.web.Request The received request. Returns ------- aiohttp.web.Response The aiohttp response. """ if request.content_type.lower() != _JSON_CONTENT_TYPE: _LOGGER.debug("Payload with invalid media type %r received", request.content_type) return aiohttp.web.Response( status=_UNSUPPORTED_MEDIA_TYPE_STATUS, body=b"Unsupported Media Type", content_type=_TEXT_CONTENT_TYPE, charset=_UTF_8_CHARSET, ) try: signature_header = bytes.fromhex(request.headers[_X_SIGNATURE_ED25519_HEADER]) timestamp_header = request.headers[_X_SIGNATURE_TIMESTAMP_HEADER].encode() except (KeyError, ValueError): user_agent = request.headers.get(_USER_AGENT_KEY, "NONE") _LOGGER.debug("Received a request with a missing or invalid signature header (UA %r)", user_agent) return aiohttp.web.Response( status=_BAD_REQUEST_STATUS, body=b"Missing or invalid required request signature header(s)", content_type=_TEXT_CONTENT_TYPE, charset=_UTF_8_CHARSET, ) try: body = await request.read() except aiohttp.web.HTTPRequestEntityTooLarge: _LOGGER.debug("Received a request with a payload that's too large to process") return aiohttp.web.Response( status=_PAYLOAD_TOO_LARGE_STATUS, body=b"Payload too large", content_type=_TEXT_CONTENT_TYPE, charset=_UTF_8_CHARSET, ) if not body: user_agent = request.headers.get(_USER_AGENT_KEY, "NONE") _LOGGER.debug("Received a body-less request (UA %r)", user_agent) return aiohttp.web.Response( status=_BAD_REQUEST_STATUS, body=b"POST request must have a body", content_type=_TEXT_CONTENT_TYPE, charset=_UTF_8_CHARSET, ) response = await self.on_interaction(body=body, signature=signature_header, timestamp=timestamp_header) if response.files: multipart = aiohttp.MultipartWriter(subtype="form-data") if response.payload: body_payload = aiohttp.BytesPayload(response.payload, content_type=response.content_type) body_payload.set_content_disposition("form-data", name="payload_json") multipart.append_payload(body_payload) for index, file_ in enumerate(response.files): async with file_.stream(head_only=True) as stream: mimetype = stream.mimetype or _APPLICATION_OCTET_STREAM payload = _FilePayload(file_, mimetype, executor=self._executor) payload.set_content_disposition("form-data", name=f"files[{index}]", filename=file_.filename) multipart.append_payload(payload) return aiohttp.web.Response(status=response.status_code, headers=response.headers, body=multipart) return aiohttp.web.Response( status=response.status_code, headers=response.headers, body=response.payload, content_type=response.content_type, charset=response.charset, )
[docs] async def close(self) -> None: """Gracefully close the server and any open connections.""" if not self._server or not self._close_event: raise errors.ComponentStateConflictError("Cannot close an inactive interaction server") if self._is_closing: await self.join() return self._is_closing = True # This shut down then cleanup ordering matters. await self._server.shutdown() await self._server.cleanup() self._server = None self._application_fetch_lock = None # Wait for handlers to complete await asyncio.gather(*self._running_generator_listeners) self._running_generator_listeners = [] self._close_event.set() self._close_event = None self._is_closing = False
[docs] async def join(self) -> None: """Wait for the process to halt before continuing.""" if not self._close_event: raise errors.ComponentStateConflictError("Cannot wait for an inactive interaction server to join") await self._close_event.wait()
[docs] async def on_interaction(self, body: bytes, signature: bytes, timestamp: bytes) -> interaction_server.Response: """Handle an interaction received from Discord as a REST server. .. note:: If this server instance is alive then this will be called internally by the server but if the instance isn't alive then this may still be called externally to trigger interaction dispatch. Parameters ---------- body : bytes The interaction payload. signature : bytes Value of the `"X-Signature-Ed25519"` header used to verify the body. timestamp : bytes Value of the `"X-Signature-Timestamp"` header used to verify the body. Returns ------- hikari.api.interaction_server.Response Instructions on how the REST server calling this should respond to the interaction request. """ public_key = self._public_key or await self._fetch_public_key() try: public_key.verify(timestamp + body, signature) except (self._nacl.exceptions.BadSignatureError, ValueError): _LOGGER.error("Received a request with an invalid signature") return _Response(_BAD_REQUEST_STATUS, b"Invalid request signature") try: payload = self._loads(body) assert isinstance(payload, dict) interaction_type = int(payload["type"]) except (ValueError, TypeError) as exc: _LOGGER.error("Received a request with an invalid JSON body", exc_info=exc) return _Response(_BAD_REQUEST_STATUS, b"Invalid JSON body") except KeyError as exc: _LOGGER.error("Missing 'type' field in received JSON payload", exc_info=exc) return _Response(_BAD_REQUEST_STATUS, b"Missing required 'type' field in payload") if interaction_type == _PING_INTERACTION_TYPE: _LOGGER.debug("Responding to ping interaction") return _PONG_RESPONSE try: interaction = self._entity_factory.deserialize_interaction(payload) except errors.UnrecognisedEntityError: _LOGGER.debug("Ignoring unknown interaction type %s", interaction_type) return _Response(_NOT_IMPLEMENTED, b"Interaction type not implemented") except Exception as exc: asyncio.get_running_loop().call_exception_handler( {"message": "Exception occurred during interaction deserialization", "exception": exc} ) return _Response(_INTERNAL_SERVER_ERROR_STATUS, b"Exception occurred during interaction deserialization") if listener := self._listeners.get(type(interaction)): _LOGGER.debug("Dispatching interaction %s", interaction.id) try: call = listener(interaction) if inspect.isasyncgen(call): result = await call.__anext__() task = asyncio.create_task(_consume_generator_listener(call)) task.add_done_callback(self._running_generator_listeners.remove) self._running_generator_listeners.append(task) else: result = await call raw_payload, files = result.build(self._entity_factory) payload = self._dumps(raw_payload) except Exception as exc: asyncio.get_running_loop().call_exception_handler( {"message": "Exception occurred during interaction dispatch", "exception": exc} ) return _Response(_INTERNAL_SERVER_ERROR_STATUS, b"Exception occurred during interaction dispatch") return _Response(_OK_STATUS, payload, files=files, content_type=_JSON_CONTENT_TYPE) _LOGGER.debug( "Ignoring interaction %s of type %s without registered listener", interaction.id, interaction.type ) return _Response(_NOT_IMPLEMENTED, b"Handler not set for this interaction type")
[docs] async def start( self, backlog: int = 128, host: typing.Optional[typing.Union[str, typing.Sequence[str]]] = None, port: typing.Optional[int] = None, path: typing.Optional[str] = None, reuse_address: typing.Optional[bool] = None, reuse_port: typing.Optional[bool] = None, socket: typing.Optional[socket_.socket] = None, shutdown_timeout: float = 60.0, ssl_context: typing.Optional[ssl.SSLContext] = None, ) -> None: """Start the bot and wait for the internal server to startup then return. .. note:: For more information on the other parameters such as defaults see AIOHTTP's documentation. Other Parameters ---------------- backlog : int The number of unaccepted connections that the system will allow before refusing new connections. host : typing.Optional[typing.Union[str, aiohttp.web.HostSequence]] TCP/IP host or a sequence of hosts for the HTTP server. port : typing.Optional[int] TCP/IP port for the HTTP server. path : typing.Optional[str] File system path for HTTP server unix domain socket. reuse_address : typing.Optional[bool] Tells the kernel to reuse a local socket in TIME_WAIT state, without waiting for its natural timeout to expire. reuse_port : typing.Optional[bool] Tells the kernel to allow this endpoint to be bound to the same port as other existing endpoints are also bound to. socket : typing.Optional[socket.socket] A pre-existing socket object to accept connections on. shutdown_timeout : float A delay to wait for graceful server shutdown before forcefully disconnecting all open client sockets. This defaults to 60 seconds. ssl_context : typing.Optional[ssl.SSLContext] SSL context for HTTPS servers. """ if self._server: raise errors.ComponentStateConflictError("Cannot start an already active interaction server") self._close_event = asyncio.Event() self._is_closing = False aio_app = aiohttp.web.Application() aio_app.add_routes([aiohttp.web.post("/", self.aiohttp_hook)]) self._server = aiohttp.web_runner.AppRunner(aio_app, access_log=_LOGGER) await self._server.setup() sites: typing.List[aiohttp.web.BaseSite] = [] if host is not None: if isinstance(host, str): host = (host,) for h in host: sites.append( aiohttp.web.TCPSite( self._server, h, port=port, shutdown_timeout=shutdown_timeout, ssl_context=ssl_context, backlog=backlog, reuse_address=reuse_address, reuse_port=reuse_port, ) ) elif path is None and socket is None or port is not None: sites.append( aiohttp.web.TCPSite( self._server, port=port, shutdown_timeout=shutdown_timeout, ssl_context=ssl_context, backlog=backlog, reuse_address=reuse_address, reuse_port=reuse_port, ) ) if path is not None: sites.append( aiohttp.web.UnixSite( self._server, path, shutdown_timeout=shutdown_timeout, ssl_context=ssl_context, backlog=backlog ) ) if socket is not None: sites.append( aiohttp.web.SockSite( self._server, socket, shutdown_timeout=shutdown_timeout, ssl_context=ssl_context, backlog=backlog ) ) for site in sites: _LOGGER.info("Starting site on %s", site.name) await site.start()
@typing.overload
[docs] def get_listener( self, interaction_type: typing.Type[command_interactions.CommandInteraction], / ) -> typing.Optional[ interaction_server.ListenerT[command_interactions.CommandInteraction, _ModalOrMessageResponseBuilderT] ]: ...
@typing.overload def get_listener( self, interaction_type: typing.Type[component_interactions.ComponentInteraction], / ) -> typing.Optional[ interaction_server.ListenerT[component_interactions.ComponentInteraction, _ModalOrMessageResponseBuilderT] ]: ... @typing.overload def get_listener( self, interaction_type: typing.Type[command_interactions.AutocompleteInteraction], / ) -> typing.Optional[ interaction_server.ListenerT[ command_interactions.AutocompleteInteraction, special_endpoints.InteractionAutocompleteBuilder ] ]: ... @typing.overload def get_listener( self, interaction_type: typing.Type[modal_interactions.ModalInteraction], / ) -> typing.Optional[interaction_server.ListenerT[modal_interactions.ModalInteraction, _MessageResponseBuilderT]]: ... @typing.overload def get_listener( self, interaction_type: typing.Type[_InteractionT_co], / ) -> typing.Optional[interaction_server.ListenerT[_InteractionT_co, special_endpoints.InteractionResponseBuilder]]: ... def get_listener( self, interaction_type: typing.Type[_InteractionT_co], / ) -> typing.Optional[interaction_server.ListenerT[_InteractionT_co, special_endpoints.InteractionResponseBuilder]]: return self._listeners.get(interaction_type) @typing.overload
[docs] def set_listener( self, interaction_type: typing.Type[command_interactions.CommandInteraction], listener: typing.Optional[ interaction_server.ListenerT[command_interactions.CommandInteraction, _ModalOrMessageResponseBuilderT] ], /, *, replace: bool = False, ) -> None: ...
@typing.overload def set_listener( self, interaction_type: typing.Type[component_interactions.ComponentInteraction], listener: typing.Optional[ interaction_server.ListenerT[component_interactions.ComponentInteraction, _ModalOrMessageResponseBuilderT] ], /, *, replace: bool = False, ) -> None: ... @typing.overload def set_listener( self, interaction_type: typing.Type[command_interactions.AutocompleteInteraction], listener: typing.Optional[ interaction_server.ListenerT[ command_interactions.AutocompleteInteraction, special_endpoints.InteractionAutocompleteBuilder ] ], /, *, replace: bool = False, ) -> None: ... @typing.overload def set_listener( self, interaction_type: typing.Type[modal_interactions.ModalInteraction], listener: typing.Optional[ interaction_server.ListenerT[modal_interactions.ModalInteraction, _MessageResponseBuilderT] ], /, *, replace: bool = False, ) -> None: ... def set_listener( self, interaction_type: typing.Type[_InteractionT_co], listener: typing.Optional[ interaction_server.ListenerT[_InteractionT_co, special_endpoints.InteractionResponseBuilder] ], /, *, replace: bool = False, ) -> None: if listener: if not replace and interaction_type in self._listeners: raise TypeError(f"Listener already set for {interaction_type.__name__}") self._listeners[interaction_type] = listener else: self._listeners.pop(interaction_type, None)