Driver: Migrate to tokio_tungstenite (#138)
This places songbird, serenity, and twilight onto the same WS library, hopefully reducing the compile overhead for everyone. Tested using `cargo make ready` and by running `examples/voice`. Closes #129.
This commit is contained in:
26
Cargo.toml
26
Cargo.toml
@@ -21,20 +21,10 @@ serde_json = "1"
|
|||||||
tracing = { version = "0.1", features = ["log"] }
|
tracing = { version = "0.1", features = ["log"] }
|
||||||
tracing-futures = "0.2"
|
tracing-futures = "0.2"
|
||||||
|
|
||||||
[dependencies.once_cell]
|
|
||||||
version = "1"
|
|
||||||
optional = true
|
|
||||||
|
|
||||||
[dependencies.async-trait]
|
[dependencies.async-trait]
|
||||||
optional = true
|
optional = true
|
||||||
version = "0.1"
|
version = "0.1"
|
||||||
|
|
||||||
[dependencies.async-tungstenite]
|
|
||||||
default-features = false
|
|
||||||
features = ["tokio-runtime"]
|
|
||||||
optional = true
|
|
||||||
version = "0.17"
|
|
||||||
|
|
||||||
[dependencies.audiopus]
|
[dependencies.audiopus]
|
||||||
optional = true
|
optional = true
|
||||||
version = "0.3.0-rc.0"
|
version = "0.3.0-rc.0"
|
||||||
@@ -59,6 +49,10 @@ version = "0.10"
|
|||||||
[dependencies.futures]
|
[dependencies.futures]
|
||||||
version = "0.3"
|
version = "0.3"
|
||||||
|
|
||||||
|
[dependencies.once_cell]
|
||||||
|
version = "1"
|
||||||
|
optional = true
|
||||||
|
|
||||||
[dependencies.parking_lot]
|
[dependencies.parking_lot]
|
||||||
optional = true
|
optional = true
|
||||||
version = "0.12"
|
version = "0.12"
|
||||||
@@ -127,6 +121,10 @@ optional = true
|
|||||||
version = "1.0"
|
version = "1.0"
|
||||||
default-features = false
|
default-features = false
|
||||||
|
|
||||||
|
[dependencies.tokio-tungstenite]
|
||||||
|
optional = true
|
||||||
|
version = "0.17"
|
||||||
|
|
||||||
[dependencies.tokio-util]
|
[dependencies.tokio-util]
|
||||||
optional = true
|
optional = true
|
||||||
version = "0.7"
|
version = "0.7"
|
||||||
@@ -184,7 +182,6 @@ gateway = [
|
|||||||
]
|
]
|
||||||
driver = [
|
driver = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"async-tungstenite",
|
|
||||||
"audiopus",
|
"audiopus",
|
||||||
"byteorder",
|
"byteorder",
|
||||||
"discortp",
|
"discortp",
|
||||||
@@ -201,7 +198,6 @@ driver = [
|
|||||||
"symphonia",
|
"symphonia",
|
||||||
"symphonia-core",
|
"symphonia-core",
|
||||||
"rusty_pool",
|
"rusty_pool",
|
||||||
"tokio-util",
|
|
||||||
"tokio/fs",
|
"tokio/fs",
|
||||||
"tokio/io-util",
|
"tokio/io-util",
|
||||||
"tokio/macros",
|
"tokio/macros",
|
||||||
@@ -210,13 +206,15 @@ driver = [
|
|||||||
"tokio/rt",
|
"tokio/rt",
|
||||||
"tokio/sync",
|
"tokio/sync",
|
||||||
"tokio/time",
|
"tokio/time",
|
||||||
|
"tokio-tungstenite",
|
||||||
|
"tokio-util",
|
||||||
"typemap_rev",
|
"typemap_rev",
|
||||||
"url",
|
"url",
|
||||||
"uuid",
|
"uuid",
|
||||||
"xsalsa20poly1305",
|
"xsalsa20poly1305",
|
||||||
]
|
]
|
||||||
rustls = ["async-tungstenite/tokio-rustls-webpki-roots", "reqwest/rustls-tls", "rustls-marker"]
|
rustls = ["tokio-tungstenite/rustls-tls-webpki-roots", "reqwest/rustls-tls", "rustls-marker"]
|
||||||
native = ["async-tungstenite/tokio-native-tls", "native-marker", "reqwest/native-tls"]
|
native = ["tokio-tungstenite/native-tls", "native-marker", "reqwest/native-tls"]
|
||||||
serenity-rustls = ["serenity/rustls_backend", "rustls", "gateway", "serenity-deps"]
|
serenity-rustls = ["serenity/rustls_backend", "rustls", "gateway", "serenity-deps"]
|
||||||
serenity-native = ["serenity/native_tls_backend", "native", "gateway", "serenity-deps"]
|
serenity-native = ["serenity/native_tls_backend", "native", "gateway", "serenity-deps"]
|
||||||
twilight-rustls = ["twilight", "twilight-gateway/rustls-native-roots", "rustls", "gateway"]
|
twilight-rustls = ["twilight", "twilight-gateway/rustls-native-roots", "rustls", "gateway"]
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ use crate::{
|
|||||||
Event as GatewayEvent,
|
Event as GatewayEvent,
|
||||||
ProtocolData,
|
ProtocolData,
|
||||||
},
|
},
|
||||||
ws::{self, ReceiverExt, SenderExt, WsStream},
|
ws::WsStream,
|
||||||
ConnectionInfo,
|
ConnectionInfo,
|
||||||
};
|
};
|
||||||
use discortp::discord::{IpDiscoveryPacket, IpDiscoveryType, MutableIpDiscoveryPacket};
|
use discortp::discord::{IpDiscoveryPacket, IpDiscoveryType, MutableIpDiscoveryPacket};
|
||||||
@@ -24,12 +24,6 @@ use tracing::{debug, info, instrument};
|
|||||||
use url::Url;
|
use url::Url;
|
||||||
use xsalsa20poly1305::{aead::NewAead, XSalsa20Poly1305 as Cipher};
|
use xsalsa20poly1305::{aead::NewAead, XSalsa20Poly1305 as Cipher};
|
||||||
|
|
||||||
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
|
|
||||||
use ws::create_rustls_client;
|
|
||||||
|
|
||||||
#[cfg(feature = "native-marker")]
|
|
||||||
use ws::create_native_tls_client;
|
|
||||||
|
|
||||||
pub(crate) struct Connection {
|
pub(crate) struct Connection {
|
||||||
pub(crate) info: ConnectionInfo,
|
pub(crate) info: ConnectionInfo,
|
||||||
pub(crate) ssrc: u32,
|
pub(crate) ssrc: u32,
|
||||||
@@ -58,11 +52,7 @@ impl Connection {
|
|||||||
) -> Result<Connection> {
|
) -> Result<Connection> {
|
||||||
let url = generate_url(&mut info.endpoint)?;
|
let url = generate_url(&mut info.endpoint)?;
|
||||||
|
|
||||||
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
|
let mut client = WsStream::connect(url).await?;
|
||||||
let mut client = create_rustls_client(url).await?;
|
|
||||||
|
|
||||||
#[cfg(feature = "native-marker")]
|
|
||||||
let mut client = create_native_tls_client(url).await?;
|
|
||||||
|
|
||||||
let mut hello = None;
|
let mut hello = None;
|
||||||
let mut ready = None;
|
let mut ready = None;
|
||||||
@@ -241,12 +231,7 @@ impl Connection {
|
|||||||
|
|
||||||
// Thread may have died, we want to send to prompt a clean exit
|
// Thread may have died, we want to send to prompt a clean exit
|
||||||
// (if at all possible) and then proceed as normal.
|
// (if at all possible) and then proceed as normal.
|
||||||
|
let mut client = WsStream::connect(url).await?;
|
||||||
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
|
|
||||||
let mut client = create_rustls_client(url).await?;
|
|
||||||
|
|
||||||
#[cfg(feature = "native-marker")]
|
|
||||||
let mut client = create_native_tls_client(url).await?;
|
|
||||||
|
|
||||||
client
|
client
|
||||||
.send_json(&GatewayEvent::from(Resume {
|
.send_json(&GatewayEvent::from(Resume {
|
||||||
|
|||||||
@@ -8,10 +8,9 @@ use crate::{
|
|||||||
FromPrimitive,
|
FromPrimitive,
|
||||||
SpeakingState,
|
SpeakingState,
|
||||||
},
|
},
|
||||||
ws::{Error as WsError, ReceiverExt, SenderExt, WsStream},
|
ws::{Error as WsError, WsStream},
|
||||||
ConnectionInfo,
|
ConnectionInfo,
|
||||||
};
|
};
|
||||||
use async_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
|
|
||||||
use flume::Receiver;
|
use flume::Receiver;
|
||||||
use rand::random;
|
use rand::random;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
@@ -19,6 +18,7 @@ use tokio::{
|
|||||||
select,
|
select,
|
||||||
time::{sleep_until, Instant},
|
time::{sleep_until, Instant},
|
||||||
};
|
};
|
||||||
|
use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
|
||||||
use tracing::{debug, info, instrument, trace, warn};
|
use tracing::{debug, info, instrument, trace, warn};
|
||||||
|
|
||||||
struct AuxNetwork {
|
struct AuxNetwork {
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ use crate::{
|
|||||||
model::{CloseCode as VoiceCloseCode, FromPrimitive},
|
model::{CloseCode as VoiceCloseCode, FromPrimitive},
|
||||||
ws::Error as WsError,
|
ws::Error as WsError,
|
||||||
};
|
};
|
||||||
use async_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
|
use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
|
||||||
|
|
||||||
/// Voice connection details gathered at termination or failure.
|
/// Voice connection details gathered at termination or failure.
|
||||||
///
|
///
|
||||||
|
|||||||
208
src/ws.rs
208
src/ws.rs
@@ -1,25 +1,71 @@
|
|||||||
use crate::{error::JsonError, model::Event};
|
use crate::{error::JsonError, model::Event};
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use futures::{SinkExt, StreamExt, TryStreamExt};
|
||||||
use async_tungstenite::{
|
use tokio::{
|
||||||
self as tungstenite,
|
net::TcpStream,
|
||||||
tokio::ConnectStream,
|
time::{timeout, Duration},
|
||||||
tungstenite::{error::Error as TungsteniteError, protocol::CloseFrame, Message},
|
};
|
||||||
|
use tokio_tungstenite::{
|
||||||
|
tungstenite::{
|
||||||
|
error::Error as TungsteniteError,
|
||||||
|
protocol::{CloseFrame, WebSocketConfig as Config},
|
||||||
|
Message,
|
||||||
|
},
|
||||||
|
MaybeTlsStream,
|
||||||
WebSocketStream,
|
WebSocketStream,
|
||||||
};
|
};
|
||||||
use futures::{SinkExt, StreamExt, TryStreamExt};
|
|
||||||
use tokio::time::{timeout, Duration};
|
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
use url::Url;
|
||||||
|
|
||||||
pub type WsStream = WebSocketStream<ConnectStream>;
|
pub struct WsStream(WebSocketStream<MaybeTlsStream<TcpStream>>);
|
||||||
|
|
||||||
|
impl WsStream {
|
||||||
|
#[instrument]
|
||||||
|
pub(crate) async fn connect(url: Url) -> Result<Self> {
|
||||||
|
let (stream, _) = tokio_tungstenite::connect_async_with_config::<Url>(
|
||||||
|
url,
|
||||||
|
Some(Config {
|
||||||
|
max_message_size: None,
|
||||||
|
max_frame_size: None,
|
||||||
|
max_send_queue: None,
|
||||||
|
..Default::default()
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(Self(stream))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn recv_json(&mut self) -> Result<Option<Event>> {
|
||||||
|
const TIMEOUT: Duration = Duration::from_millis(500);
|
||||||
|
|
||||||
|
let ws_message = match timeout(TIMEOUT, self.0.next()).await {
|
||||||
|
Ok(Some(Ok(v))) => Some(v),
|
||||||
|
Ok(Some(Err(e))) => return Err(e.into()),
|
||||||
|
Ok(None) | Err(_) => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
convert_ws_message(ws_message)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn recv_json_no_timeout(&mut self) -> Result<Option<Event>> {
|
||||||
|
convert_ws_message(self.0.try_next().await?)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn send_json(&mut self, value: &Event) -> Result<()> {
|
||||||
|
Ok(crate::json::to_string(value)
|
||||||
|
.map(Message::Text)
|
||||||
|
.map_err(Error::from)
|
||||||
|
.map(|m| self.0.send(m))?
|
||||||
|
.await?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub type Result<T> = std::result::Result<T, Error>;
|
pub type Result<T> = std::result::Result<T, Error>;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum Error {
|
pub enum Error {
|
||||||
Json(JsonError),
|
Json(JsonError),
|
||||||
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
|
|
||||||
Tls(RustlsError),
|
|
||||||
|
|
||||||
/// The discord voice gateway does not support or offer zlib compression.
|
/// The discord voice gateway does not support or offer zlib compression.
|
||||||
/// As a result, only text messages are expected.
|
/// As a result, only text messages are expected.
|
||||||
@@ -36,80 +82,12 @@ impl From<JsonError> for Error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
|
|
||||||
impl From<RustlsError> for Error {
|
|
||||||
fn from(e: RustlsError) -> Error {
|
|
||||||
Error::Tls(e)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<TungsteniteError> for Error {
|
impl From<TungsteniteError> for Error {
|
||||||
fn from(e: TungsteniteError) -> Error {
|
fn from(e: TungsteniteError) -> Error {
|
||||||
Error::Ws(e)
|
Error::Ws(e)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
use futures::stream::SplitSink;
|
|
||||||
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
|
|
||||||
use std::{
|
|
||||||
error::Error as StdError,
|
|
||||||
fmt::{Display, Formatter, Result as FmtResult},
|
|
||||||
io::Error as IoError,
|
|
||||||
};
|
|
||||||
use url::Url;
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
pub trait ReceiverExt {
|
|
||||||
async fn recv_json(&mut self) -> Result<Option<Event>>;
|
|
||||||
async fn recv_json_no_timeout(&mut self) -> Result<Option<Event>>;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
pub trait SenderExt {
|
|
||||||
async fn send_json(&mut self, value: &Event) -> Result<()>;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl ReceiverExt for WsStream {
|
|
||||||
async fn recv_json(&mut self) -> Result<Option<Event>> {
|
|
||||||
const TIMEOUT: Duration = Duration::from_millis(500);
|
|
||||||
|
|
||||||
let ws_message = match timeout(TIMEOUT, self.next()).await {
|
|
||||||
Ok(Some(Ok(v))) => Some(v),
|
|
||||||
Ok(Some(Err(e))) => return Err(e.into()),
|
|
||||||
Ok(None) | Err(_) => None,
|
|
||||||
};
|
|
||||||
|
|
||||||
convert_ws_message(ws_message)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn recv_json_no_timeout(&mut self) -> Result<Option<Event>> {
|
|
||||||
convert_ws_message(self.try_next().await?)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl SenderExt for SplitSink<WsStream, Message> {
|
|
||||||
async fn send_json(&mut self, value: &Event) -> Result<()> {
|
|
||||||
Ok(crate::json::to_string(value)
|
|
||||||
.map(Message::Text)
|
|
||||||
.map_err(Error::from)
|
|
||||||
.map(|m| self.send(m))?
|
|
||||||
.await?)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl SenderExt for WsStream {
|
|
||||||
async fn send_json(&mut self, value: &Event) -> Result<()> {
|
|
||||||
Ok(crate::json::to_string(value)
|
|
||||||
.map(Message::Text)
|
|
||||||
.map_err(Error::from)
|
|
||||||
.map(|m| self.send(m))?
|
|
||||||
.await?)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
pub(crate) fn convert_ws_message(message: Option<Message>) -> Result<Option<Event>> {
|
pub(crate) fn convert_ws_message(message: Option<Message>) -> Result<Option<Event>> {
|
||||||
Ok(match message {
|
Ok(match message {
|
||||||
@@ -125,77 +103,3 @@ pub(crate) fn convert_ws_message(message: Option<Message>) -> Result<Option<Even
|
|||||||
_ => None,
|
_ => None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// An error that occured while connecting over rustls
|
|
||||||
#[derive(Debug)]
|
|
||||||
#[non_exhaustive]
|
|
||||||
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
|
|
||||||
pub enum RustlsError {
|
|
||||||
/// An error with the handshake in tungstenite
|
|
||||||
HandshakeError,
|
|
||||||
/// Standard IO error happening while creating the tcp stream
|
|
||||||
Io(IoError),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
|
|
||||||
impl From<IoError> for RustlsError {
|
|
||||||
fn from(e: IoError) -> Self {
|
|
||||||
RustlsError::Io(e)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
|
|
||||||
impl Display for RustlsError {
|
|
||||||
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
|
|
||||||
match self {
|
|
||||||
RustlsError::HandshakeError =>
|
|
||||||
f.write_str("TLS handshake failed when making the websocket connection"),
|
|
||||||
RustlsError::Io(inner) => Display::fmt(&inner, f),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
|
|
||||||
impl StdError for RustlsError {
|
|
||||||
fn source(&self) -> Option<&(dyn StdError + 'static)> {
|
|
||||||
match self {
|
|
||||||
RustlsError::Io(inner) => Some(inner),
|
|
||||||
_ => None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
|
|
||||||
#[instrument]
|
|
||||||
pub(crate) async fn create_rustls_client(url: Url) -> Result<WsStream> {
|
|
||||||
let (stream, _) = tungstenite::tokio::connect_async_with_config::<Url>(
|
|
||||||
url,
|
|
||||||
Some(tungstenite::tungstenite::protocol::WebSocketConfig {
|
|
||||||
max_message_size: None,
|
|
||||||
max_frame_size: None,
|
|
||||||
max_send_queue: None,
|
|
||||||
..Default::default()
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.map_err(|_| RustlsError::HandshakeError)?;
|
|
||||||
|
|
||||||
Ok(stream)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "native-marker")]
|
|
||||||
#[instrument]
|
|
||||||
pub(crate) async fn create_native_tls_client(url: Url) -> Result<WsStream> {
|
|
||||||
let (stream, _) = tungstenite::tokio::connect_async_with_config::<Url>(
|
|
||||||
url,
|
|
||||||
Some(tungstenite::tungstenite::protocol::WebSocketConfig {
|
|
||||||
max_message_size: None,
|
|
||||||
max_frame_size: None,
|
|
||||||
max_send_queue: None,
|
|
||||||
..Default::default()
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(stream)
|
|
||||||
}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user