Voice Rework -- Events, Track Queues (#806)

This implements a proof-of-concept for an improved audio frontend. The largest change is the introduction of events and event handling: both by time elapsed and by track events, such as ending or looping. Following on from this, the library now includes a basic, event-driven track queue system (which people seem to ask for unusually often). A new sample, `examples/13_voice_events`, demonstrates both the `TrackQueue` system and some basic events via the `~queue` and `~play_fade` commands.

Locks are removed from around the control of `Audio` objects, which should allow the backend to be moved to a more granular futures-based backend solution in a cleaner way.
This commit is contained in:
Kyle Simpson
2020-10-29 20:25:20 +00:00
committed by Alex M. M
commit 7e4392ae68
76 changed files with 8756 additions and 0 deletions

10
src/driver/config.rs Normal file
View File

@@ -0,0 +1,10 @@
use super::CryptoMode;
/// Configuration for the inner Driver.
///
/// At present, this cannot be changed.
#[derive(Clone, Debug, Default)]
pub struct Config {
/// Selected tagging mode for voice packet encryption.
pub crypto_mode: Option<CryptoMode>,
}

View File

@@ -0,0 +1,105 @@
//! Connection errors and convenience types.
use crate::{
driver::tasks::{error::Recipient, message::*},
ws::Error as WsError,
};
use flume::SendError;
use serde_json::Error as JsonError;
use std::{error::Error as ErrorTrait, fmt, io::Error as IoError};
use xsalsa20poly1305::aead::Error as CryptoError;
/// Errors encountered while connecting to a Discord voice server over the driver.
#[derive(Debug)]
pub enum Error {
/// An error occurred during [en/de]cryption of voice packets or key generation.
Crypto(CryptoError),
/// Server did not return the expected crypto mode during negotiation.
CryptoModeInvalid,
/// Selected crypto mode was not offered by server.
CryptoModeUnavailable,
/// An indicator that an endpoint URL was invalid.
EndpointUrl,
/// Discord hello/ready handshake was violated.
ExpectedHandshake,
/// Discord failed to correctly respond to IP discovery.
IllegalDiscoveryResponse,
/// Could not parse Discord's view of our IP.
IllegalIp,
/// Miscellaneous I/O error.
Io(IoError),
/// JSON (de)serialization error.
Json(JsonError),
/// Failed to message other background tasks after connection establishment.
InterconnectFailure(Recipient),
/// Error communicating with gateway server over WebSocket.
Ws(WsError),
}
impl From<CryptoError> for Error {
fn from(e: CryptoError) -> Self {
Error::Crypto(e)
}
}
impl From<IoError> for Error {
fn from(e: IoError) -> Error {
Error::Io(e)
}
}
impl From<JsonError> for Error {
fn from(e: JsonError) -> Error {
Error::Json(e)
}
}
impl From<SendError<WsMessage>> for Error {
fn from(_e: SendError<WsMessage>) -> Error {
Error::InterconnectFailure(Recipient::AuxNetwork)
}
}
impl From<SendError<EventMessage>> for Error {
fn from(_e: SendError<EventMessage>) -> Error {
Error::InterconnectFailure(Recipient::Event)
}
}
impl From<SendError<MixerMessage>> for Error {
fn from(_e: SendError<MixerMessage>) -> Error {
Error::InterconnectFailure(Recipient::Mixer)
}
}
impl From<WsError> for Error {
fn from(e: WsError) -> Error {
Error::Ws(e)
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Failed to connect to Discord RTP server: ")?;
use Error::*;
match self {
Crypto(c) => write!(f, "cryptography error {}.", c),
CryptoModeInvalid => write!(f, "server changed negotiated encryption mode."),
CryptoModeUnavailable => write!(f, "server did not offer chosen encryption mode."),
EndpointUrl => write!(f, "endpoint URL received from gateway was invalid."),
ExpectedHandshake => write!(f, "voice initialisation protocol was violated."),
IllegalDiscoveryResponse =>
write!(f, "IP discovery/NAT punching response was invalid."),
IllegalIp => write!(f, "IP discovery/NAT punching response had bad IP value."),
Io(i) => write!(f, "I/O failure ({}).", i),
Json(j) => write!(f, "JSON (de)serialization issue ({}).", j),
InterconnectFailure(r) => write!(f, "failed to contact other task ({:?})", r),
Ws(w) => write!(f, "websocket issue ({:?}).", w),
}
}
}
impl ErrorTrait for Error {}
/// Convenience type for Discord voice/driver connection error handling.
pub type Result<T> = std::result::Result<T, Error>;

View File

@@ -0,0 +1,321 @@
pub mod error;
use super::{
tasks::{message::*, udp_rx, udp_tx, ws as ws_task},
Config,
CryptoMode,
};
use crate::{
constants::*,
model::{
payload::{Identify, Resume, SelectProtocol},
Event as GatewayEvent,
ProtocolData,
},
ws::{self, ReceiverExt, SenderExt, WsStream},
ConnectionInfo,
};
use discortp::discord::{IpDiscoveryPacket, IpDiscoveryType, MutableIpDiscoveryPacket};
use error::{Error, Result};
use flume::Sender;
use std::{net::IpAddr, str::FromStr};
use tokio::net::UdpSocket;
use tracing::{debug, info, instrument};
use url::Url;
use xsalsa20poly1305::{aead::NewAead, XSalsa20Poly1305 as Cipher};
#[cfg(all(feature = "rustls", not(feature = "native")))]
use ws::create_rustls_client;
#[cfg(feature = "native")]
use ws::create_native_tls_client;
pub(crate) struct Connection {
pub(crate) info: ConnectionInfo,
pub(crate) ws: Sender<WsMessage>,
}
impl Connection {
pub(crate) async fn new(
mut info: ConnectionInfo,
interconnect: &Interconnect,
config: &Config,
) -> Result<Connection> {
let crypto_mode = config.crypto_mode.unwrap_or(CryptoMode::Normal);
let url = generate_url(&mut info.endpoint)?;
#[cfg(all(feature = "rustls", not(feature = "native")))]
let mut client = create_rustls_client(url).await?;
#[cfg(feature = "native")]
let mut client = create_native_tls_client(url).await?;
let mut hello = None;
let mut ready = None;
client
.send_json(&GatewayEvent::from(Identify {
server_id: info.guild_id.into(),
session_id: info.session_id.clone(),
token: info.token.clone(),
user_id: info.user_id.into(),
}))
.await?;
loop {
let value = match client.recv_json().await? {
Some(value) => value,
None => continue,
};
match value {
GatewayEvent::Ready(r) => {
ready = Some(r);
if hello.is_some() {
break;
}
},
GatewayEvent::Hello(h) => {
hello = Some(h);
if ready.is_some() {
break;
}
},
other => {
debug!("Expected ready/hello; got: {:?}", other);
return Err(Error::ExpectedHandshake);
},
}
}
let hello =
hello.expect("Hello packet expected in connection initialisation, but not found.");
let ready =
ready.expect("Ready packet expected in connection initialisation, but not found.");
if !has_valid_mode(&ready.modes, crypto_mode) {
return Err(Error::CryptoModeUnavailable);
}
let mut udp = UdpSocket::bind("0.0.0.0:0").await?;
udp.connect((ready.ip, ready.port)).await?;
// Follow Discord's IP Discovery procedures, in case NAT tunnelling is needed.
let mut bytes = [0; IpDiscoveryPacket::const_packet_size()];
{
let mut view = MutableIpDiscoveryPacket::new(&mut bytes[..]).expect(
"Too few bytes in 'bytes' for IPDiscovery packet.\
(Blame: IpDiscoveryPacket::const_packet_size()?)",
);
view.set_pkt_type(IpDiscoveryType::Request);
view.set_length(70);
view.set_ssrc(ready.ssrc);
}
udp.send(&bytes).await?;
let (len, _addr) = udp.recv_from(&mut bytes).await?;
{
let view =
IpDiscoveryPacket::new(&bytes[..len]).ok_or(Error::IllegalDiscoveryResponse)?;
if view.get_pkt_type() != IpDiscoveryType::Response {
return Err(Error::IllegalDiscoveryResponse);
}
// We could do something clever like binary search,
// but possibility of UDP spoofing preclueds us from
// making the assumption we can find a "left edge" of '\0's.
let nul_byte_index = view
.get_address_raw()
.iter()
.position(|&b| b == 0)
.ok_or(Error::IllegalIp)?;
let address_str = std::str::from_utf8(&view.get_address_raw()[..nul_byte_index])
.map_err(|_| Error::IllegalIp)?;
let address = IpAddr::from_str(&address_str).map_err(|e| {
println!("{:?}", e);
Error::IllegalIp
})?;
client
.send_json(&GatewayEvent::from(SelectProtocol {
protocol: "udp".into(),
data: ProtocolData {
address,
mode: crypto_mode.to_request_str().into(),
port: view.get_port(),
},
}))
.await?;
}
let cipher = init_cipher(&mut client, crypto_mode).await?;
info!("Connected to: {}", info.endpoint);
info!("WS heartbeat duration {}ms.", hello.heartbeat_interval,);
let (ws_msg_tx, ws_msg_rx) = flume::unbounded();
let (udp_sender_msg_tx, udp_sender_msg_rx) = flume::unbounded();
let (udp_receiver_msg_tx, udp_receiver_msg_rx) = flume::unbounded();
let (udp_rx, udp_tx) = udp.split();
let ssrc = ready.ssrc;
let mix_conn = MixerConnection {
cipher: cipher.clone(),
udp_rx: udp_receiver_msg_tx,
udp_tx: udp_sender_msg_tx,
};
interconnect
.mixer
.send(MixerMessage::Ws(Some(ws_msg_tx.clone())))?;
interconnect
.mixer
.send(MixerMessage::SetConn(mix_conn, ready.ssrc))?;
tokio::spawn(ws_task::runner(
interconnect.clone(),
ws_msg_rx,
client,
ssrc,
hello.heartbeat_interval,
));
tokio::spawn(udp_rx::runner(
interconnect.clone(),
udp_receiver_msg_rx,
cipher,
crypto_mode,
udp_rx,
));
tokio::spawn(udp_tx::runner(udp_sender_msg_rx, ssrc, udp_tx));
Ok(Connection {
info,
ws: ws_msg_tx,
})
}
#[instrument(skip(self))]
pub async fn reconnect(&mut self) -> Result<()> {
let url = generate_url(&mut self.info.endpoint)?;
// Thread may have died, we want to send to prompt a clean exit
// (if at all possible) and then proceed as normal.
#[cfg(all(feature = "rustls", not(feature = "native")))]
let mut client = create_rustls_client(url).await?;
#[cfg(feature = "native")]
let mut client = create_native_tls_client(url).await?;
client
.send_json(&GatewayEvent::from(Resume {
server_id: self.info.guild_id.into(),
session_id: self.info.session_id.clone(),
token: self.info.token.clone(),
}))
.await?;
let mut hello = None;
let mut resumed = None;
loop {
let value = match client.recv_json().await? {
Some(value) => value,
None => continue,
};
match value {
GatewayEvent::Resumed => {
resumed = Some(());
if hello.is_some() {
break;
}
},
GatewayEvent::Hello(h) => {
hello = Some(h);
if resumed.is_some() {
break;
}
},
other => {
debug!("Expected resumed/hello; got: {:?}", other);
return Err(Error::ExpectedHandshake);
},
}
}
let hello =
hello.expect("Hello packet expected in connection initialisation, but not found.");
self.ws
.send(WsMessage::SetKeepalive(hello.heartbeat_interval))?;
self.ws.send(WsMessage::Ws(Box::new(client)))?;
info!("Reconnected to: {}", &self.info.endpoint);
Ok(())
}
}
impl Drop for Connection {
fn drop(&mut self) {
info!("Disconnected");
}
}
fn generate_url(endpoint: &mut String) -> Result<Url> {
if endpoint.ends_with(":80") {
let len = endpoint.len();
endpoint.truncate(len - 3);
}
Url::parse(&format!("wss://{}/?v={}", endpoint, VOICE_GATEWAY_VERSION))
.or(Err(Error::EndpointUrl))
}
#[inline]
async fn init_cipher(client: &mut WsStream, mode: CryptoMode) -> Result<Cipher> {
loop {
let value = match client.recv_json().await? {
Some(value) => value,
None => continue,
};
match value {
GatewayEvent::SessionDescription(desc) => {
if desc.mode != mode.to_request_str() {
return Err(Error::CryptoModeInvalid);
}
return Ok(Cipher::new_varkey(&desc.secret_key)?);
},
other => {
debug!(
"Expected ready for key; got: op{}/v{:?}",
other.kind() as u8,
other
);
},
}
}
}
#[inline]
fn has_valid_mode<T, It>(modes: It, mode: CryptoMode) -> bool
where
T: for<'a> PartialEq<&'a str>,
It: IntoIterator<Item = T>,
{
modes.into_iter().any(|s| s == mode.to_request_str())
}

38
src/driver/crypto.rs Normal file
View File

@@ -0,0 +1,38 @@
//! Encryption schemes supported by Discord's secure RTP negotiation.
/// Variants of the XSalsa20Poly1305 encryption scheme.
///
/// At present, only `Normal` is supported or selectable.
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
#[non_exhaustive]
pub enum Mode {
/// The RTP header is used as the source of nonce bytes for the packet.
///
/// Equivalent to a nonce of at most 48b (6B) at no extra packet overhead:
/// the RTP sequence number and timestamp are the varying quantities.
Normal,
/// An additional random 24B suffix is used as the source of nonce bytes for the packet.
///
/// Full nonce width of 24B (192b), at an extra 24B per packet (~1.2 kB/s).
Suffix,
/// An additional random 24B suffix is used as the source of nonce bytes for the packet.
///
/// Nonce width of 4B (32b), at an extra 4B per packet (~0.2 kB/s).
Lite,
}
impl Mode {
/// Returns the name of a mode as it will appear during negotiation.
pub fn to_request_str(self) -> &'static str {
use Mode::*;
match self {
Normal => "xsalsa20_poly1305",
Suffix => "xsalsa20_poly1305_suffix",
Lite => "xsalsa20_poly1305_lite",
}
}
}
// TODO: implement encrypt + decrypt + nonce selection for each.
// This will probably need some research into correct handling of
// padding, reported length, SRTP profiles, and so on.

