Source code for hikari.impl.voice

# -*- coding: utf-8 -*-
# cython: language_level=3
# Copyright (c) 2020 Nekokatt
# Copyright (c) 2021-present davfsa
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""Implementation of a simple voice management system."""

from __future__ import annotations

__all__: typing.Sequence[str] = ("VoiceComponentImpl",)

import asyncio
import logging
import types
import typing

from hikari import errors
from hikari import snowflakes
from hikari.api import voice
from hikari.events import voice_events
from hikari.internal import ux

if typing.TYPE_CHECKING:
    from hikari import channels
    from hikari import guilds
    from hikari import traits

    _VoiceConnectionT = typing.TypeVar("_VoiceConnectionT", bound="voice.VoiceConnection")

_LOGGER: typing.Final[logging.Logger] = logging.getLogger("hikari.voice.management")


[docs] class VoiceComponentImpl(voice.VoiceComponent): """A standard voice component management implementation. This is the regular implementation you will generally use to connect to voice channels with. """ __slots__: typing.Sequence[str] = ( "_app", "_connections", "connections", "_is_alive", "_is_closing", "_voice_listener", ) _connections: typing.Dict[snowflakes.Snowflake, voice.VoiceConnection] connections: typing.Mapping[snowflakes.Snowflake, voice.VoiceConnection] def __init__(self, app: traits.GatewayBotAware) -> None: self._app = app self._connections = {} self.connections = types.MappingProxyType(self._connections) self._is_alive = False self._is_closing = False self._voice_listener = False @property
[docs] def is_alive(self) -> bool: return self._is_alive
def _check_if_alive(self) -> None: if not self._is_alive: raise errors.ComponentStateConflictError("Component cannot be used while it's not alive") if self._is_closing: raise errors.ComponentStateConflictError("Component cannot be used while it's closing")
[docs] async def disconnect(self, guild: snowflakes.SnowflakeishOr[guilds.PartialGuild]) -> None: self._check_if_alive() guild_id = snowflakes.Snowflake(guild) if guild_id not in self._connections: raise errors.VoiceError("This application doesn't have any active voice connection in this server") conn = self._connections[guild_id] # We rely on the assumption that _on_connection_close will be called here rather than explicitly # to remove the connection from self._connections. await conn.disconnect()
async def _disconnect_all(self) -> None: # We rely on the assumption that _on_connection_close will be called here rather than explicitly # emptying self._connections. await asyncio.gather(*(c.disconnect() for c in self._connections.values()))
[docs] async def disconnect_all(self) -> None: self._check_if_alive() await self._disconnect_all()
[docs] async def close(self) -> None: self._check_if_alive() self._is_closing = True if self._voice_listener: self._app.event_manager.unsubscribe(voice_events.VoiceEvent, self._on_voice_event) if self._connections: _LOGGER.info("shutting down %s active voice connection(s)", len(self._connections)) await self._disconnect_all() self._is_alive = False self._is_closing = False self._voice_listener = False
[docs] def start(self) -> None: """Start this voice component.""" if self._is_alive: raise errors.ComponentStateConflictError("Cannot start a voice component which is already running") self._is_alive = True self._voice_listener = False
[docs] async def connect_to( self, guild: snowflakes.SnowflakeishOr[guilds.PartialGuild], channel: snowflakes.SnowflakeishOr[channels.GuildVoiceChannel], voice_connection_type: typing.Type[_VoiceConnectionT], *, deaf: bool = False, mute: bool = False, timeout: typing.Optional[int] = 5, **kwargs: typing.Any, ) -> _VoiceConnectionT: self._check_if_alive() guild_id = snowflakes.Snowflake(guild) if guild_id in self._connections: raise errors.VoiceError( "Already in a voice channel for that guild. Disconnect before attempting to connect again" ) shard_id = snowflakes.calculate_shard_id(self._app, guild_id) try: shard = self._app.shards[shard_id] except KeyError: raise errors.VoiceError( f"Cannot connect to shard {shard_id} as it is not present in this application" ) from None user = self._app.cache.get_me() if not user: user = await self._app.rest.fetch_my_user() _LOGGER.log(ux.TRACE, "attempting to connect to voice channel %s in %s via shard %s", channel, guild, shard_id) await shard.update_voice_state(guild, channel, self_deaf=deaf, self_mute=mute) _LOGGER.log( ux.TRACE, "waiting for voice events for connecting to voice channel %s in %s via shard %s", channel, guild, shard_id, ) try: state_event, server_event = await asyncio.gather( # Voice state update: self._app.event_manager.wait_for( voice_events.VoiceStateUpdateEvent, timeout=timeout, predicate=self._init_state_update_predicate(guild_id, user.id), ), # Server update: self._app.event_manager.wait_for( voice_events.VoiceServerUpdateEvent, timeout=timeout, predicate=self._init_server_update_predicate(guild_id), ), ) except asyncio.TimeoutError as e: raise errors.VoiceError(f"Could not connect to voice channel {channel} in guild {guild}.") from e # We will never receive the first endpoint as `None` assert server_event.endpoint is not None _LOGGER.debug( "joined voice channel %s in guild %s via shard %s using endpoint %s. Session will be %s. " "Delegating to voice websocket", state_event.state.channel_id, state_event.state.guild_id, shard_id, server_event.endpoint, state_event.state.session_id, ) try: voice_connection = await voice_connection_type.initialize( channel_id=snowflakes.Snowflake(channel), endpoint=server_event.endpoint, guild_id=guild_id, on_close=self._on_connection_close, owner=self, session_id=state_event.state.session_id, shard_id=shard_id, token=server_event.token, user_id=user.id, **kwargs, ) except Exception: _LOGGER.debug("error occurred in initialization, leaving voice channel %s in guild %s", channel, guild) try: await asyncio.wait_for(shard.update_voice_state(guild, None), timeout=5.0) except asyncio.TimeoutError: pass raise if not self._voice_listener: self._app.event_manager.subscribe(voice_events.VoiceEvent, self._on_voice_event) self._voice_listener = True self._connections[guild_id] = voice_connection return voice_connection
@staticmethod def _init_state_update_predicate( guild_id: snowflakes.Snowflake, user_id: snowflakes.Snowflake ) -> typing.Callable[[voice_events.VoiceStateUpdateEvent], bool]: def predicate(event: voice_events.VoiceStateUpdateEvent) -> bool: return event.state.guild_id == guild_id and event.state.user_id == user_id return predicate @staticmethod def _init_server_update_predicate( guild_id: snowflakes.Snowflake, ) -> typing.Callable[[voice_events.VoiceServerUpdateEvent], bool]: def predicate(event: voice_events.VoiceServerUpdateEvent) -> bool: return event.guild_id == guild_id return predicate async def _on_connection_close(self, connection: voice.VoiceConnection) -> None: try: del self._connections[connection.guild_id] if not self._connections: self._app.event_manager.unsubscribe(voice_events.VoiceEvent, self._on_voice_event) self._voice_listener = False # Leave the voice channel explicitly, otherwise we will just appear to # not leave properly. await self._app.shards[connection.shard_id].update_voice_state(guild=connection.guild_id, channel=None) _LOGGER.debug( "successfully unregistered voice connection %s to guild %s and left voice channel %s", connection, connection.guild_id, connection.channel_id, ) except KeyError: _LOGGER.warning( "ignored closure of phantom unregistered voice connection %s to guild %s. Perhaps this is a bug?", connection, connection.guild_id, ) async def _on_voice_event(self, event: voice_events.VoiceEvent) -> None: if event.guild_id in self._connections: connection = self._connections[event.guild_id] _LOGGER.log( ux.TRACE, "notifying voice connection %s in guild %s of event %s", connection, event.guild_id, event ) await connection.notify(event)