about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--crates/yt/Cargo.toml1
-rw-r--r--crates/yt/src/update/mod.rs20
-rw-r--r--crates/yt/src/update/updater.rs198
3 files changed, 116 insertions, 103 deletions
diff --git a/crates/yt/Cargo.toml b/crates/yt/Cargo.toml
index 17d4016..6803e68 100644
--- a/crates/yt/Cargo.toml
+++ b/crates/yt/Cargo.toml
@@ -50,6 +50,7 @@ yt_dlp.workspace = true
 termsize.workspace = true
 uu_fmt.workspace = true
 notify = { version = "8.0.0", default-features = false }
+tokio-util = { version = "0.7.15", features = ["rt"] }
 
 [[bin]]
 name = "yt"
diff --git a/crates/yt/src/update/mod.rs b/crates/yt/src/update/mod.rs
index 07674de..a25c233 100644
--- a/crates/yt/src/update/mod.rs
+++ b/crates/yt/src/update/mod.rs
@@ -13,7 +13,7 @@ use std::{str::FromStr, time::Duration};
 
 use anyhow::{Context, Ok, Result};
 use chrono::{DateTime, Utc};
-use log::{info, warn};
+use log::warn;
 use url::Url;
 use yt_dlp::{InfoJson, json_cast, json_get};
 
