//! Handlers for sending packets over sharded connections. use crate::{error::JoinResult, id::*}; use async_trait::async_trait; #[cfg(feature = "serenity")] use dashmap::DashMap; use derivative::Derivative; #[cfg(feature = "serenity")] use futures::channel::mpsc::{TrySendError, UnboundedSender as Sender}; #[cfg(feature = "serenity")] use parking_lot::{lock_api::RwLockWriteGuard, Mutex as PMutex, RwLock as PRwLock}; #[cfg(feature = "serenity")] use serde_json::json; #[cfg(feature = "serenity")] use serenity::gateway::InterMessage; #[cfg(feature = "serenity")] use std::result::Result as StdResult; use std::sync::Arc; #[cfg(feature = "serenity")] use tracing::{debug, error}; #[cfg(feature = "twilight")] use twilight_gateway::MessageSender; #[cfg(feature = "twilight")] use twilight_model::gateway::payload::outgoing::update_voice_state::UpdateVoiceState as TwilightVoiceState; /// Map containing [`MessageSender`]s for Twilight. /// /// [`MessageSender`]: twilight_gateway::MessageSender #[cfg(feature = "twilight")] #[derive(Debug)] pub struct TwilightMap { map: std::collections::HashMap, } #[cfg(feature = "twilight")] impl TwilightMap { /// Construct a map of shards and command senders to those shards. /// /// For correctness all shards should be in the map. pub fn new(map: std::collections::HashMap) -> Self { TwilightMap { map } } /// Get the message sender for `shard_id`. pub fn get(&self, shard_id: u64) -> Option<&MessageSender> { self.map.get(&shard_id) } /// Get the total number of shards in the map. pub fn shard_count(&self) -> u64 { self.map.len() as u64 } } #[derive(Derivative)] #[derivative(Debug)] #[non_exhaustive] /// Source of individual shard connection handles. pub enum Sharder { #[cfg(feature = "serenity")] /// Serenity-specific wrapper for sharder state initialised by the library. Serenity(SerenitySharder), #[cfg(feature = "twilight")] /// Twilight-specific wrapper for a map of command senders. Twilight(Arc), /// A generic shard handle source. Generic(#[derivative(Debug = "ignore")] Arc), } /// Trait for a generic shard cluster or other handle source. /// /// This allows any Discord library to be integrated with Songbird, and offers a source /// of generic shard handles. #[async_trait] pub trait GenericSharder { /// Get access to a new shard fn get_shard(&self, shard_id: u64) -> Option>; } impl Sharder { /// Returns a new handle to the required inner shard. #[allow(clippy::must_use_candidate)] // get_or_insert_shard_handle has side effects pub fn get_shard(&self, shard_id: u64) -> Option { match self { #[cfg(feature = "serenity")] Sharder::Serenity(s) => Some(Shard::Serenity( s.get_or_insert_shard_handle(shard_id as u32), )), #[cfg(feature = "twilight")] Sharder::Twilight(t) => Some(Shard::Twilight(t.clone(), shard_id)), Sharder::Generic(src) => src.get_shard(shard_id).map(Shard::Generic), } } } #[cfg(feature = "serenity")] impl Sharder { #[allow(unreachable_patterns)] pub(crate) fn register_shard_handle(&self, shard_id: u32, sender: Sender) { if let Sharder::Serenity(s) = self { s.register_shard_handle(shard_id, sender); } else { error!("Called serenity management function on a non-serenity Songbird instance."); } } #[allow(unreachable_patterns)] pub(crate) fn deregister_shard_handle(&self, shard_id: u32) { if let Sharder::Serenity(s) = self { s.deregister_shard_handle(shard_id); } else { error!("Called serenity management function on a non-serenity Songbird instance."); } } } #[cfg(feature = "serenity")] #[derive(Debug, Default)] /// Serenity-specific wrapper for sharder state initialised by the library. /// /// This is updated and maintained by the library, and is designed to prevent /// message loss during rebalances and reconnects. pub struct SerenitySharder(DashMap>); #[cfg(feature = "serenity")] impl SerenitySharder { fn get_or_insert_shard_handle(&self, shard_id: u32) -> Arc { self.0.entry(shard_id).or_default().clone() } fn register_shard_handle(&self, shard_id: u32, sender: Sender) { // Write locks are only used to add new entries to the map. let handle = self.get_or_insert_shard_handle(shard_id); handle.register(sender); } fn deregister_shard_handle(&self, shard_id: u32) { // Write locks are only used to add new entries to the map. let handle = self.get_or_insert_shard_handle(shard_id); handle.deregister(); } } #[derive(Derivative, Clone)] #[derivative(Debug)] #[non_exhaustive] /// A reference to an individual websocket connection. pub enum Shard { #[cfg(feature = "serenity")] /// Handle to one of serenity's shard runners. Serenity(Arc), #[cfg(feature = "twilight")] /// Handle to a map of twilight command senders. Twilight(Arc, u64), /// Handle to a generic shard instance. Generic(#[derivative(Debug = "ignore")] Arc), } #[async_trait] impl VoiceUpdate for Shard { async fn update_voice_state( &self, guild_id: GuildId, channel_id: Option, self_deaf: bool, self_mute: bool, ) -> JoinResult<()> { match self { #[cfg(feature = "serenity")] Shard::Serenity(handle) => { let map = json!({ "op": 4, "d": { "channel_id": channel_id.map(|c| c.0), "guild_id": guild_id.0, "self_deaf": self_deaf, "self_mute": self_mute, } }); handle.send(InterMessage::json(map.to_string()))?; Ok(()) }, #[cfg(feature = "twilight")] Shard::Twilight(map, shard_id) => { let channel_id = channel_id.map(|c| c.0).map(From::from); let cmd = TwilightVoiceState::new(guild_id.0, channel_id, self_deaf, self_mute); let sender = map .get(*shard_id) .ok_or(crate::error::JoinError::NoSender)?; sender.command(&cmd)?; Ok(()) }, Shard::Generic(g) => g.update_voice_state(guild_id, channel_id, self_deaf, self_mute) .await, } } } /// Trait for a generic shard handle to send voice state updates to Discord. /// /// This allows any Discord library to be integrated with Songbird, and is intended to /// wrap a message channel to a single shard. Songbird only needs to send `VoiceStateUpdate`s /// to Discord to function. /// /// Generic libraries must be sure to call [`Call::update_server`] and [`Call::update_state`] /// in response to their own received messages. /// /// [`Call::update_server`]: crate::Call::update_server /// [`Call::update_state`]: crate::Call::update_state #[async_trait] pub trait VoiceUpdate { /// Send a voice update message to the inner shard handle. async fn update_voice_state( &self, guild_id: GuildId, channel_id: Option, self_deaf: bool, self_mute: bool, ) -> JoinResult<()>; } #[cfg(feature = "serenity")] /// Handle to an individual shard designed to buffer unsent messages while /// a reconnect/rebalance is ongoing. #[derive(Debug, Default)] pub struct SerenityShardHandle { sender: PRwLock>>, queue: PMutex>, } #[cfg(feature = "serenity")] impl SerenityShardHandle { fn register(&self, sender: Sender) { debug!("Adding shard handle send channel..."); let mut sender_lock = self.sender.write(); *sender_lock = Some(sender); debug!("Added shard handle send channel."); let sender_lock = RwLockWriteGuard::downgrade(sender_lock); let mut messages_lock = self.queue.lock(); debug!("Clearing queued messages..."); if let Some(sender) = &*sender_lock { let mut i = 0; for msg in messages_lock.drain(..) { if let Err(e) = sender.unbounded_send(msg) { error!("Error while clearing gateway message queue: {:?}", e); break; } i += 1; } if i > 0 { debug!("{} buffered messages sent to Serenity.", i); } } debug!("Cleared queued messages."); } fn deregister(&self) { debug!("Removing shard handle send channel..."); let mut sender_lock = self.sender.write(); *sender_lock = None; debug!("Removed shard handle send channel."); } fn send(&self, message: InterMessage) -> StdResult<(), TrySendError> { let sender_lock = self.sender.read(); if let Some(sender) = &*sender_lock { sender.unbounded_send(message) } else { debug!("Serenity shard temporarily disconnected: buffering message..."); let mut messages_lock = self.queue.lock(); messages_lock.push(message); debug!("Buffered message."); Ok(()) } } }