From 1fc3dc225919dbabc53a14ead35adf9de9e097e3 Mon Sep 17 00:00:00 2001 From: Kyle Simpson Date: Mon, 29 Mar 2021 19:51:13 +0100 Subject: [PATCH] Gateway: Add connection timeout, add `Config` to gateway. (#51) This change fixes tasks hanging due to rare cases of messages being lost between full Discord reconnections by placing a configurable timeout on the `ConnectionInfo` responses. This is a companion fix to [serenity#1255](https://github.com/serenity-rs/serenity/pull/1255). To make this doable, `Config`s are now used by all versions of `Songbird`/`Call`, and relevant functions are added to simplify setup with configuration. These are now non-exhaustive, correcting an earlier oversight. For future extensibility, this PR moves the return type of `join`/`join_gateway` into a custom future (no longer leaking flume's `RecvFut` type). Additionally, this fixes the Makefile's feature sets for driver/gateway-only compilation. This is a breaking change in: * the return types of `join`/`join_gateway` * moving `crate::driver::Config` -> `crate::Config`, * `Config` and `JoinError` becoming `#[non_breaking]`. This was tested via `cargo make ready`, and by testing `examples/serenity/voice_receive` with various timeout settings. --- ARCHITECTURE.md | 1 + Cargo.toml | 6 + Makefile.toml | 13 +- examples/serenity/voice_receive/src/main.rs | 13 +- src/{driver => }/config.rs | 41 ++++- src/driver/mod.rs | 11 +- src/driver/tasks/error.rs | 1 + src/driver/tasks/mixer.rs | 3 +- src/driver/tasks/mod.rs | 7 +- src/driver/tasks/udp_rx.rs | 7 +- src/error.rs | 42 +++++ src/handler.rs | 116 +++++++------ src/input/cached/tests.rs | 2 +- src/input/mod.rs | 4 +- src/join.rs | 174 ++++++++++++++++++++ src/lib.rs | 5 + src/manager.rs | 83 ++++++---- src/serenity.rs | 16 +- 18 files changed, 426 insertions(+), 119 deletions(-) rename src/{driver => }/config.rs (66%) create mode 100644 src/join.rs diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index 8435d17..ced2cdd 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -21,6 +21,7 @@ If the driver feature is enabled, then every `Call` is/has an associated `Driver src/manager.rs src/handler.rs src/serenity.rs +src/join.rs ``` # Driver diff --git a/Cargo.toml b/Cargo.toml index e31505d..148cf0e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -63,6 +63,10 @@ version = "0.3" optional = true version = "0.11" +[dependencies.pin-project] +optional = true +version = "1" + [dependencies.rand] optional = true version = "0.8" @@ -142,11 +146,13 @@ default = [ gateway = [ "gateway-core", "tokio/sync", + "tokio/time", ] gateway-core = [ "dashmap", "flume", "parking_lot", + "pin-project", "spinning_top", ] driver = [ diff --git a/Makefile.toml b/Makefile.toml index 9372dbb..6ed76d1 100644 --- a/Makefile.toml +++ b/Makefile.toml @@ -14,18 +14,18 @@ command = "cargo" dependencies = ["format"] [tasks.build-gateway] -args = ["build", "--features", "serenity-rustls"] +args = ["build", "--no-default-features", "--features", "serenity-rustls"] command = "cargo" dependencies = ["format"] [tasks.build-driver] -args = ["build", "--features", "driver,rustls"] +args = ["build", "--no-default-features", "--features", "driver,rustls"] command = "cargo" dependencies = ["format"] [tasks.build-old-tokio] command = "cargo" -args = ["build", "--features", "serenity-rustls-tokio-02,driver-tokio-02"] +args = ["build", "--no-default-features", "--features", "serenity-rustls-tokio-02,driver-tokio-02"] dependencies = ["format"] [tasks.build-variants] @@ -45,7 +45,12 @@ command = "cargo" args = ["bench", "--features", "internals,full-doc"] [tasks.doc] +command = "cargo" args = ["doc", "--features", "full-doc"] +[tasks.doc-open] +command = "cargo" +args = ["doc", "--features", "full-doc", "--open"] + [tasks.ready] -dependencies = ["format", "test", "build-variants", "build-examples", "doc", "clippy"] \ No newline at end of file +dependencies = ["format", "test", "build-variants", "build-examples", "doc", "clippy"] diff --git a/examples/serenity/voice_receive/src/main.rs b/examples/serenity/voice_receive/src/main.rs index 2249c32..dbaf47a 100644 --- a/examples/serenity/voice_receive/src/main.rs +++ b/examples/serenity/voice_receive/src/main.rs @@ -28,14 +28,14 @@ use serenity::{ }; use songbird::{ - driver::{Config as DriverConfig, DecodeMode}, + driver::DecodeMode, model::payload::{ClientConnect, ClientDisconnect, Speaking}, + Config, CoreEvent, Event, EventContext, EventHandler as VoiceEventHandler, SerenityInit, - Songbird, }; struct Handler; @@ -167,16 +167,13 @@ async fn main() { // Here, we need to configure Songbird to decode all incoming voice packets. // If you want, you can do this on a per-call basis---here, we need it to // read the audio data that other people are sending us! - let songbird = Songbird::serenity(); - songbird.set_config( - DriverConfig::default() - .decode_mode(DecodeMode::Decode) - ); + let songbird_config = Config::default() + .decode_mode(DecodeMode::Decode); let mut client = Client::builder(&token) .event_handler(Handler) .framework(framework) - .register_songbird_with(songbird.into()) + .register_songbird_from_config(songbird_config) .await .expect("Err creating client"); diff --git a/src/driver/config.rs b/src/config.rs similarity index 66% rename from src/driver/config.rs rename to src/config.rs index f3a2972..caad394 100644 --- a/src/driver/config.rs +++ b/src/config.rs @@ -1,9 +1,14 @@ -use super::{CryptoMode, DecodeMode}; +#[cfg(feature = "driver-core")] +use super::driver::{CryptoMode, DecodeMode}; -/// Configuration for the inner Driver. -/// +#[cfg(feature = "gateway-core")] +use std::time::Duration; + +/// Configuration for drivers and calls. #[derive(Clone, Debug)] +#[non_exhaustive] pub struct Config { + #[cfg(feature = "driver-core")] /// Selected tagging mode for voice packet encryption. /// /// Defaults to [`CryptoMode::Normal`]. @@ -14,6 +19,7 @@ pub struct Config { /// /// [`CryptoMode::Normal`]: CryptoMode::Normal pub crypto_mode: CryptoMode, + #[cfg(feature = "driver-core")] /// Configures whether decoding and decryption occur for all received packets. /// /// If voice receiving voice packets, generally you should choose [`DecodeMode::Decode`]. @@ -29,6 +35,20 @@ pub struct Config { /// [`DecodeMode::Pass`]: DecodeMode::Pass /// [user speaking events]: crate::events::CoreEvent::SpeakingUpdate pub decode_mode: DecodeMode, + #[cfg(feature = "gateway-core")] + /// Configures the amount of time to wait for Discord to reply with connection information + /// if [`Call::join`]/[`join_gateway`] are used. + /// + /// This is a useful fallback in the event that: + /// * the underlying Discord client restarts and loses a join request, or + /// * a channel join fails because the bot is already believed to be there. + /// + /// Defaults to 10 seconds. If set to `None`, connections will never time out. + /// + /// [`Call::join`]: crate::Call::join + /// [`join_gateway`]: crate::Call::join_gateway + pub gateway_timeout: Option, + #[cfg(feature = "driver-core")] /// Number of concurrently active tracks to allocate memory for. /// /// This should be set at, or just above, the maximum number of tracks @@ -46,13 +66,19 @@ pub struct Config { impl Default for Config { fn default() -> Self { Self { + #[cfg(feature = "driver-core")] crypto_mode: CryptoMode::Normal, + #[cfg(feature = "driver-core")] decode_mode: DecodeMode::Decrypt, + #[cfg(feature = "gateway-core")] + gateway_timeout: Some(Duration::from_secs(10)), + #[cfg(feature = "driver-core")] preallocated_tracks: 1, } } } +#[cfg(feature = "driver-core")] impl Config { /// Sets this `Config`'s chosen cryptographic tagging scheme. pub fn crypto_mode(mut self, crypto_mode: CryptoMode) -> Self { @@ -79,3 +105,12 @@ impl Config { } } } + +#[cfg(feature = "gateway-core")] +impl Config { + /// Sets this `Config`'s timeout for joining a voice channel. + pub fn gateway_timeout(mut self, gateway_timeout: Option) -> Self { + self.gateway_timeout = gateway_timeout; + self + } +} diff --git a/src/driver/mod.rs b/src/driver/mod.rs index 1c612d8..e2d988d 100644 --- a/src/driver/mod.rs +++ b/src/driver/mod.rs @@ -11,13 +11,11 @@ #[cfg(feature = "internals")] pub mod bench_internals; -mod config; pub(crate) mod connection; mod crypto; mod decode_mode; pub(crate) mod tasks; -pub use config::Config; use connection::error::{Error, Result}; pub use crypto::CryptoMode; pub(crate) use crypto::CryptoState; @@ -29,6 +27,7 @@ use crate::{ events::EventData, input::Input, tracks::{self, Track, TrackHandle}, + Config, ConnectionInfo, Event, EventHandler, @@ -212,13 +211,19 @@ impl Driver { self.send(CoreMessage::SetTrack(None)) } - /// Sets the configuration for this driver. + /// Sets the configuration for this driver (and parent `Call`, if applicable). #[instrument(skip(self))] pub fn set_config(&mut self, config: Config) { self.config = config.clone(); self.send(CoreMessage::SetConfig(config)) } + /// Returns a view of this driver's configuration. + #[instrument(skip(self))] + pub fn config(&self) -> &Config { + &self.config + } + /// Attach a global event handler to an audio context. Global events may receive /// any [`EventContext`]. /// diff --git a/src/driver/tasks/error.rs b/src/driver/tasks/error.rs index c9e1fdb..c56319b 100644 --- a/src/driver/tasks/error.rs +++ b/src/driver/tasks/error.rs @@ -17,6 +17,7 @@ pub enum Recipient { pub type Result = std::result::Result; #[derive(Debug)] +#[non_exhaustive] pub enum Error { Crypto(CryptoError), /// Received an illegal voice packet on the voice UDP socket. diff --git a/src/driver/tasks/mixer.rs b/src/driver/tasks/mixer.rs index 61f4392..9db40cc 100644 --- a/src/driver/tasks/mixer.rs +++ b/src/driver/tasks/mixer.rs @@ -1,7 +1,8 @@ -use super::{disposal, error::Result, message::*, Config}; +use super::{disposal, error::Result, message::*}; use crate::{ constants::*, tracks::{PlayMode, Track}, + Config, }; use audiopus::{ coder::Encoder as OpusEncoder, diff --git a/src/driver/tasks/mod.rs b/src/driver/tasks/mod.rs index aacbd7f..f85e3fe 100644 --- a/src/driver/tasks/mod.rs +++ b/src/driver/tasks/mod.rs @@ -9,11 +9,8 @@ pub(crate) mod udp_rx; pub(crate) mod udp_tx; pub(crate) mod ws; -use super::{ - connection::{error::Error as ConnectionError, Connection}, - Config, -}; -use crate::events::CoreContext; +use super::connection::{error::Error as ConnectionError, Connection}; +use crate::{events::CoreContext, Config}; use flume::{Receiver, RecvError, Sender}; use message::*; #[cfg(not(feature = "tokio-02-marker"))] diff --git a/src/driver/tasks/udp_rx.rs b/src/driver/tasks/udp_rx.rs index af13f95..d118fd5 100644 --- a/src/driver/tasks/udp_rx.rs +++ b/src/driver/tasks/udp_rx.rs @@ -1,12 +1,9 @@ use super::{ error::{Error, Result}, message::*, + Config, }; -use crate::{ - constants::*, - driver::{Config, DecodeMode}, - events::CoreContext, -}; +use crate::{constants::*, driver::DecodeMode, events::CoreContext}; use audiopus::{ coder::Decoder as OpusDecoder, error::{Error as OpusError, ErrorCode}, diff --git a/src/error.rs b/src/error.rs index 7fbc86b..f3a05ec 100644 --- a/src/error.rs +++ b/src/error.rs @@ -11,6 +11,7 @@ use twilight_gateway::shard::CommandError; #[cfg(feature = "gateway-core")] #[derive(Debug)] +#[non_exhaustive] /// Error returned when a manager or call handler is /// unable to send messages over Discord's gateway. pub enum JoinError { @@ -23,8 +24,23 @@ pub enum JoinError { /// /// [`Call`]: crate::Call NoCall, + /// Connection details were not received from Discord in the + /// time given in [the `Call`'s configuration]. + /// + /// This can occur if a message is lost by the Discord client + /// between restarts, or if Discord's gateway believes that + /// this bot is still in the channel it attempts to join. + /// + /// *Users should `leave` the server on the gateway before + /// re-attempting connection.* + /// + /// [the `Call`'s configuration]: crate::Config + TimedOut, #[cfg(feature = "driver-core")] /// The driver failed to establish a voice connection. + /// + /// *Users should `leave` the server on the gateway before + /// re-attempting connection.* Driver(ConnectionError), #[cfg(feature = "serenity")] /// Serenity-specific WebSocket send error. @@ -34,6 +50,31 @@ pub enum JoinError { Twilight(CommandError), } +#[cfg(feature = "gateway-core")] +impl JoinError { + /// Indicates whether this failure may have left (or been + /// caused by) Discord's gateway state being in an + /// inconsistent state. + /// + /// Failure to `leave` before rejoining may cause further + /// timeouts. + pub fn should_leave_server(&self) -> bool { + matches!(self, JoinError::TimedOut) + } + + #[cfg(feature = "driver-core")] + /// Indicates whether this failure can be reattempted via + /// [`Driver::connect`] with retreived connection info. + /// + /// Failure to `leave` before rejoining may cause further + /// timeouts. + /// + /// [`Driver::connect`]: crate::driver::Driver + pub fn should_reconnect_driver(&self) -> bool { + matches!(self, JoinError::Driver(_)) + } +} + #[cfg(feature = "gateway-core")] impl fmt::Display for JoinError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -42,6 +83,7 @@ impl fmt::Display for JoinError { JoinError::Dropped => write!(f, "request was cancelled/dropped."), JoinError::NoSender => write!(f, "no gateway destination."), JoinError::NoCall => write!(f, "tried to leave a non-existent call."), + JoinError::TimedOut => write!(f, "gateway response from Discord timed out."), #[cfg(feature = "driver-core")] JoinError::Driver(t) => write!(f, "internal driver error {}.", t), #[cfg(feature = "serenity")] diff --git a/src/handler.rs b/src/handler.rs index eb29fb6..a4af634 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -1,15 +1,14 @@ #[cfg(feature = "driver-core")] -use crate::{ - driver::{Config, Driver}, - error::ConnectionResult, -}; +use crate::{driver::Driver, error::ConnectionResult}; use crate::{ error::{JoinError, JoinResult}, id::{ChannelId, GuildId, UserId}, info::{ConnectionInfo, ConnectionProgress}, + join::*, shards::Shard, + Config, }; -use flume::{r#async::RecvFut, Sender}; +use flume::Sender; use serde_json::json; use tracing::instrument; @@ -18,9 +17,15 @@ use std::ops::{Deref, DerefMut}; #[derive(Clone, Debug)] enum Return { + // Return the connection info as it is received. Info(Sender), + + // Two channels: first indicates "gateway connection" was successful, + // second indicates that the driver successfully connected. + // The first is needed to cancel a timeout as the driver can/should + // have separate connection timing/retry config. #[cfg(feature = "driver-core")] - Conn(Sender>), + Conn(Sender<()>, Sender>), } /// The Call handler is responsible for a single voice connection, acting @@ -32,6 +37,9 @@ enum Return { /// [`Driver`]: struct@Driver #[derive(Clone, Debug)] pub struct Call { + #[cfg(not(feature = "driver-core"))] + config: Config, + connection: Option<(ConnectionProgress, Return)>, #[cfg(feature = "driver-core")] @@ -61,19 +69,13 @@ impl Call { #[inline] #[instrument] pub fn new(guild_id: GuildId, ws: Shard, user_id: UserId) -> Self { - Self::new_raw(guild_id, Some(ws), user_id) + Self::new_raw_cfg(guild_id, Some(ws), user_id, Default::default()) } - #[cfg(feature = "driver-core")] /// Creates a new Call, configuring the driver as specified. #[inline] #[instrument] - pub fn from_driver_config( - guild_id: GuildId, - ws: Shard, - user_id: UserId, - config: Config, - ) -> Self { + pub fn from_config(guild_id: GuildId, ws: Shard, user_id: UserId, config: Config) -> Self { Self::new_raw_cfg(guild_id, Some(ws), user_id, config) } @@ -88,38 +90,22 @@ impl Call { #[inline] #[instrument] pub fn standalone(guild_id: GuildId, user_id: UserId) -> Self { - Self::new_raw(guild_id, None, user_id) + Self::new_raw_cfg(guild_id, None, user_id, Default::default()) } - #[cfg(feature = "driver-core")] - /// Creates a new standalone Call, configuring the driver as specified. + /// Creates a new standalone Call from the given configuration file. #[inline] #[instrument] - pub fn standalone_from_driver_config( - guild_id: GuildId, - user_id: UserId, - config: Config, - ) -> Self { + pub fn standalone_from_config(guild_id: GuildId, user_id: UserId, config: Config) -> Self { Self::new_raw_cfg(guild_id, None, user_id, config) } - fn new_raw(guild_id: GuildId, ws: Option, user_id: UserId) -> Self { - Call { - connection: None, - #[cfg(feature = "driver-core")] - driver: Default::default(), - guild_id, - self_deaf: false, - self_mute: false, - user_id, - ws, - } - } - - #[cfg(feature = "driver-core")] fn new_raw_cfg(guild_id: GuildId, ws: Option, user_id: UserId, config: Config) -> Self { Call { + #[cfg(not(feature = "driver-core"))] + config, connection: None, + #[cfg(feature = "driver-core")] driver: Driver::new(config), guild_id, self_deaf: false, @@ -137,8 +123,11 @@ impl Call { let _ = tx.send(c.clone()); }, #[cfg(feature = "driver-core")] - Some((ConnectionProgress::Complete(c), Return::Conn(tx))) => { - self.driver.raw_connect(c.clone(), tx.clone()); + Some((ConnectionProgress::Complete(c), Return::Conn(first_tx, driver_tx))) => { + // It's okay if the receiver hung up. + let _ = first_tx.send(()); + + self.driver.raw_connect(c.clone(), driver_tx.clone()); }, _ => {}, } @@ -209,11 +198,9 @@ impl Call { /// /// [`Songbird::join`]: crate::Songbird::join #[instrument(skip(self))] - pub async fn join( - &mut self, - channel_id: ChannelId, - ) -> JoinResult>> { + pub async fn join(&mut self, channel_id: ChannelId) -> JoinResult { let (tx, rx) = flume::unbounded(); + let (gw_tx, gw_rx) = flume::unbounded(); let do_conn = self .should_actually_join(|_| Ok(()), &tx, channel_id) @@ -222,12 +209,20 @@ impl Call { if do_conn { self.connection = Some(( ConnectionProgress::new(self.guild_id, self.user_id, channel_id), - Return::Conn(tx), + Return::Conn(gw_tx, tx), )); - self.update().await.map(|_| rx.into_recv_async()) + let timeout = self.config().gateway_timeout; + + self.update() + .await + .map(|_| Join::new(rx.into_recv_async(), gw_rx.into_recv_async(), timeout)) } else { - Ok(rx.into_recv_async()) + Ok(Join::new( + rx.into_recv_async(), + gw_rx.into_recv_async(), + None, + )) } } @@ -247,10 +242,7 @@ impl Call { /// /// [`Songbird::join_gateway`]: crate::Songbird::join_gateway #[instrument(skip(self))] - pub async fn join_gateway( - &mut self, - channel_id: ChannelId, - ) -> JoinResult> { + pub async fn join_gateway(&mut self, channel_id: ChannelId) -> JoinResult { let (tx, rx) = flume::unbounded(); let do_conn = self @@ -267,9 +259,13 @@ impl Call { Return::Info(tx), )); - self.update().await.map(|_| rx.into_recv_async()) + let timeout = self.config().gateway_timeout; + + self.update() + .await + .map(|_| JoinGateway::new(rx.into_recv_async(), timeout)) } else { - Ok(rx.into_recv_async()) + Ok(JoinGateway::new(rx.into_recv_async(), None)) } } @@ -414,6 +410,24 @@ impl Call { } } +#[cfg(not(feature = "driver-core"))] +impl Call { + /// Access this call handler's configuration. + pub fn config(&self) -> &Config { + &self.config + } + + /// Mutably access this call handler's configuration. + pub fn config_mut(&mut self) -> &mut Config { + &mut self.config + } + + /// Set this call handler's configuration. + pub fn set_config(&mut self, config: Config) { + self.config = config; + } +} + #[cfg(feature = "driver-core")] impl Deref for Call { type Target = Driver; diff --git a/src/input/cached/tests.rs b/src/input/cached/tests.rs index d4a7021..49aeef5 100644 --- a/src/input/cached/tests.rs +++ b/src/input/cached/tests.rs @@ -1,7 +1,7 @@ use super::*; use crate::{ constants::*, - input::{error::Error, ffmpeg, Codec, Container, Input, Reader}, + input::{error::Error, Codec, Container, Input}, test_utils::*, }; use audiopus::{coder::Decoder, Bitrate, Channels, SampleRate}; diff --git a/src/input/mod.rs b/src/input/mod.rs index d3c35a8..620c776 100644 --- a/src/input/mod.rs +++ b/src/input/mod.rs @@ -559,7 +559,7 @@ mod tests { let mut input = Input::new(false, data.clone().into(), Codec::Pcm, Container::Raw, None); let mut out_vec = vec![]; - let len = input.read_to_end(&mut out_vec).unwrap(); + let _len = input.read_to_end(&mut out_vec).unwrap(); let mut i16_window = &data[..]; let mut float_window = &out_vec[..]; @@ -580,7 +580,7 @@ mod tests { let mut input = Input::new(true, data.clone().into(), Codec::Pcm, Container::Raw, None); let mut out_vec = vec![]; - let len = input.read_to_end(&mut out_vec).unwrap(); + let _len = input.read_to_end(&mut out_vec).unwrap(); let mut i16_window = &data[..]; let mut float_window = &out_vec[..]; diff --git a/src/join.rs b/src/join.rs new file mode 100644 index 0000000..c8a88f4 --- /dev/null +++ b/src/join.rs @@ -0,0 +1,174 @@ +//! Future types for gateway interactions. + +#[cfg(feature = "driver-core")] +use crate::error::ConnectionResult; +use crate::{ + error::{JoinError, JoinResult}, + ConnectionInfo, +}; +use core::{ + convert, + future::Future, + marker::Unpin, + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; +use flume::r#async::RecvFut; +use pin_project::pin_project; +#[cfg(not(feature = "tokio-02-marker"))] +use tokio::time::{self, Timeout}; +#[cfg(feature = "tokio-02-marker")] +use tokio_compat::time::{self, Timeout}; + +#[cfg(feature = "driver-core")] +/// Future for a call to [`Call::join`]. +/// +/// This future `await`s Discord's response *and* +/// connection via the [`Driver`]. Both phases have +/// separate timeouts and failure conditions. +/// +/// This future ***must not*** be `await`ed while +/// holding the lock around a [`Call`]. +/// +/// [`Call::join`]: crate::Call::join +/// [`Call`]: crate::Call +/// [`Driver`]: crate::driver::Driver +#[pin_project] +pub struct Join { + #[pin] + gw: JoinClass<()>, + #[pin] + driver: JoinClass>, + state: JoinState, +} + +#[cfg(feature = "driver-core")] +impl Join { + pub(crate) fn new( + driver: RecvFut<'static, ConnectionResult<()>>, + gw_recv: RecvFut<'static, ()>, + timeout: Option, + ) -> Self { + Self { + gw: JoinClass::new(gw_recv, timeout), + driver: JoinClass::new(driver, None), + state: JoinState::BeforeGw, + } + } +} + +#[cfg(feature = "driver-core")] +impl Future for Join { + type Output = JoinResult<()>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + if *this.state == JoinState::BeforeGw { + let poll = this.gw.poll(cx); + match poll { + Poll::Ready(a) if a.is_ok() => { + *this.state = JoinState::AfterGw; + }, + Poll::Ready(a) => { + *this.state = JoinState::Finalised; + return Poll::Ready(a); + }, + Poll::Pending => return Poll::Pending, + } + } + + if *this.state == JoinState::AfterGw { + let poll = this + .driver + .poll(cx) + .map_ok(|res| res.map_err(JoinError::Driver)) + .map(|res| res.and_then(convert::identity)); + + match poll { + Poll::Ready(a) => { + *this.state = JoinState::Finalised; + return Poll::Ready(a); + }, + Poll::Pending => return Poll::Pending, + } + } + + Poll::Pending + } +} + +#[cfg(feature = "driver-core")] +#[derive(Copy, Clone, Eq, PartialEq)] +enum JoinState { + BeforeGw, + AfterGw, + Finalised, +} + +/// Future for a call to [`Call::join_gateway`]. +/// +/// This future `await`s Discord's gateway response, subject +/// to any timeouts. +/// +/// This future ***must not*** be `await`ed while +/// holding the lock around a [`Call`]. +/// +/// [`Call::join_gateway`]: crate::Call::join_gateway +/// [`Call`]: crate::Call +/// [`Driver`]: crate::driver::Driver +#[pin_project] +pub struct JoinGateway { + #[pin] + inner: JoinClass, +} + +impl JoinGateway { + pub(crate) fn new(recv: RecvFut<'static, ConnectionInfo>, timeout: Option) -> Self { + Self { + inner: JoinClass::new(recv, timeout), + } + } +} + +impl Future for JoinGateway { + type Output = JoinResult; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.project().inner.poll(cx) + } +} + +#[pin_project(project = JoinClassProj)] +enum JoinClass { + WithTimeout(#[pin] Timeout>), + Vanilla(RecvFut<'static, T>), +} + +impl JoinClass { + pub(crate) fn new(recv: RecvFut<'static, T>, timeout: Option) -> Self { + match timeout { + Some(t) => JoinClass::WithTimeout(time::timeout(t, recv)), + None => JoinClass::Vanilla(recv), + } + } +} + +impl Future for JoinClass +where + T: Unpin, +{ + type Output = JoinResult; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.project() { + JoinClassProj::WithTimeout(t) => t + .poll(cx) + .map_err(|_| JoinError::TimedOut) + .map_ok(|res| res.map_err(|_| JoinError::Dropped)) + .map(|m| m.and_then(convert::identity)), + JoinClassProj::Vanilla(t) => Pin::new(t).poll(cx).map_err(|_| JoinError::Dropped), + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 91fb875..44a7cbc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -37,6 +37,7 @@ //! [`ConnectionInfo`]: struct@ConnectionInfo //! [lavalink]: https://github.com/Frederikam/Lavalink +mod config; pub mod constants; #[cfg(feature = "driver-core")] pub mod driver; @@ -50,6 +51,8 @@ pub(crate) mod info; #[cfg(feature = "driver-core")] pub mod input; #[cfg(feature = "gateway-core")] +pub mod join; +#[cfg(feature = "gateway-core")] mod manager; #[cfg(feature = "serenity")] pub mod serenity; @@ -61,6 +64,7 @@ pub mod tracks; mod ws; #[cfg(feature = "driver-core")] +/// Opus encoder bitrate settings. pub use audiopus::{self as opus, Bitrate}; #[cfg(feature = "driver-core")] pub use discortp as packet; @@ -86,4 +90,5 @@ pub use crate::{handler::*, manager::*}; #[cfg(feature = "serenity")] pub use crate::serenity::*; +pub use config::Config; pub use info::ConnectionInfo; diff --git a/src/manager.rs b/src/manager.rs index f12bbc1..c9ac72d 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -1,10 +1,9 @@ -#[cfg(feature = "driver-core")] -use crate::driver::Config; use crate::{ error::{JoinError, JoinResult}, id::{ChannelId, GuildId, UserId}, shards::Sharder, Call, + Config, ConnectionInfo, }; #[cfg(feature = "serenity")] @@ -50,9 +49,7 @@ pub struct Songbird { client_data: PRwLock, calls: DashMap>>, sharder: Sharder, - - #[cfg(feature = "driver-core")] - driver_config: PRwLock>, + config: PRwLock>, } impl Songbird { @@ -63,13 +60,21 @@ impl Songbird { /// /// [registered]: crate::serenity::register_with pub fn serenity() -> Arc { + Self::serenity_from_config(Default::default()) + } + + #[cfg(feature = "serenity")] + /// Create a new Songbird instance for serenity, using the given configuration. + /// + /// This must be [registered] after creation. + /// + /// [registered]: crate::serenity::register_with + pub fn serenity_from_config(config: Config) -> Arc { Arc::new(Self { client_data: Default::default(), calls: Default::default(), sharder: Sharder::Serenity(Default::default()), - - #[cfg(feature = "driver-core")] - driver_config: Default::default(), + config: Some(config).into(), }) } @@ -82,6 +87,26 @@ impl Songbird { /// /// [`process`]: Songbird::process pub fn twilight(cluster: Cluster, shard_count: u64, user_id: U) -> Arc + where + U: Into, + { + Self::twilight_from_config(cluster, shard_count, user_id, Default::default()) + } + + #[cfg(feature = "twilight")] + /// Create a new Songbird instance for twilight. + /// + /// Twilight handlers do not need to be registered, but + /// users are responsible for passing in any events using + /// [`process`]. + /// + /// [`process`]: Songbird::process + pub fn twilight_from_config( + cluster: Cluster, + shard_count: u64, + user_id: U, + config: Config, + ) -> Arc where U: Into, { @@ -93,9 +118,7 @@ impl Songbird { }), calls: Default::default(), sharder: Sharder::Twilight(cluster), - - #[cfg(feature = "driver-core")] - driver_config: Default::default(), + config: Some(config).into(), }) } @@ -144,23 +167,30 @@ impl Songbird { .get_shard(shard) .expect("Failed to get shard handle: shard_count incorrect?"); - #[cfg(feature = "driver-core")] - let call = Call::from_driver_config( + let call = Call::from_config( guild_id, shard_handle, info.user_id, - self.driver_config.read().clone().unwrap_or_default(), + self.config.read().clone().unwrap_or_default(), ); - #[cfg(not(feature = "driver-core"))] - let call = Call::new(guild_id, shard_handle, info.user_id); - Arc::new(Mutex::new(call)) }) .clone() }) } + /// Sets a shared configuration for all drivers created from this + /// manager. + /// + /// Changes made here will apply to new Call and Driver instances only. + /// + /// Requires the `"driver"` feature. + pub fn set_config(&self, new_config: Config) { + let mut config = self.config.write(); + *config = Some(new_config); + } + fn manager_info(&self) -> ClientData { let client_data = self.client_data.write(); @@ -213,10 +243,7 @@ impl Songbird { }; let result = match stage_1 { - Ok(chan) => chan - .await - .map_err(|_| JoinError::Dropped) - .and_then(|x| x.map_err(JoinError::from)), + Ok(chan) => chan.await, Err(e) => Err(e), }; @@ -401,20 +428,6 @@ impl VoiceGatewayManager for Songbird { } } -#[cfg(feature = "driver-core")] -impl Songbird { - /// Sets a shared configuration for all drivers created from this - /// manager. - /// - /// Changes made here will apply to new Call and Driver instances only. - /// - /// Requires the `"driver"` feature. - pub fn set_config(&self, new_config: Config) { - let mut config = self.driver_config.write(); - *config = Some(new_config); - } -} - #[inline] fn shard_id(guild_id: u64, shard_count: u64) -> u64 { (guild_id >> 22) % shard_count diff --git a/src/serenity.rs b/src/serenity.rs index dd6674c..014ff74 100644 --- a/src/serenity.rs +++ b/src/serenity.rs @@ -3,7 +3,7 @@ //! //! [serenity]: https://crates.io/crates/serenity/0.9.0-rc.2 -use crate::manager::Songbird; +use crate::{Config, Songbird}; use serenity::{ client::{ClientBuilder, Context}, prelude::TypeMapKey, @@ -37,6 +37,14 @@ pub fn register_with(client_builder: ClientBuilder, voice: Arc) -> Cli .type_map_insert::(voice) } +/// Installs a given songbird instance into the serenity client. +/// +/// This should be called after any uses of `ClientBuilder::type_map`. +pub fn register_from_config(client_builder: ClientBuilder, config: Config) -> ClientBuilder { + let voice = Songbird::serenity_from_config(config); + register_with(client_builder, voice) +} + /// Retrieve the Songbird voice client from a serenity context's /// shared key-value store. pub async fn get(ctx: &Context) -> Option> { @@ -58,6 +66,8 @@ pub trait SerenityInit { fn register_songbird(self) -> Self; /// Registers a given Songbird voice system with serenity, as above. fn register_songbird_with(self, voice: Arc) -> Self; + /// Registers a Songbird voice system serenity, based on the given configuration. + fn register_songbird_from_config(self, config: Config) -> Self; } impl SerenityInit for ClientBuilder<'_> { @@ -68,4 +78,8 @@ impl SerenityInit for ClientBuilder<'_> { fn register_songbird_with(self, voice: Arc) -> Self { register_with(self, voice) } + + fn register_songbird_from_config(self, config: Config) -> Self { + register_from_config(self, config) + } }