Driver: Prune SsrcState after timeout/disconnect (#145)

`SsrcState` objects are created on a per-user basis when "receive" is enabled, but were previously never destroyed. This PR adds some shared dashmaps for the WS task to communicate SSRC-to-ID mappings to the UDP Rx task, as well as any disconnections. Additionally, decoder state is pruned a default 1 minute after a user last speaks.

This was tested using `cargo make ready` and via `examples/serenity/voice_receive/`.

Closes #133
This commit is contained in:
Kyle Simpson
2022-08-08 14:36:27 +01:00
parent 6769131fa2
commit 893dbaae34
5 changed files with 122 additions and 31 deletions

View File

@@ -49,6 +49,13 @@ pub struct Config {
/// [user speaking events]: crate::events::CoreEvent::SpeakingUpdate /// [user speaking events]: crate::events::CoreEvent::SpeakingUpdate
pub decode_mode: DecodeMode, pub decode_mode: DecodeMode,
#[cfg(all(feature = "driver", feature = "receive"))]
/// Configures the amount of time after a user/SSRC is inactive before their decoder state
/// should be removed.
///
/// Defaults to 1 minute.
pub decode_state_timeout: Duration,
#[cfg(feature = "gateway")] #[cfg(feature = "gateway")]
/// Configures the amount of time to wait for Discord to reply with connection information /// Configures the amount of time to wait for Discord to reply with connection information
/// if [`Call::join`]/[`join_gateway`] are used. /// if [`Call::join`]/[`join_gateway`] are used.
@@ -155,6 +162,8 @@ impl Default for Config {
crypto_mode: CryptoMode::Normal, crypto_mode: CryptoMode::Normal,
#[cfg(all(feature = "driver", feature = "receive"))] #[cfg(all(feature = "driver", feature = "receive"))]
decode_mode: DecodeMode::Decrypt, decode_mode: DecodeMode::Decrypt,
#[cfg(all(feature = "driver", feature = "receive"))]
decode_state_timeout: Duration::from_secs(60),
#[cfg(feature = "gateway")] #[cfg(feature = "gateway")]
gateway_timeout: Some(Duration::from_secs(10)), gateway_timeout: Some(Duration::from_secs(10)),
#[cfg(feature = "driver")] #[cfg(feature = "driver")]
@@ -198,6 +207,14 @@ impl Config {
self self
} }
#[cfg(feature = "receive")]
/// Sets this `Config`'s received packet decoder cleanup timer.
#[must_use]
pub fn decode_state_timeout(mut self, decode_state_timeout: Duration) -> Self {
self.decode_state_timeout = decode_state_timeout;
self
}
/// Sets this `Config`'s audio mixing channel count. /// Sets this `Config`'s audio mixing channel count.
#[must_use] #[must_use]
pub fn mix_mode(mut self, mix_mode: MixMode) -> Self { pub fn mix_mode(mut self, mix_mode: MixMode) -> Self {

View File

@@ -3,7 +3,10 @@ pub mod error;
#[cfg(feature = "receive")] #[cfg(feature = "receive")]
use super::tasks::udp_rx; use super::tasks::udp_rx;
use super::{ use super::{
tasks::{message::*, ws as ws_task}, tasks::{
message::*,
ws::{self as ws_task, AuxNetwork},
},
Config, Config,
CryptoMode, CryptoMode,
}; };
@@ -21,6 +24,8 @@ use discortp::discord::{IpDiscoveryPacket, IpDiscoveryType, MutableIpDiscoveryPa
use error::{Error, Result}; use error::{Error, Result};
use flume::Sender; use flume::Sender;
use socket2::Socket; use socket2::Socket;
#[cfg(feature = "receive")]
use std::sync::Arc;
use std::{net::IpAddr, str::FromStr}; use std::{net::IpAddr, str::FromStr};
use tokio::{net::UdpSocket, spawn, time::timeout}; use tokio::{net::UdpSocket, spawn, time::timeout};
use tracing::{debug, info, instrument}; use tracing::{debug, info, instrument};
@@ -217,15 +222,21 @@ impl Connection {
.mixer .mixer
.send(MixerMessage::SetConn(mix_conn, ready.ssrc))?; .send(MixerMessage::SetConn(mix_conn, ready.ssrc))?;
spawn(ws_task::runner( #[cfg(feature = "receive")]
interconnect.clone(), let ssrc_tracker = Arc::new(SsrcTracker::default());
let ws_state = AuxNetwork::new(
ws_msg_rx, ws_msg_rx,
client, client,
ssrc, ssrc,
hello.heartbeat_interval, hello.heartbeat_interval,
idx, idx,
info.clone(), info.clone(),
)); #[cfg(feature = "receive")]
ssrc_tracker.clone(),
);
spawn(ws_task::runner(interconnect.clone(), ws_state));
#[cfg(feature = "receive")] #[cfg(feature = "receive")]
spawn(udp_rx::runner( spawn(udp_rx::runner(
@@ -234,6 +245,7 @@ impl Connection {
cipher, cipher,
config.clone(), config.clone(),
udp_rx, udp_rx,
ssrc_tracker,
)); ));
Ok(Connection { Ok(Connection {

View File

@@ -2,8 +2,16 @@
use super::Interconnect; use super::Interconnect;
use crate::driver::Config; use crate::driver::Config;
use dashmap::{DashMap, DashSet};
use serenity_voice_model::id::UserId;
pub enum UdpRxMessage { pub enum UdpRxMessage {
SetConfig(Config), SetConfig(Config),
ReplaceInterconnect(Interconnect), ReplaceInterconnect(Interconnect),
} }
#[derive(Debug, Default)]
pub struct SsrcTracker {
pub disconnected_users: DashSet<UserId>,
pub user_ssrc_map: DashMap<UserId, u32>,
}

View File

@@ -22,8 +22,8 @@ use discortp::{
PacketSize, PacketSize,
}; };
use flume::Receiver; use flume::Receiver;
use std::{collections::HashMap, convert::TryInto}; use std::{collections::HashMap, convert::TryInto, sync::Arc, time::Duration};
use tokio::{net::UdpSocket, select}; use tokio::{net::UdpSocket, select, time::Instant};
use tracing::{error, instrument, trace, warn}; use tracing::{error, instrument, trace, warn};
use xsalsa20poly1305::XSalsa20Poly1305 as Cipher; use xsalsa20poly1305::XSalsa20Poly1305 as Cipher;
@@ -33,6 +33,8 @@ struct SsrcState {
decoder: OpusDecoder, decoder: OpusDecoder,
last_seq: u16, last_seq: u16,
decode_size: PacketDecodeSize, decode_size: PacketDecodeSize,
prune_time: Instant,
disconnected: bool,
} }
#[derive(Clone, Copy, Debug, Eq, PartialEq)] #[derive(Clone, Copy, Debug, Eq, PartialEq)]
@@ -84,13 +86,21 @@ enum SpeakingDelta {
} }
impl SsrcState { impl SsrcState {
fn new(pkt: &RtpPacket<'_>) -> Self { fn new(pkt: &RtpPacket<'_>, state_timeout: Duration) -> Self {
Self { Self {
silent_frame_count: 5, // We do this to make the first speech packet fire an event. silent_frame_count: 5, // We do this to make the first speech packet fire an event.
decoder: OpusDecoder::new(SAMPLE_RATE, Channels::Stereo) decoder: OpusDecoder::new(SAMPLE_RATE, Channels::Stereo)
.expect("Failed to create new Opus decoder for source."), .expect("Failed to create new Opus decoder for source."),
last_seq: pkt.get_sequence().into(), last_seq: pkt.get_sequence().into(),
decode_size: PacketDecodeSize::TwentyMillis, decode_size: PacketDecodeSize::TwentyMillis,
prune_time: Instant::now() + state_timeout,
disconnected: false,
}
}
fn refresh_timer(&mut self, state_timeout: Duration) {
if !self.disconnected {
self.prune_time = Instant::now() + state_timeout;
} }
} }
@@ -236,21 +246,23 @@ impl SsrcState {
struct UdpRx { struct UdpRx {
cipher: Cipher, cipher: Cipher,
decoder_map: HashMap<u32, SsrcState>, decoder_map: HashMap<u32, SsrcState>,
#[allow(dead_code)]
config: Config, config: Config,
packet_buffer: [u8; VOICE_PACKET_MAX], packet_buffer: [u8; VOICE_PACKET_MAX],
rx: Receiver<UdpRxMessage>, rx: Receiver<UdpRxMessage>,
ssrc_signalling: Arc<SsrcTracker>,
udp_socket: UdpSocket, udp_socket: UdpSocket,
} }
impl UdpRx { impl UdpRx {
#[instrument(skip(self))] #[instrument(skip(self))]
async fn run(&mut self, interconnect: &mut Interconnect) { async fn run(&mut self, interconnect: &mut Interconnect) {
let mut cleanup_time = Instant::now();
loop { loop {
select! { select! {
Ok((len, _addr)) = self.udp_socket.recv_from(&mut self.packet_buffer[..]) => { Ok((len, _addr)) = self.udp_socket.recv_from(&mut self.packet_buffer[..]) => {
self.process_udp_message(interconnect, len); self.process_udp_message(interconnect, len);
} },
msg = self.rx.recv_async() => { msg = self.rx.recv_async() => {
match msg { match msg {
Ok(UdpRxMessage::ReplaceInterconnect(i)) => { Ok(UdpRxMessage::ReplaceInterconnect(i)) => {
@@ -261,7 +273,41 @@ impl UdpRx {
}, },
Err(flume::RecvError::Disconnected) => break, Err(flume::RecvError::Disconnected) => break,
} }
} },
_ = tokio::time::sleep_until(cleanup_time) => {
// periodic cleanup.
let now = Instant::now();
// check ssrc map to see if the WS task has informed us of any disconnects.
loop {
// This is structured in an odd way to prevent deadlocks.
// while-let seemed to keep the dashmap iter() alive for block scope, rather than
// just the initialiser.
let id = {
if let Some(id) = self.ssrc_signalling.disconnected_users.iter().next().map(|v| *v.key()) {
id
} else {
break;
}
};
let _ = self.ssrc_signalling.disconnected_users.remove(&id);
if let Some((_, ssrc)) = self.ssrc_signalling.user_ssrc_map.remove(&id) {
if let Some(state) = self.decoder_map.get_mut(&ssrc) {
// don't cleanup immediately: leave for later cycle
// this is key with reorder/jitter buffers where we may
// still need to decode post disconnect for ~0.2s.
state.prune_time = now + Duration::from_secs(1);
state.disconnected = true;
}
}
}
// now remove all dead ssrcs.
self.decoder_map.retain(|_, v| v.prune_time > now);
cleanup_time = now + Duration::from_secs(5);
},
} }
} }
} }
@@ -310,7 +356,11 @@ impl UdpRx {
let entry = self let entry = self
.decoder_map .decoder_map
.entry(rtp.get_ssrc()) .entry(rtp.get_ssrc())
.or_insert_with(|| SsrcState::new(&rtp)); .or_insert_with(|| SsrcState::new(&rtp, self.config.decode_state_timeout));
// Only do this on RTP, rather than RTCP -- this pins decoder state liveness
// to *speech* rather than just presence.
entry.refresh_timer(self.config.decode_state_timeout);
if let Ok((delta, audio)) = entry.process( if let Ok((delta, audio)) = entry.process(
&rtp, &rtp,
@@ -396,6 +446,7 @@ pub(crate) async fn runner(
cipher: Cipher, cipher: Cipher,
config: Config, config: Config,
udp_socket: UdpSocket, udp_socket: UdpSocket,
ssrc_signalling: Arc<SsrcTracker>,
) { ) {
trace!("UDP receive handle started."); trace!("UDP receive handle started.");
@@ -405,6 +456,7 @@ pub(crate) async fn runner(
config, config,
packet_buffer: [0u8; VOICE_PACKET_MAX], packet_buffer: [0u8; VOICE_PACKET_MAX],
rx, rx,
ssrc_signalling,
udp_socket, udp_socket,
}; };

View File

@@ -13,6 +13,8 @@ use crate::{
}; };
use flume::Receiver; use flume::Receiver;
use rand::random; use rand::random;
#[cfg(feature = "receive")]
use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::{ use tokio::{
select, select,
@@ -21,7 +23,7 @@ use tokio::{
use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode; use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
use tracing::{debug, info, instrument, trace, warn}; use tracing::{debug, info, instrument, trace, warn};
struct AuxNetwork { pub(crate) struct AuxNetwork {
rx: Receiver<WsMessage>, rx: Receiver<WsMessage>,
ws_client: WsStream, ws_client: WsStream,
dont_send: bool, dont_send: bool,
@@ -34,6 +36,9 @@ struct AuxNetwork {
attempt_idx: usize, attempt_idx: usize,
info: ConnectionInfo, info: ConnectionInfo,
#[cfg(feature = "receive")]
ssrc_signalling: Arc<SsrcTracker>,
} }
impl AuxNetwork { impl AuxNetwork {
@@ -44,6 +49,7 @@ impl AuxNetwork {
heartbeat_interval: f64, heartbeat_interval: f64,
attempt_idx: usize, attempt_idx: usize,
info: ConnectionInfo, info: ConnectionInfo,
#[cfg(feature = "receive")] ssrc_signalling: Arc<SsrcTracker>,
) -> Self { ) -> Self {
Self { Self {
rx: evt_rx, rx: evt_rx,
@@ -58,6 +64,9 @@ impl AuxNetwork {
attempt_idx, attempt_idx,
info, info,
#[cfg(feature = "receive")]
ssrc_signalling,
} }
} }
@@ -186,6 +195,11 @@ impl AuxNetwork {
fn process_ws(&mut self, interconnect: &Interconnect, value: GatewayEvent) { fn process_ws(&mut self, interconnect: &Interconnect, value: GatewayEvent) {
match value { match value {
GatewayEvent::Speaking(ev) => { GatewayEvent::Speaking(ev) => {
#[cfg(feature = "receive")]
if let Some(user_id) = &ev.user_id {
self.ssrc_signalling.user_ssrc_map.insert(*user_id, ev.ssrc);
}
drop(interconnect.events.send(EventMessage::FireCoreEvent( drop(interconnect.events.send(EventMessage::FireCoreEvent(
CoreContext::SpeakingStateUpdate(ev), CoreContext::SpeakingStateUpdate(ev),
))); )));
@@ -194,6 +208,11 @@ impl AuxNetwork {
debug!("Received discontinued ClientConnect: {:?}", ev); debug!("Received discontinued ClientConnect: {:?}", ev);
}, },
GatewayEvent::ClientDisconnect(ev) => { GatewayEvent::ClientDisconnect(ev) => {
#[cfg(feature = "receive")]
{
self.ssrc_signalling.disconnected_users.insert(ev.user_id);
}
drop(interconnect.events.send(EventMessage::FireCoreEvent( drop(interconnect.events.send(EventMessage::FireCoreEvent(
CoreContext::ClientDisconnect(ev), CoreContext::ClientDisconnect(ev),
))); )));
@@ -217,26 +236,9 @@ impl AuxNetwork {
} }
} }
#[instrument(skip(interconnect, ws_client))] #[instrument(skip(interconnect, aux))]
pub(crate) async fn runner( pub(crate) async fn runner(mut interconnect: Interconnect, mut aux: AuxNetwork) {
mut interconnect: Interconnect,
evt_rx: Receiver<WsMessage>,
ws_client: WsStream,
ssrc: u32,
heartbeat_interval: f64,
attempt_idx: usize,
info: ConnectionInfo,
) {
trace!("WS thread started."); trace!("WS thread started.");
let mut aux = AuxNetwork::new(
evt_rx,
ws_client,
ssrc,
heartbeat_interval,
attempt_idx,
info,
);
aux.run(&mut interconnect).await; aux.run(&mut interconnect).await;
trace!("WS thread finished."); trace!("WS thread finished.");
} }