Gateway: Fix repeat joins on same channel from stalling (#47)

Joining a channel returns a future which fires on receipt of two messages from discord (by locally storing a channel). However, joining this same channel again after a success returns only *one* such message, causing the command to hang until another join fires or the channel is left. This alters internal behaviour to correctly cancel an in-progress connection attempt, or return success with known data if such a connection is present.

This introduces a breaking change on `Call::update_state` to include the target `ChannelId`. The reason for this is that although the `ChannelId` of a target channel was being stored, server admins may move or kick a bot from its voice channel. This changes the true channel, and may accidentally trigger a "double join" elsewhere.

This fix was tested by using an example to have a bot join its channel twice, to do so in a channel it had been moved to, and to move from a channel it had been moved to.
This commit is contained in:
Kyle Simpson
2021-03-23 10:36:23 +00:00
parent ebff98e873
commit e59c546503
3 changed files with 152 additions and 35 deletions

View File

@@ -1,4 +1,4 @@
use crate::id::{GuildId, UserId};
use crate::id::{ChannelId, GuildId, UserId};
use std::fmt;
#[derive(Clone, Debug)]
@@ -8,8 +8,9 @@ pub(crate) enum ConnectionProgress {
}
impl ConnectionProgress {
pub fn new(guild_id: GuildId, user_id: UserId) -> Self {
pub(crate) fn new(guild_id: GuildId, user_id: UserId, channel_id: ChannelId) -> Self {
ConnectionProgress::Incomplete(Partial {
channel_id,
guild_id,
user_id,
..Default::default()
@@ -24,7 +25,46 @@ impl ConnectionProgress {
}
}
pub(crate) fn apply_state_update(&mut self, session_id: String) -> bool {
pub(crate) fn in_progress(&self) -> bool {
matches!(self, ConnectionProgress::Incomplete(_))
}
pub(crate) fn channel_id(&self) -> ChannelId {
match self {
ConnectionProgress::Complete(conn_info) => conn_info
.channel_id
.expect("All code paths MUST set channel_id for local tracking."),
ConnectionProgress::Incomplete(part) => part.channel_id,
}
}
pub(crate) fn guild_id(&self) -> GuildId {
match self {
ConnectionProgress::Complete(conn_info) => conn_info.guild_id,
ConnectionProgress::Incomplete(part) => part.guild_id,
}
}
pub(crate) fn user_id(&self) -> UserId {
match self {
ConnectionProgress::Complete(conn_info) => conn_info.user_id,
ConnectionProgress::Incomplete(part) => part.user_id,
}
}
pub(crate) fn info(&self) -> Option<ConnectionInfo> {
match self {
ConnectionProgress::Complete(conn_info) => Some(conn_info.clone()),
_ => None,
}
}
pub(crate) fn apply_state_update(&mut self, session_id: String, channel_id: ChannelId) -> bool {
if self.channel_id() != channel_id {
// Likely that the bot was moved to a different channel by an admin.
*self = ConnectionProgress::new(self.guild_id(), self.user_id(), channel_id);
}
use ConnectionProgress::*;
match self {
Complete(c) => {
@@ -33,7 +73,7 @@ impl ConnectionProgress {
should_reconn
},
Incomplete(i) => i
.apply_state_update(session_id)
.apply_state_update(session_id, channel_id)
.map(|info| {
*self = Complete(info);
})
@@ -66,6 +106,11 @@ impl ConnectionProgress {
/// with the Songbird driver, lavalink, or other system.
#[derive(Clone)]
pub struct ConnectionInfo {
/// ID of the voice channel being joined, if it is known.
///
/// This is not needed to establish a connection, but can be useful
/// for book-keeping.
pub channel_id: Option<ChannelId>,
/// URL of the voice websocket gateway server assigned to this call.
pub endpoint: String,
/// ID of the target voice channel's parent guild.
@@ -83,6 +128,7 @@ pub struct ConnectionInfo {
impl fmt::Debug for ConnectionInfo {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ConnectionInfo")
.field("channel_id", &self.channel_id)
.field("endpoint", &self.endpoint)
.field("guild_id", &self.guild_id)
.field("session_id", &self.session_id)
@@ -94,6 +140,7 @@ impl fmt::Debug for ConnectionInfo {
#[derive(Clone, Default)]
pub(crate) struct Partial {
pub channel_id: ChannelId,
pub endpoint: Option<String>,
pub guild_id: GuildId,
pub session_id: Option<String>,
@@ -104,6 +151,7 @@ pub(crate) struct Partial {
impl fmt::Debug for Partial {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Partial")
.field("channel_id", &self.channel_id)
.field("endpoint", &self.endpoint)
.field("session_id", &self.session_id)
.field("token_is_some", &self.token.is_some())
@@ -119,6 +167,7 @@ impl Partial {
let token = self.token.take().unwrap();
Some(ConnectionInfo {
channel_id: Some(self.channel_id),
endpoint,
session_id,
token,
@@ -130,7 +179,17 @@ impl Partial {
}
}
fn apply_state_update(&mut self, session_id: String) -> Option<ConnectionInfo> {
fn apply_state_update(
&mut self,
session_id: String,
channel_id: ChannelId,
) -> Option<ConnectionInfo> {
if self.channel_id != channel_id {
self.endpoint = None;
self.token = None;
}
self.channel_id = channel_id;
self.session_id = Some(session_id);
self.finalise()