commit 7e4392ae68f97311f2389fdf8835e70a25912ff3 Author: Kyle Simpson Date: Thu Oct 29 20:25:20 2020 +0000 Voice Rework -- Events, Track Queues (#806) This implements a proof-of-concept for an improved audio frontend. The largest change is the introduction of events and event handling: both by time elapsed and by track events, such as ending or looping. Following on from this, the library now includes a basic, event-driven track queue system (which people seem to ask for unusually often). A new sample, `examples/13_voice_events`, demonstrates both the `TrackQueue` system and some basic events via the `~queue` and `~play_fade` commands. Locks are removed from around the control of `Audio` objects, which should allow the backend to be moved to a more granular futures-based backend solution in a cleaner way. diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..8b12454 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,155 @@ +[package] +authors = ["Kyle Simpson "] +description = "An async Rust library for the Discord voice API." +documentation = "https://docs.rs/songbird" +edition = "2018" +homepage = "https://github.com/serenity-rs/serenity" +include = ["src/**/*.rs", "Cargo.toml"] +keywords = ["discord", "api", "rtp", "audio"] +license = "ISC" +name = "songbird" +readme = "README.md" +repository = "https://github.com/serenity-rs/serenity.git" +version = "0.1.0" + +[dependencies] +serde = { version = "1", features = ["derive"] } +serde_json = "1" +tracing = "0.1" +tracing-futures = "0.2" + +[dependencies.async-trait] +optional = true +version = "0.1" + +[dependencies.async-tungstenite] +default-features = false +features = ["tokio-runtime"] +optional = true +version = "0.9" + +[dependencies.audiopus] +optional = true +version = "0.2" + +[dependencies.byteorder] +optional = true +version = "1" + +[dependencies.discortp] +features = ["discord-full"] +optional = true +version = "0.2" + +[dependencies.flume] +optional = true +version = "0.9" + +[dependencies.futures] +version = "0.3" + +[dependencies.parking_lot] +optional = true +version = "0.11" + +[dependencies.rand] +optional = true +version = "0.7" + +[dependencies.serenity] +optional = true +features = ["voice", "gateway"] +path = "../" +version = "0.9.0-rc.2" + +[dependencies.serenity-voice-model] +optional = true +path = "../voice-model" +version = "0.9.0-rc.2" + +[dependencies.spin_sleep] +optional = true +version = "1" + +[dependencies.streamcatcher] +optional = true +version = "0.1" + +[dependencies.tokio] +optional = true +version = "0.2" +default-features = false + +[dependencies.twilight-gateway] +optional = true +version = "0.1" +default-features = false + +[dependencies.twilight-model] +optional = true +version = "0.1" +default-features = false + +[dependencies.url] +optional = true +version = "2" + +[dependencies.xsalsa20poly1305] +optional = true +version = "0.5" + +[dev-dependencies] +criterion = "0.3" +utils = { path = "utils" } + +[features] +default = [ + "serenity-rustls", + "driver", + "gateway", +] +gateway = [ + "flume", + "parking_lot", + "tokio/sync", +] +driver = [ + "async-trait", + "async-tungstenite", + "audiopus", + "byteorder", + "discortp", + "flume", + "parking_lot", + "rand", + "serenity-voice-model", + "spin_sleep", + "streamcatcher", + "tokio/fs", + "tokio/io-util", + "tokio/net", + "tokio/rt-core", + "tokio/time", + "tokio/process", + "tokio/sync", + "url", + "xsalsa20poly1305", +] +rustls = ["async-tungstenite/tokio-rustls"] +native = ["async-tungstenite/tokio-native-tls"] +serenity-rustls = ["serenity/rustls_backend", "rustls", "gateway", "serenity-deps"] +serenity-native = ["serenity/native_tls_backend", "native", "gateway", "serenity-deps"] +twilight-rustls = ["twilight", "twilight-gateway/rustls", "rustls", "gateway"] +twilight-native = ["twilight", "twilight-gateway/native", "native", "gateway"] +twilight = ["twilight-model"] +simd-zlib = ["twilight-gateway/simd-zlib"] +stock-zlib = ["twilight-gateway/stock-zlib"] +serenity-deps = ["async-trait"] + +[[bench]] +name = "mixing" +path = "benches/mixing.rs" +harness = false + +[package.metadata.docs.rs] +all-features = true diff --git a/README.md b/README.md new file mode 100644 index 0000000..72d271b --- /dev/null +++ b/README.md @@ -0,0 +1,29 @@ +# Songbird + +![](songbird.png) + +Songbird is an async, cross-library compatible voice system for Discord, written in Rust. +The library offers: + * A standalone gateway frontend compatible with [serenity] and [twilight] using the + `"gateway"` and `"[serenity/twilight]-[rustls/native]"` features. You can even run + driverless, to help manage your [lavalink] sessions. + * A standalone driver for voice calls, via the `"driver"` feature. If you can create + a `ConnectionInfo` using any other gateway, or language for your bot, then you + can run the songbird voice driver. + * And, by default, a fully featured voice system featuring events, queues, RT(C)P packet + handling, seeking on compatible streams, shared multithreaded audio stream caches, + and direct Opus data passthrough from DCA files. + +## Examples +Full examples showing various types of functionality and integrations can be found as part of [serenity's examples], and in [this crate's examples directory]. + +## Attribution + +Songbird's logo is based upon the copyright-free image ["Black-Capped Chickadee"] by George Gorgas White. + +[serenity]: https://github.com/serenity-rs/serenity +[twilight]: https://github.com/twilight-rs/twilight +["Black-Capped Chickadee"]: https://www.oldbookillustrations.com/illustrations/black-capped-chickadee/ +[lavalink]: https://github.com/Frederikam/Lavalink +[serenity's examples]: https://github.com/serenity-rs/serenity/tree/current/examples +[this crate's examples directory]: https://github.com/serenity-rs/serenity/tree/current/songbird/examples diff --git a/benches/mixing.rs b/benches/mixing.rs new file mode 100644 index 0000000..7828bae --- /dev/null +++ b/benches/mixing.rs @@ -0,0 +1,30 @@ +use criterion::{black_box, criterion_group, criterion_main, BatchSize, Criterion}; +use songbird::{constants::*, input::Input}; + +pub fn mix_one_frame(c: &mut Criterion) { + let floats = utils::make_sine(STEREO_FRAME_SIZE, true); + let mut raw_buf = [0f32; STEREO_FRAME_SIZE]; + + c.bench_function("Mix stereo source", |b| { + b.iter_batched_ref( + || black_box(Input::float_pcm(true, floats.clone().into())), + |input| { + input.mix(black_box(&mut raw_buf), black_box(1.0)); + }, + BatchSize::SmallInput, + ) + }); + + c.bench_function("Mix mono source", |b| { + b.iter_batched_ref( + || black_box(Input::float_pcm(false, floats.clone().into())), + |input| { + input.mix(black_box(&mut raw_buf), black_box(1.0)); + }, + BatchSize::SmallInput, + ) + }); +} + +criterion_group!(benches, mix_one_frame); +criterion_main!(benches); diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..85b9e90 --- /dev/null +++ b/build.rs @@ -0,0 +1,23 @@ +#[cfg(all(feature = "driver", not(any(feature = "rustls", feature = "native"))))] +compile_error!( + "You have the `driver` feature enabled: \ + either the `rustls` or `native` feature must be + selected to let Songbird's driver use websockets.\n\ + - `rustls` uses Rustls, a pure Rust TLS-implemenation.\n\ + - `native` uses SChannel on Windows, Secure Transport on macOS, \ + and OpenSSL on other platforms.\n\ + If you are unsure, go with `rustls`." +); + +#[cfg(all( + feature = "twilight", + not(any(feature = "simd-zlib", feature = "stock-zlib")) +))] +compile_error!( + "Twilight requires you to specify a zlib backend: \ + either the `simd-zlib` or `stock-zlib` feature must be + selected.\n\ + If you are unsure, go with `stock-zlib`." +); + +fn main() {} diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000..45fcf68 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,3 @@ +# Songbird examples + +These examples show more advanced use of Songbird, or how to include Songbird in bots built on other libraries, such as twilight. \ No newline at end of file diff --git a/examples/twilight/Cargo.toml b/examples/twilight/Cargo.toml new file mode 100644 index 0000000..04473be --- /dev/null +++ b/examples/twilight/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "basic-twilight-bot" +version = "0.1.0" +authors = ["Twilight and Serenity Contributors"] +edition = "2018" + +[dependencies] +futures = "0.3" +tracing = "0.1" +tracing-subscriber = "0.2" +serde_json = { version = "1" } +tokio = { features = ["macros", "rt-threaded", "sync"], version = "0.2" } +twilight-gateway = "0.1" +twilight-http = "0.1" +twilight-model = "0.1" +twilight-standby = "0.1" + +[dependencies.songbird] +path = "../.." +default-features = false +features = ["twilight-rustls", "gateway", "driver", "stock-zlib"] diff --git a/examples/twilight/src/main.rs b/examples/twilight/src/main.rs new file mode 100644 index 0000000..d8a49b9 --- /dev/null +++ b/examples/twilight/src/main.rs @@ -0,0 +1,378 @@ +//! This example adapts Twilight's [basic lavalink bot] to use Songbird as its voice driver. +//! +//! # Twilight-rs attribution +//! ISC License (ISC) +//! +//! Copyright (c) 2019, 2020 (c) The Twilight Contributors +//! +//! Permission to use, copy, modify, and/or distribute this software for any purpose +//! with or without fee is hereby granted, provided that the above copyright notice +//! and this permission notice appear in all copies. +//! +//! THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH +//! REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND +//! FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, +//! INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS +//! OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER +//! TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF +//! THIS SOFTWARE. +//! +//! +//! [basic lavalink bot]: https://github.com/twilight-rs/twilight/tree/trunk/lavalink/examples/basic-lavalink-bot + +use futures::StreamExt; +use std::{collections::HashMap, env, error::Error, future::Future, sync::Arc}; +use songbird::{input::{Input, Restartable}, tracks::{PlayMode, TrackHandle}, Songbird}; +use tokio::sync::RwLock; +use twilight_gateway::{Cluster, Event}; +use twilight_http::Client as HttpClient; +use twilight_model::{channel::Message, gateway::payload::MessageCreate, id::GuildId}; +use twilight_standby::Standby; + +#[derive(Clone, Debug)] +struct State { + cluster: Cluster, + http: HttpClient, + trackdata: Arc>>, + songbird: Arc, + standby: Standby, +} + +fn spawn( + fut: impl Future>> + Send + 'static, +) { + tokio::spawn(async move { + if let Err(why) = fut.await { + tracing::debug!("handler error: {:?}", why); + } + }); +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + // Initialize the tracing subscriber. + tracing_subscriber::fmt::init(); + + let state = { + let token = env::var("DISCORD_TOKEN")?; + + let http = HttpClient::new(&token); + let user_id = http.current_user().await?.id; + + let cluster = Cluster::new(token).await?; + + let shard_count = cluster.shards().len(); + let songbird = Songbird::twilight(cluster.clone(), shard_count as u64, user_id); + + cluster.up().await; + + State { + cluster, + http, + trackdata: Default::default(), + songbird, + standby: Standby::new(), + } + }; + + let mut events = state.cluster.events(); + + while let Some(event) = events.next().await { + state.standby.process(&event.1); + state.songbird.process(&event.1).await; + + if let Event::MessageCreate(msg) = event.1 { + if msg.guild_id.is_none() || !msg.content.starts_with('!') { + continue; + } + + match msg.content.splitn(2, ' ').next() { + Some("!join") => spawn(join(msg.0, state.clone())), + Some("!leave") => spawn(leave(msg.0, state.clone())), + Some("!pause") => spawn(pause(msg.0, state.clone())), + Some("!play") => spawn(play(msg.0, state.clone())), + Some("!seek") => spawn(seek(msg.0, state.clone())), + Some("!stop") => spawn(stop(msg.0, state.clone())), + Some("!volume") => spawn(volume(msg.0, state.clone())), + _ => continue, + } + } + } + + Ok(()) +} + +async fn join(msg: Message, state: State) -> Result<(), Box> { + state + .http + .create_message(msg.channel_id) + .content("What's the channel ID you want me to join?")? + .await?; + + let author_id = msg.author.id; + let msg = state + .standby + .wait_for_message(msg.channel_id, move |new_msg: &MessageCreate| { + new_msg.author.id == author_id + }) + .await?; + let channel_id = msg.content.parse::()?; + + let guild_id = msg.guild_id.ok_or("Can't join a non-guild channel.")?; + + let (_handle, success) = state + .songbird + .join(guild_id, channel_id) + .await; + + let content = match success?.recv_async().await { + Ok(Ok(())) => format!("Joined <#{}>!", channel_id), + Ok(Err(e)) => format!("Failed to join <#{}>! Why: {:?}", channel_id, e), + _ => format!("Failed to join <#{}>: Gateway error!", channel_id), + }; + + state + .http + .create_message(msg.channel_id) + .content(content)? + .await?; + + Ok(()) +} + +async fn leave(msg: Message, state: State) -> Result<(), Box> { + tracing::debug!( + "leave command in channel {} by {}", + msg.channel_id, + msg.author.name + ); + + let guild_id = msg.guild_id.unwrap(); + + state + .songbird + .leave(guild_id) + .await?; + + state + .http + .create_message(msg.channel_id) + .content("Left the channel")? + .await?; + + Ok(()) +} + +async fn play(msg: Message, state: State) -> Result<(), Box> { + tracing::debug!( + "play command in channel {} by {}", + msg.channel_id, + msg.author.name + ); + state + .http + .create_message(msg.channel_id) + .content("What's the URL of the audio to play?")? + .await?; + + let author_id = msg.author.id; + let msg = state + .standby + .wait_for_message(msg.channel_id, move |new_msg: &MessageCreate| { + new_msg.author.id == author_id + }) + .await?; + + let guild_id = msg.guild_id.unwrap(); + + if let Ok(song) = Restartable::ytdl(msg.content.clone()) { + let input = Input::from(song); + + let content = format!( + "Playing **{:?}** by **{:?}**", + input.metadata.title.as_ref().unwrap_or(&"".to_string()), + input.metadata.artist.as_ref().unwrap_or(&"".to_string()), + ); + + state + .http + .create_message(msg.channel_id) + .content(content)? + .await?; + + if let Some(call_lock) = state.songbird.get(guild_id) { + let mut call = call_lock.lock().await; + let handle = call.play_source(input); + + let mut store = state.trackdata.write().await; + store.insert(guild_id, handle); + } + } else { + state + .http + .create_message(msg.channel_id) + .content("Didn't find any results")? + .await?; + } + + Ok(()) +} + +async fn pause(msg: Message, state: State) -> Result<(), Box> { + tracing::debug!( + "pause command in channel {} by {}", + msg.channel_id, + msg.author.name + ); + + let guild_id = msg.guild_id.unwrap(); + + let store = state.trackdata.read().await; + + let content = if let Some(handle) = store.get(&guild_id) { + let info = handle.get_info()? + .await?; + + let paused = match info.playing { + PlayMode::Play => { + let _success = handle.pause(); + false + } + _ => { + let _success = handle.play(); + true + } + }; + + let action = if paused { "Unpaused" } else { "Paused" }; + + format!("{} the track", action) + } else { + format!("No track to (un)pause!") + }; + + state + .http + .create_message(msg.channel_id) + .content(content)? + .await?; + + Ok(()) +} + +async fn seek(msg: Message, state: State) -> Result<(), Box> { + tracing::debug!( + "seek command in channel {} by {}", + msg.channel_id, + msg.author.name + ); + state + .http + .create_message(msg.channel_id) + .content("Where in the track do you want to seek to (in seconds)?")? + .await?; + + let author_id = msg.author.id; + let msg = state + .standby + .wait_for_message(msg.channel_id, move |new_msg: &MessageCreate| { + new_msg.author.id == author_id + }) + .await?; + let guild_id = msg.guild_id.unwrap(); + let position = msg.content.parse::()?; + + let store = state.trackdata.read().await; + + let content = if let Some(handle) = store.get(&guild_id) { + if handle.is_seekable() { + let _success = handle.seek_time(std::time::Duration::from_secs(position)); + format!("Seeked to {}s", position) + } else { + format!("Track is not compatible with seeking!") + } + } else { + format!("No track to seek over!") + }; + + state + .http + .create_message(msg.channel_id) + .content(content)? + .await?; + + Ok(()) +} + +async fn stop(msg: Message, state: State) -> Result<(), Box> { + tracing::debug!( + "stop command in channel {} by {}", + msg.channel_id, + msg.author.name + ); + + let guild_id = msg.guild_id.unwrap(); + + if let Some(call_lock) = state.songbird.get(guild_id) { + let mut call = call_lock.lock().await; + let _ = call.stop(); + } + + state + .http + .create_message(msg.channel_id) + .content("Stopped the track")? + .await?; + + Ok(()) +} + +async fn volume(msg: Message, state: State) -> Result<(), Box> { + tracing::debug!( + "volume command in channel {} by {}", + msg.channel_id, + msg.author.name + ); + state + .http + .create_message(msg.channel_id) + .content("What's the volume you want to set (0.0-10.0, 1.0 being the default)?")? + .await?; + + let author_id = msg.author.id; + let msg = state + .standby + .wait_for_message(msg.channel_id, move |new_msg: &MessageCreate| { + new_msg.author.id == author_id + }) + .await?; + let guild_id = msg.guild_id.unwrap(); + let volume = msg.content.parse::()?; + + if !volume.is_finite() || volume > 10.0 || volume < 0.0 { + state + .http + .create_message(msg.channel_id) + .content("Invalid volume!")? + .await?; + + return Ok(()); + } + + let store = state.trackdata.read().await; + + let content = if let Some(handle) = store.get(&guild_id) { + let _success = handle.set_volume(volume as f32); + format!("Set the volume to {}", volume) + } else { + format!("No track to change volume!") + }; + + state + .http + .create_message(msg.channel_id) + .content(content)? + .await?; + + Ok(()) +} diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..0e82264 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,6 @@ +imports_layout = "HorizontalVertical" +match_arm_blocks = false +match_block_trailing_comma = true +newline_style = "Unix" +use_field_init_shorthand = true +use_try_shorthand = true diff --git a/songbird-ico.png b/songbird-ico.png new file mode 100644 index 0000000..f9783d7 Binary files /dev/null and b/songbird-ico.png differ diff --git a/songbird.png b/songbird.png new file mode 100644 index 0000000..1870a97 Binary files /dev/null and b/songbird.png differ diff --git a/songbird.svg b/songbird.svg new file mode 100644 index 0000000..39e6e72 --- /dev/null +++ b/songbird.svg @@ -0,0 +1,22 @@ + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/constants.rs b/src/constants.rs new file mode 100644 index 0000000..d6b757a --- /dev/null +++ b/src/constants.rs @@ -0,0 +1,75 @@ +//! Constants affecting driver function and API handling. + +#[cfg(feature = "driver")] +use audiopus::{Bitrate, SampleRate}; +#[cfg(feature = "driver")] +use discortp::rtp::RtpType; +use std::time::Duration; + +#[cfg(feature = "driver")] +/// The voice gateway version used by the library. +pub const VOICE_GATEWAY_VERSION: u8 = crate::model::constants::GATEWAY_VERSION; + +#[cfg(feature = "driver")] +/// Sample rate of audio to be sent to Discord. +pub const SAMPLE_RATE: SampleRate = SampleRate::Hz48000; + +/// Sample rate of audio to be sent to Discord. +pub const SAMPLE_RATE_RAW: usize = 48_000; + +/// Number of audio frames/packets to be sent per second. +pub const AUDIO_FRAME_RATE: usize = 50; + +/// Length of time between any two audio frames. +pub const TIMESTEP_LENGTH: Duration = Duration::from_millis(1000 / AUDIO_FRAME_RATE as u64); + +#[cfg(feature = "driver")] +/// Default bitrate for audio. +pub const DEFAULT_BITRATE: Bitrate = Bitrate::BitsPerSecond(128_000); + +/// Number of samples in one complete frame of audio per channel. +/// +/// This is equally the number of stereo (joint) samples in an audio frame. +pub const MONO_FRAME_SIZE: usize = SAMPLE_RATE_RAW / AUDIO_FRAME_RATE; + +/// Number of individual samples in one complete frame of stereo audio. +pub const STEREO_FRAME_SIZE: usize = 2 * MONO_FRAME_SIZE; + +/// Number of bytes in one complete frame of raw `f32`-encoded mono audio. +pub const MONO_FRAME_BYTE_SIZE: usize = MONO_FRAME_SIZE * std::mem::size_of::(); + +/// Number of bytes in one complete frame of raw `f32`-encoded stereo audio. +pub const STEREO_FRAME_BYTE_SIZE: usize = STEREO_FRAME_SIZE * std::mem::size_of::(); + +/// Length (in milliseconds) of any audio frame. +pub const FRAME_LEN_MS: usize = 1000 / AUDIO_FRAME_RATE; + +/// Maximum number of audio frames/packets to be sent per second to be buffered. +pub const CHILD_BUFFER_LEN: usize = AUDIO_FRAME_RATE / 2; + +/// Maximum packet size for a voice packet. +/// +/// Set a safe amount below the Ethernet MTU to avoid fragmentation/rejection. +pub const VOICE_PACKET_MAX: usize = 1460; + +/// Delay between sends of UDP keepalive frames. +/// +/// Passive monitoring of Discord itself shows that these fire every 5 seconds +/// irrespective of outgoing UDP traffic. +pub const UDP_KEEPALIVE_GAP_MS: u64 = 5_000; + +/// Type-converted delay between sends of UDP keepalive frames. +/// +/// Passive monitoring of Discord itself shows that these fire every 5 seconds +/// irrespective of outgoing UDP traffic. +pub const UDP_KEEPALIVE_GAP: Duration = Duration::from_millis(UDP_KEEPALIVE_GAP_MS); + +/// Opus silent frame, used to signal speech start and end (and prevent audio glitching). +pub const SILENT_FRAME: [u8; 3] = [0xf8, 0xff, 0xfe]; + +/// The one (and only) RTP version. +pub const RTP_VERSION: u8 = 2; + +#[cfg(feature = "driver")] +/// Profile type used by Discord's Opus audio traffic. +pub const RTP_PROFILE_TYPE: RtpType = RtpType::Dynamic(120); diff --git a/src/driver/config.rs b/src/driver/config.rs new file mode 100644 index 0000000..c5349b6 --- /dev/null +++ b/src/driver/config.rs @@ -0,0 +1,10 @@ +use super::CryptoMode; + +/// Configuration for the inner Driver. +/// +/// At present, this cannot be changed. +#[derive(Clone, Debug, Default)] +pub struct Config { + /// Selected tagging mode for voice packet encryption. + pub crypto_mode: Option, +} diff --git a/src/driver/connection/error.rs b/src/driver/connection/error.rs new file mode 100644 index 0000000..cb6f8c3 --- /dev/null +++ b/src/driver/connection/error.rs @@ -0,0 +1,105 @@ +//! Connection errors and convenience types. + +use crate::{ + driver::tasks::{error::Recipient, message::*}, + ws::Error as WsError, +}; +use flume::SendError; +use serde_json::Error as JsonError; +use std::{error::Error as ErrorTrait, fmt, io::Error as IoError}; +use xsalsa20poly1305::aead::Error as CryptoError; + +/// Errors encountered while connecting to a Discord voice server over the driver. +#[derive(Debug)] +pub enum Error { + /// An error occurred during [en/de]cryption of voice packets or key generation. + Crypto(CryptoError), + /// Server did not return the expected crypto mode during negotiation. + CryptoModeInvalid, + /// Selected crypto mode was not offered by server. + CryptoModeUnavailable, + /// An indicator that an endpoint URL was invalid. + EndpointUrl, + /// Discord hello/ready handshake was violated. + ExpectedHandshake, + /// Discord failed to correctly respond to IP discovery. + IllegalDiscoveryResponse, + /// Could not parse Discord's view of our IP. + IllegalIp, + /// Miscellaneous I/O error. + Io(IoError), + /// JSON (de)serialization error. + Json(JsonError), + /// Failed to message other background tasks after connection establishment. + InterconnectFailure(Recipient), + /// Error communicating with gateway server over WebSocket. + Ws(WsError), +} + +impl From for Error { + fn from(e: CryptoError) -> Self { + Error::Crypto(e) + } +} + +impl From for Error { + fn from(e: IoError) -> Error { + Error::Io(e) + } +} + +impl From for Error { + fn from(e: JsonError) -> Error { + Error::Json(e) + } +} + +impl From> for Error { + fn from(_e: SendError) -> Error { + Error::InterconnectFailure(Recipient::AuxNetwork) + } +} + +impl From> for Error { + fn from(_e: SendError) -> Error { + Error::InterconnectFailure(Recipient::Event) + } +} + +impl From> for Error { + fn from(_e: SendError) -> Error { + Error::InterconnectFailure(Recipient::Mixer) + } +} + +impl From for Error { + fn from(e: WsError) -> Error { + Error::Ws(e) + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Failed to connect to Discord RTP server: ")?; + use Error::*; + match self { + Crypto(c) => write!(f, "cryptography error {}.", c), + CryptoModeInvalid => write!(f, "server changed negotiated encryption mode."), + CryptoModeUnavailable => write!(f, "server did not offer chosen encryption mode."), + EndpointUrl => write!(f, "endpoint URL received from gateway was invalid."), + ExpectedHandshake => write!(f, "voice initialisation protocol was violated."), + IllegalDiscoveryResponse => + write!(f, "IP discovery/NAT punching response was invalid."), + IllegalIp => write!(f, "IP discovery/NAT punching response had bad IP value."), + Io(i) => write!(f, "I/O failure ({}).", i), + Json(j) => write!(f, "JSON (de)serialization issue ({}).", j), + InterconnectFailure(r) => write!(f, "failed to contact other task ({:?})", r), + Ws(w) => write!(f, "websocket issue ({:?}).", w), + } + } +} + +impl ErrorTrait for Error {} + +/// Convenience type for Discord voice/driver connection error handling. +pub type Result = std::result::Result; diff --git a/src/driver/connection/mod.rs b/src/driver/connection/mod.rs new file mode 100644 index 0000000..ee5a416 --- /dev/null +++ b/src/driver/connection/mod.rs @@ -0,0 +1,321 @@ +pub mod error; + +use super::{ + tasks::{message::*, udp_rx, udp_tx, ws as ws_task}, + Config, + CryptoMode, +}; +use crate::{ + constants::*, + model::{ + payload::{Identify, Resume, SelectProtocol}, + Event as GatewayEvent, + ProtocolData, + }, + ws::{self, ReceiverExt, SenderExt, WsStream}, + ConnectionInfo, +}; +use discortp::discord::{IpDiscoveryPacket, IpDiscoveryType, MutableIpDiscoveryPacket}; +use error::{Error, Result}; +use flume::Sender; +use std::{net::IpAddr, str::FromStr}; +use tokio::net::UdpSocket; +use tracing::{debug, info, instrument}; +use url::Url; +use xsalsa20poly1305::{aead::NewAead, XSalsa20Poly1305 as Cipher}; + +#[cfg(all(feature = "rustls", not(feature = "native")))] +use ws::create_rustls_client; + +#[cfg(feature = "native")] +use ws::create_native_tls_client; + +pub(crate) struct Connection { + pub(crate) info: ConnectionInfo, + pub(crate) ws: Sender, +} + +impl Connection { + pub(crate) async fn new( + mut info: ConnectionInfo, + interconnect: &Interconnect, + config: &Config, + ) -> Result { + let crypto_mode = config.crypto_mode.unwrap_or(CryptoMode::Normal); + + let url = generate_url(&mut info.endpoint)?; + + #[cfg(all(feature = "rustls", not(feature = "native")))] + let mut client = create_rustls_client(url).await?; + + #[cfg(feature = "native")] + let mut client = create_native_tls_client(url).await?; + + let mut hello = None; + let mut ready = None; + + client + .send_json(&GatewayEvent::from(Identify { + server_id: info.guild_id.into(), + session_id: info.session_id.clone(), + token: info.token.clone(), + user_id: info.user_id.into(), + })) + .await?; + + loop { + let value = match client.recv_json().await? { + Some(value) => value, + None => continue, + }; + + match value { + GatewayEvent::Ready(r) => { + ready = Some(r); + if hello.is_some() { + break; + } + }, + GatewayEvent::Hello(h) => { + hello = Some(h); + if ready.is_some() { + break; + } + }, + other => { + debug!("Expected ready/hello; got: {:?}", other); + + return Err(Error::ExpectedHandshake); + }, + } + } + + let hello = + hello.expect("Hello packet expected in connection initialisation, but not found."); + let ready = + ready.expect("Ready packet expected in connection initialisation, but not found."); + + if !has_valid_mode(&ready.modes, crypto_mode) { + return Err(Error::CryptoModeUnavailable); + } + + let mut udp = UdpSocket::bind("0.0.0.0:0").await?; + udp.connect((ready.ip, ready.port)).await?; + + // Follow Discord's IP Discovery procedures, in case NAT tunnelling is needed. + let mut bytes = [0; IpDiscoveryPacket::const_packet_size()]; + { + let mut view = MutableIpDiscoveryPacket::new(&mut bytes[..]).expect( + "Too few bytes in 'bytes' for IPDiscovery packet.\ + (Blame: IpDiscoveryPacket::const_packet_size()?)", + ); + view.set_pkt_type(IpDiscoveryType::Request); + view.set_length(70); + view.set_ssrc(ready.ssrc); + } + + udp.send(&bytes).await?; + + let (len, _addr) = udp.recv_from(&mut bytes).await?; + { + let view = + IpDiscoveryPacket::new(&bytes[..len]).ok_or(Error::IllegalDiscoveryResponse)?; + + if view.get_pkt_type() != IpDiscoveryType::Response { + return Err(Error::IllegalDiscoveryResponse); + } + + // We could do something clever like binary search, + // but possibility of UDP spoofing preclueds us from + // making the assumption we can find a "left edge" of '\0's. + let nul_byte_index = view + .get_address_raw() + .iter() + .position(|&b| b == 0) + .ok_or(Error::IllegalIp)?; + + let address_str = std::str::from_utf8(&view.get_address_raw()[..nul_byte_index]) + .map_err(|_| Error::IllegalIp)?; + + let address = IpAddr::from_str(&address_str).map_err(|e| { + println!("{:?}", e); + Error::IllegalIp + })?; + + client + .send_json(&GatewayEvent::from(SelectProtocol { + protocol: "udp".into(), + data: ProtocolData { + address, + mode: crypto_mode.to_request_str().into(), + port: view.get_port(), + }, + })) + .await?; + } + + let cipher = init_cipher(&mut client, crypto_mode).await?; + + info!("Connected to: {}", info.endpoint); + + info!("WS heartbeat duration {}ms.", hello.heartbeat_interval,); + + let (ws_msg_tx, ws_msg_rx) = flume::unbounded(); + let (udp_sender_msg_tx, udp_sender_msg_rx) = flume::unbounded(); + let (udp_receiver_msg_tx, udp_receiver_msg_rx) = flume::unbounded(); + let (udp_rx, udp_tx) = udp.split(); + + let ssrc = ready.ssrc; + + let mix_conn = MixerConnection { + cipher: cipher.clone(), + udp_rx: udp_receiver_msg_tx, + udp_tx: udp_sender_msg_tx, + }; + + interconnect + .mixer + .send(MixerMessage::Ws(Some(ws_msg_tx.clone())))?; + + interconnect + .mixer + .send(MixerMessage::SetConn(mix_conn, ready.ssrc))?; + + tokio::spawn(ws_task::runner( + interconnect.clone(), + ws_msg_rx, + client, + ssrc, + hello.heartbeat_interval, + )); + + tokio::spawn(udp_rx::runner( + interconnect.clone(), + udp_receiver_msg_rx, + cipher, + crypto_mode, + udp_rx, + )); + tokio::spawn(udp_tx::runner(udp_sender_msg_rx, ssrc, udp_tx)); + + Ok(Connection { + info, + ws: ws_msg_tx, + }) + } + + #[instrument(skip(self))] + pub async fn reconnect(&mut self) -> Result<()> { + let url = generate_url(&mut self.info.endpoint)?; + + // Thread may have died, we want to send to prompt a clean exit + // (if at all possible) and then proceed as normal. + + #[cfg(all(feature = "rustls", not(feature = "native")))] + let mut client = create_rustls_client(url).await?; + + #[cfg(feature = "native")] + let mut client = create_native_tls_client(url).await?; + + client + .send_json(&GatewayEvent::from(Resume { + server_id: self.info.guild_id.into(), + session_id: self.info.session_id.clone(), + token: self.info.token.clone(), + })) + .await?; + + let mut hello = None; + let mut resumed = None; + + loop { + let value = match client.recv_json().await? { + Some(value) => value, + None => continue, + }; + + match value { + GatewayEvent::Resumed => { + resumed = Some(()); + if hello.is_some() { + break; + } + }, + GatewayEvent::Hello(h) => { + hello = Some(h); + if resumed.is_some() { + break; + } + }, + other => { + debug!("Expected resumed/hello; got: {:?}", other); + + return Err(Error::ExpectedHandshake); + }, + } + } + + let hello = + hello.expect("Hello packet expected in connection initialisation, but not found."); + + self.ws + .send(WsMessage::SetKeepalive(hello.heartbeat_interval))?; + self.ws.send(WsMessage::Ws(Box::new(client)))?; + + info!("Reconnected to: {}", &self.info.endpoint); + Ok(()) + } +} + +impl Drop for Connection { + fn drop(&mut self) { + info!("Disconnected"); + } +} + +fn generate_url(endpoint: &mut String) -> Result { + if endpoint.ends_with(":80") { + let len = endpoint.len(); + + endpoint.truncate(len - 3); + } + + Url::parse(&format!("wss://{}/?v={}", endpoint, VOICE_GATEWAY_VERSION)) + .or(Err(Error::EndpointUrl)) +} + +#[inline] +async fn init_cipher(client: &mut WsStream, mode: CryptoMode) -> Result { + loop { + let value = match client.recv_json().await? { + Some(value) => value, + None => continue, + }; + + match value { + GatewayEvent::SessionDescription(desc) => { + if desc.mode != mode.to_request_str() { + return Err(Error::CryptoModeInvalid); + } + + return Ok(Cipher::new_varkey(&desc.secret_key)?); + }, + other => { + debug!( + "Expected ready for key; got: op{}/v{:?}", + other.kind() as u8, + other + ); + }, + } + } +} + +#[inline] +fn has_valid_mode(modes: It, mode: CryptoMode) -> bool +where + T: for<'a> PartialEq<&'a str>, + It: IntoIterator, +{ + modes.into_iter().any(|s| s == mode.to_request_str()) +} diff --git a/src/driver/crypto.rs b/src/driver/crypto.rs new file mode 100644 index 0000000..e7a306d --- /dev/null +++ b/src/driver/crypto.rs @@ -0,0 +1,38 @@ +//! Encryption schemes supported by Discord's secure RTP negotiation. + +/// Variants of the XSalsa20Poly1305 encryption scheme. +/// +/// At present, only `Normal` is supported or selectable. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[non_exhaustive] +pub enum Mode { + /// The RTP header is used as the source of nonce bytes for the packet. + /// + /// Equivalent to a nonce of at most 48b (6B) at no extra packet overhead: + /// the RTP sequence number and timestamp are the varying quantities. + Normal, + /// An additional random 24B suffix is used as the source of nonce bytes for the packet. + /// + /// Full nonce width of 24B (192b), at an extra 24B per packet (~1.2 kB/s). + Suffix, + /// An additional random 24B suffix is used as the source of nonce bytes for the packet. + /// + /// Nonce width of 4B (32b), at an extra 4B per packet (~0.2 kB/s). + Lite, +} + +impl Mode { + /// Returns the name of a mode as it will appear during negotiation. + pub fn to_request_str(self) -> &'static str { + use Mode::*; + match self { + Normal => "xsalsa20_poly1305", + Suffix => "xsalsa20_poly1305_suffix", + Lite => "xsalsa20_poly1305_lite", + } + } +} + +// TODO: implement encrypt + decrypt + nonce selection for each. +// This will probably need some research into correct handling of +// padding, reported length, SRTP profiles, and so on. diff --git a/src/driver/mod.rs b/src/driver/mod.rs new file mode 100644 index 0000000..cd148bc --- /dev/null +++ b/src/driver/mod.rs @@ -0,0 +1,233 @@ +//! Runner for a voice connection. +//! +//! Songbird's driver is a mixed-sync system, using: +//! * Asynchronous connection management, event-handling, and gateway integration. +//! * Synchronous audio mixing, packet generation, and encoding. +//! +//! This splits up work according to its IO/compute bound nature, preventing packet +//! generation from being slowed down past its deadline, or from affecting other +//! asynchronous tasks your bot must handle. + +mod config; +pub(crate) mod connection; +mod crypto; +pub(crate) mod tasks; + +pub use config::Config; +use connection::error::Result; +pub use crypto::Mode as CryptoMode; + +use crate::{ + events::EventData, + input::Input, + tracks::{Track, TrackHandle}, + ConnectionInfo, + Event, + EventHandler, +}; +use audiopus::Bitrate; +use flume::{Receiver, SendError, Sender}; +use tasks::message::CoreMessage; +use tracing::instrument; + +/// The control object for a Discord voice connection, handling connection, +/// mixing, encoding, en/decryption, and event generation. +#[derive(Clone, Debug)] +pub struct Driver { + config: Config, + self_mute: bool, + sender: Sender, +} + +impl Driver { + /// Creates a new voice driver. + /// + /// This will create the core voice tasks in the background. + #[inline] + pub fn new(config: Config) -> Self { + let sender = Self::start_inner(config.clone()); + + Driver { + config, + self_mute: false, + sender, + } + } + + fn start_inner(config: Config) -> Sender { + let (tx, rx) = flume::unbounded(); + + tasks::start(config, rx, tx.clone()); + + tx + } + + fn restart_inner(&mut self) { + self.sender = Self::start_inner(self.config.clone()); + + self.mute(self.self_mute); + } + + /// Connects to a voice channel using the specified server. + #[instrument(skip(self))] + pub fn connect(&mut self, info: ConnectionInfo) -> Receiver> { + let (tx, rx) = flume::bounded(1); + + self.raw_connect(info, tx); + + rx + } + + /// Connects to a voice channel using the specified server. + #[instrument(skip(self))] + pub(crate) fn raw_connect(&mut self, info: ConnectionInfo, tx: Sender>) { + self.send(CoreMessage::ConnectWithResult(info, tx)); + } + + /// Leaves the current voice channel, disconnecting from it. + /// + /// This does *not* forget settings, like whether to be self-deafened or + /// self-muted. + #[instrument(skip(self))] + pub fn leave(&mut self) { + self.send(CoreMessage::Disconnect); + } + + /// Sets whether the current connection is to be muted. + /// + /// If there is no live voice connection, then this only acts as a settings + /// update for future connections. + #[instrument(skip(self))] + pub fn mute(&mut self, mute: bool) { + self.self_mute = mute; + self.send(CoreMessage::Mute(mute)); + } + + /// Returns whether the driver is muted (i.e., processes audio internally + /// but submits none). + #[instrument(skip(self))] + pub fn is_mute(&self) -> bool { + self.self_mute + } + + /// Plays audio from a source, returning a handle for further control. + /// + /// This can be a source created via [`ffmpeg`] or [`ytdl`]. + /// + /// [`ffmpeg`]: ../input/fn.ffmpeg.html + /// [`ytdl`]: ../input/fn.ytdl.html + #[instrument(skip(self))] + pub fn play_source(&mut self, source: Input) -> TrackHandle { + let (player, handle) = super::create_player(source); + self.send(CoreMessage::AddTrack(player)); + + handle + } + + /// Plays audio from a source, returning a handle for further control. + /// + /// Unlike [`play_source`], this stops all other sources attached + /// to the channel. + /// + /// [`play_source`]: #method.play_source + #[instrument(skip(self))] + pub fn play_only_source(&mut self, source: Input) -> TrackHandle { + let (player, handle) = super::create_player(source); + self.send(CoreMessage::SetTrack(Some(player))); + + handle + } + + /// Plays audio from a [`Track`] object. + /// + /// This will be one half of the return value of [`create_player`]. + /// The main difference between this function and [`play_source`] is + /// that this allows for direct manipulation of the [`Track`] object + /// before it is passed over to the voice and mixing contexts. + /// + /// [`create_player`]: ../tracks/fn.create_player.html + /// [`Track`]: ../tracks/struct.Track.html + /// [`play_source`]: #method.play_source + #[instrument(skip(self))] + pub fn play(&mut self, track: Track) { + self.send(CoreMessage::AddTrack(track)); + } + + /// Exclusively plays audio from a [`Track`] object. + /// + /// This will be one half of the return value of [`create_player`]. + /// As in [`play_only_source`], this stops all other sources attached to the + /// channel. Like [`play`], however, this allows for direct manipulation of the + /// [`Track`] object before it is passed over to the voice and mixing contexts. + /// + /// [`create_player`]: ../tracks/fn.create_player.html + /// [`Track`]: ../tracks/struct.Track.html + /// [`play_only_source`]: #method.play_only_source + /// [`play`]: #method.play + #[instrument(skip(self))] + pub fn play_only(&mut self, track: Track) { + self.send(CoreMessage::SetTrack(Some(track))); + } + + /// Sets the bitrate for encoding Opus packets sent along + /// the channel being managed. + /// + /// The default rate is 128 kbps. + /// Sensible values range between `Bits(512)` and `Bits(512_000)` + /// bits per second. + /// Alternatively, `Auto` and `Max` remain available. + #[instrument(skip(self))] + pub fn set_bitrate(&mut self, bitrate: Bitrate) { + self.send(CoreMessage::SetBitrate(bitrate)) + } + + /// Stops playing audio from all sources, if any are set. + #[instrument(skip(self))] + pub fn stop(&mut self) { + self.send(CoreMessage::SetTrack(None)) + } + + /// Attach a global event handler to an audio context. Global events may receive + /// any [`EventContext`]. + /// + /// Global timing events will tick regardless of whether audio is playing, + /// so long as the bot is connected to a voice channel, and have no tracks. + /// [`TrackEvent`]s will respond to all relevant tracks, giving some audio elements. + /// + /// Users **must** ensure that no costly work or blocking occurs + /// within the supplied function or closure. *Taking excess time could prevent + /// timely sending of packets, causing audio glitches and delays*. + /// + /// [`Track`]: ../tracks/struct.Track.html + /// [`TrackEvent`]: ../events/enum.TrackEvent.html + /// [`EventContext`]: ../events/enum.EventContext.html + #[instrument(skip(self, action))] + pub fn add_global_event(&mut self, event: Event, action: F) { + self.send(CoreMessage::AddEvent(EventData::new(event, action))); + } + + /// Sends a message to the inner tasks, restarting it if necessary. + fn send(&mut self, status: CoreMessage) { + // Restart thread if it errored. + if let Err(SendError(status)) = self.sender.send(status) { + self.restart_inner(); + + self.sender.send(status).unwrap(); + } + } +} + +impl Default for Driver { + fn default() -> Self { + Self::new(Default::default()) + } +} + +impl Drop for Driver { + /// Leaves the current connected voice channel, if connected to one, and + /// forgets all configurations relevant to this Handler. + fn drop(&mut self) { + self.leave(); + let _ = self.sender.send(CoreMessage::Poison); + } +} diff --git a/src/driver/tasks/error.rs b/src/driver/tasks/error.rs new file mode 100644 index 0000000..c9e1fdb --- /dev/null +++ b/src/driver/tasks/error.rs @@ -0,0 +1,97 @@ +use super::message::*; +use crate::ws::Error as WsError; +use audiopus::Error as OpusError; +use flume::SendError; +use std::io::Error as IoError; +use xsalsa20poly1305::aead::Error as CryptoError; + +#[derive(Debug)] +pub enum Recipient { + AuxNetwork, + Event, + Mixer, + UdpRx, + UdpTx, +} + +pub type Result = std::result::Result; + +#[derive(Debug)] +pub enum Error { + Crypto(CryptoError), + /// Received an illegal voice packet on the voice UDP socket. + IllegalVoicePacket, + InterconnectFailure(Recipient), + Io(IoError), + Opus(OpusError), + Ws(WsError), +} + +impl Error { + pub(crate) fn should_trigger_connect(&self) -> bool { + matches!( + self, + Error::InterconnectFailure(Recipient::AuxNetwork) + | Error::InterconnectFailure(Recipient::UdpRx) + | Error::InterconnectFailure(Recipient::UdpTx) + ) + } + + pub(crate) fn should_trigger_interconnect_rebuild(&self) -> bool { + matches!(self, Error::InterconnectFailure(Recipient::Event)) + } +} + +impl From for Error { + fn from(e: CryptoError) -> Self { + Error::Crypto(e) + } +} + +impl From for Error { + fn from(e: IoError) -> Error { + Error::Io(e) + } +} + +impl From for Error { + fn from(e: OpusError) -> Error { + Error::Opus(e) + } +} + +impl From> for Error { + fn from(_e: SendError) -> Error { + Error::InterconnectFailure(Recipient::AuxNetwork) + } +} + +impl From> for Error { + fn from(_e: SendError) -> Error { + Error::InterconnectFailure(Recipient::Event) + } +} + +impl From> for Error { + fn from(_e: SendError) -> Error { + Error::InterconnectFailure(Recipient::Mixer) + } +} + +impl From> for Error { + fn from(_e: SendError) -> Error { + Error::InterconnectFailure(Recipient::UdpRx) + } +} + +impl From> for Error { + fn from(_e: SendError) -> Error { + Error::InterconnectFailure(Recipient::UdpTx) + } +} + +impl From for Error { + fn from(e: WsError) -> Error { + Error::Ws(e) + } +} diff --git a/src/driver/tasks/events.rs b/src/driver/tasks/events.rs new file mode 100644 index 0000000..bb28895 --- /dev/null +++ b/src/driver/tasks/events.rs @@ -0,0 +1,118 @@ +use super::message::*; +use crate::{ + events::{EventStore, GlobalEvents, TrackEvent}, + tracks::{TrackHandle, TrackState}, +}; +use flume::Receiver; +use tracing::{debug, info, instrument, trace}; + +#[instrument(skip(_interconnect, evt_rx))] +pub(crate) async fn runner(_interconnect: Interconnect, evt_rx: Receiver) { + let mut global = GlobalEvents::default(); + + let mut events: Vec = vec![]; + let mut states: Vec = vec![]; + let mut handles: Vec = vec![]; + + loop { + use EventMessage::*; + match evt_rx.recv_async().await { + Ok(AddGlobalEvent(data)) => { + info!("Global event added."); + global.add_event(data); + }, + Ok(AddTrackEvent(i, data)) => { + info!("Adding event to track {}.", i); + + let event_store = events + .get_mut(i) + .expect("Event thread was given an illegal store index for AddTrackEvent."); + let state = states + .get_mut(i) + .expect("Event thread was given an illegal state index for AddTrackEvent."); + + event_store.add_event(data, state.position); + }, + Ok(FireCoreEvent(ctx)) => { + let ctx = ctx.to_user_context(); + let evt = ctx + .to_core_event() + .expect("Event thread was passed a non-core event in FireCoreEvent."); + + trace!("Firing core event {:?}.", evt); + + global.fire_core_event(evt, ctx).await; + }, + Ok(AddTrack(store, state, handle)) => { + events.push(store); + states.push(state); + handles.push(handle); + + info!("Event state for track {} added", events.len()); + }, + Ok(ChangeState(i, change)) => { + use TrackStateChange::*; + + let max_states = states.len(); + debug!( + "Changing state for track {} of {}: {:?}", + i, max_states, change + ); + + let state = states + .get_mut(i) + .expect("Event thread was given an illegal state index for ChangeState."); + + match change { + Mode(mode) => { + let old = state.playing; + state.playing = mode; + if old != mode && mode.is_done() { + global.fire_track_event(TrackEvent::End, i); + } + }, + Volume(vol) => { + state.volume = vol; + }, + Position(pos) => { + // Currently, only Tick should fire time events. + state.position = pos; + }, + Loops(loops, user_set) => { + state.loops = loops; + if !user_set { + global.fire_track_event(TrackEvent::Loop, i); + } + }, + Total(new) => { + // Massive, unprecedented state changes. + *state = new; + }, + } + }, + Ok(RemoveTrack(i)) => { + info!("Event state for track {} of {} removed.", i, events.len()); + + events.remove(i); + states.remove(i); + handles.remove(i); + }, + Ok(RemoveAllTracks) => { + info!("Event state for all tracks removed."); + + events.clear(); + states.clear(); + handles.clear(); + }, + Ok(Tick) => { + // NOTE: this should fire saved up blocks of state change evts. + global.tick(&mut events, &mut states, &mut handles).await; + }, + Err(_) | Ok(Poison) => { + break; + }, + } + } + + info!("Event thread exited."); +} diff --git a/src/driver/tasks/message/core.rs b/src/driver/tasks/message/core.rs new file mode 100644 index 0000000..3c5c017 --- /dev/null +++ b/src/driver/tasks/message/core.rs @@ -0,0 +1,24 @@ +use crate::{ + driver::connection::error::Error, + events::EventData, + tracks::Track, + Bitrate, + ConnectionInfo, +}; +use flume::Sender; + +#[allow(clippy::large_enum_variant)] +#[derive(Debug)] +pub enum CoreMessage { + ConnectWithResult(ConnectionInfo, Sender>), + Disconnect, + SetTrack(Option), + AddTrack(Track), + SetBitrate(Bitrate), + AddEvent(EventData), + Mute(bool), + Reconnect, + FullReconnect, + RebuildInterconnect, + Poison, +} diff --git a/src/driver/tasks/message/events.rs b/src/driver/tasks/message/events.rs new file mode 100644 index 0000000..197ebe8 --- /dev/null +++ b/src/driver/tasks/message/events.rs @@ -0,0 +1,31 @@ +use crate::{ + events::{CoreContext, EventData, EventStore}, + tracks::{LoopState, PlayMode, TrackHandle, TrackState}, +}; +use std::time::Duration; + +pub(crate) enum EventMessage { + // Event related. + // Track events should fire off the back of state changes. + AddGlobalEvent(EventData), + AddTrackEvent(usize, EventData), + FireCoreEvent(CoreContext), + + AddTrack(EventStore, TrackState, TrackHandle), + ChangeState(usize, TrackStateChange), + RemoveTrack(usize), + RemoveAllTracks, + Tick, + + Poison, +} + +#[derive(Debug)] +pub enum TrackStateChange { + Mode(PlayMode), + Volume(f32), + Position(Duration), + // Bool indicates user-set. + Loops(LoopState, bool), + Total(TrackState), +} diff --git a/src/driver/tasks/message/mixer.rs b/src/driver/tasks/message/mixer.rs new file mode 100644 index 0000000..4c2eec5 --- /dev/null +++ b/src/driver/tasks/message/mixer.rs @@ -0,0 +1,32 @@ +use super::{Interconnect, UdpRxMessage, UdpTxMessage, WsMessage}; + +use crate::{tracks::Track, Bitrate}; +use flume::Sender; +use xsalsa20poly1305::XSalsa20Poly1305 as Cipher; + +pub(crate) struct MixerConnection { + pub cipher: Cipher, + pub udp_rx: Sender, + pub udp_tx: Sender, +} + +impl Drop for MixerConnection { + fn drop(&mut self) { + let _ = self.udp_rx.send(UdpRxMessage::Poison); + let _ = self.udp_tx.send(UdpTxMessage::Poison); + } +} + +pub(crate) enum MixerMessage { + AddTrack(Track), + SetTrack(Option), + SetBitrate(Bitrate), + SetMute(bool), + SetConn(MixerConnection, u32), + DropConn, + ReplaceInterconnect(Interconnect), + RebuildEncoder, + + Ws(Option>), + Poison, +} diff --git a/src/driver/tasks/message/mod.rs b/src/driver/tasks/message/mod.rs new file mode 100644 index 0000000..1831839 --- /dev/null +++ b/src/driver/tasks/message/mod.rs @@ -0,0 +1,49 @@ +mod core; +mod events; +mod mixer; +mod udp_rx; +mod udp_tx; +mod ws; + +pub(crate) use self::{core::*, events::*, mixer::*, udp_rx::*, udp_tx::*, ws::*}; + +use flume::Sender; +use tracing::info; + +#[derive(Clone, Debug)] +pub(crate) struct Interconnect { + pub core: Sender, + pub events: Sender, + pub mixer: Sender, +} + +impl Interconnect { + pub fn poison(&self) { + let _ = self.events.send(EventMessage::Poison); + } + + pub fn poison_all(&self) { + self.poison(); + let _ = self.mixer.send(MixerMessage::Poison); + } + + pub fn restart_volatile_internals(&mut self) { + self.poison(); + + let (evt_tx, evt_rx) = flume::unbounded(); + + self.events = evt_tx; + + let ic = self.clone(); + tokio::spawn(async move { + info!("Event processor restarted."); + super::events::runner(ic, evt_rx).await; + info!("Event processor finished."); + }); + + // Make mixer aware of new targets... + let _ = self + .mixer + .send(MixerMessage::ReplaceInterconnect(self.clone())); + } +} diff --git a/src/driver/tasks/message/udp_rx.rs b/src/driver/tasks/message/udp_rx.rs new file mode 100644 index 0000000..91e740d --- /dev/null +++ b/src/driver/tasks/message/udp_rx.rs @@ -0,0 +1,7 @@ +use super::Interconnect; + +pub(crate) enum UdpRxMessage { + ReplaceInterconnect(Interconnect), + + Poison, +} diff --git a/src/driver/tasks/message/udp_tx.rs b/src/driver/tasks/message/udp_tx.rs new file mode 100644 index 0000000..349d524 --- /dev/null +++ b/src/driver/tasks/message/udp_tx.rs @@ -0,0 +1,4 @@ +pub enum UdpTxMessage { + Packet(Vec), // TODO: do something cheaper. + Poison, +} diff --git a/src/driver/tasks/message/ws.rs b/src/driver/tasks/message/ws.rs new file mode 100644 index 0000000..7ce5f07 --- /dev/null +++ b/src/driver/tasks/message/ws.rs @@ -0,0 +1,12 @@ +use super::Interconnect; +use crate::ws::WsStream; + +#[allow(dead_code)] +pub(crate) enum WsMessage { + Ws(Box), + ReplaceInterconnect(Interconnect), + SetKeepalive(f64), + Speaking(bool), + + Poison, +} diff --git a/src/driver/tasks/mixer.rs b/src/driver/tasks/mixer.rs new file mode 100644 index 0000000..3fa5d1d --- /dev/null +++ b/src/driver/tasks/mixer.rs @@ -0,0 +1,516 @@ +use super::{error::Result, message::*}; +use crate::{ + constants::*, + tracks::{PlayMode, Track}, +}; +use audiopus::{ + coder::Encoder as OpusEncoder, + softclip::SoftClip, + Application as CodingMode, + Bitrate, + Channels, +}; +use discortp::{ + rtp::{MutableRtpPacket, RtpPacket}, + MutablePacket, + Packet, +}; +use flume::{Receiver, Sender, TryRecvError}; +use rand::random; +use spin_sleep::SpinSleeper; +use std::time::Instant; +use tokio::runtime::Handle; +use tracing::{error, instrument}; +use xsalsa20poly1305::{aead::AeadInPlace, Nonce, TAG_SIZE}; + +struct Mixer { + async_handle: Handle, + bitrate: Bitrate, + conn_active: Option, + deadline: Instant, + encoder: OpusEncoder, + interconnect: Interconnect, + mix_rx: Receiver, + muted: bool, + packet: [u8; VOICE_PACKET_MAX], + prevent_events: bool, + silence_frames: u8, + sleeper: SpinSleeper, + soft_clip: SoftClip, + tracks: Vec, + ws: Option>, +} + +fn new_encoder(bitrate: Bitrate) -> Result { + let mut encoder = OpusEncoder::new(SAMPLE_RATE, Channels::Stereo, CodingMode::Audio)?; + encoder.set_bitrate(bitrate)?; + + Ok(encoder) +} + +impl Mixer { + fn new( + mix_rx: Receiver, + async_handle: Handle, + interconnect: Interconnect, + ) -> Self { + let bitrate = DEFAULT_BITRATE; + let encoder = new_encoder(bitrate) + .expect("Failed to create encoder in mixing thread with known-good values."); + let soft_clip = SoftClip::new(Channels::Stereo); + + let mut packet = [0u8; VOICE_PACKET_MAX]; + + let mut rtp = MutableRtpPacket::new(&mut packet[..]).expect( + "FATAL: Too few bytes in self.packet for RTP header.\ + (Blame: VOICE_PACKET_MAX?)", + ); + rtp.set_version(RTP_VERSION); + rtp.set_payload_type(RTP_PROFILE_TYPE); + rtp.set_sequence(random::().into()); + rtp.set_timestamp(random::().into()); + + Self { + async_handle, + bitrate, + conn_active: None, + deadline: Instant::now(), + encoder, + interconnect, + mix_rx, + muted: false, + packet, + prevent_events: false, + silence_frames: 0, + sleeper: Default::default(), + soft_clip, + tracks: vec![], + ws: None, + } + } + + fn run(&mut self) { + let mut events_failure = false; + let mut conn_failure = false; + + 'runner: loop { + loop { + use MixerMessage::*; + + let error = match self.mix_rx.try_recv() { + Ok(AddTrack(mut t)) => { + t.source.prep_with_handle(self.async_handle.clone()); + self.add_track(t) + }, + Ok(SetTrack(t)) => { + self.tracks.clear(); + + let mut out = self.fire_event(EventMessage::RemoveAllTracks); + + if let Some(mut t) = t { + t.source.prep_with_handle(self.async_handle.clone()); + + // Do this unconditionally: this affects local state infallibly, + // with the event installation being the remote part. + if let Err(e) = self.add_track(t) { + out = Err(e); + } + } + + out + }, + Ok(SetBitrate(b)) => { + self.bitrate = b; + if let Err(e) = self.set_bitrate(b) { + error!("Failed to update bitrate {:?}", e); + } + Ok(()) + }, + Ok(SetMute(m)) => { + self.muted = m; + Ok(()) + }, + Ok(SetConn(conn, ssrc)) => { + self.conn_active = Some(conn); + let mut rtp = MutableRtpPacket::new(&mut self.packet[..]).expect( + "Too few bytes in self.packet for RTP header.\ + (Blame: VOICE_PACKET_MAX?)", + ); + rtp.set_ssrc(ssrc); + self.deadline = Instant::now(); + Ok(()) + }, + Ok(DropConn) => { + self.conn_active = None; + Ok(()) + }, + Ok(ReplaceInterconnect(i)) => { + self.prevent_events = false; + if let Some(ws) = &self.ws { + conn_failure |= + ws.send(WsMessage::ReplaceInterconnect(i.clone())).is_err(); + } + if let Some(conn) = &self.conn_active { + conn_failure |= conn + .udp_rx + .send(UdpRxMessage::ReplaceInterconnect(i.clone())) + .is_err(); + } + self.interconnect = i; + + self.rebuild_tracks() + }, + Ok(RebuildEncoder) => match new_encoder(self.bitrate) { + Ok(encoder) => { + self.encoder = encoder; + Ok(()) + }, + Err(e) => { + error!("Failed to rebuild encoder. Resetting bitrate. {:?}", e); + self.bitrate = DEFAULT_BITRATE; + self.encoder = new_encoder(self.bitrate) + .expect("Failed fallback rebuild of OpusEncoder with safe inputs."); + Ok(()) + }, + }, + Ok(Ws(new_ws_handle)) => { + self.ws = new_ws_handle; + Ok(()) + }, + + Err(TryRecvError::Disconnected) | Ok(Poison) => { + break 'runner; + }, + + Err(TryRecvError::Empty) => { + break; + }, + }; + + if let Err(e) = error { + events_failure |= e.should_trigger_interconnect_rebuild(); + conn_failure |= e.should_trigger_connect(); + } + } + + if let Err(e) = self.cycle().and_then(|_| self.audio_commands_events()) { + events_failure |= e.should_trigger_interconnect_rebuild(); + conn_failure |= e.should_trigger_connect(); + + error!("Mixer thread cycle: {:?}", e); + } + + // event failure? rebuild interconnect. + // ws or udp failure? full connect + // (soft reconnect is covered by the ws task.) + if events_failure { + self.prevent_events = true; + self.interconnect + .core + .send(CoreMessage::RebuildInterconnect) + .expect("FATAL: No way to rebuild driver core from mixer."); + events_failure = false; + } + + if conn_failure { + self.interconnect + .core + .send(CoreMessage::FullReconnect) + .expect("FATAL: No way to rebuild driver core from mixer."); + conn_failure = false; + } + } + } + + #[inline] + fn fire_event(&self, event: EventMessage) -> Result<()> { + // As this task is responsible for noticing the potential death of an event context, + // it's responsible for not forcibly recreating said context repeatedly. + if !self.prevent_events { + self.interconnect.events.send(event)?; + Ok(()) + } else { + Ok(()) + } + } + + #[inline] + fn add_track(&mut self, mut track: Track) -> Result<()> { + let evts = track.events.take().unwrap_or_default(); + let state = track.state(); + let handle = track.handle.clone(); + + self.tracks.push(track); + + self.interconnect + .events + .send(EventMessage::AddTrack(evts, state, handle))?; + + Ok(()) + } + + // rebuilds the event thread's view of each track, in event of a full rebuild. + #[inline] + fn rebuild_tracks(&mut self) -> Result<()> { + for track in self.tracks.iter_mut() { + let evts = track.events.take().unwrap_or_default(); + let state = track.state(); + let handle = track.handle.clone(); + + self.interconnect + .events + .send(EventMessage::AddTrack(evts, state, handle))?; + } + + Ok(()) + } + + #[inline] + fn mix_tracks<'a>( + &mut self, + opus_frame: &'a mut [u8], + mix_buffer: &mut [f32; STEREO_FRAME_SIZE], + ) -> Result<(usize, &'a [u8])> { + let mut len = 0; + + // Opus frame passthrough. + // This requires that we have only one track, who has volume 1.0, and an + // Opus codec type. + let do_passthrough = self.tracks.len() == 1 && { + let track = &self.tracks[0]; + (track.volume - 1.0).abs() < f32::EPSILON && track.source.supports_passthrough() + }; + + for (i, track) in self.tracks.iter_mut().enumerate() { + let vol = track.volume; + let stream = &mut track.source; + + if track.playing != PlayMode::Play { + continue; + } + + let (temp_len, opus_len) = if do_passthrough { + (0, track.source.read_opus_frame(opus_frame).ok()) + } else { + (stream.mix(mix_buffer, vol), None) + }; + + len = len.max(temp_len); + if temp_len > 0 || opus_len.is_some() { + track.step_frame(); + } else if track.do_loop() { + if let Some(time) = track.seek_time(Default::default()) { + // have to reproduce self.fire_event here + // to circumvent the borrow checker's lack of knowledge. + // + // In event of error, one of the later event calls will + // trigger the event thread rebuild: it is more prudent that + // the mixer works as normal right now. + if !self.prevent_events { + let _ = self.interconnect.events.send(EventMessage::ChangeState( + i, + TrackStateChange::Position(time), + )); + let _ = self.interconnect.events.send(EventMessage::ChangeState( + i, + TrackStateChange::Loops(track.loops, false), + )); + } + } + } else { + track.end(); + } + + if let Some(opus_len) = opus_len { + return Ok((STEREO_FRAME_SIZE, &opus_frame[..opus_len])); + } + } + + Ok((len, &opus_frame[..0])) + } + + #[inline] + fn audio_commands_events(&mut self) -> Result<()> { + // Apply user commands. + for (i, track) in self.tracks.iter_mut().enumerate() { + // This causes fallible event system changes, + // but if the event thread has died then we'll certainly + // detect that on the tick later. + // Changes to play state etc. MUST all be handled. + track.process_commands(i, &self.interconnect); + } + + // TODO: do without vec? + let mut i = 0; + let mut to_remove = Vec::with_capacity(self.tracks.len()); + while i < self.tracks.len() { + let track = self + .tracks + .get_mut(i) + .expect("Tried to remove an illegal track index."); + + if track.playing.is_done() { + let p_state = track.playing(); + self.tracks.remove(i); + to_remove.push(i); + self.fire_event(EventMessage::ChangeState( + i, + TrackStateChange::Mode(p_state), + ))?; + } else { + i += 1; + } + } + + // Tick + self.fire_event(EventMessage::Tick)?; + + // Then do removals. + for i in &to_remove[..] { + self.fire_event(EventMessage::RemoveTrack(*i))?; + } + + Ok(()) + } + + #[inline] + fn march_deadline(&mut self) { + self.sleeper + .sleep(self.deadline.saturating_duration_since(Instant::now())); + self.deadline += TIMESTEP_LENGTH; + } + + fn cycle(&mut self) -> Result<()> { + if self.conn_active.is_none() { + self.march_deadline(); + return Ok(()); + } + + // TODO: can we make opus_frame_backing *actually* a view over + // some region of self.packet, derived using the encryption mode? + // This saves a copy on Opus passthrough. + let mut opus_frame_backing = [0u8; STEREO_FRAME_SIZE]; + let mut mix_buffer = [0f32; STEREO_FRAME_SIZE]; + + // Slice which mix tracks may use to passthrough direct Opus frames. + let mut opus_space = &mut opus_frame_backing[..]; + + // Walk over all the audio files, combining into one audio frame according + // to volume, play state, etc. + let (mut len, mut opus_frame) = self.mix_tracks(&mut opus_space, &mut mix_buffer)?; + + self.soft_clip.apply(&mut mix_buffer[..])?; + + if self.muted { + len = 0; + } + + if len == 0 { + if self.silence_frames > 0 { + self.silence_frames -= 1; + + // Explicit "Silence" frame. + opus_frame = &SILENT_FRAME[..]; + } else { + // Per official guidelines, send 5x silence BEFORE we stop speaking. + if let Some(ws) = &self.ws { + // NOTE: this should prevent a catastrophic thread pileup. + // A full reconnect might cause an inner closed connection. + // It's safer to leave the central task to clean this up and + // pass the mixer a new channel. + let _ = ws.send(WsMessage::Speaking(false)); + } + + self.march_deadline(); + + return Ok(()); + } + } else { + self.silence_frames = 5; + } + + if let Some(ws) = &self.ws { + ws.send(WsMessage::Speaking(true))?; + } + + self.march_deadline(); + self.prep_and_send_packet(mix_buffer, opus_frame)?; + + Ok(()) + } + + fn set_bitrate(&mut self, bitrate: Bitrate) -> Result<()> { + self.encoder.set_bitrate(bitrate).map_err(Into::into) + } + + fn prep_and_send_packet(&mut self, buffer: [f32; 1920], opus_frame: &[u8]) -> Result<()> { + let conn = self + .conn_active + .as_mut() + .expect("Shouldn't be mixing packets without access to a cipher + UDP dest."); + + let mut nonce = Nonce::default(); + let index = { + let mut rtp = MutableRtpPacket::new(&mut self.packet[..]).expect( + "FATAL: Too few bytes in self.packet for RTP header.\ + (Blame: VOICE_PACKET_MAX?)", + ); + + let pkt = rtp.packet(); + let rtp_len = RtpPacket::minimum_packet_size(); + nonce[..rtp_len].copy_from_slice(&pkt[..rtp_len]); + + let payload = rtp.payload_mut(); + + let payload_len = if opus_frame.is_empty() { + self.encoder + .encode_float(&buffer[..STEREO_FRAME_SIZE], &mut payload[TAG_SIZE..])? + } else { + let len = opus_frame.len(); + payload[TAG_SIZE..TAG_SIZE + len].clone_from_slice(opus_frame); + len + }; + + let final_payload_size = TAG_SIZE + payload_len; + + let tag = conn.cipher.encrypt_in_place_detached( + &nonce, + b"", + &mut payload[TAG_SIZE..final_payload_size], + )?; + payload[..TAG_SIZE].copy_from_slice(&tag[..]); + + rtp_len + final_payload_size + }; + + // TODO: This is dog slow, don't do this. + // Can we replace this with a shared ring buffer + semaphore? + // i.e., do something like double/triple buffering in graphics. + conn.udp_tx + .send(UdpTxMessage::Packet(self.packet[..index].to_vec()))?; + + let mut rtp = MutableRtpPacket::new(&mut self.packet[..]).expect( + "FATAL: Too few bytes in self.packet for RTP header.\ + (Blame: VOICE_PACKET_MAX?)", + ); + rtp.set_sequence(rtp.get_sequence() + 1); + rtp.set_timestamp(rtp.get_timestamp() + MONO_FRAME_SIZE as u32); + + Ok(()) + } +} + +/// The mixing thread is a synchronous context due to its compute-bound nature. +/// +/// We pass in an async handle for the benefit of some Input classes (e.g., restartables) +/// who need to run their restart code elsewhere and return blank data until such time. +#[instrument(skip(interconnect, mix_rx, async_handle))] +pub(crate) fn runner( + interconnect: Interconnect, + mix_rx: Receiver, + async_handle: Handle, +) { + let mut mixer = Mixer::new(mix_rx, async_handle, interconnect); + + mixer.run(); +} diff --git a/src/driver/tasks/mod.rs b/src/driver/tasks/mod.rs new file mode 100644 index 0000000..2e0b2d0 --- /dev/null +++ b/src/driver/tasks/mod.rs @@ -0,0 +1,155 @@ +pub mod error; +mod events; +pub(crate) mod message; +mod mixer; +pub(crate) mod udp_rx; +pub(crate) mod udp_tx; +pub(crate) mod ws; + +use super::{ + connection::{error::Error as ConnectionError, Connection}, + Config, +}; +use flume::{Receiver, RecvError, Sender}; +use message::*; +use tokio::runtime::Handle; +use tracing::{error, info, instrument}; + +pub(crate) fn start(config: Config, rx: Receiver, tx: Sender) { + tokio::spawn(async move { + info!("Driver started."); + runner(config, rx, tx).await; + info!("Driver finished."); + }); +} + +fn start_internals(core: Sender) -> Interconnect { + let (evt_tx, evt_rx) = flume::unbounded(); + let (mix_tx, mix_rx) = flume::unbounded(); + + let interconnect = Interconnect { + core, + events: evt_tx, + mixer: mix_tx, + }; + + let ic = interconnect.clone(); + tokio::spawn(async move { + info!("Event processor started."); + events::runner(ic, evt_rx).await; + info!("Event processor finished."); + }); + + let ic = interconnect.clone(); + let handle = Handle::current(); + std::thread::spawn(move || { + info!("Mixer started."); + mixer::runner(ic, mix_rx, handle); + info!("Mixer finished."); + }); + + interconnect +} + +#[instrument(skip(rx, tx))] +async fn runner(config: Config, rx: Receiver, tx: Sender) { + let mut connection = None; + let mut interconnect = start_internals(tx); + + loop { + match rx.recv_async().await { + Ok(CoreMessage::ConnectWithResult(info, tx)) => { + connection = match Connection::new(info, &interconnect, &config).await { + Ok(connection) => { + // Other side may not be listening: this is fine. + let _ = tx.send(Ok(())); + Some(connection) + }, + Err(why) => { + // See above. + let _ = tx.send(Err(why)); + + None + }, + }; + }, + Ok(CoreMessage::Disconnect) => { + connection = None; + let _ = interconnect.mixer.send(MixerMessage::DropConn); + let _ = interconnect.mixer.send(MixerMessage::RebuildEncoder); + }, + Ok(CoreMessage::SetTrack(s)) => { + let _ = interconnect.mixer.send(MixerMessage::SetTrack(s)); + }, + Ok(CoreMessage::AddTrack(s)) => { + let _ = interconnect.mixer.send(MixerMessage::AddTrack(s)); + }, + Ok(CoreMessage::SetBitrate(b)) => { + let _ = interconnect.mixer.send(MixerMessage::SetBitrate(b)); + }, + Ok(CoreMessage::AddEvent(evt)) => { + let _ = interconnect.events.send(EventMessage::AddGlobalEvent(evt)); + }, + Ok(CoreMessage::Mute(m)) => { + let _ = interconnect.mixer.send(MixerMessage::SetMute(m)); + }, + Ok(CoreMessage::Reconnect) => { + if let Some(mut conn) = connection.take() { + // try once: if interconnect, try again. + // if still issue, full connect. + let info = conn.info.clone(); + + let full_connect = match conn.reconnect().await { + Ok(()) => { + connection = Some(conn); + false + }, + Err(ConnectionError::InterconnectFailure(_)) => { + interconnect.restart_volatile_internals(); + + match conn.reconnect().await { + Ok(()) => { + connection = Some(conn); + false + }, + _ => true, + } + }, + _ => true, + }; + + if full_connect { + connection = Connection::new(info, &interconnect, &config) + .await + .map_err(|e| { + error!("Catastrophic connection failure. Stopping. {:?}", e); + e + }) + .ok(); + } + } + }, + Ok(CoreMessage::FullReconnect) => + if let Some(conn) = connection.take() { + let info = conn.info.clone(); + + connection = Connection::new(info, &interconnect, &config) + .await + .map_err(|e| { + error!("Catastrophic connection failure. Stopping. {:?}", e); + e + }) + .ok(); + }, + Ok(CoreMessage::RebuildInterconnect) => { + interconnect.restart_volatile_internals(); + }, + Err(RecvError::Disconnected) | Ok(CoreMessage::Poison) => { + break; + }, + } + } + + info!("Main thread exited"); + interconnect.poison_all(); +} diff --git a/src/driver/tasks/udp_rx.rs b/src/driver/tasks/udp_rx.rs new file mode 100644 index 0000000..263ef76 --- /dev/null +++ b/src/driver/tasks/udp_rx.rs @@ -0,0 +1,286 @@ +use super::{ + error::{Error, Result}, + message::*, +}; +use crate::{constants::*, driver::CryptoMode, events::CoreContext}; +use audiopus::{coder::Decoder as OpusDecoder, Channels}; +use discortp::{ + demux::{self, DemuxedMut}, + rtp::{RtpExtensionPacket, RtpPacket}, + FromPacket, + MutablePacket, + Packet, + PacketSize, +}; +use flume::Receiver; +use std::collections::HashMap; +use tokio::net::udp::RecvHalf; +use tracing::{error, info, instrument, warn}; +use xsalsa20poly1305::{aead::AeadInPlace, Nonce, Tag, XSalsa20Poly1305 as Cipher, TAG_SIZE}; + +#[derive(Debug)] +struct SsrcState { + silent_frame_count: u16, + decoder: OpusDecoder, + last_seq: u16, +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +enum SpeakingDelta { + Same, + Start, + Stop, +} + +impl SsrcState { + fn new(pkt: RtpPacket<'_>) -> Self { + Self { + silent_frame_count: 5, // We do this to make the first speech packet fire an event. + decoder: OpusDecoder::new(SAMPLE_RATE, Channels::Stereo) + .expect("Failed to create new Opus decoder for source."), + last_seq: pkt.get_sequence().into(), + } + } + + fn process( + &mut self, + pkt: RtpPacket<'_>, + data_offset: usize, + ) -> Result<(SpeakingDelta, Vec)> { + let new_seq: u16 = pkt.get_sequence().into(); + + let extensions = pkt.get_extension() != 0; + let seq_delta = new_seq.wrapping_sub(self.last_seq); + Ok(if seq_delta >= (1 << 15) { + // Overflow, reordered (previously missing) packet. + (SpeakingDelta::Same, vec![]) + } else { + self.last_seq = new_seq; + let missed_packets = seq_delta.saturating_sub(1); + let (audio, pkt_size) = + self.scan_and_decode(&pkt.payload()[data_offset..], extensions, missed_packets)?; + + let delta = if pkt_size == SILENT_FRAME.len() { + // Frame is silent. + let old = self.silent_frame_count; + self.silent_frame_count = + self.silent_frame_count.saturating_add(1 + missed_packets); + + if self.silent_frame_count >= 5 && old < 5 { + SpeakingDelta::Stop + } else { + SpeakingDelta::Same + } + } else { + // Frame has meaningful audio. + let out = if self.silent_frame_count >= 5 { + SpeakingDelta::Start + } else { + SpeakingDelta::Same + }; + self.silent_frame_count = 0; + out + }; + + (delta, audio) + }) + } + + fn scan_and_decode( + &mut self, + data: &[u8], + extension: bool, + missed_packets: u16, + ) -> Result<(Vec, usize)> { + let mut out = vec![0; STEREO_FRAME_SIZE]; + let start = if extension { + RtpExtensionPacket::new(data) + .map(|pkt| pkt.packet_size()) + .ok_or_else(|| { + error!("Extension packet indicated, but insufficient space."); + Error::IllegalVoicePacket + }) + } else { + Ok(0) + }?; + + for _ in 0..missed_packets { + let missing_frame: Option<&[u8]> = None; + if let Err(e) = self.decoder.decode(missing_frame, &mut out[..], false) { + warn!("Issue while decoding for missed packet: {:?}.", e); + } + } + + let audio_len = self + .decoder + .decode(Some(&data[start..]), &mut out[..], false) + .map_err(|e| { + error!("Failed to decode received packet: {:?}.", e); + e + })?; + + // Decoding to stereo: audio_len refers to sample count irrespective of channel count. + // => multiply by number of channels. + out.truncate(2 * audio_len); + + Ok((out, data.len() - start)) + } +} + +struct UdpRx { + cipher: Cipher, + decoder_map: HashMap, + #[allow(dead_code)] + mode: CryptoMode, // In future, this will allow crypto mode selection. + packet_buffer: [u8; VOICE_PACKET_MAX], + rx: Receiver, + udp_socket: RecvHalf, +} + +impl UdpRx { + #[instrument(skip(self))] + async fn run(&mut self, interconnect: &mut Interconnect) { + loop { + tokio::select! { + Ok((len, _addr)) = self.udp_socket.recv_from(&mut self.packet_buffer[..]) => { + self.process_udp_message(interconnect, len); + } + msg = self.rx.recv_async() => { + use UdpRxMessage::*; + match msg { + Ok(ReplaceInterconnect(i)) => { + *interconnect = i; + } + Ok(Poison) | Err(_) => break, + } + } + } + } + } + + fn process_udp_message(&mut self, interconnect: &Interconnect, len: usize) { + // NOTE: errors here (and in general for UDP) are not fatal to the connection. + // Panics should be avoided due to adversarial nature of rx'd packets, + // but correct handling should not prompt a reconnect. + // + // For simplicity, we nominate the mixing context to rebuild the event + // context if it fails (hence, the `let _ =` statements.), as it will try to + // make contact every 20ms. + let packet = &mut self.packet_buffer[..len]; + + match demux::demux_mut(packet) { + DemuxedMut::Rtp(mut rtp) => { + if !rtp_valid(rtp.to_immutable()) { + error!("Illegal RTP message received."); + return; + } + + let rtp_body_start = + decrypt_in_place(&mut rtp, &self.cipher).expect("RTP decryption failed."); + + let entry = self + .decoder_map + .entry(rtp.get_ssrc()) + .or_insert_with(|| SsrcState::new(rtp.to_immutable())); + + if let Ok((delta, audio)) = entry.process(rtp.to_immutable(), rtp_body_start) { + match delta { + SpeakingDelta::Start => { + let _ = interconnect.events.send(EventMessage::FireCoreEvent( + CoreContext::SpeakingUpdate { + ssrc: rtp.get_ssrc(), + speaking: true, + }, + )); + }, + SpeakingDelta::Stop => { + let _ = interconnect.events.send(EventMessage::FireCoreEvent( + CoreContext::SpeakingUpdate { + ssrc: rtp.get_ssrc(), + speaking: false, + }, + )); + }, + _ => {}, + } + + let _ = interconnect.events.send(EventMessage::FireCoreEvent( + CoreContext::VoicePacket { + audio, + packet: rtp.from_packet(), + payload_offset: rtp_body_start, + }, + )); + } else { + warn!("RTP decoding/decrytion failed."); + } + }, + DemuxedMut::Rtcp(mut rtcp) => { + let rtcp_body_start = decrypt_in_place(&mut rtcp, &self.cipher); + + if let Ok(start) = rtcp_body_start { + let _ = interconnect.events.send(EventMessage::FireCoreEvent( + CoreContext::RtcpPacket { + packet: rtcp.from_packet(), + payload_offset: start, + }, + )); + } else { + warn!("RTCP decryption failed."); + } + }, + DemuxedMut::FailedParse(t) => { + warn!("Failed to parse message of type {:?}.", t); + }, + _ => { + warn!("Illegal UDP packet from voice server."); + }, + } + } +} + +#[instrument(skip(interconnect, rx, cipher))] +pub(crate) async fn runner( + mut interconnect: Interconnect, + rx: Receiver, + cipher: Cipher, + mode: CryptoMode, + udp_socket: RecvHalf, +) { + info!("UDP receive handle started."); + + let mut state = UdpRx { + cipher, + decoder_map: Default::default(), + mode, + packet_buffer: [0u8; VOICE_PACKET_MAX], + rx, + udp_socket, + }; + + state.run(&mut interconnect).await; + + info!("UDP receive handle stopped."); +} + +#[inline] +fn decrypt_in_place(packet: &mut impl MutablePacket, cipher: &Cipher) -> Result { + // Applies discord's cheapest. + // In future, might want to make a choice... + let header_len = packet.packet().len() - packet.payload().len(); + let mut nonce = Nonce::default(); + nonce[..header_len].copy_from_slice(&packet.packet()[..header_len]); + + let data = packet.payload_mut(); + let (tag_bytes, data_bytes) = data.split_at_mut(TAG_SIZE); + let tag = Tag::from_slice(tag_bytes); + + Ok(cipher + .decrypt_in_place_detached(&nonce, b"", data_bytes, tag) + .map(|_| TAG_SIZE)?) +} + +#[inline] +fn rtp_valid(packet: RtpPacket<'_>) -> bool { + packet.get_version() == RTP_VERSION && packet.get_payload_type() == RTP_PROFILE_TYPE +} diff --git a/src/driver/tasks/udp_tx.rs b/src/driver/tasks/udp_tx.rs new file mode 100644 index 0000000..7027a09 --- /dev/null +++ b/src/driver/tasks/udp_tx.rs @@ -0,0 +1,45 @@ +use super::message::*; +use crate::constants::*; +use discortp::discord::MutableKeepalivePacket; +use flume::Receiver; +use tokio::{ + net::udp::SendHalf, + time::{timeout_at, Elapsed, Instant}, +}; +use tracing::{error, info, instrument, trace}; + +#[instrument(skip(udp_msg_rx))] +pub(crate) async fn runner(udp_msg_rx: Receiver, ssrc: u32, mut udp_tx: SendHalf) { + info!("UDP transmit handle started."); + + let mut keepalive_bytes = [0u8; MutableKeepalivePacket::minimum_packet_size()]; + let mut ka = MutableKeepalivePacket::new(&mut keepalive_bytes[..]) + .expect("FATAL: Insufficient bytes given to keepalive packet."); + ka.set_ssrc(ssrc); + + let mut ka_time = Instant::now() + UDP_KEEPALIVE_GAP; + + loop { + use UdpTxMessage::*; + match timeout_at(ka_time, udp_msg_rx.recv_async()).await { + Err(Elapsed { .. }) => { + trace!("Sending UDP Keepalive."); + if let Err(e) = udp_tx.send(&keepalive_bytes[..]).await { + error!("Fatal UDP keepalive send error: {:?}.", e); + break; + } + ka_time += UDP_KEEPALIVE_GAP; + }, + Ok(Ok(Packet(p))) => + if let Err(e) = udp_tx.send(&p[..]).await { + error!("Fatal UDP packet send error: {:?}.", e); + break; + }, + Ok(Err(_)) | Ok(Ok(Poison)) => { + break; + }, + } + } + + info!("UDP transmit handle stopped."); +} diff --git a/src/driver/tasks/ws.rs b/src/driver/tasks/ws.rs new file mode 100644 index 0000000..6f9813c --- /dev/null +++ b/src/driver/tasks/ws.rs @@ -0,0 +1,205 @@ +use super::{error::Result, message::*}; +use crate::{ + events::CoreContext, + model::{ + payload::{Heartbeat, Speaking}, + Event as GatewayEvent, + SpeakingState, + }, + ws::{Error as WsError, ReceiverExt, SenderExt, WsStream}, +}; +use flume::Receiver; +use rand::random; +use std::time::Duration; +use tokio::time::{self, Instant}; +use tracing::{error, info, instrument, trace, warn}; + +struct AuxNetwork { + rx: Receiver, + ws_client: WsStream, + dont_send: bool, + + ssrc: u32, + heartbeat_interval: Duration, + + speaking: SpeakingState, + last_heartbeat_nonce: Option, +} + +impl AuxNetwork { + pub(crate) fn new( + evt_rx: Receiver, + ws_client: WsStream, + ssrc: u32, + heartbeat_interval: f64, + ) -> Self { + Self { + rx: evt_rx, + ws_client, + dont_send: false, + + ssrc, + heartbeat_interval: Duration::from_secs_f64(heartbeat_interval / 1000.0), + + speaking: SpeakingState::empty(), + last_heartbeat_nonce: None, + } + } + + #[instrument(skip(self))] + async fn run(&mut self, interconnect: &mut Interconnect) { + let mut next_heartbeat = Instant::now() + self.heartbeat_interval; + + loop { + let mut ws_error = false; + + let hb = time::delay_until(next_heartbeat); + + tokio::select! { + _ = hb => { + ws_error = match self.send_heartbeat().await { + Err(e) => { + error!("Heartbeat send failure {:?}.", e); + true + }, + _ => false, + }; + next_heartbeat = self.next_heartbeat(); + } + ws_msg = self.ws_client.recv_json_no_timeout(), if !self.dont_send => { + ws_error = match ws_msg { + Err(WsError::Json(e)) => { + warn!("Unexpected JSON {:?}.", e); + false + }, + Err(e) => { + error!("Error processing ws {:?}.", e); + true + }, + Ok(Some(msg)) => { + self.process_ws(interconnect, msg); + false + }, + _ => false, + }; + } + inner_msg = self.rx.recv_async() => { + match inner_msg { + Ok(WsMessage::Ws(data)) => { + self.ws_client = *data; + next_heartbeat = self.next_heartbeat(); + self.dont_send = false; + }, + Ok(WsMessage::ReplaceInterconnect(i)) => { + *interconnect = i; + }, + Ok(WsMessage::SetKeepalive(keepalive)) => { + self.heartbeat_interval = Duration::from_secs_f64(keepalive / 1000.0); + next_heartbeat = self.next_heartbeat(); + }, + Ok(WsMessage::Speaking(is_speaking)) => { + if self.speaking.contains(SpeakingState::MICROPHONE) != is_speaking && !self.dont_send { + self.speaking.set(SpeakingState::MICROPHONE, is_speaking); + info!("Changing to {:?}", self.speaking); + + let ssu_status = self.ws_client + .send_json(&GatewayEvent::from(Speaking { + delay: Some(0), + speaking: self.speaking, + ssrc: self.ssrc, + user_id: None, + })) + .await; + + ws_error |= match ssu_status { + Err(e) => { + error!("Issue sending speaking update {:?}.", e); + true + }, + _ => false, + } + } + }, + Err(_) | Ok(WsMessage::Poison) => { + break; + }, + } + } + } + + if ws_error { + let _ = interconnect.core.send(CoreMessage::Reconnect); + self.dont_send = true; + } + } + } + + fn next_heartbeat(&self) -> Instant { + Instant::now() + self.heartbeat_interval + } + + async fn send_heartbeat(&mut self) -> Result<()> { + let nonce = random::(); + self.last_heartbeat_nonce = Some(nonce); + + trace!("Sent heartbeat {:?}", self.speaking); + + if !self.dont_send { + self.ws_client + .send_json(&GatewayEvent::from(Heartbeat { nonce })) + .await?; + } + + Ok(()) + } + + fn process_ws(&mut self, interconnect: &Interconnect, value: GatewayEvent) { + match value { + GatewayEvent::Speaking(ev) => { + let _ = interconnect.events.send(EventMessage::FireCoreEvent( + CoreContext::SpeakingStateUpdate(ev), + )); + }, + GatewayEvent::ClientConnect(ev) => { + let _ = interconnect + .events + .send(EventMessage::FireCoreEvent(CoreContext::ClientConnect(ev))); + }, + GatewayEvent::ClientDisconnect(ev) => { + let _ = interconnect.events.send(EventMessage::FireCoreEvent( + CoreContext::ClientDisconnect(ev), + )); + }, + GatewayEvent::HeartbeatAck(ev) => { + if let Some(nonce) = self.last_heartbeat_nonce.take() { + if ev.nonce == nonce { + trace!("Heartbeat ACK received."); + } else { + warn!( + "Heartbeat nonce mismatch! Expected {}, saw {}.", + nonce, ev.nonce + ); + } + } + }, + other => { + trace!("Received other websocket data: {:?}", other); + }, + } + } +} + +#[instrument(skip(interconnect, ws_client))] +pub(crate) async fn runner( + mut interconnect: Interconnect, + evt_rx: Receiver, + ws_client: WsStream, + ssrc: u32, + heartbeat_interval: f64, +) { + info!("WS thread started."); + let mut aux = AuxNetwork::new(evt_rx, ws_client, ssrc, heartbeat_interval); + + aux.run(&mut interconnect).await; + info!("WS thread finished."); +} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..bfa4a4c --- /dev/null +++ b/src/error.rs @@ -0,0 +1,69 @@ +//! Driver and gateway error handling. + +#[cfg(feature = "serenity")] +use futures::channel::mpsc::TrySendError; +#[cfg(feature = "serenity")] +use serenity::gateway::InterMessage; +#[cfg(feature = "gateway")] +use std::{error::Error, fmt}; +#[cfg(feature = "twilight")] +use twilight_gateway::shard::CommandError; + +#[cfg(feature = "gateway")] +#[derive(Debug)] +/// Error returned when a manager or call handler is +/// unable to send messages over Discord's gateway. +pub enum JoinError { + /// No available gateway connection was provided to send + /// voice state update messages. + NoSender, + /// Tried to leave a [`Call`] which was not found. + /// + /// [`Call`]: ../struct.Call.html + NoCall, + #[cfg(feature = "serenity")] + /// Serenity-specific WebSocket send error. + Serenity(TrySendError), + #[cfg(feature = "twilight")] + /// Twilight-specific WebSocket send error. + Twilight(CommandError), +} + +#[cfg(feature = "gateway")] +impl fmt::Display for JoinError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Failed to Join Voice channel: ")?; + match self { + JoinError::NoSender => write!(f, "no gateway destination."), + JoinError::NoCall => write!(f, "tried to leave a non-existent call."), + #[cfg(feature = "serenity")] + JoinError::Serenity(t) => write!(f, "serenity failure {}.", t), + #[cfg(feature = "twilight")] + JoinError::Twilight(t) => write!(f, "twilight failure {}.", t), + } + } +} + +#[cfg(feature = "gateway")] +impl Error for JoinError {} + +#[cfg(all(feature = "serenity", feature = "gateway"))] +impl From> for JoinError { + fn from(e: TrySendError) -> Self { + JoinError::Serenity(e) + } +} + +#[cfg(all(feature = "twilight", feature = "gateway"))] +impl From for JoinError { + fn from(e: CommandError) -> Self { + JoinError::Twilight(e) + } +} + +#[cfg(feature = "gateway")] +/// Convenience type for Discord gateway error handling. +pub type JoinResult = Result; + +#[cfg(feature = "driver")] +pub use crate::driver::connection::error::{Error as ConnectionError, Result as ConnectionResult}; diff --git a/src/events/context.rs b/src/events/context.rs new file mode 100644 index 0000000..004465f --- /dev/null +++ b/src/events/context.rs @@ -0,0 +1,137 @@ +use super::*; +use crate::{ + model::payload::{ClientConnect, ClientDisconnect, Speaking}, + tracks::{TrackHandle, TrackState}, +}; +use discortp::{rtcp::Rtcp, rtp::Rtp}; + +/// Information about which tracks or data fired an event. +/// +/// [`Track`] events may be local or global, and have no tracks +/// if fired on the global context via [`Handler::add_global_event`]. +/// +/// [`Track`]: ../tracks/struct.Track.html +/// [`Handler::add_global_event`]: ../struct.Handler.html#method.add_global_event +#[derive(Clone, Debug)] +pub enum EventContext<'a> { + /// Track event context, passed to events created via [`TrackHandle::add_event`], + /// [`EventStore::add_event`], or relevant global events. + /// + /// [`EventStore::add_event`]: struct.EventStore.html#method.add_event + /// [`TrackHandle::add_event`]: ../tracks/struct.TrackHandle.html#method.add_event + Track(&'a [(&'a TrackState, &'a TrackHandle)]), + /// Speaking state update, typically describing how another voice + /// user is transmitting audio data. Clients must send at least one such + /// packet to allow SSRC/UserID matching. + SpeakingStateUpdate(Speaking), + /// Speaking state transition, describing whether a given source has started/stopped + /// transmitting. This fires in response to a silent burst, or the first packet + /// breaking such a burst. + SpeakingUpdate { + /// Synchronisation Source of the user who has begun speaking. + /// + /// This must be combined with another event class to map this back to + /// its original UserId. + ssrc: u32, + /// Whether this user is currently speaking. + speaking: bool, + }, + /// Opus audio packet, received from another stream (detailed in `packet`). + /// `payload_offset` contains the true payload location within the raw packet's `payload()`, + /// if extensions or raw packet data are required. + /// if `audio.len() == 0`, then this packet arrived out-of-order. + VoicePacket { + /// Decoded audio from this packet. + audio: &'a Vec, + /// Raw RTP packet data. + /// + /// Includes the SSRC (i.e., sender) of this packet. + packet: &'a Rtp, + /// Byte index into the packet for where the payload begins. + payload_offset: usize, + }, + /// Telemetry/statistics packet, received from another stream (detailed in `packet`). + /// `payload_offset` contains the true payload location within the raw packet's `payload()`, + /// to allow manual decoding of `Rtcp` packet bodies. + RtcpPacket { + /// Raw RTCP packet data. + packet: &'a Rtcp, + /// Byte index into the packet for where the payload begins. + payload_offset: usize, + }, + /// Fired whenever a client connects to a call for the first time, allowing SSRC/UserID + /// matching. + ClientConnect(ClientConnect), + /// Fired whenever a client disconnects. + ClientDisconnect(ClientDisconnect), +} + +#[derive(Clone, Debug)] +pub(crate) enum CoreContext { + SpeakingStateUpdate(Speaking), + SpeakingUpdate { + ssrc: u32, + speaking: bool, + }, + VoicePacket { + audio: Vec, + packet: Rtp, + payload_offset: usize, + }, + RtcpPacket { + packet: Rtcp, + payload_offset: usize, + }, + ClientConnect(ClientConnect), + ClientDisconnect(ClientDisconnect), +} + +impl<'a> CoreContext { + pub(crate) fn to_user_context(&'a self) -> EventContext<'a> { + use CoreContext::*; + + match self { + SpeakingStateUpdate(evt) => EventContext::SpeakingStateUpdate(*evt), + SpeakingUpdate { ssrc, speaking } => EventContext::SpeakingUpdate { + ssrc: *ssrc, + speaking: *speaking, + }, + VoicePacket { + audio, + packet, + payload_offset, + } => EventContext::VoicePacket { + audio, + packet, + payload_offset: *payload_offset, + }, + RtcpPacket { + packet, + payload_offset, + } => EventContext::RtcpPacket { + packet, + payload_offset: *payload_offset, + }, + ClientConnect(evt) => EventContext::ClientConnect(*evt), + ClientDisconnect(evt) => EventContext::ClientDisconnect(*evt), + } + } +} + +impl EventContext<'_> { + /// Retreive the event class for an event (i.e., when matching) + /// an event against the registered listeners. + pub fn to_core_event(&self) -> Option { + use EventContext::*; + + match self { + SpeakingStateUpdate { .. } => Some(CoreEvent::SpeakingStateUpdate), + SpeakingUpdate { .. } => Some(CoreEvent::SpeakingUpdate), + VoicePacket { .. } => Some(CoreEvent::VoicePacket), + RtcpPacket { .. } => Some(CoreEvent::RtcpPacket), + ClientConnect { .. } => Some(CoreEvent::ClientConnect), + ClientDisconnect { .. } => Some(CoreEvent::ClientDisconnect), + _ => None, + } + } +} diff --git a/src/events/core.rs b/src/events/core.rs new file mode 100644 index 0000000..df5eee4 --- /dev/null +++ b/src/events/core.rs @@ -0,0 +1,31 @@ +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +/// Voice core events occur on receipt of +/// voice packets and telemetry. +/// +/// Core events persist while the `action` in [`EventData`] +/// returns `None`. +/// +/// [`EventData`]: struct.EventData.html +pub enum CoreEvent { + /// Fired on receipt of a speaking state update from another host. + /// + /// Note: this will fire when a user starts speaking for the first time, + /// or changes their capabilities. + SpeakingStateUpdate, + /// Fires when a source starts speaking, or stops speaking + /// (*i.e.*, 5 consecutive silent frames). + SpeakingUpdate, + /// Fires on receipt of a voice packet from another stream in the voice call. + /// + /// As RTP packets do not map to Discord's notion of users, SSRCs must be mapped + /// back using the user IDs seen through client connection, disconnection, + /// or speaking state update. + VoicePacket, + /// Fires on receipt of an RTCP packet, containing various call stats + /// such as latency reports. + RtcpPacket, + /// Fires whenever a user connects to the same stream as the bot. + ClientConnect, + /// Fires whenever a user disconnects from the same stream as the bot. + ClientDisconnect, +} diff --git a/src/events/data.rs b/src/events/data.rs new file mode 100644 index 0000000..cd12c91 --- /dev/null +++ b/src/events/data.rs @@ -0,0 +1,88 @@ +use super::*; +use std::{cmp::Ordering, time::Duration}; + +/// Internal representation of an event, as handled by the audio context. +pub struct EventData { + pub(crate) event: Event, + pub(crate) fire_time: Option, + pub(crate) action: Box, +} + +impl EventData { + /// Create a representation of an event and its associated handler. + /// + /// An event handler, `action`, receives an [`EventContext`] and optionally + /// produces a new [`Event`] type for itself. Returning `None` will + /// maintain the same event type, while removing any [`Delayed`] entries. + /// Event handlers will be re-added with their new trigger condition, + /// or removed if [`Cancel`]led + /// + /// [`EventContext`]: enum.EventContext.html + /// [`Event`]: enum.Event.html + /// [`Delayed`]: enum.Event.html#variant.Delayed + /// [`Cancel`]: enum.Event.html#variant.Cancel + pub fn new(event: Event, action: F) -> Self { + Self { + event, + fire_time: None, + action: Box::new(action), + } + } + + /// Computes the next firing time for a timer event. + pub fn compute_activation(&mut self, now: Duration) { + match self.event { + Event::Periodic(period, phase) => { + self.fire_time = Some(now + phase.unwrap_or(period)); + }, + Event::Delayed(offset) => { + self.fire_time = Some(now + offset); + }, + _ => {}, + } + } +} + +impl std::fmt::Debug for EventData { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + write!( + f, + "Event {{ event: {:?}, fire_time: {:?}, action: }}", + self.event, self.fire_time + ) + } +} + +/// Events are ordered/compared based on their firing time. +impl Ord for EventData { + fn cmp(&self, other: &Self) -> Ordering { + if self.fire_time.is_some() && other.fire_time.is_some() { + let t1 = self + .fire_time + .as_ref() + .expect("T1 known to be well-defined by above."); + let t2 = other + .fire_time + .as_ref() + .expect("T2 known to be well-defined by above."); + + t1.cmp(&t2) + } else { + Ordering::Equal + } + } +} + +impl PartialOrd for EventData { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl PartialEq for EventData { + fn eq(&self, other: &Self) -> bool { + self.fire_time == other.fire_time + } +} + +impl Eq for EventData {} diff --git a/src/events/mod.rs b/src/events/mod.rs new file mode 100644 index 0000000..b70961f --- /dev/null +++ b/src/events/mod.rs @@ -0,0 +1,91 @@ +//! Events relating to tracks, timing, and other callers. + +mod context; +mod core; +mod data; +mod store; +mod track; +mod untimed; + +pub use self::{context::*, core::*, data::*, store::*, track::*, untimed::*}; + +use async_trait::async_trait; +use std::time::Duration; + +#[async_trait] +/// Trait to handle an event which can be fired per-track, or globally. +/// +/// These may be feasibly reused between several event sources. +pub trait EventHandler: Send + Sync { + /// Respond to one received event. + async fn act(&self, ctx: &EventContext<'_>) -> Option; +} + +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +/// Classes of event which may occur, triggering a handler +/// at the local (track-specific) or global level. +/// +/// Local time-based events rely upon the current playback +/// time of a track, and so will not fire if a track becomes paused +/// or stops. In case this is required, global events are a better +/// fit. +/// +/// Event handlers themselves are described in [`EventData::action`]. +/// +/// [`EventData::action`]: struct.EventData.html#method.action +pub enum Event { + /// Periodic events rely upon two parameters: a *period* + /// and an optional *phase*. + /// + /// If the *phase* is `None`, then the event will first fire + /// in one *period*. Periodic events repeat automatically + /// so long as the `action` in [`EventData`] returns `None`. + /// + /// [`EventData`]: struct.EventData.html + Periodic(Duration, Option), + /// Delayed events rely upon a *delay* parameter, and + /// fire one *delay* after the audio context processes them. + /// + /// Delayed events are automatically removed once fired, + /// so long as the `action` in [`EventData`] returns `None`. + /// + /// [`EventData`]: struct.EventData.html + Delayed(Duration), + /// Track events correspond to certain actions or changes + /// of state, such as a track finishing, looping, or being + /// manually stopped. + /// + /// Track events persist while the `action` in [`EventData`] + /// returns `None`. + /// + /// [`EventData`]: struct.EventData.html + Track(TrackEvent), + /// Core events + /// + /// Track events persist while the `action` in [`EventData`] + /// returns `None`. Core events **must** be applied globally, + /// as attaching them to a track is a no-op. + /// + /// [`EventData`]: struct.EventData.html + Core(CoreEvent), + /// Cancels the event, if it was intended to persist. + Cancel, +} + +impl Event { + pub(crate) fn is_global_only(&self) -> bool { + matches!(self, Self::Core(_)) + } +} + +impl From for Event { + fn from(evt: TrackEvent) -> Self { + Event::Track(evt) + } +} + +impl From for Event { + fn from(evt: CoreEvent) -> Self { + Event::Core(evt) + } +} diff --git a/src/events/store.rs b/src/events/store.rs new file mode 100644 index 0000000..6518ee2 --- /dev/null +++ b/src/events/store.rs @@ -0,0 +1,252 @@ +use super::*; +use crate::{ + constants::*, + tracks::{PlayMode, TrackHandle, TrackState}, +}; +use std::{ + collections::{BinaryHeap, HashMap}, + time::Duration, +}; +use tracing::info; + +#[derive(Debug, Default)] +/// Storage for [`EventData`], designed to be used for both local and global contexts. +/// +/// Timed events are stored in a binary heap for fast selection, and have custom `Eq`, +/// `Ord`, etc. implementations to support (only) this. +/// +/// [`EventData`]: struct.EventData.html +pub struct EventStore { + timed: BinaryHeap, + untimed: HashMap>, + local_only: bool, +} + +impl EventStore { + /// Creates a new event store to be used globally. + pub fn new() -> Self { + Default::default() + } + + /// Creates a new event store to be used within a [`Track`]. + /// + /// This is usually automatically installed by the driver once + /// a track has been registered. + /// + /// [`Track`]: ../tracks/struct.Track.html + pub fn new_local() -> Self { + EventStore { + local_only: true, + ..Default::default() + } + } + + /// Add an event to this store. + /// + /// Updates `evt` according to [`EventData::compute_activation`]. + /// + /// [`EventData::compute_activation`]: struct.EventData.html#method.compute_activation + pub fn add_event(&mut self, mut evt: EventData, now: Duration) { + evt.compute_activation(now); + + if self.local_only && evt.event.is_global_only() { + return; + } + + use Event::*; + match evt.event { + Core(c) => { + self.untimed + .entry(c.into()) + .or_insert_with(Vec::new) + .push(evt); + }, + Track(t) => { + self.untimed + .entry(t.into()) + .or_insert_with(Vec::new) + .push(evt); + }, + Delayed(_) | Periodic(_, _) => { + self.timed.push(evt); + }, + _ => { + // Event cancelled. + }, + } + } + + /// Processes all events due up to and including `now`. + pub(crate) async fn process_timed(&mut self, now: Duration, ctx: EventContext<'_>) { + while let Some(evt) = self.timed.peek() { + if evt + .fire_time + .as_ref() + .expect("Timed event must have a fire_time.") + > &now + { + break; + } + let mut evt = self + .timed + .pop() + .expect("Can only succeed due to peek = Some(...)."); + + let old_evt_type = evt.event; + if let Some(new_evt_type) = evt.action.act(&ctx).await { + evt.event = new_evt_type; + self.add_event(evt, now); + } else if let Event::Periodic(d, _) = old_evt_type { + evt.event = Event::Periodic(d, None); + self.add_event(evt, now); + } + } + } + + /// Processes all events attached to the given track event. + pub(crate) async fn process_untimed( + &mut self, + now: Duration, + untimed_event: UntimedEvent, + ctx: EventContext<'_>, + ) { + // move a Vec in and out: not too expensive, but could be better. + // Although it's obvious that moving an event out of one vec and into + // another necessitates that they be different event types, thus entries, + // convincing the compiler of this is non-trivial without making them dedicated + // fields. + let events = self.untimed.remove(&untimed_event); + if let Some(mut events) = events { + // TODO: Possibly use tombstones to prevent realloc/memcpys? + // i.e., never shrink array, replace ended tracks with , + // maintain a "first-track" stack and freelist alongside. + let mut i = 0; + while i < events.len() { + let evt = &mut events[i]; + // Only remove/readd if the event type changes (i.e., Some AND new != old) + if let Some(new_evt_type) = evt.action.act(&ctx).await { + if evt.event == new_evt_type { + let mut evt = events.remove(i); + + evt.event = new_evt_type; + self.add_event(evt, now); + } else { + i += 1; + } + } else { + i += 1; + }; + } + self.untimed.insert(untimed_event, events); + } + } +} + +#[derive(Debug, Default)] +pub(crate) struct GlobalEvents { + pub(crate) store: EventStore, + pub(crate) time: Duration, + pub(crate) awaiting_tick: HashMap>, +} + +impl GlobalEvents { + pub(crate) fn add_event(&mut self, evt: EventData) { + self.store.add_event(evt, self.time); + } + + pub(crate) async fn fire_core_event(&mut self, evt: CoreEvent, ctx: EventContext<'_>) { + self.store.process_untimed(self.time, evt.into(), ctx).await; + } + + pub(crate) fn fire_track_event(&mut self, evt: TrackEvent, index: usize) { + let holder = self.awaiting_tick.entry(evt).or_insert_with(Vec::new); + + holder.push(index); + } + + pub(crate) async fn tick( + &mut self, + events: &mut Vec, + states: &mut Vec, + handles: &mut Vec, + ) { + // Global timed events + self.time += TIMESTEP_LENGTH; + self.store + .process_timed(self.time, EventContext::Track(&[])) + .await; + + // Local timed events + for (i, state) in states.iter_mut().enumerate() { + if state.playing == PlayMode::Play { + state.step_frame(); + + let event_store = events + .get_mut(i) + .expect("Missing store index for Tick (local timed)."); + let handle = handles + .get_mut(i) + .expect("Missing handle index for Tick (local timed)."); + + event_store + .process_timed(state.play_time, EventContext::Track(&[(&state, &handle)])) + .await; + } + } + + for (evt, indices) in self.awaiting_tick.iter() { + let untimed = (*evt).into(); + + if !indices.is_empty() { + info!("Firing {:?} for {:?}", evt, indices); + } + + // Local untimed track events. + for &i in indices.iter() { + let event_store = events + .get_mut(i) + .expect("Missing store index for Tick (local untimed)."); + let handle = handles + .get_mut(i) + .expect("Missing handle index for Tick (local untimed)."); + let state = states + .get_mut(i) + .expect("Missing state index for Tick (local untimed)."); + + event_store + .process_untimed( + state.position, + untimed, + EventContext::Track(&[(&state, &handle)]), + ) + .await; + } + + // Global untimed track events. + if self.store.untimed.contains_key(&untimed) && !indices.is_empty() { + let global_ctx: Vec<(&TrackState, &TrackHandle)> = indices + .iter() + .map(|i| { + ( + states + .get(*i) + .expect("Missing state index for Tick (global untimed)"), + handles + .get(*i) + .expect("Missing handle index for Tick (global untimed)"), + ) + }) + .collect(); + + self.store + .process_untimed(self.time, untimed, EventContext::Track(&global_ctx[..])) + .await + } + } + + // Now drain vecs. + for (_evt, indices) in self.awaiting_tick.iter_mut() { + indices.clear(); + } + } +} diff --git a/src/events/track.rs b/src/events/track.rs new file mode 100644 index 0000000..df567a9 --- /dev/null +++ b/src/events/track.rs @@ -0,0 +1,16 @@ +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +/// Track events correspond to certain actions or changes +/// of state, such as a track finishing, looping, or being +/// manually stopped. Voice core events occur on receipt of +/// voice packets and telemetry. +/// +/// Track events persist while the `action` in [`EventData`] +/// returns `None`. +/// +/// [`EventData`]: struct.EventData.html +pub enum TrackEvent { + /// The attached track has ended. + End, + /// The attached track has looped. + Loop, +} diff --git a/src/events/untimed.rs b/src/events/untimed.rs new file mode 100644 index 0000000..4bb4899 --- /dev/null +++ b/src/events/untimed.rs @@ -0,0 +1,28 @@ +use super::*; + +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +/// Track and voice core events. +/// +/// Untimed events persist while the `action` in [`EventData`] +/// returns `None`. +/// +/// [`EventData`]: struct.EventData.html +pub enum UntimedEvent { + /// Untimed events belonging to a track, such as state changes, end, or loops. + Track(TrackEvent), + /// Untimed events belonging to the global context, such as finished tracks, + /// client speaking updates, or RT(C)P voice and telemetry data. + Core(CoreEvent), +} + +impl From for UntimedEvent { + fn from(evt: TrackEvent) -> Self { + UntimedEvent::Track(evt) + } +} + +impl From for UntimedEvent { + fn from(evt: CoreEvent) -> Self { + UntimedEvent::Core(evt) + } +} diff --git a/src/handler.rs b/src/handler.rs new file mode 100644 index 0000000..3ecb089 --- /dev/null +++ b/src/handler.rs @@ -0,0 +1,301 @@ +#[cfg(feature = "driver")] +use crate::{driver::Driver, error::ConnectionResult}; +use crate::{ + error::{JoinError, JoinResult}, + id::{ChannelId, GuildId, UserId}, + info::{ConnectionInfo, ConnectionProgress}, + shards::Shard, +}; +use flume::{Receiver, Sender}; +use serde_json::json; +use tracing::instrument; + +#[cfg(feature = "driver")] +use std::ops::{Deref, DerefMut}; + +#[derive(Clone, Debug)] +enum Return { + Info(Sender), + #[cfg(feature = "driver")] + Conn(Sender>), +} + +/// The Call handler is responsible for a single voice connection, acting +/// as a clean API above the inner state and gateway message management. +/// +/// If the `"driver"` feature is enabled, then a Call exposes all control methods of +/// [`Driver`] via `Deref(Mut)`. +/// +/// [`Driver`]: driver/struct.Driver.html +/// [`Shard`]: ../gateway/struct.Shard.html +#[derive(Clone, Debug)] +pub struct Call { + connection: Option<(ChannelId, ConnectionProgress, Return)>, + + #[cfg(feature = "driver")] + /// The internal controller of the voice connection monitor thread. + driver: Driver, + + guild_id: GuildId, + /// Whether the current handler is set to deafen voice connections. + self_deaf: bool, + /// Whether the current handler is set to mute voice connections. + self_mute: bool, + user_id: UserId, + /// Will be set when a `Call` is made via the [`new`][`Call::new`] + /// method. + /// + /// When set via [`standalone`][`Call::standalone`], it will not be + /// present. + ws: Option, +} + +impl Call { + /// Creates a new Call, which will send out WebSocket messages via + /// the given shard. + #[inline] + #[instrument] + pub fn new(guild_id: GuildId, ws: Shard, user_id: UserId) -> Self { + Self::new_raw(guild_id, Some(ws), user_id) + } + + /// Creates a new, standalone Call which is not connected via + /// WebSocket to the Gateway. + /// + /// Actions such as muting, deafening, and switching channels will not + /// function through this Call and must be done through some other + /// method, as the values will only be internally updated. + /// + /// For most use cases you do not want this. + #[inline] + #[instrument] + pub fn standalone(guild_id: GuildId, user_id: UserId) -> Self { + Self::new_raw(guild_id, None, user_id) + } + + fn new_raw(guild_id: GuildId, ws: Option, user_id: UserId) -> Self { + Call { + connection: None, + #[cfg(feature = "driver")] + driver: Default::default(), + guild_id, + self_deaf: false, + self_mute: false, + user_id, + ws, + } + } + + #[instrument(skip(self))] + fn do_connect(&mut self) { + match &self.connection { + Some((_, ConnectionProgress::Complete(c), Return::Info(tx))) => { + // It's okay if the receiver hung up. + let _ = tx.send(c.clone()); + }, + #[cfg(feature = "driver")] + Some((_, ConnectionProgress::Complete(c), Return::Conn(tx))) => { + self.driver.raw_connect(c.clone(), tx.clone()); + }, + _ => {}, + } + } + + /// Sets whether the current connection is to be deafened. + /// + /// If there is no live voice connection, then this only acts as a settings + /// update for future connections. + /// + /// **Note**: Unlike in the official client, you _can_ be deafened while + /// not being muted. + /// + /// **Note**: If the `Call` was created via [`standalone`], then this + /// will _only_ update whether the connection is internally deafened. + /// + /// [`standalone`]: #method.standalone + #[instrument(skip(self))] + pub async fn deafen(&mut self, deaf: bool) -> JoinResult<()> { + self.self_deaf = deaf; + + self.update().await + } + + /// Returns whether the current connection is self-deafened in this server. + /// + /// This is purely cosmetic. + #[instrument(skip(self))] + pub fn is_deaf(&self) -> bool { + self.self_deaf + } + + #[cfg(feature = "driver")] + /// Connect or switch to the given voice channel by its Id. + #[instrument(skip(self))] + pub async fn join( + &mut self, + channel_id: ChannelId, + ) -> JoinResult>> { + let (tx, rx) = flume::unbounded(); + + self.connection = Some(( + channel_id, + ConnectionProgress::new(self.guild_id, self.user_id), + Return::Conn(tx), + )); + + self.update().await.map(|_| rx) + } + + /// Join the selected voice channel, *without* running/starting an RTP + /// session or running the driver. + /// + /// Use this if you require connection info for lavalink, + /// some other voice implementation, or don't want to use the driver for a given call. + #[instrument(skip(self))] + pub async fn join_gateway( + &mut self, + channel_id: ChannelId, + ) -> JoinResult> { + let (tx, rx) = flume::unbounded(); + + self.connection = Some(( + channel_id, + ConnectionProgress::new(self.guild_id, self.user_id), + Return::Info(tx), + )); + + self.update().await.map(|_| rx) + } + + /// Leaves the current voice channel, disconnecting from it. + /// + /// This does _not_ forget settings, like whether to be self-deafened or + /// self-muted. + /// + /// **Note**: If the `Call` was created via [`standalone`], then this + /// will _only_ update whether the connection is internally connected to a + /// voice channel. + /// + /// [`standalone`]: #method.standalone + #[instrument(skip(self))] + pub async fn leave(&mut self) -> JoinResult<()> { + // Only send an update if we were in a voice channel. + self.connection = None; + + #[cfg(feature = "driver")] + self.driver.leave(); + + self.update().await + } + + /// Sets whether the current connection is to be muted. + /// + /// If there is no live voice connection, then this only acts as a settings + /// update for future connections. + /// + /// **Note**: If the `Call` was created via [`standalone`], then this + /// will _only_ update whether the connection is internally muted. + /// + /// [`standalone`]: #method.standalone + #[instrument(skip(self))] + pub async fn mute(&mut self, mute: bool) -> JoinResult<()> { + self.self_mute = mute; + + #[cfg(feature = "driver")] + self.driver.mute(mute); + + self.update().await + } + + /// Returns whether the current connection is self-muted in this server. + #[instrument(skip(self))] + pub fn is_mute(&self) -> bool { + self.self_mute + } + + /// Updates the voice server data. + /// + /// You should only need to use this if you initialized the `Call` via + /// [`standalone`]. + /// + /// Refer to the documentation for [`connect`] for when this will + /// automatically connect to a voice channel. + /// + /// [`connect`]: #method.connect + /// [`standalone`]: #method.standalone + #[instrument(skip(self, token))] + pub fn update_server(&mut self, endpoint: String, token: String) { + let try_conn = if let Some((_, ref mut progress, _)) = self.connection.as_mut() { + progress.apply_server_update(endpoint, token) + } else { + false + }; + + if try_conn { + self.do_connect(); + } + } + + /// Updates the internal voice state of the current user. + /// + /// You should only need to use this if you initialized the `Call` via + /// [`standalone`]. + /// + /// refer to the documentation for [`connect`] for when this will + /// automatically connect to a voice channel. + /// + /// [`connect`]: #method.connect + /// [`standalone`]: #method.standalone + #[instrument(skip(self))] + pub fn update_state(&mut self, session_id: String) { + let try_conn = if let Some((_, ref mut progress, _)) = self.connection.as_mut() { + progress.apply_state_update(session_id) + } else { + false + }; + + if try_conn { + self.do_connect(); + } + } + + /// Send an update for the current session over WS. + /// + /// Does nothing if initialized via [`standalone`]. + /// + /// [`standalone`]: #method.standalone + #[instrument(skip(self))] + async fn update(&mut self) -> JoinResult<()> { + if let Some(ws) = self.ws.as_mut() { + let map = json!({ + "op": 4, + "d": { + "channel_id": self.connection.as_ref().map(|c| c.0.0), + "guild_id": self.guild_id.0, + "self_deaf": self.self_deaf, + "self_mute": self.self_mute, + } + }); + + ws.send(map).await + } else { + Err(JoinError::NoSender) + } + } +} + +#[cfg(feature = "driver")] +impl Deref for Call { + type Target = Driver; + + fn deref(&self) -> &Self::Target { + &self.driver + } +} + +#[cfg(feature = "driver")] +impl DerefMut for Call { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.driver + } +} diff --git a/src/id.rs b/src/id.rs new file mode 100644 index 0000000..f28e108 --- /dev/null +++ b/src/id.rs @@ -0,0 +1,121 @@ +//! Newtypes around Discord IDs for library cross-compatibility. + +#[cfg(feature = "driver")] +use crate::model::id::{GuildId as DriverGuild, UserId as DriverUser}; +#[cfg(feature = "serenity")] +use serenity::model::id::{ + ChannelId as SerenityChannel, + GuildId as SerenityGuild, + UserId as SerenityUser, +}; +use std::fmt::{Display, Formatter, Result as FmtResult}; +#[cfg(feature = "twilight")] +use twilight_model::id::{ + ChannelId as TwilightChannel, + GuildId as TwilightGuild, + UserId as TwilightUser, +}; + +/// ID of a Discord voice/text channel. +#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)] +pub struct ChannelId(pub u64); + +/// ID of a Discord guild (colloquially, "server"). +#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)] +pub struct GuildId(pub u64); + +/// ID of a Discord user. +#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)] +pub struct UserId(pub u64); + +impl Display for ChannelId { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + Display::fmt(&self.0, f) + } +} + +impl From for ChannelId { + fn from(id: u64) -> Self { + Self(id) + } +} + +#[cfg(feature = "serenity")] +impl From for ChannelId { + fn from(id: SerenityChannel) -> Self { + Self(id.0) + } +} + +#[cfg(feature = "twilight")] +impl From for ChannelId { + fn from(id: TwilightChannel) -> Self { + Self(id.0) + } +} + +impl Display for GuildId { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + Display::fmt(&self.0, f) + } +} + +impl From for GuildId { + fn from(id: u64) -> Self { + Self(id) + } +} + +#[cfg(feature = "serenity")] +impl From for GuildId { + fn from(id: SerenityGuild) -> Self { + Self(id.0) + } +} + +#[cfg(feature = "driver")] +impl From for DriverGuild { + fn from(id: GuildId) -> Self { + Self(id.0) + } +} + +#[cfg(feature = "twilight")] +impl From for GuildId { + fn from(id: TwilightGuild) -> Self { + Self(id.0) + } +} + +impl Display for UserId { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + Display::fmt(&self.0, f) + } +} + +impl From for UserId { + fn from(id: u64) -> Self { + Self(id) + } +} + +#[cfg(feature = "serenity")] +impl From for UserId { + fn from(id: SerenityUser) -> Self { + Self(id.0) + } +} + +#[cfg(feature = "driver")] +impl From for DriverUser { + fn from(id: UserId) -> Self { + Self(id.0) + } +} + +#[cfg(feature = "twilight")] +impl From for UserId { + fn from(id: TwilightUser) -> Self { + Self(id.0) + } +} diff --git a/src/info.rs b/src/info.rs new file mode 100644 index 0000000..8b3fdb3 --- /dev/null +++ b/src/info.rs @@ -0,0 +1,137 @@ +use crate::id::{GuildId, UserId}; +use std::fmt; + +#[derive(Clone, Debug)] +pub(crate) enum ConnectionProgress { + Complete(ConnectionInfo), + Incomplete(Partial), +} + +impl ConnectionProgress { + pub fn new(guild_id: GuildId, user_id: UserId) -> Self { + ConnectionProgress::Incomplete(Partial { + guild_id, + user_id, + ..Default::default() + }) + } + + pub(crate) fn apply_state_update(&mut self, session_id: String) -> bool { + use ConnectionProgress::*; + match self { + Complete(c) => { + let should_reconn = c.session_id != session_id; + c.session_id = session_id; + should_reconn + }, + Incomplete(i) => i + .apply_state_update(session_id) + .map(|info| { + *self = Complete(info); + }) + .is_some(), + } + } + + pub(crate) fn apply_server_update(&mut self, endpoint: String, token: String) -> bool { + use ConnectionProgress::*; + match self { + Complete(c) => { + let should_reconn = c.endpoint != endpoint || c.token != token; + + c.endpoint = endpoint; + c.token = token; + + should_reconn + }, + Incomplete(i) => i + .apply_server_update(endpoint, token) + .map(|info| { + *self = Complete(info); + }) + .is_some(), + } + } +} + +/// Parameters and information needed to start communicating with Discord's voice servers, either +/// with the Songbird driver, lavalink, or other system. +#[derive(Clone)] +pub struct ConnectionInfo { + /// URL of the voice websocket gateway server assigned to this call. + pub endpoint: String, + /// ID of the target voice channel's parent guild. + /// + /// Bots cannot connect to a guildless (i.e., direct message) voice call. + pub guild_id: GuildId, + /// Unique string describing this session for validation/authentication purposes. + pub session_id: String, + /// Ephemeral secret used to validate the above session. + pub token: String, + /// UserID of this bot. + pub user_id: UserId, +} + +impl fmt::Debug for ConnectionInfo { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ConnectionInfo") + .field("endpoint", &self.endpoint) + .field("guild_id", &self.guild_id) + .field("session_id", &self.session_id) + .field("token", &"") + .field("user_id", &self.user_id) + .finish() + } +} + +#[derive(Clone, Default)] +pub(crate) struct Partial { + pub endpoint: Option, + pub guild_id: GuildId, + pub session_id: Option, + pub token: Option, + pub user_id: UserId, +} + +impl fmt::Debug for Partial { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Partial") + .field("endpoint", &self.endpoint) + .field("session_id", &self.session_id) + .field("token_is_some", &self.token.is_some()) + .finish() + } +} + +impl Partial { + fn finalise(&mut self) -> Option { + if self.endpoint.is_some() && self.session_id.is_some() && self.token.is_some() { + let endpoint = self.endpoint.take().unwrap(); + let session_id = self.session_id.take().unwrap(); + let token = self.token.take().unwrap(); + + Some(ConnectionInfo { + endpoint, + session_id, + token, + guild_id: self.guild_id, + user_id: self.user_id, + }) + } else { + None + } + } + + fn apply_state_update(&mut self, session_id: String) -> Option { + self.session_id = Some(session_id); + + self.finalise() + } + + fn apply_server_update(&mut self, endpoint: String, token: String) -> Option { + self.endpoint = Some(endpoint); + self.token = Some(token); + + self.finalise() + } +} diff --git a/src/input/cached/compressed.rs b/src/input/cached/compressed.rs new file mode 100644 index 0000000..183cba9 --- /dev/null +++ b/src/input/cached/compressed.rs @@ -0,0 +1,303 @@ +use super::{apply_length_hint, compressed_cost_per_sec, default_config}; +use crate::{ + constants::*, + input::{ + error::{Error, Result}, + CodecType, + Container, + Input, + Metadata, + Reader, + }, +}; +use audiopus::{ + coder::Encoder as OpusEncoder, + Application, + Bitrate, + Channels, + Error as OpusError, + ErrorCode as OpusErrorCode, + SampleRate, +}; +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; +use std::{ + convert::TryInto, + io::{Error as IoError, ErrorKind as IoErrorKind, Read, Result as IoResult}, + mem, + sync::atomic::{AtomicUsize, Ordering}, +}; +use streamcatcher::{Config, NeedsBytes, Stateful, Transform, TransformPosition, TxCatcher}; +use tracing::{debug, trace}; + +/// A wrapper around an existing [`Input`] which compresses +/// the input using the Opus codec before storing it in memory. +/// +/// The main purpose of this wrapper is to enable seeking on +/// incompatible sources (i.e., ffmpeg output) and to ease resource +/// consumption for commonly reused/shared tracks. [`Restartable`] +/// and [`Memory`] offer the same functionality with different +/// tradeoffs. +/// +/// This is intended for use with larger, repeatedly used audio +/// tracks shared between sources, and stores the sound data +/// retrieved as **compressed Opus audio**. There is an associated memory cost, +/// but this is far smaller than using a [`Memory`]. +/// +/// [`Input`]: ../struct.Input.html +/// [`Memory`]: struct.Memory.html +/// [`Restartable`]: ../struct.Restartable.html +#[derive(Clone, Debug)] +pub struct Compressed { + /// Inner shared bytestore. + pub raw: TxCatcher, OpusCompressor>, + /// Metadata moved out of the captured source. + pub metadata: Metadata, + /// Stereo-ness of the captured source. + pub stereo: bool, +} + +impl Compressed { + /// Wrap an existing [`Input`] with an in-memory store, compressed using Opus. + /// + /// [`Input`]: ../struct.Input.html + /// [`Metadata.duration`]: ../struct.Metadata.html#structfield.duration + pub fn new(source: Input, bitrate: Bitrate) -> Result { + Self::with_config(source, bitrate, None) + } + + /// Wrap an existing [`Input`] with an in-memory store, compressed using Opus. + /// + /// `config.length_hint` may be used to control the size of the initial chunk, preventing + /// needless allocations and copies. If this is not present, the value specified in + /// `source`'s [`Metadata.duration`] will be used. + /// + /// [`Input`]: ../struct.Input.html + /// [`Metadata.duration`]: ../struct.Metadata.html#structfield.duration + pub fn with_config(source: Input, bitrate: Bitrate, config: Option) -> Result { + let channels = if source.stereo { + Channels::Stereo + } else { + Channels::Mono + }; + let mut encoder = OpusEncoder::new(SampleRate::Hz48000, channels, Application::Audio)?; + + encoder.set_bitrate(bitrate)?; + + Self::with_encoder(source, encoder, config) + } + + /// Wrap an existing [`Input`] with an in-memory store, compressed using a user-defined + /// Opus encoder. + /// + /// `length_hint` functions as in [`new`]. This function's behaviour is undefined if your encoder + /// has a different sample rate than 48kHz, and if the decoder has a different channel count from the source. + /// + /// [`Input`]: ../struct.Input.html + /// [`new`]: #method.new + pub fn with_encoder( + mut source: Input, + encoder: OpusEncoder, + config: Option, + ) -> Result { + let bitrate = encoder.bitrate()?; + let cost_per_sec = compressed_cost_per_sec(bitrate); + let stereo = source.stereo; + let metadata = source.metadata.take(); + + let mut config = config.unwrap_or_else(|| default_config(cost_per_sec)); + + // apply length hint. + if config.length_hint.is_none() { + if let Some(dur) = metadata.duration { + apply_length_hint(&mut config, dur, cost_per_sec); + } + } + + let raw = config + .build_tx(Box::new(source), OpusCompressor::new(encoder, stereo)) + .map_err(Error::Streamcatcher)?; + + Ok(Self { + raw, + metadata, + stereo, + }) + } + + /// Acquire a new handle to this object, creating a new + /// view of the existing cached data from the beginning. + pub fn new_handle(&self) -> Self { + Self { + raw: self.raw.new_handle(), + metadata: self.metadata.clone(), + stereo: self.stereo, + } + } +} + +impl From for Input { + fn from(src: Compressed) -> Self { + Input::new( + true, + Reader::Compressed(src.raw), + CodecType::Opus + .try_into() + .expect("Default decoder values are known to be valid."), + Container::Dca { first_frame: 0 }, + Some(src.metadata), + ) + } +} + +/// Transform applied inside [`Compressed`], converting a floating-point PCM +/// input stream into a DCA-framed Opus stream. +/// +/// Created and managed by [`Compressed`]. +/// +/// [`Compressed`]: struct.Compressed.html +#[derive(Debug)] +pub struct OpusCompressor { + encoder: OpusEncoder, + last_frame: Vec, + stereo_input: bool, + frame_pos: usize, + audio_bytes: AtomicUsize, +} + +impl OpusCompressor { + fn new(encoder: OpusEncoder, stereo_input: bool) -> Self { + Self { + encoder, + last_frame: Vec::with_capacity(4000), + stereo_input, + frame_pos: 0, + audio_bytes: Default::default(), + } + } +} + +impl Transform for OpusCompressor +where + T: Read, +{ + fn transform_read(&mut self, src: &mut T, buf: &mut [u8]) -> IoResult { + let output_start = mem::size_of::(); + let mut eof = false; + + let mut raw_len = 0; + let mut out = None; + let mut sample_buf = [0f32; STEREO_FRAME_SIZE]; + let samples_in_frame = if self.stereo_input { + STEREO_FRAME_SIZE + } else { + MONO_FRAME_SIZE + }; + + // Purge old frame and read new, if needed. + if self.frame_pos == self.last_frame.len() + output_start || self.last_frame.is_empty() { + self.last_frame.resize(self.last_frame.capacity(), 0); + + // We can't use `read_f32_into` because we can't guarantee the buffer will be filled. + for el in sample_buf[..samples_in_frame].iter_mut() { + match src.read_f32::() { + Ok(sample) => { + *el = sample; + raw_len += 1; + }, + Err(e) if e.kind() == IoErrorKind::UnexpectedEof => { + eof = true; + break; + }, + Err(e) => { + out = Some(Err(e)); + break; + }, + } + } + + if out.is_none() && raw_len > 0 { + loop { + // NOTE: we don't index by raw_len because the last frame can be too small + // to occupy a "whole packet". Zero-padding is the correct behaviour. + match self + .encoder + .encode_float(&sample_buf[..samples_in_frame], &mut self.last_frame[..]) + { + Ok(pkt_len) => { + trace!("Next packet to write has {:?}", pkt_len); + self.frame_pos = 0; + self.last_frame.truncate(pkt_len); + break; + }, + Err(OpusError::Opus(OpusErrorCode::BufferTooSmall)) => { + // If we need more capacity to encode this frame, then take it. + trace!("Resizing inner buffer (+256)."); + self.last_frame.resize(self.last_frame.len() + 256, 0); + }, + Err(e) => { + debug!("Read error {:?} {:?} {:?}.", e, out, raw_len); + out = Some(Err(IoError::new(IoErrorKind::Other, e))); + break; + }, + } + } + } + } + + if out.is_none() { + // Write from frame we have. + let start = if self.frame_pos < output_start { + (&mut buf[..output_start]) + .write_i16::(self.last_frame.len() as i16) + .expect( + "Minimum bytes requirement for Opus (2) should mean that an i16 \ + may always be written.", + ); + self.frame_pos += output_start; + + trace!("Wrote frame header: {}.", self.last_frame.len()); + + output_start + } else { + 0 + }; + + let out_pos = self.frame_pos - output_start; + let remaining = self.last_frame.len() - out_pos; + let write_len = remaining.min(buf.len() - start); + buf[start..start + write_len] + .copy_from_slice(&self.last_frame[out_pos..out_pos + write_len]); + self.frame_pos += write_len; + trace!("Appended {} to inner store", write_len); + out = Some(Ok(write_len + start)); + } + + // NOTE: use of raw_len here preserves true sample length even if + // stream is extended to 20ms boundary. + out.unwrap_or_else(|| Err(IoError::new(IoErrorKind::Other, "Unclear."))) + .map(|compressed_sz| { + self.audio_bytes + .fetch_add(raw_len * mem::size_of::(), Ordering::Release); + + if eof { + TransformPosition::Finished + } else { + TransformPosition::Read(compressed_sz) + } + }) + } +} + +impl NeedsBytes for OpusCompressor { + fn min_bytes_required(&self) -> usize { + 2 + } +} + +impl Stateful for OpusCompressor { + type State = usize; + + fn state(&self) -> Self::State { + self.audio_bytes.load(Ordering::Acquire) + } +} diff --git a/src/input/cached/hint.rs b/src/input/cached/hint.rs new file mode 100644 index 0000000..b32fbce --- /dev/null +++ b/src/input/cached/hint.rs @@ -0,0 +1,40 @@ +use std::time::Duration; +use streamcatcher::Config; + +/// Expected amount of time that an input should last. +#[derive(Copy, Clone, Debug)] +pub enum LengthHint { + /// Estimate of a source's length in bytes. + Bytes(usize), + /// Estimate of a source's length in time. + /// + /// This will be converted to a bytecount at setup. + Time(Duration), +} + +impl From for LengthHint { + fn from(size: usize) -> Self { + LengthHint::Bytes(size) + } +} + +impl From for LengthHint { + fn from(size: Duration) -> Self { + LengthHint::Time(size) + } +} + +/// Modify the given cache configuration to initially allocate +/// enough bytes to store a length of audio at the given bitrate. +pub fn apply_length_hint(config: &mut Config, hint: H, cost_per_sec: usize) +where + H: Into, +{ + config.length_hint = Some(match hint.into() { + LengthHint::Bytes(a) => a, + LengthHint::Time(t) => { + let s = t.as_secs() + if t.subsec_millis() > 0 { 1 } else { 0 }; + (s as usize) * cost_per_sec + }, + }); +} diff --git a/src/input/cached/memory.rs b/src/input/cached/memory.rs new file mode 100644 index 0000000..92062cc --- /dev/null +++ b/src/input/cached/memory.rs @@ -0,0 +1,116 @@ +use super::{apply_length_hint, default_config, raw_cost_per_sec}; +use crate::input::{ + error::{Error, Result}, + CodecType, + Container, + Input, + Metadata, + Reader, +}; +use std::convert::{TryFrom, TryInto}; +use streamcatcher::{Catcher, Config}; + +/// A wrapper around an existing [`Input`] which caches +/// the decoded and converted audio data locally in memory. +/// +/// The main purpose of this wrapper is to enable seeking on +/// incompatible sources (i.e., ffmpeg output) and to ease resource +/// consumption for commonly reused/shared tracks. [`Restartable`] +/// and [`Compressed`] offer the same functionality with different +/// tradeoffs. +/// +/// This is intended for use with small, repeatedly used audio +/// tracks shared between sources, and stores the sound data +/// retrieved in **uncompressed floating point** form to minimise the +/// cost of audio processing. This is a significant *3 Mbps (375 kiB/s)*, +/// or 131 MiB of RAM for a 6 minute song. +/// +/// [`Input`]: ../struct.Input.html +/// [`Compressed`]: struct.Compressed.html +/// [`Restartable`]: ../struct.Restartable.html +#[derive(Clone, Debug)] +pub struct Memory { + /// Inner shared bytestore. + pub raw: Catcher>, + /// Metadata moved out of the captured source. + pub metadata: Metadata, + /// Codec used to read the inner bytestore. + pub kind: CodecType, + /// Stereo-ness of the captured source. + pub stereo: bool, + /// Framing mechanism for the inner bytestore. + pub container: Container, +} + +impl Memory { + /// Wrap an existing [`Input`] with an in-memory store with the same codec and framing. + /// + /// [`Input`]: ../struct.Input.html + pub fn new(source: Input) -> Result { + Self::with_config(source, None) + } + + /// Wrap an existing [`Input`] with an in-memory store with the same codec and framing. + /// + /// `length_hint` may be used to control the size of the initial chunk, preventing + /// needless allocations and copies. If this is not present, the value specified in + /// `source`'s [`Metadata.duration`] will be used, assuming that the source is uncompressed. + /// + /// [`Input`]: ../struct.Input.html + /// [`Metadata.duration`]: ../struct.Metadata.html#structfield.duration + pub fn with_config(mut source: Input, config: Option) -> Result { + let stereo = source.stereo; + let kind = (&source.kind).into(); + let container = source.container; + let metadata = source.metadata.take(); + + let cost_per_sec = raw_cost_per_sec(stereo); + + let mut config = config.unwrap_or_else(|| default_config(cost_per_sec)); + + // apply length hint. + if config.length_hint.is_none() { + if let Some(dur) = metadata.duration { + apply_length_hint(&mut config, dur, cost_per_sec); + } + } + + let raw = config + .build(Box::new(source.reader)) + .map_err(Error::Streamcatcher)?; + + Ok(Self { + raw, + metadata, + kind, + stereo, + container, + }) + } + + /// Acquire a new handle to this object, creating a new + /// view of the existing cached data from the beginning. + pub fn new_handle(&self) -> Self { + Self { + raw: self.raw.new_handle(), + metadata: self.metadata.clone(), + kind: self.kind, + stereo: self.stereo, + container: self.container, + } + } +} + +impl TryFrom for Input { + type Error = Error; + + fn try_from(src: Memory) -> Result { + Ok(Input::new( + src.stereo, + Reader::Memory(src.raw), + src.kind.try_into()?, + src.container, + Some(src.metadata), + )) + } +} diff --git a/src/input/cached/mod.rs b/src/input/cached/mod.rs new file mode 100644 index 0000000..5983c81 --- /dev/null +++ b/src/input/cached/mod.rs @@ -0,0 +1,44 @@ +//! In-memory, shared input sources for reuse between calls, fast seeking, and +//! direct Opus frame passthrough. + +mod compressed; +mod hint; +mod memory; +#[cfg(test)] +mod tests; + +pub use self::{compressed::*, hint::*, memory::*}; + +use crate::constants::*; +use crate::input::utils; +use audiopus::Bitrate; +use std::{mem, time::Duration}; +use streamcatcher::{Config, GrowthStrategy}; + +/// Estimates the cost, in B/s, of audio data compressed at the given bitrate. +pub fn compressed_cost_per_sec(bitrate: Bitrate) -> usize { + let framing_cost_per_sec = AUDIO_FRAME_RATE * mem::size_of::(); + + let bitrate_raw = match bitrate { + Bitrate::BitsPerSecond(i) => i, + Bitrate::Auto => 64_000, + Bitrate::Max => 512_000, + } as usize; + + (bitrate_raw / 8) + framing_cost_per_sec +} + +/// Calculates the cost, in B/s, of raw floating-point audio data. +pub fn raw_cost_per_sec(stereo: bool) -> usize { + utils::timestamp_to_byte_count(Duration::from_secs(1), stereo) +} + +/// Provides the default config used by a cached source. +/// +/// This maps to the default configuration in [`streamcatcher`], using +/// a constant chunk size of 5s worth of audio at the given bitrate estimate. +/// +/// [`streamcatcher`]: https://docs.rs/streamcatcher/0.1.0/streamcatcher/struct.Config.html +pub fn default_config(cost_per_sec: usize) -> Config { + Config::new().chunk_size(GrowthStrategy::Constant(5 * cost_per_sec)) +} diff --git a/src/input/cached/tests.rs b/src/input/cached/tests.rs new file mode 100644 index 0000000..d4a7021 --- /dev/null +++ b/src/input/cached/tests.rs @@ -0,0 +1,79 @@ +use super::*; +use crate::{ + constants::*, + input::{error::Error, ffmpeg, Codec, Container, Input, Reader}, + test_utils::*, +}; +use audiopus::{coder::Decoder, Bitrate, Channels, SampleRate}; +use byteorder::{LittleEndian, ReadBytesExt}; +use std::io::{Cursor, Read}; + +#[tokio::test] +async fn streamcatcher_preserves_file() { + let input = make_sine(50 * MONO_FRAME_SIZE, true); + let input_len = input.len(); + + let mut raw = default_config(raw_cost_per_sec(true)) + .build(Cursor::new(input.clone())) + .map_err(Error::Streamcatcher) + .unwrap(); + + let mut out_buf = vec![]; + let read = raw.read_to_end(&mut out_buf).unwrap(); + + assert_eq!(input_len, read); + + assert_eq!(input, out_buf); +} + +#[test] +fn compressed_scans_frames_decodes_mono() { + let data = one_s_compressed_sine(false); + run_through_dca(data.raw); +} + +#[test] +fn compressed_scans_frames_decodes_stereo() { + let data = one_s_compressed_sine(true); + run_through_dca(data.raw); +} + +#[test] +fn compressed_triggers_valid_passthrough() { + let mut input = Input::from(one_s_compressed_sine(true)); + + assert!(input.supports_passthrough()); + + let mut opus_buf = [0u8; 10_000]; + let mut signal_buf = [0i16; 1920]; + + let opus_len = input.read_opus_frame(&mut opus_buf[..]).unwrap(); + + let mut decoder = Decoder::new(SampleRate::Hz48000, Channels::Stereo).unwrap(); + decoder + .decode(Some(&opus_buf[..opus_len]), &mut signal_buf[..], false) + .unwrap(); +} + +fn one_s_compressed_sine(stereo: bool) -> Compressed { + let data = make_sine(50 * MONO_FRAME_SIZE, stereo); + + let input = Input::new(stereo, data.into(), Codec::FloatPcm, Container::Raw, None); + + Compressed::new(input, Bitrate::BitsPerSecond(128_000)).unwrap() +} + +fn run_through_dca(mut src: impl Read) { + let mut decoder = Decoder::new(SampleRate::Hz48000, Channels::Stereo).unwrap(); + + let mut pkt_space = [0u8; 10_000]; + let mut signals = [0i16; 1920]; + + while let Ok(frame_len) = src.read_i16::() { + let pkt_len = src.read(&mut pkt_space[..frame_len as usize]).unwrap(); + + decoder + .decode(Some(&pkt_space[..pkt_len]), &mut signals[..], false) + .unwrap(); + } +} diff --git a/src/input/child.rs b/src/input/child.rs new file mode 100644 index 0000000..47d57f9 --- /dev/null +++ b/src/input/child.rs @@ -0,0 +1,38 @@ +use super::*; +use std::{ + io::{BufReader, Read}, + process::Child, +}; +use tracing::debug; + +/// Handle for a child process which ensures that any subprocesses are properly closed +/// on drop. +#[derive(Debug)] +pub struct ChildContainer(Child); + +pub(crate) fn child_to_reader(child: Child) -> Reader { + Reader::Pipe(BufReader::with_capacity( + STEREO_FRAME_SIZE * mem::size_of::() * CHILD_BUFFER_LEN, + ChildContainer(child), + )) +} + +impl From for Reader { + fn from(container: Child) -> Self { + child_to_reader::(container) + } +} + +impl Read for ChildContainer { + fn read(&mut self, buffer: &mut [u8]) -> IoResult { + self.0.stdout.as_mut().unwrap().read(buffer) + } +} + +impl Drop for ChildContainer { + fn drop(&mut self) { + if let Err(e) = self.0.kill() { + debug!("Error awaiting child process: {:?}", e); + } + } +} diff --git a/src/input/codec/mod.rs b/src/input/codec/mod.rs new file mode 100644 index 0000000..ddd4113 --- /dev/null +++ b/src/input/codec/mod.rs @@ -0,0 +1,99 @@ +//! Decoding schemes for input audio bytestreams. + +mod opus; + +pub use self::opus::OpusDecoderState; + +use super::*; +use std::{fmt::Debug, mem}; + +/// State used to decode input bytes of an [`Input`]. +/// +/// [`Input`]: ../struct.Input.html +#[non_exhaustive] +#[derive(Clone, Debug)] +pub enum Codec { + /// The inner bytestream is encoded using the Opus codec, to be decoded + /// using the given state. + /// + /// Must be combined with a non-[`Raw`] container. + /// + /// [`Raw`]: ../enum.Container.html#variant.Raw + Opus(OpusDecoderState), + /// The inner bytestream is encoded using raw `i16` samples. + /// + /// Must be combined with a [`Raw`] container. + /// + /// [`Raw`]: ../enum.Container.html#variant.Raw + Pcm, + /// The inner bytestream is encoded using raw `f32` samples. + /// + /// Must be combined with a [`Raw`] container. + /// + /// [`Raw`]: ../enum.Container.html#variant.Raw + FloatPcm, +} + +impl From<&Codec> for CodecType { + fn from(f: &Codec) -> Self { + use Codec::*; + + match f { + Opus(_) => Self::Opus, + Pcm => Self::Pcm, + FloatPcm => Self::FloatPcm, + } + } +} + +/// Type of data being passed into an [`Input`]. +/// +/// [`Input`]: ../struct.Input.html +#[non_exhaustive] +#[derive(Copy, Clone, Debug)] +pub enum CodecType { + /// The inner bytestream is encoded using the Opus codec. + /// + /// Must be combined with a non-[`Raw`] container. + /// + /// [`Raw`]: ../enum.Container.html#variant.Raw + Opus, + /// The inner bytestream is encoded using raw `i16` samples. + /// + /// Must be combined with a [`Raw`] container. + /// + /// [`Raw`]: ../enum.Container.html#variant.Raw + Pcm, + /// The inner bytestream is encoded using raw `f32` samples. + /// + /// Must be combined with a [`Raw`] container. + /// + /// [`Raw`]: ../enum.Container.html#variant.Raw + FloatPcm, +} + +impl CodecType { + /// Returns the length of a single output sample, in bytes. + pub fn sample_len(&self) -> usize { + use CodecType::*; + + match self { + Opus | FloatPcm => mem::size_of::(), + Pcm => mem::size_of::(), + } + } +} + +impl TryFrom for Codec { + type Error = Error; + + fn try_from(f: CodecType) -> Result { + use CodecType::*; + + match f { + Opus => Ok(Codec::Opus(OpusDecoderState::new()?)), + Pcm => Ok(Codec::Pcm), + FloatPcm => Ok(Codec::FloatPcm), + } + } +} diff --git a/src/input/codec/opus.rs b/src/input/codec/opus.rs new file mode 100644 index 0000000..1c002cf --- /dev/null +++ b/src/input/codec/opus.rs @@ -0,0 +1,43 @@ +use crate::constants::*; +use audiopus::{coder::Decoder as OpusDecoder, Channels, Error as OpusError}; +use parking_lot::Mutex; +use std::sync::Arc; + +#[derive(Clone, Debug)] +/// Inner state +pub struct OpusDecoderState { + /// Inner decoder used to convert opus frames into a stream of samples. + pub decoder: Arc>, + /// Controls whether this source allows direct Opus frame passthrough. + /// Defaults to `true`. + /// + /// Enabling this flag is a promise from the programmer to the audio core + /// that the source has been encoded at 48kHz, using 20ms long frames. + /// If you cannot guarantee this, disable this flag (or else risk nasal demons) + /// and bizarre audio behaviour. + pub allow_passthrough: bool, + pub(crate) current_frame: Vec, + pub(crate) frame_pos: usize, + pub(crate) should_reset: bool, +} + +impl OpusDecoderState { + /// Creates a new decoder, having stereo output at 48kHz. + pub fn new() -> Result { + Ok(Self::from_decoder(OpusDecoder::new( + SAMPLE_RATE, + Channels::Stereo, + )?)) + } + + /// Creates a new decoder pre-configured by the user. + pub fn from_decoder(decoder: OpusDecoder) -> Self { + Self { + decoder: Arc::new(Mutex::new(decoder)), + allow_passthrough: true, + current_frame: Vec::with_capacity(STEREO_FRAME_SIZE), + frame_pos: 0, + should_reset: false, + } + } +} diff --git a/src/input/container/frame.rs b/src/input/container/frame.rs new file mode 100644 index 0000000..fb5f0f4 --- /dev/null +++ b/src/input/container/frame.rs @@ -0,0 +1,8 @@ +/// Information used in audio frame detection. +#[derive(Clone, Copy, Debug)] +pub struct Frame { + /// Length of this frame's header, in bytes. + pub header_len: usize, + /// Payload length, in bytes. + pub frame_len: usize, +} diff --git a/src/input/container/mod.rs b/src/input/container/mod.rs new file mode 100644 index 0000000..f22b013 --- /dev/null +++ b/src/input/container/mod.rs @@ -0,0 +1,69 @@ +mod frame; + +pub use frame::*; + +use super::CodecType; +use byteorder::{LittleEndian, ReadBytesExt}; +use std::{ + fmt::Debug, + io::{Read, Result as IoResult}, + mem, +}; + +/// Marker and state for decoding framed input files. +#[non_exhaustive] +#[derive(Clone, Copy, Debug)] +pub enum Container { + /// Raw, unframed input. + Raw, + /// Framed input, beginning with a JSON header. + /// + /// Frames have the form `{ len: i16, payload: [u8; len]}`. + Dca { + /// Byte index of the first frame after the JSON header. + first_frame: usize, + }, +} + +impl Container { + /// Tries to read the header of the next frame from an input stream. + pub fn next_frame_length( + &mut self, + mut reader: impl Read, + input: CodecType, + ) -> IoResult { + use Container::*; + + match self { + Raw => Ok(Frame { + header_len: 0, + frame_len: input.sample_len(), + }), + Dca { .. } => reader.read_i16::().map(|frame_len| Frame { + header_len: mem::size_of::(), + frame_len: frame_len.max(0) as usize, + }), + } + } + + /// Tries to seek on an input directly using sample length, if the input + /// is unframed. + pub fn try_seek_trivial(&self, input: CodecType) -> Option { + use Container::*; + + match self { + Raw => Some(input.sample_len()), + _ => None, + } + } + + /// Returns the byte index of the first frame containing audio payload data. + pub fn input_start(&self) -> usize { + use Container::*; + + match self { + Raw => 0, + Dca { first_frame } => *first_frame, + } + } +} diff --git a/src/input/dca.rs b/src/input/dca.rs new file mode 100644 index 0000000..ea46331 --- /dev/null +++ b/src/input/dca.rs @@ -0,0 +1,137 @@ +use super::{codec::OpusDecoderState, error::DcaError, Codec, Container, Input, Metadata, Reader}; +use serde::Deserialize; +use std::{ffi::OsStr, io::BufReader, mem}; +use tokio::{fs::File as TokioFile, io::AsyncReadExt}; + +/// Creates a streamed audio source from a DCA file. +/// Currently only accepts the [DCA1 format](https://github.com/bwmarrin/dca). +pub async fn dca>(path: P) -> Result { + _dca(path.as_ref()).await +} + +async fn _dca(path: &OsStr) -> Result { + let mut reader = TokioFile::open(path).await.map_err(DcaError::IoError)?; + + let mut header = [0u8; 4]; + + // Read in the magic number to verify it's a DCA file. + reader + .read_exact(&mut header) + .await + .map_err(DcaError::IoError)?; + + if header != b"DCA1"[..] { + return Err(DcaError::InvalidHeader); + } + + let size = reader + .read_i32_le() + .await + .map_err(|_| DcaError::InvalidHeader)?; + + // Sanity check + if size < 2 { + return Err(DcaError::InvalidSize(size)); + } + + let mut raw_json = Vec::with_capacity(size as usize); + + let mut json_reader = reader.take(size as u64); + + json_reader + .read_to_end(&mut raw_json) + .await + .map_err(DcaError::IoError)?; + + let reader = BufReader::new(json_reader.into_inner().into_std().await); + + let metadata: Metadata = serde_json::from_slice::(raw_json.as_slice()) + .map_err(DcaError::InvalidMetadata)? + .into(); + + let stereo = metadata.channels == Some(2); + + Ok(Input::new( + stereo, + Reader::File(reader), + Codec::Opus(OpusDecoderState::new().map_err(DcaError::Opus)?), + Container::Dca { + first_frame: (size as usize) + mem::size_of::() + header.len(), + }, + Some(metadata), + )) +} + +#[derive(Debug, Deserialize)] +pub(crate) struct DcaMetadata { + pub(crate) dca: Dca, + pub(crate) opus: Opus, + pub(crate) info: Option, + pub(crate) origin: Option, + pub(crate) extra: Option, +} + +#[derive(Debug, Deserialize)] +pub(crate) struct Dca { + pub(crate) version: u64, + pub(crate) tool: Tool, +} + +#[derive(Debug, Deserialize)] +pub(crate) struct Tool { + pub(crate) name: String, + pub(crate) version: String, + pub(crate) url: String, + pub(crate) author: String, +} + +#[derive(Debug, Deserialize)] +pub(crate) struct Opus { + pub(crate) mode: String, + pub(crate) sample_rate: u32, + pub(crate) frame_size: u64, + pub(crate) abr: u64, + pub(crate) vbr: u64, + pub(crate) channels: u8, +} + +#[derive(Debug, Deserialize)] +pub(crate) struct Info { + pub(crate) title: Option, + pub(crate) artist: Option, + pub(crate) album: Option, + pub(crate) genre: Option, + pub(crate) cover: Option, +} + +#[derive(Debug, Deserialize)] +pub(crate) struct Origin { + pub(crate) source: Option, + pub(crate) abr: Option, + pub(crate) channels: Option, + pub(crate) encoding: Option, + pub(crate) url: Option, +} + +impl From for Metadata { + fn from(mut d: DcaMetadata) -> Self { + let (title, artist) = d + .info + .take() + .map(|mut m| (m.title.take(), m.artist.take())) + .unwrap_or_else(|| (None, None)); + + let channels = Some(d.opus.channels); + let sample_rate = Some(d.opus.sample_rate); + + Self { + title, + artist, + + channels, + sample_rate, + + ..Default::default() + } + } +} diff --git a/src/input/error.rs b/src/input/error.rs new file mode 100644 index 0000000..614249f --- /dev/null +++ b/src/input/error.rs @@ -0,0 +1,93 @@ +//! Errors caused by input creation. + +use audiopus::Error as OpusError; +use serde_json::{Error as JsonError, Value}; +use std::{io::Error as IoError, process::Output}; +use streamcatcher::CatcherError; + +/// An error returned when creating a new [`Input`]. +/// +/// [`Input`]: ../struct.Input.html +#[derive(Debug)] +#[non_exhaustive] +pub enum Error { + /// An error occurred while opening a new DCA source. + Dca(DcaError), + /// An error occurred while reading, or opening a file. + Io(IoError), + /// An error occurred while parsing JSON (i.e., during metadata/stereo detection). + Json(JsonError), + /// An error occurred within the Opus codec. + Opus(OpusError), + /// Failed to extract metadata from alternate pipe. + Metadata, + /// Apparently failed to create stdout. + Stdout, + /// An error occurred while checking if a path is stereo. + Streams, + /// Configuration error for a cached Input. + Streamcatcher(CatcherError), + /// An error occurred while processing the JSON output from `youtube-dl`. + /// + /// The JSON output is given. + YouTubeDLProcessing(Value), + /// An error occurred while running `youtube-dl`. + YouTubeDLRun(Output), + /// The `url` field of the `youtube-dl` JSON output was not present. + /// + /// The JSON output is given. + YouTubeDLUrl(Value), +} + +impl From for Error { + fn from(e: CatcherError) -> Self { + Error::Streamcatcher(e) + } +} + +impl From for Error { + fn from(e: DcaError) -> Self { + Error::Dca(e) + } +} + +impl From for Error { + fn from(e: IoError) -> Error { + Error::Io(e) + } +} + +impl From for Error { + fn from(e: JsonError) -> Self { + Error::Json(e) + } +} + +impl From for Error { + fn from(e: OpusError) -> Error { + Error::Opus(e) + } +} + +/// An error returned from the [`dca`] method. +/// +/// [`dca`]: ../fn.dca.html +#[derive(Debug)] +#[non_exhaustive] +pub enum DcaError { + /// An error occurred while reading, or opening a file. + IoError(IoError), + /// The file opened did not have a valid DCA JSON header. + InvalidHeader, + /// The file's metadata block was invalid, or could not be parsed. + InvalidMetadata(JsonError), + /// The file's header reported an invalid metadata block size. + InvalidSize(i32), + /// An error was encountered while creating a new Opus decoder. + Opus(OpusError), +} + +/// Convenience type for fallible return of [`Input`]s. +/// +/// [`Input`]: ../struct.Input.html +pub type Result = std::result::Result; diff --git a/src/input/ffmpeg_src.rs b/src/input/ffmpeg_src.rs new file mode 100644 index 0000000..f430762 --- /dev/null +++ b/src/input/ffmpeg_src.rs @@ -0,0 +1,146 @@ +use super::{ + child_to_reader, + error::{Error, Result}, + Codec, + Container, + Input, + Metadata, +}; +use serde_json::Value; +use std::{ + ffi::OsStr, + process::{Command, Stdio}, +}; +use tokio::process::Command as TokioCommand; +use tracing::debug; + +/// Opens an audio file through `ffmpeg` and creates an audio source. +pub async fn ffmpeg>(path: P) -> Result { + _ffmpeg(path.as_ref()).await +} + +pub(crate) async fn _ffmpeg(path: &OsStr) -> Result { + // Will fail if the path is not to a file on the fs. Likely a YouTube URI. + let is_stereo = is_stereo(path) + .await + .unwrap_or_else(|_e| (false, Default::default())); + let stereo_val = if is_stereo.0 { "2" } else { "1" }; + + _ffmpeg_optioned( + path, + &[], + &[ + "-f", + "s16le", + "-ac", + stereo_val, + "-ar", + "48000", + "-acodec", + "pcm_f32le", + "-", + ], + Some(is_stereo), + ) + .await +} + +/// Opens an audio file through `ffmpeg` and creates an audio source, with +/// user-specified arguments to pass to ffmpeg. +/// +/// Note that this does _not_ build on the arguments passed by the [`ffmpeg`] +/// function. +/// +/// # Examples +/// +/// Pass options to create a custom ffmpeg streamer: +/// +/// ```rust,no_run +/// use songbird::input; +/// +/// let stereo_val = "2"; +/// +/// let streamer = futures::executor::block_on(input::ffmpeg_optioned("./some_file.mp3", &[], &[ +/// "-f", +/// "s16le", +/// "-ac", +/// stereo_val, +/// "-ar", +/// "48000", +/// "-acodec", +/// "pcm_s16le", +/// "-", +/// ])); +///``` +pub async fn ffmpeg_optioned>( + path: P, + pre_input_args: &[&str], + args: &[&str], +) -> Result { + _ffmpeg_optioned(path.as_ref(), pre_input_args, args, None).await +} + +pub(crate) async fn _ffmpeg_optioned( + path: &OsStr, + pre_input_args: &[&str], + args: &[&str], + is_stereo_known: Option<(bool, Metadata)>, +) -> Result { + let (is_stereo, metadata) = if let Some(vals) = is_stereo_known { + vals + } else { + is_stereo(path) + .await + .ok() + .unwrap_or_else(|| (false, Default::default())) + }; + + let command = Command::new("ffmpeg") + .args(pre_input_args) + .arg("-i") + .arg(path) + .args(args) + .stderr(Stdio::null()) + .stdin(Stdio::null()) + .stdout(Stdio::piped()) + .spawn()?; + + Ok(Input::new( + is_stereo, + child_to_reader::(command), + Codec::FloatPcm, + Container::Raw, + Some(metadata), + )) +} + +pub(crate) async fn is_stereo(path: &OsStr) -> Result<(bool, Metadata)> { + let args = [ + "-v", + "quiet", + "-of", + "json", + "-show_format", + "-show_streams", + "-i", + ]; + + let out = TokioCommand::new("ffprobe") + .args(&args) + .arg(path) + .stdin(Stdio::null()) + .output() + .await?; + + let value: Value = serde_json::from_reader(&out.stdout[..])?; + + let metadata = Metadata::from_ffprobe_json(&value); + + debug!("FFprobe metadata {:?}", metadata); + + if let Some(count) = metadata.channels { + Ok((count == 2, metadata)) + } else { + Err(Error::Streams) + } +} diff --git a/src/input/metadata.rs b/src/input/metadata.rs new file mode 100644 index 0000000..4a47523 --- /dev/null +++ b/src/input/metadata.rs @@ -0,0 +1,166 @@ +use crate::constants::*; +use serde_json::Value; +use std::time::Duration; + +/// Information about an [`Input`] source. +/// +/// [`Input`]: struct.Input.html +#[derive(Clone, Debug, Default, Eq, PartialEq)] +pub struct Metadata { + /// The title of this stream. + pub title: Option, + /// The main artist of this stream. + pub artist: Option, + /// The date of creation of this stream. + pub date: Option, + + /// The number of audio channels in this stream. + /// + /// Any number `>= 2` is treated as stereo. + pub channels: Option, + /// The time at which the first true sample is played back. + /// + /// This occurs as an artefact of coder delay. + pub start_time: Option, + /// The reported duration of this stream. + pub duration: Option, + /// The sample rate of this stream. + pub sample_rate: Option, +} + +impl Metadata { + /// Extract metadata and details from the output of + /// `ffprobe`. + pub fn from_ffprobe_json(value: &Value) -> Self { + let format = value.as_object().and_then(|m| m.get("format")); + + let duration = format + .and_then(|m| m.get("duration")) + .and_then(Value::as_str) + .and_then(|v| v.parse::().ok()) + .map(Duration::from_secs_f64); + + let start_time = format + .and_then(|m| m.get("start_time")) + .and_then(Value::as_str) + .and_then(|v| v.parse::().ok()) + .map(Duration::from_secs_f64); + + let tags = format.and_then(|m| m.get("tags")); + + let title = tags + .and_then(|m| m.get("title")) + .and_then(Value::as_str) + .map(str::to_string); + + let artist = tags + .and_then(|m| m.get("artist")) + .and_then(Value::as_str) + .map(str::to_string); + + let date = tags + .and_then(|m| m.get("date")) + .and_then(Value::as_str) + .map(str::to_string); + + let stream = value + .as_object() + .and_then(|m| m.get("streams")) + .and_then(|v| v.as_array()) + .and_then(|v| { + v.iter() + .find(|line| line.get("codec_type").and_then(Value::as_str) == Some("audio")) + }); + + let channels = stream + .and_then(|m| m.get("channels")) + .and_then(Value::as_u64) + .map(|v| v as u8); + + let sample_rate = stream + .and_then(|m| m.get("sample_rate")) + .and_then(Value::as_str) + .and_then(|v| v.parse::().ok()) + .map(|v| v as u32); + + Self { + title, + artist, + date, + + channels, + start_time, + duration, + sample_rate, + } + } + + /// Use `youtube-dl` to extract metadata for an online resource. + pub fn from_ytdl_output(value: Value) -> Self { + let obj = value.as_object(); + + let track = obj + .and_then(|m| m.get("track")) + .and_then(Value::as_str) + .map(str::to_string); + + let title = track.or_else(|| { + obj.and_then(|m| m.get("title")) + .and_then(Value::as_str) + .map(str::to_string) + }); + + let true_artist = obj + .and_then(|m| m.get("artist")) + .and_then(Value::as_str) + .map(str::to_string); + + let artist = true_artist.or_else(|| { + obj.and_then(|m| m.get("uploader")) + .and_then(Value::as_str) + .map(str::to_string) + }); + + let r_date = obj + .and_then(|m| m.get("release_date")) + .and_then(Value::as_str) + .map(str::to_string); + + let date = r_date.or_else(|| { + obj.and_then(|m| m.get("upload_date")) + .and_then(Value::as_str) + .map(str::to_string) + }); + + let duration = obj + .and_then(|m| m.get("duration")) + .and_then(Value::as_f64) + .map(Duration::from_secs_f64); + + Self { + title, + artist, + date, + + channels: Some(2), + duration, + sample_rate: Some(SAMPLE_RATE_RAW as u32), + + ..Default::default() + } + } + + /// Move all fields from a `Metadata` object into a new one. + pub fn take(&mut self) -> Self { + Self { + title: self.title.take(), + artist: self.artist.take(), + date: self.date.take(), + + channels: self.channels.take(), + start_time: self.start_time.take(), + duration: self.duration.take(), + sample_rate: self.sample_rate.take(), + } + } +} diff --git a/src/input/mod.rs b/src/input/mod.rs new file mode 100644 index 0000000..8d10c26 --- /dev/null +++ b/src/input/mod.rs @@ -0,0 +1,596 @@ +//! Raw audio input data streams and sources. +//! +//! [`Input`] is handled in Songbird by combining metadata with: +//! * A 48kHz audio bytestream, via [`Reader`], +//! * A [`Container`] describing the framing mechanism of the bytestream, +//! * A [`Codec`], defining the format of audio frames. +//! +//! When used as a [`Read`], the output bytestream will be a floating-point +//! PCM stream at 48kHz, matching the channel count of the input source. +//! +//! ## Opus frame passthrough. +//! Some sources, such as [`Compressed`] or the output of [`dca`], support +//! direct frame passthrough to the driver. This lets you directly send the +//! audio data you have *without decoding, re-encoding, or mixing*. In many +//! cases, this can greatly reduce the processing/compute cost of the driver. +//! +//! This functionality requires that: +//! * only one track is active (including paused tracks), +//! * that track's input supports direct Opus frame reads, +//! * its [`Input`] [meets the promises described herein](codec/struct.OpusDecoderState.html#structfield.allow_passthrough), +//! * and that track's volume is set to `1.0`. +//! +//! [`Input`]: struct.Input.html +//! [`Reader`]: reader/enum.Reader.html +//! [`Container`]: enum.Container.html +//! [`Codec`]: codec/enum.Codec.html +//! [`Read`]: https://doc.rust-lang.org/std/io/trait.Read.html +//! [`Compressed`]: cached/struct.Compressed.html +//! [`dca`]: fn.dca.html + +pub mod cached; +mod child; +pub mod codec; +mod container; +mod dca; +pub mod error; +mod ffmpeg_src; +mod metadata; +pub mod reader; +pub mod restartable; +pub mod utils; +mod ytdl_src; + +pub use self::{ + child::*, + codec::{Codec, CodecType}, + container::{Container, Frame}, + dca::dca, + ffmpeg_src::*, + metadata::Metadata, + reader::Reader, + restartable::Restartable, + ytdl_src::*, +}; + +use crate::constants::*; +use audiopus::coder::GenericCtl; +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; +use cached::OpusCompressor; +use error::{Error, Result}; +use tokio::runtime::Handle; + +use std::{ + convert::TryFrom, + io::{ + self, + Error as IoError, + ErrorKind as IoErrorKind, + Read, + Result as IoResult, + Seek, + SeekFrom, + }, + mem, + time::Duration, +}; +use tracing::{debug, error}; + +/// Data and metadata needed to correctly parse a [`Reader`]'s audio bytestream. +/// +/// See the [module root] for more information. +/// +/// [`Reader`]: enum.Reader.html +/// [module root]: index.html +#[derive(Debug)] +pub struct Input { + /// Information about the played source. + pub metadata: Metadata, + /// Indicates whether `source` is stereo or mono. + pub stereo: bool, + /// Underlying audio data bytestream. + pub reader: Reader, + /// Decoder used to parse the output of `reader`. + pub kind: Codec, + /// Framing strategy needed to identify frames of compressed audio. + pub container: Container, + pos: usize, +} + +impl Input { + /// Creates a floating-point PCM Input from a given reader. + pub fn float_pcm(is_stereo: bool, reader: Reader) -> Input { + Input { + metadata: Default::default(), + stereo: is_stereo, + reader, + kind: Codec::FloatPcm, + container: Container::Raw, + pos: 0, + } + } + + /// Creates a new Input using (at least) the given reader, codec, and container. + pub fn new( + stereo: bool, + reader: Reader, + kind: Codec, + container: Container, + metadata: Option, + ) -> Self { + Input { + metadata: metadata.unwrap_or_default(), + stereo, + reader, + kind, + container, + pos: 0, + } + } + + /// Returns whether the inner [`Reader`] implements [`Seek`]. + /// + /// [`Reader`]: reader/enum.Reader.html + /// [`Seek`]: https://doc.rust-lang.org/std/io/trait.Seek.html + pub fn is_seekable(&self) -> bool { + self.reader.is_seekable() + } + + /// Returns whether the read audio signal is stereo (or mono). + pub fn is_stereo(&self) -> bool { + self.stereo + } + + /// Returns the type of the inner [`Codec`]. + /// + /// [`Codec`]: codec/enum.Codec.html + pub fn get_type(&self) -> CodecType { + (&self.kind).into() + } + + /// Mixes the output of this stream into a 20ms stereo audio buffer. + #[inline] + pub fn mix(&mut self, float_buffer: &mut [f32; STEREO_FRAME_SIZE], volume: f32) -> usize { + match self.add_float_pcm_frame(float_buffer, self.stereo, volume) { + Some(len) => len, + None => 0, + } + } + + /// Seeks the stream to the given time, if possible. + /// + /// Returns the actual time reached. + pub fn seek_time(&mut self, time: Duration) -> Option { + let future_pos = utils::timestamp_to_byte_count(time, self.stereo); + Seek::seek(self, SeekFrom::Start(future_pos as u64)) + .ok() + .map(|a| utils::byte_count_to_timestamp(a as usize, self.stereo)) + } + + fn read_inner(&mut self, buffer: &mut [u8], ignore_decode: bool) -> IoResult { + // This implementation of Read converts the input stream + // to floating point output. + let sample_len = mem::size_of::(); + let float_space = buffer.len() / sample_len; + let mut written_floats = 0; + + // TODO: better decouple codec and container here. + // this is a little bit backwards, and assumes the bottom cases are always raw... + let out = match &mut self.kind { + Codec::Opus(decoder_state) => { + if matches!(self.container, Container::Raw) { + return Err(IoError::new( + IoErrorKind::InvalidInput, + "Raw container cannot demarcate Opus frames.", + )); + } + + if ignore_decode { + // If we're less than one frame away from the end of cheap seeking, + // then we must decode to make sure the next starting offset is correct. + + // Step one: use up the remainder of the frame. + let mut aud_skipped = + decoder_state.current_frame.len() - decoder_state.frame_pos; + + decoder_state.frame_pos = 0; + decoder_state.current_frame.truncate(0); + + // Step two: take frames if we can. + while buffer.len() - aud_skipped >= STEREO_FRAME_BYTE_SIZE { + decoder_state.should_reset = true; + + let frame = self + .container + .next_frame_length(&mut self.reader, CodecType::Opus)?; + self.reader.consume(frame.frame_len); + + aud_skipped += STEREO_FRAME_BYTE_SIZE; + } + + Ok(aud_skipped) + } else { + // get new frame *if needed* + if decoder_state.frame_pos == decoder_state.current_frame.len() { + let mut decoder = decoder_state.decoder.lock(); + + if decoder_state.should_reset { + decoder + .reset_state() + .expect("Critical failure resetting decoder."); + decoder_state.should_reset = false; + } + let frame = self + .container + .next_frame_length(&mut self.reader, CodecType::Opus)?; + + let mut opus_data_buffer = [0u8; 4000]; + + decoder_state + .current_frame + .resize(decoder_state.current_frame.capacity(), 0.0); + + let seen = + Read::read(&mut self.reader, &mut opus_data_buffer[..frame.frame_len])?; + + let samples = decoder + .decode_float( + Some(&opus_data_buffer[..seen]), + &mut decoder_state.current_frame[..], + false, + ) + .unwrap_or(0); + + decoder_state.current_frame.truncate(2 * samples); + decoder_state.frame_pos = 0; + } + + // read from frame which is present. + let mut buffer = &mut buffer[..]; + + let start = decoder_state.frame_pos; + let to_write = float_space.min(decoder_state.current_frame.len() - start); + for val in &decoder_state.current_frame[start..start + float_space] { + buffer.write_f32::(*val)?; + } + decoder_state.frame_pos += to_write; + written_floats = to_write; + + Ok(written_floats * mem::size_of::()) + } + }, + Codec::Pcm => { + let mut buffer = &mut buffer[..]; + while written_floats < float_space { + if let Ok(signal) = self.reader.read_i16::() { + buffer.write_f32::(f32::from(signal) / 32768.0)?; + written_floats += 1; + } else { + break; + } + } + Ok(written_floats * mem::size_of::()) + }, + Codec::FloatPcm => Read::read(&mut self.reader, buffer), + }; + + out.map(|v| { + self.pos += v; + v + }) + } + + fn cheap_consume(&mut self, count: usize) -> IoResult { + let mut scratch = [0u8; STEREO_FRAME_BYTE_SIZE * 4]; + let len = scratch.len(); + let mut done = 0; + + loop { + let read = self.read_inner(&mut scratch[..len.min(count - done)], true)?; + if read == 0 { + break; + } + done += read; + } + + Ok(done) + } + + pub(crate) fn supports_passthrough(&self) -> bool { + match &self.kind { + Codec::Opus(state) => state.allow_passthrough, + _ => false, + } + } + + pub(crate) fn read_opus_frame(&mut self, buffer: &mut [u8]) -> IoResult { + // Called in event of opus passthrough. + if let Codec::Opus(state) = &mut self.kind { + // step 1: align to frame. + self.pos += state.current_frame.len() - state.frame_pos; + + state.frame_pos = 0; + state.current_frame.truncate(0); + + // step 2: read new header. + let frame = self + .container + .next_frame_length(&mut self.reader, CodecType::Opus)?; + + // step 3: read in bytes. + self.reader + .read_exact(&mut buffer[..frame.frame_len]) + .map(|_| { + self.pos += STEREO_FRAME_BYTE_SIZE; + frame.frame_len + }) + } else { + Err(IoError::new( + IoErrorKind::InvalidInput, + "Frame passthrough not supported for this file.", + )) + } + } + + pub(crate) fn prep_with_handle(&mut self, handle: Handle) { + self.reader.prep_with_handle(handle); + } +} + +impl Read for Input { + fn read(&mut self, buffer: &mut [u8]) -> IoResult { + self.read_inner(buffer, false) + } +} + +impl Seek for Input { + fn seek(&mut self, pos: SeekFrom) -> IoResult { + let mut target = self.pos; + match pos { + SeekFrom::Start(pos) => { + target = pos as usize; + }, + SeekFrom::Current(rel) => { + target = target.wrapping_add(rel as usize); + }, + SeekFrom::End(_pos) => unimplemented!(), + } + + debug!("Seeking to {:?}", pos); + + (if target == self.pos { + Ok(0) + } else if let Some(conversion) = self.container.try_seek_trivial(self.get_type()) { + let inside_target = (target * conversion) / mem::size_of::(); + Seek::seek(&mut self.reader, SeekFrom::Start(inside_target as u64)).map(|inner_dest| { + let outer_dest = ((inner_dest as usize) * mem::size_of::()) / conversion; + self.pos = outer_dest; + outer_dest + }) + } else if target > self.pos { + // seek in the next amount, disabling decoding if need be. + let shift = target - self.pos; + self.cheap_consume(shift) + } else { + // start from scratch, then seek in... + Seek::seek( + &mut self.reader, + SeekFrom::Start(self.container.input_start() as u64), + )?; + + self.cheap_consume(target) + }) + .map(|_| self.pos as u64) + } +} + +/// Extension trait to pull frames of audio from a byte source. +pub(crate) trait ReadAudioExt { + fn add_float_pcm_frame( + &mut self, + float_buffer: &mut [f32; STEREO_FRAME_SIZE], + true_stereo: bool, + volume: f32, + ) -> Option; + + fn consume(&mut self, amt: usize) -> usize + where + Self: Sized; +} + +impl ReadAudioExt for R { + fn add_float_pcm_frame( + &mut self, + float_buffer: &mut [f32; STEREO_FRAME_SIZE], + stereo: bool, + volume: f32, + ) -> Option { + // IDEA: Read in 8 floats at a time, then use iterator code + // to gently nudge the compiler into vectorising for us. + // Max SIMD float32 lanes is 8 on AVX, older archs use a divisor of this + // e.g., 4. + const SAMPLE_LEN: usize = mem::size_of::(); + const FLOAT_COUNT: usize = 512; + let mut simd_float_bytes = [0u8; FLOAT_COUNT * SAMPLE_LEN]; + let mut simd_float_buf = [0f32; FLOAT_COUNT]; + + let mut frame_pos = 0; + + // Code duplication here is because unifying these codepaths + // with a dynamic chunk size is not zero-cost. + if stereo { + let mut max_bytes = STEREO_FRAME_BYTE_SIZE; + + while frame_pos < float_buffer.len() { + let progress = self + .read(&mut simd_float_bytes[..max_bytes.min(FLOAT_COUNT * SAMPLE_LEN)]) + .and_then(|byte_len| { + let target = byte_len / SAMPLE_LEN; + (&simd_float_bytes[..byte_len]) + .read_f32_into::(&mut simd_float_buf[..target]) + .map(|_| target) + }) + .map(|f32_len| { + let new_pos = frame_pos + f32_len; + for (el, new_el) in float_buffer[frame_pos..new_pos] + .iter_mut() + .zip(&simd_float_buf[..f32_len]) + { + *el += volume * new_el; + } + (new_pos, f32_len) + }); + + match progress { + Ok((new_pos, delta)) => { + frame_pos = new_pos; + max_bytes -= delta * SAMPLE_LEN; + + if delta == 0 { + break; + } + }, + Err(ref e) => + return if e.kind() == IoErrorKind::UnexpectedEof { + error!("EOF unexpectedly: {:?}", e); + Some(frame_pos) + } else { + error!("Input died unexpectedly: {:?}", e); + None + }, + } + } + } else { + let mut max_bytes = MONO_FRAME_BYTE_SIZE; + + while frame_pos < float_buffer.len() { + let progress = self + .read(&mut simd_float_bytes[..max_bytes.min(FLOAT_COUNT * SAMPLE_LEN)]) + .and_then(|byte_len| { + let target = byte_len / SAMPLE_LEN; + (&simd_float_bytes[..byte_len]) + .read_f32_into::(&mut simd_float_buf[..target]) + .map(|_| target) + }) + .map(|f32_len| { + let new_pos = frame_pos + (2 * f32_len); + for (els, new_el) in float_buffer[frame_pos..new_pos] + .chunks_exact_mut(2) + .zip(&simd_float_buf[..f32_len]) + { + let sample = volume * new_el; + els[0] += sample; + els[1] += sample; + } + (new_pos, f32_len) + }); + + match progress { + Ok((new_pos, delta)) => { + frame_pos = new_pos; + max_bytes -= delta * SAMPLE_LEN; + + if delta == 0 { + break; + } + }, + Err(ref e) => + return if e.kind() == IoErrorKind::UnexpectedEof { + Some(frame_pos) + } else { + error!("Input died unexpectedly: {:?}", e); + None + }, + } + } + } + + Some(frame_pos * SAMPLE_LEN) + } + + fn consume(&mut self, amt: usize) -> usize { + io::copy(&mut self.by_ref().take(amt as u64), &mut io::sink()).unwrap_or(0) as usize + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_utils::*; + + #[test] + fn float_pcm_input_unchanged_mono() { + let data = make_sine(50 * MONO_FRAME_SIZE, false); + let mut input = Input::new( + false, + data.clone().into(), + Codec::FloatPcm, + Container::Raw, + None, + ); + + let mut out_vec = vec![]; + + let len = input.read_to_end(&mut out_vec).unwrap(); + assert_eq!(out_vec[..len], data[..]); + } + + #[test] + fn float_pcm_input_unchanged_stereo() { + let data = make_sine(50 * MONO_FRAME_SIZE, true); + let mut input = Input::new( + true, + data.clone().into(), + Codec::FloatPcm, + Container::Raw, + None, + ); + + let mut out_vec = vec![]; + + let len = input.read_to_end(&mut out_vec).unwrap(); + assert_eq!(out_vec[..len], data[..]); + } + + #[test] + fn pcm_input_becomes_float_mono() { + let data = make_pcm_sine(50 * MONO_FRAME_SIZE, false); + let mut input = Input::new(false, data.clone().into(), Codec::Pcm, Container::Raw, None); + + let mut out_vec = vec![]; + let len = input.read_to_end(&mut out_vec).unwrap(); + + let mut i16_window = &data[..]; + let mut float_window = &out_vec[..]; + + while i16_window.len() != 0 { + let before = i16_window.read_i16::().unwrap() as f32; + let after = float_window.read_f32::().unwrap(); + + let diff = (before / 32768.0) - after; + + assert!(diff.abs() < f32::EPSILON); + } + } + + #[test] + fn pcm_input_becomes_float_stereo() { + let data = make_pcm_sine(50 * MONO_FRAME_SIZE, true); + let mut input = Input::new(true, data.clone().into(), Codec::Pcm, Container::Raw, None); + + let mut out_vec = vec![]; + let len = input.read_to_end(&mut out_vec).unwrap(); + + let mut i16_window = &data[..]; + let mut float_window = &out_vec[..]; + + while i16_window.len() != 0 { + let before = i16_window.read_i16::().unwrap() as f32; + let after = float_window.read_f32::().unwrap(); + + let diff = (before / 32768.0) - after; + + assert!(diff.abs() < f32::EPSILON); + } + } +} diff --git a/src/input/reader.rs b/src/input/reader.rs new file mode 100644 index 0000000..030dac3 --- /dev/null +++ b/src/input/reader.rs @@ -0,0 +1,180 @@ +//! Raw handlers for input bytestreams. + +use super::*; +use std::{ + fmt::{Debug, Error as FormatError, Formatter}, + fs::File, + io::{ + BufReader, + Cursor, + Error as IoError, + ErrorKind as IoErrorKind, + Read, + Result as IoResult, + Seek, + SeekFrom, + }, + result::Result as StdResult, +}; +use streamcatcher::{Catcher, TxCatcher}; + +/// Usable data/byte sources for an audio stream. +/// +/// Users may define their own data sources using [`Extension`] +/// and [`ExtensionSeek`]. +/// +/// [`Extension`]: #variant.Extension +/// [`ExtensionSeek`]: #variant.ExtensionSeek +pub enum Reader { + /// Piped output of another program (i.e., [`ffmpeg`]). + /// + /// Does not support seeking. + /// + /// [`ffmpeg`]: ../fn.ffmpeg.html + Pipe(BufReader), + /// A cached, raw in-memory store, provided by Songbird. + /// + /// Supports seeking. + Memory(Catcher>), + /// A cached, Opus-compressed in-memory store, provided by Songbird. + /// + /// Supports seeking. + Compressed(TxCatcher, OpusCompressor>), + /// A source which supports seeking by recreating its inout stream. + /// + /// Supports seeking. + Restartable(Restartable), + /// A source contained in a local file. + /// + /// Supports seeking. + File(BufReader), + /// A source contained as an array in memory. + /// + /// Supports seeking. + Vec(Cursor>), + /// A basic user-provided source. + /// + /// Does not support seeking. + Extension(Box), + /// A user-provided source which also implements [`Seek`]. + /// + /// Supports seeking. + /// + /// [`Seek`]: https://doc.rust-lang.org/std/io/trait.Seek.html + ExtensionSeek(Box), +} + +impl Reader { + /// Returns whether the given source implements [`Seek`]. + /// + /// [`Seek`]: https://doc.rust-lang.org/std/io/trait.Seek.html + pub fn is_seekable(&self) -> bool { + use Reader::*; + match self { + Restartable(_) | Compressed(_) | Memory(_) => true, + Extension(_) => false, + ExtensionSeek(_) => true, + _ => false, + } + } + + #[allow(clippy::single_match)] + pub(crate) fn prep_with_handle(&mut self, handle: Handle) { + use Reader::*; + match self { + Restartable(r) => r.prep_with_handle(handle), + _ => {}, + } + } +} + +impl Read for Reader { + fn read(&mut self, buffer: &mut [u8]) -> IoResult { + use Reader::*; + match self { + Pipe(a) => Read::read(a, buffer), + Memory(a) => Read::read(a, buffer), + Compressed(a) => Read::read(a, buffer), + Restartable(a) => Read::read(a, buffer), + File(a) => Read::read(a, buffer), + Vec(a) => Read::read(a, buffer), + Extension(a) => a.read(buffer), + ExtensionSeek(a) => a.read(buffer), + } + } +} + +impl Seek for Reader { + fn seek(&mut self, pos: SeekFrom) -> IoResult { + use Reader::*; + match self { + Pipe(_) | Extension(_) => Err(IoError::new( + IoErrorKind::InvalidInput, + "Seeking not supported on Reader of this type.", + )), + Memory(a) => Seek::seek(a, pos), + Compressed(a) => Seek::seek(a, pos), + File(a) => Seek::seek(a, pos), + Restartable(a) => Seek::seek(a, pos), + Vec(a) => Seek::seek(a, pos), + ExtensionSeek(a) => a.seek(pos), + } + } +} + +impl Debug for Reader { + fn fmt(&self, f: &mut Formatter<'_>) -> StdResult<(), FormatError> { + use Reader::*; + let field = match self { + Pipe(a) => format!("{:?}", a), + Memory(a) => format!("{:?}", a), + Compressed(a) => format!("{:?}", a), + Restartable(a) => format!("{:?}", a), + File(a) => format!("{:?}", a), + Vec(a) => format!("{:?}", a), + Extension(_) => "Extension".to_string(), + ExtensionSeek(_) => "ExtensionSeek".to_string(), + }; + f.debug_tuple("Reader").field(&field).finish() + } +} + +impl From> for Reader { + fn from(val: Vec) -> Reader { + Reader::Vec(Cursor::new(val)) + } +} + +/// Fusion trait for custom input sources which allow seeking. +pub trait ReadSeek { + /// See [`Read::read`]. + /// + /// [`Read::read`]: https://doc.rust-lang.org/nightly/std/io/trait.Read.html#tymethod.read + fn read(&mut self, buf: &mut [u8]) -> IoResult; + /// See [`Seek::seek`]. + /// + /// [`Seek::seek`]: https://doc.rust-lang.org/nightly/std/io/trait.Seek.html#tymethod.seek + fn seek(&mut self, pos: SeekFrom) -> IoResult; +} + +impl Read for dyn ReadSeek { + fn read(&mut self, buf: &mut [u8]) -> IoResult { + ReadSeek::read(self, buf) + } +} + +impl Seek for dyn ReadSeek { + fn seek(&mut self, pos: SeekFrom) -> IoResult { + ReadSeek::seek(self, pos) + } +} + +impl ReadSeek for R { + fn read(&mut self, buf: &mut [u8]) -> IoResult { + Read::read(self, buf) + } + + fn seek(&mut self, pos: SeekFrom) -> IoResult { + Seek::seek(self, pos) + } +} diff --git a/src/input/restartable.rs b/src/input/restartable.rs new file mode 100644 index 0000000..6965e54 --- /dev/null +++ b/src/input/restartable.rs @@ -0,0 +1,294 @@ +//! A source which supports seeking by recreating its input stream. +//! +//! This is intended for use with single-use audio tracks which +//! may require looping or seeking, but where additional memory +//! cannot be spared. Forward seeks will drain the track until reaching +//! the desired timestamp. +//! +//! Restarting occurs by temporarily pausing the track, running the restart +//! mechanism, and then passing the handle back to the mixer thread. Until +//! success/failure is confirmed, the track produces silence. + +use super::*; +use flume::{Receiver, TryRecvError}; +use futures::executor; +use std::{ + ffi::OsStr, + fmt::{Debug, Error as FormatError, Formatter}, + io::{Error as IoError, ErrorKind as IoErrorKind, Read, Result as IoResult, Seek, SeekFrom}, + result::Result as StdResult, + time::Duration, +}; + +type Recreator = Box; +type RecreateChannel = Receiver, Recreator)>>; + +/// A wrapper around a method to create a new [`Input`] which +/// seeks backward by recreating the source. +/// +/// The main purpose of this wrapper is to enable seeking on +/// incompatible sources (i.e., ffmpeg output) and to ease resource +/// consumption for commonly reused/shared tracks. [`Compressed`] +/// and [`Memory`] offer the same functionality with different +/// tradeoffs. +/// +/// This is intended for use with single-use audio tracks which +/// may require looping or seeking, but where additional memory +/// cannot be spared. Forward seeks will drain the track until reaching +/// the desired timestamp. +/// +/// [`Input`]: struct.Input.html +/// [`Memory`]: cached/struct.Memory.html +/// [`Compressed`]: cached/struct.Compressed.html +pub struct Restartable { + async_handle: Option, + awaiting_source: Option, + position: usize, + recreator: Option, + source: Box, +} + +impl Restartable { + /// Create a new source, which can be restarted using a `recreator` function. + pub fn new(mut recreator: impl Restart + Send + 'static) -> Result { + recreator.call_restart(None).map(move |source| Self { + async_handle: None, + awaiting_source: None, + position: 0, + recreator: Some(Box::new(recreator)), + source: Box::new(source), + }) + } + + /// Create a new restartable ffmpeg source for a local file. + pub fn ffmpeg + Send + Clone + 'static>(path: P) -> Result { + Self::new(FfmpegRestarter { path }) + } + + /// Create a new restartable ytdl source. + /// + /// The cost of restarting and seeking will probably be *very* high: + /// expect a pause if you seek backwards. + pub fn ytdl + Send + Clone + 'static>(uri: P) -> Result { + Self::new(move |time: Option| { + if let Some(time) = time { + let ts = format!("{}.{}", time.as_secs(), time.subsec_millis()); + + executor::block_on(_ytdl(uri.as_ref(), &["-ss", &ts])) + } else { + executor::block_on(ytdl(uri.as_ref())) + } + }) + } + + /// Create a new restartable ytdl source, using the first result of a youtube search. + /// + /// The cost of restarting and seeking will probably be *very* high: + /// expect a pause if you seek backwards. + pub fn ytdl_search(name: &str) -> Result { + Self::ytdl(format!("ytsearch1:{}", name)) + } + + pub(crate) fn prep_with_handle(&mut self, handle: Handle) { + self.async_handle = Some(handle); + } +} + +/// Trait used to create an instance of a [`Reader`] at instantiation and when +/// a backwards seek is needed. +/// +/// Many closures derive this automatically. +/// +/// [`Reader`]: ../reader/enum.Reader.html +pub trait Restart { + /// Tries to create a replacement source. + fn call_restart(&mut self, time: Option) -> Result; +} + +struct FfmpegRestarter

+where + P: AsRef + Send, +{ + path: P, +} + +impl

Restart for FfmpegRestarter

+where + P: AsRef + Send, +{ + fn call_restart(&mut self, time: Option) -> Result { + executor::block_on(async { + if let Some(time) = time { + let is_stereo = is_stereo(self.path.as_ref()) + .await + .unwrap_or_else(|_e| (false, Default::default())); + let stereo_val = if is_stereo.0 { "2" } else { "1" }; + + let ts = format!("{}.{}", time.as_secs(), time.subsec_millis()); + _ffmpeg_optioned( + self.path.as_ref(), + &["-ss", &ts], + &[ + "-f", + "s16le", + "-ac", + stereo_val, + "-ar", + "48000", + "-acodec", + "pcm_f32le", + "-", + ], + Some(is_stereo), + ) + .await + } else { + ffmpeg(self.path.as_ref()).await + } + }) + } +} + +impl

Restart for P +where + P: FnMut(Option) -> Result + Send + 'static, +{ + fn call_restart(&mut self, time: Option) -> Result { + (self)(time) + } +} + +impl Debug for Restartable { + fn fmt(&self, f: &mut Formatter<'_>) -> StdResult<(), FormatError> { + f.debug_struct("Restartable") + .field("async_handle", &self.async_handle) + .field("awaiting_source", &self.awaiting_source) + .field("position", &self.position) + .field("recreator", &"") + .field("source", &self.source) + .finish() + } +} + +impl From for Input { + fn from(mut src: Restartable) -> Self { + let kind = src.source.kind.clone(); + let meta = Some(src.source.metadata.take()); + let stereo = src.source.stereo; + let container = src.source.container; + Input::new(stereo, Reader::Restartable(src), kind, container, meta) + } +} + +// How do these work at a high level? +// If you need to restart, send a request to do this to the async context. +// if a request is pending, then just output all zeroes. + +impl Read for Restartable { + fn read(&mut self, buffer: &mut [u8]) -> IoResult { + let (out_val, march_pos, remove_async) = if let Some(chan) = &self.awaiting_source { + match chan.try_recv() { + Ok(Ok((new_source, recreator))) => { + self.source = new_source; + self.recreator = Some(recreator); + + (Read::read(&mut self.source, buffer), true, true) + }, + Ok(Err(source_error)) => { + let e = Err(IoError::new( + IoErrorKind::UnexpectedEof, + format!("Failed to create new reader: {:?}.", source_error), + )); + (e, false, true) + }, + Err(TryRecvError::Empty) => { + // Output all zeroes. + for el in buffer.iter_mut() { + *el = 0; + } + (Ok(buffer.len()), false, false) + }, + Err(_) => { + let e = Err(IoError::new( + IoErrorKind::UnexpectedEof, + "Failed to create new reader: dropped.", + )); + (e, false, true) + }, + } + } else { + // already have a good, valid source. + (Read::read(&mut self.source, buffer), true, false) + }; + + if remove_async { + self.awaiting_source = None; + } + + if march_pos { + out_val.map(|a| { + self.position += a; + a + }) + } else { + out_val + } + } +} + +impl Seek for Restartable { + fn seek(&mut self, pos: SeekFrom) -> IoResult { + let _local_pos = self.position as u64; + + use SeekFrom::*; + match pos { + Start(offset) => { + let stereo = self.source.stereo; + let _current_ts = utils::byte_count_to_timestamp(self.position, stereo); + let offset = offset as usize; + + if offset < self.position { + // We're going back in time. + if let Some(handle) = self.async_handle.as_ref() { + let (tx, rx) = flume::bounded(1); + + self.awaiting_source = Some(rx); + + let recreator = self.recreator.take(); + + if let Some(mut rec) = recreator { + handle.spawn(async move { + let ret_val = rec.call_restart(Some( + utils::byte_count_to_timestamp(offset, stereo), + )); + + let _ = tx.send(ret_val.map(Box::new).map(|v| (v, rec))); + }); + } else { + return Err(IoError::new( + IoErrorKind::Interrupted, + "Previous seek in progress.", + )); + } + + self.position = offset; + } else { + return Err(IoError::new( + IoErrorKind::Interrupted, + "Cannot safely call seek until provided an async context handle.", + )); + } + } else { + self.position += self.source.consume(offset - self.position); + } + + Ok(offset as u64) + }, + End(_offset) => Err(IoError::new( + IoErrorKind::InvalidInput, + "End point for Restartables is not known.", + )), + Current(_offset) => unimplemented!(), + } + } +} diff --git a/src/input/utils.rs b/src/input/utils.rs new file mode 100644 index 0000000..d6072da --- /dev/null +++ b/src/input/utils.rs @@ -0,0 +1,41 @@ +//! Utility methods for seeking or decoding. + +use crate::constants::*; +use audiopus::{coder::Decoder, Channels, Result as OpusResult, SampleRate}; +use std::{mem, time::Duration}; + +/// Calculates the sample position in a FloatPCM stream from a timestamp. +pub fn timestamp_to_sample_count(timestamp: Duration, stereo: bool) -> usize { + ((timestamp.as_millis() as usize) * (MONO_FRAME_SIZE / FRAME_LEN_MS)) << stereo as usize +} + +/// Calculates the time position in a FloatPCM stream from a sample index. +pub fn sample_count_to_timestamp(amt: usize, stereo: bool) -> Duration { + Duration::from_millis((((amt * FRAME_LEN_MS) / MONO_FRAME_SIZE) as u64) >> stereo as u64) +} + +/// Calculates the byte position in a FloatPCM stream from a timestamp. +/// +/// Each sample is sized by `mem::size_of::() == 4usize`. +pub fn timestamp_to_byte_count(timestamp: Duration, stereo: bool) -> usize { + timestamp_to_sample_count(timestamp, stereo) * mem::size_of::() +} + +/// Calculates the time position in a FloatPCM stream from a byte index. +/// +/// Each sample is sized by `mem::size_of::() == 4usize`. +pub fn byte_count_to_timestamp(amt: usize, stereo: bool) -> Duration { + sample_count_to_timestamp(amt / mem::size_of::(), stereo) +} + +/// Create an Opus decoder outputting at a sample rate of 48kHz. +pub fn decoder(stereo: bool) -> OpusResult { + Decoder::new( + SampleRate::Hz48000, + if stereo { + Channels::Stereo + } else { + Channels::Mono + }, + ) +} diff --git a/src/input/ytdl_src.rs b/src/input/ytdl_src.rs new file mode 100644 index 0000000..1de3880 --- /dev/null +++ b/src/input/ytdl_src.rs @@ -0,0 +1,107 @@ +use super::{ + child_to_reader, + error::{Error, Result}, + Codec, + Container, + Input, + Metadata, +}; +use serde_json::Value; +use std::{ + io::{BufRead, BufReader, Read}, + process::{Command, Stdio}, +}; +use tokio::task; +use tracing::trace; + +/// Creates a streamed audio source with `youtube-dl` and `ffmpeg`. +pub async fn ytdl(uri: &str) -> Result { + _ytdl(uri, &[]).await +} + +pub(crate) async fn _ytdl(uri: &str, pre_args: &[&str]) -> Result { + let ytdl_args = [ + "--print-json", + "-f", + "webm[abr>0]/bestaudio/best", + "-R", + "infinite", + "--no-playlist", + "--ignore-config", + uri, + "-o", + "-", + ]; + + let ffmpeg_args = [ + "-f", + "s16le", + "-ac", + "2", + "-ar", + "48000", + "-acodec", + "pcm_f32le", + "-", + ]; + + let mut youtube_dl = Command::new("youtube-dl") + .args(&ytdl_args) + .stdin(Stdio::null()) + .stderr(Stdio::piped()) + .stdout(Stdio::piped()) + .spawn()?; + + let stderr = youtube_dl.stderr.take(); + + let (returned_stderr, value) = task::spawn_blocking(move || { + if let Some(mut s) = stderr { + let out: Option = { + let mut o_vec = vec![]; + let mut serde_read = BufReader::new(s.by_ref()); + // Newline... + if let Ok(len) = serde_read.read_until(0xA, &mut o_vec) { + serde_json::from_slice(&o_vec[..len]).ok() + } else { + None + } + }; + + (Some(s), out) + } else { + (None, None) + } + }) + .await + .map_err(|_| Error::Metadata)?; + + youtube_dl.stderr = returned_stderr; + + let ffmpeg = Command::new("ffmpeg") + .args(pre_args) + .arg("-i") + .arg("-") + .args(&ffmpeg_args) + .stdin(youtube_dl.stdout.ok_or(Error::Stdout)?) + .stderr(Stdio::null()) + .stdout(Stdio::piped()) + .spawn()?; + + let metadata = Metadata::from_ytdl_output(value.unwrap_or_default()); + + trace!("ytdl metadata {:?}", metadata); + + Ok(Input::new( + true, + child_to_reader::(ffmpeg), + Codec::FloatPcm, + Container::Raw, + Some(metadata), + )) +} + +/// Creates a streamed audio source from YouTube search results with `youtube-dl`,`ffmpeg`, and `ytsearch`. +/// Takes the first video listed from the YouTube search. +pub async fn ytdl_search(name: &str) -> Result { + ytdl(&format!("ytsearch1:{}", name)).await +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..99e53c7 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,84 @@ +#![doc( + html_logo_url = "https://raw.githubusercontent.com/FelixMcFelix/serenity/voice-rework/songbird/songbird.png", + html_favicon_url = "https://raw.githubusercontent.com/FelixMcFelix/serenity/voice-rework/songbird/songbird-ico.png" +)] +#![deny(missing_docs)] +//! ![project logo][logo] +//! +//! Songbird is an async, cross-library compatible voice system for Discord, written in Rust. +//! The library offers: +//! * A standalone gateway frontend compatible with [serenity] and [twilight] using the +//! `"gateway"` and `"[serenity/twilight]-[rustls/native]"` features. You can even run +//! driverless, to help manage your [lavalink] sessions. +//! * A standalone driver for voice calls, via the `"driver"` feature. If you can create +//! a [`ConnectionInfo`] using any other gateway, or language for your bot, then you +//! can run the songbird voice driver. +//! * And, by default, a fully featured voice system featuring events, queues, RT(C)P packet +//! handling, seeking on compatible streams, shared multithreaded audio stream caches, +//! and direct Opus data passthrough from DCA files. +//! +//! ## Examples +//! Full examples showing various types of functionality and integrations can be found as part of [serenity's examples], +//! and in [this crate's examples directory]. +//! +//! ## Attribution +//! +//! Songbird's logo is based upon the copyright-free image ["Black-Capped Chickadee"] by George Gorgas White. +//! +//! [logo]: https://raw.githubusercontent.com/FelixMcFelix/serenity/voice-rework/songbird/songbird.png +//! [serenity]: https://github.com/serenity-rs/serenity +//! [twilight]: https://github.com/twilight-rs/twilight +//! [serenity's examples]: https://github.com/serenity-rs/serenity/tree/current/examples +//! [this crate's examples directory]: https://github.com/serenity-rs/serenity/tree/current/songbird/examples +//! ["Black-Capped Chickadee"]: https://www.oldbookillustrations.com/illustrations/black-capped-chickadee/ +//! [`ConnectionInfo`]: struct.ConnectionInfo.html +//! [lavalink]: https://github.com/Frederikam/Lavalink + +pub mod constants; +#[cfg(feature = "driver")] +pub mod driver; +pub mod error; +#[cfg(feature = "driver")] +pub mod events; +#[cfg(feature = "gateway")] +mod handler; +pub mod id; +pub(crate) mod info; +#[cfg(feature = "driver")] +pub mod input; +#[cfg(feature = "gateway")] +mod manager; +#[cfg(feature = "serenity")] +pub mod serenity; +#[cfg(feature = "gateway")] +pub mod shards; +#[cfg(feature = "driver")] +pub mod tracks; +#[cfg(feature = "driver")] +mod ws; + +#[cfg(feature = "driver")] +pub use audiopus::{self as opus, Bitrate}; +#[cfg(feature = "driver")] +pub use discortp as packet; +#[cfg(feature = "driver")] +pub use serenity_voice_model as model; + +#[cfg(test)] +use utils as test_utils; + +#[cfg(feature = "driver")] +pub use crate::{ + driver::Driver, + events::{CoreEvent, Event, EventContext, EventHandler, TrackEvent}, + input::{ffmpeg, ytdl}, + tracks::create_player, +}; + +#[cfg(feature = "gateway")] +pub use crate::{handler::Call, manager::Songbird}; + +#[cfg(feature = "serenity")] +pub use crate::serenity::*; + +pub use info::ConnectionInfo; diff --git a/src/manager.rs b/src/manager.rs new file mode 100644 index 0000000..7543975 --- /dev/null +++ b/src/manager.rs @@ -0,0 +1,353 @@ +#[cfg(feature = "driver")] +use crate::error::ConnectionResult; +use crate::{ + error::{JoinError, JoinResult}, + id::{ChannelId, GuildId, UserId}, + shards::Sharder, + Call, + ConnectionInfo, +}; +#[cfg(feature = "serenity")] +use async_trait::async_trait; +use flume::Receiver; +#[cfg(feature = "serenity")] +use futures::channel::mpsc::UnboundedSender as Sender; +use parking_lot::RwLock as PRwLock; +#[cfg(feature = "serenity")] +use serenity::{ + client::bridge::voice::VoiceGatewayManager, + gateway::InterMessage, + model::{ + id::{GuildId as SerenityGuild, UserId as SerenityUser}, + voice::VoiceState, + }, +}; +use std::{collections::HashMap, sync::Arc}; +use tokio::sync::Mutex; +#[cfg(feature = "twilight")] +use twilight_gateway::Cluster; +#[cfg(feature = "twilight")] +use twilight_model::gateway::event::Event as TwilightEvent; + +#[derive(Clone, Copy, Debug, Default)] +struct ClientData { + shard_count: u64, + initialised: bool, + user_id: UserId, +} + +/// A shard-aware struct responsible for managing [`Call`]s. +/// +/// This manager transparently maps guild state and a source of shard information +/// into individual calls, and forwards state updates which affect call state. +/// +/// [`Call`]: struct.Call.html +#[derive(Debug)] +pub struct Songbird { + client_data: PRwLock, + calls: PRwLock>>>, + sharder: Sharder, +} + +impl Songbird { + #[cfg(feature = "serenity")] + /// Create a new Songbird instance for serenity. + /// + /// This must be [registered] after creation. + /// + /// [registered]: serenity/fn.register_with.html + pub fn serenity() -> Arc { + Arc::new(Self { + client_data: Default::default(), + calls: Default::default(), + sharder: Sharder::Serenity(Default::default()), + }) + } + + #[cfg(feature = "twilight")] + /// Create a new Songbird instance for twilight. + /// + /// Twilight handlers do not need to be registered, but + /// users are responsible for passing in any events using + /// [`process`]. + /// + /// [`process`]: #method.process + pub fn twilight(cluster: Cluster, shard_count: u64, user_id: U) -> Arc + where + U: Into, + { + Arc::new(Self { + client_data: PRwLock::new(ClientData { + shard_count, + initialised: true, + user_id: user_id.into(), + }), + calls: Default::default(), + sharder: Sharder::Twilight(cluster), + }) + } + + /// Set the bot's user, and the number of shards in use. + /// + /// If this struct is already initialised (e.g., from [`::twilight`]), + /// or a previous call, then this function is a no-op. + /// + /// [`::twilight`]: #method.twilight + pub fn initialise_client_data>(&self, shard_count: u64, user_id: U) { + let mut client_data = self.client_data.write(); + + if client_data.initialised { + return; + } + + client_data.shard_count = shard_count; + client_data.user_id = user_id.into(); + client_data.initialised = true; + } + + /// Retreives a [`Call`] for the given guild, if one already exists. + /// + /// [`Call`]: struct.Call.html + pub fn get>(&self, guild_id: G) -> Option>> { + let map_read = self.calls.read(); + map_read.get(&guild_id.into()).cloned() + } + + /// Retreives a [`Call`] for the given guild, creating a new one if + /// none is found. + /// + /// This will not join any calls, or cause connection state to change. + /// + /// [`Call`]: struct.Call.html + pub fn get_or_insert(&self, guild_id: GuildId) -> Arc> { + self.get(guild_id).unwrap_or_else(|| { + let mut map_read = self.calls.write(); + + map_read + .entry(guild_id) + .or_insert_with(|| { + let info = self.manager_info(); + let shard = shard_id(guild_id.0, info.shard_count); + let shard_handle = self + .sharder + .get_shard(shard) + .expect("Failed to get shard handle: shard_count incorrect?"); + + Arc::new(Mutex::new(Call::new(guild_id, shard_handle, info.user_id))) + }) + .clone() + }) + } + + fn manager_info(&self) -> ClientData { + let client_data = self.client_data.write(); + + *client_data + } + + #[cfg(feature = "driver")] + /// Connects to a target by retrieving its relevant [`Call`] and + /// connecting, or creating the handler if required. + /// + /// This can also switch to the given channel, if a handler already exists + /// for the target and the current connected channel is not equal to the + /// given channel. + /// + /// The provided channel ID is used as a connection target. The + /// channel _must_ be in the provided guild. This is _not_ checked by the + /// library, and will result in an error. If there is already a connected + /// handler for the guild, _and_ the provided channel is different from the + /// channel that the connection is already connected to, then the handler + /// will switch the connection to the provided channel. + /// + /// If you _only_ need to retrieve the handler for a target, then use + /// [`get`]. + /// + /// [`Call`]: struct.Call.html + /// [`get`]: #method.get + #[inline] + pub async fn join( + &self, + guild_id: G, + channel_id: C, + ) -> (Arc>, JoinResult>>) + where + C: Into, + G: Into, + { + self._join(guild_id.into(), channel_id.into()).await + } + + #[cfg(feature = "driver")] + async fn _join( + &self, + guild_id: GuildId, + channel_id: ChannelId, + ) -> (Arc>, JoinResult>>) { + let call = self.get_or_insert(guild_id); + + let result = { + let mut handler = call.lock().await; + handler.join(channel_id).await + }; + + (call, result) + } + + /// Partially connects to a target by retrieving its relevant [`Call`] and + /// connecting, or creating the handler if required. + /// + /// This method returns the handle and the connection info needed for other libraries + /// or drivers, such as lavalink, and does not actually start or run a voice call. + /// + /// [`Call`]: struct.Call.html + #[inline] + pub async fn join_gateway( + &self, + guild_id: G, + channel_id: C, + ) -> (Arc>, JoinResult>) + where + C: Into, + G: Into, + { + self._join_gateway(guild_id.into(), channel_id.into()).await + } + + async fn _join_gateway( + &self, + guild_id: GuildId, + channel_id: ChannelId, + ) -> (Arc>, JoinResult>) { + let call = self.get_or_insert(guild_id); + + let result = { + let mut handler = call.lock().await; + handler.join_gateway(channel_id).await + }; + + (call, result) + } + + /// Retrieves the [handler][`Call`] for the given target and leaves the + /// associated voice channel, if connected. + /// + /// This will _not_ drop the handler, and will preserve it and its settings. + /// + /// This is a wrapper around [getting][`get`] a handler and calling + /// [`leave`] on it. + /// + /// [`Call`]: struct.Call.html + /// [`get`]: #method.get + /// [`leave`]: struct.Call.html#method.leave + #[inline] + pub async fn leave>(&self, guild_id: G) -> JoinResult<()> { + self._leave(guild_id.into()).await + } + + async fn _leave(&self, guild_id: GuildId) -> JoinResult<()> { + if let Some(call) = self.get(guild_id) { + let mut handler = call.lock().await; + handler.leave().await + } else { + Err(JoinError::NoCall) + } + } + + /// Retrieves the [`Call`] for the given target and leaves the associated + /// voice channel, if connected. + /// + /// The handler is then dropped, removing settings for the target. + /// + /// An Err(...) value implies that the gateway could not be contacted, + /// and that leaving should be attempted again later (i.e., after reconnect). + /// + /// [`Call`]: struct.Call.html + #[inline] + pub async fn remove>(&self, guild_id: G) -> JoinResult<()> { + self._remove(guild_id.into()).await + } + + async fn _remove(&self, guild_id: GuildId) -> JoinResult<()> { + self.leave(guild_id).await?; + let mut calls = self.calls.write(); + calls.remove(&guild_id); + Ok(()) + } +} + +#[cfg(feature = "twilight")] +impl Songbird { + /// Handle events received on the cluster. + /// + /// When using twilight, you are required to call this with all inbound + /// (voice) events, *i.e.*, at least `VoiceStateUpdate`s and `VoiceServerUpdate`s. + pub async fn process(&self, event: &TwilightEvent) { + match event { + TwilightEvent::VoiceServerUpdate(v) => { + let call = v.guild_id.map(GuildId::from).and_then(|id| self.get(id)); + + if let Some(call) = call { + let mut handler = call.lock().await; + if let Some(endpoint) = &v.endpoint { + handler.update_server(endpoint.clone(), v.token.clone()); + } + } + }, + TwilightEvent::VoiceStateUpdate(v) => { + if v.0.user_id.0 != self.client_data.read().user_id.0 { + return; + } + + let call = v.0.guild_id.map(GuildId::from).and_then(|id| self.get(id)); + + if let Some(call) = call { + let mut handler = call.lock().await; + handler.update_state(v.0.session_id.clone()); + } + }, + _ => {}, + } + } +} + +#[cfg(feature = "serenity")] +#[async_trait] +impl VoiceGatewayManager for Songbird { + async fn initialise(&self, shard_count: u64, user_id: SerenityUser) { + self.initialise_client_data(shard_count, user_id); + } + + async fn register_shard(&self, shard_id: u64, sender: Sender) { + self.sharder.register_shard_handle(shard_id, sender); + } + + async fn deregister_shard(&self, shard_id: u64) { + self.sharder.deregister_shard_handle(shard_id); + } + + async fn server_update(&self, guild_id: SerenityGuild, endpoint: &Option, token: &str) { + if let Some(call) = self.get(guild_id) { + let mut handler = call.lock().await; + if let Some(endpoint) = endpoint { + handler.update_server(endpoint.clone(), token.to_string()); + } + } + } + + async fn state_update(&self, guild_id: SerenityGuild, voice_state: &VoiceState) { + if voice_state.user_id.0 != self.client_data.read().user_id.0 { + return; + } + + if let Some(call) = self.get(guild_id) { + let mut handler = call.lock().await; + handler.update_state(voice_state.session_id.clone()); + } + } +} + +#[inline] +fn shard_id(guild_id: u64, shard_count: u64) -> u64 { + (guild_id >> 22) % shard_count +} diff --git a/src/serenity.rs b/src/serenity.rs new file mode 100644 index 0000000..87b6d32 --- /dev/null +++ b/src/serenity.rs @@ -0,0 +1,71 @@ +//! Compatability and convenience methods for working with [serenity]. +//! Requires the `"serenity-rustls"` or `"serenity-native"` features. +//! +//! [serenity]: https://crates.io/crates/serenity/0.9.0-rc.2 + +use crate::manager::Songbird; +use serenity::{ + client::{ClientBuilder, Context}, + prelude::TypeMapKey, +}; +use std::sync::Arc; + +/// Zero-size type used to retrieve the registered [`Songbird`] instance +/// from serenity's inner TypeMap. +/// +/// [`Songbird`]: ../struct.Songbird.html +pub struct SongbirdKey; + +impl TypeMapKey for SongbirdKey { + type Value = Arc; +} + +/// Installs a new songbird instance into the serenity client. +/// +/// This should be called after any uses of `ClientBuilder::type_map`. +pub fn register(client_builder: ClientBuilder) -> ClientBuilder { + let voice = Songbird::serenity(); + register_with(client_builder, voice) +} + +/// Installs a given songbird instance into the serenity client. +/// +/// This should be called after any uses of `ClientBuilder::type_map`. +pub fn register_with(client_builder: ClientBuilder, voice: Arc) -> ClientBuilder { + client_builder + .voice_manager_arc(voice.clone()) + .type_map_insert::(voice) +} + +/// Retrieve the Songbird voice client from a serenity context's +/// shared key-value store. +pub async fn get(ctx: &Context) -> Option> { + let data = ctx.data.read().await; + + data.get::().cloned() +} + +/// Helper trait to add installation/creation methods to serenity's +/// `ClientBuilder`. +/// +/// These install the client to receive gateway voice events, and +/// store an easily accessible reference to Songbird's managers. +pub trait SerenityInit { + /// Registers a new Songbird voice system with serenity, storing it for easy + /// access via [`get`]. + /// + /// [`get`]: fn.get.html + fn register_songbird(self) -> Self; + /// Registers a given Songbird voice system with serenity, as above. + fn register_songbird_with(self, voice: Arc) -> Self; +} + +impl SerenityInit for ClientBuilder<'_> { + fn register_songbird(self) -> Self { + register(self) + } + + fn register_songbird_with(self, voice: Arc) -> Self { + register_with(self, voice) + } +} diff --git a/src/shards.rs b/src/shards.rs new file mode 100644 index 0000000..7577b8f --- /dev/null +++ b/src/shards.rs @@ -0,0 +1,168 @@ +//! Handlers for sending packets over sharded connections. + +use crate::error::{JoinError, JoinResult}; +#[cfg(feature = "serenity")] +use futures::channel::mpsc::{TrySendError, UnboundedSender as Sender}; +#[cfg(feature = "serenity")] +use parking_lot::{lock_api::RwLockWriteGuard, Mutex as PMutex, RwLock as PRwLock}; +use serde_json::Value; +#[cfg(feature = "serenity")] +use serenity::gateway::InterMessage; +#[cfg(feature = "serenity")] +use std::{collections::HashMap, result::Result as StdResult, sync::Arc}; +use tracing::error; +#[cfg(feature = "twilight")] +use twilight_gateway::{Cluster, Shard as TwilightShard}; + +#[derive(Debug)] +#[non_exhaustive] +/// Source of individual shard connection handles. +pub enum Sharder { + #[cfg(feature = "serenity")] + /// Serenity-specific wrapper for sharder state initialised by the library. + Serenity(SerenitySharder), + #[cfg(feature = "twilight")] + /// Twilight-specific wrapper for sharder state initialised by the user. + Twilight(Cluster), +} + +impl Sharder { + #[allow(unreachable_patterns)] + /// Returns a new handle to the required inner shard. + pub fn get_shard(&self, shard_id: u64) -> Option { + match self { + #[cfg(feature = "serenity")] + Sharder::Serenity(s) => Some(Shard::Serenity(s.get_or_insert_shard_handle(shard_id))), + #[cfg(feature = "twilight")] + Sharder::Twilight(t) => t.shard(shard_id).map(Shard::Twilight), + _ => None, + } + } +} + +#[cfg(feature = "serenity")] +impl Sharder { + #[allow(unreachable_patterns)] + pub(crate) fn register_shard_handle(&self, shard_id: u64, sender: Sender) { + match self { + Sharder::Serenity(s) => s.register_shard_handle(shard_id, sender), + _ => error!("Called serenity management function on a non-serenity Songbird instance."), + } + } + + #[allow(unreachable_patterns)] + pub(crate) fn deregister_shard_handle(&self, shard_id: u64) { + match self { + Sharder::Serenity(s) => s.deregister_shard_handle(shard_id), + _ => error!("Called serenity management function on a non-serenity Songbird instance."), + } + } +} + +#[cfg(feature = "serenity")] +#[derive(Debug, Default)] +/// Serenity-specific wrapper for sharder state initialised by the library. +/// +/// This is updated and maintained by the library, and is designed to prevent +/// message loss during rebalances and reconnects. +pub struct SerenitySharder(PRwLock>>); + +#[cfg(feature = "serenity")] +impl SerenitySharder { + fn get_or_insert_shard_handle(&self, shard_id: u64) -> Arc { + ({ + let map_read = self.0.read(); + map_read.get(&shard_id).cloned() + }) + .unwrap_or_else(|| { + let mut map_read = self.0.write(); + map_read.entry(shard_id).or_default().clone() + }) + } + + fn register_shard_handle(&self, shard_id: u64, sender: Sender) { + // Write locks are only used to add new entries to the map. + let handle = self.get_or_insert_shard_handle(shard_id); + + handle.register(sender); + } + + fn deregister_shard_handle(&self, shard_id: u64) { + // Write locks are only used to add new entries to the map. + let handle = self.get_or_insert_shard_handle(shard_id); + + handle.deregister(); + } +} + +#[derive(Clone, Debug)] +#[non_exhaustive] +/// A reference to an individual websocket connection. +pub enum Shard { + #[cfg(feature = "serenity")] + /// Handle to one of serenity's shard runners. + Serenity(Arc), + #[cfg(feature = "twilight")] + /// Handle to a twilight shard spawned from a cluster. + Twilight(TwilightShard), +} + +impl Shard { + #[allow(unreachable_patterns)] + /// Send a JSON message to the inner shard handle. + pub async fn send(&mut self, msg: Value) -> JoinResult<()> { + match self { + #[cfg(feature = "serenity")] + Shard::Serenity(s) => s.send(InterMessage::Json(msg))?, + #[cfg(feature = "twilight")] + Shard::Twilight(t) => t.command(&msg).await?, + _ => return Err(JoinError::NoSender), + } + Ok(()) + } +} + +#[cfg(feature = "serenity")] +/// Handle to an individual shard designed to buffer unsent messages while +/// a reconnect/rebalance is ongoing. +#[derive(Debug, Default)] +pub struct SerenityShardHandle { + sender: PRwLock>>, + queue: PMutex>, +} + +#[cfg(feature = "serenity")] +impl SerenityShardHandle { + fn register(&self, sender: Sender) { + let mut sender_lock = self.sender.write(); + *sender_lock = Some(sender); + + let sender_lock = RwLockWriteGuard::downgrade(sender_lock); + let mut messages_lock = self.queue.lock(); + + if let Some(sender) = &*sender_lock { + for msg in messages_lock.drain(..) { + if let Err(e) = sender.unbounded_send(msg) { + error!("Error while clearing gateway message queue: {:?}", e); + break; + } + } + } + } + + fn deregister(&self) { + let mut sender_lock = self.sender.write(); + *sender_lock = None; + } + + fn send(&self, message: InterMessage) -> StdResult<(), TrySendError> { + let sender_lock = self.sender.read(); + if let Some(sender) = &*sender_lock { + sender.unbounded_send(message) + } else { + let mut messages_lock = self.queue.lock(); + messages_lock.push(message); + Ok(()) + } + } +} diff --git a/src/tracks/command.rs b/src/tracks/command.rs new file mode 100644 index 0000000..4dc3ef6 --- /dev/null +++ b/src/tracks/command.rs @@ -0,0 +1,53 @@ +use super::*; +use crate::events::EventData; +use std::time::Duration; +use tokio::sync::oneshot::Sender as OneshotSender; + +/// A request from external code using a [`TrackHandle`] to modify +/// or act upon an [`Track`] object. +/// +/// [`Track`]: struct.Track.html +/// [`TrackHandle`]: struct.TrackHandle.html +pub enum TrackCommand { + /// Set the track's play_mode to play/resume. + Play, + /// Set the track's play_mode to pause. + Pause, + /// Stop the target track. This cannot be undone. + Stop, + /// Set the track's volume. + Volume(f32), + /// Seek to the given duration. + /// + /// On unsupported input types, this can be fatal. + Seek(Duration), + /// Register an event on this track. + AddEvent(EventData), + /// Run some closure on this track, with direct access to the core object. + Do(Box), + /// Request a read-only view of this track's state. + Request(OneshotSender>), + /// Change the loop count/strategy of this track. + Loop(LoopState), +} + +impl std::fmt::Debug for TrackCommand { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + use TrackCommand::*; + write!( + f, + "TrackCommand::{}", + match self { + Play => "Play".to_string(), + Pause => "Pause".to_string(), + Stop => "Stop".to_string(), + Volume(vol) => format!("Volume({})", vol), + Seek(d) => format!("Seek({:?})", d), + AddEvent(evt) => format!("AddEvent({:?})", evt), + Do(_f) => "Do([function])".to_string(), + Request(tx) => format!("Request({:?})", tx), + Loop(loops) => format!("Loop({:?})", loops), + } + ) + } +} diff --git a/src/tracks/handle.rs b/src/tracks/handle.rs new file mode 100644 index 0000000..effa703 --- /dev/null +++ b/src/tracks/handle.rs @@ -0,0 +1,159 @@ +use super::*; +use crate::events::{Event, EventData, EventHandler}; +use std::time::Duration; +use tokio::sync::{ + mpsc::{error::SendError, UnboundedSender}, + oneshot, +}; + +#[derive(Clone, Debug)] +/// Handle for safe control of a [`Track`] track from other threads, outside +/// of the audio mixing and voice handling context. +/// +/// Almost all method calls here are fallible; in most cases, this will be because +/// the underlying [`Track`] object has been discarded. Those which aren't refer +/// to immutable properties of the underlying stream. +/// +/// [`Track`]: struct.Track.html +pub struct TrackHandle { + command_channel: UnboundedSender, + seekable: bool, +} + +impl TrackHandle { + /// Creates a new handle, using the given command sink and hint as to whether + /// the underlying [`Input`] supports seek operations. + /// + /// [`Input`]: ../input/struct.Input.html + pub fn new(command_channel: UnboundedSender, seekable: bool) -> Self { + Self { + command_channel, + seekable, + } + } + + /// Unpauses an audio track. + pub fn play(&self) -> TrackResult { + self.send(TrackCommand::Play) + } + + /// Pauses an audio track. + pub fn pause(&self) -> TrackResult { + self.send(TrackCommand::Pause) + } + + /// Stops an audio track. + /// + /// This is *final*, and will cause the audio context to fire + /// a [`TrackEvent::End`] event. + /// + /// [`TrackEvent::End`]: ../events/enum.TrackEvent.html#variant.End + pub fn stop(&self) -> TrackResult { + self.send(TrackCommand::Stop) + } + + /// Sets the volume of an audio track. + pub fn set_volume(&self, volume: f32) -> TrackResult { + self.send(TrackCommand::Volume(volume)) + } + + /// Denotes whether the underlying [`Input`] stream is compatible with arbitrary seeking. + /// + /// If this returns `false`, all calls to [`seek`] will fail, and the track is + /// incapable of looping. + /// + /// [`seek`]: #method.seek + /// [`Input`]: ../input/struct.Input.html + pub fn is_seekable(&self) -> bool { + self.seekable + } + + /// Seeks along the track to the specified position. + /// + /// If the underlying [`Input`] does not support this behaviour, + /// then all calls will fail. + /// + /// [`Input`]: ../input/struct.Input.html + pub fn seek_time(&self, position: Duration) -> TrackResult { + if self.seekable { + self.send(TrackCommand::Seek(position)) + } else { + Err(SendError(TrackCommand::Seek(position))) + } + } + + /// Attach an event handler to an audio track. These will receive [`EventContext::Track`]. + /// + /// Users **must** ensure that no costly work or blocking occurs + /// within the supplied function or closure. *Taking excess time could prevent + /// timely sending of packets, causing audio glitches and delays*. + /// + /// [`Track`]: struct.Track.html + /// [`EventContext::Track`]: ../events/enum.EventContext.html#variant.Track + pub fn add_event(&self, event: Event, action: F) -> TrackResult { + let cmd = TrackCommand::AddEvent(EventData::new(event, action)); + if event.is_global_only() { + Err(SendError(cmd)) + } else { + self.send(cmd) + } + } + + /// Perform an arbitrary action on a raw [`Track`] object. + /// + /// Users **must** ensure that no costly work or blocking occurs + /// within the supplied function or closure. *Taking excess time could prevent + /// timely sending of packets, causing audio glitches and delays*. + /// + /// [`Track`]: struct.Track.html + pub fn action(&self, action: F) -> TrackResult + where + F: FnOnce(&mut Track) + Send + Sync + 'static, + { + self.send(TrackCommand::Do(Box::new(action))) + } + + /// Request playback information and state from the audio context. + /// + /// Crucially, the audio thread will respond *at a later time*: + /// It is up to the user when or how this should be read from the returned channel. + pub fn get_info(&self) -> TrackQueryResult { + let (tx, rx) = oneshot::channel(); + self.send(TrackCommand::Request(tx)).map(move |_| rx) + } + + /// Set an audio track to loop indefinitely. + pub fn enable_loop(&self) -> TrackResult { + if self.seekable { + self.send(TrackCommand::Loop(LoopState::Infinite)) + } else { + Err(SendError(TrackCommand::Loop(LoopState::Infinite))) + } + } + + /// Set an audio track to no longer loop. + pub fn disable_loop(&self) -> TrackResult { + if self.seekable { + self.send(TrackCommand::Loop(LoopState::Finite(0))) + } else { + Err(SendError(TrackCommand::Loop(LoopState::Finite(0)))) + } + } + + /// Set an audio track to loop a set number of times. + pub fn loop_for(&self, count: usize) -> TrackResult { + if self.seekable { + self.send(TrackCommand::Loop(LoopState::Finite(count))) + } else { + Err(SendError(TrackCommand::Loop(LoopState::Finite(count)))) + } + } + + #[inline] + /// Send a raw command to the [`Track`] object. + /// + /// [`Track`]: struct.Track.html + pub fn send(&self, cmd: TrackCommand) -> TrackResult { + self.command_channel.send(cmd) + } +} diff --git a/src/tracks/looping.rs b/src/tracks/looping.rs new file mode 100644 index 0000000..0e57d0a --- /dev/null +++ b/src/tracks/looping.rs @@ -0,0 +1,22 @@ +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +/// Looping behaviour for a [`Track`]. +/// +/// [`Track`]: struct.Track.html +pub enum LoopState { + /// Track will loop endlessly until loop state is changed or + /// manually stopped. + Infinite, + + /// Track will loop `n` more times. + /// + /// `Finite(0)` is the `Default`, stopping the track once its [`Input`] ends. + /// + /// [`Input`]: ../input/struct.Input.html + Finite(usize), +} + +impl Default for LoopState { + fn default() -> Self { + Self::Finite(0) + } +} diff --git a/src/tracks/mod.rs b/src/tracks/mod.rs new file mode 100644 index 0000000..d60f867 --- /dev/null +++ b/src/tracks/mod.rs @@ -0,0 +1,379 @@ +//! Live, controllable audio instances. +//! +//! Tracks add control and event data around the bytestreams offered by [`Input`], +//! where each represents a live audio source inside of the driver's mixer. +//! +//! To prevent locking and stalling of the driver, tracks are controlled from your bot using a +//! [`TrackHandle`]. These handles remotely send commands from your bot's (a)sync +//! context to control playback, register events, and execute synchronous closures. +//! +//! If you want a new track from an [`Input`], i.e., for direct control before +//! playing your source on the driver, use [`create_player`]. +//! +//! [`Input`]: ../input/struct.Input.html +//! [`TrackHandle`]: struct.TrackHandle.html +//! [`create_player`]: fn.create_player.html + +mod command; +mod handle; +mod looping; +mod mode; +mod queue; +mod state; + +pub use self::{command::*, handle::*, looping::*, mode::*, queue::*, state::*}; + +use crate::{constants::*, driver::tasks::message::*, events::EventStore, input::Input}; +use std::time::Duration; +use tokio::sync::{ + mpsc::{ + self, + error::{SendError, TryRecvError}, + UnboundedReceiver, + }, + oneshot::Receiver as OneshotReceiver, +}; + +/// Control object for audio playback. +/// +/// Accessed by both commands and the playback code -- as such, access from user code is +/// almost always guarded via a [`TrackHandle`]. You should expect to receive +/// access to a raw object of this type via [`create_player`], for use in +/// [`Driver::play`] or [`Driver::play_only`]. +/// +/// # Example +/// +/// ```rust,no_run +/// use songbird::{driver::Driver, ffmpeg, tracks::create_player}; +/// +/// # async { +/// // A Call is also valid here! +/// let mut handler: Driver = Default::default(); +/// let source = ffmpeg("../audio/my-favourite-song.mp3") +/// .await +/// .expect("This might fail: handle this error!"); +/// let (mut audio, audio_handle) = create_player(source); +/// +/// audio.set_volume(0.5); +/// +/// handler.play_only(audio); +/// +/// // Future access occurs via audio_handle. +/// # }; +/// ``` +/// +/// [`Driver::play_only`]: ../struct.Driver.html#method.play_only +/// [`Driver::play`]: ../struct.Driver.html#method.play +/// [`TrackHandle`]: struct.TrackHandle.html +/// [`create_player`]: fn.create_player.html +#[derive(Debug)] +pub struct Track { + /// Whether or not this sound is currently playing. + /// + /// Can be controlled with [`play`] or [`pause`] if chaining is desired. + /// + /// [`play`]: #method.play + /// [`pause`]: #method.pause + pub(crate) playing: PlayMode, + + /// The desired volume for playback. + /// + /// Sensible values fall between `0.0` and `1.0`. + /// + /// Can be controlled with [`volume`] if chaining is desired. + /// + /// [`volume`]: #method.volume + pub(crate) volume: f32, + + /// Underlying data access object. + /// + /// *Calling code is not expected to use this.* + pub(crate) source: Input, + + /// The current playback position in the track. + pub(crate) position: Duration, + + /// The total length of time this track has been active. + pub(crate) play_time: Duration, + + /// List of events attached to this audio track. + /// + /// This may be used to add additional events to a track + /// before it is sent to the audio context for playing. + pub events: Option, + + /// Channel from which commands are received. + /// + /// Track commands are sent in this manner to ensure that access + /// occurs in a thread-safe manner, without allowing any external + /// code to lock access to audio objects and block packet generation. + pub(crate) commands: UnboundedReceiver, + + /// Handle for safe control of this audio track from other threads. + /// + /// Typically, this is used by internal code to supply context information + /// to event handlers, though more may be cloned from this handle. + pub handle: TrackHandle, + + /// Count of remaining loops. + pub loops: LoopState, +} + +impl Track { + /// Create a new track directly from an input, command source, + /// and handle. + /// + /// In general, you should probably use [`create_player`]. + /// + /// [`create_player`]: fn.create_player.html + pub fn new_raw( + source: Input, + commands: UnboundedReceiver, + handle: TrackHandle, + ) -> Self { + Self { + playing: Default::default(), + volume: 1.0, + source, + position: Default::default(), + play_time: Default::default(), + events: Some(EventStore::new_local()), + commands, + handle, + loops: LoopState::Finite(0), + } + } + + /// Sets a track to playing if it is paused. + pub fn play(&mut self) -> &mut Self { + self.set_playing(PlayMode::Play) + } + + /// Pauses a track if it is playing. + pub fn pause(&mut self) -> &mut Self { + self.set_playing(PlayMode::Pause) + } + + /// Manually stops a track. + /// + /// This will cause the audio track to be removed, with any relevant events triggered. + /// Stopped/ended tracks cannot be restarted. + pub fn stop(&mut self) -> &mut Self { + self.set_playing(PlayMode::Stop) + } + + pub(crate) fn end(&mut self) -> &mut Self { + self.set_playing(PlayMode::End) + } + + #[inline] + fn set_playing(&mut self, new_state: PlayMode) -> &mut Self { + self.playing = self.playing.change_to(new_state); + + self + } + + /// Returns the current play status of this track. + pub fn playing(&self) -> PlayMode { + self.playing + } + + /// Sets [`volume`] in a manner that allows method chaining. + /// + /// [`volume`]: #structfield.volume + pub fn set_volume(&mut self, volume: f32) -> &mut Self { + self.volume = volume; + + self + } + + /// Returns the current playback position. + pub fn volume(&self) -> f32 { + self.volume + } + + /// Returns the current playback position. + pub fn position(&self) -> Duration { + self.position + } + + /// Returns the total length of time this track has been active. + pub fn play_time(&self) -> Duration { + self.play_time + } + + /// Sets [`loops`] in a manner that allows method chaining. + /// + /// [`loops`]: #structfield.loops + pub fn set_loops(&mut self, loops: LoopState) -> &mut Self { + self.loops = loops; + self + } + + pub(crate) fn do_loop(&mut self) -> bool { + match self.loops { + LoopState::Infinite => true, + LoopState::Finite(0) => false, + LoopState::Finite(ref mut n) => { + *n -= 1; + true + }, + } + } + + /// Steps playback location forward by one frame. + pub(crate) fn step_frame(&mut self) { + self.position += TIMESTEP_LENGTH; + self.play_time += TIMESTEP_LENGTH; + } + + /// Receives and acts upon any commands forwarded by [`TrackHandle`]s. + /// + /// *Used internally*, this should not be exposed to users. + /// + /// [`TrackHandle`]: struct.TrackHandle.html + pub(crate) fn process_commands(&mut self, index: usize, ic: &Interconnect) { + // Note: disconnection and an empty channel are both valid, + // and should allow the audio object to keep running as intended. + + // Note that interconnect failures are not currently errors. + // In correct operation, the event thread should never panic, + // but it receiving status updates is secondary do actually + // doing the work. + loop { + match self.commands.try_recv() { + Ok(cmd) => { + use TrackCommand::*; + match cmd { + Play => { + self.play(); + let _ = ic.events.send(EventMessage::ChangeState( + index, + TrackStateChange::Mode(self.playing), + )); + }, + Pause => { + self.pause(); + let _ = ic.events.send(EventMessage::ChangeState( + index, + TrackStateChange::Mode(self.playing), + )); + }, + Stop => { + self.stop(); + let _ = ic.events.send(EventMessage::ChangeState( + index, + TrackStateChange::Mode(self.playing), + )); + }, + Volume(vol) => { + self.set_volume(vol); + let _ = ic.events.send(EventMessage::ChangeState( + index, + TrackStateChange::Volume(self.volume), + )); + }, + Seek(time) => { + self.seek_time(time); + let _ = ic.events.send(EventMessage::ChangeState( + index, + TrackStateChange::Position(self.position), + )); + }, + AddEvent(evt) => { + let _ = ic.events.send(EventMessage::AddTrackEvent(index, evt)); + }, + Do(action) => { + action(self); + let _ = ic.events.send(EventMessage::ChangeState( + index, + TrackStateChange::Total(self.state()), + )); + }, + Request(tx) => { + let _ = tx.send(Box::new(self.state())); + }, + Loop(loops) => { + self.set_loops(loops); + let _ = ic.events.send(EventMessage::ChangeState( + index, + TrackStateChange::Loops(self.loops, true), + )); + }, + } + }, + Err(TryRecvError::Closed) => { + // this branch will never be visited. + break; + }, + Err(TryRecvError::Empty) => { + break; + }, + } + } + } + + /// Creates a read-only copy of the audio track's state. + /// + /// The primary use-case of this is sending information across + /// threads in response to a [`TrackHandle`]. + /// + /// [`TrackHandle`]: struct.TrackHandle.html + pub fn state(&self) -> TrackState { + TrackState { + playing: self.playing, + volume: self.volume, + position: self.position, + play_time: self.play_time, + loops: self.loops, + } + } + + /// Seek to a specific point in the track. + /// + /// Returns `None` if unsupported. + pub fn seek_time(&mut self, pos: Duration) -> Option { + let out = self.source.seek_time(pos); + + if let Some(t) = out { + self.position = t; + } + + out + } +} + +/// Creates a [`Track`] object to pass into the audio context, and a [`TrackHandle`] +/// for safe, lock-free access in external code. +/// +/// Typically, this would be used if you wished to directly work on or configure +/// the [`Track`] object before it is passed over to the driver. +/// +/// [`Track`]: struct.Track.html +/// [`TrackHandle`]: struct.TrackHandle.html +pub fn create_player(source: Input) -> (Track, TrackHandle) { + let (tx, rx) = mpsc::unbounded_channel(); + let can_seek = source.is_seekable(); + let player = Track::new_raw(source, rx, TrackHandle::new(tx.clone(), can_seek)); + + (player, TrackHandle::new(tx, can_seek)) +} + +/// Alias for most result-free calls to a [`TrackHandle`]. +/// +/// Failure indicates that the accessed audio object has been +/// removed or deleted by the audio context. +/// +/// [`TrackHandle`]: struct.TrackHandle.html +pub type TrackResult = Result<(), SendError>; + +/// Alias for return value from calls to [`TrackHandle::get_info`]. +/// +/// Crucially, the audio thread will respond *at a later time*: +/// It is up to the user when or how this should be read from the returned channel. +/// +/// Failure indicates that the accessed audio object has been +/// removed or deleted by the audio context. +/// +/// [`TrackHandle::get_info`]: struct.TrackHandle.html#method.get_info +pub type TrackQueryResult = Result>, SendError>; diff --git a/src/tracks/mode.rs b/src/tracks/mode.rs new file mode 100644 index 0000000..80dd101 --- /dev/null +++ b/src/tracks/mode.rs @@ -0,0 +1,37 @@ +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +/// Playback status of a track. +pub enum PlayMode { + /// The track is currently playing. + Play, + /// The track is currently paused, and may be resumed. + Pause, + /// The track has been manually stopped, and cannot be restarted. + Stop, + /// The track has naturally ended, and cannot be restarted. + End, +} + +impl PlayMode { + /// Returns whether the track has irreversibly stopped. + pub fn is_done(self) -> bool { + matches!(self, PlayMode::Stop | PlayMode::End) + } + + pub(crate) fn change_to(self, other: Self) -> PlayMode { + use PlayMode::*; + + // Idea: a finished track cannot be restarted -- this action is final. + // We may want to change this in future so that seekable tracks can uncancel + // themselves, perhaps, but this requires a bit more machinery to readd... + match self { + Play | Pause => other, + state => state, + } + } +} + +impl Default for PlayMode { + fn default() -> Self { + PlayMode::Play + } +} diff --git a/src/tracks/queue.rs b/src/tracks/queue.rs new file mode 100644 index 0000000..349bc3d --- /dev/null +++ b/src/tracks/queue.rs @@ -0,0 +1,213 @@ +use crate::{ + driver::Driver, + events::{Event, EventContext, EventData, EventHandler, TrackEvent}, + input::Input, + tracks::{self, Track, TrackHandle, TrackResult}, +}; +use async_trait::async_trait; +use parking_lot::Mutex; +use std::{collections::VecDeque, sync::Arc}; +use tracing::{info, warn}; + +#[derive(Default)] +/// A simple queue for several audio sources, designed to +/// play in sequence. +/// +/// This makes use of [`TrackEvent`]s to determine when the current +/// song or audio file has finished before playing the next entry. +/// +/// `examples/e16_voice_events` demonstrates how a user might manage, +/// track and use this to run a song queue in many guilds in parallel. +/// This code is trivial to extend if extra functionality is needed. +/// +/// # Example +/// +/// ```rust,no_run +/// use songbird::{ +/// driver::Driver, +/// id::GuildId, +/// ffmpeg, +/// tracks::{create_player, TrackQueue}, +/// }; +/// use std::collections::HashMap; +/// +/// # async { +/// let guild = GuildId(0); +/// // A Call is also valid here! +/// let mut driver: Driver = Default::default(); +/// +/// let mut queues: HashMap = Default::default(); +/// +/// let source = ffmpeg("../audio/my-favourite-song.mp3") +/// .await +/// .expect("This might fail: handle this error!"); +/// +/// // We need to ensure that this guild has a TrackQueue created for it. +/// let queue = queues.entry(guild) +/// .or_default(); +/// +/// // Queueing a track is this easy! +/// queue.add_source(source, &mut driver); +/// # }; +/// ``` + +/// +/// [`TrackEvent`]: ../events/enum.TrackEvent.html +pub struct TrackQueue { + // NOTE: the choice of a parking lot mutex is quite deliberate + inner: Arc>, +} + +#[derive(Default)] +/// Inner portion of a [`TrackQueue`]. +/// +/// This abstracts away thread-safety from the user, +/// and offers a convenient location to store further state if required. +/// +/// [`TrackQueue`]: struct.TrackQueue.html +struct TrackQueueCore { + tracks: VecDeque, +} + +struct QueueHandler { + remote_lock: Arc>, +} + +#[async_trait] +impl EventHandler for QueueHandler { + async fn act(&self, ctx: &EventContext<'_>) -> Option { + let mut inner = self.remote_lock.lock(); + let _old = inner.tracks.pop_front(); + + info!("Queued track ended: {:?}.", ctx); + info!("{} tracks remain.", inner.tracks.len()); + + // If any audio files die unexpectedly, then keep going until we + // find one which works, or we run out. + let mut keep_looking = true; + while keep_looking && !inner.tracks.is_empty() { + if let Some(new) = inner.tracks.front() { + keep_looking = new.play().is_err(); + + // Discard files which cannot be used for whatever reason. + if keep_looking { + warn!("Track in Queue couldn't be played..."); + let _ = inner.tracks.pop_front(); + } + } + } + + None + } +} + +impl TrackQueue { + /// Create a new, empty, track queue. + pub fn new() -> Self { + Self { + inner: Arc::new(Mutex::new(TrackQueueCore { + tracks: VecDeque::new(), + })), + } + } + + /// Adds an audio source to the queue, to be played in the channel managed by `handler`. + pub fn add_source(&self, source: Input, handler: &mut Driver) { + let (audio, audio_handle) = tracks::create_player(source); + self.add(audio, audio_handle, handler); + } + + /// Adds a [`Track`] object to the queue, to be played in the channel managed by `handler`. + /// + /// This is used with [`voice::create_player`] if additional configuration or event handlers + /// are required before enqueueing the audio track. + /// + /// [`Track`]: struct.Track.html + /// [`voice::create_player`]: fn.create_player.html + pub fn add(&self, mut track: Track, track_handle: TrackHandle, handler: &mut Driver) { + info!("Track added to queue."); + let remote_lock = self.inner.clone(); + let mut inner = self.inner.lock(); + + if !inner.tracks.is_empty() { + track.pause(); + } + + track + .events + .as_mut() + .expect("Queue inspecting EventStore on new Track: did not exist.") + .add_event( + EventData::new(Event::Track(TrackEvent::End), QueueHandler { remote_lock }), + track.position, + ); + + handler.play(track); + inner.tracks.push_back(track_handle); + } + + /// Returns the number of tracks currently in the queue. + pub fn len(&self) -> usize { + let inner = self.inner.lock(); + + inner.tracks.len() + } + + /// Returns whether there are no tracks currently in the queue. + pub fn is_empty(&self) -> bool { + let inner = self.inner.lock(); + + inner.tracks.is_empty() + } + + /// Pause the track at the head of the queue. + pub fn pause(&self) -> TrackResult { + let inner = self.inner.lock(); + + if let Some(handle) = inner.tracks.front() { + handle.pause() + } else { + Ok(()) + } + } + + /// Resume the track at the head of the queue. + pub fn resume(&self) -> TrackResult { + let inner = self.inner.lock(); + + if let Some(handle) = inner.tracks.front() { + handle.play() + } else { + Ok(()) + } + } + + /// Stop the currently playing track, and clears the queue. + pub fn stop(&self) -> TrackResult { + let mut inner = self.inner.lock(); + + let out = inner.stop_current(); + + inner.tracks.clear(); + + out + } + + /// Skip to the next track in the queue, if it exists. + pub fn skip(&self) -> TrackResult { + let inner = self.inner.lock(); + + inner.stop_current() + } +} + +impl TrackQueueCore { + /// Skip to the next track in the queue, if it exists. + fn stop_current(&self) -> TrackResult { + if let Some(handle) = self.tracks.front() { + handle.stop() + } else { + Ok(()) + } + } +} diff --git a/src/tracks/state.rs b/src/tracks/state.rs new file mode 100644 index 0000000..b0650fb --- /dev/null +++ b/src/tracks/state.rs @@ -0,0 +1,31 @@ +use super::*; + +/// State of an [`Track`] object, designed to be passed to event handlers +/// and retrieved remotely via [`TrackHandle::get_info`] or +/// [`TrackHandle::get_info_blocking`]. +/// +/// [`Track`]: struct.Track.html +/// [`TrackHandle::get_info`]: struct.TrackHandle.html#method.get_info +/// [`TrackHandle::get_info_blocking`]: struct.TrackHandle.html#method.get_info_blocking +#[derive(Copy, Clone, Debug, Default, PartialEq)] +pub struct TrackState { + /// Play status (e.g., active, paused, stopped) of this track. + pub playing: PlayMode, + /// Current volume of this track. + pub volume: f32, + /// Current playback position in the source. + /// + /// This is altered by loops and seeks + pub position: Duration, + /// Total playback time, increasing monotonically. + pub play_time: Duration, + /// Remaining loops on this track. + pub loops: LoopState, +} + +impl TrackState { + pub(crate) fn step_frame(&mut self) { + self.position += TIMESTEP_LENGTH; + self.play_time += TIMESTEP_LENGTH; + } +} diff --git a/src/ws.rs b/src/ws.rs new file mode 100644 index 0000000..f0100e2 --- /dev/null +++ b/src/ws.rs @@ -0,0 +1,208 @@ +// FIXME: this is copied from serenity/src/internal/ws_impl.rs +// To prevent this duplication, we either need to expose this on serenity's API +// (not desirable) or break the common WS elements into a subcrate. +// I believe that decisions is outside of the scope of the voice subcrate PR. + +use crate::model::Event; + +use async_trait::async_trait; +use async_tungstenite::{ + tokio::ConnectStream, + tungstenite::{error::Error as TungsteniteError, protocol::CloseFrame, Message}, + WebSocketStream, +}; +use futures::{SinkExt, StreamExt, TryStreamExt}; +use serde_json::Error as JsonError; +use tokio::time::timeout; +use tracing::{instrument, warn}; + +pub type WsStream = WebSocketStream; + +pub type Result = std::result::Result; + +#[derive(Debug)] +pub enum Error { + Json(JsonError), + #[cfg(all(feature = "rustls", not(feature = "native")))] + Tls(RustlsError), + + /// The discord voice gateway does not support or offer zlib compression. + /// As a result, only text messages are expected. + UnexpectedBinaryMessage(Vec), + + Ws(TungsteniteError), + + WsClosed(Option>), +} + +impl From for Error { + fn from(e: JsonError) -> Error { + Error::Json(e) + } +} + +#[cfg(all(feature = "rustls", not(feature = "native")))] +impl From for Error { + fn from(e: RustlsError) -> Error { + Error::Tls(e) + } +} + +impl From for Error { + fn from(e: TungsteniteError) -> Error { + Error::Ws(e) + } +} + +use futures::stream::SplitSink; +#[cfg(all(feature = "rustls", not(feature = "native")))] +use std::{ + error::Error as StdError, + fmt::{Display, Formatter, Result as FmtResult}, + io::Error as IoError, +}; +use url::Url; + +#[async_trait] +pub trait ReceiverExt { + async fn recv_json(&mut self) -> Result>; + async fn recv_json_no_timeout(&mut self) -> Result>; +} + +#[async_trait] +pub trait SenderExt { + async fn send_json(&mut self, value: &Event) -> Result<()>; +} + +#[async_trait] +impl ReceiverExt for WsStream { + async fn recv_json(&mut self) -> Result> { + const TIMEOUT: tokio::time::Duration = tokio::time::Duration::from_millis(500); + + let ws_message = match timeout(TIMEOUT, self.next()).await { + Ok(Some(Ok(v))) => Some(v), + Ok(Some(Err(e))) => return Err(e.into()), + Ok(None) | Err(_) => None, + }; + + convert_ws_message(ws_message) + } + + async fn recv_json_no_timeout(&mut self) -> Result> { + convert_ws_message(self.try_next().await.ok().flatten()) + } +} + +#[async_trait] +impl SenderExt for SplitSink { + async fn send_json(&mut self, value: &Event) -> Result<()> { + Ok(serde_json::to_string(value) + .map(Message::Text) + .map_err(Error::from) + .map(|m| self.send(m))? + .await?) + } +} + +#[async_trait] +impl SenderExt for WsStream { + async fn send_json(&mut self, value: &Event) -> Result<()> { + Ok(serde_json::to_string(value) + .map(Message::Text) + .map_err(Error::from) + .map(|m| self.send(m))? + .await?) + } +} + +#[inline] +pub(crate) fn convert_ws_message(message: Option) -> Result> { + Ok(match message { + Some(Message::Text(payload)) => + serde_json::from_str(&payload).map(Some).map_err(|why| { + warn!("Err deserializing text: {:?}; text: {}", why, payload,); + + why + })?, + Some(Message::Binary(bytes)) => { + return Err(Error::UnexpectedBinaryMessage(bytes)); + }, + Some(Message::Close(Some(frame))) => { + return Err(Error::WsClosed(Some(frame))); + }, + // Ping/Pong message behaviour is internally handled by tungstenite. + _ => None, + }) +} + +/// An error that occured while connecting over rustls +#[derive(Debug)] +#[non_exhaustive] +#[cfg(all(feature = "rustls", not(feature = "native")))] +pub enum RustlsError { + /// An error with the handshake in tungstenite + HandshakeError, + /// Standard IO error happening while creating the tcp stream + Io(IoError), +} + +#[cfg(all(feature = "rustls", not(feature = "native")))] +impl From for RustlsError { + fn from(e: IoError) -> Self { + RustlsError::Io(e) + } +} + +#[cfg(all(feature = "rustls", not(feature = "native")))] +impl Display for RustlsError { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + match self { + RustlsError::HandshakeError => + f.write_str("TLS handshake failed when making the websocket connection"), + RustlsError::Io(inner) => Display::fmt(&inner, f), + } + } +} + +#[cfg(all(feature = "rustls", not(feature = "native")))] +impl StdError for RustlsError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + match self { + RustlsError::Io(inner) => Some(inner), + _ => None, + } + } +} + +#[cfg(all(feature = "rustls", not(feature = "native")))] +#[instrument] +pub(crate) async fn create_rustls_client(url: Url) -> Result { + let (stream, _) = async_tungstenite::tokio::connect_async_with_config::( + url, + Some(async_tungstenite::tungstenite::protocol::WebSocketConfig { + max_message_size: None, + max_frame_size: None, + max_send_queue: None, + }), + ) + .await + .map_err(|_| RustlsError::HandshakeError)?; + + Ok(stream) +} + +#[cfg(feature = "native")] +#[instrument] +pub(crate) async fn create_native_tls_client(url: Url) -> Result { + let (stream, _) = async_tungstenite::tokio::connect_async_with_config::( + url, + Some(async_tungstenite::tungstenite::protocol::WebSocketConfig { + max_message_size: None, + max_frame_size: None, + max_send_queue: None, + }), + ) + .await?; + + Ok(stream) +} diff --git a/utils/Cargo.toml b/utils/Cargo.toml new file mode 100644 index 0000000..9bb2698 --- /dev/null +++ b/utils/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "utils" +version = "0.1.0" +authors = ["Kyle Simpson "] +edition = "2018" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +byteorder = "1" diff --git a/utils/README.md b/utils/README.md new file mode 100644 index 0000000..fcb910f --- /dev/null +++ b/utils/README.md @@ -0,0 +1 @@ +Test utilities for testing and benchmarking songbird. diff --git a/utils/src/lib.rs b/utils/src/lib.rs new file mode 100644 index 0000000..35bcf34 --- /dev/null +++ b/utils/src/lib.rs @@ -0,0 +1,67 @@ +use byteorder::{LittleEndian, WriteBytesExt}; +use std::mem; + +pub fn make_sine(float_len: usize, stereo: bool) -> Vec { + let sample_len = mem::size_of::(); + let byte_len = float_len * sample_len; + + // set period to 100 samples == 480Hz sine. + + let mut out = vec![0u8; byte_len]; + let mut byte_slice = &mut out[..]; + + for i in 0..float_len { + let x_val = (i as f32) * 50.0 / std::f32::consts::PI; + byte_slice.write_f32::(x_val.sin()).unwrap(); + } + + if stereo { + let mut new_out = vec![0u8; byte_len * 2]; + + for (mono_chunk, stereo_chunk) in out[..] + .chunks(sample_len) + .zip(new_out[..].chunks_mut(2 * sample_len)) + { + stereo_chunk[..sample_len].copy_from_slice(mono_chunk); + stereo_chunk[sample_len..].copy_from_slice(mono_chunk); + } + + new_out + } else { + out + } +} + +pub fn make_pcm_sine(i16_len: usize, stereo: bool) -> Vec { + let sample_len = mem::size_of::(); + let byte_len = i16_len * sample_len; + + // set period to 100 samples == 480Hz sine. + // amplitude = 10_000 + + let mut out = vec![0u8; byte_len]; + let mut byte_slice = &mut out[..]; + + for i in 0..i16_len { + let x_val = (i as f32) * 50.0 / std::f32::consts::PI; + byte_slice + .write_i16::((x_val.sin() * 10_000.0) as i16) + .unwrap(); + } + + if stereo { + let mut new_out = vec![0u8; byte_len * 2]; + + for (mono_chunk, stereo_chunk) in out[..] + .chunks(sample_len) + .zip(new_out[..].chunks_mut(2 * sample_len)) + { + stereo_chunk[..sample_len].copy_from_slice(mono_chunk); + stereo_chunk[sample_len..].copy_from_slice(mono_chunk); + } + + new_out + } else { + out + } +}