Voice Rework -- Events, Track Queues (#806)

This implements a proof-of-concept for an improved audio frontend. The largest change is the introduction of events and event handling: both by time elapsed and by track events, such as ending or looping. Following on from this, the library now includes a basic, event-driven track queue system (which people seem to ask for unusually often). A new sample, `examples/13_voice_events`, demonstrates both the `TrackQueue` system and some basic events via the `~queue` and `~play_fade` commands.

Locks are removed from around the control of `Audio` objects, which should allow the backend to be moved to a more granular futures-based backend solution in a cleaner way.
This commit is contained in:
Kyle Simpson
2020-10-29 20:25:20 +00:00
committed by Alex M. M
commit 7e4392ae68
76 changed files with 8756 additions and 0 deletions

155
Cargo.toml Normal file
View File

@@ -0,0 +1,155 @@
[package]
authors = ["Kyle Simpson <kyleandrew.simpson@gmail.com>"]
description = "An async Rust library for the Discord voice API."
documentation = "https://docs.rs/songbird"
edition = "2018"
homepage = "https://github.com/serenity-rs/serenity"
include = ["src/**/*.rs", "Cargo.toml"]
keywords = ["discord", "api", "rtp", "audio"]
license = "ISC"
name = "songbird"
readme = "README.md"
repository = "https://github.com/serenity-rs/serenity.git"
version = "0.1.0"
[dependencies]
serde = { version = "1", features = ["derive"] }
serde_json = "1"
tracing = "0.1"
tracing-futures = "0.2"
[dependencies.async-trait]
optional = true
version = "0.1"
[dependencies.async-tungstenite]
default-features = false
features = ["tokio-runtime"]
optional = true
version = "0.9"
[dependencies.audiopus]
optional = true
version = "0.2"
[dependencies.byteorder]
optional = true
version = "1"
[dependencies.discortp]
features = ["discord-full"]
optional = true
version = "0.2"
[dependencies.flume]
optional = true
version = "0.9"
[dependencies.futures]
version = "0.3"
[dependencies.parking_lot]
optional = true
version = "0.11"
[dependencies.rand]
optional = true
version = "0.7"
[dependencies.serenity]
optional = true
features = ["voice", "gateway"]
path = "../"
version = "0.9.0-rc.2"
[dependencies.serenity-voice-model]
optional = true
path = "../voice-model"
version = "0.9.0-rc.2"
[dependencies.spin_sleep]
optional = true
version = "1"
[dependencies.streamcatcher]
optional = true
version = "0.1"
[dependencies.tokio]
optional = true
version = "0.2"
default-features = false
[dependencies.twilight-gateway]
optional = true
version = "0.1"
default-features = false
[dependencies.twilight-model]
optional = true
version = "0.1"
default-features = false
[dependencies.url]
optional = true
version = "2"
[dependencies.xsalsa20poly1305]
optional = true
version = "0.5"
[dev-dependencies]
criterion = "0.3"
utils = { path = "utils" }
[features]
default = [
"serenity-rustls",
"driver",
"gateway",
]
gateway = [
"flume",
"parking_lot",
"tokio/sync",
]
driver = [
"async-trait",
"async-tungstenite",
"audiopus",
"byteorder",
"discortp",
"flume",
"parking_lot",
"rand",
"serenity-voice-model",
"spin_sleep",
"streamcatcher",
"tokio/fs",
"tokio/io-util",
"tokio/net",
"tokio/rt-core",
"tokio/time",
"tokio/process",
"tokio/sync",
"url",
"xsalsa20poly1305",
]
rustls = ["async-tungstenite/tokio-rustls"]
native = ["async-tungstenite/tokio-native-tls"]
serenity-rustls = ["serenity/rustls_backend", "rustls", "gateway", "serenity-deps"]
serenity-native = ["serenity/native_tls_backend", "native", "gateway", "serenity-deps"]
twilight-rustls = ["twilight", "twilight-gateway/rustls", "rustls", "gateway"]
twilight-native = ["twilight", "twilight-gateway/native", "native", "gateway"]
twilight = ["twilight-model"]
simd-zlib = ["twilight-gateway/simd-zlib"]
stock-zlib = ["twilight-gateway/stock-zlib"]
serenity-deps = ["async-trait"]
[[bench]]
name = "mixing"
path = "benches/mixing.rs"
harness = false
[package.metadata.docs.rs]
all-features = true

29
README.md Normal file
View File

@@ -0,0 +1,29 @@
# Songbird
![](songbird.png)
Songbird is an async, cross-library compatible voice system for Discord, written in Rust.
The library offers:
* A standalone gateway frontend compatible with [serenity] and [twilight] using the
`"gateway"` and `"[serenity/twilight]-[rustls/native]"` features. You can even run
driverless, to help manage your [lavalink] sessions.
* A standalone driver for voice calls, via the `"driver"` feature. If you can create
a `ConnectionInfo` using any other gateway, or language for your bot, then you
can run the songbird voice driver.
* And, by default, a fully featured voice system featuring events, queues, RT(C)P packet
handling, seeking on compatible streams, shared multithreaded audio stream caches,
and direct Opus data passthrough from DCA files.
## Examples
Full examples showing various types of functionality and integrations can be found as part of [serenity's examples], and in [this crate's examples directory].
## Attribution
Songbird's logo is based upon the copyright-free image ["Black-Capped Chickadee"] by George Gorgas White.
[serenity]: https://github.com/serenity-rs/serenity
[twilight]: https://github.com/twilight-rs/twilight
["Black-Capped Chickadee"]: https://www.oldbookillustrations.com/illustrations/black-capped-chickadee/
[lavalink]: https://github.com/Frederikam/Lavalink
[serenity's examples]: https://github.com/serenity-rs/serenity/tree/current/examples
[this crate's examples directory]: https://github.com/serenity-rs/serenity/tree/current/songbird/examples

30
benches/mixing.rs Normal file
View File

@@ -0,0 +1,30 @@
use criterion::{black_box, criterion_group, criterion_main, BatchSize, Criterion};
use songbird::{constants::*, input::Input};
pub fn mix_one_frame(c: &mut Criterion) {
let floats = utils::make_sine(STEREO_FRAME_SIZE, true);
let mut raw_buf = [0f32; STEREO_FRAME_SIZE];
c.bench_function("Mix stereo source", |b| {
b.iter_batched_ref(
|| black_box(Input::float_pcm(true, floats.clone().into())),
|input| {
input.mix(black_box(&mut raw_buf), black_box(1.0));
},
BatchSize::SmallInput,
)
});
c.bench_function("Mix mono source", |b| {
b.iter_batched_ref(
|| black_box(Input::float_pcm(false, floats.clone().into())),
|input| {
input.mix(black_box(&mut raw_buf), black_box(1.0));
},
BatchSize::SmallInput,
)
});
}
criterion_group!(benches, mix_one_frame);
criterion_main!(benches);

23
build.rs Normal file
View File

@@ -0,0 +1,23 @@
#[cfg(all(feature = "driver", not(any(feature = "rustls", feature = "native"))))]
compile_error!(
"You have the `driver` feature enabled: \
either the `rustls` or `native` feature must be
selected to let Songbird's driver use websockets.\n\
- `rustls` uses Rustls, a pure Rust TLS-implemenation.\n\
- `native` uses SChannel on Windows, Secure Transport on macOS, \
and OpenSSL on other platforms.\n\
If you are unsure, go with `rustls`."
);
#[cfg(all(
feature = "twilight",
not(any(feature = "simd-zlib", feature = "stock-zlib"))
))]
compile_error!(
"Twilight requires you to specify a zlib backend: \
either the `simd-zlib` or `stock-zlib` feature must be
selected.\n\
If you are unsure, go with `stock-zlib`."
);
fn main() {}

3
examples/README.md Normal file
View File

@@ -0,0 +1,3 @@
# Songbird examples
These examples show more advanced use of Songbird, or how to include Songbird in bots built on other libraries, such as twilight.

View File

@@ -0,0 +1,21 @@
[package]
name = "basic-twilight-bot"
version = "0.1.0"
authors = ["Twilight and Serenity Contributors"]
edition = "2018"
[dependencies]
futures = "0.3"
tracing = "0.1"
tracing-subscriber = "0.2"
serde_json = { version = "1" }
tokio = { features = ["macros", "rt-threaded", "sync"], version = "0.2" }
twilight-gateway = "0.1"
twilight-http = "0.1"
twilight-model = "0.1"
twilight-standby = "0.1"
[dependencies.songbird]
path = "../.."
default-features = false
features = ["twilight-rustls", "gateway", "driver", "stock-zlib"]

View File

@@ -0,0 +1,378 @@
//! This example adapts Twilight's [basic lavalink bot] to use Songbird as its voice driver.
//!
//! # Twilight-rs attribution
//! ISC License (ISC)
//!
//! Copyright (c) 2019, 2020 (c) The Twilight Contributors
//!
//! Permission to use, copy, modify, and/or distribute this software for any purpose
//! with or without fee is hereby granted, provided that the above copyright notice
//! and this permission notice appear in all copies.
//!
//! THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH
//! REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND
//! FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,
//! INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS
//! OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER
//! TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF
//! THIS SOFTWARE.
//!
//!
//! [basic lavalink bot]: https://github.com/twilight-rs/twilight/tree/trunk/lavalink/examples/basic-lavalink-bot
use futures::StreamExt;
use std::{collections::HashMap, env, error::Error, future::Future, sync::Arc};
use songbird::{input::{Input, Restartable}, tracks::{PlayMode, TrackHandle}, Songbird};
use tokio::sync::RwLock;
use twilight_gateway::{Cluster, Event};
use twilight_http::Client as HttpClient;
use twilight_model::{channel::Message, gateway::payload::MessageCreate, id::GuildId};
use twilight_standby::Standby;
#[derive(Clone, Debug)]
struct State {
cluster: Cluster,
http: HttpClient,
trackdata: Arc<RwLock<HashMap<GuildId, TrackHandle>>>,
songbird: Arc<Songbird>,
standby: Standby,
}
fn spawn(
fut: impl Future<Output = Result<(), Box<dyn Error + Send + Sync + 'static>>> + Send + 'static,
) {
tokio::spawn(async move {
if let Err(why) = fut.await {
tracing::debug!("handler error: {:?}", why);
}
});
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
// Initialize the tracing subscriber.
tracing_subscriber::fmt::init();
let state = {
let token = env::var("DISCORD_TOKEN")?;
let http = HttpClient::new(&token);
let user_id = http.current_user().await?.id;
let cluster = Cluster::new(token).await?;
let shard_count = cluster.shards().len();
let songbird = Songbird::twilight(cluster.clone(), shard_count as u64, user_id);
cluster.up().await;
State {
cluster,
http,
trackdata: Default::default(),
songbird,
standby: Standby::new(),
}
};
let mut events = state.cluster.events();
while let Some(event) = events.next().await {
state.standby.process(&event.1);
state.songbird.process(&event.1).await;
if let Event::MessageCreate(msg) = event.1 {
if msg.guild_id.is_none() || !msg.content.starts_with('!') {
continue;
}
match msg.content.splitn(2, ' ').next() {
Some("!join") => spawn(join(msg.0, state.clone())),
Some("!leave") => spawn(leave(msg.0, state.clone())),
Some("!pause") => spawn(pause(msg.0, state.clone())),
Some("!play") => spawn(play(msg.0, state.clone())),
Some("!seek") => spawn(seek(msg.0, state.clone())),
Some("!stop") => spawn(stop(msg.0, state.clone())),
Some("!volume") => spawn(volume(msg.0, state.clone())),
_ => continue,
}
}
}
Ok(())
}
async fn join(msg: Message, state: State) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
state
.http
.create_message(msg.channel_id)
.content("What's the channel ID you want me to join?")?
.await?;
let author_id = msg.author.id;
let msg = state
.standby
.wait_for_message(msg.channel_id, move |new_msg: &MessageCreate| {
new_msg.author.id == author_id
})
.await?;
let channel_id = msg.content.parse::<u64>()?;
let guild_id = msg.guild_id.ok_or("Can't join a non-guild channel.")?;
let (_handle, success) = state
.songbird
.join(guild_id, channel_id)
.await;
let content = match success?.recv_async().await {
Ok(Ok(())) => format!("Joined <#{}>!", channel_id),
Ok(Err(e)) => format!("Failed to join <#{}>! Why: {:?}", channel_id, e),
_ => format!("Failed to join <#{}>: Gateway error!", channel_id),
};
state
.http
.create_message(msg.channel_id)
.content(content)?
.await?;
Ok(())
}
async fn leave(msg: Message, state: State) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
tracing::debug!(
"leave command in channel {} by {}",
msg.channel_id,
msg.author.name
);
let guild_id = msg.guild_id.unwrap();
state
.songbird
.leave(guild_id)
.await?;
state
.http
.create_message(msg.channel_id)
.content("Left the channel")?
.await?;
Ok(())
}
async fn play(msg: Message, state: State) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
tracing::debug!(
"play command in channel {} by {}",
msg.channel_id,
msg.author.name
);
state
.http
.create_message(msg.channel_id)
.content("What's the URL of the audio to play?")?
.await?;
let author_id = msg.author.id;
let msg = state
.standby
.wait_for_message(msg.channel_id, move |new_msg: &MessageCreate| {
new_msg.author.id == author_id
})
.await?;
let guild_id = msg.guild_id.unwrap();
if let Ok(song) = Restartable::ytdl(msg.content.clone()) {
let input = Input::from(song);
let content = format!(
"Playing **{:?}** by **{:?}**",
input.metadata.title.as_ref().unwrap_or(&"<UNKNOWN>".to_string()),
input.metadata.artist.as_ref().unwrap_or(&"<UNKNOWN>".to_string()),
);
state
.http
.create_message(msg.channel_id)
.content(content)?
.await?;
if let Some(call_lock) = state.songbird.get(guild_id) {
let mut call = call_lock.lock().await;
let handle = call.play_source(input);
let mut store = state.trackdata.write().await;
store.insert(guild_id, handle);
}
} else {
state
.http
.create_message(msg.channel_id)
.content("Didn't find any results")?
.await?;
}
Ok(())
}
async fn pause(msg: Message, state: State) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
tracing::debug!(
"pause command in channel {} by {}",
msg.channel_id,
msg.author.name
);
let guild_id = msg.guild_id.unwrap();
let store = state.trackdata.read().await;
let content = if let Some(handle) = store.get(&guild_id) {
let info = handle.get_info()?
.await?;
let paused = match info.playing {
PlayMode::Play => {
let _success = handle.pause();
false
}
_ => {
let _success = handle.play();
true
}
};
let action = if paused { "Unpaused" } else { "Paused" };
format!("{} the track", action)
} else {
format!("No track to (un)pause!")
};
state
.http
.create_message(msg.channel_id)
.content(content)?
.await?;
Ok(())
}
async fn seek(msg: Message, state: State) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
tracing::debug!(
"seek command in channel {} by {}",
msg.channel_id,
msg.author.name
);
state
.http
.create_message(msg.channel_id)
.content("Where in the track do you want to seek to (in seconds)?")?
.await?;
let author_id = msg.author.id;
let msg = state
.standby
.wait_for_message(msg.channel_id, move |new_msg: &MessageCreate| {
new_msg.author.id == author_id
})
.await?;
let guild_id = msg.guild_id.unwrap();
let position = msg.content.parse::<u64>()?;
let store = state.trackdata.read().await;
let content = if let Some(handle) = store.get(&guild_id) {
if handle.is_seekable() {
let _success = handle.seek_time(std::time::Duration::from_secs(position));
format!("Seeked to {}s", position)
} else {
format!("Track is not compatible with seeking!")
}
} else {
format!("No track to seek over!")
};
state
.http
.create_message(msg.channel_id)
.content(content)?
.await?;
Ok(())
}
async fn stop(msg: Message, state: State) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
tracing::debug!(
"stop command in channel {} by {}",
msg.channel_id,
msg.author.name
);
let guild_id = msg.guild_id.unwrap();
if let Some(call_lock) = state.songbird.get(guild_id) {
let mut call = call_lock.lock().await;
let _ = call.stop();
}
state
.http
.create_message(msg.channel_id)
.content("Stopped the track")?
.await?;
Ok(())
}
async fn volume(msg: Message, state: State) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
tracing::debug!(
"volume command in channel {} by {}",
msg.channel_id,
msg.author.name
);
state
.http
.create_message(msg.channel_id)
.content("What's the volume you want to set (0.0-10.0, 1.0 being the default)?")?
.await?;
let author_id = msg.author.id;
let msg = state
.standby
.wait_for_message(msg.channel_id, move |new_msg: &MessageCreate| {
new_msg.author.id == author_id
})
.await?;
let guild_id = msg.guild_id.unwrap();
let volume = msg.content.parse::<f64>()?;
if !volume.is_finite() || volume > 10.0 || volume < 0.0 {
state
.http
.create_message(msg.channel_id)
.content("Invalid volume!")?
.await?;
return Ok(());
}
let store = state.trackdata.read().await;
let content = if let Some(handle) = store.get(&guild_id) {
let _success = handle.set_volume(volume as f32);
format!("Set the volume to {}", volume)
} else {
format!("No track to change volume!")
};
state
.http
.create_message(msg.channel_id)
.content(content)?
.await?;
Ok(())
}

6
rustfmt.toml Normal file
View File

@@ -0,0 +1,6 @@
imports_layout = "HorizontalVertical"
match_arm_blocks = false
match_block_trailing_comma = true
newline_style = "Unix"
use_field_init_shorthand = true
use_try_shorthand = true

BIN
songbird-ico.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.9 KiB

BIN
songbird.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 73 KiB

22
songbird.svg Normal file

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 220 KiB

75
src/constants.rs Normal file
View File

@@ -0,0 +1,75 @@
//! Constants affecting driver function and API handling.
#[cfg(feature = "driver")]
use audiopus::{Bitrate, SampleRate};
#[cfg(feature = "driver")]
use discortp::rtp::RtpType;
use std::time::Duration;
#[cfg(feature = "driver")]
/// The voice gateway version used by the library.
pub const VOICE_GATEWAY_VERSION: u8 = crate::model::constants::GATEWAY_VERSION;
#[cfg(feature = "driver")]
/// Sample rate of audio to be sent to Discord.
pub const SAMPLE_RATE: SampleRate = SampleRate::Hz48000;
/// Sample rate of audio to be sent to Discord.
pub const SAMPLE_RATE_RAW: usize = 48_000;
/// Number of audio frames/packets to be sent per second.
pub const AUDIO_FRAME_RATE: usize = 50;
/// Length of time between any two audio frames.
pub const TIMESTEP_LENGTH: Duration = Duration::from_millis(1000 / AUDIO_FRAME_RATE as u64);
#[cfg(feature = "driver")]
/// Default bitrate for audio.
pub const DEFAULT_BITRATE: Bitrate = Bitrate::BitsPerSecond(128_000);
/// Number of samples in one complete frame of audio per channel.
///
/// This is equally the number of stereo (joint) samples in an audio frame.
pub const MONO_FRAME_SIZE: usize = SAMPLE_RATE_RAW / AUDIO_FRAME_RATE;
/// Number of individual samples in one complete frame of stereo audio.
pub const STEREO_FRAME_SIZE: usize = 2 * MONO_FRAME_SIZE;
/// Number of bytes in one complete frame of raw `f32`-encoded mono audio.
pub const MONO_FRAME_BYTE_SIZE: usize = MONO_FRAME_SIZE * std::mem::size_of::<f32>();
/// Number of bytes in one complete frame of raw `f32`-encoded stereo audio.
pub const STEREO_FRAME_BYTE_SIZE: usize = STEREO_FRAME_SIZE * std::mem::size_of::<f32>();
/// Length (in milliseconds) of any audio frame.
pub const FRAME_LEN_MS: usize = 1000 / AUDIO_FRAME_RATE;
/// Maximum number of audio frames/packets to be sent per second to be buffered.
pub const CHILD_BUFFER_LEN: usize = AUDIO_FRAME_RATE / 2;
/// Maximum packet size for a voice packet.
///
/// Set a safe amount below the Ethernet MTU to avoid fragmentation/rejection.
pub const VOICE_PACKET_MAX: usize = 1460;
/// Delay between sends of UDP keepalive frames.
///
/// Passive monitoring of Discord itself shows that these fire every 5 seconds
/// irrespective of outgoing UDP traffic.
pub const UDP_KEEPALIVE_GAP_MS: u64 = 5_000;
/// Type-converted delay between sends of UDP keepalive frames.
///
/// Passive monitoring of Discord itself shows that these fire every 5 seconds
/// irrespective of outgoing UDP traffic.
pub const UDP_KEEPALIVE_GAP: Duration = Duration::from_millis(UDP_KEEPALIVE_GAP_MS);
/// Opus silent frame, used to signal speech start and end (and prevent audio glitching).
pub const SILENT_FRAME: [u8; 3] = [0xf8, 0xff, 0xfe];
/// The one (and only) RTP version.
pub const RTP_VERSION: u8 = 2;
#[cfg(feature = "driver")]
/// Profile type used by Discord's Opus audio traffic.
pub const RTP_PROFILE_TYPE: RtpType = RtpType::Dynamic(120);

10
src/driver/config.rs Normal file
View File

@@ -0,0 +1,10 @@
use super::CryptoMode;
/// Configuration for the inner Driver.
///
/// At present, this cannot be changed.
#[derive(Clone, Debug, Default)]
pub struct Config {
/// Selected tagging mode for voice packet encryption.
pub crypto_mode: Option<CryptoMode>,
}

View File

@@ -0,0 +1,105 @@
//! Connection errors and convenience types.
use crate::{
driver::tasks::{error::Recipient, message::*},
ws::Error as WsError,
};
use flume::SendError;
use serde_json::Error as JsonError;
use std::{error::Error as ErrorTrait, fmt, io::Error as IoError};
use xsalsa20poly1305::aead::Error as CryptoError;
/// Errors encountered while connecting to a Discord voice server over the driver.
#[derive(Debug)]
pub enum Error {
/// An error occurred during [en/de]cryption of voice packets or key generation.
Crypto(CryptoError),
/// Server did not return the expected crypto mode during negotiation.
CryptoModeInvalid,
/// Selected crypto mode was not offered by server.
CryptoModeUnavailable,
/// An indicator that an endpoint URL was invalid.
EndpointUrl,
/// Discord hello/ready handshake was violated.
ExpectedHandshake,
/// Discord failed to correctly respond to IP discovery.
IllegalDiscoveryResponse,
/// Could not parse Discord's view of our IP.
IllegalIp,
/// Miscellaneous I/O error.
Io(IoError),
/// JSON (de)serialization error.
Json(JsonError),
/// Failed to message other background tasks after connection establishment.
InterconnectFailure(Recipient),
/// Error communicating with gateway server over WebSocket.
Ws(WsError),
}
impl From<CryptoError> for Error {
fn from(e: CryptoError) -> Self {
Error::Crypto(e)
}
}
impl From<IoError> for Error {
fn from(e: IoError) -> Error {
Error::Io(e)
}
}
impl From<JsonError> for Error {
fn from(e: JsonError) -> Error {
Error::Json(e)
}
}
impl From<SendError<WsMessage>> for Error {
fn from(_e: SendError<WsMessage>) -> Error {
Error::InterconnectFailure(Recipient::AuxNetwork)
}
}
impl From<SendError<EventMessage>> for Error {
fn from(_e: SendError<EventMessage>) -> Error {
Error::InterconnectFailure(Recipient::Event)
}
}
impl From<SendError<MixerMessage>> for Error {
fn from(_e: SendError<MixerMessage>) -> Error {
Error::InterconnectFailure(Recipient::Mixer)
}
}
impl From<WsError> for Error {
fn from(e: WsError) -> Error {
Error::Ws(e)
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Failed to connect to Discord RTP server: ")?;
use Error::*;
match self {
Crypto(c) => write!(f, "cryptography error {}.", c),
CryptoModeInvalid => write!(f, "server changed negotiated encryption mode."),
CryptoModeUnavailable => write!(f, "server did not offer chosen encryption mode."),
EndpointUrl => write!(f, "endpoint URL received from gateway was invalid."),
ExpectedHandshake => write!(f, "voice initialisation protocol was violated."),
IllegalDiscoveryResponse =>
write!(f, "IP discovery/NAT punching response was invalid."),
IllegalIp => write!(f, "IP discovery/NAT punching response had bad IP value."),
Io(i) => write!(f, "I/O failure ({}).", i),
Json(j) => write!(f, "JSON (de)serialization issue ({}).", j),
InterconnectFailure(r) => write!(f, "failed to contact other task ({:?})", r),
Ws(w) => write!(f, "websocket issue ({:?}).", w),
}
}
}
impl ErrorTrait for Error {}
/// Convenience type for Discord voice/driver connection error handling.
pub type Result<T> = std::result::Result<T, Error>;

View File

@@ -0,0 +1,321 @@
pub mod error;
use super::{
tasks::{message::*, udp_rx, udp_tx, ws as ws_task},
Config,
CryptoMode,
};
use crate::{
constants::*,
model::{
payload::{Identify, Resume, SelectProtocol},
Event as GatewayEvent,
ProtocolData,
},
ws::{self, ReceiverExt, SenderExt, WsStream},
ConnectionInfo,
};
use discortp::discord::{IpDiscoveryPacket, IpDiscoveryType, MutableIpDiscoveryPacket};
use error::{Error, Result};
use flume::Sender;
use std::{net::IpAddr, str::FromStr};
use tokio::net::UdpSocket;
use tracing::{debug, info, instrument};
use url::Url;
use xsalsa20poly1305::{aead::NewAead, XSalsa20Poly1305 as Cipher};
#[cfg(all(feature = "rustls", not(feature = "native")))]
use ws::create_rustls_client;
#[cfg(feature = "native")]
use ws::create_native_tls_client;
pub(crate) struct Connection {
pub(crate) info: ConnectionInfo,
pub(crate) ws: Sender<WsMessage>,
}
impl Connection {
pub(crate) async fn new(
mut info: ConnectionInfo,
interconnect: &Interconnect,
config: &Config,
) -> Result<Connection> {
let crypto_mode = config.crypto_mode.unwrap_or(CryptoMode::Normal);
let url = generate_url(&mut info.endpoint)?;
#[cfg(all(feature = "rustls", not(feature = "native")))]
let mut client = create_rustls_client(url).await?;
#[cfg(feature = "native")]
let mut client = create_native_tls_client(url).await?;
let mut hello = None;
let mut ready = None;
client
.send_json(&GatewayEvent::from(Identify {
server_id: info.guild_id.into(),
session_id: info.session_id.clone(),
token: info.token.clone(),
user_id: info.user_id.into(),
}))
.await?;
loop {
let value = match client.recv_json().await? {
Some(value) => value,
None => continue,
};
match value {
GatewayEvent::Ready(r) => {
ready = Some(r);
if hello.is_some() {
break;
}
},
GatewayEvent::Hello(h) => {
hello = Some(h);
if ready.is_some() {
break;
}
},
other => {
debug!("Expected ready/hello; got: {:?}", other);
return Err(Error::ExpectedHandshake);
},
}
}
let hello =
hello.expect("Hello packet expected in connection initialisation, but not found.");
let ready =
ready.expect("Ready packet expected in connection initialisation, but not found.");
if !has_valid_mode(&ready.modes, crypto_mode) {
return Err(Error::CryptoModeUnavailable);
}
let mut udp = UdpSocket::bind("0.0.0.0:0").await?;
udp.connect((ready.ip, ready.port)).await?;
// Follow Discord's IP Discovery procedures, in case NAT tunnelling is needed.
let mut bytes = [0; IpDiscoveryPacket::const_packet_size()];
{
let mut view = MutableIpDiscoveryPacket::new(&mut bytes[..]).expect(
"Too few bytes in 'bytes' for IPDiscovery packet.\
(Blame: IpDiscoveryPacket::const_packet_size()?)",
);
view.set_pkt_type(IpDiscoveryType::Request);
view.set_length(70);
view.set_ssrc(ready.ssrc);
}
udp.send(&bytes).await?;
let (len, _addr) = udp.recv_from(&mut bytes).await?;
{
let view =
IpDiscoveryPacket::new(&bytes[..len]).ok_or(Error::IllegalDiscoveryResponse)?;
if view.get_pkt_type() != IpDiscoveryType::Response {
return Err(Error::IllegalDiscoveryResponse);
}
// We could do something clever like binary search,
// but possibility of UDP spoofing preclueds us from
// making the assumption we can find a "left edge" of '\0's.
let nul_byte_index = view
.get_address_raw()
.iter()
.position(|&b| b == 0)
.ok_or(Error::IllegalIp)?;
let address_str = std::str::from_utf8(&view.get_address_raw()[..nul_byte_index])
.map_err(|_| Error::IllegalIp)?;
let address = IpAddr::from_str(&address_str).map_err(|e| {
println!("{:?}", e);
Error::IllegalIp
})?;
client
.send_json(&GatewayEvent::from(SelectProtocol {
protocol: "udp".into(),
data: ProtocolData {
address,
mode: crypto_mode.to_request_str().into(),
port: view.get_port(),
},
}))
.await?;
}
let cipher = init_cipher(&mut client, crypto_mode).await?;
info!("Connected to: {}", info.endpoint);
info!("WS heartbeat duration {}ms.", hello.heartbeat_interval,);
let (ws_msg_tx, ws_msg_rx) = flume::unbounded();
let (udp_sender_msg_tx, udp_sender_msg_rx) = flume::unbounded();
let (udp_receiver_msg_tx, udp_receiver_msg_rx) = flume::unbounded();
let (udp_rx, udp_tx) = udp.split();
let ssrc = ready.ssrc;
let mix_conn = MixerConnection {
cipher: cipher.clone(),
udp_rx: udp_receiver_msg_tx,
udp_tx: udp_sender_msg_tx,
};
interconnect
.mixer
.send(MixerMessage::Ws(Some(ws_msg_tx.clone())))?;
interconnect
.mixer
.send(MixerMessage::SetConn(mix_conn, ready.ssrc))?;
tokio::spawn(ws_task::runner(
interconnect.clone(),
ws_msg_rx,
client,
ssrc,
hello.heartbeat_interval,
));
tokio::spawn(udp_rx::runner(
interconnect.clone(),
udp_receiver_msg_rx,
cipher,
crypto_mode,
udp_rx,
));
tokio::spawn(udp_tx::runner(udp_sender_msg_rx, ssrc, udp_tx));
Ok(Connection {
info,
ws: ws_msg_tx,
})
}
#[instrument(skip(self))]
pub async fn reconnect(&mut self) -> Result<()> {
let url = generate_url(&mut self.info.endpoint)?;
// Thread may have died, we want to send to prompt a clean exit
// (if at all possible) and then proceed as normal.
#[cfg(all(feature = "rustls", not(feature = "native")))]
let mut client = create_rustls_client(url).await?;
#[cfg(feature = "native")]
let mut client = create_native_tls_client(url).await?;
client
.send_json(&GatewayEvent::from(Resume {
server_id: self.info.guild_id.into(),
session_id: self.info.session_id.clone(),
token: self.info.token.clone(),
}))
.await?;
let mut hello = None;
let mut resumed = None;
loop {
let value = match client.recv_json().await? {
Some(value) => value,
None => continue,
};
match value {
GatewayEvent::Resumed => {
resumed = Some(());
if hello.is_some() {
break;
}
},
GatewayEvent::Hello(h) => {
hello = Some(h);
if resumed.is_some() {
break;
}
},
other => {
debug!("Expected resumed/hello; got: {:?}", other);
return Err(Error::ExpectedHandshake);
},
}
}
let hello =
hello.expect("Hello packet expected in connection initialisation, but not found.");
self.ws
.send(WsMessage::SetKeepalive(hello.heartbeat_interval))?;
self.ws.send(WsMessage::Ws(Box::new(client)))?;
info!("Reconnected to: {}", &self.info.endpoint);
Ok(())
}
}
impl Drop for Connection {
fn drop(&mut self) {
info!("Disconnected");
}
}
fn generate_url(endpoint: &mut String) -> Result<Url> {
if endpoint.ends_with(":80") {
let len = endpoint.len();
endpoint.truncate(len - 3);
}
Url::parse(&format!("wss://{}/?v={}", endpoint, VOICE_GATEWAY_VERSION))
.or(Err(Error::EndpointUrl))
}
#[inline]
async fn init_cipher(client: &mut WsStream, mode: CryptoMode) -> Result<Cipher> {
loop {
let value = match client.recv_json().await? {
Some(value) => value,
None => continue,
};
match value {
GatewayEvent::SessionDescription(desc) => {
if desc.mode != mode.to_request_str() {
return Err(Error::CryptoModeInvalid);
}
return Ok(Cipher::new_varkey(&desc.secret_key)?);
},
other => {
debug!(
"Expected ready for key; got: op{}/v{:?}",
other.kind() as u8,
other
);
},
}
}
}
#[inline]
fn has_valid_mode<T, It>(modes: It, mode: CryptoMode) -> bool
where
T: for<'a> PartialEq<&'a str>,
It: IntoIterator<Item = T>,
{
modes.into_iter().any(|s| s == mode.to_request_str())
}

