Files
songbird/src/shards.rs
2025-02-21 12:41:44 +00:00

297 lines
9.7 KiB
Rust

//! Handlers for sending packets over sharded connections.
use crate::{error::JoinResult, id::*};
use async_trait::async_trait;
#[cfg(feature = "serenity")]
use dashmap::DashMap;
use derivative::Derivative;
#[cfg(feature = "serenity")]
use futures::channel::mpsc::{TrySendError, UnboundedSender as Sender};
#[cfg(feature = "serenity")]
use parking_lot::{lock_api::RwLockWriteGuard, Mutex as PMutex, RwLock as PRwLock};
#[cfg(feature = "serenity")]
use serde_json::json;
#[cfg(feature = "serenity")]
use serenity::gateway::ShardRunnerMessage;
#[cfg(feature = "serenity")]
use std::result::Result as StdResult;
use std::sync::Arc;
#[cfg(feature = "serenity")]
use tracing::{debug, error};
#[cfg(feature = "twilight")]
use twilight_gateway::MessageSender;
#[cfg(feature = "twilight")]
use twilight_model::gateway::payload::outgoing::update_voice_state::UpdateVoiceState as TwilightVoiceState;
/// Map containing [`MessageSender`]s for Twilight.
///
/// [`MessageSender`]: twilight_gateway::MessageSender
#[cfg(feature = "twilight")]
#[derive(Debug)]
pub struct TwilightMap {
map: std::collections::HashMap<u32, MessageSender>,
}
#[cfg(feature = "twilight")]
impl TwilightMap {
/// Construct a map of shards and command senders to those shards.
///
/// For correctness all shards should be in the map.
#[must_use]
pub fn new(map: std::collections::HashMap<u32, MessageSender>) -> Self {
TwilightMap { map }
}
/// Get the message sender for `shard_id`.
#[must_use]
pub fn get(&self, shard_id: u32) -> Option<&MessageSender> {
self.map.get(&shard_id)
}
/// Get the total number of shards in the map.
#[must_use]
pub fn shard_count(&self) -> u64 {
self.map.len() as u64
}
}
#[derive(Derivative)]
#[derivative(Debug)]
#[non_exhaustive]
/// Source of individual shard connection handles.
pub enum Sharder {
#[cfg(feature = "serenity")]
/// Serenity-specific wrapper for sharder state initialised by the library.
Serenity(SerenitySharder),
#[cfg(feature = "twilight")]
/// Twilight-specific wrapper for a map of command senders.
Twilight(Arc<TwilightMap>),
/// A generic shard handle source.
Generic(#[derivative(Debug = "ignore")] Arc<dyn GenericSharder + Send + Sync>),
}
/// Trait for a generic shard cluster or other handle source.
///
/// This allows any Discord library to be integrated with Songbird, and offers a source
/// of generic shard handles.
#[async_trait]
pub trait GenericSharder {
/// Get access to a new shard
fn get_shard(&self, shard_id: u64) -> Option<Arc<dyn VoiceUpdate + Send + Sync>>;
}
impl Sharder {
/// Returns a new handle to the required inner shard.
#[allow(clippy::must_use_candidate)] // get_or_insert_shard_handle has side effects
pub fn get_shard(&self, shard_id: u64) -> Option<Shard> {
match self {
#[cfg(feature = "serenity")]
Sharder::Serenity(s) => Some(Shard::Serenity(
s.get_or_insert_shard_handle(shard_id as u32),
)),
#[cfg(feature = "twilight")]
Sharder::Twilight(t) => Some(Shard::Twilight(t.clone(), shard_id as u32)),
Sharder::Generic(src) => src.get_shard(shard_id).map(Shard::Generic),
}
}
}
#[cfg(feature = "serenity")]
impl Sharder {
#[allow(unreachable_patterns)]
pub(crate) fn register_shard_handle(&self, shard_id: u32, sender: Sender<ShardRunnerMessage>) {
if let Sharder::Serenity(s) = self {
s.register_shard_handle(shard_id, sender);
} else {
error!("Called serenity management function on a non-serenity Songbird instance.");
}
}
#[allow(unreachable_patterns)]
pub(crate) fn deregister_shard_handle(&self, shard_id: u32) {
if let Sharder::Serenity(s) = self {
s.deregister_shard_handle(shard_id);
} else {
error!("Called serenity management function on a non-serenity Songbird instance.");
}
}
}
#[cfg(feature = "serenity")]
#[derive(Debug, Default)]
/// Serenity-specific wrapper for sharder state initialised by the library.
///
/// This is updated and maintained by the library, and is designed to prevent
/// message loss during rebalances and reconnects.
pub struct SerenitySharder(DashMap<u32, Arc<SerenityShardHandle>>);
#[cfg(feature = "serenity")]
impl SerenitySharder {
fn get_or_insert_shard_handle(&self, shard_id: u32) -> Arc<SerenityShardHandle> {
self.0.entry(shard_id).or_default().clone()
}
fn register_shard_handle(&self, shard_id: u32, sender: Sender<ShardRunnerMessage>) {
// Write locks are only used to add new entries to the map.
let handle = self.get_or_insert_shard_handle(shard_id);
handle.register(sender);
}
fn deregister_shard_handle(&self, shard_id: u32) {
// Write locks are only used to add new entries to the map.
let handle = self.get_or_insert_shard_handle(shard_id);
handle.deregister();
}
}
#[derive(Derivative, Clone)]
#[derivative(Debug)]
#[non_exhaustive]
/// A reference to an individual websocket connection.
pub enum Shard {
#[cfg(feature = "serenity")]
/// Handle to one of serenity's shard runners.
Serenity(Arc<SerenityShardHandle>),
#[cfg(feature = "twilight")]
/// Handle to a map of twilight command senders.
Twilight(Arc<TwilightMap>, u32),
/// Handle to a generic shard instance.
Generic(#[derivative(Debug = "ignore")] Arc<dyn VoiceUpdate + Send + Sync>),
}
#[async_trait]
impl VoiceUpdate for Shard {
async fn update_voice_state(
&self,
guild_id: GuildId,
channel_id: Option<ChannelId>,
self_deaf: bool,
self_mute: bool,
) -> JoinResult<()> {
match self {
#[cfg(feature = "serenity")]
Shard::Serenity(handle) => {
let map = json!({
"op": 4,
"d": {
"channel_id": channel_id.map(|c| c.0),
"guild_id": guild_id.0,
"self_deaf": self_deaf,
"self_mute": self_mute,
}
});
handle.send(ShardRunnerMessage::Message(map.to_string().into()))?;
Ok(())
},
#[cfg(feature = "twilight")]
Shard::Twilight(map, shard_id) => {
let channel_id = channel_id.map(|c| c.0).map(From::from);
let cmd = TwilightVoiceState::new(guild_id.0, channel_id, self_deaf, self_mute);
let sender = map
.get(*shard_id)
.ok_or(crate::error::JoinError::NoSender)?;
sender.command(&cmd)?;
Ok(())
},
Shard::Generic(g) =>
g.update_voice_state(guild_id, channel_id, self_deaf, self_mute)
.await,
}
}
}
/// Trait for a generic shard handle to send voice state updates to Discord.
///
/// This allows any Discord library to be integrated with Songbird, and is intended to
/// wrap a message channel to a single shard. Songbird only needs to send `VoiceStateUpdate`s
/// to Discord to function.
///
/// Generic libraries must be sure to call [`Call::update_server`] and [`Call::update_state`]
/// in response to their own received messages.
///
/// [`Call::update_server`]: crate::Call::update_server
/// [`Call::update_state`]: crate::Call::update_state
#[async_trait]
pub trait VoiceUpdate {
/// Send a voice update message to the inner shard handle.
async fn update_voice_state(
&self,
guild_id: GuildId,
channel_id: Option<ChannelId>,
self_deaf: bool,
self_mute: bool,
) -> JoinResult<()>;
}
#[cfg(feature = "serenity")]
/// Handle to an individual shard designed to buffer unsent messages while
/// a reconnect/rebalance is ongoing.
#[derive(Debug, Default)]
pub struct SerenityShardHandle {
sender: PRwLock<Option<Sender<ShardRunnerMessage>>>,
queue: PMutex<Vec<ShardRunnerMessage>>,
}
#[cfg(feature = "serenity")]
impl SerenityShardHandle {
fn register(&self, sender: Sender<ShardRunnerMessage>) {
debug!("Adding shard handle send channel...");
let mut sender_lock = self.sender.write();
*sender_lock = Some(sender);
debug!("Added shard handle send channel.");
let sender_lock = RwLockWriteGuard::downgrade(sender_lock);
let mut messages_lock = self.queue.lock();
debug!("Clearing queued messages...");
if let Some(sender) = &*sender_lock {
let mut i = 0;
for msg in messages_lock.drain(..) {
if let Err(e) = sender.unbounded_send(msg) {
error!("Error while clearing gateway message queue: {:?}", e);
break;
}
i += 1;
}
if i > 0 {
debug!("{} buffered messages sent to Serenity.", i);
}
}
debug!("Cleared queued messages.");
}
fn deregister(&self) {
debug!("Removing shard handle send channel...");
let mut sender_lock = self.sender.write();
*sender_lock = None;
debug!("Removed shard handle send channel.");
}
fn send(
&self,
message: ShardRunnerMessage,
) -> StdResult<(), Box<TrySendError<ShardRunnerMessage>>> {
let sender_lock = self.sender.read();
if let Some(sender) = &*sender_lock {
sender.unbounded_send(message).map_err(Box::new)
} else {
debug!("Serenity shard temporarily disconnected: buffering message...");
let mut messages_lock = self.queue.lock();
messages_lock.push(message);
debug!("Buffered message.");
Ok(())
}
}
}