Driver: Migrate to tokio_tungstenite (#138)
This places songbird, serenity, and twilight onto the same WS library, hopefully reducing the compile overhead for everyone. Tested using `cargo make ready` and by running `examples/voice`. Closes #129.
This commit is contained in:
26
Cargo.toml
26
Cargo.toml
@@ -21,20 +21,10 @@ serde_json = "1"
|
||||
tracing = { version = "0.1", features = ["log"] }
|
||||
tracing-futures = "0.2"
|
||||
|
||||
[dependencies.once_cell]
|
||||
version = "1"
|
||||
optional = true
|
||||
|
||||
[dependencies.async-trait]
|
||||
optional = true
|
||||
version = "0.1"
|
||||
|
||||
[dependencies.async-tungstenite]
|
||||
default-features = false
|
||||
features = ["tokio-runtime"]
|
||||
optional = true
|
||||
version = "0.17"
|
||||
|
||||
[dependencies.audiopus]
|
||||
optional = true
|
||||
version = "0.3.0-rc.0"
|
||||
@@ -59,6 +49,10 @@ version = "0.10"
|
||||
[dependencies.futures]
|
||||
version = "0.3"
|
||||
|
||||
[dependencies.once_cell]
|
||||
version = "1"
|
||||
optional = true
|
||||
|
||||
[dependencies.parking_lot]
|
||||
optional = true
|
||||
version = "0.12"
|
||||
@@ -127,6 +121,10 @@ optional = true
|
||||
version = "1.0"
|
||||
default-features = false
|
||||
|
||||
[dependencies.tokio-tungstenite]
|
||||
optional = true
|
||||
version = "0.17"
|
||||
|
||||
[dependencies.tokio-util]
|
||||
optional = true
|
||||
version = "0.7"
|
||||
@@ -184,7 +182,6 @@ gateway = [
|
||||
]
|
||||
driver = [
|
||||
"async-trait",
|
||||
"async-tungstenite",
|
||||
"audiopus",
|
||||
"byteorder",
|
||||
"discortp",
|
||||
@@ -201,7 +198,6 @@ driver = [
|
||||
"symphonia",
|
||||
"symphonia-core",
|
||||
"rusty_pool",
|
||||
"tokio-util",
|
||||
"tokio/fs",
|
||||
"tokio/io-util",
|
||||
"tokio/macros",
|
||||
@@ -210,13 +206,15 @@ driver = [
|
||||
"tokio/rt",
|
||||
"tokio/sync",
|
||||
"tokio/time",
|
||||
"tokio-tungstenite",
|
||||
"tokio-util",
|
||||
"typemap_rev",
|
||||
"url",
|
||||
"uuid",
|
||||
"xsalsa20poly1305",
|
||||
]
|
||||
rustls = ["async-tungstenite/tokio-rustls-webpki-roots", "reqwest/rustls-tls", "rustls-marker"]
|
||||
native = ["async-tungstenite/tokio-native-tls", "native-marker", "reqwest/native-tls"]
|
||||
rustls = ["tokio-tungstenite/rustls-tls-webpki-roots", "reqwest/rustls-tls", "rustls-marker"]
|
||||
native = ["tokio-tungstenite/native-tls", "native-marker", "reqwest/native-tls"]
|
||||
serenity-rustls = ["serenity/rustls_backend", "rustls", "gateway", "serenity-deps"]
|
||||
serenity-native = ["serenity/native_tls_backend", "native", "gateway", "serenity-deps"]
|
||||
twilight-rustls = ["twilight", "twilight-gateway/rustls-native-roots", "rustls", "gateway"]
|
||||
|
||||
@@ -12,7 +12,7 @@ use crate::{
|
||||
Event as GatewayEvent,
|
||||
ProtocolData,
|
||||
},
|
||||
ws::{self, ReceiverExt, SenderExt, WsStream},
|
||||
ws::WsStream,
|
||||
ConnectionInfo,
|
||||
};
|
||||
use discortp::discord::{IpDiscoveryPacket, IpDiscoveryType, MutableIpDiscoveryPacket};
|
||||
@@ -24,12 +24,6 @@ use tracing::{debug, info, instrument};
|
||||
use url::Url;
|
||||
use xsalsa20poly1305::{aead::NewAead, XSalsa20Poly1305 as Cipher};
|
||||
|
||||
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
|
||||
use ws::create_rustls_client;
|
||||
|
||||
#[cfg(feature = "native-marker")]
|
||||
use ws::create_native_tls_client;
|
||||
|
||||
pub(crate) struct Connection {
|
||||
pub(crate) info: ConnectionInfo,
|
||||
pub(crate) ssrc: u32,
|
||||
@@ -58,11 +52,7 @@ impl Connection {
|
||||
) -> Result<Connection> {
|
||||
let url = generate_url(&mut info.endpoint)?;
|
||||
|
||||
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
|
||||
let mut client = create_rustls_client(url).await?;
|
||||
|
||||
#[cfg(feature = "native-marker")]
|
||||
let mut client = create_native_tls_client(url).await?;
|
||||
let mut client = WsStream::connect(url).await?;
|
||||
|
||||
let mut hello = None;
|
||||
let mut ready = None;
|
||||
@@ -241,12 +231,7 @@ impl Connection {
|
||||
|
||||
// Thread may have died, we want to send to prompt a clean exit
|
||||
// (if at all possible) and then proceed as normal.
|
||||
|
||||
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
|
||||
let mut client = create_rustls_client(url).await?;
|
||||
|
||||
#[cfg(feature = "native-marker")]
|
||||
let mut client = create_native_tls_client(url).await?;
|
||||
let mut client = WsStream::connect(url).await?;
|
||||
|
||||
client
|
||||
.send_json(&GatewayEvent::from(Resume {
|
||||
|
||||
@@ -8,10 +8,9 @@ use crate::{
|
||||
FromPrimitive,
|
||||
SpeakingState,
|
||||
},
|
||||
ws::{Error as WsError, ReceiverExt, SenderExt, WsStream},
|
||||
ws::{Error as WsError, WsStream},
|
||||
ConnectionInfo,
|
||||
};
|
||||
use async_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
|
||||
use flume::Receiver;
|
||||
use rand::random;
|
||||
use std::time::Duration;
|
||||
@@ -19,6 +18,7 @@ use tokio::{
|
||||
select,
|
||||
time::{sleep_until, Instant},
|
||||
};
|
||||
use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
|
||||
use tracing::{debug, info, instrument, trace, warn};
|
||||
|
||||
struct AuxNetwork {
|
||||
|
||||
@@ -4,7 +4,7 @@ use crate::{
|
||||
model::{CloseCode as VoiceCloseCode, FromPrimitive},
|
||||
ws::Error as WsError,
|
||||
};
|
||||
use async_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
|
||||
use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
|
||||
|
||||
/// Voice connection details gathered at termination or failure.
|
||||
///
|
||||
|
||||
208
src/ws.rs
208
src/ws.rs
@@ -1,25 +1,71 @@
|
||||
use crate::{error::JsonError, model::Event};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use async_tungstenite::{
|
||||
self as tungstenite,
|
||||
tokio::ConnectStream,
|
||||
tungstenite::{error::Error as TungsteniteError, protocol::CloseFrame, Message},
|
||||
use futures::{SinkExt, StreamExt, TryStreamExt};
|
||||
use tokio::{
|
||||
net::TcpStream,
|
||||
time::{timeout, Duration},
|
||||
};
|
||||
use tokio_tungstenite::{
|
||||
tungstenite::{
|
||||
error::Error as TungsteniteError,
|
||||
protocol::{CloseFrame, WebSocketConfig as Config},
|
||||
Message,
|
||||
},
|
||||
MaybeTlsStream,
|
||||
WebSocketStream,
|
||||
};
|
||||
use futures::{SinkExt, StreamExt, TryStreamExt};
|
||||
use tokio::time::{timeout, Duration};
|
||||
use tracing::instrument;
|
||||
use url::Url;
|
||||
|
||||
pub type WsStream = WebSocketStream<ConnectStream>;
|
||||
pub struct WsStream(WebSocketStream<MaybeTlsStream<TcpStream>>);
|
||||
|
||||
impl WsStream {
|
||||
#[instrument]
|
||||
pub(crate) async fn connect(url: Url) -> Result<Self> {
|
||||
let (stream, _) = tokio_tungstenite::connect_async_with_config::<Url>(
|
||||
url,
|
||||
Some(Config {
|
||||
max_message_size: None,
|
||||
max_frame_size: None,
|
||||
max_send_queue: None,
|
||||
..Default::default()
|
||||
}),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(Self(stream))
|
||||
}
|
||||
|
||||
pub(crate) async fn recv_json(&mut self) -> Result<Option<Event>> {
|
||||
const TIMEOUT: Duration = Duration::from_millis(500);
|
||||
|
||||
let ws_message = match timeout(TIMEOUT, self.0.next()).await {
|
||||
Ok(Some(Ok(v))) => Some(v),
|
||||
Ok(Some(Err(e))) => return Err(e.into()),
|
||||
Ok(None) | Err(_) => None,
|
||||
};
|
||||
|
||||
convert_ws_message(ws_message)
|
||||
}
|
||||
|
||||
pub(crate) async fn recv_json_no_timeout(&mut self) -> Result<Option<Event>> {
|
||||
convert_ws_message(self.0.try_next().await?)
|
||||
}
|
||||
|
||||
pub(crate) async fn send_json(&mut self, value: &Event) -> Result<()> {
|
||||
Ok(crate::json::to_string(value)
|
||||
.map(Message::Text)
|
||||
.map_err(Error::from)
|
||||
.map(|m| self.0.send(m))?
|
||||
.await?)
|
||||
}
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
Json(JsonError),
|
||||
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
|
||||
Tls(RustlsError),
|
||||
|
||||
/// The discord voice gateway does not support or offer zlib compression.
|
||||
/// As a result, only text messages are expected.
|
||||
@@ -36,80 +82,12 @@ impl From<JsonError> for Error {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
|
||||
impl From<RustlsError> for Error {
|
||||
fn from(e: RustlsError) -> Error {
|
||||
Error::Tls(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TungsteniteError> for Error {
|
||||
fn from(e: TungsteniteError) -> Error {
|
||||
Error::Ws(e)
|
||||
}
|
||||
}
|
||||
|
||||
use futures::stream::SplitSink;
|
||||
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
|
||||
use std::{
|
||||
error::Error as StdError,
|
||||
fmt::{Display, Formatter, Result as FmtResult},
|
||||
io::Error as IoError,
|
||||
};
|
||||
use url::Url;
|
||||
|
||||
#[async_trait]
|
||||
pub trait ReceiverExt {
|
||||
async fn recv_json(&mut self) -> Result<Option<Event>>;
|
||||
async fn recv_json_no_timeout(&mut self) -> Result<Option<Event>>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait SenderExt {
|
||||
async fn send_json(&mut self, value: &Event) -> Result<()>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ReceiverExt for WsStream {
|
||||
async fn recv_json(&mut self) -> Result<Option<Event>> {
|
||||
const TIMEOUT: Duration = Duration::from_millis(500);
|
||||
|
||||
let ws_message = match timeout(TIMEOUT, self.next()).await {
|
||||
Ok(Some(Ok(v))) => Some(v),
|
||||
Ok(Some(Err(e))) => return Err(e.into()),
|
||||
Ok(None) | Err(_) => None,
|
||||
};
|
||||
|
||||
convert_ws_message(ws_message)
|
||||
}
|
||||
|
||||
async fn recv_json_no_timeout(&mut self) -> Result<Option<Event>> {
|
||||
convert_ws_message(self.try_next().await?)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl SenderExt for SplitSink<WsStream, Message> {
|
||||
async fn send_json(&mut self, value: &Event) -> Result<()> {
|
||||
Ok(crate::json::to_string(value)
|
||||
.map(Message::Text)
|
||||
.map_err(Error::from)
|
||||
.map(|m| self.send(m))?
|
||||
.await?)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl SenderExt for WsStream {
|
||||
async fn send_json(&mut self, value: &Event) -> Result<()> {
|
||||
Ok(crate::json::to_string(value)
|
||||
.map(Message::Text)
|
||||
.map_err(Error::from)
|
||||
.map(|m| self.send(m))?
|
||||
.await?)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn convert_ws_message(message: Option<Message>) -> Result<Option<Event>> {
|
||||
Ok(match message {
|
||||
@@ -125,77 +103,3 @@ pub(crate) fn convert_ws_message(message: Option<Message>) -> Result<Option<Even
|
||||
_ => None,
|
||||
})
|
||||
}
|
||||
|
||||
/// An error that occured while connecting over rustls
|
||||
#[derive(Debug)]
|
||||
#[non_exhaustive]
|
||||
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
|
||||
pub enum RustlsError {
|
||||
/// An error with the handshake in tungstenite
|
||||
HandshakeError,
|
||||
/// Standard IO error happening while creating the tcp stream
|
||||
Io(IoError),
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
|
||||
impl From<IoError> for RustlsError {
|
||||
fn from(e: IoError) -> Self {
|
||||
RustlsError::Io(e)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
|
||||
impl Display for RustlsError {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
|
||||
match self {
|
||||
RustlsError::HandshakeError =>
|
||||
f.write_str("TLS handshake failed when making the websocket connection"),
|
||||
RustlsError::Io(inner) => Display::fmt(&inner, f),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
|
||||
impl StdError for RustlsError {
|
||||
fn source(&self) -> Option<&(dyn StdError + 'static)> {
|
||||
match self {
|
||||
RustlsError::Io(inner) => Some(inner),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
|
||||
#[instrument]
|
||||
pub(crate) async fn create_rustls_client(url: Url) -> Result<WsStream> {
|
||||
let (stream, _) = tungstenite::tokio::connect_async_with_config::<Url>(
|
||||
url,
|
||||
Some(tungstenite::tungstenite::protocol::WebSocketConfig {
|
||||
max_message_size: None,
|
||||
max_frame_size: None,
|
||||
max_send_queue: None,
|
||||
..Default::default()
|
||||
}),
|
||||
)
|
||||
.await
|
||||
.map_err(|_| RustlsError::HandshakeError)?;
|
||||
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
#[cfg(feature = "native-marker")]
|
||||
#[instrument]
|
||||
pub(crate) async fn create_native_tls_client(url: Url) -> Result<WsStream> {
|
||||
let (stream, _) = tungstenite::tokio::connect_async_with_config::<Url>(
|
||||
url,
|
||||
Some(tungstenite::tungstenite::protocol::WebSocketConfig {
|
||||
max_message_size: None,
|
||||
max_frame_size: None,
|
||||
max_send_queue: None,
|
||||
..Default::default()
|
||||
}),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user