38
src/driver/crypto.rs Normal file
View File

@@ -0,0 +1,38 @@
//! Encryption schemes supported by Discord's secure RTP negotiation.
/// Variants of the XSalsa20Poly1305 encryption scheme.
///
/// At present, only `Normal` is supported or selectable.
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
#[non_exhaustive]
pub enum Mode {
/// The RTP header is used as the source of nonce bytes for the packet.
///
/// Equivalent to a nonce of at most 48b (6B) at no extra packet overhead:
/// the RTP sequence number and timestamp are the varying quantities.
Normal,
/// An additional random 24B suffix is used as the source of nonce bytes for the packet.
///
/// Full nonce width of 24B (192b), at an extra 24B per packet (~1.2 kB/s).
Suffix,
/// An additional random 24B suffix is used as the source of nonce bytes for the packet.
///
/// Nonce width of 4B (32b), at an extra 4B per packet (~0.2 kB/s).
Lite,
}
impl Mode {
/// Returns the name of a mode as it will appear during negotiation.
pub fn to_request_str(self) -> &'static str {
use Mode::*;
match self {
Normal => "xsalsa20_poly1305",
Suffix => "xsalsa20_poly1305_suffix",
Lite => "xsalsa20_poly1305_lite",
}
}
}
// TODO: implement encrypt + decrypt + nonce selection for each.
// This will probably need some research into correct handling of
// padding, reported length, SRTP profiles, and so on.

233
src/driver/mod.rs Normal file
View File

@@ -0,0 +1,233 @@
//! Runner for a voice connection.
//!
//! Songbird's driver is a mixed-sync system, using:
//! * Asynchronous connection management, event-handling, and gateway integration.
//! * Synchronous audio mixing, packet generation, and encoding.
//!
//! This splits up work according to its IO/compute bound nature, preventing packet
//! generation from being slowed down past its deadline, or from affecting other
//! asynchronous tasks your bot must handle.
mod config;
pub(crate) mod connection;
mod crypto;
pub(crate) mod tasks;
pub use config::Config;
use connection::error::Result;
pub use crypto::Mode as CryptoMode;
use crate::{
events::EventData,
input::Input,
tracks::{Track, TrackHandle},
ConnectionInfo,
Event,
EventHandler,
};
use audiopus::Bitrate;
use flume::{Receiver, SendError, Sender};
use tasks::message::CoreMessage;
use tracing::instrument;
/// The control object for a Discord voice connection, handling connection,
/// mixing, encoding, en/decryption, and event generation.
#[derive(Clone, Debug)]
pub struct Driver {
config: Config,
self_mute: bool,
sender: Sender<CoreMessage>,
}
impl Driver {
/// Creates a new voice driver.
///
/// This will create the core voice tasks in the background.
#[inline]
pub fn new(config: Config) -> Self {
let sender = Self::start_inner(config.clone());
Driver {
config,
self_mute: false,
sender,
}
}
fn start_inner(config: Config) -> Sender<CoreMessage> {
let (tx, rx) = flume::unbounded();
tasks::start(config, rx, tx.clone());
tx
}
fn restart_inner(&mut self) {
self.sender = Self::start_inner(self.config.clone());
self.mute(self.self_mute);
}
/// Connects to a voice channel using the specified server.
#[instrument(skip(self))]
pub fn connect(&mut self, info: ConnectionInfo) -> Receiver<Result<()>> {
let (tx, rx) = flume::bounded(1);
self.raw_connect(info, tx);
rx
}
/// Connects to a voice channel using the specified server.
#[instrument(skip(self))]
pub(crate) fn raw_connect(&mut self, info: ConnectionInfo, tx: Sender<Result<()>>) {
self.send(CoreMessage::ConnectWithResult(info, tx));
}
/// Leaves the current voice channel, disconnecting from it.
///
/// This does *not* forget settings, like whether to be self-deafened or
/// self-muted.
#[instrument(skip(self))]
pub fn leave(&mut self) {
self.send(CoreMessage::Disconnect);
}
/// Sets whether the current connection is to be muted.
///
/// If there is no live voice connection, then this only acts as a settings
/// update for future connections.
#[instrument(skip(self))]
pub fn mute(&mut self, mute: bool) {
self.self_mute = mute;
self.send(CoreMessage::Mute(mute));
}
/// Returns whether the driver is muted (i.e., processes audio internally
/// but submits none).
#[instrument(skip(self))]
pub fn is_mute(&self) -> bool {
self.self_mute
}
/// Plays audio from a source, returning a handle for further control.
///
/// This can be a source created via [`ffmpeg`] or [`ytdl`].
///
/// [`ffmpeg`]: ../input/fn.ffmpeg.html
/// [`ytdl`]: ../input/fn.ytdl.html
#[instrument(skip(self))]
pub fn play_source(&mut self, source: Input) -> TrackHandle {
let (player, handle) = super::create_player(source);
self.send(CoreMessage::AddTrack(player));
handle
}
/// Plays audio from a source, returning a handle for further control.
///
/// Unlike [`play_source`], this stops all other sources attached
/// to the channel.
///
/// [`play_source`]: #method.play_source
#[instrument(skip(self))]
pub fn play_only_source(&mut self, source: Input) -> TrackHandle {
let (player, handle) = super::create_player(source);
self.send(CoreMessage::SetTrack(Some(player)));
handle
}
/// Plays audio from a [`Track`] object.
///
/// This will be one half of the return value of [`create_player`].
/// The main difference between this function and [`play_source`] is
/// that this allows for direct manipulation of the [`Track`] object
/// before it is passed over to the voice and mixing contexts.
///
/// [`create_player`]: ../tracks/fn.create_player.html
/// [`Track`]: ../tracks/struct.Track.html
/// [`play_source`]: #method.play_source
#[instrument(skip(self))]
pub fn play(&mut self, track: Track) {
self.send(CoreMessage::AddTrack(track));
}
/// Exclusively plays audio from a [`Track`] object.
///
/// This will be one half of the return value of [`create_player`].
/// As in [`play_only_source`], this stops all other sources attached to the
/// channel. Like [`play`], however, this allows for direct manipulation of the
/// [`Track`] object before it is passed over to the voice and mixing contexts.
///
/// [`create_player`]: ../tracks/fn.create_player.html
/// [`Track`]: ../tracks/struct.Track.html
/// [`play_only_source`]: #method.play_only_source
/// [`play`]: #method.play
#[instrument(skip(self))]
pub fn play_only(&mut self, track: Track) {
self.send(CoreMessage::SetTrack(Some(track)));
}
/// Sets the bitrate for encoding Opus packets sent along
/// the channel being managed.
///
/// The default rate is 128 kbps.
/// Sensible values range between `Bits(512)` and `Bits(512_000)`
/// bits per second.
/// Alternatively, `Auto` and `Max` remain available.
#[instrument(skip(self))]
pub fn set_bitrate(&mut self, bitrate: Bitrate) {
self.send(CoreMessage::SetBitrate(bitrate))
}
/// Stops playing audio from all sources, if any are set.
#[instrument(skip(self))]
pub fn stop(&mut self) {
self.send(CoreMessage::SetTrack(None))
}
/// Attach a global event handler to an audio context. Global events may receive
/// any [`EventContext`].
///
/// Global timing events will tick regardless of whether audio is playing,
/// so long as the bot is connected to a voice channel, and have no tracks.
/// [`TrackEvent`]s will respond to all relevant tracks, giving some audio elements.
///
/// Users **must** ensure that no costly work or blocking occurs
/// within the supplied function or closure. *Taking excess time could prevent
/// timely sending of packets, causing audio glitches and delays*.
///
/// [`Track`]: ../tracks/struct.Track.html
/// [`TrackEvent`]: ../events/enum.TrackEvent.html
/// [`EventContext`]: ../events/enum.EventContext.html
#[instrument(skip(self, action))]
pub fn add_global_event<F: EventHandler + 'static>(&mut self, event: Event, action: F) {
self.send(CoreMessage::AddEvent(EventData::new(event, action)));
}
/// Sends a message to the inner tasks, restarting it if necessary.
fn send(&mut self, status: CoreMessage) {
// Restart thread if it errored.
if let Err(SendError(status)) = self.sender.send(status) {
self.restart_inner();
self.sender.send(status).unwrap();
}
}
}
impl Default for Driver {
fn default() -> Self {
Self::new(Default::default())
}
}
impl Drop for Driver {
/// Leaves the current connected voice channel, if connected to one, and
/// forgets all configurations relevant to this Handler.
fn drop(&mut self) {
self.leave();
let _ = self.sender.send(CoreMessage::Poison);
}
}

97
src/driver/tasks/error.rs Normal file
View File

@@ -0,0 +1,97 @@
use super::message::*;
use crate::ws::Error as WsError;
use audiopus::Error as OpusError;
use flume::SendError;
use std::io::Error as IoError;
use xsalsa20poly1305::aead::Error as CryptoError;
#[derive(Debug)]
pub enum Recipient {
AuxNetwork,
Event,
Mixer,
UdpRx,
UdpTx,
}
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug)]
pub enum Error {
Crypto(CryptoError),
/// Received an illegal voice packet on the voice UDP socket.
IllegalVoicePacket,
InterconnectFailure(Recipient),
Io(IoError),
Opus(OpusError),
Ws(WsError),
}
impl Error {
pub(crate) fn should_trigger_connect(&self) -> bool {
matches!(
self,
Error::InterconnectFailure(Recipient::AuxNetwork)
| Error::InterconnectFailure(Recipient::UdpRx)
| Error::InterconnectFailure(Recipient::UdpTx)
)
}
pub(crate) fn should_trigger_interconnect_rebuild(&self) -> bool {
matches!(self, Error::InterconnectFailure(Recipient::Event))
}
}
impl From<CryptoError> for Error {
fn from(e: CryptoError) -> Self {
Error::Crypto(e)
}
}
impl From<IoError> for Error {
fn from(e: IoError) -> Error {
Error::Io(e)
}
}
impl From<OpusError> for Error {
fn from(e: OpusError) -> Error {
Error::Opus(e)
}
}
impl From<SendError<WsMessage>> for Error {
fn from(_e: SendError<WsMessage>) -> Error {
Error::InterconnectFailure(Recipient::AuxNetwork)
}
}
impl From<SendError<EventMessage>> for Error {
fn from(_e: SendError<EventMessage>) -> Error {
Error::InterconnectFailure(Recipient::Event)
}
}
impl From<SendError<MixerMessage>> for Error {
fn from(_e: SendError<MixerMessage>) -> Error {
Error::InterconnectFailure(Recipient::Mixer)
}
}
impl From<SendError<UdpRxMessage>> for Error {
fn from(_e: SendError<UdpRxMessage>) -> Error {
Error::InterconnectFailure(Recipient::UdpRx)
}
}
impl From<SendError<UdpTxMessage>> for Error {
fn from(_e: SendError<UdpTxMessage>) -> Error {
Error::InterconnectFailure(Recipient::UdpTx)
}
}
impl From<WsError> for Error {
fn from(e: WsError) -> Error {
Error::Ws(e)
}
}

118
src/driver/tasks/events.rs Normal file
View File

@@ -0,0 +1,118 @@
use super::message::*;
use crate::{
events::{EventStore, GlobalEvents, TrackEvent},
tracks::{TrackHandle, TrackState},
};
use flume::Receiver;
use tracing::{debug, info, instrument, trace};
#[instrument(skip(_interconnect, evt_rx))]
pub(crate) async fn runner(_interconnect: Interconnect, evt_rx: Receiver<EventMessage>) {
let mut global = GlobalEvents::default();
let mut events: Vec<EventStore> = vec![];
let mut states: Vec<TrackState> = vec![];
let mut handles: Vec<TrackHandle> = vec![];
loop {
use EventMessage::*;
match evt_rx.recv_async().await {
Ok(AddGlobalEvent(data)) => {
info!("Global event added.");
global.add_event(data);
},
Ok(AddTrackEvent(i, data)) => {
info!("Adding event to track {}.", i);
let event_store = events
.get_mut(i)
.expect("Event thread was given an illegal store index for AddTrackEvent.");
let state = states
.get_mut(i)
.expect("Event thread was given an illegal state index for AddTrackEvent.");
event_store.add_event(data, state.position);
},
Ok(FireCoreEvent(ctx)) => {
let ctx = ctx.to_user_context();
let evt = ctx
.to_core_event()
.expect("Event thread was passed a non-core event in FireCoreEvent.");
trace!("Firing core event {:?}.", evt);
global.fire_core_event(evt, ctx).await;
},
Ok(AddTrack(store, state, handle)) => {
events.push(store);
states.push(state);
handles.push(handle);
info!("Event state for track {} added", events.len());
},
Ok(ChangeState(i, change)) => {
use TrackStateChange::*;
let max_states = states.len();
debug!(
"Changing state for track {} of {}: {:?}",
i, max_states, change
);
let state = states
.get_mut(i)
.expect("Event thread was given an illegal state index for ChangeState.");
match change {
Mode(mode) => {
let old = state.playing;
state.playing = mode;
if old != mode && mode.is_done() {
global.fire_track_event(TrackEvent::End, i);
}
},
Volume(vol) => {
state.volume = vol;
},
Position(pos) => {
// Currently, only Tick should fire time events.
state.position = pos;
},
Loops(loops, user_set) => {
state.loops = loops;
if !user_set {
global.fire_track_event(TrackEvent::Loop, i);
}
},
Total(new) => {
// Massive, unprecedented state changes.
*state = new;
},
}
},
Ok(RemoveTrack(i)) => {
info!("Event state for track {} of {} removed.", i, events.len());
events.remove(i);
states.remove(i);
handles.remove(i);
},
Ok(RemoveAllTracks) => {
info!("Event state for all tracks removed.");
events.clear();
states.clear();
handles.clear();
},
Ok(Tick) => {
// NOTE: this should fire saved up blocks of state change evts.
global.tick(&mut events, &mut states, &mut handles).await;
},
Err(_) | Ok(Poison) => {
break;
},
}
}
info!("Event thread exited.");
}

View File

@@ -0,0 +1,24 @@
use crate::{
driver::connection::error::Error,
events::EventData,
tracks::Track,
Bitrate,
ConnectionInfo,
};
use flume::Sender;
#[allow(clippy::large_enum_variant)]
#[derive(Debug)]
pub enum CoreMessage {
ConnectWithResult(ConnectionInfo, Sender<Result<(), Error>>),
Disconnect,
SetTrack(Option<Track>),
AddTrack(Track),
SetBitrate(Bitrate),
AddEvent(EventData),
Mute(bool),
Reconnect,
FullReconnect,
RebuildInterconnect,
Poison,
}

View File

@@ -0,0 +1,31 @@
use crate::{
events::{CoreContext, EventData, EventStore},
tracks::{LoopState, PlayMode, TrackHandle, TrackState},
};
use std::time::Duration;
pub(crate) enum EventMessage {
// Event related.
// Track events should fire off the back of state changes.
AddGlobalEvent(EventData),
AddTrackEvent(usize, EventData),
FireCoreEvent(CoreContext),
AddTrack(EventStore, TrackState, TrackHandle),
ChangeState(usize, TrackStateChange),
RemoveTrack(usize),
RemoveAllTracks,
Tick,
Poison,
}
#[derive(Debug)]
pub enum TrackStateChange {
Mode(PlayMode),
Volume(f32),
Position(Duration),
// Bool indicates user-set.
Loops(LoopState, bool),
Total(TrackState),
}

View File

@@ -0,0 +1,32 @@
use super::{Interconnect, UdpRxMessage, UdpTxMessage, WsMessage};
use crate::{tracks::Track, Bitrate};
use flume::Sender;
use xsalsa20poly1305::XSalsa20Poly1305 as Cipher;
pub(crate) struct MixerConnection {
pub cipher: Cipher,
pub udp_rx: Sender<UdpRxMessage>,
pub udp_tx: Sender<UdpTxMessage>,
}
impl Drop for MixerConnection {
fn drop(&mut self) {
let _ = self.udp_rx.send(UdpRxMessage::Poison);
let _ = self.udp_tx.send(UdpTxMessage::Poison);
}
}
pub(crate) enum MixerMessage {
AddTrack(Track),
SetTrack(Option<Track>),
SetBitrate(Bitrate),
SetMute(bool),
SetConn(MixerConnection, u32),
DropConn,
ReplaceInterconnect(Interconnect),
RebuildEncoder,
Ws(Option<Sender<WsMessage>>),
Poison,
}

View File

@@ -0,0 +1,49 @@
mod core;
mod events;
mod mixer;
mod udp_rx;
mod udp_tx;
mod ws;
pub(crate) use self::{core::*, events::*, mixer::*, udp_rx::*, udp_tx::*, ws::*};
use flume::Sender;
use tracing::info;
#[derive(Clone, Debug)]
pub(crate) struct Interconnect {
pub core: Sender<CoreMessage>,
pub events: Sender<EventMessage>,
pub mixer: Sender<MixerMessage>,
}
impl Interconnect {
pub fn poison(&self) {
let _ = self.events.send(EventMessage::Poison);
}
pub fn poison_all(&self) {
self.poison();
let _ = self.mixer.send(MixerMessage::Poison);
}
pub fn restart_volatile_internals(&mut self) {
self.poison();
let (evt_tx, evt_rx) = flume::unbounded();
self.events = evt_tx;
let ic = self.clone();
tokio::spawn(async move {
info!("Event processor restarted.");
super::events::runner(ic, evt_rx).await;
info!("Event processor finished.");
});
// Make mixer aware of new targets...
let _ = self
.mixer
.send(MixerMessage::ReplaceInterconnect(self.clone()));
}
}

View File

@@ -0,0 +1,7 @@
use super::Interconnect;
pub(crate) enum UdpRxMessage {
ReplaceInterconnect(Interconnect),
Poison,
}

View File

@@ -0,0 +1,4 @@
pub enum UdpTxMessage {
Packet(Vec<u8>), // TODO: do something cheaper.
Poison,
}

View File

@@ -0,0 +1,12 @@
use super::Interconnect;
use crate::ws::WsStream;
#[allow(dead_code)]
pub(crate) enum WsMessage {
Ws(Box<WsStream>),
ReplaceInterconnect(Interconnect),
SetKeepalive(f64),
Speaking(bool),
Poison,
}

516
src/driver/tasks/mixer.rs Normal file
View File

@@ -0,0 +1,516 @@
use super::{error::Result, message::*};
use crate::{
constants::*,
tracks::{PlayMode, Track},
};
use audiopus::{
coder::Encoder as OpusEncoder,
softclip::SoftClip,
Application as CodingMode,
Bitrate,
Channels,
};
use discortp::{
rtp::{MutableRtpPacket, RtpPacket},
MutablePacket,
Packet,
};
use flume::{Receiver, Sender, TryRecvError};
use rand::random;
use spin_sleep::SpinSleeper;
use std::time::Instant;
use tokio::runtime::Handle;
use tracing::{error, instrument};
use xsalsa20poly1305::{aead::AeadInPlace, Nonce, TAG_SIZE};
struct Mixer {
async_handle: Handle,
bitrate: Bitrate,
conn_active: Option<MixerConnection>,
deadline: Instant,
encoder: OpusEncoder,
interconnect: Interconnect,
mix_rx: Receiver<MixerMessage>,
muted: bool,
packet: [u8; VOICE_PACKET_MAX],
prevent_events: bool,
silence_frames: u8,
sleeper: SpinSleeper,
soft_clip: SoftClip,
tracks: Vec<Track>,
ws: Option<Sender<WsMessage>>,
}
fn new_encoder(bitrate: Bitrate) -> Result<OpusEncoder> {
let mut encoder = OpusEncoder::new(SAMPLE_RATE, Channels::Stereo, CodingMode::Audio)?;
encoder.set_bitrate(bitrate)?;
Ok(encoder)
}
impl Mixer {
fn new(
mix_rx: Receiver<MixerMessage>,
async_handle: Handle,
interconnect: Interconnect,
) -> Self {
let bitrate = DEFAULT_BITRATE;
let encoder = new_encoder(bitrate)
.expect("Failed to create encoder in mixing thread with known-good values.");
let soft_clip = SoftClip::new(Channels::Stereo);
let mut packet = [0u8; VOICE_PACKET_MAX];
let mut rtp = MutableRtpPacket::new(&mut packet[..]).expect(
"FATAL: Too few bytes in self.packet for RTP header.\
(Blame: VOICE_PACKET_MAX?)",
);
rtp.set_version(RTP_VERSION);
rtp.set_payload_type(RTP_PROFILE_TYPE);
rtp.set_sequence(random::<u16>().into());
rtp.set_timestamp(random::<u32>().into());
Self {
async_handle,
bitrate,
conn_active: None,
deadline: Instant::now(),
encoder,
interconnect,
mix_rx,
muted: false,
packet,
prevent_events: false,
silence_frames: 0,
sleeper: Default::default(),
soft_clip,
tracks: vec![],
ws: None,
}
}
fn run(&mut self) {
let mut events_failure = false;
let mut conn_failure = false;
'runner: loop {
loop {
use MixerMessage::*;
let error = match self.mix_rx.try_recv() {
Ok(AddTrack(mut t)) => {
t.source.prep_with_handle(self.async_handle.clone());
self.add_track(t)
},
Ok(SetTrack(t)) => {
self.tracks.clear();
let mut out = self.fire_event(EventMessage::RemoveAllTracks);
if let Some(mut t) = t {
t.source.prep_with_handle(self.async_handle.clone());
// Do this unconditionally: this affects local state infallibly,
// with the event installation being the remote part.
if let Err(e) = self.add_track(t) {
out = Err(e);
}
}
out
},
Ok(SetBitrate(b)) => {
self.bitrate = b;
if let Err(e) = self.set_bitrate(b) {
error!("Failed to update bitrate {:?}", e);
}
Ok(())
},
Ok(SetMute(m)) => {
self.muted = m;
Ok(())
},
Ok(SetConn(conn, ssrc)) => {
self.conn_active = Some(conn);
let mut rtp = MutableRtpPacket::new(&mut self.packet[..]).expect(
"Too few bytes in self.packet for RTP header.\
(Blame: VOICE_PACKET_MAX?)",
);
rtp.set_ssrc(ssrc);
self.deadline = Instant::now();
Ok(())
},
Ok(DropConn) => {
self.conn_active = None;
Ok(())
},
Ok(ReplaceInterconnect(i)) => {
self.prevent_events = false;
if let Some(ws) = &self.ws {
conn_failure |=
ws.send(WsMessage::ReplaceInterconnect(i.clone())).is_err();
}
if let Some(conn) = &self.conn_active {
conn_failure |= conn
.udp_rx
.send(UdpRxMessage::ReplaceInterconnect(i.clone()))
.is_err();
}
self.interconnect = i;
self.rebuild_tracks()
},
Ok(RebuildEncoder) => match new_encoder(self.bitrate) {
Ok(encoder) => {
self.encoder = encoder;
Ok(())
},
Err(e) => {
error!("Failed to rebuild encoder. Resetting bitrate. {:?}", e);
self.bitrate = DEFAULT_BITRATE;
self.encoder = new_encoder(self.bitrate)
.expect("Failed fallback rebuild of OpusEncoder with safe inputs.");
Ok(())
},
},
Ok(Ws(new_ws_handle)) => {
self.ws = new_ws_handle;
Ok(())
},
Err(TryRecvError::Disconnected) | Ok(Poison) => {
break 'runner;
},
Err(TryRecvError::Empty) => {
break;
},
};
if let Err(e) = error {
events_failure |= e.should_trigger_interconnect_rebuild();
conn_failure |= e.should_trigger_connect();
}
}
if let Err(e) = self.cycle().and_then(|_| self.audio_commands_events()) {
events_failure |= e.should_trigger_interconnect_rebuild();
conn_failure |= e.should_trigger_connect();
error!("Mixer thread cycle: {:?}", e);
}
// event failure? rebuild interconnect.
// ws or udp failure? full connect
// (soft reconnect is covered by the ws task.)
if events_failure {
self.prevent_events = true;
self.interconnect
.core
.send(CoreMessage::RebuildInterconnect)
.expect("FATAL: No way to rebuild driver core from mixer.");
events_failure = false;
}
if conn_failure {
self.interconnect
.core
.send(CoreMessage::FullReconnect)
.expect("FATAL: No way to rebuild driver core from mixer.");
conn_failure = false;
}
}
}
#[inline]
fn fire_event(&self, event: EventMessage) -> Result<()> {
// As this task is responsible for noticing the potential death of an event context,
// it's responsible for not forcibly recreating said context repeatedly.
if !self.prevent_events {
self.interconnect.events.send(event)?;
Ok(())
} else {
Ok(())
}
}
#[inline]
fn add_track(&mut self, mut track: Track) -> Result<()> {
let evts = track.events.take().unwrap_or_default();
let state = track.state();
let handle = track.handle.clone();
self.tracks.push(track);
self.interconnect
.events
.send(EventMessage::AddTrack(evts, state, handle))?;
Ok(())
}
// rebuilds the event thread's view of each track, in event of a full rebuild.
#[inline]
fn rebuild_tracks(&mut self) -> Result<()> {
for track in self.tracks.iter_mut() {
let evts = track.events.take().unwrap_or_default();
let state = track.state();
let handle = track.handle.clone();
self.interconnect
.events
.send(EventMessage::AddTrack(evts, state, handle))?;
}
Ok(())
}
#[inline]
fn mix_tracks<'a>(
&mut self,
opus_frame: &'a mut [u8],
mix_buffer: &mut [f32; STEREO_FRAME_SIZE],
) -> Result<(usize, &'a [u8])> {
let mut len = 0;
// Opus frame passthrough.
// This requires that we have only one track, who has volume 1.0, and an
// Opus codec type.
let do_passthrough = self.tracks.len() == 1 && {
let track = &self.tracks[0];
(track.volume - 1.0).abs() < f32::EPSILON && track.source.supports_passthrough()
};
for (i, track) in self.tracks.iter_mut().enumerate() {
let vol = track.volume;
let stream = &mut track.source;
if track.playing != PlayMode::Play {
continue;
}
let (temp_len, opus_len) = if do_passthrough {
(0, track.source.read_opus_frame(opus_frame).ok())
} else {
(stream.mix(mix_buffer, vol), None)
};
len = len.max(temp_len);
if temp_len > 0 || opus_len.is_some() {
track.step_frame();
} else if track.do_loop() {
if let Some(time) = track.seek_time(Default::default()) {
// have to reproduce self.fire_event here
// to circumvent the borrow checker's lack of knowledge.
//
// In event of error, one of the later event calls will
// trigger the event thread rebuild: it is more prudent that
// the mixer works as normal right now.
if !self.prevent_events {
let _ = self.interconnect.events.send(EventMessage::ChangeState(
i,
TrackStateChange::Position(time),
));
let _ = self.interconnect.events.send(EventMessage::ChangeState(
i,
TrackStateChange::Loops(track.loops, false),
));
}
}
} else {
track.end();
}
if let Some(opus_len) = opus_len {
return Ok((STEREO_FRAME_SIZE, &opus_frame[..opus_len]));
}
}
Ok((len, &opus_frame[..0]))
}
#[inline]
fn audio_commands_events(&mut self) -> Result<()> {
// Apply user commands.
for (i, track) in self.tracks.iter_mut().enumerate() {
// This causes fallible event system changes,
// but if the event thread has died then we'll certainly
// detect that on the tick later.
// Changes to play state etc. MUST all be handled.
track.process_commands(i, &self.interconnect);
}
// TODO: do without vec?
let mut i = 0;
let mut to_remove = Vec::with_capacity(self.tracks.len());
while i < self.tracks.len() {
let track = self
.tracks
.get_mut(i)
.expect("Tried to remove an illegal track index.");
if track.playing.is_done() {
let p_state = track.playing();
self.tracks.remove(i);
to_remove.push(i);
self.fire_event(EventMessage::ChangeState(
i,
TrackStateChange::Mode(p_state),
))?;
} else {
i += 1;
}
}
// Tick
self.fire_event(EventMessage::Tick)?;
// Then do removals.
for i in &to_remove[..] {
self.fire_event(EventMessage::RemoveTrack(*i))?;
}
Ok(())
}
#[inline]
fn march_deadline(&mut self) {
self.sleeper
.sleep(self.deadline.saturating_duration_since(Instant::now()));
self.deadline += TIMESTEP_LENGTH;
}
fn cycle(&mut self) -> Result<()> {
if self.conn_active.is_none() {
self.march_deadline();
return Ok(());
}
// TODO: can we make opus_frame_backing *actually* a view over
// some region of self.packet, derived using the encryption mode?
// This saves a copy on Opus passthrough.
let mut opus_frame_backing = [0u8; STEREO_FRAME_SIZE];
let mut mix_buffer = [0f32; STEREO_FRAME_SIZE];
// Slice which mix tracks may use to passthrough direct Opus frames.
let mut opus_space = &mut opus_frame_backing[..];
// Walk over all the audio files, combining into one audio frame according
// to volume, play state, etc.
let (mut len, mut opus_frame) = self.mix_tracks(&mut opus_space, &mut mix_buffer)?;
self.soft_clip.apply(&mut mix_buffer[..])?;
if self.muted {
len = 0;
}
if len == 0 {
if self.silence_frames > 0 {
self.silence_frames -= 1;
// Explicit "Silence" frame.
opus_frame = &SILENT_FRAME[..];
} else {
// Per official guidelines, send 5x silence BEFORE we stop speaking.
if let Some(ws) = &self.ws {
// NOTE: this should prevent a catastrophic thread pileup.
// A full reconnect might cause an inner closed connection.
// It's safer to leave the central task to clean this up and
// pass the mixer a new channel.
let _ = ws.send(WsMessage::Speaking(false));
}
self.march_deadline();
return Ok(());
}
} else {
self.silence_frames = 5;
}
if let Some(ws) = &self.ws {
ws.send(WsMessage::Speaking(true))?;
}
self.march_deadline();
self.prep_and_send_packet(mix_buffer, opus_frame)?;
Ok(())
}
fn set_bitrate(&mut self, bitrate: Bitrate) -> Result<()> {
self.encoder.set_bitrate(bitrate).map_err(Into::into)
}
fn prep_and_send_packet(&mut self, buffer: [f32; 1920], opus_frame: &[u8]) -> Result<()> {
let conn = self
.conn_active
.as_mut()
.expect("Shouldn't be mixing packets without access to a cipher + UDP dest.");
let mut nonce = Nonce::default();
let index = {
let mut rtp = MutableRtpPacket::new(&mut self.packet[..]).expect(
"FATAL: Too few bytes in self.packet for RTP header.\
(Blame: VOICE_PACKET_MAX?)",
);
let pkt = rtp.packet();
let rtp_len = RtpPacket::minimum_packet_size();
nonce[..rtp_len].copy_from_slice(&pkt[..rtp_len]);
let payload = rtp.payload_mut();
let payload_len = if opus_frame.is_empty() {
self.encoder
.encode_float(&buffer[..STEREO_FRAME_SIZE], &mut payload[TAG_SIZE..])?
} else {
let len = opus_frame.len();
payload[TAG_SIZE..TAG_SIZE + len].clone_from_slice(opus_frame);
len
};
let final_payload_size = TAG_SIZE + payload_len;
let tag = conn.cipher.encrypt_in_place_detached(
&nonce,
b"",
&mut payload[TAG_SIZE..final_payload_size],
)?;
payload[..TAG_SIZE].copy_from_slice(&tag[..]);
rtp_len + final_payload_size
};
// TODO: This is dog slow, don't do this.
// Can we replace this with a shared ring buffer + semaphore?
// i.e., do something like double/triple buffering in graphics.
conn.udp_tx
.send(UdpTxMessage::Packet(self.packet[..index].to_vec()))?;
let mut rtp = MutableRtpPacket::new(&mut self.packet[..]).expect(
"FATAL: Too few bytes in self.packet for RTP header.\
(Blame: VOICE_PACKET_MAX?)",
);
rtp.set_sequence(rtp.get_sequence() + 1);
rtp.set_timestamp(rtp.get_timestamp() + MONO_FRAME_SIZE as u32);
Ok(())
}
}
/// The mixing thread is a synchronous context due to its compute-bound nature.
///
/// We pass in an async handle for the benefit of some Input classes (e.g., restartables)
/// who need to run their restart code elsewhere and return blank data until such time.
#[instrument(skip(interconnect, mix_rx, async_handle))]
pub(crate) fn runner(
interconnect: Interconnect,
mix_rx: Receiver<MixerMessage>,
async_handle: Handle,
) {
let mut mixer = Mixer::new(mix_rx, async_handle, interconnect);
mixer.run();
}