233
src/driver/mod.rs Normal file
View File

@@ -0,0 +1,233 @@
//! Runner for a voice connection.
//!
//! Songbird's driver is a mixed-sync system, using:
//! * Asynchronous connection management, event-handling, and gateway integration.
//! * Synchronous audio mixing, packet generation, and encoding.
//!
//! This splits up work according to its IO/compute bound nature, preventing packet
//! generation from being slowed down past its deadline, or from affecting other
//! asynchronous tasks your bot must handle.
mod config;
pub(crate) mod connection;
mod crypto;
pub(crate) mod tasks;
pub use config::Config;
use connection::error::Result;
pub use crypto::Mode as CryptoMode;
use crate::{
events::EventData,
input::Input,
tracks::{Track, TrackHandle},
ConnectionInfo,
Event,
EventHandler,
};
use audiopus::Bitrate;
use flume::{Receiver, SendError, Sender};
use tasks::message::CoreMessage;
use tracing::instrument;
/// The control object for a Discord voice connection, handling connection,
/// mixing, encoding, en/decryption, and event generation.
#[derive(Clone, Debug)]
pub struct Driver {
config: Config,
self_mute: bool,
sender: Sender<CoreMessage>,
}
impl Driver {
/// Creates a new voice driver.
///
/// This will create the core voice tasks in the background.
#[inline]
pub fn new(config: Config) -> Self {
let sender = Self::start_inner(config.clone());
Driver {
config,
self_mute: false,
sender,
}
}
fn start_inner(config: Config) -> Sender<CoreMessage> {
let (tx, rx) = flume::unbounded();
tasks::start(config, rx, tx.clone());
tx
}
fn restart_inner(&mut self) {
self.sender = Self::start_inner(self.config.clone());
self.mute(self.self_mute);
}
/// Connects to a voice channel using the specified server.
#[instrument(skip(self))]
pub fn connect(&mut self, info: ConnectionInfo) -> Receiver<Result<()>> {
let (tx, rx) = flume::bounded(1);
self.raw_connect(info, tx);
rx
}
/// Connects to a voice channel using the specified server.
#[instrument(skip(self))]
pub(crate) fn raw_connect(&mut self, info: ConnectionInfo, tx: Sender<Result<()>>) {
self.send(CoreMessage::ConnectWithResult(info, tx));
}
/// Leaves the current voice channel, disconnecting from it.
///
/// This does *not* forget settings, like whether to be self-deafened or
/// self-muted.
#[instrument(skip(self))]
pub fn leave(&mut self) {
self.send(CoreMessage::Disconnect);
}
/// Sets whether the current connection is to be muted.
///
/// If there is no live voice connection, then this only acts as a settings
/// update for future connections.
#[instrument(skip(self))]
pub fn mute(&mut self, mute: bool) {
self.self_mute = mute;
self.send(CoreMessage::Mute(mute));
}
/// Returns whether the driver is muted (i.e., processes audio internally
/// but submits none).
#[instrument(skip(self))]
pub fn is_mute(&self) -> bool {
self.self_mute
}
/// Plays audio from a source, returning a handle for further control.
///
/// This can be a source created via [`ffmpeg`] or [`ytdl`].
///
/// [`ffmpeg`]: ../input/fn.ffmpeg.html
/// [`ytdl`]: ../input/fn.ytdl.html
#[instrument(skip(self))]
pub fn play_source(&mut self, source: Input) -> TrackHandle {
let (player, handle) = super::create_player(source);
self.send(CoreMessage::AddTrack(player));
handle
}
/// Plays audio from a source, returning a handle for further control.
///
/// Unlike [`play_source`], this stops all other sources attached
/// to the channel.
///
/// [`play_source`]: #method.play_source
#[instrument(skip(self))]
pub fn play_only_source(&mut self, source: Input) -> TrackHandle {
let (player, handle) = super::create_player(source);
self.send(CoreMessage::SetTrack(Some(player)));
handle
}
/// Plays audio from a [`Track`] object.
///
/// This will be one half of the return value of [`create_player`].
/// The main difference between this function and [`play_source`] is
/// that this allows for direct manipulation of the [`Track`] object
/// before it is passed over to the voice and mixing contexts.
///
/// [`create_player`]: ../tracks/fn.create_player.html
/// [`Track`]: ../tracks/struct.Track.html
/// [`play_source`]: #method.play_source
#[instrument(skip(self))]
pub fn play(&mut self, track: Track) {
self.send(CoreMessage::AddTrack(track));
}
/// Exclusively plays audio from a [`Track`] object.
///
/// This will be one half of the return value of [`create_player`].
/// As in [`play_only_source`], this stops all other sources attached to the
/// channel. Like [`play`], however, this allows for direct manipulation of the
/// [`Track`] object before it is passed over to the voice and mixing contexts.
///
/// [`create_player`]: ../tracks/fn.create_player.html
/// [`Track`]: ../tracks/struct.Track.html
/// [`play_only_source`]: #method.play_only_source
/// [`play`]: #method.play
#[instrument(skip(self))]
pub fn play_only(&mut self, track: Track) {
self.send(CoreMessage::SetTrack(Some(track)));
}
/// Sets the bitrate for encoding Opus packets sent along
/// the channel being managed.
///
/// The default rate is 128 kbps.
/// Sensible values range between `Bits(512)` and `Bits(512_000)`
/// bits per second.
/// Alternatively, `Auto` and `Max` remain available.
#[instrument(skip(self))]
pub fn set_bitrate(&mut self, bitrate: Bitrate) {
self.send(CoreMessage::SetBitrate(bitrate))
}
/// Stops playing audio from all sources, if any are set.
#[instrument(skip(self))]
pub fn stop(&mut self) {
self.send(CoreMessage::SetTrack(None))
}
/// Attach a global event handler to an audio context. Global events may receive
/// any [`EventContext`].
///
/// Global timing events will tick regardless of whether audio is playing,
/// so long as the bot is connected to a voice channel, and have no tracks.
/// [`TrackEvent`]s will respond to all relevant tracks, giving some audio elements.
///
/// Users **must** ensure that no costly work or blocking occurs
/// within the supplied function or closure. *Taking excess time could prevent
/// timely sending of packets, causing audio glitches and delays*.
///
/// [`Track`]: ../tracks/struct.Track.html
/// [`TrackEvent`]: ../events/enum.TrackEvent.html
/// [`EventContext`]: ../events/enum.EventContext.html
#[instrument(skip(self, action))]
pub fn add_global_event<F: EventHandler + 'static>(&mut self, event: Event, action: F) {
self.send(CoreMessage::AddEvent(EventData::new(event, action)));
}
/// Sends a message to the inner tasks, restarting it if necessary.
fn send(&mut self, status: CoreMessage) {
// Restart thread if it errored.
if let Err(SendError(status)) = self.sender.send(status) {
self.restart_inner();
self.sender.send(status).unwrap();
}
}
}
impl Default for Driver {
fn default() -> Self {
Self::new(Default::default())
}
}
impl Drop for Driver {
/// Leaves the current connected voice channel, if connected to one, and
/// forgets all configurations relevant to this Handler.
fn drop(&mut self) {
self.leave();
let _ = self.sender.send(CoreMessage::Poison);
}
}

