Source code for hikari.impl.shard

# -*- coding: utf-8 -*-
# cython: language_level=3
# Copyright (c) 2020 Nekokatt
# Copyright (c) 2021-present davfsa
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""Single-shard implementation for the V10 event gateway for Discord."""

from __future__ import annotations

__all__: typing.Sequence[str] = ("GatewayShardImpl",)

import asyncio
import contextlib
import logging
import platform
import sys
import typing
import urllib.parse
import zlib

import aiohttp

from hikari import _about as about
from hikari import errors
from hikari import intents as intents_
from hikari import presences
from hikari import snowflakes
from hikari import undefined
from hikari import urls
from hikari.api import shard
from hikari.impl import rate_limits
from hikari.internal import aio
from hikari.internal import data_binding
from hikari.internal import net
from hikari.internal import time
from hikari.internal import ux

if typing.TYPE_CHECKING:
    import datetime

    import aiohttp.http_websocket
    import aiohttp.typedefs

    from hikari import channels
    from hikari import guilds
    from hikari import users as users_
    from hikari.api import event_factory as event_factory_
    from hikari.api import event_manager as event_manager_
    from hikari.impl import config

# Important attributes
_D: typing.Final[str] = sys.intern("d")
_T: typing.Final[str] = sys.intern("t")
_S: typing.Final[str] = sys.intern("s")
_OP: typing.Final[str] = sys.intern("op")

# Opcodes
_DISPATCH: typing.Final[int] = 0
_HEARTBEAT: typing.Final[int] = 1
_IDENTIFY: typing.Final[int] = 2
_PRESENCE_UPDATE: typing.Final[int] = 3
_VOICE_STATE_UPDATE: typing.Final[int] = 4
_RESUME: typing.Final[int] = 6
_RECONNECT: typing.Final[int] = 7
_REQUEST_GUILD_MEMBERS: typing.Final[int] = 8
_INVALID_SESSION: typing.Final[int] = 9
_HELLO: typing.Final[int] = 10
_HEARTBEAT_ACK: typing.Final[int] = 11
# Special dispatches
_READY: typing.Final[str] = sys.intern("READY")
_RESUMED: typing.Final[str] = sys.intern("RESUMED")
# If we disconnect within this period of time after starting, we should
# use an exponential backoff before restarting.
_BACKOFF_WINDOW: typing.Final[float] = 30.0
_BACKOFF_BASE: typing.Final[float] = 1.85
_BACKOFF_CAP: typing.Final[float] = 60.0
# Discord seems to invalidate sessions if I send a 1xxx, which is useless
# for invalid session and reconnect messages where I want to be able to
# resume.
_RESUME_CLOSE_CODE: typing.Final[int] = 3_000
# Per-shard sending rate-limit
_TOTAL_RATELIMIT: typing.Final[typing.Tuple[float, int]] = (60.0, 120)
# Rate-limit for non-priority packages.
# This is done to always allow for HEARTBEAT packages
# to get around (leaving 3 slots for it).
_NON_PRIORITY_RATELIMIT: typing.Final[typing.Tuple[float, int]] = (60.0, 117)
# Used to identify the end of a ZLIB payload
_ZLIB_SUFFIX: typing.Final[bytes] = b"\x00\x00\xff\xff"
# Close codes which don't invalidate the current session.
_RECONNECTABLE_CLOSE_CODES: typing.FrozenSet[errors.ShardCloseCode] = frozenset(
    (
        errors.ShardCloseCode.UNKNOWN_ERROR,
        errors.ShardCloseCode.DECODE_ERROR,
        errors.ShardCloseCode.INVALID_SEQ,
        errors.ShardCloseCode.SESSION_TIMEOUT,
        errors.ShardCloseCode.RATE_LIMITED,
    )
)


def _log_filterer(token: str) -> typing.Callable[[str], str]:
    def filterer(entry: str) -> str:
        return entry.replace(token, "**REDACTED TOKEN**")

    return filterer


@typing.final
class _GatewayTransport:
    """Internal component to handle lower-level communication logic.

    This includes translating aiohttp error conditions to hikari ones,
    handling inbound zlib packets, creating the websocket and client session,
    and ensuring all resources are freed deterministically where possible.

    Payload logging is also performed here.
    """

    __slots__ = (
        "_zlib",
        "_sent_close",
        "_logger",
        "_exit_stack",
        "_log_filterer",
        "_ws",
        "_receive_and_check",
        "_loads",
        "_dumps",
    )

    def __init__(
        self,
        ws: aiohttp.ClientWebSocketResponse,
        transport_compression: bool,
        exit_stack: contextlib.AsyncExitStack,
        logger: logging.Logger,
        log_filterer: typing.Callable[[str], str],
        dumps: data_binding.JSONEncoder,
        loads: data_binding.JSONDecoder,
    ) -> None:
        self._logger = logger
        self._log_filterer = log_filterer
        self._exit_stack = exit_stack
        self._sent_close = False
        self._ws = ws
        self._zlib = zlib.decompressobj()
        self._loads = loads
        self._dumps = dumps

        if transport_compression:
            self._receive_and_check = self._receive_and_check_zlib
        else:
            self._receive_and_check = self._receive_and_check_text

    async def send_close(self, *, code: int, message: bytes) -> None:
        if self._sent_close:
            return

        self._sent_close = True
        self._logger.debug("sending close frame with code %s and message %s", code, message)
        try:
            await asyncio.wait_for(self._ws.close(code=code, message=message), timeout=5)

        except asyncio.TimeoutError:
            self._logger.debug("failed to send close frame in time, probably connection issues")

        finally:
            await self._exit_stack.aclose()

            # We have to sleep to allow aiohttp time to close SSL transports...
            # This code can be removed in aiohttp v4.0.0
            # https://github.com/aio-libs/aiohttp/issues/1925
            # https://docs.aiohttp.org/en/stable/client_advanced.html#graceful-shutdown
            await asyncio.sleep(0.25)

    async def receive_json(self) -> typing.Any:
        pl = await self._receive_and_check()
        if self._logger.isEnabledFor(ux.TRACE):
            filtered = self._log_filterer(pl)
            self._logger.log(ux.TRACE, "received payload with size %s\n    %s", len(pl), filtered)

        return self._loads(pl)

    async def send_json(self, data: data_binding.JSONObject) -> None:
        pl = self._dumps(data)
        if self._logger.isEnabledFor(ux.TRACE):
            filtered = self._log_filterer(pl.decode("utf-8"))
            self._logger.log(ux.TRACE, "sending payload with size %s\n    %s", len(pl), filtered)

        await self._ws.send_bytes(pl)

    def _handle_other_message(self, message: aiohttp.WSMessage, /) -> typing.NoReturn:
        if message.type == aiohttp.WSMsgType.TEXT:
            raise errors.GatewayError("Unexpected message type received TEXT, expected BINARY")

        if message.type == aiohttp.WSMsgType.BINARY:
            raise errors.GatewayError("Unexpected message type received BINARY, expected TEXT")

        if message.type == aiohttp.WSMsgType.CLOSE:
            close_code = int(message.data)

            can_reconnect = close_code < 4000 or close_code in _RECONNECTABLE_CLOSE_CODES
            raise errors.GatewayServerClosedConnectionError(message.extra, close_code, can_reconnect)

        if message.type == aiohttp.WSMsgType.CLOSING or message.type == aiohttp.WSMsgType.CLOSED:
            # May be caused by the server shutting us down.
            # May be caused by Windows injecting an EOF if something disconnects, as some
            # network drivers appear to do this.
            raise errors.GatewayConnectionError("Socket has closed")

        # Assume exception for now.
        raise errors.GatewayError("Unexpected websocket exception from gateway") from self._ws.exception()

    async def _receive_and_check_text(self) -> str:
        message = await self._ws.receive()

        if message.type == aiohttp.WSMsgType.TEXT:
            assert isinstance(message.data, str)
            return message.data

        self._handle_other_message(message)

    async def _receive_and_check_zlib(self) -> str:
        message = await self._ws.receive()

        if message.type == aiohttp.WSMsgType.BINARY:
            if message.data.endswith(_ZLIB_SUFFIX):
                return self._zlib.decompress(message.data).decode("utf-8")

            return await self._receive_and_check_complete_zlib_package(message.data)

        self._handle_other_message(message)

    async def _receive_and_check_complete_zlib_package(self, initial_data: bytes, /) -> str:
        buff = bytearray(initial_data)

        while not buff.endswith(_ZLIB_SUFFIX):
            message = await self._ws.receive()

            if message.type == aiohttp.WSMsgType.BINARY:
                buff.extend(message.data)
                continue

            self._handle_other_message(message)

        return self._zlib.decompress(buff).decode("utf-8")

    @classmethod
    async def connect(
        cls,
        *,
        http_settings: config.HTTPSettings,
        logger: logging.Logger,
        proxy_settings: config.ProxySettings,
        log_filterer: typing.Callable[[str], str],
        dumps: data_binding.JSONEncoder,
        loads: data_binding.JSONDecoder,
        transport_compression: bool,
        url: str,
    ) -> _GatewayTransport:
        """Generate a single-use websocket connection.

        This uses a single connection in a TCP connector pool, with a one-use
        aiohttp client session.
        """
        exit_stack = contextlib.AsyncExitStack()

        try:
            try:
                connector = net.create_tcp_connector(http_settings=http_settings, dns_cache=False, limit=1)
                client_session = await exit_stack.enter_async_context(
                    net.create_client_session(
                        connector=connector,
                        connector_owner=True,
                        http_settings=http_settings,
                        raise_for_status=True,
                        trust_env=proxy_settings.trust_env,
                    )
                )

                web_socket = await exit_stack.enter_async_context(
                    client_session.ws_connect(
                        max_msg_size=0,
                        proxy=proxy_settings.url,
                        proxy_headers=proxy_settings.headers,
                        url=url,
                        # We manage this ourselves
                        autoclose=False,
                    )
                )

                return cls(
                    ws=web_socket,
                    transport_compression=transport_compression,
                    exit_stack=exit_stack,
                    logger=logger,
                    log_filterer=log_filterer,
                    loads=loads,
                    dumps=dumps,
                )

            except (aiohttp.ClientConnectionError, aiohttp.ClientResponseError, asyncio.TimeoutError) as ex:
                # If we cannot do DNS lookup, this will fail with an aiohttp.ClientConnectionError
                # usually, but it might also fail with asyncio.TimeoutError if its gets stuck in a weird way
                #
                # aiohttp.ClientResponseError has a really bad str, so we use the repr instead
                if isinstance(ex, aiohttp.ClientResponseError):
                    reason = repr(ex)
                elif isinstance(ex, asyncio.TimeoutError):
                    reason = "Timeout exceeded"
                else:
                    reason = str(ex)
                raise errors.GatewayConnectionError(reason) from None

        except Exception:
            await exit_stack.aclose()

            # We have to sleep to allow aiohttp time to close SSL transports...
            # This code can be removed in aiohttp v4.0.0
            # https://github.com/aio-libs/aiohttp/issues/1925
            # https://docs.aiohttp.org/en/stable/client_advanced.html#graceful-shutdown
            await asyncio.sleep(0.25)

            raise


def _serialize_datetime(dt: typing.Optional[datetime.datetime]) -> typing.Optional[int]:
    if dt is None:
        return None

    return int(dt.timestamp() * 1_000)


def _serialize_activity(activity: typing.Optional[presences.Activity]) -> data_binding.JSONish:
    if activity is None:
        return None

    return {"name": activity.name, "type": int(activity.type), "url": activity.url}


[docs]class GatewayShardImpl(shard.GatewayShard): """Implementation of a V10 compatible gateway. .. note:: If all four of `initial_activity`, `initial_idle_since`, `initial_is_afk`, and `initial_status` are not defined and left to their default values, then the presence will not be _updated_ on startup at all. If any of these _are_ specified, then any that are not specified will be set to sane defaults, which may change the previous status. This will only occur during startup, and is an artifact of how Discord manages these updates internally. All other calls to update the status of the shard will support partial updates. Parameters ---------- token : str The bot token to use. url : str The gateway URL to use. This should not contain a query-string or fragments. event_manager : hikari.api.event_manager.EventManager The event manager this shard should make calls to. event_factory : hikari.api.event_factory.EventFactory The event factory this shard should use. Other Parameters ---------------- compression : typing.Optional[str] Compression format to use for the shard. Only supported values are `"transport_zlib_stream"` or `None` to disable it. dumps : hikari.internal.data_binding.JSONEncoder The JSON encoder this application should use. Defaults to `hikari.internal.data_binding.default_json_dumps`. loads : hikari.internal.data_binding.JSONDecoder The JSON decoder this application should use. Defaults to `hikari.internal.data_binding.default_json_loads`. initial_activity : typing.Optional[hikari.presences.Activity] The initial activity to appear to have for this shard, or `None` if no activity should be set initially. This is the default. initial_idle_since : typing.Optional[datetime.datetime] The datetime to appear to be idle since, or `None` if the shard should not provide this. The default is `None`. initial_is_afk : bool Whether to appear to be AFK or not on login. Defaults to `False`. initial_status : hikari.presences.Status The initial status to set on login for the shard. Defaults to `hikari.presences.Status.ONLINE`. intents : hikari.intents.Intents Collection of intents to use. large_threshold : int The number of members to have in a guild for it to be considered large. shard_id : int The shard ID. shard_count : int The shard count. http_settings : hikari.impl.config.HTTPSettings The HTTP-related settings to use while negotiating a websocket. proxy_settings : hikari.impl.config.ProxySettings The proxy settings to use while negotiating a websocket. data_format : str Data format to use for inbound data. Only supported format is `"json"`. """ __slots__: typing.Sequence[str] = ( "_activity", "_dumps", "_event_manager", "_event_factory", "_gateway_url", "_handshake_event", "_heartbeat_latency", "_http_settings", "_idle_since", "_intents", "_is_afk", "_is_closing", "_keep_alive_task", "_large_threshold", "_last_heartbeat_ack_received", "_last_heartbeat_sent", "_loads", "_logger", "_non_priority_rate_limit", "_proxy_settings", "_resume_gateway_url", "_seq", "_session_id", "_shard_count", "_shard_id", "_status", "_token", "_total_rate_limit", "_transport_compression", "_user_id", "_ws", ) def __init__( self, *, compression: typing.Optional[str] = shard.GatewayCompression.TRANSPORT_ZLIB_STREAM, dumps: data_binding.JSONEncoder = data_binding.default_json_dumps, loads: data_binding.JSONDecoder = data_binding.default_json_loads, initial_activity: typing.Optional[presences.Activity] = None, initial_idle_since: typing.Optional[datetime.datetime] = None, initial_is_afk: bool = False, initial_status: presences.Status = presences.Status.ONLINE, intents: intents_.Intents, large_threshold: int = 250, shard_id: int = 0, shard_count: int = 1, http_settings: config.HTTPSettings, proxy_settings: config.ProxySettings, data_format: str = shard.GatewayDataFormat.JSON, event_manager: event_manager_.EventManager, event_factory: event_factory_.EventFactory, token: str, url: str, ) -> None: if data_format != shard.GatewayDataFormat.JSON: raise NotImplementedError(f"Unsupported gateway data format: {data_format}") if compression and compression != shard.GatewayCompression.TRANSPORT_ZLIB_STREAM: raise NotImplementedError(f"Unsupported compression format {compression}") self._activity = initial_activity self._event_manager = event_manager self._event_factory = event_factory self._gateway_url = url self._handshake_event: typing.Optional[asyncio.Event] = None self._heartbeat_latency = float("nan") self._http_settings = http_settings self._idle_since = initial_idle_since self._intents = intents self._is_afk = initial_is_afk self._is_closing = False self._keep_alive_task: typing.Optional[asyncio.Task[None]] = None self._large_threshold = large_threshold self._last_heartbeat_ack_received = float("nan") self._last_heartbeat_sent = float("nan") self._logger = logging.getLogger(f"hikari.gateway.{shard_id}") self._non_priority_rate_limit = rate_limits.WindowedBurstRateLimiter( f"shard {shard_id} non-priority rate limit", *_NON_PRIORITY_RATELIMIT ) self._proxy_settings = proxy_settings self._resume_gateway_url: typing.Optional[str] = None self._seq: typing.Optional[int] = None self._session_id: typing.Optional[str] = None self._shard_count = shard_count self._shard_id = shard_id self._status = initial_status self._token = token self._total_rate_limit = rate_limits.WindowedBurstRateLimiter( f"shard {shard_id} total rate limit", *_TOTAL_RATELIMIT ) self._transport_compression = compression is not None self._dumps = dumps self._loads = loads self._user_id: typing.Optional[snowflakes.Snowflake] = None self._ws: typing.Optional[_GatewayTransport] = None @property
[docs] def heartbeat_latency(self) -> float: return self._heartbeat_latency
@property
[docs] def id(self) -> int: return self._shard_id
@property
[docs] def intents(self) -> intents_.Intents: return self._intents
@property
[docs] def is_alive(self) -> bool: return self._keep_alive_task is not None
@property
[docs] def is_connected(self) -> bool: return self._ws is not None and self._handshake_event is not None and self._handshake_event.is_set()
@property
[docs] def shard_count(self) -> int: return self._shard_count
[docs] async def close(self) -> None: if not self._keep_alive_task: raise errors.ComponentStateConflictError("Cannot close an inactive shard") if self._is_closing: await self.join() return self._logger.info("shard has been requested to shutdown") self._is_closing = True self._keep_alive_task.cancel() try: await self._keep_alive_task except asyncio.CancelledError: pass self._keep_alive_task = None self._non_priority_rate_limit.close() self._total_rate_limit.close() self._is_closing = False self._logger.info("shard shutdown successfully")
[docs] def get_user_id(self) -> snowflakes.Snowflake: self._check_if_connected() assert self._user_id is not None, "user_id was not known, this is probably a bug" return self._user_id
[docs] async def join(self) -> None: if not self._keep_alive_task: raise errors.ComponentStateConflictError("Cannot join an inactive shard") await asyncio.wait_for(asyncio.shield(self._keep_alive_task), timeout=None)
async def _send_json(self, data: data_binding.JSONObject, /, priority: bool = False) -> None: if not priority: await self._non_priority_rate_limit.acquire() await self._total_rate_limit.acquire() assert self._ws is not None await self._ws.send_json(data) def _check_if_connected(self) -> None: if not self.is_connected: raise errors.ComponentStateConflictError( f"shard {self._shard_id} is not connected so it cannot be interacted with" )
[docs] async def request_guild_members( self, guild: snowflakes.SnowflakeishOr[guilds.PartialGuild], *, include_presences: undefined.UndefinedOr[bool] = undefined.UNDEFINED, query: str = "", limit: int = 0, users: undefined.UndefinedOr[snowflakes.SnowflakeishSequence[users_.User]] = undefined.UNDEFINED, nonce: undefined.UndefinedOr[str] = undefined.UNDEFINED, ) -> None: self._check_if_connected() if not query and not limit and not self._intents & intents_.Intents.GUILD_MEMBERS: raise errors.MissingIntentError(intents_.Intents.GUILD_MEMBERS) if include_presences and not self._intents & intents_.Intents.GUILD_PRESENCES: raise errors.MissingIntentError(intents_.Intents.GUILD_PRESENCES) if users is not undefined.UNDEFINED and (query or limit): raise ValueError("Cannot specify limit/query with users") if not 0 <= limit <= 100: raise ValueError("'limit' must be between 0 and 100, both inclusive") if users is not undefined.UNDEFINED and len(users) > 100: raise ValueError("'users' is limited to 100 users") if nonce is not undefined.UNDEFINED and len(bytes(nonce, "utf-8")) > 32: raise ValueError("'nonce' can be no longer than 32 byte characters long.") payload = data_binding.JSONObjectBuilder() payload.put_snowflake("guild_id", guild) payload.put("presences", include_presences) payload.put("query", query) payload.put("limit", limit) payload.put_snowflake_array("user_ids", users) payload.put("nonce", nonce) await self._send_json({_OP: _REQUEST_GUILD_MEMBERS, _D: payload})
[docs] async def start(self) -> None: if self._keep_alive_task or self._handshake_event: raise errors.ComponentStateConflictError("Cannot run more than one instance of one shard concurrently") self._handshake_event = asyncio.Event() keep_alive_task = asyncio.create_task( self._keep_alive(), name=f"keep alive (shard {self._shard_id})", ) await aio.first_completed(self._handshake_event.wait(), asyncio.shield(keep_alive_task)) if not self._handshake_event.is_set(): # This might throw an error, or it might not, depending on what we do with it. # This occurs if the run task finished before the handshake completion event, # which implies the shard died before it could become ready/resume... keep_alive_task.result() raise RuntimeError(f"shard {self._shard_id} was closed before it could start successfully") self._keep_alive_task = keep_alive_task
[docs] async def update_presence( self, *, idle_since: undefined.UndefinedNoneOr[datetime.datetime] = undefined.UNDEFINED, afk: undefined.UndefinedOr[bool] = undefined.UNDEFINED, activity: undefined.UndefinedNoneOr[presences.Activity] = undefined.UNDEFINED, status: undefined.UndefinedOr[presences.Status] = undefined.UNDEFINED, ) -> None: self._check_if_connected() presence_payload = self._serialize_and_store_presence_payload( idle_since=idle_since, afk=afk, activity=activity, status=status, ) await self._send_json({_OP: _PRESENCE_UPDATE, _D: presence_payload})
[docs] async def update_voice_state( self, guild: snowflakes.SnowflakeishOr[guilds.PartialGuild], channel: typing.Optional[snowflakes.SnowflakeishOr[channels.GuildVoiceChannel]], *, self_mute: undefined.UndefinedOr[bool] = undefined.UNDEFINED, self_deaf: undefined.UndefinedOr[bool] = undefined.UNDEFINED, ) -> None: self._check_if_connected() payload = data_binding.JSONObjectBuilder() payload.put_snowflake("guild_id", guild) payload.put_snowflake("channel_id", channel) payload.put("self_mute", self_mute) payload.put("self_deaf", self_deaf) await self._send_json({_OP: _VOICE_STATE_UPDATE, _D: payload})
async def _send_heartbeat(self) -> None: self._logger.log(ux.TRACE, "sending HEARTBEAT [s:%s]", self._seq) await self._send_json({_OP: _HEARTBEAT, _D: self._seq}, priority=True) self._last_heartbeat_sent = time.monotonic() async def _heartbeat(self, heartbeat_interval: float) -> None: # Prevent immediately zombie-ing. self._last_heartbeat_ack_received = time.monotonic() self._logger.debug("starting heartbeat with interval %ss", heartbeat_interval) while True: if self._last_heartbeat_ack_received <= self._last_heartbeat_sent: # Gateway is zombie, close and request reconnect. self._logger.error( "connection has not received a HEARTBEAT_ACK for approx %.1fs and is being disconnected; " "will attempt to reconnect", time.monotonic() - self._last_heartbeat_ack_received, ) return await self._send_heartbeat() await asyncio.sleep(heartbeat_interval) async def _poll_events(self) -> None: assert self._ws is not None assert self._handshake_event is not None while True: payload = await self._ws.receive_json() op = payload[_OP] if op == _DISPATCH: name = payload[_T] data = payload[_D] self._seq = payload[_S] self._logger.log(ux.TRACE, "dispatching %s with seq %s", name, self._seq) if name == _READY: self._session_id = data["session_id"] self._resume_gateway_url = data["resume_gateway_url"] user_pl = data["user"] self._user_id = snowflakes.Snowflake(user_pl["id"]) self._logger.info( "shard is ready: %s guilds, %s (%s), session %r on v%s gateway", len(data["guilds"]), f"{user_pl['username']}#{user_pl['discriminator']}", self._user_id, self._session_id, data["v"], ) self._handshake_event.set() elif name == _RESUMED: self._logger.info("resumed session [session:%s, seq:%s]", self._session_id, self._seq) self._handshake_event.set() try: self._event_manager.consume_raw_event(name, self, data) except LookupError: self._logger.debug("ignoring unknown event %s:\n %r", name, data) elif op == _HEARTBEAT_ACK: now = time.monotonic() self._last_heartbeat_ack_received = now self._heartbeat_latency = now - self._last_heartbeat_sent self._logger.log(ux.TRACE, "received HEARTBEAT ACK in %.1fms", self._heartbeat_latency * 1_000) elif op == _HEARTBEAT: self._logger.log(ux.TRACE, "sending heartbeat as requested by gateway") await self._send_heartbeat() elif op == _RECONNECT: self._logger.info("received instruction to reconnect, will resume existing session") return elif op == _INVALID_SESSION: can_reconnect = payload[_D] # We can resume if the payload data is `true`. if not can_reconnect: self._logger.info("received invalid session, will need to start a new session") self._seq = None self._resume_gateway_url = None self._session_id = None else: self._logger.info("received invalid session, will resume existing session") return else: self._logger.log(ux.TRACE, "unknown opcode %s received, it will be ignored...", op) async def _connect(self) -> typing.Tuple[asyncio.Task[None], ...]: if self._ws is not None: raise errors.ComponentStateConflictError("Attempting to connect an already connected shard") assert self._handshake_event is not None url_parts = urllib.parse.urlparse( self._resume_gateway_url or self._gateway_url, allow_fragments=True, ) query = dict(urllib.parse.parse_qsl(url_parts.query)) query["v"] = str(urls.VERSION) query["encoding"] = "json" if self._transport_compression: query["compress"] = "zlib-stream" url = urllib.parse.urlunparse( ( url_parts.scheme, url_parts.netloc, url_parts.path, url_parts.params, urllib.parse.urlencode(query), "", ) ) self._ws = await _GatewayTransport.connect( http_settings=self._http_settings, log_filterer=_log_filterer(self._token), logger=self._logger, proxy_settings=self._proxy_settings, transport_compression=self._transport_compression, loads=self._loads, dumps=self._dumps, url=url, ) self._event_manager.dispatch(self._event_factory.deserialize_connected_event(self)) # Expect initial HELLO hello_payload = await self._ws.receive_json() if hello_payload[_OP] != _HELLO: self._logger.debug( "expected %s (HELLO) opcode, received %s which makes no sense, closing with PROTOCOL ERROR", _HELLO, hello_payload[_OP], ) await self._ws.send_close(code=errors.ShardCloseCode.PROTOCOL_ERROR, message=b"Expected HELLO op") raise errors.GatewayError(f"Expected opcode {_HELLO} (HELLO), but received {hello_payload[_OP]}") # Spawn lifetime tasks heartbeat_interval = float(hello_payload[_D]["heartbeat_interval"]) / 1_000.0 heartbeat_task = asyncio.create_task( self._heartbeat(heartbeat_interval), name=f"heartbeat (shard {self._shard_id})" ) poll_events_task = asyncio.create_task(self._poll_events(), name=f"poll events (shard {self._shard_id})") # Perform handshake if self._seq is None: self._logger.debug("identifying with new session") await self._send_json( { _OP: _IDENTIFY, _D: { "token": self._token, "compress": False, "large_threshold": self._large_threshold, "properties": { "os": f"{platform.system()} {platform.architecture()[0]}", "browser": f"hikari ({about.__version__}, aiohttp {aiohttp.__version__})", "device": f"hikari {about.__version__}", }, "shard": [self._shard_id, self._shard_count], "intents": self._intents, "presence": self._serialize_and_store_presence_payload(), }, } ) else: self._logger.debug("resuming session %s", self._session_id) await self._send_json( { _OP: _RESUME, _D: {"token": self._token, "seq": self._seq, "session_id": self._session_id}, } ) lifetime_tasks = (heartbeat_task, poll_events_task) await aio.first_completed(self._handshake_event.wait(), *(asyncio.shield(t) for t in lifetime_tasks)) return lifetime_tasks async def _keep_alive(self) -> None: assert self._handshake_event is not None lifetime_tasks: typing.Tuple[asyncio.Task[None], ...] = () last_started_at = -float("inf") backoff = rate_limits.ExponentialBackOff(base=_BACKOFF_BASE, maximum=_BACKOFF_CAP) while True: self._handshake_event.clear() if time.monotonic() - last_started_at < _BACKOFF_WINDOW: backoff_time = next(backoff) self._logger.info("backing off reconnecting for %.2fs", backoff_time) await asyncio.sleep(backoff_time) try: last_started_at = time.monotonic() lifetime_tasks = await self._connect() if not self._handshake_event.is_set(): continue await aio.first_completed(*lifetime_tasks) # Since nothing went wrong, we can reset the backoff and try again backoff.reset() except errors.GatewayConnectionError as ex: self._logger.warning("failed to communicate with server, reason was: %r. Will retry shortly", ex.reason) except errors.GatewayServerClosedConnectionError as ex: if not ex.can_reconnect: self._logger.info( "server has closed the connection permanently [code:%s, reason:%s]", ex.code, ex.reason, ) raise self._logger.info( "server has closed the connection, will attempt to reconnect [code:%s, reason:%s]", ex.code, ex.reason, ) # We don't want to back off from this. If Discord keep closing the connection, it is their issue. # If we back off here, we'll find a mass outage will prevent shards from becoming healthy on # reconnect in large sharded bots for a very long period of time. backoff.reset() except errors.GatewayError as ex: self._logger.error("encountered generic gateway error", exc_info=ex) raise except asyncio.CancelledError: self._is_closing = True return except Exception as ex: self._logger.error("encountered some unhandled error", exc_info=ex) raise finally: # Cancel any left-over tasks for task in lifetime_tasks: if not task.done() and not task.cancelled(): task.cancel() try: await task except asyncio.CancelledError: pass # Close the ws if self._ws: ws = self._ws self._ws = None if self._is_closing: await ws.send_close( code=errors.ShardCloseCode.GOING_AWAY, message=b"shard disconnecting permanently" ) else: await ws.send_close(code=_RESUME_CLOSE_CODE, message=b"shard disconnecting temporarily") self._event_manager.dispatch(self._event_factory.deserialize_disconnected_event(self)) def _serialize_and_store_presence_payload( self, idle_since: undefined.UndefinedNoneOr[datetime.datetime] = undefined.UNDEFINED, afk: undefined.UndefinedOr[bool] = undefined.UNDEFINED, status: undefined.UndefinedOr[presences.Status] = undefined.UNDEFINED, activity: undefined.UndefinedNoneOr[presences.Activity] = undefined.UNDEFINED, ) -> data_binding.JSONObject: if activity is undefined.UNDEFINED: activity = self._activity else: self._activity = activity if status is undefined.UNDEFINED: status = self._status else: self._status = status if idle_since is undefined.UNDEFINED: idle_since = self._idle_since else: self._idle_since = idle_since if afk is undefined.UNDEFINED: afk = self._is_afk else: self._is_afk = afk payload = data_binding.JSONObjectBuilder() payload.put("since", idle_since, conversion=_serialize_datetime) payload.put("afk", afk) payload.put("game", activity, conversion=_serialize_activity) # Sending "offline" to the gateway won't do anything, we will have to # send "invisible" instead for this to work. payload.put("status", "invisible" if status is presences.Status.OFFLINE else status) return payload