Driver: Migrate to tokio_tungstenite (#138)

This places songbird, serenity, and twilight onto the same WS library, hopefully reducing the compile overhead for everyone.

Tested using `cargo make ready` and by running `examples/voice`.

Closes #129.
This commit is contained in:
Kyle Simpson
2022-07-25 17:19:55 +01:00
parent 13946b47ce
commit 76c9851034
5 changed files with 74 additions and 187 deletions

208
src/ws.rs
View File

@@ -1,25 +1,71 @@
use crate::{error::JsonError, model::Event};
use async_trait::async_trait;
use async_tungstenite::{
self as tungstenite,
tokio::ConnectStream,
tungstenite::{error::Error as TungsteniteError, protocol::CloseFrame, Message},
use futures::{SinkExt, StreamExt, TryStreamExt};
use tokio::{
net::TcpStream,
time::{timeout, Duration},
};
use tokio_tungstenite::{
tungstenite::{
error::Error as TungsteniteError,
protocol::{CloseFrame, WebSocketConfig as Config},
Message,
},
MaybeTlsStream,
WebSocketStream,
};
use futures::{SinkExt, StreamExt, TryStreamExt};
use tokio::time::{timeout, Duration};
use tracing::instrument;
use url::Url;
pub type WsStream = WebSocketStream<ConnectStream>;
pub struct WsStream(WebSocketStream<MaybeTlsStream<TcpStream>>);
impl WsStream {
#[instrument]
pub(crate) async fn connect(url: Url) -> Result<Self> {
let (stream, _) = tokio_tungstenite::connect_async_with_config::<Url>(
url,
Some(Config {
max_message_size: None,
max_frame_size: None,
max_send_queue: None,
..Default::default()
}),
)
.await?;
Ok(Self(stream))
}
pub(crate) async fn recv_json(&mut self) -> Result<Option<Event>> {
const TIMEOUT: Duration = Duration::from_millis(500);
let ws_message = match timeout(TIMEOUT, self.0.next()).await {
Ok(Some(Ok(v))) => Some(v),
Ok(Some(Err(e))) => return Err(e.into()),
Ok(None) | Err(_) => None,
};
convert_ws_message(ws_message)
}
pub(crate) async fn recv_json_no_timeout(&mut self) -> Result<Option<Event>> {
convert_ws_message(self.0.try_next().await?)
}
pub(crate) async fn send_json(&mut self, value: &Event) -> Result<()> {
Ok(crate::json::to_string(value)
.map(Message::Text)
.map_err(Error::from)
.map(|m| self.0.send(m))?
.await?)
}
}
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug)]
pub enum Error {
Json(JsonError),
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
Tls(RustlsError),
/// The discord voice gateway does not support or offer zlib compression.
/// As a result, only text messages are expected.
@@ -36,80 +82,12 @@ impl From<JsonError> for Error {
}
}
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
impl From<RustlsError> for Error {
fn from(e: RustlsError) -> Error {
Error::Tls(e)
}
}
impl From<TungsteniteError> for Error {
fn from(e: TungsteniteError) -> Error {
Error::Ws(e)
}
}
use futures::stream::SplitSink;
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
use std::{
error::Error as StdError,
fmt::{Display, Formatter, Result as FmtResult},
io::Error as IoError,
};
use url::Url;
#[async_trait]
pub trait ReceiverExt {
async fn recv_json(&mut self) -> Result<Option<Event>>;
async fn recv_json_no_timeout(&mut self) -> Result<Option<Event>>;
}
#[async_trait]
pub trait SenderExt {
async fn send_json(&mut self, value: &Event) -> Result<()>;
}
#[async_trait]
impl ReceiverExt for WsStream {
async fn recv_json(&mut self) -> Result<Option<Event>> {
const TIMEOUT: Duration = Duration::from_millis(500);
let ws_message = match timeout(TIMEOUT, self.next()).await {
Ok(Some(Ok(v))) => Some(v),
Ok(Some(Err(e))) => return Err(e.into()),
Ok(None) | Err(_) => None,
};
convert_ws_message(ws_message)
}
async fn recv_json_no_timeout(&mut self) -> Result<Option<Event>> {
convert_ws_message(self.try_next().await?)
}
}
#[async_trait]
impl SenderExt for SplitSink<WsStream, Message> {
async fn send_json(&mut self, value: &Event) -> Result<()> {
Ok(crate::json::to_string(value)
.map(Message::Text)
.map_err(Error::from)
.map(|m| self.send(m))?
.await?)
}
}
#[async_trait]
impl SenderExt for WsStream {
async fn send_json(&mut self, value: &Event) -> Result<()> {
Ok(crate::json::to_string(value)
.map(Message::Text)
.map_err(Error::from)
.map(|m| self.send(m))?
.await?)
}
}
#[inline]
pub(crate) fn convert_ws_message(message: Option<Message>) -> Result<Option<Event>> {
Ok(match message {
@@ -125,77 +103,3 @@ pub(crate) fn convert_ws_message(message: Option<Message>) -> Result<Option<Even
_ => None,
})
}
/// An error that occured while connecting over rustls
#[derive(Debug)]
#[non_exhaustive]
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
pub enum RustlsError {
/// An error with the handshake in tungstenite
HandshakeError,
/// Standard IO error happening while creating the tcp stream
Io(IoError),
}
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
impl From<IoError> for RustlsError {
fn from(e: IoError) -> Self {
RustlsError::Io(e)
}
}
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
impl Display for RustlsError {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
match self {
RustlsError::HandshakeError =>
f.write_str("TLS handshake failed when making the websocket connection"),
RustlsError::Io(inner) => Display::fmt(&inner, f),
}
}
}
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
impl StdError for RustlsError {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
match self {
RustlsError::Io(inner) => Some(inner),
_ => None,
}
}
}
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
#[instrument]
pub(crate) async fn create_rustls_client(url: Url) -> Result<WsStream> {
let (stream, _) = tungstenite::tokio::connect_async_with_config::<Url>(
url,
Some(tungstenite::tungstenite::protocol::WebSocketConfig {
max_message_size: None,
max_frame_size: None,
max_send_queue: None,
..Default::default()
}),
)
.await
.map_err(|_| RustlsError::HandshakeError)?;
Ok(stream)
}
#[cfg(feature = "native-marker")]
#[instrument]
pub(crate) async fn create_native_tls_client(url: Url) -> Result<WsStream> {
let (stream, _) = tungstenite::tokio::connect_async_with_config::<Url>(
url,
Some(tungstenite::tungstenite::protocol::WebSocketConfig {
max_message_size: None,
max_frame_size: None,
max_send_queue: None,
..Default::default()
}),
)
.await?;
Ok(stream)
}