about summary refs log tree commit diff stats
path: root/yt/src/download/progress_hook.rs
diff options
context:
space:
mode:
Diffstat (limited to 'yt/src/download/progress_hook.rs')
-rw-r--r--yt/src/download/progress_hook.rs190
1 files changed, 190 insertions, 0 deletions
diff --git a/yt/src/download/progress_hook.rs b/yt/src/download/progress_hook.rs
new file mode 100644
index 0000000..65156e7
--- /dev/null
+++ b/yt/src/download/progress_hook.rs
@@ -0,0 +1,190 @@
+use std::{
+    io::{Write, stderr},
+    process,
+};
+
+use bytes::Bytes;
+use log::{Level, log_enabled};
+use yt_dlp::mk_python_function;
+
+use crate::select::selection_file::duration::MaybeDuration;
+
+// #[allow(clippy::too_many_lines)]
+// #[allow(clippy::missing_panics_doc)]
+// #[allow(clippy::items_after_statements)]
+// #[allow(
+//     clippy::cast_possible_truncation,
+//     clippy::cast_sign_loss,
+//     clippy::cast_precision_loss
+// )]
+pub fn progress_hook(
+    input: serde_json::Map<String, serde_json::Value>,
+) -> Result<(), std::io::Error> {
+    // Only add the handler, if the log-level is higher than Debug (this avoids covering debug
+    // messages).
+    if log_enabled!(Level::Debug) {
+        return Ok(());
+    }
+
+    // ANSI ESCAPE CODES Wrappers {{{
+    // see: https://en.wikipedia.org/wiki/ANSI_escape_code#Control_Sequence_Introducer_commands
+    const CSI: &str = "\x1b[";
+    fn clear_whole_line() {
+        eprint!("{CSI}2K");
+    }
+    fn move_to_col(x: usize) {
+        eprint!("{CSI}{x}G");
+    }
+    // }}}
+
+    macro_rules! get {
+        (@interrogate $item:ident, $type_fun:ident, $get_fun:ident, $name:expr) => {{
+            let a = $item.get($name).expect(concat!(
+                "The field '",
+                stringify!($name),
+                "' should exist."
+            ));
+
+            if a.$type_fun() {
+                a.$get_fun().expect(
+                    "The should have been checked in the if guard, so unpacking here is fine",
+                )
+            } else {
+                panic!(
+                    "Value {} => \n{}\n is not of type: {}",
+                    $name,
+                    a,
+                    stringify!($type_fun)
+                );
+            }
+        }};
+
+        ($type_fun:ident, $get_fun:ident, $name1:expr, $name2:expr) => {{
+            let a = get! {@interrogate input, is_object, as_object, $name1};
+            let b = get! {@interrogate a, $type_fun, $get_fun, $name2};
+            b
+        }};
+
+        ($type_fun:ident, $get_fun:ident, $name:expr) => {{
+            get! {@interrogate input, $type_fun, $get_fun, $name}
+        }};
+    }
+
+    macro_rules! default_get {
+        (@interrogate $item:ident, $default:expr, $get_fun:ident, $name:expr) => {{
+            let a = if let Some(field) = $item.get($name) {
+                field.$get_fun().unwrap_or($default)
+            } else {
+                $default
+            };
+            a
+        }};
+
+        ($get_fun:ident, $default:expr, $name1:expr, $name2:expr) => {{
+            let a = get! {@interrogate input, is_object, as_object, $name1};
+            let b = default_get! {@interrogate a, $default, $get_fun, $name2};
+            b
+        }};
+
+        ($get_fun:ident, $default:expr, $name:expr) => {{
+            default_get! {@interrogate input, $default, $get_fun, $name}
+        }};
+    }
+
+    macro_rules! c {
+        ($color:expr, $format:expr) => {
+            format!("\x1b[{}m{}\x1b[0m", $color, $format)
+        };
+    }
+
+    fn format_bytes(bytes: u64) -> String {
+        let bytes = Bytes::new(bytes);
+        bytes.to_string()
+    }
+
+    fn format_speed(speed: f64) -> String {
+        #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
+        let bytes = Bytes::new(speed.floor() as u64);
+        format!("{bytes}/s")
+    }
+
+    let get_title = || -> String {
+        match get! {is_string, as_str, "info_dict", "ext"} {
+            "vtt" => {
+                format!(
+                    "Subtitles ({})",
+                    default_get! {as_str, "<No Subtitle Language>", "info_dict", "name"}
+                )
+            }
+            "webm" | "mp4" | "mp3" | "m4a" => {
+                default_get! { as_str, "<No title>", "info_dict", "title"}.to_owned()
+            }
+            other => panic!("The extension '{other}' is not yet implemented"),
+        }
+    };
+
+    match get! {is_string, as_str, "status"} {
+        "downloading" => {
+            let elapsed = default_get! {as_f64, 0.0f64, "elapsed"};
+            let eta = default_get! {as_f64, 0.0, "eta"};
+            let speed = default_get! {as_f64, 0.0, "speed"};
+
+            let downloaded_bytes = get! {is_u64, as_u64, "downloaded_bytes"};
+            let (total_bytes, bytes_is_estimate): (u64, &'static str) = {
+                let total_bytes = default_get!(as_u64, 0, "total_bytes");
+                if total_bytes == 0 {
+                    let maybe_estimate = default_get!(as_u64, 0, "total_bytes_estimate");
+
+                    if maybe_estimate == 0 {
+                        // The download speed should be in bytes per second and the eta in seconds.
+                        // Thus multiplying them gets us the raw bytes (which were estimated by `yt_dlp`, from their `info.json`)
+                        let bytes_still_needed = (speed * eta).ceil() as u64;
+
+                        (downloaded_bytes + bytes_still_needed, "~")
+                    } else {
+                        (maybe_estimate, "~")
+                    }
+                } else {
+                    (total_bytes, "")
+                }
+            };
+            let percent: f64 = {
+                if total_bytes == 0 {
+                    100.0
+                } else {
+                    (downloaded_bytes as f64 / total_bytes as f64) * 100.0
+                }
+            };
+
+            clear_whole_line();
+            move_to_col(1);
+
+            eprint!(
+                "'{}' [{}/{} at {}] -> [{} of {}{} {}] ",
+                c!("34;1", get_title()),
+                c!("33;1", MaybeDuration::from_secs_f64(elapsed)),
+                c!("33;1", MaybeDuration::from_secs_f64(eta)),
+                c!("32;1", format_speed(speed)),
+                c!("31;1", format_bytes(downloaded_bytes)),
+                c!("31;1", bytes_is_estimate),
+                c!("31;1", format_bytes(total_bytes)),
+                c!("36;1", format!("{:.02}%", percent))
+            );
+            stderr().flush()?;
+        }
+        "finished" => {
+            eprintln!("-> Finished downloading.");
+        }
+        "error" => {
+            // TODO: This should probably return an Err. But I'm not so sure where the error would
+            // bubble up to (i.e., who would catch it) <2025-01-21>
+            eprintln!("-> Error while downloading: {}", get_title());
+            process::exit(1);
+        }
+        other => unreachable!("'{other}' should not be a valid state!"),
+    };
+
+    Ok(())
+}
+
+mk_python_function!(progress_hook, wrapped_progress_hook);