about summary refs log tree commit diff stats
path: root/src
diff options
context:
space:
mode:
authorBenedikt Peetz <benedikt.peetz@b-peetz.de>2024-08-22 14:22:13 +0200
committerBenedikt Peetz <benedikt.peetz@b-peetz.de>2024-08-22 14:22:13 +0200
commit6bfc7ee06dc1a598014dd5bec659b14a3aa87bbd (patch)
treef12b4892214fd9cd0fbbd206abd6929179f75d2b /src
parenttest(benches/update): Init (diff)
downloadyt-6bfc7ee06dc1a598014dd5bec659b14a3aa87bbd.zip
feat(download): Support limiting the downloader by maximal cache size
Diffstat (limited to 'src')
-rw-r--r--src/cli.rs47
-rw-r--r--src/constants.rs14
-rw-r--r--src/download/download_options.rs2
-rw-r--r--src/download/mod.rs96
-rw-r--r--src/main.rs10
-rw-r--r--src/storage/video_database/extractor_hash.rs2
6 files changed, 153 insertions, 18 deletions
diff --git a/src/cli.rs b/src/cli.rs
index f3f4b7e..a61a57e 100644
--- a/src/cli.rs
+++ b/src/cli.rs
@@ -10,6 +10,7 @@
 
 use std::path::PathBuf;
 
+use anyhow::{bail, Error};
 use chrono::NaiveDate;
 use clap::{ArgAction, Args, Parser, Subcommand};
 use url::Url;
@@ -47,6 +48,11 @@ pub enum Command {
         /// Forcefully re-download all cached videos (i.e. delete the cache path, then download).
         #[arg(short, long)]
         force: bool,
+
+        /// The maximum size the download dir should have. Beware that the value must be given in
+        /// bytes.
+        #[arg(short, long, default_value = "3 GiB", value_parser = byte_parser)]
+        max_cache_size: u64,
     },
 
     /// Watch the already cached (and selected) videos
@@ -102,6 +108,47 @@ pub enum Command {
     },
 }
 
