Driver: Migrate to tokio_tungstenite (#138)

This places songbird, serenity, and twilight onto the same WS library, hopefully reducing the compile overhead for everyone.

Tested using `cargo make ready` and by running `examples/voice`.

Closes #129.
This commit is contained in:
Kyle Simpson
2022-07-25 17:19:55 +01:00
parent 13946b47ce
commit 76c9851034
5 changed files with 74 additions and 187 deletions

View File

@@ -21,20 +21,10 @@ serde_json = "1"
tracing = { version = "0.1", features = ["log"] } tracing = { version = "0.1", features = ["log"] }
tracing-futures = "0.2" tracing-futures = "0.2"
[dependencies.once_cell]
version = "1"
optional = true
[dependencies.async-trait] [dependencies.async-trait]
optional = true optional = true
version = "0.1" version = "0.1"
[dependencies.async-tungstenite]
default-features = false
features = ["tokio-runtime"]
optional = true
version = "0.17"
[dependencies.audiopus] [dependencies.audiopus]
optional = true optional = true
version = "0.3.0-rc.0" version = "0.3.0-rc.0"
@@ -59,6 +49,10 @@ version = "0.10"
[dependencies.futures] [dependencies.futures]
version = "0.3" version = "0.3"
[dependencies.once_cell]
version = "1"
optional = true
[dependencies.parking_lot] [dependencies.parking_lot]
optional = true optional = true
version = "0.12" version = "0.12"
@@ -127,6 +121,10 @@ optional = true
version = "1.0" version = "1.0"
default-features = false default-features = false
[dependencies.tokio-tungstenite]
optional = true
version = "0.17"
[dependencies.tokio-util] [dependencies.tokio-util]
optional = true optional = true
version = "0.7" version = "0.7"
@@ -184,7 +182,6 @@ gateway = [
] ]
driver = [ driver = [
"async-trait", "async-trait",
"async-tungstenite",
"audiopus", "audiopus",
"byteorder", "byteorder",
"discortp", "discortp",
@@ -201,7 +198,6 @@ driver = [
"symphonia", "symphonia",
"symphonia-core", "symphonia-core",
"rusty_pool", "rusty_pool",
"tokio-util",
"tokio/fs", "tokio/fs",
"tokio/io-util", "tokio/io-util",
"tokio/macros", "tokio/macros",
@@ -210,13 +206,15 @@ driver = [
"tokio/rt", "tokio/rt",
"tokio/sync", "tokio/sync",
"tokio/time", "tokio/time",
"tokio-tungstenite",
"tokio-util",
"typemap_rev", "typemap_rev",
"url", "url",
"uuid", "uuid",
"xsalsa20poly1305", "xsalsa20poly1305",
] ]
rustls = ["async-tungstenite/tokio-rustls-webpki-roots", "reqwest/rustls-tls", "rustls-marker"] rustls = ["tokio-tungstenite/rustls-tls-webpki-roots", "reqwest/rustls-tls", "rustls-marker"]
native = ["async-tungstenite/tokio-native-tls", "native-marker", "reqwest/native-tls"] native = ["tokio-tungstenite/native-tls", "native-marker", "reqwest/native-tls"]
serenity-rustls = ["serenity/rustls_backend", "rustls", "gateway", "serenity-deps"] serenity-rustls = ["serenity/rustls_backend", "rustls", "gateway", "serenity-deps"]
serenity-native = ["serenity/native_tls_backend", "native", "gateway", "serenity-deps"] serenity-native = ["serenity/native_tls_backend", "native", "gateway", "serenity-deps"]
twilight-rustls = ["twilight", "twilight-gateway/rustls-native-roots", "rustls", "gateway"] twilight-rustls = ["twilight", "twilight-gateway/rustls-native-roots", "rustls", "gateway"]

View File

@@ -12,7 +12,7 @@ use crate::{
Event as GatewayEvent, Event as GatewayEvent,
ProtocolData, ProtocolData,
}, },
ws::{self, ReceiverExt, SenderExt, WsStream}, ws::WsStream,
ConnectionInfo, ConnectionInfo,
}; };
use discortp::discord::{IpDiscoveryPacket, IpDiscoveryType, MutableIpDiscoveryPacket}; use discortp::discord::{IpDiscoveryPacket, IpDiscoveryType, MutableIpDiscoveryPacket};
@@ -24,12 +24,6 @@ use tracing::{debug, info, instrument};
use url::Url; use url::Url;
use xsalsa20poly1305::{aead::NewAead, XSalsa20Poly1305 as Cipher}; use xsalsa20poly1305::{aead::NewAead, XSalsa20Poly1305 as Cipher};
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
use ws::create_rustls_client;
#[cfg(feature = "native-marker")]
use ws::create_native_tls_client;
pub(crate) struct Connection { pub(crate) struct Connection {
pub(crate) info: ConnectionInfo, pub(crate) info: ConnectionInfo,
pub(crate) ssrc: u32, pub(crate) ssrc: u32,
@@ -58,11 +52,7 @@ impl Connection {
) -> Result<Connection> { ) -> Result<Connection> {
let url = generate_url(&mut info.endpoint)?; let url = generate_url(&mut info.endpoint)?;
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))] let mut client = WsStream::connect(url).await?;
let mut client = create_rustls_client(url).await?;
#[cfg(feature = "native-marker")]
let mut client = create_native_tls_client(url).await?;
let mut hello = None; let mut hello = None;
let mut ready = None; let mut ready = None;
@@ -241,12 +231,7 @@ impl Connection {
// Thread may have died, we want to send to prompt a clean exit // Thread may have died, we want to send to prompt a clean exit
// (if at all possible) and then proceed as normal. // (if at all possible) and then proceed as normal.
let mut client = WsStream::connect(url).await?;
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
let mut client = create_rustls_client(url).await?;
#[cfg(feature = "native-marker")]
let mut client = create_native_tls_client(url).await?;
client client
.send_json(&GatewayEvent::from(Resume { .send_json(&GatewayEvent::from(Resume {

View File

@@ -8,10 +8,9 @@ use crate::{
FromPrimitive, FromPrimitive,
SpeakingState, SpeakingState,
}, },
ws::{Error as WsError, ReceiverExt, SenderExt, WsStream}, ws::{Error as WsError, WsStream},
ConnectionInfo, ConnectionInfo,
}; };
use async_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
use flume::Receiver; use flume::Receiver;
use rand::random; use rand::random;
use std::time::Duration; use std::time::Duration;
@@ -19,6 +18,7 @@ use tokio::{
select, select,
time::{sleep_until, Instant}, time::{sleep_until, Instant},
}; };
use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
use tracing::{debug, info, instrument, trace, warn}; use tracing::{debug, info, instrument, trace, warn};
struct AuxNetwork { struct AuxNetwork {

View File

@@ -4,7 +4,7 @@ use crate::{
model::{CloseCode as VoiceCloseCode, FromPrimitive}, model::{CloseCode as VoiceCloseCode, FromPrimitive},
ws::Error as WsError, ws::Error as WsError,
}; };
use async_tungstenite::tungstenite::protocol::frame::coding::CloseCode; use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
/// Voice connection details gathered at termination or failure. /// Voice connection details gathered at termination or failure.
/// ///

208
src/ws.rs
View File

@@ -1,25 +1,71 @@
use crate::{error::JsonError, model::Event}; use crate::{error::JsonError, model::Event};
use async_trait::async_trait; use futures::{SinkExt, StreamExt, TryStreamExt};
use async_tungstenite::{ use tokio::{
self as tungstenite, net::TcpStream,
tokio::ConnectStream, time::{timeout, Duration},
tungstenite::{error::Error as TungsteniteError, protocol::CloseFrame, Message}, };
use tokio_tungstenite::{
tungstenite::{
error::Error as TungsteniteError,
protocol::{CloseFrame, WebSocketConfig as Config},
Message,
},
MaybeTlsStream,
WebSocketStream, WebSocketStream,
}; };
use futures::{SinkExt, StreamExt, TryStreamExt};
use tokio::time::{timeout, Duration};
use tracing::instrument; use tracing::instrument;
use url::Url;
pub type WsStream = WebSocketStream<ConnectStream>; pub struct WsStream(WebSocketStream<MaybeTlsStream<TcpStream>>);
impl WsStream {
#[instrument]
pub(crate) async fn connect(url: Url) -> Result<Self> {
let (stream, _) = tokio_tungstenite::connect_async_with_config::<Url>(
url,
Some(Config {
max_message_size: None,
max_frame_size: None,
max_send_queue: None,
..Default::default()
}),
)
.await?;
Ok(Self(stream))
}
pub(crate) async fn recv_json(&mut self) -> Result<Option<Event>> {
const TIMEOUT: Duration = Duration::from_millis(500);
let ws_message = match timeout(TIMEOUT, self.0.next()).await {
Ok(Some(Ok(v))) => Some(v),
Ok(Some(Err(e))) => return Err(e.into()),
Ok(None) | Err(_) => None,
};
convert_ws_message(ws_message)
}
pub(crate) async fn recv_json_no_timeout(&mut self) -> Result<Option<Event>> {
convert_ws_message(self.0.try_next().await?)
}
pub(crate) async fn send_json(&mut self, value: &Event) -> Result<()> {
Ok(crate::json::to_string(value)
.map(Message::Text)
.map_err(Error::from)
.map(|m| self.0.send(m))?
.await?)
}
}
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug)] #[derive(Debug)]
pub enum Error { pub enum Error {
Json(JsonError), Json(JsonError),
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
Tls(RustlsError),
/// The discord voice gateway does not support or offer zlib compression. /// The discord voice gateway does not support or offer zlib compression.
/// As a result, only text messages are expected. /// As a result, only text messages are expected.
@@ -36,80 +82,12 @@ impl From<JsonError> for Error {
} }
} }
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
impl From<RustlsError> for Error {
fn from(e: RustlsError) -> Error {
Error::Tls(e)
}
}
impl From<TungsteniteError> for Error { impl From<TungsteniteError> for Error {
fn from(e: TungsteniteError) -> Error { fn from(e: TungsteniteError) -> Error {
Error::Ws(e) Error::Ws(e)
} }
} }
use futures::stream::SplitSink;
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
use std::{
error::Error as StdError,
fmt::{Display, Formatter, Result as FmtResult},
io::Error as IoError,
};
use url::Url;
#[async_trait]
pub trait ReceiverExt {
async fn recv_json(&mut self) -> Result<Option<Event>>;
async fn recv_json_no_timeout(&mut self) -> Result<Option<Event>>;
}
#[async_trait]
pub trait SenderExt {
async fn send_json(&mut self, value: &Event) -> Result<()>;
}
#[async_trait]
impl ReceiverExt for WsStream {
async fn recv_json(&mut self) -> Result<Option<Event>> {
const TIMEOUT: Duration = Duration::from_millis(500);
let ws_message = match timeout(TIMEOUT, self.next()).await {
Ok(Some(Ok(v))) => Some(v),
Ok(Some(Err(e))) => return Err(e.into()),
Ok(None) | Err(_) => None,
};
convert_ws_message(ws_message)
}
async fn recv_json_no_timeout(&mut self) -> Result<Option<Event>> {
convert_ws_message(self.try_next().await?)
}
}
#[async_trait]
impl SenderExt for SplitSink<WsStream, Message> {
async fn send_json(&mut self, value: &Event) -> Result<()> {
Ok(crate::json::to_string(value)
.map(Message::Text)
.map_err(Error::from)
.map(|m| self.send(m))?
.await?)
}
}
#[async_trait]
impl SenderExt for WsStream {
async fn send_json(&mut self, value: &Event) -> Result<()> {
Ok(crate::json::to_string(value)
.map(Message::Text)
.map_err(Error::from)
.map(|m| self.send(m))?
.await?)
}
}
#[inline] #[inline]
pub(crate) fn convert_ws_message(message: Option<Message>) -> Result<Option<Event>> { pub(crate) fn convert_ws_message(message: Option<Message>) -> Result<Option<Event>> {
Ok(match message { Ok(match message {
@@ -125,77 +103,3 @@ pub(crate) fn convert_ws_message(message: Option<Message>) -> Result<Option<Even
_ => None, _ => None,
}) })
} }
/// An error that occured while connecting over rustls
#[derive(Debug)]
#[non_exhaustive]
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
pub enum RustlsError {
/// An error with the handshake in tungstenite
HandshakeError,
/// Standard IO error happening while creating the tcp stream
Io(IoError),
}
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
impl From<IoError> for RustlsError {
fn from(e: IoError) -> Self {
RustlsError::Io(e)
}
}
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
impl Display for RustlsError {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
match self {
RustlsError::HandshakeError =>
f.write_str("TLS handshake failed when making the websocket connection"),
RustlsError::Io(inner) => Display::fmt(&inner, f),
}
}
}
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
impl StdError for RustlsError {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
match self {
RustlsError::Io(inner) => Some(inner),
_ => None,
}
}
}
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
#[instrument]
pub(crate) async fn create_rustls_client(url: Url) -> Result<WsStream> {
let (stream, _) = tungstenite::tokio::connect_async_with_config::<Url>(
url,
Some(tungstenite::tungstenite::protocol::WebSocketConfig {
max_message_size: None,
max_frame_size: None,
max_send_queue: None,
..Default::default()
}),
)
.await
.map_err(|_| RustlsError::HandshakeError)?;
Ok(stream)
}
#[cfg(feature = "native-marker")]
#[instrument]
pub(crate) async fn create_native_tls_client(url: Url) -> Result<WsStream> {
let (stream, _) = tungstenite::tokio::connect_async_with_config::<Url>(
url,
Some(tungstenite::tungstenite::protocol::WebSocketConfig {
max_message_size: None,
max_frame_size: None,
max_send_queue: None,
..Default::default()
}),
)
.await?;
Ok(stream)
}