155
src/driver/tasks/mod.rs Normal file
View File

@@ -0,0 +1,155 @@
pub mod error;
mod events;
pub(crate) mod message;
mod mixer;
pub(crate) mod udp_rx;
pub(crate) mod udp_tx;
pub(crate) mod ws;
use super::{
connection::{error::Error as ConnectionError, Connection},
Config,
};
use flume::{Receiver, RecvError, Sender};
use message::*;
use tokio::runtime::Handle;
use tracing::{error, info, instrument};
pub(crate) fn start(config: Config, rx: Receiver<CoreMessage>, tx: Sender<CoreMessage>) {
tokio::spawn(async move {
info!("Driver started.");
runner(config, rx, tx).await;
info!("Driver finished.");
});
}
fn start_internals(core: Sender<CoreMessage>) -> Interconnect {
let (evt_tx, evt_rx) = flume::unbounded();
let (mix_tx, mix_rx) = flume::unbounded();
let interconnect = Interconnect {
core,
events: evt_tx,
mixer: mix_tx,
};
let ic = interconnect.clone();
tokio::spawn(async move {
info!("Event processor started.");
events::runner(ic, evt_rx).await;
info!("Event processor finished.");
});
let ic = interconnect.clone();
let handle = Handle::current();
std::thread::spawn(move || {
info!("Mixer started.");
mixer::runner(ic, mix_rx, handle);
info!("Mixer finished.");
});
interconnect
}
#[instrument(skip(rx, tx))]
async fn runner(config: Config, rx: Receiver<CoreMessage>, tx: Sender<CoreMessage>) {
let mut connection = None;
let mut interconnect = start_internals(tx);
loop {
match rx.recv_async().await {
Ok(CoreMessage::ConnectWithResult(info, tx)) => {
connection = match Connection::new(info, &interconnect, &config).await {
Ok(connection) => {
// Other side may not be listening: this is fine.
let _ = tx.send(Ok(()));
Some(connection)
},
Err(why) => {
// See above.
let _ = tx.send(Err(why));
None
},
};
},
Ok(CoreMessage::Disconnect) => {
connection = None;
let _ = interconnect.mixer.send(MixerMessage::DropConn);
let _ = interconnect.mixer.send(MixerMessage::RebuildEncoder);
},
Ok(CoreMessage::SetTrack(s)) => {
let _ = interconnect.mixer.send(MixerMessage::SetTrack(s));
},
Ok(CoreMessage::AddTrack(s)) => {
let _ = interconnect.mixer.send(MixerMessage::AddTrack(s));
},
Ok(CoreMessage::SetBitrate(b)) => {
let _ = interconnect.mixer.send(MixerMessage::SetBitrate(b));
},
Ok(CoreMessage::AddEvent(evt)) => {
let _ = interconnect.events.send(EventMessage::AddGlobalEvent(evt));
},
Ok(CoreMessage::Mute(m)) => {
let _ = interconnect.mixer.send(MixerMessage::SetMute(m));
},
Ok(CoreMessage::Reconnect) => {
if let Some(mut conn) = connection.take() {
// try once: if interconnect, try again.
// if still issue, full connect.
let info = conn.info.clone();
let full_connect = match conn.reconnect().await {
Ok(()) => {
connection = Some(conn);
false
},
Err(ConnectionError::InterconnectFailure(_)) => {
interconnect.restart_volatile_internals();
match conn.reconnect().await {
Ok(()) => {
connection = Some(conn);
false
},
_ => true,
}
},
_ => true,
};
if full_connect {
connection = Connection::new(info, &interconnect, &config)
.await
.map_err(|e| {
error!("Catastrophic connection failure. Stopping. {:?}", e);
e
})
.ok();
}
}
},
Ok(CoreMessage::FullReconnect) =>
if let Some(conn) = connection.take() {
let info = conn.info.clone();
connection = Connection::new(info, &interconnect, &config)
.await
.map_err(|e| {
error!("Catastrophic connection failure. Stopping. {:?}", e);
e
})
.ok();
},
Ok(CoreMessage::RebuildInterconnect) => {
interconnect.restart_volatile_internals();
},
Err(RecvError::Disconnected) | Ok(CoreMessage::Poison) => {
break;
},
}
}
info!("Main thread exited");
interconnect.poison_all();
}

286
src/driver/tasks/udp_rx.rs Normal file
View File

@@ -0,0 +1,286 @@
use super::{
error::{Error, Result},
message::*,
};
use crate::{constants::*, driver::CryptoMode, events::CoreContext};
use audiopus::{coder::Decoder as OpusDecoder, Channels};
use discortp::{
demux::{self, DemuxedMut},
rtp::{RtpExtensionPacket, RtpPacket},
FromPacket,
MutablePacket,
Packet,
PacketSize,
};
use flume::Receiver;
use std::collections::HashMap;
use tokio::net::udp::RecvHalf;
use tracing::{error, info, instrument, warn};
use xsalsa20poly1305::{aead::AeadInPlace, Nonce, Tag, XSalsa20Poly1305 as Cipher, TAG_SIZE};
#[derive(Debug)]
struct SsrcState {
silent_frame_count: u16,
decoder: OpusDecoder,
last_seq: u16,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum SpeakingDelta {
Same,
Start,
Stop,
}
impl SsrcState {
fn new(pkt: RtpPacket<'_>) -> Self {
Self {
silent_frame_count: 5, // We do this to make the first speech packet fire an event.
decoder: OpusDecoder::new(SAMPLE_RATE, Channels::Stereo)
.expect("Failed to create new Opus decoder for source."),
last_seq: pkt.get_sequence().into(),
}
}
fn process(
&mut self,
pkt: RtpPacket<'_>,
data_offset: usize,
) -> Result<(SpeakingDelta, Vec<i16>)> {
let new_seq: u16 = pkt.get_sequence().into();
let extensions = pkt.get_extension() != 0;
let seq_delta = new_seq.wrapping_sub(self.last_seq);
Ok(if seq_delta >= (1 << 15) {
// Overflow, reordered (previously missing) packet.
(SpeakingDelta::Same, vec![])
} else {
self.last_seq = new_seq;
let missed_packets = seq_delta.saturating_sub(1);
let (audio, pkt_size) =
self.scan_and_decode(&pkt.payload()[data_offset..], extensions, missed_packets)?;
let delta = if pkt_size == SILENT_FRAME.len() {
// Frame is silent.
let old = self.silent_frame_count;
self.silent_frame_count =
self.silent_frame_count.saturating_add(1 + missed_packets);
if self.silent_frame_count >= 5 && old < 5 {
SpeakingDelta::Stop
} else {
SpeakingDelta::Same
}
} else {
// Frame has meaningful audio.
let out = if self.silent_frame_count >= 5 {
SpeakingDelta::Start
} else {
SpeakingDelta::Same
};
self.silent_frame_count = 0;
out
};
(delta, audio)
})
}
fn scan_and_decode(
&mut self,
data: &[u8],
extension: bool,
missed_packets: u16,
) -> Result<(Vec<i16>, usize)> {
let mut out = vec![0; STEREO_FRAME_SIZE];
let start = if extension {
RtpExtensionPacket::new(data)
.map(|pkt| pkt.packet_size())
.ok_or_else(|| {
error!("Extension packet indicated, but insufficient space.");
Error::IllegalVoicePacket
})
} else {
Ok(0)
}?;
for _ in 0..missed_packets {
let missing_frame: Option<&[u8]> = None;
if let Err(e) = self.decoder.decode(missing_frame, &mut out[..], false) {
warn!("Issue while decoding for missed packet: {:?}.", e);
}
}
let audio_len = self
.decoder
.decode(Some(&data[start..]), &mut out[..], false)
.map_err(|e| {
error!("Failed to decode received packet: {:?}.", e);
e
})?;
// Decoding to stereo: audio_len refers to sample count irrespective of channel count.
// => multiply by number of channels.
out.truncate(2 * audio_len);
Ok((out, data.len() - start))
}
}
struct UdpRx {
cipher: Cipher,
decoder_map: HashMap<u32, SsrcState>,
#[allow(dead_code)]
mode: CryptoMode, // In future, this will allow crypto mode selection.
packet_buffer: [u8; VOICE_PACKET_MAX],
rx: Receiver<UdpRxMessage>,
udp_socket: RecvHalf,
}
impl UdpRx {
#[instrument(skip(self))]
async fn run(&mut self, interconnect: &mut Interconnect) {
loop {
tokio::select! {
Ok((len, _addr)) = self.udp_socket.recv_from(&mut self.packet_buffer[..]) => {
self.process_udp_message(interconnect, len);
}
msg = self.rx.recv_async() => {
use UdpRxMessage::*;
match msg {
Ok(ReplaceInterconnect(i)) => {
*interconnect = i;
}
Ok(Poison) | Err(_) => break,
}
}
}
}
}
fn process_udp_message(&mut self, interconnect: &Interconnect, len: usize) {
// NOTE: errors here (and in general for UDP) are not fatal to the connection.
// Panics should be avoided due to adversarial nature of rx'd packets,
// but correct handling should not prompt a reconnect.
//
// For simplicity, we nominate the mixing context to rebuild the event
// context if it fails (hence, the `let _ =` statements.), as it will try to
// make contact every 20ms.
let packet = &mut self.packet_buffer[..len];
match demux::demux_mut(packet) {
DemuxedMut::Rtp(mut rtp) => {
if !rtp_valid(rtp.to_immutable()) {
error!("Illegal RTP message received.");
return;
}
let rtp_body_start =
decrypt_in_place(&mut rtp, &self.cipher).expect("RTP decryption failed.");
let entry = self
.decoder_map
.entry(rtp.get_ssrc())
.or_insert_with(|| SsrcState::new(rtp.to_immutable()));
if let Ok((delta, audio)) = entry.process(rtp.to_immutable(), rtp_body_start) {
match delta {
SpeakingDelta::Start => {
let _ = interconnect.events.send(EventMessage::FireCoreEvent(
CoreContext::SpeakingUpdate {
ssrc: rtp.get_ssrc(),
speaking: true,
},
));
},
SpeakingDelta::Stop => {
let _ = interconnect.events.send(EventMessage::FireCoreEvent(
CoreContext::SpeakingUpdate {
ssrc: rtp.get_ssrc(),
speaking: false,
},
));
},
_ => {},
}
let _ = interconnect.events.send(EventMessage::FireCoreEvent(
CoreContext::VoicePacket {
audio,
packet: rtp.from_packet(),
payload_offset: rtp_body_start,
},
));
} else {
warn!("RTP decoding/decrytion failed.");
}
},
DemuxedMut::Rtcp(mut rtcp) => {
let rtcp_body_start = decrypt_in_place(&mut rtcp, &self.cipher);
if let Ok(start) = rtcp_body_start {
let _ = interconnect.events.send(EventMessage::FireCoreEvent(
CoreContext::RtcpPacket {
packet: rtcp.from_packet(),
payload_offset: start,
},
));
} else {
warn!("RTCP decryption failed.");
}
},
DemuxedMut::FailedParse(t) => {
warn!("Failed to parse message of type {:?}.", t);
},
_ => {
warn!("Illegal UDP packet from voice server.");
},
}
}
}
#[instrument(skip(interconnect, rx, cipher))]
pub(crate) async fn runner(
mut interconnect: Interconnect,
rx: Receiver<UdpRxMessage>,
cipher: Cipher,
mode: CryptoMode,
udp_socket: RecvHalf,
) {
info!("UDP receive handle started.");
let mut state = UdpRx {
cipher,
decoder_map: Default::default(),
mode,
packet_buffer: [0u8; VOICE_PACKET_MAX],
rx,
udp_socket,
};
state.run(&mut interconnect).await;
info!("UDP receive handle stopped.");
}
#[inline]
fn decrypt_in_place(packet: &mut impl MutablePacket, cipher: &Cipher) -> Result<usize> {
// Applies discord's cheapest.
// In future, might want to make a choice...
let header_len = packet.packet().len() - packet.payload().len();
let mut nonce = Nonce::default();
nonce[..header_len].copy_from_slice(&packet.packet()[..header_len]);
let data = packet.payload_mut();
let (tag_bytes, data_bytes) = data.split_at_mut(TAG_SIZE);
let tag = Tag::from_slice(tag_bytes);
Ok(cipher
.decrypt_in_place_detached(&nonce, b"", data_bytes, tag)
.map(|_| TAG_SIZE)?)
}
#[inline]
fn rtp_valid(packet: RtpPacket<'_>) -> bool {
packet.get_version() == RTP_VERSION && packet.get_payload_type() == RTP_PROFILE_TYPE
}

View File

@@ -0,0 +1,45 @@
use super::message::*;
use crate::constants::*;
use discortp::discord::MutableKeepalivePacket;
use flume::Receiver;
use tokio::{
net::udp::SendHalf,
time::{timeout_at, Elapsed, Instant},
};
use tracing::{error, info, instrument, trace};
#[instrument(skip(udp_msg_rx))]
pub(crate) async fn runner(udp_msg_rx: Receiver<UdpTxMessage>, ssrc: u32, mut udp_tx: SendHalf) {
info!("UDP transmit handle started.");
let mut keepalive_bytes = [0u8; MutableKeepalivePacket::minimum_packet_size()];
let mut ka = MutableKeepalivePacket::new(&mut keepalive_bytes[..])
.expect("FATAL: Insufficient bytes given to keepalive packet.");
ka.set_ssrc(ssrc);
let mut ka_time = Instant::now() + UDP_KEEPALIVE_GAP;
loop {
use UdpTxMessage::*;
match timeout_at(ka_time, udp_msg_rx.recv_async()).await {
Err(Elapsed { .. }) => {
trace!("Sending UDP Keepalive.");
if let Err(e) = udp_tx.send(&keepalive_bytes[..]).await {
error!("Fatal UDP keepalive send error: {:?}.", e);
break;
}
ka_time += UDP_KEEPALIVE_GAP;
},
Ok(Ok(Packet(p))) =>
if let Err(e) = udp_tx.send(&p[..]).await {
error!("Fatal UDP packet send error: {:?}.", e);
break;
},
Ok(Err(_)) | Ok(Ok(Poison)) => {
break;
},
}
}
info!("UDP transmit handle stopped.");
}

205
src/driver/tasks/ws.rs Normal file
View File

@@ -0,0 +1,205 @@
use super::{error::Result, message::*};
use crate::{
events::CoreContext,
model::{
payload::{Heartbeat, Speaking},
Event as GatewayEvent,
SpeakingState,
},
ws::{Error as WsError, ReceiverExt, SenderExt, WsStream},
};
use flume::Receiver;
use rand::random;
use std::time::Duration;
use tokio::time::{self, Instant};
use tracing::{error, info, instrument, trace, warn};
struct AuxNetwork {
rx: Receiver<WsMessage>,
ws_client: WsStream,
dont_send: bool,
ssrc: u32,
heartbeat_interval: Duration,
speaking: SpeakingState,
last_heartbeat_nonce: Option<u64>,
}
impl AuxNetwork {
pub(crate) fn new(
evt_rx: Receiver<WsMessage>,
ws_client: WsStream,
ssrc: u32,
heartbeat_interval: f64,
) -> Self {
Self {
rx: evt_rx,
ws_client,
dont_send: false,
ssrc,
heartbeat_interval: Duration::from_secs_f64(heartbeat_interval / 1000.0),
speaking: SpeakingState::empty(),
last_heartbeat_nonce: None,
}
}
#[instrument(skip(self))]
async fn run(&mut self, interconnect: &mut Interconnect) {
let mut next_heartbeat = Instant::now() + self.heartbeat_interval;
loop {
let mut ws_error = false;
let hb = time::delay_until(next_heartbeat);
tokio::select! {
_ = hb => {
ws_error = match self.send_heartbeat().await {
Err(e) => {
error!("Heartbeat send failure {:?}.", e);
true
},
_ => false,
};
next_heartbeat = self.next_heartbeat();
}
ws_msg = self.ws_client.recv_json_no_timeout(), if !self.dont_send => {
ws_error = match ws_msg {
Err(WsError::Json(e)) => {
warn!("Unexpected JSON {:?}.", e);
false
},
Err(e) => {
error!("Error processing ws {:?}.", e);
true
},
Ok(Some(msg)) => {
self.process_ws(interconnect, msg);
false
},
_ => false,
};
}
inner_msg = self.rx.recv_async() => {
match inner_msg {
Ok(WsMessage::Ws(data)) => {
self.ws_client = *data;
next_heartbeat = self.next_heartbeat();
self.dont_send = false;
},
Ok(WsMessage::ReplaceInterconnect(i)) => {
*interconnect = i;
},
Ok(WsMessage::SetKeepalive(keepalive)) => {
self.heartbeat_interval = Duration::from_secs_f64(keepalive / 1000.0);
next_heartbeat = self.next_heartbeat();
},
Ok(WsMessage::Speaking(is_speaking)) => {
if self.speaking.contains(SpeakingState::MICROPHONE) != is_speaking && !self.dont_send {
self.speaking.set(SpeakingState::MICROPHONE, is_speaking);
info!("Changing to {:?}", self.speaking);
let ssu_status = self.ws_client
.send_json(&GatewayEvent::from(Speaking {
delay: Some(0),
speaking: self.speaking,
ssrc: self.ssrc,
user_id: None,
}))
.await;
ws_error |= match ssu_status {
Err(e) => {
error!("Issue sending speaking update {:?}.", e);
true
},
_ => false,
}
}
},
Err(_) | Ok(WsMessage::Poison) => {
break;
},
}
}
}
if ws_error {
let _ = interconnect.core.send(CoreMessage::Reconnect);
self.dont_send = true;
}
}
}
fn next_heartbeat(&self) -> Instant {
Instant::now() + self.heartbeat_interval
}
async fn send_heartbeat(&mut self) -> Result<()> {
let nonce = random::<u64>();
self.last_heartbeat_nonce = Some(nonce);
trace!("Sent heartbeat {:?}", self.speaking);
if !self.dont_send {
self.ws_client
.send_json(&GatewayEvent::from(Heartbeat { nonce }))
.await?;
}
Ok(())
}
fn process_ws(&mut self, interconnect: &Interconnect, value: GatewayEvent) {
match value {
GatewayEvent::Speaking(ev) => {
let _ = interconnect.events.send(EventMessage::FireCoreEvent(
CoreContext::SpeakingStateUpdate(ev),
));
},
GatewayEvent::ClientConnect(ev) => {
let _ = interconnect
.events
.send(EventMessage::FireCoreEvent(CoreContext::ClientConnect(ev)));
},
GatewayEvent::ClientDisconnect(ev) => {
let _ = interconnect.events.send(EventMessage::FireCoreEvent(
CoreContext::ClientDisconnect(ev),
));
},
GatewayEvent::HeartbeatAck(ev) => {
if let Some(nonce) = self.last_heartbeat_nonce.take() {
if ev.nonce == nonce {
trace!("Heartbeat ACK received.");
} else {
warn!(
"Heartbeat nonce mismatch! Expected {}, saw {}.",
nonce, ev.nonce
);
}
}
},
other => {
trace!("Received other websocket data: {:?}", other);
},
}
}
}
#[instrument(skip(interconnect, ws_client))]
pub(crate) async fn runner(
mut interconnect: Interconnect,
evt_rx: Receiver<WsMessage>,
ws_client: WsStream,
ssrc: u32,
heartbeat_interval: f64,
) {
info!("WS thread started.");
let mut aux = AuxNetwork::new(evt_rx, ws_client, ssrc, heartbeat_interval);
aux.run(&mut interconnect).await;
info!("WS thread finished.");
}

69
src/error.rs Normal file
View File

@@ -0,0 +1,69 @@
//! Driver and gateway error handling.
#[cfg(feature = "serenity")]
use futures::channel::mpsc::TrySendError;
#[cfg(feature = "serenity")]
use serenity::gateway::InterMessage;
#[cfg(feature = "gateway")]
use std::{error::Error, fmt};
#[cfg(feature = "twilight")]
use twilight_gateway::shard::CommandError;
#[cfg(feature = "gateway")]
#[derive(Debug)]
/// Error returned when a manager or call handler is
/// unable to send messages over Discord's gateway.
pub enum JoinError {
/// No available gateway connection was provided to send
/// voice state update messages.
NoSender,
/// Tried to leave a [`Call`] which was not found.
///
/// [`Call`]: ../struct.Call.html
NoCall,
#[cfg(feature = "serenity")]
/// Serenity-specific WebSocket send error.
Serenity(TrySendError<InterMessage>),
#[cfg(feature = "twilight")]
/// Twilight-specific WebSocket send error.
Twilight(CommandError),
}
#[cfg(feature = "gateway")]
impl fmt::Display for JoinError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Failed to Join Voice channel: ")?;
match self {
JoinError::NoSender => write!(f, "no gateway destination."),
JoinError::NoCall => write!(f, "tried to leave a non-existent call."),
#[cfg(feature = "serenity")]
JoinError::Serenity(t) => write!(f, "serenity failure {}.", t),
#[cfg(feature = "twilight")]
JoinError::Twilight(t) => write!(f, "twilight failure {}.", t),
}
}
}
#[cfg(feature = "gateway")]
impl Error for JoinError {}
#[cfg(all(feature = "serenity", feature = "gateway"))]
impl From<TrySendError<InterMessage>> for JoinError {
fn from(e: TrySendError<InterMessage>) -> Self {
JoinError::Serenity(e)
}
}
#[cfg(all(feature = "twilight", feature = "gateway"))]
impl From<CommandError> for JoinError {
fn from(e: CommandError) -> Self {
JoinError::Twilight(e)
}
}
#[cfg(feature = "gateway")]
/// Convenience type for Discord gateway error handling.
pub type JoinResult<T> = Result<T, JoinError>;
#[cfg(feature = "driver")]
pub use crate::driver::connection::error::{Error as ConnectionError, Result as ConnectionResult};

137
src/events/context.rs Normal file
View File

@@ -0,0 +1,137 @@
use super::*;
use crate::{
model::payload::{ClientConnect, ClientDisconnect, Speaking},
tracks::{TrackHandle, TrackState},
};
use discortp::{rtcp::Rtcp, rtp::Rtp};
/// Information about which tracks or data fired an event.
///
/// [`Track`] events may be local or global, and have no tracks
/// if fired on the global context via [`Handler::add_global_event`].
///
/// [`Track`]: ../tracks/struct.Track.html
/// [`Handler::add_global_event`]: ../struct.Handler.html#method.add_global_event
#[derive(Clone, Debug)]
pub enum EventContext<'a> {
/// Track event context, passed to events created via [`TrackHandle::add_event`],
/// [`EventStore::add_event`], or relevant global events.
///
/// [`EventStore::add_event`]: struct.EventStore.html#method.add_event
/// [`TrackHandle::add_event`]: ../tracks/struct.TrackHandle.html#method.add_event
Track(&'a [(&'a TrackState, &'a TrackHandle)]),
/// Speaking state update, typically describing how another voice
/// user is transmitting audio data. Clients must send at least one such
/// packet to allow SSRC/UserID matching.
SpeakingStateUpdate(Speaking),
/// Speaking state transition, describing whether a given source has started/stopped
/// transmitting. This fires in response to a silent burst, or the first packet
/// breaking such a burst.
SpeakingUpdate {
/// Synchronisation Source of the user who has begun speaking.
///
/// This must be combined with another event class to map this back to
/// its original UserId.
ssrc: u32,
/// Whether this user is currently speaking.
speaking: bool,
},
/// Opus audio packet, received from another stream (detailed in `packet`).
/// `payload_offset` contains the true payload location within the raw packet's `payload()`,
/// if extensions or raw packet data are required.
/// if `audio.len() == 0`, then this packet arrived out-of-order.
VoicePacket {
/// Decoded audio from this packet.
audio: &'a Vec<i16>,
/// Raw RTP packet data.
///
/// Includes the SSRC (i.e., sender) of this packet.
packet: &'a Rtp,
/// Byte index into the packet for where the payload begins.
payload_offset: usize,
},
/// Telemetry/statistics packet, received from another stream (detailed in `packet`).
/// `payload_offset` contains the true payload location within the raw packet's `payload()`,
/// to allow manual decoding of `Rtcp` packet bodies.
RtcpPacket {
/// Raw RTCP packet data.
packet: &'a Rtcp,
/// Byte index into the packet for where the payload begins.
payload_offset: usize,
},
/// Fired whenever a client connects to a call for the first time, allowing SSRC/UserID
/// matching.
ClientConnect(ClientConnect),
/// Fired whenever a client disconnects.
ClientDisconnect(ClientDisconnect),
}
#[derive(Clone, Debug)]
pub(crate) enum CoreContext {
SpeakingStateUpdate(Speaking),
SpeakingUpdate {
ssrc: u32,
speaking: bool,
},
VoicePacket {
audio: Vec<i16>,
packet: Rtp,
payload_offset: usize,
},
RtcpPacket {
packet: Rtcp,
payload_offset: usize,
},
ClientConnect(ClientConnect),
ClientDisconnect(ClientDisconnect),
}
impl<'a> CoreContext {
pub(crate) fn to_user_context(&'a self) -> EventContext<'a> {
use CoreContext::*;
match self {
SpeakingStateUpdate(evt) => EventContext::SpeakingStateUpdate(*evt),
SpeakingUpdate { ssrc, speaking } => EventContext::SpeakingUpdate {
ssrc: *ssrc,
speaking: *speaking,
},
VoicePacket {
audio,
packet,
payload_offset,
} => EventContext::VoicePacket {
audio,
packet,
payload_offset: *payload_offset,
},
RtcpPacket {
packet,
payload_offset,
} => EventContext::RtcpPacket {
packet,
payload_offset: *payload_offset,
},
ClientConnect(evt) => EventContext::ClientConnect(*evt),
ClientDisconnect(evt) => EventContext::ClientDisconnect(*evt),
}
}
}
impl EventContext<'_> {
/// Retreive the event class for an event (i.e., when matching)
/// an event against the registered listeners.
pub fn to_core_event(&self) -> Option<CoreEvent> {
use EventContext::*;
match self {
SpeakingStateUpdate { .. } => Some(CoreEvent::SpeakingStateUpdate),
SpeakingUpdate { .. } => Some(CoreEvent::SpeakingUpdate),
VoicePacket { .. } => Some(CoreEvent::VoicePacket),
RtcpPacket { .. } => Some(CoreEvent::RtcpPacket),
ClientConnect { .. } => Some(CoreEvent::ClientConnect),
ClientDisconnect { .. } => Some(CoreEvent::ClientDisconnect),
_ => None,
}
}
}

31
src/events/core.rs Normal file
View File

@@ -0,0 +1,31 @@
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
/// Voice core events occur on receipt of
/// voice packets and telemetry.
///
/// Core events persist while the `action` in [`EventData`]
/// returns `None`.
///
/// [`EventData`]: struct.EventData.html
pub enum CoreEvent {
/// Fired on receipt of a speaking state update from another host.
///
/// Note: this will fire when a user starts speaking for the first time,
/// or changes their capabilities.
SpeakingStateUpdate,
/// Fires when a source starts speaking, or stops speaking
/// (*i.e.*, 5 consecutive silent frames).
SpeakingUpdate,
/// Fires on receipt of a voice packet from another stream in the voice call.
///
/// As RTP packets do not map to Discord's notion of users, SSRCs must be mapped
/// back using the user IDs seen through client connection, disconnection,
/// or speaking state update.
VoicePacket,
/// Fires on receipt of an RTCP packet, containing various call stats
/// such as latency reports.
RtcpPacket,
/// Fires whenever a user connects to the same stream as the bot.
ClientConnect,
/// Fires whenever a user disconnects from the same stream as the bot.
ClientDisconnect,
}