97
src/driver/tasks/error.rs Normal file
View File

@@ -0,0 +1,97 @@
use super::message::*;
use crate::ws::Error as WsError;
use audiopus::Error as OpusError;
use flume::SendError;
use std::io::Error as IoError;
use xsalsa20poly1305::aead::Error as CryptoError;
#[derive(Debug)]
pub enum Recipient {
AuxNetwork,
Event,
Mixer,
UdpRx,
UdpTx,
}
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug)]
pub enum Error {
Crypto(CryptoError),
/// Received an illegal voice packet on the voice UDP socket.
IllegalVoicePacket,
InterconnectFailure(Recipient),
Io(IoError),
Opus(OpusError),
Ws(WsError),
}
impl Error {
pub(crate) fn should_trigger_connect(&self) -> bool {
matches!(
self,
Error::InterconnectFailure(Recipient::AuxNetwork)
| Error::InterconnectFailure(Recipient::UdpRx)
| Error::InterconnectFailure(Recipient::UdpTx)
)
}
pub(crate) fn should_trigger_interconnect_rebuild(&self) -> bool {
matches!(self, Error::InterconnectFailure(Recipient::Event))
}
}
impl From<CryptoError> for Error {
fn from(e: CryptoError) -> Self {
Error::Crypto(e)
}
}
impl From<IoError> for Error {
fn from(e: IoError) -> Error {
Error::Io(e)
}
}
impl From<OpusError> for Error {
fn from(e: OpusError) -> Error {
Error::Opus(e)
}
}
impl From<SendError<WsMessage>> for Error {
fn from(_e: SendError<WsMessage>) -> Error {
Error::InterconnectFailure(Recipient::AuxNetwork)
}
}
impl From<SendError<EventMessage>> for Error {
fn from(_e: SendError<EventMessage>) -> Error {
Error::InterconnectFailure(Recipient::Event)
}
}
impl From<SendError<MixerMessage>> for Error {
fn from(_e: SendError<MixerMessage>) -> Error {
Error::InterconnectFailure(Recipient::Mixer)
}
}
impl From<SendError<UdpRxMessage>> for Error {
fn from(_e: SendError<UdpRxMessage>) -> Error {
Error::InterconnectFailure(Recipient::UdpRx)
}
}
impl From<SendError<UdpTxMessage>> for Error {
fn from(_e: SendError<UdpTxMessage>) -> Error {
Error::InterconnectFailure(Recipient::UdpTx)
}
}
impl From<WsError> for Error {
fn from(e: WsError) -> Error {
Error::Ws(e)
}
}

118
src/driver/tasks/events.rs Normal file
View File