@@ -39,12 +39,12 @@ pub async fn update(
 ) -> Result<()> {
     let subscriptions = subscriptions::get(app).await?;
 
-    let urls: Vec<_> = if subscription_names_to_update.is_empty() {
-        subscriptions.0.values().collect()
+    let subs: Vec<Subscription> = if subscription_names_to_update.is_empty() {
+        subscriptions.0.into_values().collect()
     } else {
         subscriptions
             .0
-            .values()
+            .into_values()
             .filter(|sub| subscription_names_to_update.contains(&sub.name))
             .collect()
     };
@@ -53,10 +53,8 @@ pub async fn update(
     // should not contain duplicates.
     let hashes = get_all_hashes(app).await?;
 
-    {
-        let mut updater = Updater::new(max_backlog, &hashes);
-        updater.update(app, &urls).await?;
-    }
+    let updater = Updater::new(max_backlog, hashes);
+    updater.update(app, subs).await?;
 
     Ok(())
 }
@@ -182,9 +180,9 @@ pub fn video_entry_to_video(entry: &InfoJson, sub: Option<&Subscription>) -> Res
     Ok(video)
 }
 
-async fn process_subscription(app: &App, sub: &Subscription, entry: InfoJson) -> Result<()> {
-    let video =
-        video_entry_to_video(&entry, Some(sub)).context("Failed to parse search entry as Video")?;
+async fn process_subscription(app: &App, sub: Subscription, entry: InfoJson) -> Result<()> {
+    let video = video_entry_to_video(&entry, Some(&sub))
+        .context("Failed to parse search entry as Video")?;
 
     add_video(app, video.clone())
         .await
diff --git a/crates/yt/src/update/updater.rs b/crates/yt/src/update/updater.rs
index b314172..934b84b 100644
--- a/crates/yt/src/update/updater.rs
+++ b/crates/yt/src/update/updater.rs
@@ -8,17 +8,18 @@
 // You should have received a copy of the License along with this program.
 // If not, see <https://www.gnu.org/licenses/gpl-3.0.txt>.
 
-use std::io::{Write, stderr};
+use std::{
+    io::{Write, stderr},
+    sync::atomic::AtomicUsize,
+};
 
 use anyhow::{Context, Result};
 use blake3::Hash;
-use futures::{
-    StreamExt, TryStreamExt,
-    stream::{self},
-};
+use futures::{StreamExt, future::join_all, stream};
 use log::{Level, debug, error, log_enabled};
 use serde_json::json;
-use yt_dlp::{InfoJson, YoutubeDL, YoutubeDLOptions, json_cast, json_get, process_ie_result};
+use tokio_util::task::LocalPoolHandle;
+use yt_dlp::{InfoJson, YoutubeDLOptions, json_cast, json_get, process_ie_result};
 
 use crate::{
     ansi_escape_codes::{clear_whole_line, move_to_col},
@@ -28,44 +29,41 @@ use crate::{
 
 use super::process_subscription;
 
-pub(super) struct Updater<'a> {
+pub(super) struct Updater {
     max_backlog: usize,
-    hashes: &'a [Hash],
+    hashes: Vec<Hash>,
+    pool: LocalPoolHandle,
 }
 
-impl<'a> Updater<'a> {
-    pub(super) fn new(max_backlog: usize, hashes: &'a [Hash]) -> Self {
+impl Updater {
+    pub(super) fn new(max_backlog: usize, hashes: Vec<Hash>) -> Self {
+        // TODO(@bpeetz): The number should not be hardcoded. <2025-06-14>
+        let pool = LocalPoolHandle::new(16);
+
         Self {
             max_backlog,
             hashes,
+            pool,
         }
     }
 
-    pub(super) async fn update(
-        &mut self,
-        app: &App,
-        subscriptions: &[&Subscription],
-    ) -> Result<()> {
+    pub(super) async fn update(self, app: &App, subscriptions: Vec<Subscription>) -> Result<()> {
         let mut stream = stream::iter(subscriptions)
             .map(|sub| self.get_new_entries(sub))
-            .buffer_unordered(100);
+            .buffer_unordered(16 * 4);
 
         while let Some(output) = stream.next().await {
             let mut entries = output?;
 
-            if entries.is_empty() {
-                continue;
-            }
+            if let Some(next) = entries.next() {
+                let (sub, entry) = next;
+                process_subscription(app, sub, entry).await?;
 
-            let (sub, entry) = entries.remove(0);
-            process_subscription(app, sub, entry).await?;
-
-            let entry_stream: Result<()> = stream::iter(entries)
-                .map(|(sub, entry)| process_subscription(app, sub, entry))
-                .buffer_unordered(100)
-                .try_collect()
-                .await;
-            entry_stream?;
+                join_all(entries.map(|(sub, entry)| process_subscription(app, sub, entry)))
+                    .await
+                    .into_iter()
+                    .collect::<Result<(), _>>()?;
+            }
         }
 
         Ok(())
@@ -73,11 +71,14 @@ impl<'a> Updater<'a> {
 
     async fn get_new_entries(
         &self,
-        sub: &'a Subscription,
-    ) -> Result<Vec<(&'a Subscription, InfoJson)>> {
+        sub: Subscription,
+    ) -> Result<impl Iterator<Item = (Subscription, InfoJson)>> {
+        let max_backlog = self.max_backlog;
+        let hashes = self.hashes.clone();
+
         let yt_dlp = YoutubeDLOptions::new()
             .set("playliststart", 1)
-            .set("playlistend", self.max_backlog)
+            .set("playlistend", max_backlog)
             .set("noplaylist", false)
             .set(
                 "extractor_args",
@@ -88,70 +89,83 @@ impl<'a> Updater<'a> {
             .set("match-filter", "availability=public")
             .build()?;
 
-        if !log_enabled!(Level::Debug) {
-            clear_whole_line();
-            move_to_col(1);
-            eprint!("Checking playlist {}...", sub.name);
-            move_to_col(1);
-            stderr().flush()?;
-        }
-
-        let info = yt_dlp
-            .extract_info(&sub.url, false, false)
-            .with_context(|| format!("Failed to get playlist '{}'.", sub.name))?;
-
-        let empty = vec![];
-        let entries = info
-            .get("entries")
-            .map_or(&empty, |val| json_cast!(val, as_array));
-
-        let valid_entries: Vec<(&Subscription, InfoJson)> = entries
-            .iter()
-            .take(self.max_backlog)
-            .filter_map(|entry| -> Option<(&Subscription, InfoJson)> {
-                let id = json_get!(entry, "id", as_str);
-                let extractor_hash = blake3::hash(id.as_bytes());
-                if self.hashes.contains(&extractor_hash) {
-                    debug!("Skipping entry, as it is already present: '{extractor_hash}'",);
-                    None
-                } else {
-                    Some((sub, json_cast!(entry, as_object).to_owned()))
-                }
-            })
-            .collect();
-
-        let processed_entries: Vec<(&Subscription, InfoJson)> = stream::iter(valid_entries)
-            .map(
-                async |(sub, entry)| match yt_dlp.process_ie_result(entry, false) {
-                    Ok(output) => Ok((sub, output)),
-                    Err(err) => Err(err),
-                },
-            )
-            .buffer_unordered(100)
-            .collect::<Vec<_>>()
-            .await
-            .into_iter()
-            // Don't fail the whole update, if one of the entries fails to fetch.
-            .filter_map(|base| match base {
-                Ok(ok) => Some(ok),
-                Err(err) => {
-                    let process_ie_result::Error::Python(err) = &err;
-
-                    if err.contains("Join this channel to get access to members-only content ") {
-                        // Hide this error
-                    } else {
-                        // Show the error, but don't fail.
-                        let error = err
-                            .strip_prefix("DownloadError: \u{1b}[0;31mERROR:\u{1b}[0m ")
-                            .unwrap_or(err);
-                        error!("{error}");
+        self.pool
+            .spawn_pinned(move || {
+                async move {
+                    if !log_enabled!(Level::Debug) {
+                        clear_whole_line();
+                        move_to_col(1);
+                        eprint!(
+                            "Checking playlist {}...",
+                            sub.name
+                        );
+                        move_to_col(1);
+                        stderr().flush()?;
                     }
 
-                    None
+                    let info = yt_dlp
+                        .extract_info(&sub.url, false, false)
+                        .with_context(|| format!("Failed to get playlist '{}'.", sub.name))?;
+
+                    let empty = vec![];
+                    let entries = info
+                        .get("entries")
+                        .map_or(&empty, |val| json_cast!(val, as_array));
+
+                    let valid_entries: Vec<(Subscription, InfoJson)> = entries
+                        .iter()
+                        .take(max_backlog)
+                        .filter_map(|entry| -> Option<(Subscription, InfoJson)> {
+                            let id = json_get!(entry, "id", as_str);
+                            let extractor_hash = blake3::hash(id.as_bytes());
+
+                            if hashes.contains(&extractor_hash) {
+                                debug!(
+                                    "Skipping entry, as it is already present: '{extractor_hash}'",
+                                );
+                                None
+                            } else {
+                                Some((sub.clone(), json_cast!(entry, as_object).to_owned()))
+                            }
+                        })
+                        .collect();
+
+                    Ok(valid_entries
+                        .into_iter()
+                        .map(|(sub, entry)| {
+                            let inner_yt_dlp = YoutubeDLOptions::new()
+                                .set("noplaylist", true)
+                                .build()
+                                .expect("Worked before, should work now");
+
+                            match inner_yt_dlp.process_ie_result(entry, false) {
+                                Ok(output) => Ok((sub, output)),
+                                Err(err) => Err(err),
+                            }
+                        })
+                        // Don't fail the whole update, if one of the entries fails to fetch.
+                        .filter_map(|base| match base {
+                            Ok(ok) => Some(ok),
+                            Err(err) => {
+                                let process_ie_result::Error::Python(err) = &err;
+
+                                if err.contains(
+                                    "Join this channel to get access to members-only content ",
+                                ) {
+                                    // Hide this error
+                                } else {
+                                    // Show the error, but don't fail.
+                                    let error = err
+                                        .strip_prefix("DownloadError: \u{1b}[0;31mERROR:\u{1b}[0m ")
+                                        .unwrap_or(err);
+                                    error!("{error}");
+                                }
+
+                                None
+                            }
+                        }))
                 }
             })
-            .collect();
-
-        Ok(processed_entries)
+            .await?
     }
 }