Input: Add ytdl search (#210)

* Add ytdl search

* fix fmt

* Remove compose, add tests, return AuxMetadata

* fix parsing of AuxMetadata and better test

* Fix playability of `YoutubeDl::new_search`

Refactors such that parsing of (ND)JSON is handled in only one location
now, which allows us to greatly simplify the actual `search` method. The
main change is that any `new_search` is now instantly playable.

---------

Co-authored-by: Kyle Simpson <kyleandrew.simpson@gmail.com>
This commit is contained in:
Cycle Five
2023-12-12 03:28:13 -05:00
committed by GitHub
parent 873aeae16a
commit d681b71b1f
2 changed files with 120 additions and 25 deletions

View File

@@ -278,15 +278,7 @@ async fn play(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
}, },
}; };
if !url.starts_with("http") { let do_search = !url.starts_with("http");
check_msg(
msg.channel_id
.say(&ctx.http, "Must provide a valid URL")
.await,
);
return Ok(());
}
let guild_id = msg.guild_id.unwrap(); let guild_id = msg.guild_id.unwrap();
@@ -305,8 +297,12 @@ async fn play(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
if let Some(handler_lock) = manager.get(guild_id) { if let Some(handler_lock) = manager.get(guild_id) {
let mut handler = handler_lock.lock().await; let mut handler = handler_lock.lock().await;
let src = YoutubeDl::new(http_client, url); let mut src = if do_search {
let _ = handler.play_input(src.into()); YoutubeDl::new_search(http_client, url)
} else {
YoutubeDl::new(http_client, url)
};
let _ = handler.play_input(src.clone().into());
check_msg(msg.channel_id.say(&ctx.http, "Playing song").await); check_msg(msg.channel_id.say(&ctx.http, "Playing song").await);
} else { } else {

View File

@@ -18,6 +18,12 @@ use tokio::process::Command;
const YOUTUBE_DL_COMMAND: &str = "yt-dlp"; const YOUTUBE_DL_COMMAND: &str = "yt-dlp";
#[derive(Clone, Debug)]
enum QueryType {
Url(String),
Search(String),
}
/// A lazily instantiated call to download a file, finding its URL via youtube-dl. /// A lazily instantiated call to download a file, finding its URL via youtube-dl.
/// ///
/// By default, this uses yt-dlp and is backed by an [`HttpRequest`]. This handler /// By default, this uses yt-dlp and is backed by an [`HttpRequest`]. This handler
@@ -30,7 +36,7 @@ pub struct YoutubeDl {
program: &'static str, program: &'static str,
client: Client, client: Client,
metadata: Option<AuxMetadata>, metadata: Option<AuxMetadata>,
url: String, query: QueryType,
} }
impl YoutubeDl { impl YoutubeDl {
@@ -52,14 +58,63 @@ impl YoutubeDl {
program, program,
client, client,
metadata: None, metadata: None,
url, query: QueryType::Url(url),
} }
} }
async fn query(&mut self) -> Result<Output, AudioStreamError> { /// Creates a request to search youtube for an optionally specified number of videos matching `query`,
/// using "yt-dlp".
#[must_use]
pub fn new_search(client: Client, query: String) -> Self {
Self::new_search_ytdl_like(YOUTUBE_DL_COMMAND, client, query)
}
/// Creates a request to search youtube for an optionally specified number of videos matching `query`,
/// using `program`.
#[must_use]
pub fn new_search_ytdl_like(program: &'static str, client: Client, query: String) -> Self {
Self {
program,
client,
metadata: None,
query: QueryType::Search(query),
}
}
/// Runs a search for the given query, returning a list of up to `n_results`
/// possible matches which are `AuxMetadata` objects containing a valid URL.
///
/// Returns up to 5 matches by default.
pub async fn search(
&mut self,
n_results: Option<usize>,
) -> Result<Vec<AuxMetadata>, AudioStreamError> {
let n_results = n_results.unwrap_or(5);
Ok(match &self.query {
// Safer to just return the metadata for the pointee if possible
QueryType::Url(_) => vec![self.aux_metadata().await?],
QueryType::Search(_) => self
.query(n_results)
.await?
.into_iter()
.map(|v| v.as_aux_metadata())
.collect(),
})
}
async fn query(&mut self, n_results: usize) -> Result<Vec<Output>, AudioStreamError> {
let new_query;
let query_str = match &self.query {
QueryType::Url(url) => url,
QueryType::Search(query) => {
new_query = format!("ytsearch{n_results}:{query}");
&new_query
},
};
let ytdl_args = [ let ytdl_args = [
"-j", "-j",
&self.url, query_str,
"-f", "-f",
"ba[abr>0][vcodec=none]/best", "ba[abr>0][vcodec=none]/best",
"--no-playlist", "--no-playlist",
@@ -77,14 +132,35 @@ impl YoutubeDl {
}) })
})?; })?;
// NOTE: must be mut for simd-json. if !output.status.success() {
#[allow(clippy::unnecessary_mut_passed)] return Err(AudioStreamError::Fail(
let stdout: Output = crate::json::from_slice(&mut output.stdout[..]) format!(
"{} failed with non-zero status code: {}",
self.program,
std::str::from_utf8(&output.stderr[..]).unwrap_or("<no error message>")
)
.into(),
));
}
// NOTE: must be split_mut for simd-json.
let out = output
.stdout
.split_mut(|&b| b == b'\n')
.filter_map(|x| (!x.is_empty()).then(|| crate::json::from_slice(x)))
.collect::<Result<Vec<Output>, _>>()
.map_err(|e| AudioStreamError::Fail(Box::new(e)))?; .map_err(|e| AudioStreamError::Fail(Box::new(e)))?;
self.metadata = Some(stdout.as_aux_metadata()); let meta = out
.first()
.ok_or_else(|| {
AudioStreamError::Fail(format!("no results found for '{query_str}'").into())
})?
.as_aux_metadata();
Ok(stdout) self.metadata = Some(meta);
Ok(out)
} }
} }
@@ -103,11 +179,13 @@ impl Compose for YoutubeDl {
async fn create_async( async fn create_async(
&mut self, &mut self,
) -> Result<AudioStream<Box<dyn MediaSource>>, AudioStreamError> { ) -> Result<AudioStream<Box<dyn MediaSource>>, AudioStreamError> {
let stdout = self.query().await?; // panic safety: `query` should have ensured > 0 results if `Ok`
let mut results = self.query(1).await?;
let result = results.swap_remove(0);
let mut headers = HeaderMap::default(); let mut headers = HeaderMap::default();
if let Some(map) = stdout.http_headers { if let Some(map) = result.http_headers {
headers.extend(map.iter().filter_map(|(k, v)| { headers.extend(map.iter().filter_map(|(k, v)| {
Some(( Some((
HeaderName::from_bytes(k.as_bytes()).ok()?, HeaderName::from_bytes(k.as_bytes()).ok()?,
@@ -118,9 +196,9 @@ impl Compose for YoutubeDl {
let mut req = HttpRequest { let mut req = HttpRequest {
client: self.client.clone(), client: self.client.clone(),
request: stdout.url, request: result.url,
headers, headers,
content_length: stdout.filesize, content_length: result.filesize,
}; };
req.create_async().await req.create_async().await
@@ -135,7 +213,7 @@ impl Compose for YoutubeDl {
return Ok(meta.clone()); return Ok(meta.clone());
} }
self.query().await?; self.query(1).await?;
self.metadata.clone().ok_or_else(|| { self.metadata.clone().ok_or_else(|| {
let msg: Box<dyn Error + Send + Sync + 'static> = let msg: Box<dyn Error + Send + Sync + 'static> =
@@ -185,4 +263,25 @@ mod tests {
assert!(ytdl.aux_metadata().await.is_err()); assert!(ytdl.aux_metadata().await.is_err());
} }
#[tokio::test]
#[ntest::timeout(20_000)]
async fn ytdl_search_plays() {
let mut ytdl = YoutubeDl::new_search(Client::new(), "cloudkicker 94 days".into());
let res = ytdl.search(Some(1)).await;
let res = res.unwrap();
assert_eq!(res.len(), 1);
track_plays_passthrough(move || ytdl).await;
}
#[tokio::test]
#[ntest::timeout(20_000)]
async fn ytdl_search_3() {
let mut ytdl = YoutubeDl::new_search(Client::new(), "test".into());
let res = ytdl.search(Some(3)).await;
assert_eq!(res.unwrap().len(), 3);
}
} }