@@ -0,0 +1,118 @@
use super::message::*;
use crate::{
events::{EventStore, GlobalEvents, TrackEvent},
tracks::{TrackHandle, TrackState},
};
use flume::Receiver;
use tracing::{debug, info, instrument, trace};
#[instrument(skip(_interconnect, evt_rx))]
pub(crate) async fn runner(_interconnect: Interconnect, evt_rx: Receiver<EventMessage>) {
let mut global = GlobalEvents::default();
let mut events: Vec<EventStore> = vec![];
let mut states: Vec<TrackState> = vec![];
let mut handles: Vec<TrackHandle> = vec![];
loop {
use EventMessage::*;
match evt_rx.recv_async().await {
Ok(AddGlobalEvent(data)) => {
info!("Global event added.");
global.add_event(data);
},
Ok(AddTrackEvent(i, data)) => {
info!("Adding event to track {}.", i);
let event_store = events
.get_mut(i)
.expect("Event thread was given an illegal store index for AddTrackEvent.");
let state = states
.get_mut(i)
.expect("Event thread was given an illegal state index for AddTrackEvent.");
event_store.add_event(data, state.position);
},
Ok(FireCoreEvent(ctx)) => {
let ctx = ctx.to_user_context();
let evt = ctx
.to_core_event()
.expect("Event thread was passed a non-core event in FireCoreEvent.");
trace!("Firing core event {:?}.", evt);
global.fire_core_event(evt, ctx).await;
},
Ok(AddTrack(store, state, handle)) => {
events.push(store);
states.push(state);
handles.push(handle);
info!("Event state for track {} added", events.len());
},
Ok(ChangeState(i, change)) => {
use TrackStateChange::*;
let max_states = states.len();
debug!(
"Changing state for track {} of {}: {:?}",
i, max_states, change
);
let state = states
.get_mut(i)
.expect("Event thread was given an illegal state index for ChangeState.");
match change {
Mode(mode) => {
let old = state.playing;
state.playing = mode;
if old != mode && mode.is_done() {
global.fire_track_event(TrackEvent::End, i);
}
},
Volume(vol) => {
state.volume = vol;
},
Position(pos) => {
// Currently, only Tick should fire time events.
state.position = pos;
},
Loops(loops, user_set) => {
state.loops = loops;
if !user_set {
global.fire_track_event(TrackEvent::Loop, i);
}
},
Total(new) => {
// Massive, unprecedented state changes.
*state = new;
},
}
},
Ok(RemoveTrack(i)) => {
info!("Event state for track {} of {} removed.", i, events.len());
events.remove(i);
states.remove(i);
handles.remove(i);
},
Ok(RemoveAllTracks) => {
info!("Event state for all tracks removed.");
events.clear();
states.clear();
handles.clear();
},
Ok(Tick) => {
// NOTE: this should fire saved up blocks of state change evts.
global.tick(&mut events, &mut states, &mut handles).await;
},
Err(_) | Ok(Poison) => {
break;
},
}
}
info!("Event thread exited.");
}

View File

@@ -0,0 +1,24 @@
use crate::{
driver::connection::error::Error,
events::EventData,
tracks::Track,
Bitrate,
ConnectionInfo,
};
use flume::Sender;
#[allow(clippy::large_enum_variant)]
#[derive(Debug)]
pub enum CoreMessage {
ConnectWithResult(ConnectionInfo, Sender<Result<(), Error>>),
Disconnect,
SetTrack(Option<Track>),
AddTrack(Track),
SetBitrate(Bitrate),
AddEvent(EventData),
Mute(bool),
Reconnect,
FullReconnect,
RebuildInterconnect,
Poison,
}

View File

@@ -0,0 +1,31 @@
use crate::{
events::{CoreContext, EventData, EventStore},
tracks::{LoopState, PlayMode, TrackHandle, TrackState},
};
use std::time::Duration;
pub(crate) enum EventMessage {
// Event related.
// Track events should fire off the back of state changes.
AddGlobalEvent(EventData),
AddTrackEvent(usize, EventData),
FireCoreEvent(CoreContext),
AddTrack(EventStore, TrackState, TrackHandle),
ChangeState(usize, TrackStateChange),
RemoveTrack(usize),
RemoveAllTracks,
Tick,
Poison,
}
#[derive(Debug)]
pub enum TrackStateChange {
Mode(PlayMode),
Volume(f32),
Position(Duration),
// Bool indicates user-set.
Loops(LoopState, bool),
Total(TrackState),
}

View File

@@ -0,0 +1,32 @@
use super::{Interconnect, UdpRxMessage, UdpTxMessage, WsMessage};
use crate::{tracks::Track, Bitrate};
use flume::Sender;
use xsalsa20poly1305::XSalsa20Poly1305 as Cipher;
pub(crate) struct MixerConnection {
pub cipher: Cipher,
pub udp_rx: Sender<UdpRxMessage>,
pub udp_tx: Sender<UdpTxMessage>,
}
impl Drop for MixerConnection {
fn drop(&mut self) {
let _ = self.udp_rx.send(UdpRxMessage::Poison);
let _ = self.udp_tx.send(UdpTxMessage::Poison);
}
}
pub(crate) enum MixerMessage {
AddTrack(Track),
SetTrack(Option<Track>),
SetBitrate(Bitrate),
SetMute(bool),
SetConn(MixerConnection, u32),
DropConn,
ReplaceInterconnect(Interconnect),
RebuildEncoder,
Ws(Option<Sender<WsMessage>>),
Poison,
}

View File

@@ -0,0 +1,49 @@
mod core;
mod events;
mod mixer;
mod udp_rx;
mod udp_tx;
mod ws;
pub(crate) use self::{core::*, events::*, mixer::*, udp_rx::*, udp_tx::*, ws::*};
use flume::Sender;
use tracing::info;
#[derive(Clone, Debug)]
pub(crate) struct Interconnect {
pub core: Sender<CoreMessage>,
pub events: Sender<EventMessage>,
pub mixer: Sender<MixerMessage>,
}
impl Interconnect {
pub fn poison(&self) {
let _ = self.events.send(EventMessage::Poison);
}
pub fn poison_all(&self) {
self.poison();
let _ = self.mixer.send(MixerMessage::Poison);
}
pub fn restart_volatile_internals(&mut self) {
self.poison();
let (evt_tx, evt_rx) = flume::unbounded();
self.events = evt_tx;
let ic = self.clone();
tokio::spawn(async move {
info!("Event processor restarted.");
super::events::runner(ic, evt_rx).await;
info!("Event processor finished.");
});
// Make mixer aware of new targets...
let _ = self
.mixer
.send(MixerMessage::ReplaceInterconnect(self.clone()));
}
}

View File

@@ -0,0 +1,7 @@
use super::Interconnect;
pub(crate) enum UdpRxMessage {
ReplaceInterconnect(Interconnect),
Poison,
}

View File

@@ -0,0 +1,4 @@
pub enum UdpTxMessage {
Packet(Vec<u8>), // TODO: do something cheaper.
Poison,
}

View File

@@ -0,0 +1,12 @@
use super::Interconnect;
use crate::ws::WsStream;
#[allow(dead_code)]
pub(crate) enum WsMessage {
Ws(Box<WsStream>),
ReplaceInterconnect(Interconnect),
SetKeepalive(f64),
Speaking(bool),
Poison,
}

516
src/driver/tasks/mixer.rs Normal file
View File