88
src/events/data.rs Normal file
View File

@@ -0,0 +1,88 @@
use super::*;
use std::{cmp::Ordering, time::Duration};
/// Internal representation of an event, as handled by the audio context.
pub struct EventData {
pub(crate) event: Event,
pub(crate) fire_time: Option<Duration>,
pub(crate) action: Box<dyn EventHandler>,
}
impl EventData {
/// Create a representation of an event and its associated handler.
///
/// An event handler, `action`, receives an [`EventContext`] and optionally
/// produces a new [`Event`] type for itself. Returning `None` will
/// maintain the same event type, while removing any [`Delayed`] entries.
/// Event handlers will be re-added with their new trigger condition,
/// or removed if [`Cancel`]led
///
/// [`EventContext`]: enum.EventContext.html
/// [`Event`]: enum.Event.html
/// [`Delayed`]: enum.Event.html#variant.Delayed
/// [`Cancel`]: enum.Event.html#variant.Cancel
pub fn new<F: EventHandler + 'static>(event: Event, action: F) -> Self {
Self {
event,
fire_time: None,
action: Box::new(action),
}
}
/// Computes the next firing time for a timer event.
pub fn compute_activation(&mut self, now: Duration) {
match self.event {
Event::Periodic(period, phase) => {
self.fire_time = Some(now + phase.unwrap_or(period));
},
Event::Delayed(offset) => {
self.fire_time = Some(now + offset);
},
_ => {},
}
}
}
impl std::fmt::Debug for EventData {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
write!(
f,
"Event {{ event: {:?}, fire_time: {:?}, action: <fn> }}",
self.event, self.fire_time
)
}
}
/// Events are ordered/compared based on their firing time.
impl Ord for EventData {
fn cmp(&self, other: &Self) -> Ordering {
if self.fire_time.is_some() && other.fire_time.is_some() {
let t1 = self
.fire_time
.as_ref()
.expect("T1 known to be well-defined by above.");
let t2 = other
.fire_time
.as_ref()
.expect("T2 known to be well-defined by above.");
t1.cmp(&t2)
} else {
Ordering::Equal
}
}
}
impl PartialOrd for EventData {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl PartialEq for EventData {
fn eq(&self, other: &Self) -> bool {
self.fire_time == other.fire_time
}
}
impl Eq for EventData {}

91
src/events/mod.rs Normal file
View File

@@ -0,0 +1,91 @@
//! Events relating to tracks, timing, and other callers.
mod context;
mod core;
mod data;
mod store;
mod track;
mod untimed;
pub use self::{context::*, core::*, data::*, store::*, track::*, untimed::*};
use async_trait::async_trait;
use std::time::Duration;
#[async_trait]
/// Trait to handle an event which can be fired per-track, or globally.
///
/// These may be feasibly reused between several event sources.
pub trait EventHandler: Send + Sync {
/// Respond to one received event.
async fn act(&self, ctx: &EventContext<'_>) -> Option<Event>;
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
/// Classes of event which may occur, triggering a handler
/// at the local (track-specific) or global level.
///
/// Local time-based events rely upon the current playback
/// time of a track, and so will not fire if a track becomes paused
/// or stops. In case this is required, global events are a better
/// fit.
///
/// Event handlers themselves are described in [`EventData::action`].
///
/// [`EventData::action`]: struct.EventData.html#method.action
pub enum Event {
/// Periodic events rely upon two parameters: a *period*
/// and an optional *phase*.
///
/// If the *phase* is `None`, then the event will first fire
/// in one *period*. Periodic events repeat automatically
/// so long as the `action` in [`EventData`] returns `None`.
///
/// [`EventData`]: struct.EventData.html
Periodic(Duration, Option<Duration>),
/// Delayed events rely upon a *delay* parameter, and
/// fire one *delay* after the audio context processes them.
///
/// Delayed events are automatically removed once fired,
/// so long as the `action` in [`EventData`] returns `None`.
///
/// [`EventData`]: struct.EventData.html
Delayed(Duration),
/// Track events correspond to certain actions or changes
/// of state, such as a track finishing, looping, or being
/// manually stopped.
///
/// Track events persist while the `action` in [`EventData`]
/// returns `None`.
///
/// [`EventData`]: struct.EventData.html
Track(TrackEvent),
/// Core events
///
/// Track events persist while the `action` in [`EventData`]
/// returns `None`. Core events **must** be applied globally,
/// as attaching them to a track is a no-op.
///
/// [`EventData`]: struct.EventData.html
Core(CoreEvent),
/// Cancels the event, if it was intended to persist.
Cancel,
}
impl Event {
pub(crate) fn is_global_only(&self) -> bool {
matches!(self, Self::Core(_))
}
}
impl From<TrackEvent> for Event {
fn from(evt: TrackEvent) -> Self {
Event::Track(evt)
}
}
impl From<CoreEvent> for Event {
fn from(evt: CoreEvent) -> Self {
Event::Core(evt)
}
}

252
src/events/store.rs Normal file
View File

@@ -0,0 +1,252 @@
use super::*;
use crate::{
constants::*,
tracks::{PlayMode, TrackHandle, TrackState},
};
use std::{
collections::{BinaryHeap, HashMap},
time::Duration,
};
use tracing::info;
#[derive(Debug, Default)]
/// Storage for [`EventData`], designed to be used for both local and global contexts.
///
/// Timed events are stored in a binary heap for fast selection, and have custom `Eq`,
/// `Ord`, etc. implementations to support (only) this.
///
/// [`EventData`]: struct.EventData.html
pub struct EventStore {
timed: BinaryHeap<EventData>,
untimed: HashMap<UntimedEvent, Vec<EventData>>,
local_only: bool,
}
impl EventStore {
/// Creates a new event store to be used globally.
pub fn new() -> Self {
Default::default()
}
/// Creates a new event store to be used within a [`Track`].
///
/// This is usually automatically installed by the driver once
/// a track has been registered.
///
/// [`Track`]: ../tracks/struct.Track.html
pub fn new_local() -> Self {
EventStore {
local_only: true,
..Default::default()
}
}
/// Add an event to this store.
///
/// Updates `evt` according to [`EventData::compute_activation`].
///
/// [`EventData::compute_activation`]: struct.EventData.html#method.compute_activation
pub fn add_event(&mut self, mut evt: EventData, now: Duration) {
evt.compute_activation(now);
if self.local_only && evt.event.is_global_only() {
return;
}
use Event::*;
match evt.event {
Core(c) => {
self.untimed
.entry(c.into())
.or_insert_with(Vec::new)
.push(evt);
},
Track(t) => {
self.untimed
.entry(t.into())
.or_insert_with(Vec::new)
.push(evt);
},
Delayed(_) | Periodic(_, _) => {
self.timed.push(evt);
},
_ => {
// Event cancelled.
},
}
}
/// Processes all events due up to and including `now`.
pub(crate) async fn process_timed(&mut self, now: Duration, ctx: EventContext<'_>) {
while let Some(evt) = self.timed.peek() {
if evt
.fire_time
.as_ref()
.expect("Timed event must have a fire_time.")
> &now
{
break;
}
let mut evt = self
.timed
.pop()
.expect("Can only succeed due to peek = Some(...).");
let old_evt_type = evt.event;
if let Some(new_evt_type) = evt.action.act(&ctx).await {
evt.event = new_evt_type;
self.add_event(evt, now);
} else if let Event::Periodic(d, _) = old_evt_type {
evt.event = Event::Periodic(d, None);
self.add_event(evt, now);
}
}
}
/// Processes all events attached to the given track event.
pub(crate) async fn process_untimed(
&mut self,
now: Duration,
untimed_event: UntimedEvent,
ctx: EventContext<'_>,
) {
// move a Vec in and out: not too expensive, but could be better.
// Although it's obvious that moving an event out of one vec and into
// another necessitates that they be different event types, thus entries,
// convincing the compiler of this is non-trivial without making them dedicated
// fields.
let events = self.untimed.remove(&untimed_event);
if let Some(mut events) = events {
// TODO: Possibly use tombstones to prevent realloc/memcpys?
// i.e., never shrink array, replace ended tracks with <DEAD>,
// maintain a "first-track" stack and freelist alongside.
let mut i = 0;
while i < events.len() {
let evt = &mut events[i];
// Only remove/readd if the event type changes (i.e., Some AND new != old)
if let Some(new_evt_type) = evt.action.act(&ctx).await {
if evt.event == new_evt_type {
let mut evt = events.remove(i);
evt.event = new_evt_type;
self.add_event(evt, now);
} else {
i += 1;
}
} else {
i += 1;
};
}
self.untimed.insert(untimed_event, events);
}
}
}
#[derive(Debug, Default)]
pub(crate) struct GlobalEvents {
pub(crate) store: EventStore,
pub(crate) time: Duration,
pub(crate) awaiting_tick: HashMap<TrackEvent, Vec<usize>>,
}
impl GlobalEvents {
pub(crate) fn add_event(&mut self, evt: EventData) {
self.store.add_event(evt, self.time);
}
pub(crate) async fn fire_core_event(&mut self, evt: CoreEvent, ctx: EventContext<'_>) {
self.store.process_untimed(self.time, evt.into(), ctx).await;
}
pub(crate) fn fire_track_event(&mut self, evt: TrackEvent, index: usize) {
let holder = self.awaiting_tick.entry(evt).or_insert_with(Vec::new);
holder.push(index);
}
pub(crate) async fn tick(
&mut self,
events: &mut Vec<EventStore>,
states: &mut Vec<TrackState>,
handles: &mut Vec<TrackHandle>,
) {
// Global timed events
self.time += TIMESTEP_LENGTH;
self.store
.process_timed(self.time, EventContext::Track(&[]))
.await;
// Local timed events
for (i, state) in states.iter_mut().enumerate() {
if state.playing == PlayMode::Play {
state.step_frame();
let event_store = events
.get_mut(i)
.expect("Missing store index for Tick (local timed).");
let handle = handles
.get_mut(i)
.expect("Missing handle index for Tick (local timed).");
event_store
.process_timed(state.play_time, EventContext::Track(&[(&state, &handle)]))
.await;
}
}
for (evt, indices) in self.awaiting_tick.iter() {
let untimed = (*evt).into();
if !indices.is_empty() {
info!("Firing {:?} for {:?}", evt, indices);
}
// Local untimed track events.
for &i in indices.iter() {
let event_store = events
.get_mut(i)
.expect("Missing store index for Tick (local untimed).");
let handle = handles
.get_mut(i)
.expect("Missing handle index for Tick (local untimed).");
let state = states
.get_mut(i)
.expect("Missing state index for Tick (local untimed).");
event_store
.process_untimed(
state.position,
untimed,
EventContext::Track(&[(&state, &handle)]),
)
.await;
}
// Global untimed track events.
if self.store.untimed.contains_key(&untimed) && !indices.is_empty() {
let global_ctx: Vec<(&TrackState, &TrackHandle)> = indices
.iter()
.map(|i| {
(
states
.get(*i)
.expect("Missing state index for Tick (global untimed)"),
handles
.get(*i)
.expect("Missing handle index for Tick (global untimed)"),
)
})
.collect();
self.store
.process_untimed(self.time, untimed, EventContext::Track(&global_ctx[..]))
.await
}
}
// Now drain vecs.
for (_evt, indices) in self.awaiting_tick.iter_mut() {
indices.clear();
}
}
}

16
src/events/track.rs Normal file
View File

@@ -0,0 +1,16 @@
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
/// Track events correspond to certain actions or changes
/// of state, such as a track finishing, looping, or being
/// manually stopped. Voice core events occur on receipt of
/// voice packets and telemetry.
///
/// Track events persist while the `action` in [`EventData`]
/// returns `None`.
///
/// [`EventData`]: struct.EventData.html
pub enum TrackEvent {
/// The attached track has ended.
End,
/// The attached track has looped.
Loop,
}

28
src/events/untimed.rs Normal file
View File

@@ -0,0 +1,28 @@
use super::*;
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
/// Track and voice core events.
///
/// Untimed events persist while the `action` in [`EventData`]
/// returns `None`.
///
/// [`EventData`]: struct.EventData.html
pub enum UntimedEvent {
/// Untimed events belonging to a track, such as state changes, end, or loops.
Track(TrackEvent),
/// Untimed events belonging to the global context, such as finished tracks,
/// client speaking updates, or RT(C)P voice and telemetry data.
Core(CoreEvent),
}
impl From<TrackEvent> for UntimedEvent {
fn from(evt: TrackEvent) -> Self {
UntimedEvent::Track(evt)
}
}
impl From<CoreEvent> for UntimedEvent {
fn from(evt: CoreEvent) -> Self {
UntimedEvent::Core(evt)
}
}

301
src/handler.rs Normal file
View File

@@ -0,0 +1,301 @@
#[cfg(feature = "driver")]
use crate::{driver::Driver, error::ConnectionResult};
use crate::{
error::{JoinError, JoinResult},
id::{ChannelId, GuildId, UserId},
info::{ConnectionInfo, ConnectionProgress},
shards::Shard,
};
use flume::{Receiver, Sender};
use serde_json::json;
use tracing::instrument;
#[cfg(feature = "driver")]
use std::ops::{Deref, DerefMut};
#[derive(Clone, Debug)]
enum Return {
Info(Sender<ConnectionInfo>),
#[cfg(feature = "driver")]
Conn(Sender<ConnectionResult<()>>),
}
/// The Call handler is responsible for a single voice connection, acting
/// as a clean API above the inner state and gateway message management.
///
/// If the `"driver"` feature is enabled, then a Call exposes all control methods of
/// [`Driver`] via `Deref(Mut)`.
///
/// [`Driver`]: driver/struct.Driver.html
/// [`Shard`]: ../gateway/struct.Shard.html
#[derive(Clone, Debug)]
pub struct Call {
connection: Option<(ChannelId, ConnectionProgress, Return)>,
#[cfg(feature = "driver")]
/// The internal controller of the voice connection monitor thread.
driver: Driver,
guild_id: GuildId,
/// Whether the current handler is set to deafen voice connections.
self_deaf: bool,
/// Whether the current handler is set to mute voice connections.
self_mute: bool,
user_id: UserId,
/// Will be set when a `Call` is made via the [`new`][`Call::new`]
/// method.
///
/// When set via [`standalone`][`Call::standalone`], it will not be
/// present.
ws: Option<Shard>,
}
impl Call {
/// Creates a new Call, which will send out WebSocket messages via
/// the given shard.
#[inline]
#[instrument]
pub fn new(guild_id: GuildId, ws: Shard, user_id: UserId) -> Self {
Self::new_raw(guild_id, Some(ws), user_id)
}
/// Creates a new, standalone Call which is not connected via
/// WebSocket to the Gateway.
///
/// Actions such as muting, deafening, and switching channels will not
/// function through this Call and must be done through some other
/// method, as the values will only be internally updated.
///
/// For most use cases you do not want this.
#[inline]
#[instrument]
pub fn standalone(guild_id: GuildId, user_id: UserId) -> Self {
Self::new_raw(guild_id, None, user_id)
}
fn new_raw(guild_id: GuildId, ws: Option<Shard>, user_id: UserId) -> Self {
Call {
connection: None,
#[cfg(feature = "driver")]
driver: Default::default(),
guild_id,
self_deaf: false,
self_mute: false,
user_id,
ws,
}
}
#[instrument(skip(self))]
fn do_connect(&mut self) {
match &self.connection {
Some((_, ConnectionProgress::Complete(c), Return::Info(tx))) => {
// It's okay if the receiver hung up.
let _ = tx.send(c.clone());
},
#[cfg(feature = "driver")]
Some((_, ConnectionProgress::Complete(c), Return::Conn(tx))) => {
self.driver.raw_connect(c.clone(), tx.clone());
},
_ => {},
}
}
/// Sets whether the current connection is to be deafened.
///
/// If there is no live voice connection, then this only acts as a settings
/// update for future connections.
///
/// **Note**: Unlike in the official client, you _can_ be deafened while
/// not being muted.
///
/// **Note**: If the `Call` was created via [`standalone`], then this
/// will _only_ update whether the connection is internally deafened.
///
/// [`standalone`]: #method.standalone
#[instrument(skip(self))]
pub async fn deafen(&mut self, deaf: bool) -> JoinResult<()> {
self.self_deaf = deaf;
self.update().await
}
/// Returns whether the current connection is self-deafened in this server.
///
/// This is purely cosmetic.
#[instrument(skip(self))]
pub fn is_deaf(&self) -> bool {
self.self_deaf
}
#[cfg(feature = "driver")]
/// Connect or switch to the given voice channel by its Id.
#[instrument(skip(self))]
pub async fn join(
&mut self,
channel_id: ChannelId,
) -> JoinResult<Receiver<ConnectionResult<()>>> {
let (tx, rx) = flume::unbounded();
self.connection = Some((
channel_id,
ConnectionProgress::new(self.guild_id, self.user_id),
Return::Conn(tx),
));
self.update().await.map(|_| rx)
}
/// Join the selected voice channel, *without* running/starting an RTP
/// session or running the driver.
///
/// Use this if you require connection info for lavalink,
/// some other voice implementation, or don't want to use the driver for a given call.
#[instrument(skip(self))]
pub async fn join_gateway(
&mut self,
channel_id: ChannelId,
) -> JoinResult<Receiver<ConnectionInfo>> {
let (tx, rx) = flume::unbounded();
self.connection = Some((
channel_id,
ConnectionProgress::new(self.guild_id, self.user_id),
Return::Info(tx),
));
self.update().await.map(|_| rx)
}
/// Leaves the current voice channel, disconnecting from it.
///
/// This does _not_ forget settings, like whether to be self-deafened or
/// self-muted.
///
/// **Note**: If the `Call` was created via [`standalone`], then this
/// will _only_ update whether the connection is internally connected to a
/// voice channel.
///
/// [`standalone`]: #method.standalone
#[instrument(skip(self))]
pub async fn leave(&mut self) -> JoinResult<()> {
// Only send an update if we were in a voice channel.
self.connection = None;
#[cfg(feature = "driver")]
self.driver.leave();
self.update().await
}
/// Sets whether the current connection is to be muted.
///
/// If there is no live voice connection, then this only acts as a settings
/// update for future connections.
///
/// **Note**: If the `Call` was created via [`standalone`], then this
/// will _only_ update whether the connection is internally muted.
///
/// [`standalone`]: #method.standalone
#[instrument(skip(self))]
pub async fn mute(&mut self, mute: bool) -> JoinResult<()> {
self.self_mute = mute;
#[cfg(feature = "driver")]
self.driver.mute(mute);
self.update().await
}
/// Returns whether the current connection is self-muted in this server.
#[instrument(skip(self))]
pub fn is_mute(&self) -> bool {
self.self_mute
}
/// Updates the voice server data.
///
/// You should only need to use this if you initialized the `Call` via
/// [`standalone`].
///
/// Refer to the documentation for [`connect`] for when this will
/// automatically connect to a voice channel.
///
/// [`connect`]: #method.connect
/// [`standalone`]: #method.standalone
#[instrument(skip(self, token))]
pub fn update_server(&mut self, endpoint: String, token: String) {
let try_conn = if let Some((_, ref mut progress, _)) = self.connection.as_mut() {
progress.apply_server_update(endpoint, token)
} else {
false
};
if try_conn {
self.do_connect();
}
}
/// Updates the internal voice state of the current user.
///
/// You should only need to use this if you initialized the `Call` via
/// [`standalone`].
///
/// refer to the documentation for [`connect`] for when this will
/// automatically connect to a voice channel.
///
/// [`connect`]: #method.connect
/// [`standalone`]: #method.standalone
#[instrument(skip(self))]
pub fn update_state(&mut self, session_id: String) {
let try_conn = if let Some((_, ref mut progress, _)) = self.connection.as_mut() {
progress.apply_state_update(session_id)
} else {
false
};
if try_conn {
self.do_connect();
}
}
/// Send an update for the current session over WS.
///
/// Does nothing if initialized via [`standalone`].
///
/// [`standalone`]: #method.standalone
#[instrument(skip(self))]
async fn update(&mut self) -> JoinResult<()> {
if let Some(ws) = self.ws.as_mut() {
let map = json!({
"op": 4,
"d": {
"channel_id": self.connection.as_ref().map(|c| c.0.0),
"guild_id": self.guild_id.0,
"self_deaf": self.self_deaf,
"self_mute": self.self_mute,
}
});
ws.send(map).await
} else {
Err(JoinError::NoSender)
}
}
}
#[cfg(feature = "driver")]
impl Deref for Call {
type Target = Driver;
fn deref(&self) -> &Self::Target {
&self.driver
}
}
#[cfg(feature = "driver")]
impl DerefMut for Call {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.driver
}
}

121
src/id.rs Normal file
View File

@@ -0,0 +1,121 @@
//! Newtypes around Discord IDs for library cross-compatibility.
#[cfg(feature = "driver")]
use crate::model::id::{GuildId as DriverGuild, UserId as DriverUser};
#[cfg(feature = "serenity")]
use serenity::model::id::{
ChannelId as SerenityChannel,
GuildId as SerenityGuild,
UserId as SerenityUser,
};
use std::fmt::{Display, Formatter, Result as FmtResult};
#[cfg(feature = "twilight")]
use twilight_model::id::{
ChannelId as TwilightChannel,
GuildId as TwilightGuild,
UserId as TwilightUser,
};
/// ID of a Discord voice/text channel.
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
pub struct ChannelId(pub u64);
/// ID of a Discord guild (colloquially, "server").
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
pub struct GuildId(pub u64);
/// ID of a Discord user.
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
pub struct UserId(pub u64);
impl Display for ChannelId {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
Display::fmt(&self.0, f)
}
}
impl From<u64> for ChannelId {
fn from(id: u64) -> Self {
Self(id)
}
}
#[cfg(feature = "serenity")]
impl From<SerenityChannel> for ChannelId {
fn from(id: SerenityChannel) -> Self {
Self(id.0)
}
}
#[cfg(feature = "twilight")]
impl From<TwilightChannel> for ChannelId {
fn from(id: TwilightChannel) -> Self {
Self(id.0)
}
}
impl Display for GuildId {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
Display::fmt(&self.0, f)
}
}
impl From<u64> for GuildId {
fn from(id: u64) -> Self {
Self(id)
}
}
#[cfg(feature = "serenity")]
impl From<SerenityGuild> for GuildId {
fn from(id: SerenityGuild) -> Self {
Self(id.0)
}
}
#[cfg(feature = "driver")]
impl From<GuildId> for DriverGuild {
fn from(id: GuildId) -> Self {
Self(id.0)
}
}
#[cfg(feature = "twilight")]
impl From<TwilightGuild> for GuildId {
fn from(id: TwilightGuild) -> Self {
Self(id.0)
}
}
impl Display for UserId {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
Display::fmt(&self.0, f)
}
}
impl From<u64> for UserId {
fn from(id: u64) -> Self {
Self(id)
}
}
#[cfg(feature = "serenity")]
impl From<SerenityUser> for UserId {
fn from(id: SerenityUser) -> Self {
Self(id.0)
}
}
#[cfg(feature = "driver")]
impl From<UserId> for DriverUser {
fn from(id: UserId) -> Self {
Self(id.0)
}
}
#[cfg(feature = "twilight")]
impl From<TwilightUser> for UserId {
fn from(id: TwilightUser) -> Self {
Self(id.0)
}
}

137
src/info.rs Normal file
View File

@@ -0,0 +1,137 @@
use crate::id::{GuildId, UserId};
use std::fmt;
#[derive(Clone, Debug)]
pub(crate) enum ConnectionProgress {
Complete(ConnectionInfo),
Incomplete(Partial),
}
impl ConnectionProgress {
pub fn new(guild_id: GuildId, user_id: UserId) -> Self {
ConnectionProgress::Incomplete(Partial {
guild_id,
user_id,
..Default::default()
})
}
pub(crate) fn apply_state_update(&mut self, session_id: String) -> bool {
use ConnectionProgress::*;
match self {
Complete(c) => {
let should_reconn = c.session_id != session_id;
c.session_id = session_id;
should_reconn
},
Incomplete(i) => i
.apply_state_update(session_id)
.map(|info| {
*self = Complete(info);
})
.is_some(),
}
}
pub(crate) fn apply_server_update(&mut self, endpoint: String, token: String) -> bool {
use ConnectionProgress::*;
match self {
Complete(c) => {
let should_reconn = c.endpoint != endpoint || c.token != token;
c.endpoint = endpoint;
c.token = token;
should_reconn
},
Incomplete(i) => i
.apply_server_update(endpoint, token)
.map(|info| {
*self = Complete(info);
})
.is_some(),
}
}
}
/// Parameters and information needed to start communicating with Discord's voice servers, either
/// with the Songbird driver, lavalink, or other system.
#[derive(Clone)]
pub struct ConnectionInfo {
/// URL of the voice websocket gateway server assigned to this call.
pub endpoint: String,
/// ID of the target voice channel's parent guild.
///
/// Bots cannot connect to a guildless (i.e., direct message) voice call.
pub guild_id: GuildId,
/// Unique string describing this session for validation/authentication purposes.
pub session_id: String,
/// Ephemeral secret used to validate the above session.
pub token: String,
/// UserID of this bot.
pub user_id: UserId,
}
impl fmt::Debug for ConnectionInfo {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ConnectionInfo")
.field("endpoint", &self.endpoint)
.field("guild_id", &self.guild_id)
.field("session_id", &self.session_id)
.field("token", &"<secret>")
.field("user_id", &self.user_id)
.finish()
}
}
#[derive(Clone, Default)]
pub(crate) struct Partial {
pub endpoint: Option<String>,
pub guild_id: GuildId,
pub session_id: Option<String>,
pub token: Option<String>,
pub user_id: UserId,
}
impl fmt::Debug for Partial {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Partial")
.field("endpoint", &self.endpoint)
.field("session_id", &self.session_id)
.field("token_is_some", &self.token.is_some())
.finish()
}
}
impl Partial {
fn finalise(&mut self) -> Option<ConnectionInfo> {
if self.endpoint.is_some() && self.session_id.is_some() && self.token.is_some() {
let endpoint = self.endpoint.take().unwrap();
let session_id = self.session_id.take().unwrap();
let token = self.token.take().unwrap();
Some(ConnectionInfo {
endpoint,
session_id,
token,
guild_id: self.guild_id,
user_id: self.user_id,
})
} else {
None
}
}
fn apply_state_update(&mut self, session_id: String) -> Option<ConnectionInfo> {
self.session_id = Some(session_id);
self.finalise()
}
fn apply_server_update(&mut self, endpoint: String, token: String) -> Option<ConnectionInfo> {
self.endpoint = Some(endpoint);
self.token = Some(token);
self.finalise()
}
}

View File

@@ -0,0 +1,303 @@
use super::{apply_length_hint, compressed_cost_per_sec, default_config};
use crate::{
constants::*,
input::{
error::{Error, Result},
CodecType,
Container,
Input,
Metadata,
Reader,
},
};
use audiopus::{
coder::Encoder as OpusEncoder,
Application,
Bitrate,
Channels,
Error as OpusError,
ErrorCode as OpusErrorCode,
SampleRate,
};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use std::{
convert::TryInto,
io::{Error as IoError, ErrorKind as IoErrorKind, Read, Result as IoResult},
mem,
sync::atomic::{AtomicUsize, Ordering},
};
use streamcatcher::{Config, NeedsBytes, Stateful, Transform, TransformPosition, TxCatcher};
use tracing::{debug, trace};
/// A wrapper around an existing [`Input`] which compresses
/// the input using the Opus codec before storing it in memory.
///
/// The main purpose of this wrapper is to enable seeking on
/// incompatible sources (i.e., ffmpeg output) and to ease resource
/// consumption for commonly reused/shared tracks. [`Restartable`]
/// and [`Memory`] offer the same functionality with different
/// tradeoffs.
///
/// This is intended for use with larger, repeatedly used audio
/// tracks shared between sources, and stores the sound data
/// retrieved as **compressed Opus audio**. There is an associated memory cost,
/// but this is far smaller than using a [`Memory`].
///
/// [`Input`]: ../struct.Input.html
/// [`Memory`]: struct.Memory.html
/// [`Restartable`]: ../struct.Restartable.html
#[derive(Clone, Debug)]
pub struct Compressed {
/// Inner shared bytestore.
pub raw: TxCatcher<Box<Input>, OpusCompressor>,
/// Metadata moved out of the captured source.
pub metadata: Metadata,
/// Stereo-ness of the captured source.
pub stereo: bool,
}
impl Compressed {
/// Wrap an existing [`Input`] with an in-memory store, compressed using Opus.
///
/// [`Input`]: ../struct.Input.html
/// [`Metadata.duration`]: ../struct.Metadata.html#structfield.duration
pub fn new(source: Input, bitrate: Bitrate) -> Result<Self> {
Self::with_config(source, bitrate, None)
}
/// Wrap an existing [`Input`] with an in-memory store, compressed using Opus.
///
/// `config.length_hint` may be used to control the size of the initial chunk, preventing
/// needless allocations and copies. If this is not present, the value specified in
/// `source`'s [`Metadata.duration`] will be used.
///
/// [`Input`]: ../struct.Input.html
/// [`Metadata.duration`]: ../struct.Metadata.html#structfield.duration
pub fn with_config(source: Input, bitrate: Bitrate, config: Option<Config>) -> Result<Self> {
let channels = if source.stereo {
Channels::Stereo
} else {
Channels::Mono
};
let mut encoder = OpusEncoder::new(SampleRate::Hz48000, channels, Application::Audio)?;
encoder.set_bitrate(bitrate)?;
Self::with_encoder(source, encoder, config)
}
/// Wrap an existing [`Input`] with an in-memory store, compressed using a user-defined
/// Opus encoder.
///
/// `length_hint` functions as in [`new`]. This function's behaviour is undefined if your encoder
/// has a different sample rate than 48kHz, and if the decoder has a different channel count from the source.
///
/// [`Input`]: ../struct.Input.html
/// [`new`]: #method.new
pub fn with_encoder(
mut source: Input,
encoder: OpusEncoder,
config: Option<Config>,
) -> Result<Self> {
let bitrate = encoder.bitrate()?;
let cost_per_sec = compressed_cost_per_sec(bitrate);
let stereo = source.stereo;
let metadata = source.metadata.take();
let mut config = config.unwrap_or_else(|| default_config(cost_per_sec));
// apply length hint.
if config.length_hint.is_none() {
if let Some(dur) = metadata.duration {
apply_length_hint(&mut config, dur, cost_per_sec);
}
}
let raw = config
.build_tx(Box::new(source), OpusCompressor::new(encoder, stereo))
.map_err(Error::Streamcatcher)?;
Ok(Self {
raw,
metadata,
stereo,
})
}
/// Acquire a new handle to this object, creating a new
/// view of the existing cached data from the beginning.
pub fn new_handle(&self) -> Self {
Self {
raw: self.raw.new_handle(),
metadata: self.metadata.clone(),
stereo: self.stereo,
}
}
}
impl From<Compressed> for Input {
fn from(src: Compressed) -> Self {
Input::new(
true,
Reader::Compressed(src.raw),
CodecType::Opus
.try_into()
.expect("Default decoder values are known to be valid."),
Container::Dca { first_frame: 0 },
Some(src.metadata),
)
}
}
/// Transform applied inside [`Compressed`], converting a floating-point PCM
/// input stream into a DCA-framed Opus stream.
///
/// Created and managed by [`Compressed`].
///
/// [`Compressed`]: struct.Compressed.html
#[derive(Debug)]
pub struct OpusCompressor {
encoder: OpusEncoder,
last_frame: Vec<u8>,
stereo_input: bool,
frame_pos: usize,
audio_bytes: AtomicUsize,
}
impl OpusCompressor {
fn new(encoder: OpusEncoder, stereo_input: bool) -> Self {
Self {
encoder,
last_frame: Vec::with_capacity(4000),
stereo_input,
frame_pos: 0,
audio_bytes: Default::default(),
}
}
}
impl<T> Transform<T> for OpusCompressor
where
T: Read,
{
fn transform_read(&mut self, src: &mut T, buf: &mut [u8]) -> IoResult<TransformPosition> {
let output_start = mem::size_of::<u16>();
let mut eof = false;
let mut raw_len = 0;
let mut out = None;
let mut sample_buf = [0f32; STEREO_FRAME_SIZE];
let samples_in_frame = if self.stereo_input {
STEREO_FRAME_SIZE
} else {
MONO_FRAME_SIZE
};
// Purge old frame and read new, if needed.
if self.frame_pos == self.last_frame.len() + output_start || self.last_frame.is_empty() {
self.last_frame.resize(self.last_frame.capacity(), 0);
// We can't use `read_f32_into` because we can't guarantee the buffer will be filled.
for el in sample_buf[..samples_in_frame].iter_mut() {
match src.read_f32::<LittleEndian>() {
Ok(sample) => {
*el = sample;
raw_len += 1;
},
Err(e) if e.kind() == IoErrorKind::UnexpectedEof => {
eof = true;
break;
},
Err(e) => {
out = Some(Err(e));
break;
},
}
}
if out.is_none() && raw_len > 0 {
loop {
// NOTE: we don't index by raw_len because the last frame can be too small
// to occupy a "whole packet". Zero-padding is the correct behaviour.
match self
.encoder
.encode_float(&sample_buf[..samples_in_frame], &mut self.last_frame[..])
{
Ok(pkt_len) => {
trace!("Next packet to write has {:?}", pkt_len);
self.frame_pos = 0;
self.last_frame.truncate(pkt_len);
break;
},
Err(OpusError::Opus(OpusErrorCode::BufferTooSmall)) => {
// If we need more capacity to encode this frame, then take it.
trace!("Resizing inner buffer (+256).");
self.last_frame.resize(self.last_frame.len() + 256, 0);
},
Err(e) => {
debug!("Read error {:?} {:?} {:?}.", e, out, raw_len);
out = Some(Err(IoError::new(IoErrorKind::Other, e)));
break;
},
}
}
}
}
if out.is_none() {
// Write from frame we have.
let start = if self.frame_pos < output_start {
(&mut buf[..output_start])
.write_i16::<LittleEndian>(self.last_frame.len() as i16)
.expect(
"Minimum bytes requirement for Opus (2) should mean that an i16 \
may always be written.",
);
self.frame_pos += output_start;
trace!("Wrote frame header: {}.", self.last_frame.len());
output_start
} else {
0
};
let out_pos = self.frame_pos - output_start;
let remaining = self.last_frame.len() - out_pos;
let write_len = remaining.min(buf.len() - start);
buf[start..start + write_len]
.copy_from_slice(&self.last_frame[out_pos..out_pos + write_len]);
self.frame_pos += write_len;
trace!("Appended {} to inner store", write_len);
out = Some(Ok(write_len + start));
}
// NOTE: use of raw_len here preserves true sample length even if
// stream is extended to 20ms boundary.
out.unwrap_or_else(|| Err(IoError::new(IoErrorKind::Other, "Unclear.")))
.map(|compressed_sz| {
self.audio_bytes
.fetch_add(raw_len * mem::size_of::<f32>(), Ordering::Release);
if eof {
TransformPosition::Finished
} else {
TransformPosition::Read(compressed_sz)
}
})
}
}
impl NeedsBytes for OpusCompressor {
fn min_bytes_required(&self) -> usize {
2
}
}
impl Stateful for OpusCompressor {
type State = usize;
fn state(&self) -> Self::State {
self.audio_bytes.load(Ordering::Acquire)
}
}

