Gateway: Add connection timeout, add Config to gateway. (#51)

This change fixes tasks hanging due to rare cases of messages being lost between full Discord reconnections by placing a configurable timeout on the `ConnectionInfo` responses. This is a companion fix to [serenity#1255](https://github.com/serenity-rs/serenity/pull/1255). To make this doable, `Config`s are now used by all versions of `Songbird`/`Call`, and relevant functions are  added to simplify setup with configuration. These are now non-exhaustive, correcting an earlier oversight. For future extensibility, this PR moves the return type of `join`/`join_gateway` into a custom future (no longer leaking flume's `RecvFut` type).

Additionally, this fixes the Makefile's feature sets for driver/gateway-only compilation.

This is a breaking change in:
* the return types of `join`/`join_gateway`
* moving `crate::driver::Config` -> `crate::Config`,
* `Config` and `JoinError` becoming `#[non_breaking]`.

This was tested via `cargo make ready`, and by testing `examples/serenity/voice_receive` with various timeout settings.
This commit is contained in:
Kyle Simpson
2021-03-29 19:51:13 +01:00
parent f449d4f679
commit 1fc3dc2259
18 changed files with 426 additions and 119 deletions

View File

@@ -21,6 +21,7 @@ If the driver feature is enabled, then every `Call` is/has an associated `Driver
src/manager.rs src/manager.rs
src/handler.rs src/handler.rs
src/serenity.rs src/serenity.rs
src/join.rs
``` ```
# Driver # Driver

View File

@@ -63,6 +63,10 @@ version = "0.3"
optional = true optional = true
version = "0.11" version = "0.11"
[dependencies.pin-project]
optional = true
version = "1"
[dependencies.rand] [dependencies.rand]
optional = true optional = true
version = "0.8" version = "0.8"
@@ -142,11 +146,13 @@ default = [
gateway = [ gateway = [
"gateway-core", "gateway-core",
"tokio/sync", "tokio/sync",
"tokio/time",
] ]
gateway-core = [ gateway-core = [
"dashmap", "dashmap",
"flume", "flume",
"parking_lot", "parking_lot",
"pin-project",
"spinning_top", "spinning_top",
] ]
driver = [ driver = [

View File

@@ -14,18 +14,18 @@ command = "cargo"
dependencies = ["format"] dependencies = ["format"]
[tasks.build-gateway] [tasks.build-gateway]
args = ["build", "--features", "serenity-rustls"] args = ["build", "--no-default-features", "--features", "serenity-rustls"]
command = "cargo" command = "cargo"
dependencies = ["format"] dependencies = ["format"]
[tasks.build-driver] [tasks.build-driver]
args = ["build", "--features", "driver,rustls"] args = ["build", "--no-default-features", "--features", "driver,rustls"]
command = "cargo" command = "cargo"
dependencies = ["format"] dependencies = ["format"]
[tasks.build-old-tokio] [tasks.build-old-tokio]
command = "cargo" command = "cargo"
args = ["build", "--features", "serenity-rustls-tokio-02,driver-tokio-02"] args = ["build", "--no-default-features", "--features", "serenity-rustls-tokio-02,driver-tokio-02"]
dependencies = ["format"] dependencies = ["format"]
[tasks.build-variants] [tasks.build-variants]
@@ -45,7 +45,12 @@ command = "cargo"
args = ["bench", "--features", "internals,full-doc"] args = ["bench", "--features", "internals,full-doc"]
[tasks.doc] [tasks.doc]
command = "cargo"
args = ["doc", "--features", "full-doc"] args = ["doc", "--features", "full-doc"]
[tasks.doc-open]
command = "cargo"
args = ["doc", "--features", "full-doc", "--open"]
[tasks.ready] [tasks.ready]
dependencies = ["format", "test", "build-variants", "build-examples", "doc", "clippy"] dependencies = ["format", "test", "build-variants", "build-examples", "doc", "clippy"]

View File

@@ -28,14 +28,14 @@ use serenity::{
}; };
use songbird::{ use songbird::{
driver::{Config as DriverConfig, DecodeMode}, driver::DecodeMode,
model::payload::{ClientConnect, ClientDisconnect, Speaking}, model::payload::{ClientConnect, ClientDisconnect, Speaking},
Config,
CoreEvent, CoreEvent,
Event, Event,
EventContext, EventContext,
EventHandler as VoiceEventHandler, EventHandler as VoiceEventHandler,
SerenityInit, SerenityInit,
Songbird,
}; };
struct Handler; struct Handler;
@@ -167,16 +167,13 @@ async fn main() {
// Here, we need to configure Songbird to decode all incoming voice packets. // Here, we need to configure Songbird to decode all incoming voice packets.
// If you want, you can do this on a per-call basis---here, we need it to // If you want, you can do this on a per-call basis---here, we need it to
// read the audio data that other people are sending us! // read the audio data that other people are sending us!
let songbird = Songbird::serenity(); let songbird_config = Config::default()
songbird.set_config( .decode_mode(DecodeMode::Decode);
DriverConfig::default()
.decode_mode(DecodeMode::Decode)
);
let mut client = Client::builder(&token) let mut client = Client::builder(&token)
.event_handler(Handler) .event_handler(Handler)
.framework(framework) .framework(framework)
.register_songbird_with(songbird.into()) .register_songbird_from_config(songbird_config)
.await .await
.expect("Err creating client"); .expect("Err creating client");

View File

@@ -1,9 +1,14 @@
use super::{CryptoMode, DecodeMode}; #[cfg(feature = "driver-core")]
use super::driver::{CryptoMode, DecodeMode};
/// Configuration for the inner Driver. #[cfg(feature = "gateway-core")]
/// use std::time::Duration;
/// Configuration for drivers and calls.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
#[non_exhaustive]
pub struct Config { pub struct Config {
#[cfg(feature = "driver-core")]
/// Selected tagging mode for voice packet encryption. /// Selected tagging mode for voice packet encryption.
/// ///
/// Defaults to [`CryptoMode::Normal`]. /// Defaults to [`CryptoMode::Normal`].
@@ -14,6 +19,7 @@ pub struct Config {
/// ///
/// [`CryptoMode::Normal`]: CryptoMode::Normal /// [`CryptoMode::Normal`]: CryptoMode::Normal
pub crypto_mode: CryptoMode, pub crypto_mode: CryptoMode,
#[cfg(feature = "driver-core")]
/// Configures whether decoding and decryption occur for all received packets. /// Configures whether decoding and decryption occur for all received packets.
/// ///
/// If voice receiving voice packets, generally you should choose [`DecodeMode::Decode`]. /// If voice receiving voice packets, generally you should choose [`DecodeMode::Decode`].
@@ -29,6 +35,20 @@ pub struct Config {
/// [`DecodeMode::Pass`]: DecodeMode::Pass /// [`DecodeMode::Pass`]: DecodeMode::Pass
/// [user speaking events]: crate::events::CoreEvent::SpeakingUpdate /// [user speaking events]: crate::events::CoreEvent::SpeakingUpdate
pub decode_mode: DecodeMode, pub decode_mode: DecodeMode,
#[cfg(feature = "gateway-core")]
/// Configures the amount of time to wait for Discord to reply with connection information
/// if [`Call::join`]/[`join_gateway`] are used.
///
/// This is a useful fallback in the event that:
/// * the underlying Discord client restarts and loses a join request, or
/// * a channel join fails because the bot is already believed to be there.
///
/// Defaults to 10 seconds. If set to `None`, connections will never time out.
///
/// [`Call::join`]: crate::Call::join
/// [`join_gateway`]: crate::Call::join_gateway
pub gateway_timeout: Option<Duration>,
#[cfg(feature = "driver-core")]
/// Number of concurrently active tracks to allocate memory for. /// Number of concurrently active tracks to allocate memory for.
/// ///
/// This should be set at, or just above, the maximum number of tracks /// This should be set at, or just above, the maximum number of tracks
@@ -46,13 +66,19 @@ pub struct Config {
impl Default for Config { impl Default for Config {
fn default() -> Self { fn default() -> Self {
Self { Self {
#[cfg(feature = "driver-core")]
crypto_mode: CryptoMode::Normal, crypto_mode: CryptoMode::Normal,
#[cfg(feature = "driver-core")]
decode_mode: DecodeMode::Decrypt, decode_mode: DecodeMode::Decrypt,
#[cfg(feature = "gateway-core")]
gateway_timeout: Some(Duration::from_secs(10)),
#[cfg(feature = "driver-core")]
preallocated_tracks: 1, preallocated_tracks: 1,
} }
} }
} }
#[cfg(feature = "driver-core")]
impl Config { impl Config {
/// Sets this `Config`'s chosen cryptographic tagging scheme. /// Sets this `Config`'s chosen cryptographic tagging scheme.
pub fn crypto_mode(mut self, crypto_mode: CryptoMode) -> Self { pub fn crypto_mode(mut self, crypto_mode: CryptoMode) -> Self {
@@ -79,3 +105,12 @@ impl Config {
} }
} }
} }
#[cfg(feature = "gateway-core")]
impl Config {
/// Sets this `Config`'s timeout for joining a voice channel.
pub fn gateway_timeout(mut self, gateway_timeout: Option<Duration>) -> Self {
self.gateway_timeout = gateway_timeout;
self
}
}

View File

@@ -11,13 +11,11 @@
#[cfg(feature = "internals")] #[cfg(feature = "internals")]
pub mod bench_internals; pub mod bench_internals;
mod config;
pub(crate) mod connection; pub(crate) mod connection;
mod crypto; mod crypto;
mod decode_mode; mod decode_mode;
pub(crate) mod tasks; pub(crate) mod tasks;
pub use config::Config;
use connection::error::{Error, Result}; use connection::error::{Error, Result};
pub use crypto::CryptoMode; pub use crypto::CryptoMode;
pub(crate) use crypto::CryptoState; pub(crate) use crypto::CryptoState;
@@ -29,6 +27,7 @@ use crate::{
events::EventData, events::EventData,
input::Input, input::Input,
tracks::{self, Track, TrackHandle}, tracks::{self, Track, TrackHandle},
Config,
ConnectionInfo, ConnectionInfo,
Event, Event,
EventHandler, EventHandler,
@@ -212,13 +211,19 @@ impl Driver {
self.send(CoreMessage::SetTrack(None)) self.send(CoreMessage::SetTrack(None))
} }
/// Sets the configuration for this driver. /// Sets the configuration for this driver (and parent `Call`, if applicable).
#[instrument(skip(self))] #[instrument(skip(self))]
pub fn set_config(&mut self, config: Config) { pub fn set_config(&mut self, config: Config) {
self.config = config.clone(); self.config = config.clone();
self.send(CoreMessage::SetConfig(config)) self.send(CoreMessage::SetConfig(config))
} }
/// Returns a view of this driver's configuration.
#[instrument(skip(self))]
pub fn config(&self) -> &Config {
&self.config
}
/// Attach a global event handler to an audio context. Global events may receive /// Attach a global event handler to an audio context. Global events may receive
/// any [`EventContext`]. /// any [`EventContext`].
/// ///

View File

@@ -17,6 +17,7 @@ pub enum Recipient {
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug)] #[derive(Debug)]
#[non_exhaustive]
pub enum Error { pub enum Error {
Crypto(CryptoError), Crypto(CryptoError),
/// Received an illegal voice packet on the voice UDP socket. /// Received an illegal voice packet on the voice UDP socket.

View File

@@ -1,7 +1,8 @@
use super::{disposal, error::Result, message::*, Config}; use super::{disposal, error::Result, message::*};
use crate::{ use crate::{
constants::*, constants::*,
tracks::{PlayMode, Track}, tracks::{PlayMode, Track},
Config,
}; };
use audiopus::{ use audiopus::{
coder::Encoder as OpusEncoder, coder::Encoder as OpusEncoder,

View File

@@ -9,11 +9,8 @@ pub(crate) mod udp_rx;
pub(crate) mod udp_tx; pub(crate) mod udp_tx;
pub(crate) mod ws; pub(crate) mod ws;
use super::{ use super::connection::{error::Error as ConnectionError, Connection};
connection::{error::Error as ConnectionError, Connection}, use crate::{events::CoreContext, Config};
Config,
};
use crate::events::CoreContext;
use flume::{Receiver, RecvError, Sender}; use flume::{Receiver, RecvError, Sender};
use message::*; use message::*;
#[cfg(not(feature = "tokio-02-marker"))] #[cfg(not(feature = "tokio-02-marker"))]

View File

@@ -1,12 +1,9 @@
use super::{ use super::{
error::{Error, Result}, error::{Error, Result},
message::*, message::*,
Config,
}; };
use crate::{ use crate::{constants::*, driver::DecodeMode, events::CoreContext};
constants::*,
driver::{Config, DecodeMode},
events::CoreContext,
};
use audiopus::{ use audiopus::{
coder::Decoder as OpusDecoder, coder::Decoder as OpusDecoder,
error::{Error as OpusError, ErrorCode}, error::{Error as OpusError, ErrorCode},

View File

@@ -11,6 +11,7 @@ use twilight_gateway::shard::CommandError;
#[cfg(feature = "gateway-core")] #[cfg(feature = "gateway-core")]
#[derive(Debug)] #[derive(Debug)]
#[non_exhaustive]
/// Error returned when a manager or call handler is /// Error returned when a manager or call handler is
/// unable to send messages over Discord's gateway. /// unable to send messages over Discord's gateway.
pub enum JoinError { pub enum JoinError {
@@ -23,8 +24,23 @@ pub enum JoinError {
/// ///
/// [`Call`]: crate::Call /// [`Call`]: crate::Call
NoCall, NoCall,
/// Connection details were not received from Discord in the
/// time given in [the `Call`'s configuration].
///
/// This can occur if a message is lost by the Discord client
/// between restarts, or if Discord's gateway believes that
/// this bot is still in the channel it attempts to join.
///
/// *Users should `leave` the server on the gateway before
/// re-attempting connection.*
///
/// [the `Call`'s configuration]: crate::Config
TimedOut,
#[cfg(feature = "driver-core")] #[cfg(feature = "driver-core")]
/// The driver failed to establish a voice connection. /// The driver failed to establish a voice connection.
///
/// *Users should `leave` the server on the gateway before
/// re-attempting connection.*
Driver(ConnectionError), Driver(ConnectionError),
#[cfg(feature = "serenity")] #[cfg(feature = "serenity")]
/// Serenity-specific WebSocket send error. /// Serenity-specific WebSocket send error.
@@ -34,6 +50,31 @@ pub enum JoinError {
Twilight(CommandError), Twilight(CommandError),
} }
#[cfg(feature = "gateway-core")]
impl JoinError {
/// Indicates whether this failure may have left (or been
/// caused by) Discord's gateway state being in an
/// inconsistent state.
///
/// Failure to `leave` before rejoining may cause further
/// timeouts.
pub fn should_leave_server(&self) -> bool {
matches!(self, JoinError::TimedOut)
}
#[cfg(feature = "driver-core")]
/// Indicates whether this failure can be reattempted via
/// [`Driver::connect`] with retreived connection info.
///
/// Failure to `leave` before rejoining may cause further
/// timeouts.
///
/// [`Driver::connect`]: crate::driver::Driver
pub fn should_reconnect_driver(&self) -> bool {
matches!(self, JoinError::Driver(_))
}
}
#[cfg(feature = "gateway-core")] #[cfg(feature = "gateway-core")]
impl fmt::Display for JoinError { impl fmt::Display for JoinError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
@@ -42,6 +83,7 @@ impl fmt::Display for JoinError {
JoinError::Dropped => write!(f, "request was cancelled/dropped."), JoinError::Dropped => write!(f, "request was cancelled/dropped."),
JoinError::NoSender => write!(f, "no gateway destination."), JoinError::NoSender => write!(f, "no gateway destination."),
JoinError::NoCall => write!(f, "tried to leave a non-existent call."), JoinError::NoCall => write!(f, "tried to leave a non-existent call."),
JoinError::TimedOut => write!(f, "gateway response from Discord timed out."),
#[cfg(feature = "driver-core")] #[cfg(feature = "driver-core")]
JoinError::Driver(t) => write!(f, "internal driver error {}.", t), JoinError::Driver(t) => write!(f, "internal driver error {}.", t),
#[cfg(feature = "serenity")] #[cfg(feature = "serenity")]

View File

@@ -1,15 +1,14 @@
#[cfg(feature = "driver-core")] #[cfg(feature = "driver-core")]
use crate::{ use crate::{driver::Driver, error::ConnectionResult};
driver::{Config, Driver},
error::ConnectionResult,
};
use crate::{ use crate::{
error::{JoinError, JoinResult}, error::{JoinError, JoinResult},
id::{ChannelId, GuildId, UserId}, id::{ChannelId, GuildId, UserId},
info::{ConnectionInfo, ConnectionProgress}, info::{ConnectionInfo, ConnectionProgress},
join::*,
shards::Shard, shards::Shard,
Config,
}; };
use flume::{r#async::RecvFut, Sender}; use flume::Sender;
use serde_json::json; use serde_json::json;
use tracing::instrument; use tracing::instrument;
@@ -18,9 +17,15 @@ use std::ops::{Deref, DerefMut};
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
enum Return { enum Return {
// Return the connection info as it is received.
Info(Sender<ConnectionInfo>), Info(Sender<ConnectionInfo>),
// Two channels: first indicates "gateway connection" was successful,
// second indicates that the driver successfully connected.
// The first is needed to cancel a timeout as the driver can/should
// have separate connection timing/retry config.
#[cfg(feature = "driver-core")] #[cfg(feature = "driver-core")]
Conn(Sender<ConnectionResult<()>>), Conn(Sender<()>, Sender<ConnectionResult<()>>),
} }
/// The Call handler is responsible for a single voice connection, acting /// The Call handler is responsible for a single voice connection, acting
@@ -32,6 +37,9 @@ enum Return {
/// [`Driver`]: struct@Driver /// [`Driver`]: struct@Driver
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct Call { pub struct Call {
#[cfg(not(feature = "driver-core"))]
config: Config,
connection: Option<(ConnectionProgress, Return)>, connection: Option<(ConnectionProgress, Return)>,
#[cfg(feature = "driver-core")] #[cfg(feature = "driver-core")]
@@ -61,19 +69,13 @@ impl Call {
#[inline] #[inline]
#[instrument] #[instrument]
pub fn new(guild_id: GuildId, ws: Shard, user_id: UserId) -> Self { pub fn new(guild_id: GuildId, ws: Shard, user_id: UserId) -> Self {
Self::new_raw(guild_id, Some(ws), user_id) Self::new_raw_cfg(guild_id, Some(ws), user_id, Default::default())
} }
#[cfg(feature = "driver-core")]
/// Creates a new Call, configuring the driver as specified. /// Creates a new Call, configuring the driver as specified.
#[inline] #[inline]
#[instrument] #[instrument]
pub fn from_driver_config( pub fn from_config(guild_id: GuildId, ws: Shard, user_id: UserId, config: Config) -> Self {
guild_id: GuildId,
ws: Shard,
user_id: UserId,
config: Config,
) -> Self {
Self::new_raw_cfg(guild_id, Some(ws), user_id, config) Self::new_raw_cfg(guild_id, Some(ws), user_id, config)
} }
@@ -88,38 +90,22 @@ impl Call {
#[inline] #[inline]
#[instrument] #[instrument]
pub fn standalone(guild_id: GuildId, user_id: UserId) -> Self { pub fn standalone(guild_id: GuildId, user_id: UserId) -> Self {
Self::new_raw(guild_id, None, user_id) Self::new_raw_cfg(guild_id, None, user_id, Default::default())
} }
#[cfg(feature = "driver-core")] /// Creates a new standalone Call from the given configuration file.
/// Creates a new standalone Call, configuring the driver as specified.
#[inline] #[inline]
#[instrument] #[instrument]
pub fn standalone_from_driver_config( pub fn standalone_from_config(guild_id: GuildId, user_id: UserId, config: Config) -> Self {
guild_id: GuildId,
user_id: UserId,
config: Config,
) -> Self {
Self::new_raw_cfg(guild_id, None, user_id, config) Self::new_raw_cfg(guild_id, None, user_id, config)
} }
fn new_raw(guild_id: GuildId, ws: Option<Shard>, user_id: UserId) -> Self {
Call {
connection: None,
#[cfg(feature = "driver-core")]
driver: Default::default(),
guild_id,
self_deaf: false,
self_mute: false,
user_id,
ws,
}
}
#[cfg(feature = "driver-core")]
fn new_raw_cfg(guild_id: GuildId, ws: Option<Shard>, user_id: UserId, config: Config) -> Self { fn new_raw_cfg(guild_id: GuildId, ws: Option<Shard>, user_id: UserId, config: Config) -> Self {
Call { Call {
#[cfg(not(feature = "driver-core"))]
config,
connection: None, connection: None,
#[cfg(feature = "driver-core")]
driver: Driver::new(config), driver: Driver::new(config),
guild_id, guild_id,
self_deaf: false, self_deaf: false,
@@ -137,8 +123,11 @@ impl Call {
let _ = tx.send(c.clone()); let _ = tx.send(c.clone());
}, },
#[cfg(feature = "driver-core")] #[cfg(feature = "driver-core")]
Some((ConnectionProgress::Complete(c), Return::Conn(tx))) => { Some((ConnectionProgress::Complete(c), Return::Conn(first_tx, driver_tx))) => {
self.driver.raw_connect(c.clone(), tx.clone()); // It's okay if the receiver hung up.
let _ = first_tx.send(());
self.driver.raw_connect(c.clone(), driver_tx.clone());
}, },
_ => {}, _ => {},
} }
@@ -209,11 +198,9 @@ impl Call {
/// ///
/// [`Songbird::join`]: crate::Songbird::join /// [`Songbird::join`]: crate::Songbird::join
#[instrument(skip(self))] #[instrument(skip(self))]
pub async fn join( pub async fn join(&mut self, channel_id: ChannelId) -> JoinResult<Join> {
&mut self,
channel_id: ChannelId,
) -> JoinResult<RecvFut<'static, ConnectionResult<()>>> {
let (tx, rx) = flume::unbounded(); let (tx, rx) = flume::unbounded();
let (gw_tx, gw_rx) = flume::unbounded();
let do_conn = self let do_conn = self
.should_actually_join(|_| Ok(()), &tx, channel_id) .should_actually_join(|_| Ok(()), &tx, channel_id)
@@ -222,12 +209,20 @@ impl Call {
if do_conn { if do_conn {
self.connection = Some(( self.connection = Some((
ConnectionProgress::new(self.guild_id, self.user_id, channel_id), ConnectionProgress::new(self.guild_id, self.user_id, channel_id),
Return::Conn(tx), Return::Conn(gw_tx, tx),
)); ));
self.update().await.map(|_| rx.into_recv_async()) let timeout = self.config().gateway_timeout;
self.update()
.await
.map(|_| Join::new(rx.into_recv_async(), gw_rx.into_recv_async(), timeout))
} else { } else {
Ok(rx.into_recv_async()) Ok(Join::new(
rx.into_recv_async(),
gw_rx.into_recv_async(),
None,
))
} }
} }
@@ -247,10 +242,7 @@ impl Call {
/// ///
/// [`Songbird::join_gateway`]: crate::Songbird::join_gateway /// [`Songbird::join_gateway`]: crate::Songbird::join_gateway
#[instrument(skip(self))] #[instrument(skip(self))]
pub async fn join_gateway( pub async fn join_gateway(&mut self, channel_id: ChannelId) -> JoinResult<JoinGateway> {
&mut self,
channel_id: ChannelId,
) -> JoinResult<RecvFut<'static, ConnectionInfo>> {
let (tx, rx) = flume::unbounded(); let (tx, rx) = flume::unbounded();
let do_conn = self let do_conn = self
@@ -267,9 +259,13 @@ impl Call {
Return::Info(tx), Return::Info(tx),
)); ));
self.update().await.map(|_| rx.into_recv_async()) let timeout = self.config().gateway_timeout;
self.update()
.await
.map(|_| JoinGateway::new(rx.into_recv_async(), timeout))
} else { } else {
Ok(rx.into_recv_async()) Ok(JoinGateway::new(rx.into_recv_async(), None))
} }
} }
@@ -414,6 +410,24 @@ impl Call {
} }
} }
#[cfg(not(feature = "driver-core"))]
impl Call {
/// Access this call handler's configuration.
pub fn config(&self) -> &Config {
&self.config
}
/// Mutably access this call handler's configuration.
pub fn config_mut(&mut self) -> &mut Config {
&mut self.config
}
/// Set this call handler's configuration.
pub fn set_config(&mut self, config: Config) {
self.config = config;
}
}
#[cfg(feature = "driver-core")] #[cfg(feature = "driver-core")]
impl Deref for Call { impl Deref for Call {
type Target = Driver; type Target = Driver;

View File

@@ -1,7 +1,7 @@
use super::*; use super::*;
use crate::{ use crate::{
constants::*, constants::*,
input::{error::Error, ffmpeg, Codec, Container, Input, Reader}, input::{error::Error, Codec, Container, Input},
test_utils::*, test_utils::*,
}; };
use audiopus::{coder::Decoder, Bitrate, Channels, SampleRate}; use audiopus::{coder::Decoder, Bitrate, Channels, SampleRate};

View File

@@ -559,7 +559,7 @@ mod tests {
let mut input = Input::new(false, data.clone().into(), Codec::Pcm, Container::Raw, None); let mut input = Input::new(false, data.clone().into(), Codec::Pcm, Container::Raw, None);
let mut out_vec = vec![]; let mut out_vec = vec![];
let len = input.read_to_end(&mut out_vec).unwrap(); let _len = input.read_to_end(&mut out_vec).unwrap();
let mut i16_window = &data[..]; let mut i16_window = &data[..];
let mut float_window = &out_vec[..]; let mut float_window = &out_vec[..];
@@ -580,7 +580,7 @@ mod tests {
let mut input = Input::new(true, data.clone().into(), Codec::Pcm, Container::Raw, None); let mut input = Input::new(true, data.clone().into(), Codec::Pcm, Container::Raw, None);
let mut out_vec = vec![]; let mut out_vec = vec![];
let len = input.read_to_end(&mut out_vec).unwrap(); let _len = input.read_to_end(&mut out_vec).unwrap();
let mut i16_window = &data[..]; let mut i16_window = &data[..];
let mut float_window = &out_vec[..]; let mut float_window = &out_vec[..];

174
src/join.rs Normal file
View File

@@ -0,0 +1,174 @@
//! Future types for gateway interactions.
#[cfg(feature = "driver-core")]
use crate::error::ConnectionResult;
use crate::{
error::{JoinError, JoinResult},
ConnectionInfo,
};
use core::{
convert,
future::Future,
marker::Unpin,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use flume::r#async::RecvFut;
use pin_project::pin_project;
#[cfg(not(feature = "tokio-02-marker"))]
use tokio::time::{self, Timeout};
#[cfg(feature = "tokio-02-marker")]
use tokio_compat::time::{self, Timeout};
#[cfg(feature = "driver-core")]
/// Future for a call to [`Call::join`].
///
/// This future `await`s Discord's response *and*
/// connection via the [`Driver`]. Both phases have
/// separate timeouts and failure conditions.
///
/// This future ***must not*** be `await`ed while
/// holding the lock around a [`Call`].
///
/// [`Call::join`]: crate::Call::join
/// [`Call`]: crate::Call
/// [`Driver`]: crate::driver::Driver
#[pin_project]
pub struct Join {
#[pin]
gw: JoinClass<()>,
#[pin]
driver: JoinClass<ConnectionResult<()>>,
state: JoinState,
}
#[cfg(feature = "driver-core")]
impl Join {
pub(crate) fn new(
driver: RecvFut<'static, ConnectionResult<()>>,
gw_recv: RecvFut<'static, ()>,
timeout: Option<Duration>,
) -> Self {
Self {
gw: JoinClass::new(gw_recv, timeout),
driver: JoinClass::new(driver, None),
state: JoinState::BeforeGw,
}
}
}
#[cfg(feature = "driver-core")]
impl Future for Join {
type Output = JoinResult<()>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
if *this.state == JoinState::BeforeGw {
let poll = this.gw.poll(cx);
match poll {
Poll::Ready(a) if a.is_ok() => {
*this.state = JoinState::AfterGw;
},
Poll::Ready(a) => {
*this.state = JoinState::Finalised;
return Poll::Ready(a);
},
Poll::Pending => return Poll::Pending,
}
}
if *this.state == JoinState::AfterGw {
let poll = this
.driver
.poll(cx)
.map_ok(|res| res.map_err(JoinError::Driver))
.map(|res| res.and_then(convert::identity));
match poll {
Poll::Ready(a) => {
*this.state = JoinState::Finalised;
return Poll::Ready(a);
},
Poll::Pending => return Poll::Pending,
}
}
Poll::Pending
}
}
#[cfg(feature = "driver-core")]
#[derive(Copy, Clone, Eq, PartialEq)]
enum JoinState {
BeforeGw,
AfterGw,
Finalised,
}
/// Future for a call to [`Call::join_gateway`].
///
/// This future `await`s Discord's gateway response, subject
/// to any timeouts.
///
/// This future ***must not*** be `await`ed while
/// holding the lock around a [`Call`].
///
/// [`Call::join_gateway`]: crate::Call::join_gateway
/// [`Call`]: crate::Call
/// [`Driver`]: crate::driver::Driver
#[pin_project]
pub struct JoinGateway {
#[pin]
inner: JoinClass<ConnectionInfo>,
}
impl JoinGateway {
pub(crate) fn new(recv: RecvFut<'static, ConnectionInfo>, timeout: Option<Duration>) -> Self {
Self {
inner: JoinClass::new(recv, timeout),
}
}
}
impl Future for JoinGateway {
type Output = JoinResult<ConnectionInfo>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.project().inner.poll(cx)
}
}
#[pin_project(project = JoinClassProj)]
enum JoinClass<T: 'static> {
WithTimeout(#[pin] Timeout<RecvFut<'static, T>>),
Vanilla(RecvFut<'static, T>),
}
impl<T: 'static> JoinClass<T> {
pub(crate) fn new(recv: RecvFut<'static, T>, timeout: Option<Duration>) -> Self {
match timeout {
Some(t) => JoinClass::WithTimeout(time::timeout(t, recv)),
None => JoinClass::Vanilla(recv),
}
}
}
impl<T> Future for JoinClass<T>
where
T: Unpin,
{
type Output = JoinResult<T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project() {
JoinClassProj::WithTimeout(t) => t
.poll(cx)
.map_err(|_| JoinError::TimedOut)
.map_ok(|res| res.map_err(|_| JoinError::Dropped))
.map(|m| m.and_then(convert::identity)),
JoinClassProj::Vanilla(t) => Pin::new(t).poll(cx).map_err(|_| JoinError::Dropped),
}
}
}

View File

@@ -37,6 +37,7 @@
//! [`ConnectionInfo`]: struct@ConnectionInfo //! [`ConnectionInfo`]: struct@ConnectionInfo
//! [lavalink]: https://github.com/Frederikam/Lavalink //! [lavalink]: https://github.com/Frederikam/Lavalink
mod config;
pub mod constants; pub mod constants;
#[cfg(feature = "driver-core")] #[cfg(feature = "driver-core")]
pub mod driver; pub mod driver;
@@ -50,6 +51,8 @@ pub(crate) mod info;
#[cfg(feature = "driver-core")] #[cfg(feature = "driver-core")]
pub mod input; pub mod input;
#[cfg(feature = "gateway-core")] #[cfg(feature = "gateway-core")]
pub mod join;
#[cfg(feature = "gateway-core")]
mod manager; mod manager;
#[cfg(feature = "serenity")] #[cfg(feature = "serenity")]
pub mod serenity; pub mod serenity;
@@ -61,6 +64,7 @@ pub mod tracks;
mod ws; mod ws;
#[cfg(feature = "driver-core")] #[cfg(feature = "driver-core")]
/// Opus encoder bitrate settings.
pub use audiopus::{self as opus, Bitrate}; pub use audiopus::{self as opus, Bitrate};
#[cfg(feature = "driver-core")] #[cfg(feature = "driver-core")]
pub use discortp as packet; pub use discortp as packet;
@@ -86,4 +90,5 @@ pub use crate::{handler::*, manager::*};
#[cfg(feature = "serenity")] #[cfg(feature = "serenity")]
pub use crate::serenity::*; pub use crate::serenity::*;
pub use config::Config;
pub use info::ConnectionInfo; pub use info::ConnectionInfo;

View File

@@ -1,10 +1,9 @@
#[cfg(feature = "driver-core")]
use crate::driver::Config;
use crate::{ use crate::{
error::{JoinError, JoinResult}, error::{JoinError, JoinResult},
id::{ChannelId, GuildId, UserId}, id::{ChannelId, GuildId, UserId},
shards::Sharder, shards::Sharder,
Call, Call,
Config,
ConnectionInfo, ConnectionInfo,
}; };
#[cfg(feature = "serenity")] #[cfg(feature = "serenity")]
@@ -50,9 +49,7 @@ pub struct Songbird {
client_data: PRwLock<ClientData>, client_data: PRwLock<ClientData>,
calls: DashMap<GuildId, Arc<Mutex<Call>>>, calls: DashMap<GuildId, Arc<Mutex<Call>>>,
sharder: Sharder, sharder: Sharder,
config: PRwLock<Option<Config>>,
#[cfg(feature = "driver-core")]
driver_config: PRwLock<Option<Config>>,
} }
impl Songbird { impl Songbird {
@@ -63,13 +60,21 @@ impl Songbird {
/// ///
/// [registered]: crate::serenity::register_with /// [registered]: crate::serenity::register_with
pub fn serenity() -> Arc<Self> { pub fn serenity() -> Arc<Self> {
Self::serenity_from_config(Default::default())
}
#[cfg(feature = "serenity")]
/// Create a new Songbird instance for serenity, using the given configuration.
///
/// This must be [registered] after creation.
///
/// [registered]: crate::serenity::register_with
pub fn serenity_from_config(config: Config) -> Arc<Self> {
Arc::new(Self { Arc::new(Self {
client_data: Default::default(), client_data: Default::default(),
calls: Default::default(), calls: Default::default(),
sharder: Sharder::Serenity(Default::default()), sharder: Sharder::Serenity(Default::default()),
config: Some(config).into(),
#[cfg(feature = "driver-core")]
driver_config: Default::default(),
}) })
} }
@@ -82,6 +87,26 @@ impl Songbird {
/// ///
/// [`process`]: Songbird::process /// [`process`]: Songbird::process
pub fn twilight<U>(cluster: Cluster, shard_count: u64, user_id: U) -> Arc<Self> pub fn twilight<U>(cluster: Cluster, shard_count: u64, user_id: U) -> Arc<Self>
where
U: Into<UserId>,
{
Self::twilight_from_config(cluster, shard_count, user_id, Default::default())
}
#[cfg(feature = "twilight")]
/// Create a new Songbird instance for twilight.
///
/// Twilight handlers do not need to be registered, but
/// users are responsible for passing in any events using
/// [`process`].
///
/// [`process`]: Songbird::process
pub fn twilight_from_config<U>(
cluster: Cluster,
shard_count: u64,
user_id: U,
config: Config,
) -> Arc<Self>
where where
U: Into<UserId>, U: Into<UserId>,
{ {
@@ -93,9 +118,7 @@ impl Songbird {
}), }),
calls: Default::default(), calls: Default::default(),
sharder: Sharder::Twilight(cluster), sharder: Sharder::Twilight(cluster),
config: Some(config).into(),
#[cfg(feature = "driver-core")]
driver_config: Default::default(),
}) })
} }
@@ -144,23 +167,30 @@ impl Songbird {
.get_shard(shard) .get_shard(shard)
.expect("Failed to get shard handle: shard_count incorrect?"); .expect("Failed to get shard handle: shard_count incorrect?");
#[cfg(feature = "driver-core")] let call = Call::from_config(
let call = Call::from_driver_config(
guild_id, guild_id,
shard_handle, shard_handle,
info.user_id, info.user_id,
self.driver_config.read().clone().unwrap_or_default(), self.config.read().clone().unwrap_or_default(),
); );
#[cfg(not(feature = "driver-core"))]
let call = Call::new(guild_id, shard_handle, info.user_id);
Arc::new(Mutex::new(call)) Arc::new(Mutex::new(call))
}) })
.clone() .clone()
}) })
} }
/// Sets a shared configuration for all drivers created from this
/// manager.
///
/// Changes made here will apply to new Call and Driver instances only.
///
/// Requires the `"driver"` feature.
pub fn set_config(&self, new_config: Config) {
let mut config = self.config.write();
*config = Some(new_config);
}
fn manager_info(&self) -> ClientData { fn manager_info(&self) -> ClientData {
let client_data = self.client_data.write(); let client_data = self.client_data.write();
@@ -213,10 +243,7 @@ impl Songbird {
}; };
let result = match stage_1 { let result = match stage_1 {
Ok(chan) => chan Ok(chan) => chan.await,
.await
.map_err(|_| JoinError::Dropped)
.and_then(|x| x.map_err(JoinError::from)),
Err(e) => Err(e), Err(e) => Err(e),
}; };
@@ -401,20 +428,6 @@ impl VoiceGatewayManager for Songbird {
} }
} }
#[cfg(feature = "driver-core")]
impl Songbird {
/// Sets a shared configuration for all drivers created from this
/// manager.
///
/// Changes made here will apply to new Call and Driver instances only.
///
/// Requires the `"driver"` feature.
pub fn set_config(&self, new_config: Config) {
let mut config = self.driver_config.write();
*config = Some(new_config);
}
}
#[inline] #[inline]
fn shard_id(guild_id: u64, shard_count: u64) -> u64 { fn shard_id(guild_id: u64, shard_count: u64) -> u64 {
(guild_id >> 22) % shard_count (guild_id >> 22) % shard_count

View File

@@ -3,7 +3,7 @@
//! //!
//! [serenity]: https://crates.io/crates/serenity/0.9.0-rc.2 //! [serenity]: https://crates.io/crates/serenity/0.9.0-rc.2
use crate::manager::Songbird; use crate::{Config, Songbird};
use serenity::{ use serenity::{
client::{ClientBuilder, Context}, client::{ClientBuilder, Context},
prelude::TypeMapKey, prelude::TypeMapKey,
@@ -37,6 +37,14 @@ pub fn register_with(client_builder: ClientBuilder, voice: Arc<Songbird>) -> Cli
.type_map_insert::<SongbirdKey>(voice) .type_map_insert::<SongbirdKey>(voice)
} }
/// Installs a given songbird instance into the serenity client.
///
/// This should be called after any uses of `ClientBuilder::type_map`.
pub fn register_from_config(client_builder: ClientBuilder, config: Config) -> ClientBuilder {
let voice = Songbird::serenity_from_config(config);
register_with(client_builder, voice)
}
/// Retrieve the Songbird voice client from a serenity context's /// Retrieve the Songbird voice client from a serenity context's
/// shared key-value store. /// shared key-value store.
pub async fn get(ctx: &Context) -> Option<Arc<Songbird>> { pub async fn get(ctx: &Context) -> Option<Arc<Songbird>> {
@@ -58,6 +66,8 @@ pub trait SerenityInit {
fn register_songbird(self) -> Self; fn register_songbird(self) -> Self;
/// Registers a given Songbird voice system with serenity, as above. /// Registers a given Songbird voice system with serenity, as above.
fn register_songbird_with(self, voice: Arc<Songbird>) -> Self; fn register_songbird_with(self, voice: Arc<Songbird>) -> Self;
/// Registers a Songbird voice system serenity, based on the given configuration.
fn register_songbird_from_config(self, config: Config) -> Self;
} }
impl SerenityInit for ClientBuilder<'_> { impl SerenityInit for ClientBuilder<'_> {
@@ -68,4 +78,8 @@ impl SerenityInit for ClientBuilder<'_> {
fn register_songbird_with(self, voice: Arc<Songbird>) -> Self { fn register_songbird_with(self, voice: Arc<Songbird>) -> Self {
register_with(self, voice) register_with(self, voice)
} }
fn register_songbird_from_config(self, config: Config) -> Self {
register_from_config(self, config)
}
} }