@@ -0,0 +1,516 @@
use super::{error::Result, message::*};
use crate::{
constants::*,
tracks::{PlayMode, Track},
};
use audiopus::{
coder::Encoder as OpusEncoder,
softclip::SoftClip,
Application as CodingMode,
Bitrate,
Channels,
};
use discortp::{
rtp::{MutableRtpPacket, RtpPacket},
MutablePacket,
Packet,
};
use flume::{Receiver, Sender, TryRecvError};
use rand::random;
use spin_sleep::SpinSleeper;
use std::time::Instant;
use tokio::runtime::Handle;
use tracing::{error, instrument};
use xsalsa20poly1305::{aead::AeadInPlace, Nonce, TAG_SIZE};
struct Mixer {
async_handle: Handle,
bitrate: Bitrate,
conn_active: Option<MixerConnection>,
deadline: Instant,
encoder: OpusEncoder,
interconnect: Interconnect,
mix_rx: Receiver<MixerMessage>,
muted: bool,
packet: [u8; VOICE_PACKET_MAX],
prevent_events: bool,
silence_frames: u8,
sleeper: SpinSleeper,
soft_clip: SoftClip,
tracks: Vec<Track>,
ws: Option<Sender<WsMessage>>,
}
fn new_encoder(bitrate: Bitrate) -> Result<OpusEncoder> {
let mut encoder = OpusEncoder::new(SAMPLE_RATE, Channels::Stereo, CodingMode::Audio)?;
encoder.set_bitrate(bitrate)?;
Ok(encoder)
}
impl Mixer {
fn new(
mix_rx: Receiver<MixerMessage>,
async_handle: Handle,
interconnect: Interconnect,
) -> Self {
let bitrate = DEFAULT_BITRATE;
let encoder = new_encoder(bitrate)
.expect("Failed to create encoder in mixing thread with known-good values.");
let soft_clip = SoftClip::new(Channels::Stereo);
let mut packet = [0u8; VOICE_PACKET_MAX];
let mut rtp = MutableRtpPacket::new(&mut packet[..]).expect(
"FATAL: Too few bytes in self.packet for RTP header.\
(Blame: VOICE_PACKET_MAX?)",
);
rtp.set_version(RTP_VERSION);
rtp.set_payload_type(RTP_PROFILE_TYPE);
rtp.set_sequence(random::<u16>().into());
rtp.set_timestamp(random::<u32>().into());
Self {
async_handle,
bitrate,
conn_active: None,
deadline: Instant::now(),
encoder,
interconnect,
mix_rx,
muted: false,
packet,
prevent_events: false,
silence_frames: 0,
sleeper: Default::default(),
soft_clip,
tracks: vec![],
ws: None,
}
}
fn run(&mut self) {
let mut events_failure = false;
let mut conn_failure = false;
'runner: loop {
loop {
use MixerMessage::*;
let error = match self.mix_rx.try_recv() {
Ok(AddTrack(mut t)) => {
t.source.prep_with_handle(self.async_handle.clone());
self.add_track(t)
},
Ok(SetTrack(t)) => {
self.tracks.clear();
let mut out = self.fire_event(EventMessage::RemoveAllTracks);
if let Some(mut t) = t {
t.source.prep_with_handle(self.async_handle.clone());
// Do this unconditionally: this affects local state infallibly,
// with the event installation being the remote part.
if let Err(e) = self.add_track(t) {
out = Err(e);
}
}
out
},
Ok(SetBitrate(b)) => {
self.bitrate = b;
if let Err(e) = self.set_bitrate(b) {
error!("Failed to update bitrate {:?}", e);
}
Ok(())
},
Ok(SetMute(m)) => {
self.muted = m;
Ok(())
},
Ok(SetConn(conn, ssrc)) => {
self.conn_active = Some(conn);
let mut rtp = MutableRtpPacket::new(&mut self.packet[..]).expect(
"Too few bytes in self.packet for RTP header.\
(Blame: VOICE_PACKET_MAX?)",
);
rtp.set_ssrc(ssrc);
self.deadline = Instant::now();
Ok(())
},
Ok(DropConn) => {
self.conn_active = None;
Ok(())
},
Ok(ReplaceInterconnect(i)) => {
self.prevent_events = false;
if let Some(ws) = &self.ws {
conn_failure |=
ws.send(WsMessage::ReplaceInterconnect(i.clone())).is_err();
}
if let Some(conn) = &self.conn_active {
conn_failure |= conn
.udp_rx
.send(UdpRxMessage::ReplaceInterconnect(i.clone()))
.is_err();
}
self.interconnect = i;
self.rebuild_tracks()
},
Ok(RebuildEncoder) => match new_encoder(self.bitrate) {
Ok(encoder) => {
self.encoder = encoder;
Ok(())
},
Err(e) => {
error!("Failed to rebuild encoder. Resetting bitrate. {:?}", e);
self.bitrate = DEFAULT_BITRATE;
self.encoder = new_encoder(self.bitrate)
.expect("Failed fallback rebuild of OpusEncoder with safe inputs.");
Ok(())
},
},
Ok(Ws(new_ws_handle)) => {
self.ws = new_ws_handle;
Ok(())
},
Err(TryRecvError::Disconnected) | Ok(Poison) => {
break 'runner;
},
Err(TryRecvError::Empty) => {
break;
},
};
if let Err(e) = error {
events_failure |= e.should_trigger_interconnect_rebuild();
conn_failure |= e.should_trigger_connect();
}
}
if let Err(e) = self.cycle().and_then(|_| self.audio_commands_events()) {
events_failure |= e.should_trigger_interconnect_rebuild();
conn_failure |= e.should_trigger_connect();
error!("Mixer thread cycle: {:?}", e);
}
// event failure? rebuild interconnect.
// ws or udp failure? full connect
// (soft reconnect is covered by the ws task.)
if events_failure {
self.prevent_events = true;
self.interconnect
.core
.send(CoreMessage::RebuildInterconnect)
.expect("FATAL: No way to rebuild driver core from mixer.");
events_failure = false;
}
if conn_failure {
self.interconnect
.core
.send(CoreMessage::FullReconnect)
.expect("FATAL: No way to rebuild driver core from mixer.");
conn_failure = false;
}
}
}
#[inline]
fn fire_event(&self, event: EventMessage) -> Result<()> {
// As this task is responsible for noticing the potential death of an event context,
// it's responsible for not forcibly recreating said context repeatedly.
if !self.prevent_events {
self.interconnect.events.send(event)?;
Ok(())
} else {
Ok(())
}
}
#[inline]
fn add_track(&mut self, mut track: Track) -> Result<()> {
let evts = track.events.take().unwrap_or_default();
let state = track.state();
let handle = track.handle.clone();
self.tracks.push(track);
self.interconnect
.events
.send(EventMessage::AddTrack(evts, state, handle))?;
Ok(())
}
// rebuilds the event thread's view of each track, in event of a full rebuild.
#[inline]
fn rebuild_tracks(&mut self) -> Result<()> {
for track in self.tracks.iter_mut() {
let evts = track.events.take().unwrap_or_default();
let state = track.state();
let handle = track.handle.clone();
self.interconnect
.events
.send(EventMessage::AddTrack(evts, state, handle))?;
}
Ok(())
}
#[inline]
fn mix_tracks<'a>(
&mut self,
opus_frame: &'a mut [u8],
mix_buffer: &mut [f32; STEREO_FRAME_SIZE],
) -> Result<(usize, &'a [u8])> {
let mut len = 0;
// Opus frame passthrough.
// This requires that we have only one track, who has volume 1.0, and an
// Opus codec type.
let do_passthrough = self.tracks.len() == 1 && {
let track = &self.tracks[0];
(track.volume - 1.0).abs() < f32::EPSILON && track.source.supports_passthrough()
};
for (i, track) in self.tracks.iter_mut().enumerate() {
let vol = track.volume;
let stream = &mut track.source;
if track.playing != PlayMode::Play {
continue;
}
let (temp_len, opus_len) = if do_passthrough {
(0, track.source.read_opus_frame(opus_frame).ok())
} else {
(stream.mix(mix_buffer, vol), None)
};
len = len.max(temp_len);
if temp_len > 0 || opus_len.is_some() {
track.step_frame();
} else if track.do_loop() {
if let Some(time) = track.seek_time(Default::default()) {
// have to reproduce self.fire_event here
// to circumvent the borrow checker's lack of knowledge.
//
// In event of error, one of the later event calls will
// trigger the event thread rebuild: it is more prudent that
// the mixer works as normal right now.
if !self.prevent_events {
let _ = self.interconnect.events.send(EventMessage::ChangeState(
i,
TrackStateChange::Position(time),
));
let _ = self.interconnect.events.send(EventMessage::ChangeState(
i,
TrackStateChange::Loops(track.loops, false),
));
}
}
} else {
track.end();
}
if let Some(opus_len) = opus_len {
return Ok((STEREO_FRAME_SIZE, &opus_frame[..opus_len]));
}
}
Ok((len, &opus_frame[..0]))
}
#[inline]
fn audio_commands_events(&mut self) -> Result<()> {
// Apply user commands.
for (i, track) in self.tracks.iter_mut().enumerate() {
// This causes fallible event system changes,
// but if the event thread has died then we'll certainly
// detect that on the tick later.
// Changes to play state etc. MUST all be handled.
track.process_commands(i, &self.interconnect);
}
// TODO: do without vec?
let mut i = 0;
let mut to_remove = Vec::with_capacity(self.tracks.len());
while i < self.tracks.len() {
let track = self
.tracks
.get_mut(i)
.expect("Tried to remove an illegal track index.");
if track.playing.is_done() {
let p_state = track.playing();
self.tracks.remove(i);
to_remove.push(i);
self.fire_event(EventMessage::ChangeState(
i,
TrackStateChange::Mode(p_state),
))?;
} else {
i += 1;
}
}
// Tick
self.fire_event(EventMessage::Tick)?;
// Then do removals.
for i in &to_remove[..] {
self.fire_event(EventMessage::RemoveTrack(*i))?;
}
Ok(())
}
#[inline]
fn march_deadline(&mut self) {
self.sleeper
.sleep(self.deadline.saturating_duration_since(Instant::now()));
self.deadline += TIMESTEP_LENGTH;
}
fn cycle(&mut self) -> Result<()> {
if self.conn_active.is_none() {
self.march_deadline();
return Ok(());
}
// TODO: can we make opus_frame_backing *actually* a view over
// some region of self.packet, derived using the encryption mode?
// This saves a copy on Opus passthrough.
let mut opus_frame_backing = [0u8; STEREO_FRAME_SIZE];
let mut mix_buffer = [0f32; STEREO_FRAME_SIZE];
// Slice which mix tracks may use to passthrough direct Opus frames.
let mut opus_space = &mut opus_frame_backing[..];
// Walk over all the audio files, combining into one audio frame according
// to volume, play state, etc.
let (mut len, mut opus_frame) = self.mix_tracks(&mut opus_space, &mut mix_buffer)?;
self.soft_clip.apply(&mut mix_buffer[..])?;
if self.muted {
len = 0;
}
if len == 0 {
if self.silence_frames > 0 {
self.silence_frames -= 1;
// Explicit "Silence" frame.
opus_frame = &SILENT_FRAME[..];
} else {
// Per official guidelines, send 5x silence BEFORE we stop speaking.
if let Some(ws) = &self.ws {
// NOTE: this should prevent a catastrophic thread pileup.
// A full reconnect might cause an inner closed connection.
// It's safer to leave the central task to clean this up and
// pass the mixer a new channel.
let _ = ws.send(WsMessage::Speaking(false));
}
self.march_deadline();
return Ok(());
}
} else {
self.silence_frames = 5;
}
if let Some(ws) = &self.ws {
ws.send(WsMessage::Speaking(true))?;
}
self.march_deadline();
self.prep_and_send_packet(mix_buffer, opus_frame)?;
Ok(())
}
fn set_bitrate(&mut self, bitrate: Bitrate) -> Result<()> {
self.encoder.set_bitrate(bitrate).map_err(Into::into)
}
fn prep_and_send_packet(&mut self, buffer: [f32; 1920], opus_frame: &[u8]) -> Result<()> {
let conn = self
.conn_active
.as_mut()
.expect("Shouldn't be mixing packets without access to a cipher + UDP dest.");
let mut nonce = Nonce::default();
let index = {
let mut rtp = MutableRtpPacket::new(&mut self.packet[..]).expect(
"FATAL: Too few bytes in self.packet for RTP header.\
(Blame: VOICE_PACKET_MAX?)",
);
let pkt = rtp.packet();
let rtp_len = RtpPacket::minimum_packet_size();
nonce[..rtp_len].copy_from_slice(&pkt[..rtp_len]);
let payload = rtp.payload_mut();
let payload_len = if opus_frame.is_empty() {
self.encoder
.encode_float(&buffer[..STEREO_FRAME_SIZE], &mut payload[TAG_SIZE..])?
} else {
let len = opus_frame.len();
payload[TAG_SIZE..TAG_SIZE + len].clone_from_slice(opus_frame);
len
};
let final_payload_size = TAG_SIZE + payload_len;
let tag = conn.cipher.encrypt_in_place_detached(
&nonce,
b"",
&mut payload[TAG_SIZE..final_payload_size],
)?;
payload[..TAG_SIZE].copy_from_slice(&tag[..]);
rtp_len + final_payload_size
};
// TODO: This is dog slow, don't do this.
// Can we replace this with a shared ring buffer + semaphore?
// i.e., do something like double/triple buffering in graphics.
conn.udp_tx
.send(UdpTxMessage::Packet(self.packet[..index].to_vec()))?;
let mut rtp = MutableRtpPacket::new(&mut self.packet[..]).expect(
"FATAL: Too few bytes in self.packet for RTP header.\
(Blame: VOICE_PACKET_MAX?)",
);
rtp.set_sequence(rtp.get_sequence() + 1);
rtp.set_timestamp(rtp.get_timestamp() + MONO_FRAME_SIZE as u32);
Ok(())
}
}
/// The mixing thread is a synchronous context due to its compute-bound nature.
///
/// We pass in an async handle for the benefit of some Input classes (e.g., restartables)
/// who need to run their restart code elsewhere and return blank data until such time.
#[instrument(skip(interconnect, mix_rx, async_handle))]
pub(crate) fn runner(
interconnect: Interconnect,
mix_rx: Receiver<MixerMessage>,
async_handle: Handle,
) {
let mut mixer = Mixer::new(mix_rx, async_handle, interconnect);
mixer.run();
}