+fn byte_parser(s: &str) -> Result<u64, Error> {
+    const B: u64 = 1;
+
+    const KIB: u64 = 1024 * B;
+    const MIB: u64 = 1024 * KIB;
+    const GIB: u64 = 1024 * MIB;
+
+    const KB: u64 = 1000 * B;
+    const MB: u64 = 1000 * KB;
+    const GB: u64 = 1000 * MB;
+
+    let s = s
+        .chars()
+        .filter(|elem| !elem.is_whitespace())
+        .collect::<String>();
+
+    let number: u64 = s
+        .chars()
+        .take_while(|x| x.is_numeric())
+        .collect::<String>()
+        .parse()?;
+    let extension = s.chars().skip_while(|x| x.is_numeric()).collect::<String>();
+
+    let output = match extension.to_lowercase().as_str() {
+        "" => number,
+        "b" => number * B,
+        "kib" => number * KIB,
+        "mib" => number * MIB,
+        "gib" => number * GIB,
+        "kb" => number * KB,
+        "mb" => number * MB,
+        "gb" => number * GB,
+        other => bail!(
+            "Your extension '{}' is not yet supported. Only KB,MB,GB or KiB,MiB,GiB are supported",
+            other
+        ),
+    };
+
+    Ok(output)
+}
+
 impl Default for Command {
     fn default() -> Self {
         Self::Select {
diff --git a/src/constants.rs b/src/constants.rs
index f4eef3d..00919ce 100644
--- a/src/constants.rs
+++ b/src/constants.rs
@@ -8,9 +8,9 @@
 // You should have received a copy of the License along with this program.
 // If not, see <https://www.gnu.org/licenses/gpl-3.0.txt>.
 
-use std::{env::temp_dir, path::PathBuf};
+use std::{env::temp_dir, fs, path::PathBuf};
 
-use anyhow::Context;
+use anyhow::{Context, Result};
 
 pub const HELP_STR: &str = include_str!("./select/selection_file/help.str");
 pub const LOCAL_COMMENTS_LENGTH: usize = 1000;
@@ -21,9 +21,15 @@ pub const DEFAULT_SUBTITLE_LANGS: &str = "en";
 
 pub const CONCURRENT_DOWNLOADS: u32 = 5;
 // We download to the temp dir to avoid taxing the disk
-pub fn download_dir() -> PathBuf {
+pub fn download_dir(create: bool) -> Result<PathBuf> {
     const DOWNLOAD_DIR: &str = "/tmp/yt";
-    PathBuf::from(DOWNLOAD_DIR)
+    let dir = PathBuf::from(DOWNLOAD_DIR);
+
+    if !dir.exists() && create {
+        fs::create_dir_all(&dir).context("Failed to create the download directory")?
+    }
+
+    Ok(dir)
 }
 
 const PREFIX: &str = "yt";
diff --git a/src/download/download_options.rs b/src/download/download_options.rs
index 17cf66c..04c1600 100644
--- a/src/download/download_options.rs
+++ b/src/download/download_options.rs
@@ -50,7 +50,7 @@ pub fn download_opts(additional_opts: YtDlpOptions) -> serde_json::Map<String, s
       "writeautomaticsub": true,
 
       "outtmpl": {
-        "default": constants::download_dir().join("%(channel)s/%(title)s.%(ext)s"),
+        "default": constants::download_dir(false).expect("We're not creating this dir, thus this function can't error").join("%(channel)s/%(title)s.%(ext)s"),
         "chapter": "%(title)s - %(section_number)03d %(section_title)s [%(id)s].%(ext)s"
       },
       "compat_opts": {},
diff --git a/src/download/mod.rs b/src/download/mod.rs
index 3785876..c3d79b7 100644
--- a/src/download/mod.rs
+++ b/src/download/mod.rs
@@ -8,22 +8,24 @@
 // You should have received a copy of the License along with this program.
 // If not, see <https://www.gnu.org/licenses/gpl-3.0.txt>.
 
-use std::{sync::Arc, time::Duration};
+use std::{collections::HashMap, sync::Arc, time::Duration};
 
 use crate::{
     app::App,
+    constants::download_dir,
     download::download_options::download_opts,
     storage::video_database::{
         downloader::{get_next_uncached_video, set_video_cache_path},
         extractor_hash::ExtractorHash,
         getters::get_video_yt_dlp_opts,
-        Video,
+        Video, YtDlpOptions,
     },
 };
 
-use anyhow::{Context, Result};
-use log::{debug, info};
-use tokio::{task::JoinHandle, time};
+use anyhow::{bail, Context, Result};
+use futures::{future::BoxFuture, FutureExt};
+use log::{debug, info, warn};
+use tokio::{fs, task::JoinHandle, time};
 
 mod download_options;
 
@@ -53,12 +55,14 @@ impl CurrentDownload {
 
 pub struct Downloader {
     current_download: Option<CurrentDownload>,
+    video_size_cache: HashMap<ExtractorHash, u64>,
 }
 
 impl Downloader {
     pub fn new() -> Self {
         Self {
             current_download: None,
+            video_size_cache: HashMap::new(),
         }
     }
 
@@ -67,7 +71,20 @@ impl Downloader {
     /// change which videos it downloads.
     /// This will run, until the database doesn't contain any watchable videos
     pub async fn consume(&mut self, app: Arc<App>, max_cache_size: u64) -> Result<()> {
-        while let Some(next_video) = get_next_uncached_video(app).await? {
+        while let Some(next_video) = get_next_uncached_video(&app).await? {
+            if Self::get_current_cache_allocation().await?
+                + self.get_approx_video_size(&next_video).await?
+                >= max_cache_size
+            {
+                warn!(
+                    "Can't download video: '{}' as it's too large for the cache.",
+                    next_video.title
+                );
+                // Wait and hope, that a large video is deleted from the cache.
+                time::sleep(Duration::from_secs(10)).await;
+                continue;
+            }
+
             if let Some(_) = &self.current_download {
                 let current_download = self.current_download.take().expect("Is Some");
 
@@ -99,7 +116,6 @@ impl Downloader {
                     );
                     // Reset the taken value
                     self.current_download = Some(current_download);
-                    time::sleep(Duration::new(1, 0)).await;
                 }
             } else {
                 info!(
@@ -111,15 +127,75 @@ impl Downloader {
                 self.current_download = Some(new_current_download);
             }
 
-            // if get_allocated_cache().await? < CONCURRENT {
-            //     .cache_video(next_video).await?;
-            // }
+            time::sleep(Duration::new(1, 0)).await;
         }
 
         info!("Finished downloading!");
         Ok(())
     }
 
+    async fn get_current_cache_allocation() -> Result<u64> {
+        fn dir_size(mut dir: fs::ReadDir) -> BoxFuture<'static, Result<u64>> {
+            async move {
+                let mut acc = 0;
+                while let Some(entry) = dir.next_entry().await? {
+                    let size = match entry.metadata().await? {
+                        data if data.is_dir() => {
+                            let path = entry.path();
+                            let read_dir = fs::read_dir(path).await?;
+
+                            dir_size(read_dir).await?
+                        }
+                        data => data.len(),
+                    };
+                    acc += size;
+                }
+                Ok(acc)
+            }
+            .boxed()
+        }
+
+        let val = dir_size(fs::read_dir(download_dir(true)?).await?).await;
+        if let Ok(val) = val.as_ref() {
+            info!("Cache dir has a size of '{}'", val);
+        }
+        val
+    }
+
+    async fn get_approx_video_size(&mut self, video: &Video) -> Result<u64> {
+        if let Some(value) = self.video_size_cache.get(&video.extractor_hash) {
+            Ok(*value)
+        } else {
+            // the subtitle file size should be negligible
+            let add_opts = YtDlpOptions {
+                subtitle_langs: "".to_owned(),
+            };
+            let opts = &download_opts(add_opts);
+
+            let result = yt_dlp::extract_info(&opts, &video.url, false, true)
+                .await
+                .with_context(|| {
+                    format!("Failed to extract video information: '{}'", video.title)
+                })?;
+
+            let size = if let Some(val) = result.filesize {
+                val
+            } else if let Some(val) = result.filesize_approx {
+                val
+            } else {
+                bail!("Failed to find a filesize for video: '{}'", video.title);
+            };
+
+            assert_eq!(
+                self.video_size_cache
+                    .insert(video.extractor_hash.clone(), size),
+                None
+            );
+
+            Ok(size)
+        }
+    }
+
     async fn actually_cache_video(app: &App, video: &Video) -> Result<()> {
         debug!("Download started: {}", &video.title);
 
diff --git a/src/main.rs b/src/main.rs
index c223140..ebbb45f 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -8,13 +8,14 @@
 // You should have received a copy of the License along with this program.
 // If not, see <https://www.gnu.org/licenses/gpl-3.0.txt>.
 
-use std::{collections::HashMap, fs};
+use std::{collections::HashMap, fs, sync::Arc};
 
 use anyhow::{bail, Context, Result};
 use app::App;
 use cache::invalidate;
 use clap::Parser;
 use cli::{CacheCommand, CheckCommand, SelectCommand, SubscriptionCommand};
+use log::info;
 use select::cmds::handle_select_cmd;
 use tokio::{
     fs::File,
@@ -56,7 +57,12 @@ async fn main() -> Result<()> {
     let app = App::new(args.db_path.unwrap_or(constants::database()?)).await?;
 
     match args.command.unwrap_or(Command::default()) {
-        Command::Download { force } => {
+        Command::Download {
+            force,
+            max_cache_size,
+        } => {
+            info!("max cache size: '{}'", max_cache_size);
+
             if force {
                 invalidate(&app, true).await?;
             }
diff --git a/src/storage/video_database/extractor_hash.rs b/src/storage/video_database/extractor_hash.rs
index 3af4f60..593b5c4 100644
--- a/src/storage/video_database/extractor_hash.rs
+++ b/src/storage/video_database/extractor_hash.rs
@@ -19,7 +19,7 @@ use crate::{app::App, storage::video_database::getters::get_all_hashes};
 
 static EXTRACTOR_HASH_LENGTH: OnceCell<usize> = OnceCell::const_new();
 
-#[derive(Debug, Clone, PartialEq, Eq)]
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 pub struct ExtractorHash {
     hash: Hash,
 }