40
src/input/cached/hint.rs Normal file
View File

@@ -0,0 +1,40 @@
use std::time::Duration;
use streamcatcher::Config;
/// Expected amount of time that an input should last.
#[derive(Copy, Clone, Debug)]
pub enum LengthHint {
/// Estimate of a source's length in bytes.
Bytes(usize),
/// Estimate of a source's length in time.
///
/// This will be converted to a bytecount at setup.
Time(Duration),
}
impl From<usize> for LengthHint {
fn from(size: usize) -> Self {
LengthHint::Bytes(size)
}
}
impl From<Duration> for LengthHint {
fn from(size: Duration) -> Self {
LengthHint::Time(size)
}
}
/// Modify the given cache configuration to initially allocate
/// enough bytes to store a length of audio at the given bitrate.
pub fn apply_length_hint<H>(config: &mut Config, hint: H, cost_per_sec: usize)
where
H: Into<LengthHint>,
{
config.length_hint = Some(match hint.into() {
LengthHint::Bytes(a) => a,
LengthHint::Time(t) => {
let s = t.as_secs() + if t.subsec_millis() > 0 { 1 } else { 0 };
(s as usize) * cost_per_sec
},
});
}

116
src/input/cached/memory.rs Normal file
View File

@@ -0,0 +1,116 @@
use super::{apply_length_hint, default_config, raw_cost_per_sec};
use crate::input::{
error::{Error, Result},
CodecType,
Container,
Input,
Metadata,
Reader,
};
use std::convert::{TryFrom, TryInto};
use streamcatcher::{Catcher, Config};
/// A wrapper around an existing [`Input`] which caches
/// the decoded and converted audio data locally in memory.
///
/// The main purpose of this wrapper is to enable seeking on
/// incompatible sources (i.e., ffmpeg output) and to ease resource
/// consumption for commonly reused/shared tracks. [`Restartable`]
/// and [`Compressed`] offer the same functionality with different
/// tradeoffs.
///
/// This is intended for use with small, repeatedly used audio
/// tracks shared between sources, and stores the sound data
/// retrieved in **uncompressed floating point** form to minimise the
/// cost of audio processing. This is a significant *3 Mbps (375 kiB/s)*,
/// or 131 MiB of RAM for a 6 minute song.
///
/// [`Input`]: ../struct.Input.html
/// [`Compressed`]: struct.Compressed.html
/// [`Restartable`]: ../struct.Restartable.html
#[derive(Clone, Debug)]
pub struct Memory {
/// Inner shared bytestore.
pub raw: Catcher<Box<Reader>>,
/// Metadata moved out of the captured source.
pub metadata: Metadata,
/// Codec used to read the inner bytestore.
pub kind: CodecType,
/// Stereo-ness of the captured source.
pub stereo: bool,
/// Framing mechanism for the inner bytestore.
pub container: Container,
}
impl Memory {
/// Wrap an existing [`Input`] with an in-memory store with the same codec and framing.
///
/// [`Input`]: ../struct.Input.html
pub fn new(source: Input) -> Result<Self> {
Self::with_config(source, None)
}
/// Wrap an existing [`Input`] with an in-memory store with the same codec and framing.
///
/// `length_hint` may be used to control the size of the initial chunk, preventing
/// needless allocations and copies. If this is not present, the value specified in
/// `source`'s [`Metadata.duration`] will be used, assuming that the source is uncompressed.
///
/// [`Input`]: ../struct.Input.html
/// [`Metadata.duration`]: ../struct.Metadata.html#structfield.duration
pub fn with_config(mut source: Input, config: Option<Config>) -> Result<Self> {
let stereo = source.stereo;
let kind = (&source.kind).into();
let container = source.container;
let metadata = source.metadata.take();
let cost_per_sec = raw_cost_per_sec(stereo);
let mut config = config.unwrap_or_else(|| default_config(cost_per_sec));
// apply length hint.
if config.length_hint.is_none() {
if let Some(dur) = metadata.duration {
apply_length_hint(&mut config, dur, cost_per_sec);
}
}
let raw = config
.build(Box::new(source.reader))
.map_err(Error::Streamcatcher)?;
Ok(Self {
raw,
metadata,
kind,
stereo,
container,
})
}
/// Acquire a new handle to this object, creating a new
/// view of the existing cached data from the beginning.
pub fn new_handle(&self) -> Self {
Self {
raw: self.raw.new_handle(),
metadata: self.metadata.clone(),
kind: self.kind,
stereo: self.stereo,
container: self.container,
}
}
}
impl TryFrom<Memory> for Input {
type Error = Error;
fn try_from(src: Memory) -> Result<Self> {
Ok(Input::new(
src.stereo,
Reader::Memory(src.raw),
src.kind.try_into()?,
src.container,
Some(src.metadata),
))
}
}

44
src/input/cached/mod.rs Normal file
View File

@@ -0,0 +1,44 @@
//! In-memory, shared input sources for reuse between calls, fast seeking, and
//! direct Opus frame passthrough.
mod compressed;
mod hint;
mod memory;
#[cfg(test)]
mod tests;
pub use self::{compressed::*, hint::*, memory::*};
use crate::constants::*;
use crate::input::utils;
use audiopus::Bitrate;
use std::{mem, time::Duration};
use streamcatcher::{Config, GrowthStrategy};
/// Estimates the cost, in B/s, of audio data compressed at the given bitrate.
pub fn compressed_cost_per_sec(bitrate: Bitrate) -> usize {
let framing_cost_per_sec = AUDIO_FRAME_RATE * mem::size_of::<u16>();
let bitrate_raw = match bitrate {
Bitrate::BitsPerSecond(i) => i,
Bitrate::Auto => 64_000,
Bitrate::Max => 512_000,
} as usize;
(bitrate_raw / 8) + framing_cost_per_sec
}
/// Calculates the cost, in B/s, of raw floating-point audio data.
pub fn raw_cost_per_sec(stereo: bool) -> usize {
utils::timestamp_to_byte_count(Duration::from_secs(1), stereo)
}
/// Provides the default config used by a cached source.
///
/// This maps to the default configuration in [`streamcatcher`], using
/// a constant chunk size of 5s worth of audio at the given bitrate estimate.
///
/// [`streamcatcher`]: https://docs.rs/streamcatcher/0.1.0/streamcatcher/struct.Config.html
pub fn default_config(cost_per_sec: usize) -> Config {
Config::new().chunk_size(GrowthStrategy::Constant(5 * cost_per_sec))
}

79
src/input/cached/tests.rs Normal file
View File

@@ -0,0 +1,79 @@
use super::*;
use crate::{
constants::*,
input::{error::Error, ffmpeg, Codec, Container, Input, Reader},
test_utils::*,
};
use audiopus::{coder::Decoder, Bitrate, Channels, SampleRate};
use byteorder::{LittleEndian, ReadBytesExt};
use std::io::{Cursor, Read};
#[tokio::test]
async fn streamcatcher_preserves_file() {
let input = make_sine(50 * MONO_FRAME_SIZE, true);
let input_len = input.len();
let mut raw = default_config(raw_cost_per_sec(true))
.build(Cursor::new(input.clone()))
.map_err(Error::Streamcatcher)
.unwrap();
let mut out_buf = vec![];
let read = raw.read_to_end(&mut out_buf).unwrap();
assert_eq!(input_len, read);
assert_eq!(input, out_buf);
}
#[test]
fn compressed_scans_frames_decodes_mono() {
let data = one_s_compressed_sine(false);
run_through_dca(data.raw);
}
#[test]
fn compressed_scans_frames_decodes_stereo() {
let data = one_s_compressed_sine(true);
run_through_dca(data.raw);
}
#[test]
fn compressed_triggers_valid_passthrough() {
let mut input = Input::from(one_s_compressed_sine(true));
assert!(input.supports_passthrough());
let mut opus_buf = [0u8; 10_000];
let mut signal_buf = [0i16; 1920];
let opus_len = input.read_opus_frame(&mut opus_buf[..]).unwrap();
let mut decoder = Decoder::new(SampleRate::Hz48000, Channels::Stereo).unwrap();
decoder
.decode(Some(&opus_buf[..opus_len]), &mut signal_buf[..], false)
.unwrap();
}
fn one_s_compressed_sine(stereo: bool) -> Compressed {
let data = make_sine(50 * MONO_FRAME_SIZE, stereo);
let input = Input::new(stereo, data.into(), Codec::FloatPcm, Container::Raw, None);
Compressed::new(input, Bitrate::BitsPerSecond(128_000)).unwrap()
}
fn run_through_dca(mut src: impl Read) {
let mut decoder = Decoder::new(SampleRate::Hz48000, Channels::Stereo).unwrap();
let mut pkt_space = [0u8; 10_000];
let mut signals = [0i16; 1920];
while let Ok(frame_len) = src.read_i16::<LittleEndian>() {
let pkt_len = src.read(&mut pkt_space[..frame_len as usize]).unwrap();
decoder
.decode(Some(&pkt_space[..pkt_len]), &mut signals[..], false)
.unwrap();
}
}

38
src/input/child.rs Normal file
View File

@@ -0,0 +1,38 @@
use super::*;
use std::{
io::{BufReader, Read},
process::Child,
};
use tracing::debug;
/// Handle for a child process which ensures that any subprocesses are properly closed
/// on drop.
#[derive(Debug)]
pub struct ChildContainer(Child);
pub(crate) fn child_to_reader<T>(child: Child) -> Reader {
Reader::Pipe(BufReader::with_capacity(
STEREO_FRAME_SIZE * mem::size_of::<T>() * CHILD_BUFFER_LEN,
ChildContainer(child),
))
}
impl From<Child> for Reader {
fn from(container: Child) -> Self {
child_to_reader::<f32>(container)
}
}
impl Read for ChildContainer {
fn read(&mut self, buffer: &mut [u8]) -> IoResult<usize> {
self.0.stdout.as_mut().unwrap().read(buffer)
}
}
impl Drop for ChildContainer {
fn drop(&mut self) {
if let Err(e) = self.0.kill() {
debug!("Error awaiting child process: {:?}", e);
}
}
}

99
src/input/codec/mod.rs Normal file
View File

@@ -0,0 +1,99 @@
//! Decoding schemes for input audio bytestreams.
mod opus;
pub use self::opus::OpusDecoderState;
use super::*;
use std::{fmt::Debug, mem};
/// State used to decode input bytes of an [`Input`].
///
/// [`Input`]: ../struct.Input.html
#[non_exhaustive]
#[derive(Clone, Debug)]
pub enum Codec {
/// The inner bytestream is encoded using the Opus codec, to be decoded
/// using the given state.
///
/// Must be combined with a non-[`Raw`] container.
///
/// [`Raw`]: ../enum.Container.html#variant.Raw
Opus(OpusDecoderState),
/// The inner bytestream is encoded using raw `i16` samples.
///
/// Must be combined with a [`Raw`] container.
///
/// [`Raw`]: ../enum.Container.html#variant.Raw
Pcm,
/// The inner bytestream is encoded using raw `f32` samples.
///
/// Must be combined with a [`Raw`] container.
///
/// [`Raw`]: ../enum.Container.html#variant.Raw
FloatPcm,
}
impl From<&Codec> for CodecType {
fn from(f: &Codec) -> Self {
use Codec::*;
match f {
Opus(_) => Self::Opus,
Pcm => Self::Pcm,
FloatPcm => Self::FloatPcm,
}
}
}
/// Type of data being passed into an [`Input`].
///
/// [`Input`]: ../struct.Input.html
#[non_exhaustive]
#[derive(Copy, Clone, Debug)]
pub enum CodecType {
/// The inner bytestream is encoded using the Opus codec.
///
/// Must be combined with a non-[`Raw`] container.
///
/// [`Raw`]: ../enum.Container.html#variant.Raw
Opus,
/// The inner bytestream is encoded using raw `i16` samples.
///
/// Must be combined with a [`Raw`] container.
///
/// [`Raw`]: ../enum.Container.html#variant.Raw
Pcm,
/// The inner bytestream is encoded using raw `f32` samples.
///
/// Must be combined with a [`Raw`] container.
///
/// [`Raw`]: ../enum.Container.html#variant.Raw
FloatPcm,
}
impl CodecType {
/// Returns the length of a single output sample, in bytes.
pub fn sample_len(&self) -> usize {
use CodecType::*;
match self {
Opus | FloatPcm => mem::size_of::<f32>(),
Pcm => mem::size_of::<i16>(),
}
}
}
impl TryFrom<CodecType> for Codec {
type Error = Error;
fn try_from(f: CodecType) -> Result<Self> {
use CodecType::*;
match f {
Opus => Ok(Codec::Opus(OpusDecoderState::new()?)),
Pcm => Ok(Codec::Pcm),
FloatPcm => Ok(Codec::FloatPcm),
}
}
}

43
src/input/codec/opus.rs Normal file
View File

@@ -0,0 +1,43 @@
use crate::constants::*;
use audiopus::{coder::Decoder as OpusDecoder, Channels, Error as OpusError};
use parking_lot::Mutex;
use std::sync::Arc;
#[derive(Clone, Debug)]
/// Inner state
pub struct OpusDecoderState {
/// Inner decoder used to convert opus frames into a stream of samples.
pub decoder: Arc<Mutex<OpusDecoder>>,
/// Controls whether this source allows direct Opus frame passthrough.
/// Defaults to `true`.
///
/// Enabling this flag is a promise from the programmer to the audio core
/// that the source has been encoded at 48kHz, using 20ms long frames.
/// If you cannot guarantee this, disable this flag (or else risk nasal demons)
/// and bizarre audio behaviour.
pub allow_passthrough: bool,
pub(crate) current_frame: Vec<f32>,
pub(crate) frame_pos: usize,
pub(crate) should_reset: bool,
}
impl OpusDecoderState {
/// Creates a new decoder, having stereo output at 48kHz.
pub fn new() -> Result<Self, OpusError> {
Ok(Self::from_decoder(OpusDecoder::new(
SAMPLE_RATE,
Channels::Stereo,
)?))
}
/// Creates a new decoder pre-configured by the user.
pub fn from_decoder(decoder: OpusDecoder) -> Self {
Self {
decoder: Arc::new(Mutex::new(decoder)),
allow_passthrough: true,
current_frame: Vec::with_capacity(STEREO_FRAME_SIZE),
frame_pos: 0,
should_reset: false,
}
}
}

View File

@@ -0,0 +1,8 @@
/// Information used in audio frame detection.
#[derive(Clone, Copy, Debug)]
pub struct Frame {
/// Length of this frame's header, in bytes.
pub header_len: usize,
/// Payload length, in bytes.
pub frame_len: usize,
}

View File

@@ -0,0 +1,69 @@
mod frame;
pub use frame::*;
use super::CodecType;
use byteorder::{LittleEndian, ReadBytesExt};
use std::{
fmt::Debug,
io::{Read, Result as IoResult},
mem,
};
/// Marker and state for decoding framed input files.
#[non_exhaustive]
#[derive(Clone, Copy, Debug)]
pub enum Container {
/// Raw, unframed input.
Raw,
/// Framed input, beginning with a JSON header.
///
/// Frames have the form `{ len: i16, payload: [u8; len]}`.
Dca {
/// Byte index of the first frame after the JSON header.
first_frame: usize,
},
}
impl Container {
/// Tries to read the header of the next frame from an input stream.
pub fn next_frame_length(
&mut self,
mut reader: impl Read,
input: CodecType,
) -> IoResult<Frame> {
use Container::*;
match self {
Raw => Ok(Frame {
header_len: 0,
frame_len: input.sample_len(),
}),
Dca { .. } => reader.read_i16::<LittleEndian>().map(|frame_len| Frame {
header_len: mem::size_of::<i16>(),
frame_len: frame_len.max(0) as usize,
}),
}
}
/// Tries to seek on an input directly using sample length, if the input
/// is unframed.
pub fn try_seek_trivial(&self, input: CodecType) -> Option<usize> {
use Container::*;
match self {
Raw => Some(input.sample_len()),
_ => None,
}
}
/// Returns the byte index of the first frame containing audio payload data.
pub fn input_start(&self) -> usize {
use Container::*;
match self {
Raw => 0,
Dca { first_frame } => *first_frame,
}
}
}

137
src/input/dca.rs Normal file
View File

@@ -0,0 +1,137 @@
use super::{codec::OpusDecoderState, error::DcaError, Codec, Container, Input, Metadata, Reader};
use serde::Deserialize;
use std::{ffi::OsStr, io::BufReader, mem};
use tokio::{fs::File as TokioFile, io::AsyncReadExt};
/// Creates a streamed audio source from a DCA file.
/// Currently only accepts the [DCA1 format](https://github.com/bwmarrin/dca).
pub async fn dca<P: AsRef<OsStr>>(path: P) -> Result<Input, DcaError> {
_dca(path.as_ref()).await
}
async fn _dca(path: &OsStr) -> Result<Input, DcaError> {
let mut reader = TokioFile::open(path).await.map_err(DcaError::IoError)?;
let mut header = [0u8; 4];
// Read in the magic number to verify it's a DCA file.
reader
.read_exact(&mut header)
.await
.map_err(DcaError::IoError)?;
if header != b"DCA1"[..] {
return Err(DcaError::InvalidHeader);
}
let size = reader
.read_i32_le()
.await
.map_err(|_| DcaError::InvalidHeader)?;
// Sanity check
if size < 2 {
return Err(DcaError::InvalidSize(size));
}
let mut raw_json = Vec::with_capacity(size as usize);
let mut json_reader = reader.take(size as u64);
json_reader
.read_to_end(&mut raw_json)
.await
.map_err(DcaError::IoError)?;
let reader = BufReader::new(json_reader.into_inner().into_std().await);
let metadata: Metadata = serde_json::from_slice::<DcaMetadata>(raw_json.as_slice())
.map_err(DcaError::InvalidMetadata)?
.into();
let stereo = metadata.channels == Some(2);
Ok(Input::new(
stereo,
Reader::File(reader),
Codec::Opus(OpusDecoderState::new().map_err(DcaError::Opus)?),
Container::Dca {
first_frame: (size as usize) + mem::size_of::<i32>() + header.len(),
},
Some(metadata),
))
}
#[derive(Debug, Deserialize)]
pub(crate) struct DcaMetadata {
pub(crate) dca: Dca,
pub(crate) opus: Opus,
pub(crate) info: Option<Info>,
pub(crate) origin: Option<Origin>,
pub(crate) extra: Option<serde_json::Value>,
}
#[derive(Debug, Deserialize)]
pub(crate) struct Dca {
pub(crate) version: u64,
pub(crate) tool: Tool,
}
#[derive(Debug, Deserialize)]
pub(crate) struct Tool {
pub(crate) name: String,
pub(crate) version: String,
pub(crate) url: String,
pub(crate) author: String,
}
#[derive(Debug, Deserialize)]
pub(crate) struct Opus {
pub(crate) mode: String,
pub(crate) sample_rate: u32,
pub(crate) frame_size: u64,
pub(crate) abr: u64,
pub(crate) vbr: u64,
pub(crate) channels: u8,
}
#[derive(Debug, Deserialize)]
pub(crate) struct Info {
pub(crate) title: Option<String>,
pub(crate) artist: Option<String>,
pub(crate) album: Option<String>,
pub(crate) genre: Option<String>,
pub(crate) cover: Option<String>,
}
#[derive(Debug, Deserialize)]
pub(crate) struct Origin {
pub(crate) source: Option<String>,
pub(crate) abr: Option<u64>,
pub(crate) channels: Option<u8>,
pub(crate) encoding: Option<String>,
pub(crate) url: Option<String>,
}
impl From<DcaMetadata> for Metadata {
fn from(mut d: DcaMetadata) -> Self {
let (title, artist) = d
.info
.take()
.map(|mut m| (m.title.take(), m.artist.take()))
.unwrap_or_else(|| (None, None));
let channels = Some(d.opus.channels);
let sample_rate = Some(d.opus.sample_rate);
Self {
title,
artist,
channels,
sample_rate,
..Default::default()
}
}
}

93
src/input/error.rs Normal file
View File

@@ -0,0 +1,93 @@
//! Errors caused by input creation.
use audiopus::Error as OpusError;
use serde_json::{Error as JsonError, Value};
use std::{io::Error as IoError, process::Output};
use streamcatcher::CatcherError;
/// An error returned when creating a new [`Input`].
///
/// [`Input`]: ../struct.Input.html
#[derive(Debug)]
#[non_exhaustive]
pub enum Error {
/// An error occurred while opening a new DCA source.
Dca(DcaError),
/// An error occurred while reading, or opening a file.
Io(IoError),
/// An error occurred while parsing JSON (i.e., during metadata/stereo detection).
Json(JsonError),
/// An error occurred within the Opus codec.
Opus(OpusError),
/// Failed to extract metadata from alternate pipe.
Metadata,
/// Apparently failed to create stdout.
Stdout,
/// An error occurred while checking if a path is stereo.
Streams,
/// Configuration error for a cached Input.
Streamcatcher(CatcherError),
/// An error occurred while processing the JSON output from `youtube-dl`.
///
/// The JSON output is given.
YouTubeDLProcessing(Value),
/// An error occurred while running `youtube-dl`.
YouTubeDLRun(Output),
/// The `url` field of the `youtube-dl` JSON output was not present.
///
/// The JSON output is given.
YouTubeDLUrl(Value),
}
impl From<CatcherError> for Error {
fn from(e: CatcherError) -> Self {
Error::Streamcatcher(e)
}
}
impl From<DcaError> for Error {
fn from(e: DcaError) -> Self {
Error::Dca(e)
}
}
impl From<IoError> for Error {
fn from(e: IoError) -> Error {
Error::Io(e)
}
}
impl From<JsonError> for Error {
fn from(e: JsonError) -> Self {
Error::Json(e)
}
}
impl From<OpusError> for Error {
fn from(e: OpusError) -> Error {
Error::Opus(e)
}
}
/// An error returned from the [`dca`] method.
///
/// [`dca`]: ../fn.dca.html
#[derive(Debug)]
#[non_exhaustive]
pub enum DcaError {
/// An error occurred while reading, or opening a file.
IoError(IoError),
/// The file opened did not have a valid DCA JSON header.
InvalidHeader,
/// The file's metadata block was invalid, or could not be parsed.
InvalidMetadata(JsonError),
/// The file's header reported an invalid metadata block size.
InvalidSize(i32),
/// An error was encountered while creating a new Opus decoder.
Opus(OpusError),
}
/// Convenience type for fallible return of [`Input`]s.
///
/// [`Input`]: ../struct.Input.html
pub type Result<T> = std::result::Result<T, Error>;

146
src/input/ffmpeg_src.rs Normal file
View File

@@ -0,0 +1,146 @@
use super::{
child_to_reader,
error::{Error, Result},
Codec,
Container,
Input,
Metadata,
};
use serde_json::Value;
use std::{
ffi::OsStr,
process::{Command, Stdio},
};
use tokio::process::Command as TokioCommand;
use tracing::debug;
/// Opens an audio file through `ffmpeg` and creates an audio source.
pub async fn ffmpeg<P: AsRef<OsStr>>(path: P) -> Result<Input> {
_ffmpeg(path.as_ref()).await
}
pub(crate) async fn _ffmpeg(path: &OsStr) -> Result<Input> {
// Will fail if the path is not to a file on the fs. Likely a YouTube URI.
let is_stereo = is_stereo(path)
.await
.unwrap_or_else(|_e| (false, Default::default()));
let stereo_val = if is_stereo.0 { "2" } else { "1" };
_ffmpeg_optioned(
path,
&[],
&[
"-f",
"s16le",
"-ac",
stereo_val,
"-ar",
"48000",
"-acodec",
"pcm_f32le",
"-",
],
Some(is_stereo),
)
.await
}
/// Opens an audio file through `ffmpeg` and creates an audio source, with
/// user-specified arguments to pass to ffmpeg.
///
/// Note that this does _not_ build on the arguments passed by the [`ffmpeg`]
/// function.
///
/// # Examples
///
/// Pass options to create a custom ffmpeg streamer:
///
/// ```rust,no_run
/// use songbird::input;
///
/// let stereo_val = "2";
///
/// let streamer = futures::executor::block_on(input::ffmpeg_optioned("./some_file.mp3", &[], &[
/// "-f",
/// "s16le",
/// "-ac",
/// stereo_val,
/// "-ar",
/// "48000",
/// "-acodec",
/// "pcm_s16le",
/// "-",
/// ]));
///```
pub async fn ffmpeg_optioned<P: AsRef<OsStr>>(
path: P,
pre_input_args: &[&str],
args: &[&str],
) -> Result<Input> {
_ffmpeg_optioned(path.as_ref(), pre_input_args, args, None).await
}
pub(crate) async fn _ffmpeg_optioned(
path: &OsStr,
pre_input_args: &[&str],
args: &[&str],
is_stereo_known: Option<(bool, Metadata)>,
) -> Result<Input> {
let (is_stereo, metadata) = if let Some(vals) = is_stereo_known {
vals
} else {
is_stereo(path)
.await
.ok()
.unwrap_or_else(|| (false, Default::default()))
};
let command = Command::new("ffmpeg")
.args(pre_input_args)
.arg("-i")
.arg(path)
.args(args)
.stderr(Stdio::null())
.stdin(Stdio::null())
.stdout(Stdio::piped())
.spawn()?;
Ok(Input::new(
is_stereo,
child_to_reader::<f32>(command),
Codec::FloatPcm,
Container::Raw,
Some(metadata),
))
}
pub(crate) async fn is_stereo(path: &OsStr) -> Result<(bool, Metadata)> {
let args = [
"-v",
"quiet",
"-of",
"json",
"-show_format",
"-show_streams",
"-i",
];
let out = TokioCommand::new("ffprobe")
.args(&args)
.arg(path)
.stdin(Stdio::null())
.output()
.await?;
let value: Value = serde_json::from_reader(&out.stdout[..])?;
let metadata = Metadata::from_ffprobe_json(&value);
debug!("FFprobe metadata {:?}", metadata);
if let Some(count) = metadata.channels {
Ok((count == 2, metadata))
} else {
Err(Error::Streams)
}
}

166
src/input/metadata.rs Normal file
View File

