From 210e3ae58499fa45edf9b65de6d9114292341d28 Mon Sep 17 00:00:00 2001 From: Kyle Simpson Date: Wed, 23 Jun 2021 17:11:14 +0100 Subject: [PATCH] Driver: Automate (re)connection logic (#81) This PR adds several enhancements to Driver connection logic: * Driver (re)connection attempts now have a default timeout of around 10s. * The driver will now attempt to retry full connection attempts using a user-provided strategy: currently, this defaults to 5 attempts under an exponential backoff strategy. * The driver will now fire `DriverDisconnect` events at the end of any session -- this unifies (re)connection failure events with session expiry as seen in #76, which should provide users with enough detail to know *which* voice channel to reconnect to. Users still need to be careful to read the session/channel IDs to ensure that they aren't overwriting another join. This has been tested using `cargo make ready`, and by setting low timeouts to force failures in the voice receive example (with some additional error handlers). Closes #68. --- src/config.rs | 33 ++- src/driver/connection/error.rs | 14 ++ src/driver/connection/mod.rs | 31 ++- src/driver/mod.rs | 1 + src/driver/retry/mod.rs | 49 +++++ src/driver/retry/strategy.rs | 84 ++++++++ src/driver/tasks/message/core.rs | 4 +- src/driver/tasks/mod.rs | 284 ++++++++++++++++++++------ src/driver/tasks/ws.rs | 37 +++- src/events/context/data/connect.rs | 11 + src/events/context/data/disconnect.rs | 119 +++++++++++ src/events/context/data/mod.rs | 3 +- src/events/context/internal_data.rs | 27 ++- src/events/context/mod.rs | 43 +++- src/events/core.rs | 10 + src/handler.rs | 10 +- src/info.rs | 2 +- 17 files changed, 672 insertions(+), 90 deletions(-) create mode 100644 src/driver/retry/mod.rs create mode 100644 src/driver/retry/strategy.rs create mode 100644 src/events/context/data/disconnect.rs diff --git a/src/config.rs b/src/config.rs index caad394..692a855 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,7 +1,6 @@ #[cfg(feature = "driver-core")] -use super::driver::{CryptoMode, DecodeMode}; +use super::driver::{retry::Retry, CryptoMode, DecodeMode}; -#[cfg(feature = "gateway-core")] use std::time::Duration; /// Configuration for drivers and calls. @@ -61,6 +60,20 @@ pub struct Config { /// Changes to this field in a running driver will only ever increase /// the capacity of the track store. pub preallocated_tracks: usize, + #[cfg(feature = "driver-core")] + /// Connection retry logic for the [`Driver`]. + /// + /// This controls how many times the [`Driver`] should retry any connections, + /// as well as how long to wait between attempts. + /// + /// [`Driver`]: crate::driver::Driver + pub driver_retry: Retry, + #[cfg(feature = "driver-core")] + /// Configures the maximum amount of time to wait for an attempted voice + /// connection to Discord. + /// + /// Defaults to 10 seconds. If set to `None`, connections will never time out. + pub driver_timeout: Option, } impl Default for Config { @@ -74,6 +87,10 @@ impl Default for Config { gateway_timeout: Some(Duration::from_secs(10)), #[cfg(feature = "driver-core")] preallocated_tracks: 1, + #[cfg(feature = "driver-core")] + driver_retry: Default::default(), + #[cfg(feature = "driver-core")] + driver_timeout: Some(Duration::from_secs(10)), } } } @@ -98,6 +115,18 @@ impl Config { self } + /// Sets this `Config`'s timeout for establishing a voice connection. + pub fn driver_timeout(mut self, driver_timeout: Option) -> Self { + self.driver_timeout = driver_timeout; + self + } + + /// Sets this `Config`'s voice connection retry configuration. + pub fn driver_retry(mut self, driver_retry: Retry) -> Self { + self.driver_retry = driver_retry; + self + } + /// This is used to prevent changes which would invalidate the current session. pub(crate) fn make_safe(&mut self, previous: &Config, connected: bool) { if connected { diff --git a/src/driver/connection/error.rs b/src/driver/connection/error.rs index d7f3236..2dd01bd 100644 --- a/src/driver/connection/error.rs +++ b/src/driver/connection/error.rs @@ -7,6 +7,10 @@ use crate::{ use flume::SendError; use serde_json::Error as JsonError; use std::{error::Error as StdError, fmt, io::Error as IoError}; +#[cfg(not(feature = "tokio-02-marker"))] +use tokio::time::error::Elapsed; +#[cfg(feature = "tokio-02-marker")] +use tokio_compat::time::Elapsed; use xsalsa20poly1305::aead::Error as CryptoError; /// Errors encountered while connecting to a Discord voice server over the driver. @@ -38,6 +42,8 @@ pub enum Error { InterconnectFailure(Recipient), /// Error communicating with gateway server over WebSocket. Ws(WsError), + /// Connection attempt timed out. + TimedOut, } impl From for Error { @@ -82,6 +88,12 @@ impl From for Error { } } +impl From for Error { + fn from(_e: Elapsed) -> Error { + Error::TimedOut + } +} + impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "failed to connect to Discord RTP server: ")?; @@ -99,6 +111,7 @@ impl fmt::Display for Error { Json(e) => e.fmt(f), InterconnectFailure(e) => write!(f, "failed to contact other task ({:?})", e), Ws(e) => write!(f, "websocket issue ({:?}).", e), + TimedOut => write!(f, "connection attempt timed out"), } } } @@ -118,6 +131,7 @@ impl StdError for Error { Error::Json(e) => e.source(), Error::InterconnectFailure(_) => None, Error::Ws(_) => None, + Error::TimedOut => None, } } } diff --git a/src/driver/connection/mod.rs b/src/driver/connection/mod.rs index cb82a86..6eb9b16 100644 --- a/src/driver/connection/mod.rs +++ b/src/driver/connection/mod.rs @@ -21,9 +21,9 @@ use error::{Error, Result}; use flume::Sender; use std::{net::IpAddr, str::FromStr, sync::Arc}; #[cfg(not(feature = "tokio-02-marker"))] -use tokio::{net::UdpSocket, spawn}; +use tokio::{net::UdpSocket, spawn, time::timeout}; #[cfg(feature = "tokio-02-marker")] -use tokio_compat::{net::UdpSocket, spawn}; +use tokio_compat::{net::UdpSocket, spawn, time::timeout}; use tracing::{debug, info, instrument}; use url::Url; use xsalsa20poly1305::{aead::NewAead, XSalsa20Poly1305 as Cipher}; @@ -42,9 +42,23 @@ pub(crate) struct Connection { impl Connection { pub(crate) async fn new( + info: ConnectionInfo, + interconnect: &Interconnect, + config: &Config, + idx: usize, + ) -> Result { + if let Some(t) = config.driver_timeout { + timeout(t, Connection::new_inner(info, interconnect, config, idx)).await? + } else { + Connection::new_inner(info, interconnect, config, idx).await + } + } + + pub(crate) async fn new_inner( mut info: ConnectionInfo, interconnect: &Interconnect, config: &Config, + idx: usize, ) -> Result { let url = generate_url(&mut info.endpoint)?; @@ -207,6 +221,8 @@ impl Connection { client, ssrc, hello.heartbeat_interval, + idx, + info.clone(), )); spawn(udp_rx::runner( @@ -226,7 +242,16 @@ impl Connection { } #[instrument(skip(self))] - pub async fn reconnect(&mut self) -> Result<()> { + pub async fn reconnect(&mut self, config: &Config) -> Result<()> { + if let Some(t) = config.driver_timeout { + timeout(t, self.reconnect_inner()).await? + } else { + self.reconnect_inner().await + } + } + + #[instrument(skip(self))] + pub async fn reconnect_inner(&mut self) -> Result<()> { let url = generate_url(&mut self.info.endpoint)?; // Thread may have died, we want to send to prompt a clean exit diff --git a/src/driver/mod.rs b/src/driver/mod.rs index a732a5b..cf1a021 100644 --- a/src/driver/mod.rs +++ b/src/driver/mod.rs @@ -14,6 +14,7 @@ pub mod bench_internals; pub(crate) mod connection; mod crypto; mod decode_mode; +pub mod retry; pub(crate) mod tasks; use connection::error::{Error, Result}; diff --git a/src/driver/retry/mod.rs b/src/driver/retry/mod.rs new file mode 100644 index 0000000..e25374b --- /dev/null +++ b/src/driver/retry/mod.rs @@ -0,0 +1,49 @@ +//! Configuration for connection retries. + +mod strategy; + +pub use self::strategy::*; + +use std::time::Duration; + +/// Configuration to be used for retrying driver connection attempts. +#[derive(Clone, Copy, Debug, PartialEq)] +pub struct Retry { + /// Strategy used to determine how long to wait between retry attempts. + /// + /// *Defaults to an [`ExponentialBackoff`] from 0.25s + /// to 10s, with a jitter of `0.1`.* + /// + /// [`ExponentialBackoff`]: Strategy::Backoff + pub strategy: Strategy, + /// The maximum number of retries to attempt. + /// + /// `None` will attempt an infinite number of retries, + /// while `Some(0)` will attempt to connect *once* (no retries). + /// + /// *Defaults to `Some(5)`.* + pub retry_limit: Option, +} + +impl Default for Retry { + fn default() -> Self { + Self { + strategy: Strategy::Backoff(Default::default()), + retry_limit: Some(5), + } + } +} + +impl Retry { + pub(crate) fn retry_in( + &self, + last_wait: Option, + attempts: usize, + ) -> Option { + if self.retry_limit.map(|a| attempts < a).unwrap_or(true) { + Some(self.strategy.retry_in(last_wait)) + } else { + None + } + } +} diff --git a/src/driver/retry/strategy.rs b/src/driver/retry/strategy.rs new file mode 100644 index 0000000..6de58e7 --- /dev/null +++ b/src/driver/retry/strategy.rs @@ -0,0 +1,84 @@ +use rand::random; +use std::time::Duration; + +/// Logic used to determine how long to wait between retry attempts. +#[derive(Clone, Copy, Debug, PartialEq)] +#[non_exhaustive] +pub enum Strategy { + /// The driver will wait for the same amount of time between each retry. + Every(Duration), + /// Exponential backoff waiting strategy, where the duration between + /// attempts (approximately) doubles each time. + Backoff(ExponentialBackoff), +} + +impl Strategy { + pub(crate) fn retry_in(&self, last_wait: Option) -> Duration { + match self { + Self::Every(t) => *t, + Self::Backoff(exp) => exp.retry_in(last_wait), + } + } +} + +/// Exponential backoff waiting strategy. +/// +/// Each attempt waits for twice the last delay plus/minus a +/// random jitter, clamped to a min and max value. +#[derive(Clone, Copy, Debug, PartialEq)] +pub struct ExponentialBackoff { + /// Minimum amount of time to wait between retries. + /// + /// *Defaults to 0.25s.* + pub min: Duration, + /// Maximum amount of time to wait between retries. + /// + /// This will be clamped to `>=` min. + /// + /// *Defaults to 10s.* + pub max: Duration, + /// Amount of uniform random jitter to apply to generated wait times. + /// I.e., 0.1 will add +/-10% to generated intervals. + /// + /// This is restricted to within +/-100%. + /// + /// *Defaults to `0.1`.* + pub jitter: f32, +} + +impl Default for ExponentialBackoff { + fn default() -> Self { + Self { + min: Duration::from_millis(250), + max: Duration::from_secs(10), + jitter: 0.1, + } + } +} + +impl ExponentialBackoff { + pub(crate) fn retry_in(&self, last_wait: Option) -> Duration { + let attempt = last_wait.map(|t| 2 * t).unwrap_or(self.min); + let perturb = (1.0 - (self.jitter * 2.0 * (random::() - 1.0))) + .max(0.0) + .min(2.0); + let mut target_time = attempt.mul_f32(perturb); + + // Now clamp target time into given range. + let safe_max = if self.max < self.min { + self.min + } else { + self.max + }; + + if target_time > safe_max { + target_time = safe_max; + } + + if target_time < self.min { + target_time = self.min; + } + + target_time + } +} diff --git a/src/driver/tasks/message/core.rs b/src/driver/tasks/message/core.rs index b8b761e..5fbc0df 100644 --- a/src/driver/tasks/message/core.rs +++ b/src/driver/tasks/message/core.rs @@ -2,7 +2,7 @@ use crate::{ driver::{connection::error::Error, Bitrate, Config}, - events::EventData, + events::{context_data::DisconnectReason, EventData}, tracks::Track, ConnectionInfo, }; @@ -12,6 +12,8 @@ use flume::Sender; #[derive(Debug)] pub enum CoreMessage { ConnectWithResult(ConnectionInfo, Sender>), + RetryConnect(usize), + SignalWsClosure(usize, ConnectionInfo, Option), Disconnect, SetTrack(Option), AddTrack(Track), diff --git a/src/driver/tasks/mod.rs b/src/driver/tasks/mod.rs index 8f9c26c..b138964 100644 --- a/src/driver/tasks/mod.rs +++ b/src/driver/tasks/mod.rs @@ -9,18 +9,25 @@ pub(crate) mod udp_rx; pub(crate) mod udp_tx; pub(crate) mod ws; +use std::time::Duration; + use super::connection::{error::Error as ConnectionError, Connection}; use crate::{ - events::{internal_data::InternalConnect, CoreContext}, + events::{ + context_data::{DisconnectKind, DisconnectReason}, + internal_data::{InternalConnect, InternalDisconnect}, + CoreContext, + }, Config, + ConnectionInfo, }; use flume::{Receiver, RecvError, Sender}; use message::*; #[cfg(not(feature = "tokio-02-marker"))] -use tokio::{runtime::Handle, spawn}; +use tokio::{runtime::Handle, spawn, time::sleep as tsleep}; #[cfg(feature = "tokio-02-marker")] -use tokio_compat::{runtime::Handle, spawn}; -use tracing::{error, instrument, trace}; +use tokio_compat::{runtime::Handle, spawn, time::delay_for as tsleep}; +use tracing::{debug, instrument, trace}; pub(crate) fn start(config: Config, rx: Receiver, tx: Sender) { spawn(async move { @@ -61,8 +68,10 @@ fn start_internals(core: Sender, config: Config) -> Interconnect { #[instrument(skip(rx, tx))] async fn runner(mut config: Config, rx: Receiver, tx: Sender) { let mut next_config: Option = None; - let mut connection = None; + let mut connection: Option = None; let mut interconnect = start_internals(tx, config.clone()); + let mut retrying = None; + let mut attempt_idx = 0; loop { match rx.recv_async().await { @@ -76,36 +85,69 @@ async fn runner(mut config: Config, rx: Receiver, tx: Sender { - // Other side may not be listening: this is fine. - let _ = tx.send(Ok(())); - - let _ = interconnect.events.send(EventMessage::FireCoreEvent( - CoreContext::DriverConnect(InternalConnect { - server: connection.info.endpoint.clone(), - ssrc: connection.ssrc, - }), - )); - - Some(connection) - }, - Err(why) => { - // See above. - let _ = tx.send(Err(why)); - - let _ = interconnect.events.send(EventMessage::FireCoreEvent( - CoreContext::DriverConnectFailed, - )); - - None - }, - }; + if connection + .as_ref() + .map(|conn| conn.info != info) + .unwrap_or(true) + { + // Only *actually* reconnect if the conn info changed, or we don't have an + // active connection. + // This allows the gateway component to keep sending join requests independent + // of driver failures. + connection = ConnectionRetryData::connect(tx, info, &mut attempt_idx) + .attempt(&mut retrying, &interconnect, &config) + .await; + } else { + // No reconnection was attempted as there's a valid, identical connection; + // tell the outside listener that the operation was a success. + let _ = tx.send(Ok(())); + } + }, + Ok(CoreMessage::RetryConnect(retry_idx)) => { + debug!("Retrying idx: {} (vs. {})", retry_idx, attempt_idx); + if retry_idx == attempt_idx { + if let Some(progress) = retrying.take() { + connection = progress + .attempt(&mut retrying, &interconnect, &config) + .await; + } + } }, Ok(CoreMessage::Disconnect) => { - connection = None; + let last_conn = connection.take(); let _ = interconnect.mixer.send(MixerMessage::DropConn); let _ = interconnect.mixer.send(MixerMessage::RebuildEncoder); + + if let Some(conn) = last_conn { + let _ = interconnect.events.send(EventMessage::FireCoreEvent( + CoreContext::DriverDisconnect(InternalDisconnect { + kind: DisconnectKind::Runtime, + reason: None, + info: conn.info.clone(), + }), + )); + } + }, + Ok(CoreMessage::SignalWsClosure(ws_idx, ws_info, mut reason)) => { + // if idx is not a match, quash reason + // (i.e., prevent users from mistakenly trying to reconnect for an *old* dead conn). + // if it *is* a match, the conn needs to die! + // (as the WS channel has truly given up the ghost). + if ws_idx != attempt_idx { + reason = None; + } else { + connection = None; + let _ = interconnect.mixer.send(MixerMessage::DropConn); + let _ = interconnect.mixer.send(MixerMessage::RebuildEncoder); + } + + let _ = interconnect.events.send(EventMessage::FireCoreEvent( + CoreContext::DriverDisconnect(InternalDisconnect { + kind: DisconnectKind::Runtime, + reason, + info: ws_info, + }), + )); }, Ok(CoreMessage::SetTrack(s)) => { let _ = interconnect.mixer.send(MixerMessage::SetTrack(s)); @@ -138,7 +180,7 @@ async fn runner(mut config: Config, rx: Receiver, tx: Sender { connection = Some(conn); false @@ -146,7 +188,7 @@ async fn runner(mut config: Config, rx: Receiver, tx: Sender { interconnect.restart_volatile_internals(); - match conn.reconnect().await { + match conn.reconnect(&config).await { Ok(()) => { connection = Some(conn); false @@ -158,22 +200,13 @@ async fn runner(mut config: Config, rx: Receiver, tx: Sender, tx: Sender { interconnect.restart_volatile_internals(); @@ -216,3 +233,138 @@ async fn runner(mut config: Config, rx: Receiver, tx: Sender, + info: ConnectionInfo, + idx: usize, +} + +impl ConnectionRetryData { + fn connect( + tx: Sender>, + info: ConnectionInfo, + idx_src: &mut usize, + ) -> Self { + Self::base(ConnectionFlavour::Connect(tx), info, idx_src) + } + + fn reconnect(info: ConnectionInfo, idx_src: &mut usize) -> Self { + Self::base(ConnectionFlavour::Reconnect, info, idx_src) + } + + fn base(flavour: ConnectionFlavour, info: ConnectionInfo, idx_src: &mut usize) -> Self { + *idx_src = idx_src.wrapping_add(1); + + Self { + flavour, + attempts: 0, + last_wait: None, + info, + idx: *idx_src, + } + } + + async fn attempt( + mut self, + attempt_slot: &mut Option, + interconnect: &Interconnect, + config: &Config, + ) -> Option { + match Connection::new(self.info.clone(), interconnect, config, self.idx).await { + Ok(connection) => { + match self.flavour { + ConnectionFlavour::Connect(tx) => { + // Other side may not be listening: this is fine. + let _ = tx.send(Ok(())); + + let _ = interconnect.events.send(EventMessage::FireCoreEvent( + CoreContext::DriverConnect(InternalConnect { + info: connection.info.clone(), + ssrc: connection.ssrc, + }), + )); + }, + ConnectionFlavour::Reconnect => { + let _ = interconnect.events.send(EventMessage::FireCoreEvent( + CoreContext::DriverReconnect(InternalConnect { + info: connection.info.clone(), + ssrc: connection.ssrc, + }), + )); + }, + } + + Some(connection) + }, + Err(why) => { + debug!("Failed to connect for {:?}: {}", self.info.guild_id, why); + if let Some(t) = config.driver_retry.retry_in(self.last_wait, self.attempts) { + let remote_ic = interconnect.clone(); + let idx = self.idx; + + spawn(async move { + tsleep(t).await; + let _ = remote_ic.core.send(CoreMessage::RetryConnect(idx)); + }); + + self.attempts += 1; + self.last_wait = Some(t); + + debug!( + "Retrying connection for {:?} in {}s ({}/{:?})", + self.info.guild_id, + t.as_secs_f32(), + self.attempts, + config.driver_retry.retry_limit + ); + + *attempt_slot = Some(self); + } else { + let reason = Some(DisconnectReason::from(&why)); + + match self.flavour { + ConnectionFlavour::Connect(tx) => { + // See above. + let _ = tx.send(Err(why)); + + let _ = interconnect.events.send(EventMessage::FireCoreEvent( + CoreContext::DriverConnectFailed, + )); + + let _ = interconnect.events.send(EventMessage::FireCoreEvent( + CoreContext::DriverDisconnect(InternalDisconnect { + kind: DisconnectKind::Connect, + reason, + info: self.info, + }), + )); + }, + ConnectionFlavour::Reconnect => { + let _ = interconnect.events.send(EventMessage::FireCoreEvent( + CoreContext::DriverReconnectFailed, + )); + + let _ = interconnect.events.send(EventMessage::FireCoreEvent( + CoreContext::DriverDisconnect(InternalDisconnect { + kind: DisconnectKind::Reconnect, + reason, + info: self.info, + }), + )); + }, + } + } + + None + }, + } + } +} + +enum ConnectionFlavour { + Connect(Sender>), + Reconnect, +} diff --git a/src/driver/tasks/ws.rs b/src/driver/tasks/ws.rs index e9f3f0d..1fd2ffc 100644 --- a/src/driver/tasks/ws.rs +++ b/src/driver/tasks/ws.rs @@ -9,6 +9,7 @@ use crate::{ SpeakingState, }, ws::{Error as WsError, ReceiverExt, SenderExt, WsStream}, + ConnectionInfo, }; #[cfg(not(feature = "tokio-02-marker"))] use async_tungstenite::tungstenite::protocol::frame::coding::CloseCode; @@ -39,6 +40,9 @@ struct AuxNetwork { speaking: SpeakingState, last_heartbeat_nonce: Option, + + attempt_idx: usize, + info: ConnectionInfo, } impl AuxNetwork { @@ -47,6 +51,8 @@ impl AuxNetwork { ws_client: WsStream, ssrc: u32, heartbeat_interval: f64, + attempt_idx: usize, + info: ConnectionInfo, ) -> Self { Self { rx: evt_rx, @@ -58,6 +64,9 @@ impl AuxNetwork { speaking: SpeakingState::empty(), last_heartbeat_nonce: None, + + attempt_idx, + info, } } @@ -68,6 +77,7 @@ impl AuxNetwork { loop { let mut ws_error = false; let mut should_reconnect = false; + let mut ws_reason = None; let hb = sleep_until(next_heartbeat); @@ -75,7 +85,8 @@ impl AuxNetwork { _ = hb => { ws_error = match self.send_heartbeat().await { Err(e) => { - should_reconnect = ws_error_is_not_final(e); + should_reconnect = ws_error_is_not_final(&e); + ws_reason = Some((&e).into()); true }, _ => false, @@ -89,7 +100,8 @@ impl AuxNetwork { false }, Err(e) => { - should_reconnect = ws_error_is_not_final(e); + should_reconnect = ws_error_is_not_final(&e); + ws_reason = Some((&e).into()); true }, Ok(Some(msg)) => { @@ -129,7 +141,8 @@ impl AuxNetwork { ws_error |= match ssu_status { Err(e) => { - should_reconnect = ws_error_is_not_final(e); + should_reconnect = ws_error_is_not_final(&e); + ws_reason = Some((&e).into()); true }, _ => false, @@ -149,6 +162,11 @@ impl AuxNetwork { if should_reconnect { let _ = interconnect.core.send(CoreMessage::Reconnect); } else { + let _ = interconnect.core.send(CoreMessage::SignalWsClosure( + self.attempt_idx, + self.info.clone(), + ws_reason, + )); break; } } @@ -217,15 +235,24 @@ pub(crate) async fn runner( ws_client: WsStream, ssrc: u32, heartbeat_interval: f64, + attempt_idx: usize, + info: ConnectionInfo, ) { trace!("WS thread started."); - let mut aux = AuxNetwork::new(evt_rx, ws_client, ssrc, heartbeat_interval); + let mut aux = AuxNetwork::new( + evt_rx, + ws_client, + ssrc, + heartbeat_interval, + attempt_idx, + info, + ); aux.run(&mut interconnect).await; trace!("WS thread finished."); } -fn ws_error_is_not_final(err: WsError) -> bool { +fn ws_error_is_not_final(err: &WsError) -> bool { match err { WsError::WsClosed(Some(frame)) => match frame.code { CloseCode::Library(l) => diff --git a/src/events/context/data/connect.rs b/src/events/context/data/connect.rs index c618828..26b04a3 100644 --- a/src/events/context/data/connect.rs +++ b/src/events/context/data/connect.rs @@ -1,7 +1,18 @@ +use crate::id::*; + /// Voice connection details gathered at setup/reinstantiation. #[derive(Clone, Debug, Eq, Hash, PartialEq)] #[non_exhaustive] pub struct ConnectData<'a> { + /// ID of the voice channel being joined, if it is known. + /// + /// If this is available, then this can be used to reconnect/renew + /// a voice session via thew gateway. + pub channel_id: Option, + /// ID of the target voice channel's parent guild. + pub guild_id: GuildId, + /// Unique string describing this session for validation/authentication purposes. + pub session_id: &'a str, /// The domain name of Discord's voice/TURN server. /// /// With the introduction of Discord's automatic voice server selection, diff --git a/src/events/context/data/disconnect.rs b/src/events/context/data/disconnect.rs new file mode 100644 index 0000000..e7956d1 --- /dev/null +++ b/src/events/context/data/disconnect.rs @@ -0,0 +1,119 @@ +use crate::{ + error::ConnectionError, + id::*, + model::{CloseCode as VoiceCloseCode, FromPrimitive}, + ws::Error as WsError, +}; +#[cfg(not(feature = "tokio-02-marker"))] +use async_tungstenite::tungstenite::protocol::frame::coding::CloseCode; +#[cfg(feature = "tokio-02-marker")] +use async_tungstenite_compat::tungstenite::protocol::frame::coding::CloseCode; + +/// Voice connection details gathered at termination or failure. +/// +/// In the event of a failure, this event data is gathered after +/// a reconnection strategy has exhausted all of its attempts. +#[derive(Debug)] +#[non_exhaustive] +pub struct DisconnectData<'a> { + /// The location that a voice connection was terminated. + pub kind: DisconnectKind, + /// The cause of any connection failure. + /// + /// If `None`, then this disconnect was requested by the user in some way + /// (i.e., leaving or changing voice channels). + pub reason: Option, + /// ID of the voice channel being joined, if it is known. + /// + /// If this is available, then this can be used to reconnect/renew + /// a voice session via thew gateway. + pub channel_id: Option, + /// ID of the target voice channel's parent guild. + pub guild_id: GuildId, + /// Unique string describing this session for validation/authentication purposes. + pub session_id: &'a str, +} + +/// The location that a voice connection was terminated. +#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)] +#[non_exhaustive] +pub enum DisconnectKind { + /// The voice driver failed to connect to the server. + /// + /// This requires explicit handling at the gateway level + /// to either reconnect or fully disconnect. + Connect, + /// The voice driver failed to reconnect to the server. + /// + /// This requires explicit handling at the gateway level + /// to either reconnect or fully disconnect. + Reconnect, + /// The voice connection was terminated mid-session by either + /// the user or Discord. + /// + /// If `reason == None`, then this disconnection is either + /// a full disconnect or a user-requested channel change. + /// Otherwise, this is likely a session expiry (requiring user + /// handling to fully disconnect/reconnect). + Runtime, +} + +/// The reason that a voice connection failed. +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +#[non_exhaustive] +pub enum DisconnectReason { + /// This (re)connection attempt was dropped due to another request. + AttemptDiscarded, + /// Songbird had an internal error. + /// + /// This should never happen; if this is ever seen, raise an issue with logs. + Internal, + /// A host-specific I/O error caused the fault; this is likely transient, and + /// should be retried some time later. + Io, + /// Songbird and Discord disagreed on the protocol used to establish a + /// voice connection. + /// + /// This should never happen; if this is ever seen, raise an issue with logs. + ProtocolViolation, + /// A voice connection was not established in the specified time. + TimedOut, + /// The Websocket connection was closed by Discord. + /// + /// This typically indicates that the voice session has expired, + /// and a new one needs to be requested via the gateway. + WsClosed(Option), +} + +impl From<&ConnectionError> for DisconnectReason { + fn from(e: &ConnectionError) -> Self { + use ConnectionError::*; + + match e { + AttemptDiscarded => Self::AttemptDiscarded, + CryptoModeInvalid + | CryptoModeUnavailable + | EndpointUrl + | ExpectedHandshake + | IllegalDiscoveryResponse + | IllegalIp + | Json(_) => Self::ProtocolViolation, + Io(_) => Self::Io, + Crypto(_) | InterconnectFailure(_) => Self::Internal, + Ws(ws) => ws.into(), + TimedOut => Self::TimedOut, + } + } +} + +impl From<&WsError> for DisconnectReason { + fn from(e: &WsError) -> Self { + Self::WsClosed(match e { + WsError::WsClosed(Some(frame)) => match frame.code { + CloseCode::Library(l) => VoiceCloseCode::from_u16(l), + _ => None, + }, + _ => None, + }) + } +} diff --git a/src/events/context/data/mod.rs b/src/events/context/data/mod.rs index 6ed8211..abc385e 100644 --- a/src/events/context/data/mod.rs +++ b/src/events/context/data/mod.rs @@ -2,10 +2,11 @@ //! //! [`EventContext`]: super::EventContext mod connect; +mod disconnect; mod rtcp; mod speaking; mod voice; use discortp::{rtcp::Rtcp, rtp::Rtp}; -pub use self::{connect::*, rtcp::*, speaking::*, voice::*}; +pub use self::{connect::*, disconnect::*, rtcp::*, speaking::*, voice::*}; diff --git a/src/events/context/internal_data.rs b/src/events/context/internal_data.rs index 4a26a18..e26d5ad 100644 --- a/src/events/context/internal_data.rs +++ b/src/events/context/internal_data.rs @@ -1,12 +1,20 @@ use super::context_data::*; +use crate::ConnectionInfo; use discortp::{rtcp::Rtcp, rtp::Rtp}; #[derive(Clone, Debug, Eq, Hash, PartialEq)] pub struct InternalConnect { - pub server: String, + pub info: ConnectionInfo, pub ssrc: u32, } +#[derive(Debug)] +pub struct InternalDisconnect { + pub kind: DisconnectKind, + pub reason: Option, + pub info: ConnectionInfo, +} + #[derive(Clone, Debug, Eq, Hash, PartialEq)] pub struct InternalSpeakingUpdate { pub ssrc: u32, @@ -31,12 +39,27 @@ pub struct InternalRtcpPacket { impl<'a> From<&'a InternalConnect> for ConnectData<'a> { fn from(val: &'a InternalConnect) -> Self { Self { - server: &val.server, + channel_id: val.info.channel_id, + guild_id: val.info.guild_id, + session_id: &val.info.session_id, + server: &val.info.endpoint, ssrc: val.ssrc, } } } +impl<'a> From<&'a InternalDisconnect> for DisconnectData<'a> { + fn from(val: &'a InternalDisconnect) -> Self { + Self { + kind: val.kind, + reason: val.reason, + channel_id: val.info.channel_id, + guild_id: val.info.guild_id, + session_id: &val.info.session_id, + } + } +} + impl<'a> From<&'a InternalSpeakingUpdate> for SpeakingUpdateData { fn from(val: &'a InternalSpeakingUpdate) -> Self { Self { diff --git a/src/events/context/mod.rs b/src/events/context/mod.rs index 71ee51f..c44c8f9 100644 --- a/src/events/context/mod.rs +++ b/src/events/context/mod.rs @@ -17,7 +17,7 @@ use internal_data::*; /// /// [`Track`]: crate::tracks::Track /// [`Driver::add_global_event`]: crate::driver::Driver::add_global_event -#[derive(Clone, Debug)] +#[derive(Debug)] #[non_exhaustive] pub enum EventContext<'a> { /// Track event context, passed to events created via [`TrackHandle::add_event`], @@ -47,12 +47,32 @@ pub enum EventContext<'a> { DriverConnect(ConnectData<'a>), /// Fires when this driver successfully reconnects after a network error. DriverReconnect(ConnectData<'a>), + #[deprecated( + since = "0.2.0", + note = "Please use the DriverDisconnect event instead." + )] /// Fires when this driver fails to connect to a voice channel. + /// + /// Users will need to manually reconnect on receipt of this error. + /// **This event is deprecated in favour of [`DriverDisconnect`].** + /// + /// [`DriverDisconnect`]: Self::DriverDisconnect + // TODO: remove in 0.3.x DriverConnectFailed, + #[deprecated( + since = "0.2.0", + note = "Please use the DriverDisconnect event instead." + )] /// Fires when this driver fails to reconnect to a voice channel after a network error. /// /// Users will need to manually reconnect on receipt of this error. + /// **This event is deprecated in favour of [`DriverDisconnect`].** + /// + /// [`DriverDisconnect`]: Self::DriverDisconnect + // TODO: remove in 0.3.x DriverReconnectFailed, + /// Fires when this driver fails to connect to, or drops from, a voice channel. + DriverDisconnect(DisconnectData<'a>), #[deprecated( since = "0.2.0", note = "Please use the DriverConnect/Reconnect events instead." @@ -69,7 +89,7 @@ pub enum EventContext<'a> { SsrcKnown(u32), } -#[derive(Clone, Debug)] +#[derive(Debug)] pub enum CoreContext { SpeakingStateUpdate(Speaking), SpeakingUpdate(InternalSpeakingUpdate), @@ -79,6 +99,7 @@ pub enum CoreContext { ClientDisconnect(ClientDisconnect), DriverConnect(InternalConnect), DriverReconnect(InternalConnect), + DriverDisconnect(InternalDisconnect), DriverConnectFailed, DriverReconnectFailed, SsrcKnown(u32), @@ -97,7 +118,10 @@ impl<'a> CoreContext { ClientDisconnect(evt) => EventContext::ClientDisconnect(*evt), DriverConnect(evt) => EventContext::DriverConnect(ConnectData::from(evt)), DriverReconnect(evt) => EventContext::DriverReconnect(ConnectData::from(evt)), + DriverDisconnect(evt) => EventContext::DriverDisconnect(DisconnectData::from(evt)), + #[allow(deprecated)] DriverConnectFailed => EventContext::DriverConnectFailed, + #[allow(deprecated)] DriverReconnectFailed => EventContext::DriverReconnectFailed, #[allow(deprecated)] SsrcKnown(s) => EventContext::SsrcKnown(*s), @@ -112,15 +136,18 @@ impl EventContext<'_> { use EventContext::*; match self { - SpeakingStateUpdate { .. } => Some(CoreEvent::SpeakingStateUpdate), - SpeakingUpdate { .. } => Some(CoreEvent::SpeakingUpdate), - VoicePacket { .. } => Some(CoreEvent::VoicePacket), - RtcpPacket { .. } => Some(CoreEvent::RtcpPacket), - ClientConnect { .. } => Some(CoreEvent::ClientConnect), - ClientDisconnect { .. } => Some(CoreEvent::ClientDisconnect), + SpeakingStateUpdate(_) => Some(CoreEvent::SpeakingStateUpdate), + SpeakingUpdate(_) => Some(CoreEvent::SpeakingUpdate), + VoicePacket(_) => Some(CoreEvent::VoicePacket), + RtcpPacket(_) => Some(CoreEvent::RtcpPacket), + ClientConnect(_) => Some(CoreEvent::ClientConnect), + ClientDisconnect(_) => Some(CoreEvent::ClientDisconnect), DriverConnect(_) => Some(CoreEvent::DriverConnect), DriverReconnect(_) => Some(CoreEvent::DriverReconnect), + DriverDisconnect(_) => Some(CoreEvent::DriverDisconnect), + #[allow(deprecated)] DriverConnectFailed => Some(CoreEvent::DriverConnectFailed), + #[allow(deprecated)] DriverReconnectFailed => Some(CoreEvent::DriverReconnectFailed), #[allow(deprecated)] SsrcKnown(_) => Some(CoreEvent::SsrcKnown), diff --git a/src/events/core.rs b/src/events/core.rs index 803547f..16ea47d 100644 --- a/src/events/core.rs +++ b/src/events/core.rs @@ -33,12 +33,22 @@ pub enum CoreEvent { DriverConnect, /// Fires when this driver successfully reconnects after a network error. DriverReconnect, + #[deprecated( + since = "0.2.0", + note = "Please use the DriverDisconnect event instead." + )] /// Fires when this driver fails to connect to a voice channel. DriverConnectFailed, + #[deprecated( + since = "0.2.0", + note = "Please use the DriverDisconnect event instead." + )] /// Fires when this driver fails to reconnect to a voice channel after a network error. /// /// Users will need to manually reconnect on receipt of this error. DriverReconnectFailed, + /// Fires when this driver fails to connect to, or drops from, a voice channel. + DriverDisconnect, /// Fires whenever the driver is assigned a new [RTP SSRC] by the voice server. /// /// This typically fires alongside a [DriverConnect], or a full [DriverReconnect]. diff --git a/src/handler.rs b/src/handler.rs index 8891c3e..74c3886 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -203,7 +203,7 @@ impl Call { let (gw_tx, gw_rx) = flume::unbounded(); let do_conn = self - .should_actually_join(|_| Ok(()), &tx, channel_id) + .should_actually_join(|_| (), &gw_tx, channel_id) .await?; if do_conn { @@ -218,6 +218,14 @@ impl Call { .await .map(|_| Join::new(rx.into_recv_async(), gw_rx.into_recv_async(), timeout)) } else { + // Skipping the gateway connection implies that the current connection is complete + // AND the channel is a match. + // + // Send a polite request to the driver, which should only *actually* reconnect + // if it had a problem earlier. + let info = self.current_connection().unwrap().clone(); + self.driver.raw_connect(info, tx.clone()); + Ok(Join::new( rx.into_recv_async(), gw_rx.into_recv_async(), diff --git a/src/info.rs b/src/info.rs index 6ddc390..15e564f 100644 --- a/src/info.rs +++ b/src/info.rs @@ -104,7 +104,7 @@ impl ConnectionProgress { /// Parameters and information needed to start communicating with Discord's voice servers, either /// with the Songbird driver, lavalink, or other system. -#[derive(Clone)] +#[derive(Clone, Eq, Hash, PartialEq)] pub struct ConnectionInfo { /// ID of the voice channel being joined, if it is known. ///