diff --git a/Cargo.toml b/Cargo.toml index 90ef9d1..7055ac3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,20 +21,10 @@ serde_json = "1" tracing = { version = "0.1", features = ["log"] } tracing-futures = "0.2" -[dependencies.once_cell] -version = "1" -optional = true - [dependencies.async-trait] optional = true version = "0.1" -[dependencies.async-tungstenite] -default-features = false -features = ["tokio-runtime"] -optional = true -version = "0.17" - [dependencies.audiopus] optional = true version = "0.3.0-rc.0" @@ -59,6 +49,10 @@ version = "0.10" [dependencies.futures] version = "0.3" +[dependencies.once_cell] +version = "1" +optional = true + [dependencies.parking_lot] optional = true version = "0.12" @@ -127,6 +121,10 @@ optional = true version = "1.0" default-features = false +[dependencies.tokio-tungstenite] +optional = true +version = "0.17" + [dependencies.tokio-util] optional = true version = "0.7" @@ -184,7 +182,6 @@ gateway = [ ] driver = [ "async-trait", - "async-tungstenite", "audiopus", "byteorder", "discortp", @@ -201,7 +198,6 @@ driver = [ "symphonia", "symphonia-core", "rusty_pool", - "tokio-util", "tokio/fs", "tokio/io-util", "tokio/macros", @@ -210,13 +206,15 @@ driver = [ "tokio/rt", "tokio/sync", "tokio/time", + "tokio-tungstenite", + "tokio-util", "typemap_rev", "url", "uuid", "xsalsa20poly1305", ] -rustls = ["async-tungstenite/tokio-rustls-webpki-roots", "reqwest/rustls-tls", "rustls-marker"] -native = ["async-tungstenite/tokio-native-tls", "native-marker", "reqwest/native-tls"] +rustls = ["tokio-tungstenite/rustls-tls-webpki-roots", "reqwest/rustls-tls", "rustls-marker"] +native = ["tokio-tungstenite/native-tls", "native-marker", "reqwest/native-tls"] serenity-rustls = ["serenity/rustls_backend", "rustls", "gateway", "serenity-deps"] serenity-native = ["serenity/native_tls_backend", "native", "gateway", "serenity-deps"] twilight-rustls = ["twilight", "twilight-gateway/rustls-native-roots", "rustls", "gateway"] diff --git a/src/driver/connection/mod.rs b/src/driver/connection/mod.rs index 76b93a1..9a92a25 100644 --- a/src/driver/connection/mod.rs +++ b/src/driver/connection/mod.rs @@ -12,7 +12,7 @@ use crate::{ Event as GatewayEvent, ProtocolData, }, - ws::{self, ReceiverExt, SenderExt, WsStream}, + ws::WsStream, ConnectionInfo, }; use discortp::discord::{IpDiscoveryPacket, IpDiscoveryType, MutableIpDiscoveryPacket}; @@ -24,12 +24,6 @@ use tracing::{debug, info, instrument}; use url::Url; use xsalsa20poly1305::{aead::NewAead, XSalsa20Poly1305 as Cipher}; -#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))] -use ws::create_rustls_client; - -#[cfg(feature = "native-marker")] -use ws::create_native_tls_client; - pub(crate) struct Connection { pub(crate) info: ConnectionInfo, pub(crate) ssrc: u32, @@ -58,11 +52,7 @@ impl Connection { ) -> Result { let url = generate_url(&mut info.endpoint)?; - #[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))] - let mut client = create_rustls_client(url).await?; - - #[cfg(feature = "native-marker")] - let mut client = create_native_tls_client(url).await?; + let mut client = WsStream::connect(url).await?; let mut hello = None; let mut ready = None; @@ -241,12 +231,7 @@ impl Connection { // Thread may have died, we want to send to prompt a clean exit // (if at all possible) and then proceed as normal. - - #[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))] - let mut client = create_rustls_client(url).await?; - - #[cfg(feature = "native-marker")] - let mut client = create_native_tls_client(url).await?; + let mut client = WsStream::connect(url).await?; client .send_json(&GatewayEvent::from(Resume { diff --git a/src/driver/tasks/ws.rs b/src/driver/tasks/ws.rs index c7a3626..07bc796 100644 --- a/src/driver/tasks/ws.rs +++ b/src/driver/tasks/ws.rs @@ -8,10 +8,9 @@ use crate::{ FromPrimitive, SpeakingState, }, - ws::{Error as WsError, ReceiverExt, SenderExt, WsStream}, + ws::{Error as WsError, WsStream}, ConnectionInfo, }; -use async_tungstenite::tungstenite::protocol::frame::coding::CloseCode; use flume::Receiver; use rand::random; use std::time::Duration; @@ -19,6 +18,7 @@ use tokio::{ select, time::{sleep_until, Instant}, }; +use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode; use tracing::{debug, info, instrument, trace, warn}; struct AuxNetwork { diff --git a/src/events/context/data/disconnect.rs b/src/events/context/data/disconnect.rs index 6fbbba0..536236e 100644 --- a/src/events/context/data/disconnect.rs +++ b/src/events/context/data/disconnect.rs @@ -4,7 +4,7 @@ use crate::{ model::{CloseCode as VoiceCloseCode, FromPrimitive}, ws::Error as WsError, }; -use async_tungstenite::tungstenite::protocol::frame::coding::CloseCode; +use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode; /// Voice connection details gathered at termination or failure. /// diff --git a/src/ws.rs b/src/ws.rs index f47e1a9..a71cf55 100644 --- a/src/ws.rs +++ b/src/ws.rs @@ -1,25 +1,71 @@ use crate::{error::JsonError, model::Event}; -use async_trait::async_trait; -use async_tungstenite::{ - self as tungstenite, - tokio::ConnectStream, - tungstenite::{error::Error as TungsteniteError, protocol::CloseFrame, Message}, +use futures::{SinkExt, StreamExt, TryStreamExt}; +use tokio::{ + net::TcpStream, + time::{timeout, Duration}, +}; +use tokio_tungstenite::{ + tungstenite::{ + error::Error as TungsteniteError, + protocol::{CloseFrame, WebSocketConfig as Config}, + Message, + }, + MaybeTlsStream, WebSocketStream, }; -use futures::{SinkExt, StreamExt, TryStreamExt}; -use tokio::time::{timeout, Duration}; use tracing::instrument; +use url::Url; -pub type WsStream = WebSocketStream; +pub struct WsStream(WebSocketStream>); + +impl WsStream { + #[instrument] + pub(crate) async fn connect(url: Url) -> Result { + let (stream, _) = tokio_tungstenite::connect_async_with_config::( + url, + Some(Config { + max_message_size: None, + max_frame_size: None, + max_send_queue: None, + ..Default::default() + }), + ) + .await?; + + Ok(Self(stream)) + } + + pub(crate) async fn recv_json(&mut self) -> Result> { + const TIMEOUT: Duration = Duration::from_millis(500); + + let ws_message = match timeout(TIMEOUT, self.0.next()).await { + Ok(Some(Ok(v))) => Some(v), + Ok(Some(Err(e))) => return Err(e.into()), + Ok(None) | Err(_) => None, + }; + + convert_ws_message(ws_message) + } + + pub(crate) async fn recv_json_no_timeout(&mut self) -> Result> { + convert_ws_message(self.0.try_next().await?) + } + + pub(crate) async fn send_json(&mut self, value: &Event) -> Result<()> { + Ok(crate::json::to_string(value) + .map(Message::Text) + .map_err(Error::from) + .map(|m| self.0.send(m))? + .await?) + } +} pub type Result = std::result::Result; #[derive(Debug)] pub enum Error { Json(JsonError), - #[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))] - Tls(RustlsError), /// The discord voice gateway does not support or offer zlib compression. /// As a result, only text messages are expected. @@ -36,80 +82,12 @@ impl From for Error { } } -#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))] -impl From for Error { - fn from(e: RustlsError) -> Error { - Error::Tls(e) - } -} - impl From for Error { fn from(e: TungsteniteError) -> Error { Error::Ws(e) } } -use futures::stream::SplitSink; -#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))] -use std::{ - error::Error as StdError, - fmt::{Display, Formatter, Result as FmtResult}, - io::Error as IoError, -}; -use url::Url; - -#[async_trait] -pub trait ReceiverExt { - async fn recv_json(&mut self) -> Result>; - async fn recv_json_no_timeout(&mut self) -> Result>; -} - -#[async_trait] -pub trait SenderExt { - async fn send_json(&mut self, value: &Event) -> Result<()>; -} - -#[async_trait] -impl ReceiverExt for WsStream { - async fn recv_json(&mut self) -> Result> { - const TIMEOUT: Duration = Duration::from_millis(500); - - let ws_message = match timeout(TIMEOUT, self.next()).await { - Ok(Some(Ok(v))) => Some(v), - Ok(Some(Err(e))) => return Err(e.into()), - Ok(None) | Err(_) => None, - }; - - convert_ws_message(ws_message) - } - - async fn recv_json_no_timeout(&mut self) -> Result> { - convert_ws_message(self.try_next().await?) - } -} - -#[async_trait] -impl SenderExt for SplitSink { - async fn send_json(&mut self, value: &Event) -> Result<()> { - Ok(crate::json::to_string(value) - .map(Message::Text) - .map_err(Error::from) - .map(|m| self.send(m))? - .await?) - } -} - -#[async_trait] -impl SenderExt for WsStream { - async fn send_json(&mut self, value: &Event) -> Result<()> { - Ok(crate::json::to_string(value) - .map(Message::Text) - .map_err(Error::from) - .map(|m| self.send(m))? - .await?) - } -} - #[inline] pub(crate) fn convert_ws_message(message: Option) -> Result> { Ok(match message { @@ -125,77 +103,3 @@ pub(crate) fn convert_ws_message(message: Option) -> Result None, }) } - -/// An error that occured while connecting over rustls -#[derive(Debug)] -#[non_exhaustive] -#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))] -pub enum RustlsError { - /// An error with the handshake in tungstenite - HandshakeError, - /// Standard IO error happening while creating the tcp stream - Io(IoError), -} - -#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))] -impl From for RustlsError { - fn from(e: IoError) -> Self { - RustlsError::Io(e) - } -} - -#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))] -impl Display for RustlsError { - fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { - match self { - RustlsError::HandshakeError => - f.write_str("TLS handshake failed when making the websocket connection"), - RustlsError::Io(inner) => Display::fmt(&inner, f), - } - } -} - -#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))] -impl StdError for RustlsError { - fn source(&self) -> Option<&(dyn StdError + 'static)> { - match self { - RustlsError::Io(inner) => Some(inner), - _ => None, - } - } -} - -#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))] -#[instrument] -pub(crate) async fn create_rustls_client(url: Url) -> Result { - let (stream, _) = tungstenite::tokio::connect_async_with_config::( - url, - Some(tungstenite::tungstenite::protocol::WebSocketConfig { - max_message_size: None, - max_frame_size: None, - max_send_queue: None, - ..Default::default() - }), - ) - .await - .map_err(|_| RustlsError::HandshakeError)?; - - Ok(stream) -} - -#[cfg(feature = "native-marker")] -#[instrument] -pub(crate) async fn create_native_tls_client(url: Url) -> Result { - let (stream, _) = tungstenite::tokio::connect_async_with_config::( - url, - Some(tungstenite::tungstenite::protocol::WebSocketConfig { - max_message_size: None, - max_frame_size: None, - max_send_queue: None, - ..Default::default() - }), - ) - .await?; - - Ok(stream) -}