@@ -0,0 +1,166 @@
use crate::constants::*;
use serde_json::Value;
use std::time::Duration;
/// Information about an [`Input`] source.
///
/// [`Input`]: struct.Input.html
#[derive(Clone, Debug, Default, Eq, PartialEq)]
pub struct Metadata {
/// The title of this stream.
pub title: Option<String>,
/// The main artist of this stream.
pub artist: Option<String>,
/// The date of creation of this stream.
pub date: Option<String>,
/// The number of audio channels in this stream.
///
/// Any number `>= 2` is treated as stereo.
pub channels: Option<u8>,
/// The time at which the first true sample is played back.
///
/// This occurs as an artefact of coder delay.
pub start_time: Option<Duration>,
/// The reported duration of this stream.
pub duration: Option<Duration>,
/// The sample rate of this stream.
pub sample_rate: Option<u32>,
}
impl Metadata {
/// Extract metadata and details from the output of
/// `ffprobe`.
pub fn from_ffprobe_json(value: &Value) -> Self {
let format = value.as_object().and_then(|m| m.get("format"));
let duration = format
.and_then(|m| m.get("duration"))
.and_then(Value::as_str)
.and_then(|v| v.parse::<f64>().ok())
.map(Duration::from_secs_f64);
let start_time = format
.and_then(|m| m.get("start_time"))
.and_then(Value::as_str)
.and_then(|v| v.parse::<f64>().ok())
.map(Duration::from_secs_f64);
let tags = format.and_then(|m| m.get("tags"));
let title = tags
.and_then(|m| m.get("title"))
.and_then(Value::as_str)
.map(str::to_string);
let artist = tags
.and_then(|m| m.get("artist"))
.and_then(Value::as_str)
.map(str::to_string);
let date = tags
.and_then(|m| m.get("date"))
.and_then(Value::as_str)
.map(str::to_string);
let stream = value
.as_object()
.and_then(|m| m.get("streams"))
.and_then(|v| v.as_array())
.and_then(|v| {
v.iter()
.find(|line| line.get("codec_type").and_then(Value::as_str) == Some("audio"))
});
let channels = stream
.and_then(|m| m.get("channels"))
.and_then(Value::as_u64)
.map(|v| v as u8);
let sample_rate = stream
.and_then(|m| m.get("sample_rate"))
.and_then(Value::as_str)
.and_then(|v| v.parse::<u64>().ok())
.map(|v| v as u32);
Self {
title,
artist,
date,
channels,
start_time,
duration,
sample_rate,
}
}
/// Use `youtube-dl` to extract metadata for an online resource.
pub fn from_ytdl_output(value: Value) -> Self {
let obj = value.as_object();
let track = obj
.and_then(|m| m.get("track"))
.and_then(Value::as_str)
.map(str::to_string);
let title = track.or_else(|| {
obj.and_then(|m| m.get("title"))
.and_then(Value::as_str)
.map(str::to_string)
});
let true_artist = obj
.and_then(|m| m.get("artist"))
.and_then(Value::as_str)
.map(str::to_string);
let artist = true_artist.or_else(|| {
obj.and_then(|m| m.get("uploader"))
.and_then(Value::as_str)
.map(str::to_string)
});
let r_date = obj
.and_then(|m| m.get("release_date"))
.and_then(Value::as_str)
.map(str::to_string);
let date = r_date.or_else(|| {
obj.and_then(|m| m.get("upload_date"))
.and_then(Value::as_str)
.map(str::to_string)
});
let duration = obj
.and_then(|m| m.get("duration"))
.and_then(Value::as_f64)
.map(Duration::from_secs_f64);
Self {
title,
artist,
date,
channels: Some(2),
duration,
sample_rate: Some(SAMPLE_RATE_RAW as u32),
..Default::default()
}
}
/// Move all fields from a `Metadata` object into a new one.
pub fn take(&mut self) -> Self {
Self {
title: self.title.take(),
artist: self.artist.take(),
date: self.date.take(),
channels: self.channels.take(),
start_time: self.start_time.take(),
duration: self.duration.take(),
sample_rate: self.sample_rate.take(),
}
}
}

596
src/input/mod.rs Normal file
View File

@@ -0,0 +1,596 @@
//! Raw audio input data streams and sources.
//!
//! [`Input`] is handled in Songbird by combining metadata with:
//! * A 48kHz audio bytestream, via [`Reader`],
//! * A [`Container`] describing the framing mechanism of the bytestream,
//! * A [`Codec`], defining the format of audio frames.
//!
//! When used as a [`Read`], the output bytestream will be a floating-point
//! PCM stream at 48kHz, matching the channel count of the input source.
//!
//! ## Opus frame passthrough.
//! Some sources, such as [`Compressed`] or the output of [`dca`], support
//! direct frame passthrough to the driver. This lets you directly send the
//! audio data you have *without decoding, re-encoding, or mixing*. In many
//! cases, this can greatly reduce the processing/compute cost of the driver.
//!
//! This functionality requires that:
//! * only one track is active (including paused tracks),
//! * that track's input supports direct Opus frame reads,
//! * its [`Input`] [meets the promises described herein](codec/struct.OpusDecoderState.html#structfield.allow_passthrough),
//! * and that track's volume is set to `1.0`.
//!
//! [`Input`]: struct.Input.html
//! [`Reader`]: reader/enum.Reader.html
//! [`Container`]: enum.Container.html
//! [`Codec`]: codec/enum.Codec.html
//! [`Read`]: https://doc.rust-lang.org/std/io/trait.Read.html
//! [`Compressed`]: cached/struct.Compressed.html
//! [`dca`]: fn.dca.html
pub mod cached;
mod child;
pub mod codec;
mod container;
mod dca;
pub mod error;
mod ffmpeg_src;
mod metadata;
pub mod reader;
pub mod restartable;
pub mod utils;
mod ytdl_src;
pub use self::{
child::*,
codec::{Codec, CodecType},
container::{Container, Frame},
dca::dca,
ffmpeg_src::*,
metadata::Metadata,
reader::Reader,
restartable::Restartable,
ytdl_src::*,
};
use crate::constants::*;
use audiopus::coder::GenericCtl;
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use cached::OpusCompressor;
use error::{Error, Result};
use tokio::runtime::Handle;
use std::{
convert::TryFrom,
io::{
self,
Error as IoError,
ErrorKind as IoErrorKind,
Read,
Result as IoResult,
Seek,
SeekFrom,
},
mem,
time::Duration,
};
use tracing::{debug, error};
/// Data and metadata needed to correctly parse a [`Reader`]'s audio bytestream.
///
/// See the [module root] for more information.
///
/// [`Reader`]: enum.Reader.html
/// [module root]: index.html
#[derive(Debug)]
pub struct Input {
/// Information about the played source.
pub metadata: Metadata,
/// Indicates whether `source` is stereo or mono.
pub stereo: bool,
/// Underlying audio data bytestream.
pub reader: Reader,
/// Decoder used to parse the output of `reader`.
pub kind: Codec,
/// Framing strategy needed to identify frames of compressed audio.
pub container: Container,
pos: usize,
}
impl Input {
/// Creates a floating-point PCM Input from a given reader.
pub fn float_pcm(is_stereo: bool, reader: Reader) -> Input {
Input {
metadata: Default::default(),
stereo: is_stereo,
reader,
kind: Codec::FloatPcm,
container: Container::Raw,
pos: 0,
}
}
/// Creates a new Input using (at least) the given reader, codec, and container.
pub fn new(
stereo: bool,
reader: Reader,
kind: Codec,
container: Container,
metadata: Option<Metadata>,
) -> Self {
Input {
metadata: metadata.unwrap_or_default(),
stereo,
reader,
kind,
container,
pos: 0,
}
}
/// Returns whether the inner [`Reader`] implements [`Seek`].
///
/// [`Reader`]: reader/enum.Reader.html
/// [`Seek`]: https://doc.rust-lang.org/std/io/trait.Seek.html
pub fn is_seekable(&self) -> bool {
self.reader.is_seekable()
}
/// Returns whether the read audio signal is stereo (or mono).
pub fn is_stereo(&self) -> bool {
self.stereo
}
/// Returns the type of the inner [`Codec`].
///
/// [`Codec`]: codec/enum.Codec.html
pub fn get_type(&self) -> CodecType {
(&self.kind).into()
}
/// Mixes the output of this stream into a 20ms stereo audio buffer.
#[inline]
pub fn mix(&mut self, float_buffer: &mut [f32; STEREO_FRAME_SIZE], volume: f32) -> usize {
match self.add_float_pcm_frame(float_buffer, self.stereo, volume) {
Some(len) => len,
None => 0,
}
}
/// Seeks the stream to the given time, if possible.
///
/// Returns the actual time reached.
pub fn seek_time(&mut self, time: Duration) -> Option<Duration> {
let future_pos = utils::timestamp_to_byte_count(time, self.stereo);
Seek::seek(self, SeekFrom::Start(future_pos as u64))
.ok()
.map(|a| utils::byte_count_to_timestamp(a as usize, self.stereo))
}
fn read_inner(&mut self, buffer: &mut [u8], ignore_decode: bool) -> IoResult<usize> {
// This implementation of Read converts the input stream
// to floating point output.
let sample_len = mem::size_of::<f32>();
let float_space = buffer.len() / sample_len;
let mut written_floats = 0;
// TODO: better decouple codec and container here.
// this is a little bit backwards, and assumes the bottom cases are always raw...
let out = match &mut self.kind {
Codec::Opus(decoder_state) => {
if matches!(self.container, Container::Raw) {
return Err(IoError::new(
IoErrorKind::InvalidInput,
"Raw container cannot demarcate Opus frames.",
));
}
if ignore_decode {
// If we're less than one frame away from the end of cheap seeking,
// then we must decode to make sure the next starting offset is correct.
// Step one: use up the remainder of the frame.
let mut aud_skipped =
decoder_state.current_frame.len() - decoder_state.frame_pos;
decoder_state.frame_pos = 0;
decoder_state.current_frame.truncate(0);
// Step two: take frames if we can.
while buffer.len() - aud_skipped >= STEREO_FRAME_BYTE_SIZE {
decoder_state.should_reset = true;
let frame = self
.container
.next_frame_length(&mut self.reader, CodecType::Opus)?;
self.reader.consume(frame.frame_len);
aud_skipped += STEREO_FRAME_BYTE_SIZE;
}
Ok(aud_skipped)
} else {
// get new frame *if needed*
if decoder_state.frame_pos == decoder_state.current_frame.len() {
let mut decoder = decoder_state.decoder.lock();
if decoder_state.should_reset {
decoder
.reset_state()
.expect("Critical failure resetting decoder.");
decoder_state.should_reset = false;
}
let frame = self
.container
.next_frame_length(&mut self.reader, CodecType::Opus)?;
let mut opus_data_buffer = [0u8; 4000];
decoder_state
.current_frame
.resize(decoder_state.current_frame.capacity(), 0.0);
let seen =
Read::read(&mut self.reader, &mut opus_data_buffer[..frame.frame_len])?;
let samples = decoder
.decode_float(
Some(&opus_data_buffer[..seen]),
&mut decoder_state.current_frame[..],
false,
)
.unwrap_or(0);
decoder_state.current_frame.truncate(2 * samples);
decoder_state.frame_pos = 0;
}
// read from frame which is present.
let mut buffer = &mut buffer[..];
let start = decoder_state.frame_pos;
let to_write = float_space.min(decoder_state.current_frame.len() - start);
for val in &decoder_state.current_frame[start..start + float_space] {
buffer.write_f32::<LittleEndian>(*val)?;
}
decoder_state.frame_pos += to_write;
written_floats = to_write;
Ok(written_floats * mem::size_of::<f32>())
}
},
Codec::Pcm => {
let mut buffer = &mut buffer[..];
while written_floats < float_space {
if let Ok(signal) = self.reader.read_i16::<LittleEndian>() {
buffer.write_f32::<LittleEndian>(f32::from(signal) / 32768.0)?;
written_floats += 1;
} else {
break;
}
}
Ok(written_floats * mem::size_of::<f32>())
},
Codec::FloatPcm => Read::read(&mut self.reader, buffer),
};
out.map(|v| {
self.pos += v;
v
})
}
fn cheap_consume(&mut self, count: usize) -> IoResult<usize> {
let mut scratch = [0u8; STEREO_FRAME_BYTE_SIZE * 4];
let len = scratch.len();
let mut done = 0;
loop {
let read = self.read_inner(&mut scratch[..len.min(count - done)], true)?;
if read == 0 {
break;
}
done += read;
}
Ok(done)
}
pub(crate) fn supports_passthrough(&self) -> bool {
match &self.kind {
Codec::Opus(state) => state.allow_passthrough,
_ => false,
}
}
pub(crate) fn read_opus_frame(&mut self, buffer: &mut [u8]) -> IoResult<usize> {
// Called in event of opus passthrough.
if let Codec::Opus(state) = &mut self.kind {
// step 1: align to frame.
self.pos += state.current_frame.len() - state.frame_pos;
state.frame_pos = 0;
state.current_frame.truncate(0);
// step 2: read new header.
let frame = self
.container
.next_frame_length(&mut self.reader, CodecType::Opus)?;
// step 3: read in bytes.
self.reader
.read_exact(&mut buffer[..frame.frame_len])
.map(|_| {
self.pos += STEREO_FRAME_BYTE_SIZE;
frame.frame_len
})
} else {
Err(IoError::new(
IoErrorKind::InvalidInput,
"Frame passthrough not supported for this file.",
))
}
}
pub(crate) fn prep_with_handle(&mut self, handle: Handle) {
self.reader.prep_with_handle(handle);
}
}
impl Read for Input {
fn read(&mut self, buffer: &mut [u8]) -> IoResult<usize> {
self.read_inner(buffer, false)
}
}
impl Seek for Input {
fn seek(&mut self, pos: SeekFrom) -> IoResult<u64> {
let mut target = self.pos;
match pos {
SeekFrom::Start(pos) => {
target = pos as usize;
},
SeekFrom::Current(rel) => {
target = target.wrapping_add(rel as usize);
},
SeekFrom::End(_pos) => unimplemented!(),
}
debug!("Seeking to {:?}", pos);
(if target == self.pos {
Ok(0)
} else if let Some(conversion) = self.container.try_seek_trivial(self.get_type()) {
let inside_target = (target * conversion) / mem::size_of::<f32>();
Seek::seek(&mut self.reader, SeekFrom::Start(inside_target as u64)).map(|inner_dest| {
let outer_dest = ((inner_dest as usize) * mem::size_of::<f32>()) / conversion;
self.pos = outer_dest;
outer_dest
})
} else if target > self.pos {
// seek in the next amount, disabling decoding if need be.
let shift = target - self.pos;
self.cheap_consume(shift)
} else {
// start from scratch, then seek in...
Seek::seek(
&mut self.reader,
SeekFrom::Start(self.container.input_start() as u64),
)?;
self.cheap_consume(target)
})
.map(|_| self.pos as u64)
}
}
/// Extension trait to pull frames of audio from a byte source.
pub(crate) trait ReadAudioExt {
fn add_float_pcm_frame(
&mut self,
float_buffer: &mut [f32; STEREO_FRAME_SIZE],
true_stereo: bool,
volume: f32,
) -> Option<usize>;
fn consume(&mut self, amt: usize) -> usize
where
Self: Sized;
}
impl<R: Read + Sized> ReadAudioExt for R {
fn add_float_pcm_frame(
&mut self,
float_buffer: &mut [f32; STEREO_FRAME_SIZE],
stereo: bool,
volume: f32,
) -> Option<usize> {
// IDEA: Read in 8 floats at a time, then use iterator code
// to gently nudge the compiler into vectorising for us.
// Max SIMD float32 lanes is 8 on AVX, older archs use a divisor of this
// e.g., 4.
const SAMPLE_LEN: usize = mem::size_of::<f32>();
const FLOAT_COUNT: usize = 512;
let mut simd_float_bytes = [0u8; FLOAT_COUNT * SAMPLE_LEN];
let mut simd_float_buf = [0f32; FLOAT_COUNT];
let mut frame_pos = 0;
// Code duplication here is because unifying these codepaths
// with a dynamic chunk size is not zero-cost.
if stereo {
let mut max_bytes = STEREO_FRAME_BYTE_SIZE;
while frame_pos < float_buffer.len() {
let progress = self
.read(&mut simd_float_bytes[..max_bytes.min(FLOAT_COUNT * SAMPLE_LEN)])
.and_then(|byte_len| {
let target = byte_len / SAMPLE_LEN;
(&simd_float_bytes[..byte_len])
.read_f32_into::<LittleEndian>(&mut simd_float_buf[..target])
.map(|_| target)
})
.map(|f32_len| {
let new_pos = frame_pos + f32_len;
for (el, new_el) in float_buffer[frame_pos..new_pos]
.iter_mut()
.zip(&simd_float_buf[..f32_len])
{
*el += volume * new_el;
}
(new_pos, f32_len)
});
match progress {
Ok((new_pos, delta)) => {
frame_pos = new_pos;
max_bytes -= delta * SAMPLE_LEN;
if delta == 0 {
break;
}
},
Err(ref e) =>
return if e.kind() == IoErrorKind::UnexpectedEof {
error!("EOF unexpectedly: {:?}", e);
Some(frame_pos)
} else {
error!("Input died unexpectedly: {:?}", e);
None
},
}
}
} else {
let mut max_bytes = MONO_FRAME_BYTE_SIZE;
while frame_pos < float_buffer.len() {
let progress = self
.read(&mut simd_float_bytes[..max_bytes.min(FLOAT_COUNT * SAMPLE_LEN)])
.and_then(|byte_len| {
let target = byte_len / SAMPLE_LEN;
(&simd_float_bytes[..byte_len])
.read_f32_into::<LittleEndian>(&mut simd_float_buf[..target])
.map(|_| target)
})
.map(|f32_len| {
let new_pos = frame_pos + (2 * f32_len);
for (els, new_el) in float_buffer[frame_pos..new_pos]
.chunks_exact_mut(2)
.zip(&simd_float_buf[..f32_len])
{
let sample = volume * new_el;
els[0] += sample;
els[1] += sample;
}
(new_pos, f32_len)
});
match progress {
Ok((new_pos, delta)) => {
frame_pos = new_pos;
max_bytes -= delta * SAMPLE_LEN;
if delta == 0 {
break;
}
},
Err(ref e) =>
return if e.kind() == IoErrorKind::UnexpectedEof {
Some(frame_pos)
} else {
error!("Input died unexpectedly: {:?}", e);
None
},
}
}
}
Some(frame_pos * SAMPLE_LEN)
}
fn consume(&mut self, amt: usize) -> usize {
io::copy(&mut self.by_ref().take(amt as u64), &mut io::sink()).unwrap_or(0) as usize
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::*;
#[test]
fn float_pcm_input_unchanged_mono() {
let data = make_sine(50 * MONO_FRAME_SIZE, false);
let mut input = Input::new(
false,
data.clone().into(),
Codec::FloatPcm,
Container::Raw,
None,
);
let mut out_vec = vec![];
let len = input.read_to_end(&mut out_vec).unwrap();
assert_eq!(out_vec[..len], data[..]);
}
#[test]
fn float_pcm_input_unchanged_stereo() {
let data = make_sine(50 * MONO_FRAME_SIZE, true);
let mut input = Input::new(
true,
data.clone().into(),
Codec::FloatPcm,
Container::Raw,
None,
);
let mut out_vec = vec![];
let len = input.read_to_end(&mut out_vec).unwrap();
assert_eq!(out_vec[..len], data[..]);
}
#[test]
fn pcm_input_becomes_float_mono() {
let data = make_pcm_sine(50 * MONO_FRAME_SIZE, false);
let mut input = Input::new(false, data.clone().into(), Codec::Pcm, Container::Raw, None);
let mut out_vec = vec![];
let len = input.read_to_end(&mut out_vec).unwrap();
let mut i16_window = &data[..];
let mut float_window = &out_vec[..];
while i16_window.len() != 0 {
let before = i16_window.read_i16::<LittleEndian>().unwrap() as f32;
let after = float_window.read_f32::<LittleEndian>().unwrap();
let diff = (before / 32768.0) - after;
assert!(diff.abs() < f32::EPSILON);
}
}
#[test]
fn pcm_input_becomes_float_stereo() {
let data = make_pcm_sine(50 * MONO_FRAME_SIZE, true);
let mut input = Input::new(true, data.clone().into(), Codec::Pcm, Container::Raw, None);
let mut out_vec = vec![];
let len = input.read_to_end(&mut out_vec).unwrap();
let mut i16_window = &data[..];
let mut float_window = &out_vec[..];
while i16_window.len() != 0 {
let before = i16_window.read_i16::<LittleEndian>().unwrap() as f32;
let after = float_window.read_f32::<LittleEndian>().unwrap();
let diff = (before / 32768.0) - after;
assert!(diff.abs() < f32::EPSILON);
}
}
}

180
src/input/reader.rs Normal file
View File

@@ -0,0 +1,180 @@
//! Raw handlers for input bytestreams.
use super::*;
use std::{
fmt::{Debug, Error as FormatError, Formatter},
fs::File,
io::{
BufReader,
Cursor,
Error as IoError,
ErrorKind as IoErrorKind,
Read,
Result as IoResult,
Seek,
SeekFrom,
},
result::Result as StdResult,
};
use streamcatcher::{Catcher, TxCatcher};
/// Usable data/byte sources for an audio stream.
///
/// Users may define their own data sources using [`Extension`]
/// and [`ExtensionSeek`].
///
/// [`Extension`]: #variant.Extension
/// [`ExtensionSeek`]: #variant.ExtensionSeek
pub enum Reader {
/// Piped output of another program (i.e., [`ffmpeg`]).
///
/// Does not support seeking.
///
/// [`ffmpeg`]: ../fn.ffmpeg.html
Pipe(BufReader<ChildContainer>),
/// A cached, raw in-memory store, provided by Songbird.
///
/// Supports seeking.
Memory(Catcher<Box<Reader>>),
/// A cached, Opus-compressed in-memory store, provided by Songbird.
///
/// Supports seeking.
Compressed(TxCatcher<Box<Input>, OpusCompressor>),
/// A source which supports seeking by recreating its inout stream.
///
/// Supports seeking.
Restartable(Restartable),
/// A source contained in a local file.
///
/// Supports seeking.
File(BufReader<File>),
/// A source contained as an array in memory.
///
/// Supports seeking.
Vec(Cursor<Vec<u8>>),
/// A basic user-provided source.
///
/// Does not support seeking.
Extension(Box<dyn Read + Send>),
/// A user-provided source which also implements [`Seek`].
///
/// Supports seeking.
///
/// [`Seek`]: https://doc.rust-lang.org/std/io/trait.Seek.html
ExtensionSeek(Box<dyn ReadSeek + Send>),
}
impl Reader {
/// Returns whether the given source implements [`Seek`].
///
/// [`Seek`]: https://doc.rust-lang.org/std/io/trait.Seek.html
pub fn is_seekable(&self) -> bool {
use Reader::*;
match self {
Restartable(_) | Compressed(_) | Memory(_) => true,
Extension(_) => false,
ExtensionSeek(_) => true,
_ => false,
}
}
#[allow(clippy::single_match)]
pub(crate) fn prep_with_handle(&mut self, handle: Handle) {
use Reader::*;
match self {
Restartable(r) => r.prep_with_handle(handle),
_ => {},
}
}
}
impl Read for Reader {
fn read(&mut self, buffer: &mut [u8]) -> IoResult<usize> {
use Reader::*;
match self {
Pipe(a) => Read::read(a, buffer),
Memory(a) => Read::read(a, buffer),
Compressed(a) => Read::read(a, buffer),
Restartable(a) => Read::read(a, buffer),
File(a) => Read::read(a, buffer),
Vec(a) => Read::read(a, buffer),
Extension(a) => a.read(buffer),
ExtensionSeek(a) => a.read(buffer),
}
}
}
impl Seek for Reader {
fn seek(&mut self, pos: SeekFrom) -> IoResult<u64> {
use Reader::*;
match self {
Pipe(_) | Extension(_) => Err(IoError::new(
IoErrorKind::InvalidInput,
"Seeking not supported on Reader of this type.",
)),
Memory(a) => Seek::seek(a, pos),
Compressed(a) => Seek::seek(a, pos),
File(a) => Seek::seek(a, pos),
Restartable(a) => Seek::seek(a, pos),
Vec(a) => Seek::seek(a, pos),
ExtensionSeek(a) => a.seek(pos),
}
}
}
impl Debug for Reader {
fn fmt(&self, f: &mut Formatter<'_>) -> StdResult<(), FormatError> {
use Reader::*;
let field = match self {
Pipe(a) => format!("{:?}", a),
Memory(a) => format!("{:?}", a),
Compressed(a) => format!("{:?}", a),
Restartable(a) => format!("{:?}", a),
File(a) => format!("{:?}", a),
Vec(a) => format!("{:?}", a),
Extension(_) => "Extension".to_string(),
ExtensionSeek(_) => "ExtensionSeek".to_string(),
};
f.debug_tuple("Reader").field(&field).finish()
}
}
impl From<Vec<u8>> for Reader {
fn from(val: Vec<u8>) -> Reader {
Reader::Vec(Cursor::new(val))
}
}
/// Fusion trait for custom input sources which allow seeking.
pub trait ReadSeek {
/// See [`Read::read`].
///
/// [`Read::read`]: https://doc.rust-lang.org/nightly/std/io/trait.Read.html#tymethod.read
fn read(&mut self, buf: &mut [u8]) -> IoResult<usize>;
/// See [`Seek::seek`].
///
/// [`Seek::seek`]: https://doc.rust-lang.org/nightly/std/io/trait.Seek.html#tymethod.seek
fn seek(&mut self, pos: SeekFrom) -> IoResult<u64>;
}
impl Read for dyn ReadSeek {
fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
ReadSeek::read(self, buf)
}
}
impl Seek for dyn ReadSeek {
fn seek(&mut self, pos: SeekFrom) -> IoResult<u64> {
ReadSeek::seek(self, pos)
}
}
impl<R: Read + Seek> ReadSeek for R {
fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
Read::read(self, buf)
}
fn seek(&mut self, pos: SeekFrom) -> IoResult<u64> {
Seek::seek(self, pos)
}
}

294
src/input/restartable.rs Normal file
View File

@@ -0,0 +1,294 @@
//! A source which supports seeking by recreating its input stream.
//!
//! This is intended for use with single-use audio tracks which
//! may require looping or seeking, but where additional memory
//! cannot be spared. Forward seeks will drain the track until reaching
//! the desired timestamp.
//!
//! Restarting occurs by temporarily pausing the track, running the restart
//! mechanism, and then passing the handle back to the mixer thread. Until
//! success/failure is confirmed, the track produces silence.
use super::*;
use flume::{Receiver, TryRecvError};
use futures::executor;
use std::{
ffi::OsStr,
fmt::{Debug, Error as FormatError, Formatter},
io::{Error as IoError, ErrorKind as IoErrorKind, Read, Result as IoResult, Seek, SeekFrom},
result::Result as StdResult,
time::Duration,
};
type Recreator = Box<dyn Restart + Send + 'static>;
type RecreateChannel = Receiver<Result<(Box<Input>, Recreator)>>;
/// A wrapper around a method to create a new [`Input`] which
/// seeks backward by recreating the source.
///
/// The main purpose of this wrapper is to enable seeking on
/// incompatible sources (i.e., ffmpeg output) and to ease resource
/// consumption for commonly reused/shared tracks. [`Compressed`]
/// and [`Memory`] offer the same functionality with different
/// tradeoffs.
///
/// This is intended for use with single-use audio tracks which
/// may require looping or seeking, but where additional memory
/// cannot be spared. Forward seeks will drain the track until reaching
/// the desired timestamp.
///
/// [`Input`]: struct.Input.html
/// [`Memory`]: cached/struct.Memory.html
/// [`Compressed`]: cached/struct.Compressed.html
pub struct Restartable {
async_handle: Option<Handle>,
awaiting_source: Option<RecreateChannel>,
position: usize,
recreator: Option<Recreator>,
source: Box<Input>,
}
impl Restartable {
/// Create a new source, which can be restarted using a `recreator` function.
pub fn new(mut recreator: impl Restart + Send + 'static) -> Result<Self> {
recreator.call_restart(None).map(move |source| Self {
async_handle: None,
awaiting_source: None,
position: 0,
recreator: Some(Box::new(recreator)),
source: Box::new(source),
})
}
/// Create a new restartable ffmpeg source for a local file.
pub fn ffmpeg<P: AsRef<OsStr> + Send + Clone + 'static>(path: P) -> Result<Self> {
Self::new(FfmpegRestarter { path })
}
/// Create a new restartable ytdl source.
///
/// The cost of restarting and seeking will probably be *very* high:
/// expect a pause if you seek backwards.
pub fn ytdl<P: AsRef<str> + Send + Clone + 'static>(uri: P) -> Result<Self> {
Self::new(move |time: Option<Duration>| {
if let Some(time) = time {
let ts = format!("{}.{}", time.as_secs(), time.subsec_millis());
executor::block_on(_ytdl(uri.as_ref(), &["-ss", &ts]))
} else {
executor::block_on(ytdl(uri.as_ref()))
}
})
}
/// Create a new restartable ytdl source, using the first result of a youtube search.
///
/// The cost of restarting and seeking will probably be *very* high:
/// expect a pause if you seek backwards.
pub fn ytdl_search(name: &str) -> Result<Self> {
Self::ytdl(format!("ytsearch1:{}", name))
}
pub(crate) fn prep_with_handle(&mut self, handle: Handle) {
self.async_handle = Some(handle);
}
}
/// Trait used to create an instance of a [`Reader`] at instantiation and when
/// a backwards seek is needed.
///
/// Many closures derive this automatically.
///
/// [`Reader`]: ../reader/enum.Reader.html
pub trait Restart {
/// Tries to create a replacement source.
fn call_restart(&mut self, time: Option<Duration>) -> Result<Input>;
}
struct FfmpegRestarter<P>
where
P: AsRef<OsStr> + Send,
{
path: P,
}
impl<P> Restart for FfmpegRestarter<P>
where
P: AsRef<OsStr> + Send,
{
fn call_restart(&mut self, time: Option<Duration>) -> Result<Input> {
executor::block_on(async {
if let Some(time) = time {
let is_stereo = is_stereo(self.path.as_ref())
.await
.unwrap_or_else(|_e| (false, Default::default()));
let stereo_val = if is_stereo.0 { "2" } else { "1" };
let ts = format!("{}.{}", time.as_secs(), time.subsec_millis());
_ffmpeg_optioned(
self.path.as_ref(),
&["-ss", &ts],
&[
"-f",
"s16le",
"-ac",
stereo_val,
"-ar",
"48000",
"-acodec",
"pcm_f32le",
"-",
],
Some(is_stereo),
)
.await
} else {
ffmpeg(self.path.as_ref()).await
}
})
}
}
impl<P> Restart for P
where
P: FnMut(Option<Duration>) -> Result<Input> + Send + 'static,
{
fn call_restart(&mut self, time: Option<Duration>) -> Result<Input> {
(self)(time)
}
}
impl Debug for Restartable {
fn fmt(&self, f: &mut Formatter<'_>) -> StdResult<(), FormatError> {
f.debug_struct("Restartable")
.field("async_handle", &self.async_handle)
.field("awaiting_source", &self.awaiting_source)
.field("position", &self.position)
.field("recreator", &"<fn>")
.field("source", &self.source)
.finish()
}
}
impl From<Restartable> for Input {
fn from(mut src: Restartable) -> Self {
let kind = src.source.kind.clone();
let meta = Some(src.source.metadata.take());
let stereo = src.source.stereo;
let container = src.source.container;
Input::new(stereo, Reader::Restartable(src), kind, container, meta)
}
}
// How do these work at a high level?
// If you need to restart, send a request to do this to the async context.
// if a request is pending, then just output all zeroes.
impl Read for Restartable {
fn read(&mut self, buffer: &mut [u8]) -> IoResult<usize> {
let (out_val, march_pos, remove_async) = if let Some(chan) = &self.awaiting_source {
match chan.try_recv() {
Ok(Ok((new_source, recreator))) => {
self.source = new_source;
self.recreator = Some(recreator);
(Read::read(&mut self.source, buffer), true, true)
},
Ok(Err(source_error)) => {
let e = Err(IoError::new(
IoErrorKind::UnexpectedEof,
format!("Failed to create new reader: {:?}.", source_error),
));
(e, false, true)
},
Err(TryRecvError::Empty) => {
// Output all zeroes.
for el in buffer.iter_mut() {
*el = 0;
}
(Ok(buffer.len()), false, false)
},
Err(_) => {
let e = Err(IoError::new(
IoErrorKind::UnexpectedEof,
"Failed to create new reader: dropped.",
));
(e, false, true)
},
}
} else {
// already have a good, valid source.
(Read::read(&mut self.source, buffer), true, false)
};
if remove_async {
self.awaiting_source = None;
}
if march_pos {
out_val.map(|a| {
self.position += a;
a
})
} else {
out_val
}
}
}
impl Seek for Restartable {
fn seek(&mut self, pos: SeekFrom) -> IoResult<u64> {
let _local_pos = self.position as u64;
use SeekFrom::*;
match pos {
Start(offset) => {
let stereo = self.source.stereo;
let _current_ts = utils::byte_count_to_timestamp(self.position, stereo);
let offset = offset as usize;
if offset < self.position {
// We're going back in time.
if let Some(handle) = self.async_handle.as_ref() {
let (tx, rx) = flume::bounded(1);
self.awaiting_source = Some(rx);
let recreator = self.recreator.take();
if let Some(mut rec) = recreator {
handle.spawn(async move {
let ret_val = rec.call_restart(Some(
utils::byte_count_to_timestamp(offset, stereo),
));
let _ = tx.send(ret_val.map(Box::new).map(|v| (v, rec)));
});
} else {
return Err(IoError::new(
IoErrorKind::Interrupted,
"Previous seek in progress.",
));
}
self.position = offset;
} else {
return Err(IoError::new(
IoErrorKind::Interrupted,
"Cannot safely call seek until provided an async context handle.",
));
}
} else {
self.position += self.source.consume(offset - self.position);
}
Ok(offset as u64)
},
End(_offset) => Err(IoError::new(
IoErrorKind::InvalidInput,
"End point for Restartables is not known.",
)),
Current(_offset) => unimplemented!(),
}
}
}

