From 812ff4c5bafffc5708a6d5066f1ebadb6d9fc958 Mon Sep 17 00:00:00 2001 From: ThetaDev Date: Wed, 29 Jan 2025 02:07:18 +0100 Subject: [PATCH] fix: ensure downloader futures are send --- downloader/src/lib.rs | 57 ++++++++++++++++++--------------------- downloader/tests/tests.rs | 24 +++++++++++++---- 2 files changed, 45 insertions(+), 36 deletions(-) diff --git a/downloader/src/lib.rs b/downloader/src/lib.rs index 339af94..7bcbdd1 100644 --- a/downloader/src/lib.rs +++ b/downloader/src/lib.rs @@ -15,7 +15,7 @@ use std::{ time::Duration, }; -use futures_util::stream::{self, StreamExt}; +use futures_util::stream::{self, StreamExt, TryStreamExt}; use once_cell::sync::Lazy; use rand::Rng; use regex::Regex; @@ -871,8 +871,8 @@ impl DownloadQuery { if let Some(pb) = pb { pb.set_message(format!("Downloading {name}{attempt_suffix}")) } - download_streams( - &downloads, + let downloads = download_streams( + downloads, &self.dl.i.http, &user_agent, pot, @@ -930,13 +930,9 @@ impl DownloadQuery { } // Delete original files - stream::iter(&downloads) - .map(|d| fs::remove_file(d.file.clone())) - .buffer_unordered(downloads.len()) - .collect::>() - .await - .into_iter() - .collect::>()?; + for d in &downloads { + fs::remove_file(&d.file).await?; + } #[cfg(feature = "indicatif")] if let Some(pb) = pb { @@ -1442,33 +1438,32 @@ struct StreamDownload { } async fn download_streams( - downloads: &Vec, + downloads: Vec, http: &Client, user_agent: &str, pot: Option<&str>, #[cfg(feature = "indicatif")] pb: Option, -) -> Result<()> { - let n = downloads.len(); - - stream::iter(downloads) - .map(|d| { - download_single_file( - &d.url, - &d.file, - http, - user_agent, - pot, - #[cfg(feature = "indicatif")] - pb.clone(), - ) +) -> Result> { + stream::iter(downloads.iter().map(Ok)) + .try_for_each_concurrent(2, |d| { + #[cfg(feature = "indicatif")] + let pb = pb.clone(); + async move { + download_single_file( + &d.url, + &d.file, + http, + user_agent, + pot, + #[cfg(feature = "indicatif")] + pb, + ) + .await + } }) - .buffer_unordered(n) - .collect::>() - .await - .into_iter() - .collect::>>()?; + .await?; - Ok(()) + Ok(downloads) } async fn convert_streams( diff --git a/downloader/tests/tests.rs b/downloader/tests/tests.rs index aa4a05e..b8a3987 100644 --- a/downloader/tests/tests.rs +++ b/downloader/tests/tests.rs @@ -47,11 +47,13 @@ async fn download_music(rp: RustyPipe) { let td = TempDir::default(); let td_path = td.to_path_buf(); - let dl = Downloader::builder() - .audio_tag() - .crop_cover() - .rustypipe(&rp) - .build(); + #[allow(unused_mut)] + let mut dl = Downloader::builder().rustypipe(&rp); + #[cfg(feature = "audiotag")] + { + dl = dl.audio_tag().crop_cover(); + } + let dl = dl.build(); let res = dl .id("bVtv3st8bgc") @@ -111,3 +113,15 @@ fn assert_audio_meta(p: &Path, title: &str, artist: &str, album: &str, date: &st assert_eq!(tags["ALBUM"].as_str(), Some(album)); assert_eq!(tags["DATE"].as_str(), Some(date)); } + +/// This is just a static check to make sure all RustyPipe futures can be sent +/// between threads safely. +/// Otherwise this may cause issues when integrating RustyPipe into async projects. +#[allow(unused)] +async fn all_send_and_sync() { + fn send_and_sync(t: T) {} + + let dl = Downloader::default(); + let dlq = dl.id(""); + send_and_sync(dlq.download()); +}