Gateway: Fix repeat joins on same channel from stalling (#47)
Joining a channel returns a future which fires on receipt of two messages from discord (by locally storing a channel). However, joining this same channel again after a success returns only *one* such message, causing the command to hang until another join fires or the channel is left. This alters internal behaviour to correctly cancel an in-progress connection attempt, or return success with known data if such a connection is present. This introduces a breaking change on `Call::update_state` to include the target `ChannelId`. The reason for this is that although the `ChannelId` of a target channel was being stored, server admins may move or kick a bot from its voice channel. This changes the true channel, and may accidentally trigger a "double join" elsewhere. This fix was tested by using an example to have a bot join its channel twice, to do so in a channel it had been moved to, and to move from a channel it had been moved to.
This commit is contained in:
@@ -32,7 +32,7 @@ enum Return {
|
|||||||
/// [`Driver`]: struct@Driver
|
/// [`Driver`]: struct@Driver
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct Call {
|
pub struct Call {
|
||||||
connection: Option<(ChannelId, ConnectionProgress, Return)>,
|
connection: Option<(ConnectionProgress, Return)>,
|
||||||
|
|
||||||
#[cfg(feature = "driver-core")]
|
#[cfg(feature = "driver-core")]
|
||||||
/// The internal controller of the voice connection monitor thread.
|
/// The internal controller of the voice connection monitor thread.
|
||||||
@@ -132,12 +132,12 @@ impl Call {
|
|||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
fn do_connect(&mut self) {
|
fn do_connect(&mut self) {
|
||||||
match &self.connection {
|
match &self.connection {
|
||||||
Some((_, ConnectionProgress::Complete(c), Return::Info(tx))) => {
|
Some((ConnectionProgress::Complete(c), Return::Info(tx))) => {
|
||||||
// It's okay if the receiver hung up.
|
// It's okay if the receiver hung up.
|
||||||
let _ = tx.send(c.clone());
|
let _ = tx.send(c.clone());
|
||||||
},
|
},
|
||||||
#[cfg(feature = "driver-core")]
|
#[cfg(feature = "driver-core")]
|
||||||
Some((_, ConnectionProgress::Complete(c), Return::Conn(tx))) => {
|
Some((ConnectionProgress::Complete(c), Return::Conn(tx))) => {
|
||||||
self.driver.raw_connect(c.clone(), tx.clone());
|
self.driver.raw_connect(c.clone(), tx.clone());
|
||||||
},
|
},
|
||||||
_ => {},
|
_ => {},
|
||||||
@@ -171,6 +171,31 @@ impl Call {
|
|||||||
self.self_deaf
|
self.self_deaf
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn should_actually_join<F, G>(
|
||||||
|
&mut self,
|
||||||
|
completion_generator: F,
|
||||||
|
tx: &Sender<G>,
|
||||||
|
channel_id: ChannelId,
|
||||||
|
) -> JoinResult<bool>
|
||||||
|
where
|
||||||
|
F: FnOnce(&Self) -> G,
|
||||||
|
{
|
||||||
|
Ok(if let Some(conn) = &self.connection {
|
||||||
|
if conn.0.in_progress() {
|
||||||
|
self.leave().await?;
|
||||||
|
true
|
||||||
|
} else if conn.0.channel_id() == channel_id {
|
||||||
|
let _ = tx.send(completion_generator(&self));
|
||||||
|
false
|
||||||
|
} else {
|
||||||
|
// not in progress, and/or a channel change.
|
||||||
|
true
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(feature = "driver-core")]
|
#[cfg(feature = "driver-core")]
|
||||||
/// Connect or switch to the given voice channel by its Id.
|
/// Connect or switch to the given voice channel by its Id.
|
||||||
///
|
///
|
||||||
@@ -190,13 +215,20 @@ impl Call {
|
|||||||
) -> JoinResult<RecvFut<'static, ConnectionResult<()>>> {
|
) -> JoinResult<RecvFut<'static, ConnectionResult<()>>> {
|
||||||
let (tx, rx) = flume::unbounded();
|
let (tx, rx) = flume::unbounded();
|
||||||
|
|
||||||
|
let do_conn = self
|
||||||
|
.should_actually_join(|_| Ok(()), &tx, channel_id)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if do_conn {
|
||||||
self.connection = Some((
|
self.connection = Some((
|
||||||
channel_id,
|
ConnectionProgress::new(self.guild_id, self.user_id, channel_id),
|
||||||
ConnectionProgress::new(self.guild_id, self.user_id),
|
|
||||||
Return::Conn(tx),
|
Return::Conn(tx),
|
||||||
));
|
));
|
||||||
|
|
||||||
self.update().await.map(|_| rx.into_recv_async())
|
self.update().await.map(|_| rx.into_recv_async())
|
||||||
|
} else {
|
||||||
|
Ok(rx.into_recv_async())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Join the selected voice channel, *without* running/starting an RTP
|
/// Join the selected voice channel, *without* running/starting an RTP
|
||||||
@@ -221,13 +253,24 @@ impl Call {
|
|||||||
) -> JoinResult<RecvFut<'static, ConnectionInfo>> {
|
) -> JoinResult<RecvFut<'static, ConnectionInfo>> {
|
||||||
let (tx, rx) = flume::unbounded();
|
let (tx, rx) = flume::unbounded();
|
||||||
|
|
||||||
self.connection = Some((
|
let do_conn = self
|
||||||
|
.should_actually_join(
|
||||||
|
|call| call.connection.as_ref().unwrap().0.info().unwrap(),
|
||||||
|
&tx,
|
||||||
channel_id,
|
channel_id,
|
||||||
ConnectionProgress::new(self.guild_id, self.user_id),
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if do_conn {
|
||||||
|
self.connection = Some((
|
||||||
|
ConnectionProgress::new(self.guild_id, self.user_id, channel_id),
|
||||||
Return::Info(tx),
|
Return::Info(tx),
|
||||||
));
|
));
|
||||||
|
|
||||||
self.update().await.map(|_| rx.into_recv_async())
|
self.update().await.map(|_| rx.into_recv_async())
|
||||||
|
} else {
|
||||||
|
Ok(rx.into_recv_async())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the current voice connection details for this Call,
|
/// Returns the current voice connection details for this Call,
|
||||||
@@ -235,7 +278,7 @@ impl Call {
|
|||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
pub fn current_connection(&self) -> Option<&ConnectionInfo> {
|
pub fn current_connection(&self) -> Option<&ConnectionInfo> {
|
||||||
match &self.connection {
|
match &self.connection {
|
||||||
Some((_, progress, _)) => progress.get_connection_info(),
|
Some((progress, _)) => progress.get_connection_info(),
|
||||||
_ => None,
|
_ => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -265,13 +308,17 @@ impl Call {
|
|||||||
/// [`standalone`]: Call::standalone
|
/// [`standalone`]: Call::standalone
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
pub async fn leave(&mut self) -> JoinResult<()> {
|
pub async fn leave(&mut self) -> JoinResult<()> {
|
||||||
|
self.leave_local();
|
||||||
|
|
||||||
// Only send an update if we were in a voice channel.
|
// Only send an update if we were in a voice channel.
|
||||||
|
self.update().await
|
||||||
|
}
|
||||||
|
|
||||||
|
fn leave_local(&mut self) {
|
||||||
self.connection = None;
|
self.connection = None;
|
||||||
|
|
||||||
#[cfg(feature = "driver-core")]
|
#[cfg(feature = "driver-core")]
|
||||||
self.driver.leave();
|
self.driver.leave();
|
||||||
|
|
||||||
self.update().await
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sets whether the current connection is to be muted.
|
/// Sets whether the current connection is to be muted.
|
||||||
@@ -307,7 +354,7 @@ impl Call {
|
|||||||
/// [`standalone`]: Call::standalone
|
/// [`standalone`]: Call::standalone
|
||||||
#[instrument(skip(self, token))]
|
#[instrument(skip(self, token))]
|
||||||
pub fn update_server(&mut self, endpoint: String, token: String) {
|
pub fn update_server(&mut self, endpoint: String, token: String) {
|
||||||
let try_conn = if let Some((_, ref mut progress, _)) = self.connection.as_mut() {
|
let try_conn = if let Some((ref mut progress, _)) = self.connection.as_mut() {
|
||||||
progress.apply_server_update(endpoint, token)
|
progress.apply_server_update(endpoint, token)
|
||||||
} else {
|
} else {
|
||||||
false
|
false
|
||||||
@@ -325,9 +372,10 @@ impl Call {
|
|||||||
///
|
///
|
||||||
/// [`standalone`]: Call::standalone
|
/// [`standalone`]: Call::standalone
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
pub fn update_state(&mut self, session_id: String) {
|
pub fn update_state(&mut self, session_id: String, channel_id: Option<ChannelId>) {
|
||||||
let try_conn = if let Some((_, ref mut progress, _)) = self.connection.as_mut() {
|
if let Some(channel_id) = channel_id {
|
||||||
progress.apply_state_update(session_id)
|
let try_conn = if let Some((ref mut progress, _)) = self.connection.as_mut() {
|
||||||
|
progress.apply_state_update(session_id, channel_id)
|
||||||
} else {
|
} else {
|
||||||
false
|
false
|
||||||
};
|
};
|
||||||
@@ -335,6 +383,10 @@ impl Call {
|
|||||||
if try_conn {
|
if try_conn {
|
||||||
self.do_connect();
|
self.do_connect();
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
// Likely that we were disconnected by an admin.
|
||||||
|
self.leave_local();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Send an update for the current session over WS.
|
/// Send an update for the current session over WS.
|
||||||
@@ -348,7 +400,7 @@ impl Call {
|
|||||||
let map = json!({
|
let map = json!({
|
||||||
"op": 4,
|
"op": 4,
|
||||||
"d": {
|
"d": {
|
||||||
"channel_id": self.connection.as_ref().map(|c| c.0.0),
|
"channel_id": self.connection.as_ref().map(|c| c.0.channel_id().0),
|
||||||
"guild_id": self.guild_id.0,
|
"guild_id": self.guild_id.0,
|
||||||
"self_deaf": self.self_deaf,
|
"self_deaf": self.self_deaf,
|
||||||
"self_mute": self.self_mute,
|
"self_mute": self.self_mute,
|
||||||
|
|||||||
69
src/info.rs
69
src/info.rs
@@ -1,4 +1,4 @@
|
|||||||
use crate::id::{GuildId, UserId};
|
use crate::id::{ChannelId, GuildId, UserId};
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
@@ -8,8 +8,9 @@ pub(crate) enum ConnectionProgress {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl ConnectionProgress {
|
impl ConnectionProgress {
|
||||||
pub fn new(guild_id: GuildId, user_id: UserId) -> Self {
|
pub(crate) fn new(guild_id: GuildId, user_id: UserId, channel_id: ChannelId) -> Self {
|
||||||
ConnectionProgress::Incomplete(Partial {
|
ConnectionProgress::Incomplete(Partial {
|
||||||
|
channel_id,
|
||||||
guild_id,
|
guild_id,
|
||||||
user_id,
|
user_id,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
@@ -24,7 +25,46 @@ impl ConnectionProgress {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn apply_state_update(&mut self, session_id: String) -> bool {
|
pub(crate) fn in_progress(&self) -> bool {
|
||||||
|
matches!(self, ConnectionProgress::Incomplete(_))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn channel_id(&self) -> ChannelId {
|
||||||
|
match self {
|
||||||
|
ConnectionProgress::Complete(conn_info) => conn_info
|
||||||
|
.channel_id
|
||||||
|
.expect("All code paths MUST set channel_id for local tracking."),
|
||||||
|
ConnectionProgress::Incomplete(part) => part.channel_id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn guild_id(&self) -> GuildId {
|
||||||
|
match self {
|
||||||
|
ConnectionProgress::Complete(conn_info) => conn_info.guild_id,
|
||||||
|
ConnectionProgress::Incomplete(part) => part.guild_id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn user_id(&self) -> UserId {
|
||||||
|
match self {
|
||||||
|
ConnectionProgress::Complete(conn_info) => conn_info.user_id,
|
||||||
|
ConnectionProgress::Incomplete(part) => part.user_id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn info(&self) -> Option<ConnectionInfo> {
|
||||||
|
match self {
|
||||||
|
ConnectionProgress::Complete(conn_info) => Some(conn_info.clone()),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn apply_state_update(&mut self, session_id: String, channel_id: ChannelId) -> bool {
|
||||||
|
if self.channel_id() != channel_id {
|
||||||
|
// Likely that the bot was moved to a different channel by an admin.
|
||||||
|
*self = ConnectionProgress::new(self.guild_id(), self.user_id(), channel_id);
|
||||||
|
}
|
||||||
|
|
||||||
use ConnectionProgress::*;
|
use ConnectionProgress::*;
|
||||||
match self {
|
match self {
|
||||||
Complete(c) => {
|
Complete(c) => {
|
||||||
@@ -33,7 +73,7 @@ impl ConnectionProgress {
|
|||||||
should_reconn
|
should_reconn
|
||||||
},
|
},
|
||||||
Incomplete(i) => i
|
Incomplete(i) => i
|
||||||
.apply_state_update(session_id)
|
.apply_state_update(session_id, channel_id)
|
||||||
.map(|info| {
|
.map(|info| {
|
||||||
*self = Complete(info);
|
*self = Complete(info);
|
||||||
})
|
})
|
||||||
@@ -66,6 +106,11 @@ impl ConnectionProgress {
|
|||||||
/// with the Songbird driver, lavalink, or other system.
|
/// with the Songbird driver, lavalink, or other system.
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct ConnectionInfo {
|
pub struct ConnectionInfo {
|
||||||
|
/// ID of the voice channel being joined, if it is known.
|
||||||
|
///
|
||||||
|
/// This is not needed to establish a connection, but can be useful
|
||||||
|
/// for book-keeping.
|
||||||
|
pub channel_id: Option<ChannelId>,
|
||||||
/// URL of the voice websocket gateway server assigned to this call.
|
/// URL of the voice websocket gateway server assigned to this call.
|
||||||
pub endpoint: String,
|
pub endpoint: String,
|
||||||
/// ID of the target voice channel's parent guild.
|
/// ID of the target voice channel's parent guild.
|
||||||
@@ -83,6 +128,7 @@ pub struct ConnectionInfo {
|
|||||||
impl fmt::Debug for ConnectionInfo {
|
impl fmt::Debug for ConnectionInfo {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
f.debug_struct("ConnectionInfo")
|
f.debug_struct("ConnectionInfo")
|
||||||
|
.field("channel_id", &self.channel_id)
|
||||||
.field("endpoint", &self.endpoint)
|
.field("endpoint", &self.endpoint)
|
||||||
.field("guild_id", &self.guild_id)
|
.field("guild_id", &self.guild_id)
|
||||||
.field("session_id", &self.session_id)
|
.field("session_id", &self.session_id)
|
||||||
@@ -94,6 +140,7 @@ impl fmt::Debug for ConnectionInfo {
|
|||||||
|
|
||||||
#[derive(Clone, Default)]
|
#[derive(Clone, Default)]
|
||||||
pub(crate) struct Partial {
|
pub(crate) struct Partial {
|
||||||
|
pub channel_id: ChannelId,
|
||||||
pub endpoint: Option<String>,
|
pub endpoint: Option<String>,
|
||||||
pub guild_id: GuildId,
|
pub guild_id: GuildId,
|
||||||
pub session_id: Option<String>,
|
pub session_id: Option<String>,
|
||||||
@@ -104,6 +151,7 @@ pub(crate) struct Partial {
|
|||||||
impl fmt::Debug for Partial {
|
impl fmt::Debug for Partial {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
f.debug_struct("Partial")
|
f.debug_struct("Partial")
|
||||||
|
.field("channel_id", &self.channel_id)
|
||||||
.field("endpoint", &self.endpoint)
|
.field("endpoint", &self.endpoint)
|
||||||
.field("session_id", &self.session_id)
|
.field("session_id", &self.session_id)
|
||||||
.field("token_is_some", &self.token.is_some())
|
.field("token_is_some", &self.token.is_some())
|
||||||
@@ -119,6 +167,7 @@ impl Partial {
|
|||||||
let token = self.token.take().unwrap();
|
let token = self.token.take().unwrap();
|
||||||
|
|
||||||
Some(ConnectionInfo {
|
Some(ConnectionInfo {
|
||||||
|
channel_id: Some(self.channel_id),
|
||||||
endpoint,
|
endpoint,
|
||||||
session_id,
|
session_id,
|
||||||
token,
|
token,
|
||||||
@@ -130,7 +179,17 @@ impl Partial {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn apply_state_update(&mut self, session_id: String) -> Option<ConnectionInfo> {
|
fn apply_state_update(
|
||||||
|
&mut self,
|
||||||
|
session_id: String,
|
||||||
|
channel_id: ChannelId,
|
||||||
|
) -> Option<ConnectionInfo> {
|
||||||
|
if self.channel_id != channel_id {
|
||||||
|
self.endpoint = None;
|
||||||
|
self.token = None;
|
||||||
|
}
|
||||||
|
|
||||||
|
self.channel_id = channel_id;
|
||||||
self.session_id = Some(session_id);
|
self.session_id = Some(session_id);
|
||||||
|
|
||||||
self.finalise()
|
self.finalise()
|
||||||
|
|||||||
@@ -351,7 +351,10 @@ impl Songbird {
|
|||||||
|
|
||||||
if let Some(call) = call {
|
if let Some(call) = call {
|
||||||
let mut handler = call.lock().await;
|
let mut handler = call.lock().await;
|
||||||
handler.update_state(v.0.session_id.clone());
|
handler.update_state(
|
||||||
|
v.0.session_id.clone(),
|
||||||
|
v.0.channel_id.clone().map(Into::into),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
_ => {},
|
_ => {},
|
||||||
@@ -390,7 +393,10 @@ impl VoiceGatewayManager for Songbird {
|
|||||||
|
|
||||||
if let Some(call) = self.get(guild_id) {
|
if let Some(call) = self.get(guild_id) {
|
||||||
let mut handler = call.lock().await;
|
let mut handler = call.lock().await;
|
||||||
handler.update_state(voice_state.session_id.clone());
|
handler.update_state(
|
||||||
|
voice_state.session_id.clone(),
|
||||||
|
voice_state.channel_id.clone().map(Into::into),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user