41
src/input/utils.rs Normal file
View File

@@ -0,0 +1,41 @@
//! Utility methods for seeking or decoding.
use crate::constants::*;
use audiopus::{coder::Decoder, Channels, Result as OpusResult, SampleRate};
use std::{mem, time::Duration};
/// Calculates the sample position in a FloatPCM stream from a timestamp.
pub fn timestamp_to_sample_count(timestamp: Duration, stereo: bool) -> usize {
((timestamp.as_millis() as usize) * (MONO_FRAME_SIZE / FRAME_LEN_MS)) << stereo as usize
}
/// Calculates the time position in a FloatPCM stream from a sample index.
pub fn sample_count_to_timestamp(amt: usize, stereo: bool) -> Duration {
Duration::from_millis((((amt * FRAME_LEN_MS) / MONO_FRAME_SIZE) as u64) >> stereo as u64)
}
/// Calculates the byte position in a FloatPCM stream from a timestamp.
///
/// Each sample is sized by `mem::size_of::<f32>() == 4usize`.
pub fn timestamp_to_byte_count(timestamp: Duration, stereo: bool) -> usize {
timestamp_to_sample_count(timestamp, stereo) * mem::size_of::<f32>()
}
/// Calculates the time position in a FloatPCM stream from a byte index.
///
/// Each sample is sized by `mem::size_of::<f32>() == 4usize`.
pub fn byte_count_to_timestamp(amt: usize, stereo: bool) -> Duration {
sample_count_to_timestamp(amt / mem::size_of::<f32>(), stereo)
}
/// Create an Opus decoder outputting at a sample rate of 48kHz.
pub fn decoder(stereo: bool) -> OpusResult<Decoder> {
Decoder::new(
SampleRate::Hz48000,
if stereo {
Channels::Stereo
} else {
Channels::Mono
},
)
}

107
src/input/ytdl_src.rs Normal file
View File

@@ -0,0 +1,107 @@
use super::{
child_to_reader,
error::{Error, Result},
Codec,
Container,
Input,
Metadata,
};
use serde_json::Value;
use std::{
io::{BufRead, BufReader, Read},
process::{Command, Stdio},
};
use tokio::task;
use tracing::trace;
/// Creates a streamed audio source with `youtube-dl` and `ffmpeg`.
pub async fn ytdl(uri: &str) -> Result<Input> {
_ytdl(uri, &[]).await
}
pub(crate) async fn _ytdl(uri: &str, pre_args: &[&str]) -> Result<Input> {
let ytdl_args = [
"--print-json",
"-f",
"webm[abr>0]/bestaudio/best",
"-R",
"infinite",
"--no-playlist",
"--ignore-config",
uri,
"-o",
"-",
];
let ffmpeg_args = [
"-f",
"s16le",
"-ac",
"2",
"-ar",
"48000",
"-acodec",
"pcm_f32le",
"-",
];
let mut youtube_dl = Command::new("youtube-dl")
.args(&ytdl_args)
.stdin(Stdio::null())
.stderr(Stdio::piped())
.stdout(Stdio::piped())
.spawn()?;
let stderr = youtube_dl.stderr.take();
let (returned_stderr, value) = task::spawn_blocking(move || {
if let Some(mut s) = stderr {
let out: Option<Value> = {
let mut o_vec = vec![];
let mut serde_read = BufReader::new(s.by_ref());
// Newline...
if let Ok(len) = serde_read.read_until(0xA, &mut o_vec) {
serde_json::from_slice(&o_vec[..len]).ok()
} else {
None
}
};
(Some(s), out)
} else {
(None, None)
}
})
.await
.map_err(|_| Error::Metadata)?;
youtube_dl.stderr = returned_stderr;
let ffmpeg = Command::new("ffmpeg")
.args(pre_args)
.arg("-i")
.arg("-")
.args(&ffmpeg_args)
.stdin(youtube_dl.stdout.ok_or(Error::Stdout)?)
.stderr(Stdio::null())
.stdout(Stdio::piped())
.spawn()?;
let metadata = Metadata::from_ytdl_output(value.unwrap_or_default());
trace!("ytdl metadata {:?}", metadata);
Ok(Input::new(
true,
child_to_reader::<f32>(ffmpeg),
Codec::FloatPcm,
Container::Raw,
Some(metadata),
))
}
/// Creates a streamed audio source from YouTube search results with `youtube-dl`,`ffmpeg`, and `ytsearch`.
/// Takes the first video listed from the YouTube search.
pub async fn ytdl_search(name: &str) -> Result<Input> {
ytdl(&format!("ytsearch1:{}", name)).await
}

84
src/lib.rs Normal file
View File

@@ -0,0 +1,84 @@
#![doc(
html_logo_url = "https://raw.githubusercontent.com/FelixMcFelix/serenity/voice-rework/songbird/songbird.png",
html_favicon_url = "https://raw.githubusercontent.com/FelixMcFelix/serenity/voice-rework/songbird/songbird-ico.png"
)]
#![deny(missing_docs)]
//! ![project logo][logo]
//!
//! Songbird is an async, cross-library compatible voice system for Discord, written in Rust.
//! The library offers:
//! * A standalone gateway frontend compatible with [serenity] and [twilight] using the
//! `"gateway"` and `"[serenity/twilight]-[rustls/native]"` features. You can even run
//! driverless, to help manage your [lavalink] sessions.
//! * A standalone driver for voice calls, via the `"driver"` feature. If you can create
//! a [`ConnectionInfo`] using any other gateway, or language for your bot, then you
//! can run the songbird voice driver.
//! * And, by default, a fully featured voice system featuring events, queues, RT(C)P packet
//! handling, seeking on compatible streams, shared multithreaded audio stream caches,
//! and direct Opus data passthrough from DCA files.
//!
//! ## Examples
//! Full examples showing various types of functionality and integrations can be found as part of [serenity's examples],
//! and in [this crate's examples directory].
//!
//! ## Attribution
//!
//! Songbird's logo is based upon the copyright-free image ["Black-Capped Chickadee"] by George Gorgas White.
//!
//! [logo]: https://raw.githubusercontent.com/FelixMcFelix/serenity/voice-rework/songbird/songbird.png
//! [serenity]: https://github.com/serenity-rs/serenity
//! [twilight]: https://github.com/twilight-rs/twilight
//! [serenity's examples]: https://github.com/serenity-rs/serenity/tree/current/examples
//! [this crate's examples directory]: https://github.com/serenity-rs/serenity/tree/current/songbird/examples
//! ["Black-Capped Chickadee"]: https://www.oldbookillustrations.com/illustrations/black-capped-chickadee/
//! [`ConnectionInfo`]: struct.ConnectionInfo.html
//! [lavalink]: https://github.com/Frederikam/Lavalink
pub mod constants;
#[cfg(feature = "driver")]
pub mod driver;
pub mod error;
#[cfg(feature = "driver")]
pub mod events;
#[cfg(feature = "gateway")]
mod handler;
pub mod id;
pub(crate) mod info;
#[cfg(feature = "driver")]
pub mod input;
#[cfg(feature = "gateway")]
mod manager;
#[cfg(feature = "serenity")]
pub mod serenity;
#[cfg(feature = "gateway")]
pub mod shards;
#[cfg(feature = "driver")]
pub mod tracks;
#[cfg(feature = "driver")]
mod ws;
#[cfg(feature = "driver")]
pub use audiopus::{self as opus, Bitrate};
#[cfg(feature = "driver")]
pub use discortp as packet;
#[cfg(feature = "driver")]
pub use serenity_voice_model as model;
#[cfg(test)]
use utils as test_utils;
#[cfg(feature = "driver")]
pub use crate::{
driver::Driver,
events::{CoreEvent, Event, EventContext, EventHandler, TrackEvent},
input::{ffmpeg, ytdl},
tracks::create_player,
};
#[cfg(feature = "gateway")]
pub use crate::{handler::Call, manager::Songbird};
#[cfg(feature = "serenity")]
pub use crate::serenity::*;
pub use info::ConnectionInfo;

353
src/manager.rs Normal file
View File

@@ -0,0 +1,353 @@
#[cfg(feature = "driver")]
use crate::error::ConnectionResult;
use crate::{
error::{JoinError, JoinResult},
id::{ChannelId, GuildId, UserId},
shards::Sharder,
Call,
ConnectionInfo,
};
#[cfg(feature = "serenity")]
use async_trait::async_trait;
use flume::Receiver;
#[cfg(feature = "serenity")]
use futures::channel::mpsc::UnboundedSender as Sender;
use parking_lot::RwLock as PRwLock;
#[cfg(feature = "serenity")]
use serenity::{
client::bridge::voice::VoiceGatewayManager,
gateway::InterMessage,
model::{
id::{GuildId as SerenityGuild, UserId as SerenityUser},
voice::VoiceState,
},
};
use std::{collections::HashMap, sync::Arc};
use tokio::sync::Mutex;
#[cfg(feature = "twilight")]
use twilight_gateway::Cluster;
#[cfg(feature = "twilight")]
use twilight_model::gateway::event::Event as TwilightEvent;
#[derive(Clone, Copy, Debug, Default)]
struct ClientData {
shard_count: u64,
initialised: bool,
user_id: UserId,
}
/// A shard-aware struct responsible for managing [`Call`]s.
///
/// This manager transparently maps guild state and a source of shard information
/// into individual calls, and forwards state updates which affect call state.
///
/// [`Call`]: struct.Call.html
#[derive(Debug)]
pub struct Songbird {
client_data: PRwLock<ClientData>,
calls: PRwLock<HashMap<GuildId, Arc<Mutex<Call>>>>,
sharder: Sharder,
}
impl Songbird {
#[cfg(feature = "serenity")]
/// Create a new Songbird instance for serenity.
///
/// This must be [registered] after creation.
///
/// [registered]: serenity/fn.register_with.html
pub fn serenity() -> Arc<Self> {
Arc::new(Self {
client_data: Default::default(),
calls: Default::default(),
sharder: Sharder::Serenity(Default::default()),
})
}
#[cfg(feature = "twilight")]
/// Create a new Songbird instance for twilight.
///
/// Twilight handlers do not need to be registered, but
/// users are responsible for passing in any events using
/// [`process`].
///
/// [`process`]: #method.process
pub fn twilight<U>(cluster: Cluster, shard_count: u64, user_id: U) -> Arc<Self>
where
U: Into<UserId>,
{
Arc::new(Self {
client_data: PRwLock::new(ClientData {
shard_count,
initialised: true,
user_id: user_id.into(),
}),
calls: Default::default(),
sharder: Sharder::Twilight(cluster),
})
}
/// Set the bot's user, and the number of shards in use.
///
/// If this struct is already initialised (e.g., from [`::twilight`]),
/// or a previous call, then this function is a no-op.
///
/// [`::twilight`]: #method.twilight
pub fn initialise_client_data<U: Into<UserId>>(&self, shard_count: u64, user_id: U) {
let mut client_data = self.client_data.write();
if client_data.initialised {
return;
}
client_data.shard_count = shard_count;
client_data.user_id = user_id.into();
client_data.initialised = true;
}
/// Retreives a [`Call`] for the given guild, if one already exists.
///
/// [`Call`]: struct.Call.html
pub fn get<G: Into<GuildId>>(&self, guild_id: G) -> Option<Arc<Mutex<Call>>> {
let map_read = self.calls.read();
map_read.get(&guild_id.into()).cloned()
}
/// Retreives a [`Call`] for the given guild, creating a new one if
/// none is found.
///
/// This will not join any calls, or cause connection state to change.
///
/// [`Call`]: struct.Call.html
pub fn get_or_insert(&self, guild_id: GuildId) -> Arc<Mutex<Call>> {
self.get(guild_id).unwrap_or_else(|| {
let mut map_read = self.calls.write();
map_read
.entry(guild_id)
.or_insert_with(|| {
let info = self.manager_info();
let shard = shard_id(guild_id.0, info.shard_count);
let shard_handle = self
.sharder
.get_shard(shard)
.expect("Failed to get shard handle: shard_count incorrect?");
Arc::new(Mutex::new(Call::new(guild_id, shard_handle, info.user_id)))
})
.clone()
})
}
fn manager_info(&self) -> ClientData {
let client_data = self.client_data.write();
*client_data
}
#[cfg(feature = "driver")]
/// Connects to a target by retrieving its relevant [`Call`] and
/// connecting, or creating the handler if required.
///
/// This can also switch to the given channel, if a handler already exists
/// for the target and the current connected channel is not equal to the
/// given channel.
///
/// The provided channel ID is used as a connection target. The
/// channel _must_ be in the provided guild. This is _not_ checked by the
/// library, and will result in an error. If there is already a connected
/// handler for the guild, _and_ the provided channel is different from the
/// channel that the connection is already connected to, then the handler
/// will switch the connection to the provided channel.
///
/// If you _only_ need to retrieve the handler for a target, then use
/// [`get`].
///
/// [`Call`]: struct.Call.html
/// [`get`]: #method.get
#[inline]
pub async fn join<C, G>(
&self,
guild_id: G,
channel_id: C,
) -> (Arc<Mutex<Call>>, JoinResult<Receiver<ConnectionResult<()>>>)
where
C: Into<ChannelId>,
G: Into<GuildId>,
{
self._join(guild_id.into(), channel_id.into()).await
}
#[cfg(feature = "driver")]
async fn _join(
&self,
guild_id: GuildId,
channel_id: ChannelId,
) -> (Arc<Mutex<Call>>, JoinResult<Receiver<ConnectionResult<()>>>) {
let call = self.get_or_insert(guild_id);
let result = {
let mut handler = call.lock().await;
handler.join(channel_id).await
};
(call, result)
}
/// Partially connects to a target by retrieving its relevant [`Call`] and
/// connecting, or creating the handler if required.
///
/// This method returns the handle and the connection info needed for other libraries
/// or drivers, such as lavalink, and does not actually start or run a voice call.
///
/// [`Call`]: struct.Call.html
#[inline]
pub async fn join_gateway<C, G>(
&self,
guild_id: G,
channel_id: C,
) -> (Arc<Mutex<Call>>, JoinResult<Receiver<ConnectionInfo>>)
where
C: Into<ChannelId>,
G: Into<GuildId>,
{
self._join_gateway(guild_id.into(), channel_id.into()).await
}
async fn _join_gateway(
&self,
guild_id: GuildId,
channel_id: ChannelId,
) -> (Arc<Mutex<Call>>, JoinResult<Receiver<ConnectionInfo>>) {
let call = self.get_or_insert(guild_id);
let result = {
let mut handler = call.lock().await;
handler.join_gateway(channel_id).await
};
(call, result)
}
/// Retrieves the [handler][`Call`] for the given target and leaves the
/// associated voice channel, if connected.
///
/// This will _not_ drop the handler, and will preserve it and its settings.
///
/// This is a wrapper around [getting][`get`] a handler and calling
/// [`leave`] on it.
///
/// [`Call`]: struct.Call.html
/// [`get`]: #method.get
/// [`leave`]: struct.Call.html#method.leave
#[inline]
pub async fn leave<G: Into<GuildId>>(&self, guild_id: G) -> JoinResult<()> {
self._leave(guild_id.into()).await
}
async fn _leave(&self, guild_id: GuildId) -> JoinResult<()> {
if let Some(call) = self.get(guild_id) {
let mut handler = call.lock().await;
handler.leave().await
} else {
Err(JoinError::NoCall)
}
}
/// Retrieves the [`Call`] for the given target and leaves the associated
/// voice channel, if connected.
///
/// The handler is then dropped, removing settings for the target.
///
/// An Err(...) value implies that the gateway could not be contacted,
/// and that leaving should be attempted again later (i.e., after reconnect).
///
/// [`Call`]: struct.Call.html
#[inline]
pub async fn remove<G: Into<GuildId>>(&self, guild_id: G) -> JoinResult<()> {
self._remove(guild_id.into()).await
}
async fn _remove(&self, guild_id: GuildId) -> JoinResult<()> {
self.leave(guild_id).await?;
let mut calls = self.calls.write();
calls.remove(&guild_id);
Ok(())
}
}
#[cfg(feature = "twilight")]
impl Songbird {
/// Handle events received on the cluster.
///
/// When using twilight, you are required to call this with all inbound
/// (voice) events, *i.e.*, at least `VoiceStateUpdate`s and `VoiceServerUpdate`s.
pub async fn process(&self, event: &TwilightEvent) {
match event {
TwilightEvent::VoiceServerUpdate(v) => {
let call = v.guild_id.map(GuildId::from).and_then(|id| self.get(id));
if let Some(call) = call {
let mut handler = call.lock().await;
if let Some(endpoint) = &v.endpoint {
handler.update_server(endpoint.clone(), v.token.clone());
}
}
},
TwilightEvent::VoiceStateUpdate(v) => {
if v.0.user_id.0 != self.client_data.read().user_id.0 {
return;
}
let call = v.0.guild_id.map(GuildId::from).and_then(|id| self.get(id));
if let Some(call) = call {
let mut handler = call.lock().await;
handler.update_state(v.0.session_id.clone());
}
},
_ => {},
}
}
}
#[cfg(feature = "serenity")]
#[async_trait]
impl VoiceGatewayManager for Songbird {
async fn initialise(&self, shard_count: u64, user_id: SerenityUser) {
self.initialise_client_data(shard_count, user_id);
}
async fn register_shard(&self, shard_id: u64, sender: Sender<InterMessage>) {
self.sharder.register_shard_handle(shard_id, sender);
}
async fn deregister_shard(&self, shard_id: u64) {
self.sharder.deregister_shard_handle(shard_id);
}
async fn server_update(&self, guild_id: SerenityGuild, endpoint: &Option<String>, token: &str) {
if let Some(call) = self.get(guild_id) {
let mut handler = call.lock().await;
if let Some(endpoint) = endpoint {
handler.update_server(endpoint.clone(), token.to_string());
}
}
}
async fn state_update(&self, guild_id: SerenityGuild, voice_state: &VoiceState) {
if voice_state.user_id.0 != self.client_data.read().user_id.0 {
return;
}
if let Some(call) = self.get(guild_id) {
let mut handler = call.lock().await;
handler.update_state(voice_state.session_id.clone());
}
}
}
#[inline]
fn shard_id(guild_id: u64, shard_count: u64) -> u64 {
(guild_id >> 22) % shard_count
}

71
src/serenity.rs Normal file
View File

@@ -0,0 +1,71 @@
//! Compatability and convenience methods for working with [serenity].
//! Requires the `"serenity-rustls"` or `"serenity-native"` features.
//!
//! [serenity]: https://crates.io/crates/serenity/0.9.0-rc.2
use crate::manager::Songbird;
use serenity::{
client::{ClientBuilder, Context},
prelude::TypeMapKey,
};
use std::sync::Arc;
/// Zero-size type used to retrieve the registered [`Songbird`] instance
/// from serenity's inner TypeMap.
///
/// [`Songbird`]: ../struct.Songbird.html
pub struct SongbirdKey;
impl TypeMapKey for SongbirdKey {
type Value = Arc<Songbird>;
}
/// Installs a new songbird instance into the serenity client.
///
/// This should be called after any uses of `ClientBuilder::type_map`.
pub fn register(client_builder: ClientBuilder) -> ClientBuilder {
let voice = Songbird::serenity();
register_with(client_builder, voice)
}
/// Installs a given songbird instance into the serenity client.
///
/// This should be called after any uses of `ClientBuilder::type_map`.
pub fn register_with(client_builder: ClientBuilder, voice: Arc<Songbird>) -> ClientBuilder {
client_builder
.voice_manager_arc(voice.clone())
.type_map_insert::<SongbirdKey>(voice)
}
/// Retrieve the Songbird voice client from a serenity context's
/// shared key-value store.
pub async fn get(ctx: &Context) -> Option<Arc<Songbird>> {
let data = ctx.data.read().await;
data.get::<SongbirdKey>().cloned()
}
/// Helper trait to add installation/creation methods to serenity's
/// `ClientBuilder`.
///
/// These install the client to receive gateway voice events, and
/// store an easily accessible reference to Songbird's managers.
pub trait SerenityInit {
/// Registers a new Songbird voice system with serenity, storing it for easy
/// access via [`get`].
///
/// [`get`]: fn.get.html
fn register_songbird(self) -> Self;
/// Registers a given Songbird voice system with serenity, as above.
fn register_songbird_with(self, voice: Arc<Songbird>) -> Self;
}
impl SerenityInit for ClientBuilder<'_> {
fn register_songbird(self) -> Self {
register(self)
}
fn register_songbird_with(self, voice: Arc<Songbird>) -> Self {
register_with(self, voice)
}
}

168
src/shards.rs Normal file
View File

@@ -0,0 +1,168 @@
//! Handlers for sending packets over sharded connections.
use crate::error::{JoinError, JoinResult};
#[cfg(feature = "serenity")]
use futures::channel::mpsc::{TrySendError, UnboundedSender as Sender};
#[cfg(feature = "serenity")]
use parking_lot::{lock_api::RwLockWriteGuard, Mutex as PMutex, RwLock as PRwLock};
use serde_json::Value;
#[cfg(feature = "serenity")]
use serenity::gateway::InterMessage;
#[cfg(feature = "serenity")]
use std::{collections::HashMap, result::Result as StdResult, sync::Arc};
use tracing::error;
#[cfg(feature = "twilight")]
use twilight_gateway::{Cluster, Shard as TwilightShard};
#[derive(Debug)]
#[non_exhaustive]
/// Source of individual shard connection handles.
pub enum Sharder {
#[cfg(feature = "serenity")]
/// Serenity-specific wrapper for sharder state initialised by the library.
Serenity(SerenitySharder),
#[cfg(feature = "twilight")]
/// Twilight-specific wrapper for sharder state initialised by the user.
Twilight(Cluster),
}
impl Sharder {
#[allow(unreachable_patterns)]
/// Returns a new handle to the required inner shard.
pub fn get_shard(&self, shard_id: u64) -> Option<Shard> {
match self {
#[cfg(feature = "serenity")]
Sharder::Serenity(s) => Some(Shard::Serenity(s.get_or_insert_shard_handle(shard_id))),
#[cfg(feature = "twilight")]
Sharder::Twilight(t) => t.shard(shard_id).map(Shard::Twilight),
_ => None,
}
}
}
#[cfg(feature = "serenity")]
impl Sharder {
#[allow(unreachable_patterns)]
pub(crate) fn register_shard_handle(&self, shard_id: u64, sender: Sender<InterMessage>) {
match self {
Sharder::Serenity(s) => s.register_shard_handle(shard_id, sender),
_ => error!("Called serenity management function on a non-serenity Songbird instance."),
}
}
#[allow(unreachable_patterns)]
pub(crate) fn deregister_shard_handle(&self, shard_id: u64) {
match self {
Sharder::Serenity(s) => s.deregister_shard_handle(shard_id),
_ => error!("Called serenity management function on a non-serenity Songbird instance."),
}
}
}
#[cfg(feature = "serenity")]
#[derive(Debug, Default)]
/// Serenity-specific wrapper for sharder state initialised by the library.
///
/// This is updated and maintained by the library, and is designed to prevent
/// message loss during rebalances and reconnects.
pub struct SerenitySharder(PRwLock<HashMap<u64, Arc<SerenityShardHandle>>>);
#[cfg(feature = "serenity")]
impl SerenitySharder {
fn get_or_insert_shard_handle(&self, shard_id: u64) -> Arc<SerenityShardHandle> {
({
let map_read = self.0.read();
map_read.get(&shard_id).cloned()
})
.unwrap_or_else(|| {
let mut map_read = self.0.write();
map_read.entry(shard_id).or_default().clone()
})
}
fn register_shard_handle(&self, shard_id: u64, sender: Sender<InterMessage>) {
// Write locks are only used to add new entries to the map.
let handle = self.get_or_insert_shard_handle(shard_id);
handle.register(sender);
}
fn deregister_shard_handle(&self, shard_id: u64) {
// Write locks are only used to add new entries to the map.
let handle = self.get_or_insert_shard_handle(shard_id);
handle.deregister();
}
}
#[derive(Clone, Debug)]
#[non_exhaustive]
/// A reference to an individual websocket connection.
pub enum Shard {
#[cfg(feature = "serenity")]
/// Handle to one of serenity's shard runners.
Serenity(Arc<SerenityShardHandle>),
#[cfg(feature = "twilight")]
/// Handle to a twilight shard spawned from a cluster.
Twilight(TwilightShard),
}
impl Shard {
#[allow(unreachable_patterns)]
/// Send a JSON message to the inner shard handle.
pub async fn send(&mut self, msg: Value) -> JoinResult<()> {
match self {
#[cfg(feature = "serenity")]
Shard::Serenity(s) => s.send(InterMessage::Json(msg))?,
#[cfg(feature = "twilight")]
Shard::Twilight(t) => t.command(&msg).await?,
_ => return Err(JoinError::NoSender),
}
Ok(())
}
}
#[cfg(feature = "serenity")]
/// Handle to an individual shard designed to buffer unsent messages while
/// a reconnect/rebalance is ongoing.
#[derive(Debug, Default)]
pub struct SerenityShardHandle {
sender: PRwLock<Option<Sender<InterMessage>>>,
queue: PMutex<Vec<InterMessage>>,
}
#[cfg(feature = "serenity")]
impl SerenityShardHandle {
fn register(&self, sender: Sender<InterMessage>) {
let mut sender_lock = self.sender.write();
*sender_lock = Some(sender);
let sender_lock = RwLockWriteGuard::downgrade(sender_lock);
let mut messages_lock = self.queue.lock();
if let Some(sender) = &*sender_lock {
for msg in messages_lock.drain(..) {
if let Err(e) = sender.unbounded_send(msg) {
error!("Error while clearing gateway message queue: {:?}", e);
break;
}
}
}
}
fn deregister(&self) {
let mut sender_lock = self.sender.write();
*sender_lock = None;
}
fn send(&self, message: InterMessage) -> StdResult<(), TrySendError<InterMessage>> {
let sender_lock = self.sender.read();
if let Some(sender) = &*sender_lock {
sender.unbounded_send(message)
} else {
let mut messages_lock = self.queue.lock();
messages_lock.push(message);
Ok(())
}
}
}

53
src/tracks/command.rs Normal file
View File

@@ -0,0 +1,53 @@
use super::*;
use crate::events::EventData;
use std::time::Duration;
use tokio::sync::oneshot::Sender as OneshotSender;
/// A request from external code using a [`TrackHandle`] to modify
/// or act upon an [`Track`] object.
///
/// [`Track`]: struct.Track.html
/// [`TrackHandle`]: struct.TrackHandle.html
pub enum TrackCommand {
/// Set the track's play_mode to play/resume.
Play,
/// Set the track's play_mode to pause.
Pause,
/// Stop the target track. This cannot be undone.
Stop,
/// Set the track's volume.
Volume(f32),
/// Seek to the given duration.
///
/// On unsupported input types, this can be fatal.
Seek(Duration),
/// Register an event on this track.
AddEvent(EventData),
/// Run some closure on this track, with direct access to the core object.
Do(Box<dyn FnOnce(&mut Track) + Send + Sync + 'static>),
/// Request a read-only view of this track's state.
Request(OneshotSender<Box<TrackState>>),
/// Change the loop count/strategy of this track.
Loop(LoopState),
}
impl std::fmt::Debug for TrackCommand {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
use TrackCommand::*;
write!(
f,
"TrackCommand::{}",
match self {
Play => "Play".to_string(),
Pause => "Pause".to_string(),
Stop => "Stop".to_string(),
Volume(vol) => format!("Volume({})", vol),
Seek(d) => format!("Seek({:?})", d),
AddEvent(evt) => format!("AddEvent({:?})", evt),
Do(_f) => "Do([function])".to_string(),
Request(tx) => format!("Request({:?})", tx),
Loop(loops) => format!("Loop({:?})", loops),
}
)
}
}

159
src/tracks/handle.rs Normal file
View File

