Driver: Prune SsrcState after timeout/disconnect (#145)
`SsrcState` objects are created on a per-user basis when "receive" is enabled, but were previously never destroyed. This PR adds some shared dashmaps for the WS task to communicate SSRC-to-ID mappings to the UDP Rx task, as well as any disconnections. Additionally, decoder state is pruned a default 1 minute after a user last speaks. This was tested using `cargo make ready` and via `examples/serenity/voice_receive/`. Closes #133
This commit is contained in:
@@ -49,6 +49,13 @@ pub struct Config {
|
|||||||
/// [user speaking events]: crate::events::CoreEvent::SpeakingUpdate
|
/// [user speaking events]: crate::events::CoreEvent::SpeakingUpdate
|
||||||
pub decode_mode: DecodeMode,
|
pub decode_mode: DecodeMode,
|
||||||
|
|
||||||
|
#[cfg(all(feature = "driver", feature = "receive"))]
|
||||||
|
/// Configures the amount of time after a user/SSRC is inactive before their decoder state
|
||||||
|
/// should be removed.
|
||||||
|
///
|
||||||
|
/// Defaults to 1 minute.
|
||||||
|
pub decode_state_timeout: Duration,
|
||||||
|
|
||||||
#[cfg(feature = "gateway")]
|
#[cfg(feature = "gateway")]
|
||||||
/// Configures the amount of time to wait for Discord to reply with connection information
|
/// Configures the amount of time to wait for Discord to reply with connection information
|
||||||
/// if [`Call::join`]/[`join_gateway`] are used.
|
/// if [`Call::join`]/[`join_gateway`] are used.
|
||||||
@@ -155,6 +162,8 @@ impl Default for Config {
|
|||||||
crypto_mode: CryptoMode::Normal,
|
crypto_mode: CryptoMode::Normal,
|
||||||
#[cfg(all(feature = "driver", feature = "receive"))]
|
#[cfg(all(feature = "driver", feature = "receive"))]
|
||||||
decode_mode: DecodeMode::Decrypt,
|
decode_mode: DecodeMode::Decrypt,
|
||||||
|
#[cfg(all(feature = "driver", feature = "receive"))]
|
||||||
|
decode_state_timeout: Duration::from_secs(60),
|
||||||
#[cfg(feature = "gateway")]
|
#[cfg(feature = "gateway")]
|
||||||
gateway_timeout: Some(Duration::from_secs(10)),
|
gateway_timeout: Some(Duration::from_secs(10)),
|
||||||
#[cfg(feature = "driver")]
|
#[cfg(feature = "driver")]
|
||||||
@@ -198,6 +207,14 @@ impl Config {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "receive")]
|
||||||
|
/// Sets this `Config`'s received packet decoder cleanup timer.
|
||||||
|
#[must_use]
|
||||||
|
pub fn decode_state_timeout(mut self, decode_state_timeout: Duration) -> Self {
|
||||||
|
self.decode_state_timeout = decode_state_timeout;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
/// Sets this `Config`'s audio mixing channel count.
|
/// Sets this `Config`'s audio mixing channel count.
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn mix_mode(mut self, mix_mode: MixMode) -> Self {
|
pub fn mix_mode(mut self, mix_mode: MixMode) -> Self {
|
||||||
|
|||||||
@@ -3,7 +3,10 @@ pub mod error;
|
|||||||
#[cfg(feature = "receive")]
|
#[cfg(feature = "receive")]
|
||||||
use super::tasks::udp_rx;
|
use super::tasks::udp_rx;
|
||||||
use super::{
|
use super::{
|
||||||
tasks::{message::*, ws as ws_task},
|
tasks::{
|
||||||
|
message::*,
|
||||||
|
ws::{self as ws_task, AuxNetwork},
|
||||||
|
},
|
||||||
Config,
|
Config,
|
||||||
CryptoMode,
|
CryptoMode,
|
||||||
};
|
};
|
||||||
@@ -21,6 +24,8 @@ use discortp::discord::{IpDiscoveryPacket, IpDiscoveryType, MutableIpDiscoveryPa
|
|||||||
use error::{Error, Result};
|
use error::{Error, Result};
|
||||||
use flume::Sender;
|
use flume::Sender;
|
||||||
use socket2::Socket;
|
use socket2::Socket;
|
||||||
|
#[cfg(feature = "receive")]
|
||||||
|
use std::sync::Arc;
|
||||||
use std::{net::IpAddr, str::FromStr};
|
use std::{net::IpAddr, str::FromStr};
|
||||||
use tokio::{net::UdpSocket, spawn, time::timeout};
|
use tokio::{net::UdpSocket, spawn, time::timeout};
|
||||||
use tracing::{debug, info, instrument};
|
use tracing::{debug, info, instrument};
|
||||||
@@ -217,15 +222,21 @@ impl Connection {
|
|||||||
.mixer
|
.mixer
|
||||||
.send(MixerMessage::SetConn(mix_conn, ready.ssrc))?;
|
.send(MixerMessage::SetConn(mix_conn, ready.ssrc))?;
|
||||||
|
|
||||||
spawn(ws_task::runner(
|
#[cfg(feature = "receive")]
|
||||||
interconnect.clone(),
|
let ssrc_tracker = Arc::new(SsrcTracker::default());
|
||||||
|
|
||||||
|
let ws_state = AuxNetwork::new(
|
||||||
ws_msg_rx,
|
ws_msg_rx,
|
||||||
client,
|
client,
|
||||||
ssrc,
|
ssrc,
|
||||||
hello.heartbeat_interval,
|
hello.heartbeat_interval,
|
||||||
idx,
|
idx,
|
||||||
info.clone(),
|
info.clone(),
|
||||||
));
|
#[cfg(feature = "receive")]
|
||||||
|
ssrc_tracker.clone(),
|
||||||
|
);
|
||||||
|
|
||||||
|
spawn(ws_task::runner(interconnect.clone(), ws_state));
|
||||||
|
|
||||||
#[cfg(feature = "receive")]
|
#[cfg(feature = "receive")]
|
||||||
spawn(udp_rx::runner(
|
spawn(udp_rx::runner(
|
||||||
@@ -234,6 +245,7 @@ impl Connection {
|
|||||||
cipher,
|
cipher,
|
||||||
config.clone(),
|
config.clone(),
|
||||||
udp_rx,
|
udp_rx,
|
||||||
|
ssrc_tracker,
|
||||||
));
|
));
|
||||||
|
|
||||||
Ok(Connection {
|
Ok(Connection {
|
||||||
|
|||||||
@@ -2,8 +2,16 @@
|
|||||||
|
|
||||||
use super::Interconnect;
|
use super::Interconnect;
|
||||||
use crate::driver::Config;
|
use crate::driver::Config;
|
||||||
|
use dashmap::{DashMap, DashSet};
|
||||||
|
use serenity_voice_model::id::UserId;
|
||||||
|
|
||||||
pub enum UdpRxMessage {
|
pub enum UdpRxMessage {
|
||||||
SetConfig(Config),
|
SetConfig(Config),
|
||||||
ReplaceInterconnect(Interconnect),
|
ReplaceInterconnect(Interconnect),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Default)]
|
||||||
|
pub struct SsrcTracker {
|
||||||
|
pub disconnected_users: DashSet<UserId>,
|
||||||
|
pub user_ssrc_map: DashMap<UserId, u32>,
|
||||||
|
}
|
||||||
|
|||||||
@@ -22,8 +22,8 @@ use discortp::{
|
|||||||
PacketSize,
|
PacketSize,
|
||||||
};
|
};
|
||||||
use flume::Receiver;
|
use flume::Receiver;
|
||||||
use std::{collections::HashMap, convert::TryInto};
|
use std::{collections::HashMap, convert::TryInto, sync::Arc, time::Duration};
|
||||||
use tokio::{net::UdpSocket, select};
|
use tokio::{net::UdpSocket, select, time::Instant};
|
||||||
use tracing::{error, instrument, trace, warn};
|
use tracing::{error, instrument, trace, warn};
|
||||||
use xsalsa20poly1305::XSalsa20Poly1305 as Cipher;
|
use xsalsa20poly1305::XSalsa20Poly1305 as Cipher;
|
||||||
|
|
||||||
@@ -33,6 +33,8 @@ struct SsrcState {
|
|||||||
decoder: OpusDecoder,
|
decoder: OpusDecoder,
|
||||||
last_seq: u16,
|
last_seq: u16,
|
||||||
decode_size: PacketDecodeSize,
|
decode_size: PacketDecodeSize,
|
||||||
|
prune_time: Instant,
|
||||||
|
disconnected: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
||||||
@@ -84,13 +86,21 @@ enum SpeakingDelta {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl SsrcState {
|
impl SsrcState {
|
||||||
fn new(pkt: &RtpPacket<'_>) -> Self {
|
fn new(pkt: &RtpPacket<'_>, state_timeout: Duration) -> Self {
|
||||||
Self {
|
Self {
|
||||||
silent_frame_count: 5, // We do this to make the first speech packet fire an event.
|
silent_frame_count: 5, // We do this to make the first speech packet fire an event.
|
||||||
decoder: OpusDecoder::new(SAMPLE_RATE, Channels::Stereo)
|
decoder: OpusDecoder::new(SAMPLE_RATE, Channels::Stereo)
|
||||||
.expect("Failed to create new Opus decoder for source."),
|
.expect("Failed to create new Opus decoder for source."),
|
||||||
last_seq: pkt.get_sequence().into(),
|
last_seq: pkt.get_sequence().into(),
|
||||||
decode_size: PacketDecodeSize::TwentyMillis,
|
decode_size: PacketDecodeSize::TwentyMillis,
|
||||||
|
prune_time: Instant::now() + state_timeout,
|
||||||
|
disconnected: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn refresh_timer(&mut self, state_timeout: Duration) {
|
||||||
|
if !self.disconnected {
|
||||||
|
self.prune_time = Instant::now() + state_timeout;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -236,21 +246,23 @@ impl SsrcState {
|
|||||||
struct UdpRx {
|
struct UdpRx {
|
||||||
cipher: Cipher,
|
cipher: Cipher,
|
||||||
decoder_map: HashMap<u32, SsrcState>,
|
decoder_map: HashMap<u32, SsrcState>,
|
||||||
#[allow(dead_code)]
|
|
||||||
config: Config,
|
config: Config,
|
||||||
packet_buffer: [u8; VOICE_PACKET_MAX],
|
packet_buffer: [u8; VOICE_PACKET_MAX],
|
||||||
rx: Receiver<UdpRxMessage>,
|
rx: Receiver<UdpRxMessage>,
|
||||||
|
ssrc_signalling: Arc<SsrcTracker>,
|
||||||
udp_socket: UdpSocket,
|
udp_socket: UdpSocket,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl UdpRx {
|
impl UdpRx {
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
async fn run(&mut self, interconnect: &mut Interconnect) {
|
async fn run(&mut self, interconnect: &mut Interconnect) {
|
||||||
|
let mut cleanup_time = Instant::now();
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
select! {
|
select! {
|
||||||
Ok((len, _addr)) = self.udp_socket.recv_from(&mut self.packet_buffer[..]) => {
|
Ok((len, _addr)) = self.udp_socket.recv_from(&mut self.packet_buffer[..]) => {
|
||||||
self.process_udp_message(interconnect, len);
|
self.process_udp_message(interconnect, len);
|
||||||
}
|
},
|
||||||
msg = self.rx.recv_async() => {
|
msg = self.rx.recv_async() => {
|
||||||
match msg {
|
match msg {
|
||||||
Ok(UdpRxMessage::ReplaceInterconnect(i)) => {
|
Ok(UdpRxMessage::ReplaceInterconnect(i)) => {
|
||||||
@@ -261,7 +273,41 @@ impl UdpRx {
|
|||||||
},
|
},
|
||||||
Err(flume::RecvError::Disconnected) => break,
|
Err(flume::RecvError::Disconnected) => break,
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
_ = tokio::time::sleep_until(cleanup_time) => {
|
||||||
|
// periodic cleanup.
|
||||||
|
let now = Instant::now();
|
||||||
|
|
||||||
|
// check ssrc map to see if the WS task has informed us of any disconnects.
|
||||||
|
loop {
|
||||||
|
// This is structured in an odd way to prevent deadlocks.
|
||||||
|
// while-let seemed to keep the dashmap iter() alive for block scope, rather than
|
||||||
|
// just the initialiser.
|
||||||
|
let id = {
|
||||||
|
if let Some(id) = self.ssrc_signalling.disconnected_users.iter().next().map(|v| *v.key()) {
|
||||||
|
id
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let _ = self.ssrc_signalling.disconnected_users.remove(&id);
|
||||||
|
if let Some((_, ssrc)) = self.ssrc_signalling.user_ssrc_map.remove(&id) {
|
||||||
|
if let Some(state) = self.decoder_map.get_mut(&ssrc) {
|
||||||
|
// don't cleanup immediately: leave for later cycle
|
||||||
|
// this is key with reorder/jitter buffers where we may
|
||||||
|
// still need to decode post disconnect for ~0.2s.
|
||||||
|
state.prune_time = now + Duration::from_secs(1);
|
||||||
|
state.disconnected = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// now remove all dead ssrcs.
|
||||||
|
self.decoder_map.retain(|_, v| v.prune_time > now);
|
||||||
|
|
||||||
|
cleanup_time = now + Duration::from_secs(5);
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -310,7 +356,11 @@ impl UdpRx {
|
|||||||
let entry = self
|
let entry = self
|
||||||
.decoder_map
|
.decoder_map
|
||||||
.entry(rtp.get_ssrc())
|
.entry(rtp.get_ssrc())
|
||||||
.or_insert_with(|| SsrcState::new(&rtp));
|
.or_insert_with(|| SsrcState::new(&rtp, self.config.decode_state_timeout));
|
||||||
|
|
||||||
|
// Only do this on RTP, rather than RTCP -- this pins decoder state liveness
|
||||||
|
// to *speech* rather than just presence.
|
||||||
|
entry.refresh_timer(self.config.decode_state_timeout);
|
||||||
|
|
||||||
if let Ok((delta, audio)) = entry.process(
|
if let Ok((delta, audio)) = entry.process(
|
||||||
&rtp,
|
&rtp,
|
||||||
@@ -396,6 +446,7 @@ pub(crate) async fn runner(
|
|||||||
cipher: Cipher,
|
cipher: Cipher,
|
||||||
config: Config,
|
config: Config,
|
||||||
udp_socket: UdpSocket,
|
udp_socket: UdpSocket,
|
||||||
|
ssrc_signalling: Arc<SsrcTracker>,
|
||||||
) {
|
) {
|
||||||
trace!("UDP receive handle started.");
|
trace!("UDP receive handle started.");
|
||||||
|
|
||||||
@@ -405,6 +456,7 @@ pub(crate) async fn runner(
|
|||||||
config,
|
config,
|
||||||
packet_buffer: [0u8; VOICE_PACKET_MAX],
|
packet_buffer: [0u8; VOICE_PACKET_MAX],
|
||||||
rx,
|
rx,
|
||||||
|
ssrc_signalling,
|
||||||
udp_socket,
|
udp_socket,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ use crate::{
|
|||||||
};
|
};
|
||||||
use flume::Receiver;
|
use flume::Receiver;
|
||||||
use rand::random;
|
use rand::random;
|
||||||
|
#[cfg(feature = "receive")]
|
||||||
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::{
|
use tokio::{
|
||||||
select,
|
select,
|
||||||
@@ -21,7 +23,7 @@ use tokio::{
|
|||||||
use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
|
use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
|
||||||
use tracing::{debug, info, instrument, trace, warn};
|
use tracing::{debug, info, instrument, trace, warn};
|
||||||
|
|
||||||
struct AuxNetwork {
|
pub(crate) struct AuxNetwork {
|
||||||
rx: Receiver<WsMessage>,
|
rx: Receiver<WsMessage>,
|
||||||
ws_client: WsStream,
|
ws_client: WsStream,
|
||||||
dont_send: bool,
|
dont_send: bool,
|
||||||
@@ -34,6 +36,9 @@ struct AuxNetwork {
|
|||||||
|
|
||||||
attempt_idx: usize,
|
attempt_idx: usize,
|
||||||
info: ConnectionInfo,
|
info: ConnectionInfo,
|
||||||
|
|
||||||
|
#[cfg(feature = "receive")]
|
||||||
|
ssrc_signalling: Arc<SsrcTracker>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AuxNetwork {
|
impl AuxNetwork {
|
||||||
@@ -44,6 +49,7 @@ impl AuxNetwork {
|
|||||||
heartbeat_interval: f64,
|
heartbeat_interval: f64,
|
||||||
attempt_idx: usize,
|
attempt_idx: usize,
|
||||||
info: ConnectionInfo,
|
info: ConnectionInfo,
|
||||||
|
#[cfg(feature = "receive")] ssrc_signalling: Arc<SsrcTracker>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
rx: evt_rx,
|
rx: evt_rx,
|
||||||
@@ -58,6 +64,9 @@ impl AuxNetwork {
|
|||||||
|
|
||||||
attempt_idx,
|
attempt_idx,
|
||||||
info,
|
info,
|
||||||
|
|
||||||
|
#[cfg(feature = "receive")]
|
||||||
|
ssrc_signalling,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -186,6 +195,11 @@ impl AuxNetwork {
|
|||||||
fn process_ws(&mut self, interconnect: &Interconnect, value: GatewayEvent) {
|
fn process_ws(&mut self, interconnect: &Interconnect, value: GatewayEvent) {
|
||||||
match value {
|
match value {
|
||||||
GatewayEvent::Speaking(ev) => {
|
GatewayEvent::Speaking(ev) => {
|
||||||
|
#[cfg(feature = "receive")]
|
||||||
|
if let Some(user_id) = &ev.user_id {
|
||||||
|
self.ssrc_signalling.user_ssrc_map.insert(*user_id, ev.ssrc);
|
||||||
|
}
|
||||||
|
|
||||||
drop(interconnect.events.send(EventMessage::FireCoreEvent(
|
drop(interconnect.events.send(EventMessage::FireCoreEvent(
|
||||||
CoreContext::SpeakingStateUpdate(ev),
|
CoreContext::SpeakingStateUpdate(ev),
|
||||||
)));
|
)));
|
||||||
@@ -194,6 +208,11 @@ impl AuxNetwork {
|
|||||||
debug!("Received discontinued ClientConnect: {:?}", ev);
|
debug!("Received discontinued ClientConnect: {:?}", ev);
|
||||||
},
|
},
|
||||||
GatewayEvent::ClientDisconnect(ev) => {
|
GatewayEvent::ClientDisconnect(ev) => {
|
||||||
|
#[cfg(feature = "receive")]
|
||||||
|
{
|
||||||
|
self.ssrc_signalling.disconnected_users.insert(ev.user_id);
|
||||||
|
}
|
||||||
|
|
||||||
drop(interconnect.events.send(EventMessage::FireCoreEvent(
|
drop(interconnect.events.send(EventMessage::FireCoreEvent(
|
||||||
CoreContext::ClientDisconnect(ev),
|
CoreContext::ClientDisconnect(ev),
|
||||||
)));
|
)));
|
||||||
@@ -217,26 +236,9 @@ impl AuxNetwork {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip(interconnect, ws_client))]
|
#[instrument(skip(interconnect, aux))]
|
||||||
pub(crate) async fn runner(
|
pub(crate) async fn runner(mut interconnect: Interconnect, mut aux: AuxNetwork) {
|
||||||
mut interconnect: Interconnect,
|
|
||||||
evt_rx: Receiver<WsMessage>,
|
|
||||||
ws_client: WsStream,
|
|
||||||
ssrc: u32,
|
|
||||||
heartbeat_interval: f64,
|
|
||||||
attempt_idx: usize,
|
|
||||||
info: ConnectionInfo,
|
|
||||||
) {
|
|
||||||
trace!("WS thread started.");
|
trace!("WS thread started.");
|
||||||
let mut aux = AuxNetwork::new(
|
|
||||||
evt_rx,
|
|
||||||
ws_client,
|
|
||||||
ssrc,
|
|
||||||
heartbeat_interval,
|
|
||||||
attempt_idx,
|
|
||||||
info,
|
|
||||||
);
|
|
||||||
|
|
||||||
aux.run(&mut interconnect).await;
|
aux.run(&mut interconnect).await;
|
||||||
trace!("WS thread finished.");
|
trace!("WS thread finished.");
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user