155
src/driver/tasks/mod.rs Normal file
View File

@@ -0,0 +1,155 @@
pub mod error;
mod events;
pub(crate) mod message;
mod mixer;
pub(crate) mod udp_rx;
pub(crate) mod udp_tx;
pub(crate) mod ws;
use super::{
connection::{error::Error as ConnectionError, Connection},
Config,
};
use flume::{Receiver, RecvError, Sender};
use message::*;
use tokio::runtime::Handle;
use tracing::{error, info, instrument};
pub(crate) fn start(config: Config, rx: Receiver<CoreMessage>, tx: Sender<CoreMessage>) {
tokio::spawn(async move {
info!("Driver started.");
runner(config, rx, tx).await;
info!("Driver finished.");
});
}
fn start_internals(core: Sender<CoreMessage>) -> Interconnect {
let (evt_tx, evt_rx) = flume::unbounded();
let (mix_tx, mix_rx) = flume::unbounded();
let interconnect = Interconnect {
core,
events: evt_tx,
mixer: mix_tx,
};
let ic = interconnect.clone();
tokio::spawn(async move {
info!("Event processor started.");
events::runner(ic, evt_rx).await;
info!("Event processor finished.");
});
let ic = interconnect.clone();
let handle = Handle::current();
std::thread::spawn(move || {
info!("Mixer started.");
mixer::runner(ic, mix_rx, handle);
info!("Mixer finished.");
});
interconnect
}
#[instrument(skip(rx, tx))]
async fn runner(config: Config, rx: Receiver<CoreMessage>, tx: Sender<CoreMessage>) {
let mut connection = None;
let mut interconnect = start_internals(tx);
loop {
match rx.recv_async().await {
Ok(CoreMessage::ConnectWithResult(info, tx)) => {
connection = match Connection::new(info, &interconnect, &config).await {
Ok(connection) => {
// Other side may not be listening: this is fine.
let _ = tx.send(Ok(()));
Some(connection)
},
Err(why) => {
// See above.
let _ = tx.send(Err(why));
None
},
};
},
Ok(CoreMessage::Disconnect) => {
connection = None;
let _ = interconnect.mixer.send(MixerMessage::DropConn);
let _ = interconnect.mixer.send(MixerMessage::RebuildEncoder);
},
Ok(CoreMessage::SetTrack(s)) => {
let _ = interconnect.mixer.send(MixerMessage::SetTrack(s));
},
Ok(CoreMessage::AddTrack(s)) => {
let _ = interconnect.mixer.send(MixerMessage::AddTrack(s));
},
Ok(CoreMessage::SetBitrate(b)) => {
let _ = interconnect.mixer.send(MixerMessage::SetBitrate(b));
},
Ok(CoreMessage::AddEvent(evt)) => {
let _ = interconnect.events.send(EventMessage::AddGlobalEvent(evt));
},
Ok(CoreMessage::Mute(m)) => {
let _ = interconnect.mixer.send(MixerMessage::SetMute(m));
},
Ok(CoreMessage::Reconnect) => {
if let Some(mut conn) = connection.take() {
// try once: if interconnect, try again.
// if still issue, full connect.
let info = conn.info.clone();
let full_connect = match conn.reconnect().await {
Ok(()) => {
connection = Some(conn);
false
},
Err(ConnectionError::InterconnectFailure(_)) => {
interconnect.restart_volatile_internals();
match conn.reconnect().await {
Ok(()) => {
connection = Some(conn);
false
},
_ => true,
}
},
_ => true,
};
if full_connect {
connection = Connection::new(info, &interconnect, &config)
.await
.map_err(|e| {
error!("Catastrophic connection failure. Stopping. {:?}", e);
e
})
.ok();
}
}
},
Ok(CoreMessage::FullReconnect) =>
if let Some(conn) = connection.take() {
let info = conn.info.clone();
connection = Connection::new(info, &interconnect, &config)
.await
.map_err(|e| {
error!("Catastrophic connection failure. Stopping. {:?}", e);
e
})
.ok();
},
Ok(CoreMessage::RebuildInterconnect) => {
interconnect.restart_volatile_internals();
},
Err(RecvError::Disconnected) | Ok(CoreMessage::Poison) => {
break;
},
}
}
info!("Main thread exited");
interconnect.poison_all();
}