@@ -0,0 +1,159 @@
use super::*;
use crate::events::{Event, EventData, EventHandler};
use std::time::Duration;
use tokio::sync::{
mpsc::{error::SendError, UnboundedSender},
oneshot,
};
#[derive(Clone, Debug)]
/// Handle for safe control of a [`Track`] track from other threads, outside
/// of the audio mixing and voice handling context.
///
/// Almost all method calls here are fallible; in most cases, this will be because
/// the underlying [`Track`] object has been discarded. Those which aren't refer
/// to immutable properties of the underlying stream.
///
/// [`Track`]: struct.Track.html
pub struct TrackHandle {
command_channel: UnboundedSender<TrackCommand>,
seekable: bool,
}
impl TrackHandle {
/// Creates a new handle, using the given command sink and hint as to whether
/// the underlying [`Input`] supports seek operations.
///
/// [`Input`]: ../input/struct.Input.html
pub fn new(command_channel: UnboundedSender<TrackCommand>, seekable: bool) -> Self {
Self {
command_channel,
seekable,
}
}
/// Unpauses an audio track.
pub fn play(&self) -> TrackResult {
self.send(TrackCommand::Play)
}
/// Pauses an audio track.
pub fn pause(&self) -> TrackResult {
self.send(TrackCommand::Pause)
}
/// Stops an audio track.
///
/// This is *final*, and will cause the audio context to fire
/// a [`TrackEvent::End`] event.
///
/// [`TrackEvent::End`]: ../events/enum.TrackEvent.html#variant.End
pub fn stop(&self) -> TrackResult {
self.send(TrackCommand::Stop)
}
/// Sets the volume of an audio track.
pub fn set_volume(&self, volume: f32) -> TrackResult {
self.send(TrackCommand::Volume(volume))
}
/// Denotes whether the underlying [`Input`] stream is compatible with arbitrary seeking.
///
/// If this returns `false`, all calls to [`seek`] will fail, and the track is
/// incapable of looping.
///
/// [`seek`]: #method.seek
/// [`Input`]: ../input/struct.Input.html
pub fn is_seekable(&self) -> bool {
self.seekable
}
/// Seeks along the track to the specified position.
///
/// If the underlying [`Input`] does not support this behaviour,
/// then all calls will fail.
///
/// [`Input`]: ../input/struct.Input.html
pub fn seek_time(&self, position: Duration) -> TrackResult {
if self.seekable {
self.send(TrackCommand::Seek(position))
} else {
Err(SendError(TrackCommand::Seek(position)))
}
}
/// Attach an event handler to an audio track. These will receive [`EventContext::Track`].
///
/// Users **must** ensure that no costly work or blocking occurs
/// within the supplied function or closure. *Taking excess time could prevent
/// timely sending of packets, causing audio glitches and delays*.
///
/// [`Track`]: struct.Track.html
/// [`EventContext::Track`]: ../events/enum.EventContext.html#variant.Track
pub fn add_event<F: EventHandler + 'static>(&self, event: Event, action: F) -> TrackResult {
let cmd = TrackCommand::AddEvent(EventData::new(event, action));
if event.is_global_only() {
Err(SendError(cmd))
} else {
self.send(cmd)
}
}
/// Perform an arbitrary action on a raw [`Track`] object.
///
/// Users **must** ensure that no costly work or blocking occurs
/// within the supplied function or closure. *Taking excess time could prevent
/// timely sending of packets, causing audio glitches and delays*.
///
/// [`Track`]: struct.Track.html
pub fn action<F>(&self, action: F) -> TrackResult
where
F: FnOnce(&mut Track) + Send + Sync + 'static,
{
self.send(TrackCommand::Do(Box::new(action)))
}
/// Request playback information and state from the audio context.
///
/// Crucially, the audio thread will respond *at a later time*:
/// It is up to the user when or how this should be read from the returned channel.
pub fn get_info(&self) -> TrackQueryResult {
let (tx, rx) = oneshot::channel();
self.send(TrackCommand::Request(tx)).map(move |_| rx)
}
/// Set an audio track to loop indefinitely.
pub fn enable_loop(&self) -> TrackResult {
if self.seekable {
self.send(TrackCommand::Loop(LoopState::Infinite))
} else {
Err(SendError(TrackCommand::Loop(LoopState::Infinite)))
}
}
/// Set an audio track to no longer loop.
pub fn disable_loop(&self) -> TrackResult {
if self.seekable {
self.send(TrackCommand::Loop(LoopState::Finite(0)))
} else {
Err(SendError(TrackCommand::Loop(LoopState::Finite(0))))
}
}
/// Set an audio track to loop a set number of times.
pub fn loop_for(&self, count: usize) -> TrackResult {
if self.seekable {
self.send(TrackCommand::Loop(LoopState::Finite(count)))
} else {
Err(SendError(TrackCommand::Loop(LoopState::Finite(count))))
}
}
#[inline]
/// Send a raw command to the [`Track`] object.
///
/// [`Track`]: struct.Track.html
pub fn send(&self, cmd: TrackCommand) -> TrackResult {
self.command_channel.send(cmd)
}
}

22
src/tracks/looping.rs Normal file
View File

@@ -0,0 +1,22 @@
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
/// Looping behaviour for a [`Track`].
///
/// [`Track`]: struct.Track.html
pub enum LoopState {
/// Track will loop endlessly until loop state is changed or
/// manually stopped.
Infinite,
/// Track will loop `n` more times.
///
/// `Finite(0)` is the `Default`, stopping the track once its [`Input`] ends.
///
/// [`Input`]: ../input/struct.Input.html
Finite(usize),
}
impl Default for LoopState {
fn default() -> Self {
Self::Finite(0)
}
}

379
src/tracks/mod.rs Normal file
View File

@@ -0,0 +1,379 @@
//! Live, controllable audio instances.
//!
//! Tracks add control and event data around the bytestreams offered by [`Input`],
//! where each represents a live audio source inside of the driver's mixer.
//!
//! To prevent locking and stalling of the driver, tracks are controlled from your bot using a
//! [`TrackHandle`]. These handles remotely send commands from your bot's (a)sync
//! context to control playback, register events, and execute synchronous closures.
//!
//! If you want a new track from an [`Input`], i.e., for direct control before
//! playing your source on the driver, use [`create_player`].
//!
//! [`Input`]: ../input/struct.Input.html
//! [`TrackHandle`]: struct.TrackHandle.html
//! [`create_player`]: fn.create_player.html
mod command;
mod handle;
mod looping;
mod mode;
mod queue;
mod state;
pub use self::{command::*, handle::*, looping::*, mode::*, queue::*, state::*};
use crate::{constants::*, driver::tasks::message::*, events::EventStore, input::Input};
use std::time::Duration;
use tokio::sync::{
mpsc::{
self,
error::{SendError, TryRecvError},
UnboundedReceiver,
},
oneshot::Receiver as OneshotReceiver,
};
/// Control object for audio playback.
///
/// Accessed by both commands and the playback code -- as such, access from user code is
/// almost always guarded via a [`TrackHandle`]. You should expect to receive
/// access to a raw object of this type via [`create_player`], for use in
/// [`Driver::play`] or [`Driver::play_only`].
///
/// # Example
///
/// ```rust,no_run
/// use songbird::{driver::Driver, ffmpeg, tracks::create_player};
///
/// # async {
/// // A Call is also valid here!
/// let mut handler: Driver = Default::default();
/// let source = ffmpeg("../audio/my-favourite-song.mp3")
/// .await
/// .expect("This might fail: handle this error!");
/// let (mut audio, audio_handle) = create_player(source);
///
/// audio.set_volume(0.5);
///
/// handler.play_only(audio);
///
/// // Future access occurs via audio_handle.
/// # };
/// ```
///
/// [`Driver::play_only`]: ../struct.Driver.html#method.play_only
/// [`Driver::play`]: ../struct.Driver.html#method.play
/// [`TrackHandle`]: struct.TrackHandle.html
/// [`create_player`]: fn.create_player.html
#[derive(Debug)]
pub struct Track {
/// Whether or not this sound is currently playing.
///
/// Can be controlled with [`play`] or [`pause`] if chaining is desired.
///
/// [`play`]: #method.play
/// [`pause`]: #method.pause
pub(crate) playing: PlayMode,
/// The desired volume for playback.
///
/// Sensible values fall between `0.0` and `1.0`.
///
/// Can be controlled with [`volume`] if chaining is desired.
///
/// [`volume`]: #method.volume
pub(crate) volume: f32,
/// Underlying data access object.
///
/// *Calling code is not expected to use this.*
pub(crate) source: Input,
/// The current playback position in the track.
pub(crate) position: Duration,
/// The total length of time this track has been active.
pub(crate) play_time: Duration,
/// List of events attached to this audio track.
///
/// This may be used to add additional events to a track
/// before it is sent to the audio context for playing.
pub events: Option<EventStore>,
/// Channel from which commands are received.
///
/// Track commands are sent in this manner to ensure that access
/// occurs in a thread-safe manner, without allowing any external
/// code to lock access to audio objects and block packet generation.
pub(crate) commands: UnboundedReceiver<TrackCommand>,
/// Handle for safe control of this audio track from other threads.
///
/// Typically, this is used by internal code to supply context information
/// to event handlers, though more may be cloned from this handle.
pub handle: TrackHandle,
/// Count of remaining loops.
pub loops: LoopState,
}
impl Track {
/// Create a new track directly from an input, command source,
/// and handle.
///
/// In general, you should probably use [`create_player`].
///
/// [`create_player`]: fn.create_player.html
pub fn new_raw(
source: Input,
commands: UnboundedReceiver<TrackCommand>,
handle: TrackHandle,
) -> Self {
Self {
playing: Default::default(),
volume: 1.0,
source,
position: Default::default(),
play_time: Default::default(),
events: Some(EventStore::new_local()),
commands,
handle,
loops: LoopState::Finite(0),
}
}
/// Sets a track to playing if it is paused.
pub fn play(&mut self) -> &mut Self {
self.set_playing(PlayMode::Play)
}
/// Pauses a track if it is playing.
pub fn pause(&mut self) -> &mut Self {
self.set_playing(PlayMode::Pause)
}
/// Manually stops a track.
///
/// This will cause the audio track to be removed, with any relevant events triggered.
/// Stopped/ended tracks cannot be restarted.
pub fn stop(&mut self) -> &mut Self {
self.set_playing(PlayMode::Stop)
}
pub(crate) fn end(&mut self) -> &mut Self {
self.set_playing(PlayMode::End)
}
#[inline]
fn set_playing(&mut self, new_state: PlayMode) -> &mut Self {
self.playing = self.playing.change_to(new_state);
self
}
/// Returns the current play status of this track.
pub fn playing(&self) -> PlayMode {
self.playing
}
/// Sets [`volume`] in a manner that allows method chaining.
///
/// [`volume`]: #structfield.volume
pub fn set_volume(&mut self, volume: f32) -> &mut Self {
self.volume = volume;
self
}
/// Returns the current playback position.
pub fn volume(&self) -> f32 {
self.volume
}
/// Returns the current playback position.
pub fn position(&self) -> Duration {
self.position
}
/// Returns the total length of time this track has been active.
pub fn play_time(&self) -> Duration {
self.play_time
}
/// Sets [`loops`] in a manner that allows method chaining.
///
/// [`loops`]: #structfield.loops
pub fn set_loops(&mut self, loops: LoopState) -> &mut Self {
self.loops = loops;
self
}
pub(crate) fn do_loop(&mut self) -> bool {
match self.loops {
LoopState::Infinite => true,
LoopState::Finite(0) => false,
LoopState::Finite(ref mut n) => {
*n -= 1;
true
},
}
}
/// Steps playback location forward by one frame.
pub(crate) fn step_frame(&mut self) {
self.position += TIMESTEP_LENGTH;
self.play_time += TIMESTEP_LENGTH;
}
/// Receives and acts upon any commands forwarded by [`TrackHandle`]s.
///
/// *Used internally*, this should not be exposed to users.
///
/// [`TrackHandle`]: struct.TrackHandle.html
pub(crate) fn process_commands(&mut self, index: usize, ic: &Interconnect) {
// Note: disconnection and an empty channel are both valid,
// and should allow the audio object to keep running as intended.
// Note that interconnect failures are not currently errors.
// In correct operation, the event thread should never panic,
// but it receiving status updates is secondary do actually
// doing the work.
loop {
match self.commands.try_recv() {
Ok(cmd) => {
use TrackCommand::*;
match cmd {
Play => {
self.play();
let _ = ic.events.send(EventMessage::ChangeState(
index,
TrackStateChange::Mode(self.playing),
));
},
Pause => {
self.pause();
let _ = ic.events.send(EventMessage::ChangeState(
index,
TrackStateChange::Mode(self.playing),
));
},
Stop => {
self.stop();
let _ = ic.events.send(EventMessage::ChangeState(
index,
TrackStateChange::Mode(self.playing),
));
},
Volume(vol) => {
self.set_volume(vol);
let _ = ic.events.send(EventMessage::ChangeState(
index,
TrackStateChange::Volume(self.volume),
));
},
Seek(time) => {
self.seek_time(time);
let _ = ic.events.send(EventMessage::ChangeState(
index,
TrackStateChange::Position(self.position),
));
},
AddEvent(evt) => {
let _ = ic.events.send(EventMessage::AddTrackEvent(index, evt));
},
Do(action) => {
action(self);
let _ = ic.events.send(EventMessage::ChangeState(
index,
TrackStateChange::Total(self.state()),
));
},
Request(tx) => {
let _ = tx.send(Box::new(self.state()));
},
Loop(loops) => {
self.set_loops(loops);
let _ = ic.events.send(EventMessage::ChangeState(
index,
TrackStateChange::Loops(self.loops, true),
));
},
}
},
Err(TryRecvError::Closed) => {
// this branch will never be visited.
break;
},
Err(TryRecvError::Empty) => {
break;
},
}
}
}
/// Creates a read-only copy of the audio track's state.
///
/// The primary use-case of this is sending information across
/// threads in response to a [`TrackHandle`].
///
/// [`TrackHandle`]: struct.TrackHandle.html
pub fn state(&self) -> TrackState {
TrackState {
playing: self.playing,
volume: self.volume,
position: self.position,
play_time: self.play_time,
loops: self.loops,
}
}
/// Seek to a specific point in the track.
///
/// Returns `None` if unsupported.
pub fn seek_time(&mut self, pos: Duration) -> Option<Duration> {
let out = self.source.seek_time(pos);
if let Some(t) = out {
self.position = t;
}
out
}
}
/// Creates a [`Track`] object to pass into the audio context, and a [`TrackHandle`]
/// for safe, lock-free access in external code.
///
/// Typically, this would be used if you wished to directly work on or configure
/// the [`Track`] object before it is passed over to the driver.
///
/// [`Track`]: struct.Track.html
/// [`TrackHandle`]: struct.TrackHandle.html
pub fn create_player(source: Input) -> (Track, TrackHandle) {
let (tx, rx) = mpsc::unbounded_channel();
let can_seek = source.is_seekable();
let player = Track::new_raw(source, rx, TrackHandle::new(tx.clone(), can_seek));
(player, TrackHandle::new(tx, can_seek))
}
/// Alias for most result-free calls to a [`TrackHandle`].
///
/// Failure indicates that the accessed audio object has been
/// removed or deleted by the audio context.
///
/// [`TrackHandle`]: struct.TrackHandle.html
pub type TrackResult = Result<(), SendError<TrackCommand>>;
/// Alias for return value from calls to [`TrackHandle::get_info`].
///
/// Crucially, the audio thread will respond *at a later time*:
/// It is up to the user when or how this should be read from the returned channel.
///
/// Failure indicates that the accessed audio object has been
/// removed or deleted by the audio context.
///
/// [`TrackHandle::get_info`]: struct.TrackHandle.html#method.get_info
pub type TrackQueryResult = Result<OneshotReceiver<Box<TrackState>>, SendError<TrackCommand>>;

37
src/tracks/mode.rs Normal file
View File

@@ -0,0 +1,37 @@
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
/// Playback status of a track.
pub enum PlayMode {
/// The track is currently playing.
Play,
/// The track is currently paused, and may be resumed.
Pause,
/// The track has been manually stopped, and cannot be restarted.
Stop,
/// The track has naturally ended, and cannot be restarted.
End,
}
impl PlayMode {
/// Returns whether the track has irreversibly stopped.
pub fn is_done(self) -> bool {
matches!(self, PlayMode::Stop | PlayMode::End)
}
pub(crate) fn change_to(self, other: Self) -> PlayMode {
use PlayMode::*;
// Idea: a finished track cannot be restarted -- this action is final.
// We may want to change this in future so that seekable tracks can uncancel
// themselves, perhaps, but this requires a bit more machinery to readd...
match self {
Play | Pause => other,
state => state,
}
}
}
impl Default for PlayMode {
fn default() -> Self {
PlayMode::Play
}
}

213
src/tracks/queue.rs Normal file
View File

@@ -0,0 +1,213 @@
use crate::{
driver::Driver,
events::{Event, EventContext, EventData, EventHandler, TrackEvent},
input::Input,
tracks::{self, Track, TrackHandle, TrackResult},
};
use async_trait::async_trait;
use parking_lot::Mutex;
use std::{collections::VecDeque, sync::Arc};
use tracing::{info, warn};
#[derive(Default)]
/// A simple queue for several audio sources, designed to
/// play in sequence.
///
/// This makes use of [`TrackEvent`]s to determine when the current
/// song or audio file has finished before playing the next entry.
///
/// `examples/e16_voice_events` demonstrates how a user might manage,
/// track and use this to run a song queue in many guilds in parallel.
/// This code is trivial to extend if extra functionality is needed.
///
/// # Example
///
/// ```rust,no_run
/// use songbird::{
/// driver::Driver,
/// id::GuildId,
/// ffmpeg,
/// tracks::{create_player, TrackQueue},
/// };
/// use std::collections::HashMap;
///
/// # async {
/// let guild = GuildId(0);
/// // A Call is also valid here!
/// let mut driver: Driver = Default::default();
///
/// let mut queues: HashMap<GuildId, TrackQueue> = Default::default();
///
/// let source = ffmpeg("../audio/my-favourite-song.mp3")
/// .await
/// .expect("This might fail: handle this error!");
///
/// // We need to ensure that this guild has a TrackQueue created for it.
/// let queue = queues.entry(guild)
/// .or_default();
///
/// // Queueing a track is this easy!
/// queue.add_source(source, &mut driver);
/// # };
/// ```
///
/// [`TrackEvent`]: ../events/enum.TrackEvent.html
pub struct TrackQueue {
// NOTE: the choice of a parking lot mutex is quite deliberate
inner: Arc<Mutex<TrackQueueCore>>,
}
#[derive(Default)]
/// Inner portion of a [`TrackQueue`].
///
/// This abstracts away thread-safety from the user,
/// and offers a convenient location to store further state if required.
///
/// [`TrackQueue`]: struct.TrackQueue.html
struct TrackQueueCore {
tracks: VecDeque<TrackHandle>,
}
struct QueueHandler {
remote_lock: Arc<Mutex<TrackQueueCore>>,
}
#[async_trait]
impl EventHandler for QueueHandler {
async fn act(&self, ctx: &EventContext<'_>) -> Option<Event> {
let mut inner = self.remote_lock.lock();
let _old = inner.tracks.pop_front();
info!("Queued track ended: {:?}.", ctx);
info!("{} tracks remain.", inner.tracks.len());
// If any audio files die unexpectedly, then keep going until we
// find one which works, or we run out.
let mut keep_looking = true;
while keep_looking && !inner.tracks.is_empty() {
if let Some(new) = inner.tracks.front() {
keep_looking = new.play().is_err();
// Discard files which cannot be used for whatever reason.
if keep_looking {
warn!("Track in Queue couldn't be played...");
let _ = inner.tracks.pop_front();
}
}
}
None
}
}
impl TrackQueue {
/// Create a new, empty, track queue.
pub fn new() -> Self {
Self {
inner: Arc::new(Mutex::new(TrackQueueCore {
tracks: VecDeque::new(),
})),
}
}
/// Adds an audio source to the queue, to be played in the channel managed by `handler`.
pub fn add_source(&self, source: Input, handler: &mut Driver) {
let (audio, audio_handle) = tracks::create_player(source);
self.add(audio, audio_handle, handler);
}
/// Adds a [`Track`] object to the queue, to be played in the channel managed by `handler`.
///
/// This is used with [`voice::create_player`] if additional configuration or event handlers
/// are required before enqueueing the audio track.
///
/// [`Track`]: struct.Track.html
/// [`voice::create_player`]: fn.create_player.html
pub fn add(&self, mut track: Track, track_handle: TrackHandle, handler: &mut Driver) {
info!("Track added to queue.");
let remote_lock = self.inner.clone();
let mut inner = self.inner.lock();
if !inner.tracks.is_empty() {
track.pause();
}
track
.events
.as_mut()
.expect("Queue inspecting EventStore on new Track: did not exist.")
.add_event(
EventData::new(Event::Track(TrackEvent::End), QueueHandler { remote_lock }),
track.position,
);
handler.play(track);
inner.tracks.push_back(track_handle);
}
/// Returns the number of tracks currently in the queue.
pub fn len(&self) -> usize {
let inner = self.inner.lock();
inner.tracks.len()
}
/// Returns whether there are no tracks currently in the queue.
pub fn is_empty(&self) -> bool {
let inner = self.inner.lock();
inner.tracks.is_empty()
}
/// Pause the track at the head of the queue.
pub fn pause(&self) -> TrackResult {
let inner = self.inner.lock();
if let Some(handle) = inner.tracks.front() {
handle.pause()
} else {
Ok(())
}
}
/// Resume the track at the head of the queue.
pub fn resume(&self) -> TrackResult {
let inner = self.inner.lock();
if let Some(handle) = inner.tracks.front() {
handle.play()
} else {
Ok(())
}
}
/// Stop the currently playing track, and clears the queue.
pub fn stop(&self) -> TrackResult {
let mut inner = self.inner.lock();
let out = inner.stop_current();
inner.tracks.clear();
out
}
/// Skip to the next track in the queue, if it exists.
pub fn skip(&self) -> TrackResult {
let inner = self.inner.lock();
inner.stop_current()
}
}
impl TrackQueueCore {
/// Skip to the next track in the queue, if it exists.
fn stop_current(&self) -> TrackResult {
if let Some(handle) = self.tracks.front() {
handle.stop()
} else {
Ok(())
}
}
}

31
src/tracks/state.rs Normal file
View File

@@ -0,0 +1,31 @@
use super::*;
/// State of an [`Track`] object, designed to be passed to event handlers
/// and retrieved remotely via [`TrackHandle::get_info`] or
/// [`TrackHandle::get_info_blocking`].
///
/// [`Track`]: struct.Track.html
/// [`TrackHandle::get_info`]: struct.TrackHandle.html#method.get_info
/// [`TrackHandle::get_info_blocking`]: struct.TrackHandle.html#method.get_info_blocking
#[derive(Copy, Clone, Debug, Default, PartialEq)]
pub struct TrackState {
/// Play status (e.g., active, paused, stopped) of this track.
pub playing: PlayMode,
/// Current volume of this track.
pub volume: f32,
/// Current playback position in the source.
///
/// This is altered by loops and seeks
pub position: Duration,
/// Total playback time, increasing monotonically.
pub play_time: Duration,
/// Remaining loops on this track.
pub loops: LoopState,
}
impl TrackState {
pub(crate) fn step_frame(&mut self) {
self.position += TIMESTEP_LENGTH;
self.play_time += TIMESTEP_LENGTH;
}
}

208
src/ws.rs Normal file
View File

@@ -0,0 +1,208 @@
// FIXME: this is copied from serenity/src/internal/ws_impl.rs
// To prevent this duplication, we either need to expose this on serenity's API
// (not desirable) or break the common WS elements into a subcrate.
// I believe that decisions is outside of the scope of the voice subcrate PR.
use crate::model::Event;
use async_trait::async_trait;
use async_tungstenite::{
tokio::ConnectStream,
tungstenite::{error::Error as TungsteniteError, protocol::CloseFrame, Message},
WebSocketStream,
};
use futures::{SinkExt, StreamExt, TryStreamExt};
use serde_json::Error as JsonError;
use tokio::time::timeout;
use tracing::{instrument, warn};
pub type WsStream = WebSocketStream<ConnectStream>;
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug)]
pub enum Error {
Json(JsonError),
#[cfg(all(feature = "rustls", not(feature = "native")))]
Tls(RustlsError),
/// The discord voice gateway does not support or offer zlib compression.
/// As a result, only text messages are expected.
UnexpectedBinaryMessage(Vec<u8>),
Ws(TungsteniteError),
WsClosed(Option<CloseFrame<'static>>),
}
impl From<JsonError> for Error {
fn from(e: JsonError) -> Error {
Error::Json(e)
}
}
#[cfg(all(feature = "rustls", not(feature = "native")))]
impl From<RustlsError> for Error {
fn from(e: RustlsError) -> Error {
Error::Tls(e)
}
}
impl From<TungsteniteError> for Error {
fn from(e: TungsteniteError) -> Error {
Error::Ws(e)
}
}
use futures::stream::SplitSink;
#[cfg(all(feature = "rustls", not(feature = "native")))]
use std::{
error::Error as StdError,
fmt::{Display, Formatter, Result as FmtResult},
io::Error as IoError,
};
use url::Url;
#[async_trait]
pub trait ReceiverExt {
async fn recv_json(&mut self) -> Result<Option<Event>>;
async fn recv_json_no_timeout(&mut self) -> Result<Option<Event>>;
}
#[async_trait]
pub trait SenderExt {
async fn send_json(&mut self, value: &Event) -> Result<()>;
}
#[async_trait]
impl ReceiverExt for WsStream {
async fn recv_json(&mut self) -> Result<Option<Event>> {
const TIMEOUT: tokio::time::Duration = tokio::time::Duration::from_millis(500);
let ws_message = match timeout(TIMEOUT, self.next()).await {
Ok(Some(Ok(v))) => Some(v),
Ok(Some(Err(e))) => return Err(e.into()),
Ok(None) | Err(_) => None,
};
convert_ws_message(ws_message)
}
async fn recv_json_no_timeout(&mut self) -> Result<Option<Event>> {
convert_ws_message(self.try_next().await.ok().flatten())
}
}
#[async_trait]
impl SenderExt for SplitSink<WsStream, Message> {
async fn send_json(&mut self, value: &Event) -> Result<()> {
Ok(serde_json::to_string(value)
.map(Message::Text)
.map_err(Error::from)
.map(|m| self.send(m))?
.await?)
}
}
#[async_trait]
impl SenderExt for WsStream {
async fn send_json(&mut self, value: &Event) -> Result<()> {
Ok(serde_json::to_string(value)
.map(Message::Text)
.map_err(Error::from)
.map(|m| self.send(m))?
.await?)
}
}
#[inline]
pub(crate) fn convert_ws_message(message: Option<Message>) -> Result<Option<Event>> {
Ok(match message {
Some(Message::Text(payload)) =>
serde_json::from_str(&payload).map(Some).map_err(|why| {
warn!("Err deserializing text: {:?}; text: {}", why, payload,);
why
})?,
Some(Message::Binary(bytes)) => {
return Err(Error::UnexpectedBinaryMessage(bytes));
},
Some(Message::Close(Some(frame))) => {
return Err(Error::WsClosed(Some(frame)));
},
// Ping/Pong message behaviour is internally handled by tungstenite.
_ => None,
})
}
/// An error that occured while connecting over rustls
#[derive(Debug)]
#[non_exhaustive]
#[cfg(all(feature = "rustls", not(feature = "native")))]
pub enum RustlsError {
/// An error with the handshake in tungstenite
HandshakeError,
/// Standard IO error happening while creating the tcp stream
Io(IoError),
}
#[cfg(all(feature = "rustls", not(feature = "native")))]
impl From<IoError> for RustlsError {
fn from(e: IoError) -> Self {
RustlsError::Io(e)
}
}
#[cfg(all(feature = "rustls", not(feature = "native")))]
impl Display for RustlsError {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
match self {
RustlsError::HandshakeError =>
f.write_str("TLS handshake failed when making the websocket connection"),
RustlsError::Io(inner) => Display::fmt(&inner, f),
}
}
}
#[cfg(all(feature = "rustls", not(feature = "native")))]
impl StdError for RustlsError {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
match self {
RustlsError::Io(inner) => Some(inner),
_ => None,
}
}
}
#[cfg(all(feature = "rustls", not(feature = "native")))]
#[instrument]
pub(crate) async fn create_rustls_client(url: Url) -> Result<WsStream> {
let (stream, _) = async_tungstenite::tokio::connect_async_with_config::<Url>(
url,
Some(async_tungstenite::tungstenite::protocol::WebSocketConfig {
max_message_size: None,
max_frame_size: None,
max_send_queue: None,
}),
)
.await
.map_err(|_| RustlsError::HandshakeError)?;
Ok(stream)
}
#[cfg(feature = "native")]
#[instrument]
pub(crate) async fn create_native_tls_client(url: Url) -> Result<WsStream> {
let (stream, _) = async_tungstenite::tokio::connect_async_with_config::<Url>(
url,
Some(async_tungstenite::tungstenite::protocol::WebSocketConfig {
max_message_size: None,
max_frame_size: None,
max_send_queue: None,
}),
)
.await?;
Ok(stream)
}

10
utils/Cargo.toml Normal file
View File

@@ -0,0 +1,10 @@
[package]
name = "utils"
version = "0.1.0"
authors = ["Kyle Simpson <kyleandrew.simpson@gmail.com>"]
edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
byteorder = "1"

1
utils/README.md Normal file
View File

@@ -0,0 +1 @@
Test utilities for testing and benchmarking songbird.

67
utils/src/lib.rs Normal file
View File

@@ -0,0 +1,67 @@
use byteorder::{LittleEndian, WriteBytesExt};
use std::mem;
pub fn make_sine(float_len: usize, stereo: bool) -> Vec<u8> {
let sample_len = mem::size_of::<f32>();
let byte_len = float_len * sample_len;
// set period to 100 samples == 480Hz sine.
let mut out = vec![0u8; byte_len];
let mut byte_slice = &mut out[..];
for i in 0..float_len {
let x_val = (i as f32) * 50.0 / std::f32::consts::PI;
byte_slice.write_f32::<LittleEndian>(x_val.sin()).unwrap();
}
if stereo {
let mut new_out = vec![0u8; byte_len * 2];
for (mono_chunk, stereo_chunk) in out[..]
.chunks(sample_len)
.zip(new_out[..].chunks_mut(2 * sample_len))
{
stereo_chunk[..sample_len].copy_from_slice(mono_chunk);
stereo_chunk[sample_len..].copy_from_slice(mono_chunk);
}
new_out
} else {
out
}
}
pub fn make_pcm_sine(i16_len: usize, stereo: bool) -> Vec<u8> {
let sample_len = mem::size_of::<i16>();
let byte_len = i16_len * sample_len;
// set period to 100 samples == 480Hz sine.
// amplitude = 10_000
let mut out = vec![0u8; byte_len];
let mut byte_slice = &mut out[..];
for i in 0..i16_len {
let x_val = (i as f32) * 50.0 / std::f32::consts::PI;
byte_slice
.write_i16::<LittleEndian>((x_val.sin() * 10_000.0) as i16)
.unwrap();
}
if stereo {
let mut new_out = vec![0u8; byte_len * 2];
for (mono_chunk, stereo_chunk) in out[..]
.chunks(sample_len)
.zip(new_out[..].chunks_mut(2 * sample_len))
{
stereo_chunk[..sample_len].copy_from_slice(mono_chunk);
stereo_chunk[sample_len..].copy_from_slice(mono_chunk);
}
new_out
} else {
out
}
}