Library: Add compatibility for legacy Tokio 0.2 (#40)

Adds support to the library for tokio 0.2 backward-compatibility. This should hopefully benefit, and prevent lavalink-rs from being blocked on this feature.

These can be reached using, e.g., `gateway-tokio-02`, `driver-tokio-02`, `serenity-rustls-tokio-02`, and `serenity-native-tokio-02` features.

Naturally, this requires some jiggering about with features and the underlying CI, which has been taken care of. Twilight can't be handled in this way, as their last tokio 0.2 version uses the deprecated Discord Gateway v6.
This commit is contained in:
Kyle Simpson
2021-02-04 02:34:07 +00:00
committed by GitHub
parent b2453091e7
commit aaab97511d
24 changed files with 353 additions and 146 deletions

View File

@@ -19,15 +19,18 @@ use discortp::discord::{IpDiscoveryPacket, IpDiscoveryType, MutableIpDiscoveryPa
use error::{Error, Result};
use flume::Sender;
use std::{net::IpAddr, str::FromStr, sync::Arc};
use tokio::net::UdpSocket;
#[cfg(not(feature = "tokio-02-marker"))]
use tokio::{net::UdpSocket, spawn};
#[cfg(feature = "tokio-02-marker")]
use tokio_compat::{net::UdpSocket, spawn};
use tracing::{debug, info, instrument};
use url::Url;
use xsalsa20poly1305::{aead::NewAead, XSalsa20Poly1305 as Cipher};
#[cfg(all(feature = "rustls", not(feature = "native")))]
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
use ws::create_rustls_client;
#[cfg(feature = "native")]
#[cfg(feature = "native-marker")]
use ws::create_native_tls_client;
pub(crate) struct Connection {
@@ -43,10 +46,10 @@ impl Connection {
) -> Result<Connection> {
let url = generate_url(&mut info.endpoint)?;
#[cfg(all(feature = "rustls", not(feature = "native")))]
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
let mut client = create_rustls_client(url).await?;
#[cfg(feature = "native")]
#[cfg(feature = "native-marker")]
let mut client = create_native_tls_client(url).await?;
let mut hello = None;
@@ -97,7 +100,11 @@ impl Connection {
return Err(Error::CryptoModeUnavailable);
}
#[cfg(not(feature = "tokio-02-marker"))]
let udp = UdpSocket::bind("0.0.0.0:0").await?;
#[cfg(feature = "tokio-02-marker")]
let mut udp = UdpSocket::bind("0.0.0.0:0").await?;
udp.connect((ready.ip, ready.port)).await?;
// Follow Discord's IP Discovery procedures, in case NAT tunnelling is needed.
@@ -124,7 +131,7 @@ impl Connection {
}
// We could do something clever like binary search,
// but possibility of UDP spoofing preclueds us from
// but possibility of UDP spoofing precludes us from
// making the assumption we can find a "left edge" of '\0's.
let nul_byte_index = view
.get_address_raw()
@@ -162,8 +169,14 @@ impl Connection {
let (udp_sender_msg_tx, udp_sender_msg_rx) = flume::unbounded();
let (udp_receiver_msg_tx, udp_receiver_msg_rx) = flume::unbounded();
let udp_rx = Arc::new(udp);
let udp_tx = Arc::clone(&udp_rx);
#[cfg(not(feature = "tokio-02-marker"))]
let (udp_rx, udp_tx) = {
let udp_rx = Arc::new(udp);
let udp_tx = Arc::clone(&udp_rx);
(udp_rx, udp_tx)
};
#[cfg(feature = "tokio-02-marker")]
let (udp_rx, udp_tx) = udp.split();
let ssrc = ready.ssrc;
@@ -182,7 +195,7 @@ impl Connection {
.mixer
.send(MixerMessage::SetConn(mix_conn, ready.ssrc))?;
tokio::spawn(ws_task::runner(
spawn(ws_task::runner(
interconnect.clone(),
ws_msg_rx,
client,
@@ -190,14 +203,14 @@ impl Connection {
hello.heartbeat_interval,
));
tokio::spawn(udp_rx::runner(
spawn(udp_rx::runner(
interconnect.clone(),
udp_receiver_msg_rx,
cipher,
config.clone(),
udp_rx,
));
tokio::spawn(udp_tx::runner(udp_sender_msg_rx, ssrc, udp_tx));
spawn(udp_tx::runner(udp_sender_msg_rx, ssrc, udp_tx));
Ok(Connection {
info,
@@ -212,10 +225,10 @@ impl Connection {
// Thread may have died, we want to send to prompt a clean exit
// (if at all possible) and then proceed as normal.
#[cfg(all(feature = "rustls", not(feature = "native")))]
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
let mut client = create_rustls_client(url).await?;
#[cfg(feature = "native")]
#[cfg(feature = "native-marker")]
let mut client = create_native_tls_client(url).await?;
client

View File

@@ -11,6 +11,10 @@ mod ws;
pub use self::{core::*, disposal::*, events::*, mixer::*, udp_rx::*, udp_tx::*, ws::*};
use flume::Sender;
#[cfg(not(feature = "tokio-02-marker"))]
use tokio::spawn;
#[cfg(feature = "tokio-02-marker")]
use tokio_compat::spawn;
use tracing::info;
#[derive(Clone, Debug)]
@@ -38,7 +42,7 @@ impl Interconnect {
self.events = evt_tx;
let ic = self.clone();
tokio::spawn(async move {
spawn(async move {
info!("Event processor restarted.");
super::events::runner(ic, evt_rx).await;
info!("Event processor finished.");

View File

@@ -18,7 +18,10 @@ use flume::{Receiver, Sender, TryRecvError};
use rand::random;
use spin_sleep::SpinSleeper;
use std::time::Instant;
#[cfg(not(feature = "tokio-02-marker"))]
use tokio::runtime::Handle;
#[cfg(feature = "tokio-02-marker")]
use tokio_compat::runtime::Handle;
use tracing::{error, instrument};
use xsalsa20poly1305::TAG_SIZE;

View File

@@ -16,11 +16,14 @@ use super::{
use crate::events::CoreContext;
use flume::{Receiver, RecvError, Sender};
use message::*;
use tokio::runtime::Handle;
#[cfg(not(feature = "tokio-02-marker"))]
use tokio::{runtime::Handle, spawn};
#[cfg(feature = "tokio-02-marker")]
use tokio_compat::{runtime::Handle, spawn};
use tracing::{error, info, instrument};
pub(crate) fn start(config: Config, rx: Receiver<CoreMessage>, tx: Sender<CoreMessage>) {
tokio::spawn(async move {
spawn(async move {
info!("Driver started.");
runner(config, rx, tx).await;
info!("Driver finished.");
@@ -38,7 +41,7 @@ fn start_internals(core: Sender<CoreMessage>, config: Config) -> Interconnect {
};
let ic = interconnect.clone();
tokio::spawn(async move {
spawn(async move {
info!("Event processor started.");
events::runner(ic, evt_rx).await;
info!("Event processor finished.");

View File

@@ -21,7 +21,10 @@ use discortp::{
};
use flume::Receiver;
use std::{collections::HashMap, sync::Arc};
use tokio::net::UdpSocket;
#[cfg(not(feature = "tokio-02-marker"))]
use tokio::{net::UdpSocket, select};
#[cfg(feature = "tokio-02-marker")]
use tokio_compat::{net::udp::RecvHalf, select};
use tracing::{error, info, instrument, warn};
use xsalsa20poly1305::XSalsa20Poly1305 as Cipher;
@@ -236,14 +239,18 @@ struct UdpRx {
config: Config,
packet_buffer: [u8; VOICE_PACKET_MAX],
rx: Receiver<UdpRxMessage>,
#[cfg(not(feature = "tokio-02-marker"))]
udp_socket: Arc<UdpSocket>,
#[cfg(feature = "tokio-02-marker")]
udp_socket: RecvHalf,
}
impl UdpRx {
#[instrument(skip(self))]
async fn run(&mut self, interconnect: &mut Interconnect) {
loop {
tokio::select! {
select! {
Ok((len, _addr)) = self.udp_socket.recv_from(&mut self.packet_buffer[..]) => {
self.process_udp_message(interconnect, len);
}
@@ -385,6 +392,7 @@ impl UdpRx {
}
}
#[cfg(not(feature = "tokio-02-marker"))]
#[instrument(skip(interconnect, rx, cipher))]
pub(crate) async fn runner(
mut interconnect: Interconnect,
@@ -409,6 +417,31 @@ pub(crate) async fn runner(
info!("UDP receive handle stopped.");
}
#[cfg(feature = "tokio-02-marker")]
#[instrument(skip(interconnect, rx, cipher))]
pub(crate) async fn runner(
mut interconnect: Interconnect,
rx: Receiver<UdpRxMessage>,
cipher: Cipher,
config: Config,
udp_socket: RecvHalf,
) {
info!("UDP receive handle started.");
let mut state = UdpRx {
cipher,
decoder_map: Default::default(),
config,
packet_buffer: [0u8; VOICE_PACKET_MAX],
rx,
udp_socket,
};
state.run(&mut interconnect).await;
info!("UDP receive handle stopped.");
}
#[inline]
fn rtp_valid(packet: RtpPacket<'_>) -> bool {
packet.get_version() == RTP_VERSION && packet.get_payload_type() == RTP_PROFILE_TYPE

View File

@@ -3,48 +3,93 @@ use crate::constants::*;
use discortp::discord::MutableKeepalivePacket;
use flume::Receiver;
use std::sync::Arc;
#[cfg(not(feature = "tokio-02-marker"))]
use tokio::{
net::UdpSocket,
time::{timeout_at, Instant},
};
#[cfg(feature = "tokio-02-marker")]
use tokio_compat::{
net::udp::SendHalf,
time::{timeout_at, Instant},
};
use tracing::{error, info, instrument, trace};
struct UdpTx {
ssrc: u32,
rx: Receiver<UdpTxMessage>,
#[cfg(not(feature = "tokio-02-marker"))]
udp_tx: Arc<UdpSocket>,
#[cfg(feature = "tokio-02-marker")]
udp_tx: SendHalf,
}
impl UdpTx {
async fn run(&mut self) {
let mut keepalive_bytes = [0u8; MutableKeepalivePacket::minimum_packet_size()];
let mut ka = MutableKeepalivePacket::new(&mut keepalive_bytes[..])
.expect("FATAL: Insufficient bytes given to keepalive packet.");
ka.set_ssrc(self.ssrc);
let mut ka_time = Instant::now() + UDP_KEEPALIVE_GAP;
loop {
use UdpTxMessage::*;
match timeout_at(ka_time, self.rx.recv_async()).await {
Err(_) => {
trace!("Sending UDP Keepalive.");
if let Err(e) = self.udp_tx.send(&keepalive_bytes[..]).await {
error!("Fatal UDP keepalive send error: {:?}.", e);
break;
}
ka_time += UDP_KEEPALIVE_GAP;
},
Ok(Ok(Packet(p))) =>
if let Err(e) = self.udp_tx.send(&p[..]).await {
error!("Fatal UDP packet send error: {:?}.", e);
break;
},
Ok(Err(e)) => {
error!("Fatal UDP packet receive error: {:?}.", e);
break;
},
Ok(Ok(Poison)) => {
break;
},
}
}
}
}
#[cfg(not(feature = "tokio-02-marker"))]
#[instrument(skip(udp_msg_rx))]
pub(crate) async fn runner(udp_msg_rx: Receiver<UdpTxMessage>, ssrc: u32, udp_tx: Arc<UdpSocket>) {
info!("UDP transmit handle started.");
let mut keepalive_bytes = [0u8; MutableKeepalivePacket::minimum_packet_size()];
let mut ka = MutableKeepalivePacket::new(&mut keepalive_bytes[..])
.expect("FATAL: Insufficient bytes given to keepalive packet.");
ka.set_ssrc(ssrc);
let mut txer = UdpTx {
ssrc,
rx: udp_msg_rx,
udp_tx,
};
let mut ka_time = Instant::now() + UDP_KEEPALIVE_GAP;
loop {
use UdpTxMessage::*;
match timeout_at(ka_time, udp_msg_rx.recv_async()).await {
Err(_) => {
trace!("Sending UDP Keepalive.");
if let Err(e) = udp_tx.send(&keepalive_bytes[..]).await {
error!("Fatal UDP keepalive send error: {:?}.", e);
break;
}
ka_time += UDP_KEEPALIVE_GAP;
},
Ok(Ok(Packet(p))) =>
if let Err(e) = udp_tx.send(&p[..]).await {
error!("Fatal UDP packet send error: {:?}.", e);
break;
},
Ok(Err(e)) => {
error!("Fatal UDP packet receive error: {:?}.", e);
break;
},
Ok(Ok(Poison)) => {
break;
},
}
}
txer.run().await;
info!("UDP transmit handle stopped.");
}
#[cfg(feature = "tokio-02-marker")]
#[instrument(skip(udp_msg_rx))]
pub(crate) async fn runner(udp_msg_rx: Receiver<UdpTxMessage>, ssrc: u32, udp_tx: SendHalf) {
info!("UDP transmit handle started.");
let mut txer = UdpTx {
ssrc,
rx: udp_msg_rx,
udp_tx,
};
txer.run().await;
info!("UDP transmit handle stopped.");
}

View File

@@ -10,11 +10,23 @@ use crate::{
},
ws::{Error as WsError, ReceiverExt, SenderExt, WsStream},
};
#[cfg(not(feature = "tokio-02-marker"))]
use async_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
#[cfg(feature = "tokio-02-marker")]
use async_tungstenite_compat::tungstenite::protocol::frame::coding::CloseCode;
use flume::Receiver;
use rand::random;
use std::time::Duration;
use tokio::time::{self, Instant};
#[cfg(not(feature = "tokio-02-marker"))]
use tokio::{
select,
time::{sleep_until, Instant},
};
#[cfg(feature = "tokio-02-marker")]
use tokio_compat::{
select,
time::{delay_until as sleep_until, Instant},
};
use tracing::{error, info, instrument, trace, warn};
struct AuxNetwork {
@@ -57,9 +69,9 @@ impl AuxNetwork {
let mut ws_error = false;
let mut should_reconnect = false;
let hb = time::sleep_until(next_heartbeat);
let hb = sleep_until(next_heartbeat);
tokio::select! {
select! {
_ = hb => {
ws_error = match self.send_heartbeat().await {
Err(e) => {