286
src/driver/tasks/udp_rx.rs Normal file
View File

@@ -0,0 +1,286 @@
use super::{
error::{Error, Result},
message::*,
};
use crate::{constants::*, driver::CryptoMode, events::CoreContext};
use audiopus::{coder::Decoder as OpusDecoder, Channels};
use discortp::{
demux::{self, DemuxedMut},
rtp::{RtpExtensionPacket, RtpPacket},
FromPacket,
MutablePacket,
Packet,
PacketSize,
};
use flume::Receiver;
use std::collections::HashMap;
use tokio::net::udp::RecvHalf;
use tracing::{error, info, instrument, warn};
use xsalsa20poly1305::{aead::AeadInPlace, Nonce, Tag, XSalsa20Poly1305 as Cipher, TAG_SIZE};
#[derive(Debug)]
struct SsrcState {
silent_frame_count: u16,
decoder: OpusDecoder,
last_seq: u16,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum SpeakingDelta {
Same,
Start,
Stop,
}
impl SsrcState {
fn new(pkt: RtpPacket<'_>) -> Self {
Self {
silent_frame_count: 5, // We do this to make the first speech packet fire an event.
decoder: OpusDecoder::new(SAMPLE_RATE, Channels::Stereo)
.expect("Failed to create new Opus decoder for source."),
last_seq: pkt.get_sequence().into(),
}
}
fn process(
&mut self,
pkt: RtpPacket<'_>,
data_offset: usize,
) -> Result<(SpeakingDelta, Vec<i16>)> {
let new_seq: u16 = pkt.get_sequence().into();
let extensions = pkt.get_extension() != 0;
let seq_delta = new_seq.wrapping_sub(self.last_seq);
Ok(if seq_delta >= (1 << 15) {
// Overflow, reordered (previously missing) packet.
(SpeakingDelta::Same, vec![])
} else {
self.last_seq = new_seq;
let missed_packets = seq_delta.saturating_sub(1);
let (audio, pkt_size) =
self.scan_and_decode(&pkt.payload()[data_offset..], extensions, missed_packets)?;
let delta = if pkt_size == SILENT_FRAME.len() {
// Frame is silent.
let old = self.silent_frame_count;
self.silent_frame_count =
self.silent_frame_count.saturating_add(1 + missed_packets);
if self.silent_frame_count >= 5 && old < 5 {
SpeakingDelta::Stop
} else {
SpeakingDelta::Same
}
} else {
// Frame has meaningful audio.
let out = if self.silent_frame_count >= 5 {
SpeakingDelta::Start
} else {
SpeakingDelta::Same
};
self.silent_frame_count = 0;
out
};
(delta, audio)
})
}
fn scan_and_decode(
&mut self,
data: &[u8],
extension: bool,
missed_packets: u16,
) -> Result<(Vec<i16>, usize)> {
let mut out = vec![0; STEREO_FRAME_SIZE];
let start = if extension {
RtpExtensionPacket::new(data)
.map(|pkt| pkt.packet_size())
.ok_or_else(|| {
error!("Extension packet indicated, but insufficient space.");
Error::IllegalVoicePacket
})
} else {
Ok(0)
}?;
for _ in 0..missed_packets {
let missing_frame: Option<&[u8]> = None;
if let Err(e) = self.decoder.decode(missing_frame, &mut out[..], false) {
warn!("Issue while decoding for missed packet: {:?}.", e);
}
}
let audio_len = self
.decoder
.decode(Some(&data[start..]), &mut out[..], false)
.map_err(|e| {
error!("Failed to decode received packet: {:?}.", e);
e
})?;
// Decoding to stereo: audio_len refers to sample count irrespective of channel count.
// => multiply by number of channels.
out.truncate(2 * audio_len);
Ok((out, data.len() - start))
}
}
struct UdpRx {
cipher: Cipher,
decoder_map: HashMap<u32, SsrcState>,
#[allow(dead_code)]
mode: CryptoMode, // In future, this will allow crypto mode selection.
packet_buffer: [u8; VOICE_PACKET_MAX],
rx: Receiver<UdpRxMessage>,
udp_socket: RecvHalf,
}
impl UdpRx {
#[instrument(skip(self))]
async fn run(&mut self, interconnect: &mut Interconnect) {
loop {
tokio::select! {
Ok((len, _addr)) = self.udp_socket.recv_from(&mut self.packet_buffer[..]) => {
self.process_udp_message(interconnect, len);
}
msg = self.rx.recv_async() => {
use UdpRxMessage::*;
match msg {
Ok(ReplaceInterconnect(i)) => {
*interconnect = i;
}
Ok(Poison) | Err(_) => break,
}
}
}
}
}
fn process_udp_message(&mut self, interconnect: &Interconnect, len: usize) {
// NOTE: errors here (and in general for UDP) are not fatal to the connection.
// Panics should be avoided due to adversarial nature of rx'd packets,
// but correct handling should not prompt a reconnect.
//
// For simplicity, we nominate the mixing context to rebuild the event
// context if it fails (hence, the `let _ =` statements.), as it will try to
// make contact every 20ms.
let packet = &mut self.packet_buffer[..len];
match demux::demux_mut(packet) {
DemuxedMut::Rtp(mut rtp) => {
if !rtp_valid(rtp.to_immutable()) {
error!("Illegal RTP message received.");
return;
}
let rtp_body_start =
decrypt_in_place(&mut rtp, &self.cipher).expect("RTP decryption failed.");
let entry = self
.decoder_map
.entry(rtp.get_ssrc())
.or_insert_with(|| SsrcState::new(rtp.to_immutable()));
if let Ok((delta, audio)) = entry.process(rtp.to_immutable(), rtp_body_start) {
match delta {
SpeakingDelta::Start => {
let _ = interconnect.events.send(EventMessage::FireCoreEvent(
CoreContext::SpeakingUpdate {
ssrc: rtp.get_ssrc(),
speaking: true,
},
));
},
SpeakingDelta::Stop => {
let _ = interconnect.events.send(EventMessage::FireCoreEvent(
CoreContext::SpeakingUpdate {
ssrc: rtp.get_ssrc(),
speaking: false,
},
));
},
_ => {},
}
let _ = interconnect.events.send(EventMessage::FireCoreEvent(
CoreContext::VoicePacket {
audio,
packet: rtp.from_packet(),
payload_offset: rtp_body_start,
},
));
} else {
warn!("RTP decoding/decrytion failed.");
}
},
DemuxedMut::Rtcp(mut rtcp) => {
let rtcp_body_start = decrypt_in_place(&mut rtcp, &self.cipher);
if let Ok(start) = rtcp_body_start {
let _ = interconnect.events.send(EventMessage::FireCoreEvent(
CoreContext::RtcpPacket {
packet: rtcp.from_packet(),
payload_offset: start,
},
));
} else {
warn!("RTCP decryption failed.");
}
},
DemuxedMut::FailedParse(t) => {
warn!("Failed to parse message of type {:?}.", t);
},
_ => {
warn!("Illegal UDP packet from voice server.");
},
}
}
}
#[instrument(skip(interconnect, rx, cipher))]
pub(crate) async fn runner(
mut interconnect: Interconnect,
rx: Receiver<UdpRxMessage>,
cipher: Cipher,
mode: CryptoMode,
udp_socket: RecvHalf,
) {
info!("UDP receive handle started.");
let mut state = UdpRx {
cipher,
decoder_map: Default::default(),
mode,
packet_buffer: [0u8; VOICE_PACKET_MAX],
rx,
udp_socket,
};
state.run(&mut interconnect).await;
info!("UDP receive handle stopped.");
}
#[inline]
fn decrypt_in_place(packet: &mut impl MutablePacket, cipher: &Cipher) -> Result<usize> {
// Applies discord's cheapest.
// In future, might want to make a choice...
let header_len = packet.packet().len() - packet.payload().len();
let mut nonce = Nonce::default();
nonce[..header_len].copy_from_slice(&packet.packet()[..header_len]);
let data = packet.payload_mut();
let (tag_bytes, data_bytes) = data.split_at_mut(TAG_SIZE);
let tag = Tag::from_slice(tag_bytes);
Ok(cipher
.decrypt_in_place_detached(&nonce, b"", data_bytes, tag)
.map(|_| TAG_SIZE)?)
}
#[inline]
fn rtp_valid(packet: RtpPacket<'_>) -> bool {
packet.get_version() == RTP_VERSION && packet.get_payload_type() == RTP_PROFILE_TYPE
}

View File

@@ -0,0 +1,45 @@
use super::message::*;
use crate::constants::*;
use discortp::discord::MutableKeepalivePacket;
use flume::Receiver;
use tokio::{
net::udp::SendHalf,
time::{timeout_at, Elapsed, Instant},
};
use tracing::{error, info, instrument, trace};
#[instrument(skip(udp_msg_rx))]
pub(crate) async fn runner(udp_msg_rx: Receiver<UdpTxMessage>, ssrc: u32, mut udp_tx: SendHalf) {
info!("UDP transmit handle started.");
let mut keepalive_bytes = [0u8; MutableKeepalivePacket::minimum_packet_size()];
let mut ka = MutableKeepalivePacket::new(&mut keepalive_bytes[..])
.expect("FATAL: Insufficient bytes given to keepalive packet.");
ka.set_ssrc(ssrc);
let mut ka_time = Instant::now() + UDP_KEEPALIVE_GAP;
loop {
use UdpTxMessage::*;
match timeout_at(ka_time, udp_msg_rx.recv_async()).await {
Err(Elapsed { .. }) => {
trace!("Sending UDP Keepalive.");
if let Err(e) = udp_tx.send(&keepalive_bytes[..]).await {
error!("Fatal UDP keepalive send error: {:?}.", e);
break;
}
ka_time += UDP_KEEPALIVE_GAP;
},
Ok(Ok(Packet(p))) =>
if let Err(e) = udp_tx.send(&p[..]).await {
error!("Fatal UDP packet send error: {:?}.", e);
break;
},
Ok(Err(_)) | Ok(Ok(Poison)) => {
break;
},
}
}
info!("UDP transmit handle stopped.");
}

205
src/driver/tasks/ws.rs Normal file
View File

@@ -0,0 +1,205 @@
use super::{error::Result, message::*};
use crate::{
events::CoreContext,
model::{
payload::{Heartbeat, Speaking},
Event as GatewayEvent,
SpeakingState,
},
ws::{Error as WsError, ReceiverExt, SenderExt, WsStream},
};
use flume::Receiver;
use rand::random;
use std::time::Duration;
use tokio::time::{self, Instant};
use tracing::{error, info, instrument, trace, warn};
struct AuxNetwork {
rx: Receiver<WsMessage>,
ws_client: WsStream,
dont_send: bool,
ssrc: u32,
heartbeat_interval: Duration,
speaking: SpeakingState,
last_heartbeat_nonce: Option<u64>,
}
impl AuxNetwork {
pub(crate) fn new(
evt_rx: Receiver<WsMessage>,
ws_client: WsStream,
ssrc: u32,
heartbeat_interval: f64,
) -> Self {
Self {
rx: evt_rx,
ws_client,
dont_send: false,
ssrc,
heartbeat_interval: Duration::from_secs_f64(heartbeat_interval / 1000.0),
speaking: SpeakingState::empty(),
last_heartbeat_nonce: None,
}
}
#[instrument(skip(self))]
async fn run(&mut self, interconnect: &mut Interconnect) {
let mut next_heartbeat = Instant::now() + self.heartbeat_interval;
loop {
let mut ws_error = false;
let hb = time::delay_until(next_heartbeat);
tokio::select! {
_ = hb => {
ws_error = match self.send_heartbeat().await {
Err(e) => {
error!("Heartbeat send failure {:?}.", e);
true
},
_ => false,
};
next_heartbeat = self.next_heartbeat();
}
ws_msg = self.ws_client.recv_json_no_timeout(), if !self.dont_send => {
ws_error = match ws_msg {
Err(WsError::Json(e)) => {
warn!("Unexpected JSON {:?}.", e);
false
},
Err(e) => {
error!("Error processing ws {:?}.", e);
true
},
Ok(Some(msg)) => {
self.process_ws(interconnect, msg);
false
},
_ => false,
};
}
inner_msg = self.rx.recv_async() => {
match inner_msg {
Ok(WsMessage::Ws(data)) => {
self.ws_client = *data;
next_heartbeat = self.next_heartbeat();
self.dont_send = false;
},
Ok(WsMessage::ReplaceInterconnect(i)) => {
*interconnect = i;
},
Ok(WsMessage::SetKeepalive(keepalive)) => {
self.heartbeat_interval = Duration::from_secs_f64(keepalive / 1000.0);
next_heartbeat = self.next_heartbeat();
},
Ok(WsMessage::Speaking(is_speaking)) => {
if self.speaking.contains(SpeakingState::MICROPHONE) != is_speaking && !self.dont_send {
self.speaking.set(SpeakingState::MICROPHONE, is_speaking);
info!("Changing to {:?}", self.speaking);
let ssu_status = self.ws_client
.send_json(&GatewayEvent::from(Speaking {
delay: Some(0),
speaking: self.speaking,
ssrc: self.ssrc,
user_id: None,
}))
.await;
ws_error |= match ssu_status {
Err(e) => {
error!("Issue sending speaking update {:?}.", e);
true
},
_ => false,
}
}
},
Err(_) | Ok(WsMessage::Poison) => {
break;
},
}
}
}
if ws_error {
let _ = interconnect.core.send(CoreMessage::Reconnect);
self.dont_send = true;
}
}
}
fn next_heartbeat(&self) -> Instant {
Instant::now() + self.heartbeat_interval
}
async fn send_heartbeat(&mut self) -> Result<()> {
let nonce = random::<u64>();
self.last_heartbeat_nonce = Some(nonce);
trace!("Sent heartbeat {:?}", self.speaking);
if !self.dont_send {
self.ws_client
.send_json(&GatewayEvent::from(Heartbeat { nonce }))
.await?;
}
Ok(())
}
fn process_ws(&mut self, interconnect: &Interconnect, value: GatewayEvent) {
match value {
GatewayEvent::Speaking(ev) => {
let _ = interconnect.events.send(EventMessage::FireCoreEvent(
CoreContext::SpeakingStateUpdate(ev),
));
},
GatewayEvent::ClientConnect(ev) => {
let _ = interconnect
.events
.send(EventMessage::FireCoreEvent(CoreContext::ClientConnect(ev)));
},
GatewayEvent::ClientDisconnect(ev) => {
let _ = interconnect.events.send(EventMessage::FireCoreEvent(
CoreContext::ClientDisconnect(ev),
));
},
GatewayEvent::HeartbeatAck(ev) => {
if let Some(nonce) = self.last_heartbeat_nonce.take() {
if ev.nonce == nonce {
trace!("Heartbeat ACK received.");
} else {
warn!(
"Heartbeat nonce mismatch! Expected {}, saw {}.",
nonce, ev.nonce
);
}
}
},
other => {
trace!("Received other websocket data: {:?}", other);
},
}
}
}
#[instrument(skip(interconnect, ws_client))]
pub(crate) async fn runner(
mut interconnect: Interconnect,
evt_rx: Receiver<WsMessage>,
ws_client: WsStream,
ssrc: u32,
heartbeat_interval: f64,
) {
info!("WS thread started.");
let mut aux = AuxNetwork::new(evt_rx, ws_client, ssrc, heartbeat_interval);
aux.run(&mut interconnect).await;
info!("WS thread finished.");
}