about summary refs log tree commit diff stats
path: root/crates/yt_dlp/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--crates/yt_dlp/src/lib.rs151
1 files changed, 117 insertions, 34 deletions
diff --git a/crates/yt_dlp/src/lib.rs b/crates/yt_dlp/src/lib.rs
index 970bfe2..40610c2 100644
--- a/crates/yt_dlp/src/lib.rs
+++ b/crates/yt_dlp/src/lib.rs
@@ -12,8 +12,8 @@
 #![allow(unsafe_op_in_unsafe_fn)]
 #![allow(clippy::missing_errors_doc)]
 
-use std::env;
-use std::io::stdout;
+use std::io::stderr;
+use std::{env, process};
 use std::{fs::File, io::Write};
 
 use std::{path::PathBuf, sync::Once};
@@ -21,18 +21,20 @@ use std::{path::PathBuf, sync::Once};
 use crate::{duration::Duration, logging::setup_logging, wrapper::info_json::InfoJson};
 
 use bytes::Bytes;
-use log::{info, log_enabled, Level};
+use error::YtDlpError;
+use log::{Level, debug, info, log_enabled};
 use pyo3::types::{PyString, PyTuple, PyTupleMethods};
 use pyo3::{
-    pyfunction,
+    Bound, PyAny, PyResult, Python, pyfunction,
     types::{PyAnyMethods, PyDict, PyDictMethods, PyList, PyListMethods, PyModule},
-    wrap_pyfunction, Bound, PyAny, PyResult, Python,
+    wrap_pyfunction,
 };
 use serde::Serialize;
 use serde_json::{Map, Value};
 use url::Url;
 
 pub mod duration;
+pub mod error;
 pub mod logging;
 pub mod wrapper;
 
@@ -51,6 +53,33 @@ pub fn add_logger_and_sig_handler<'a>(
     opts: Bound<'a, PyDict>,
     py: Python<'_>,
 ) -> PyResult<Bound<'a, PyDict>> {
+    /// Is the specified record to be logged? Returns false for no,
+    /// true for yes. Filters can either modify log records in-place or
+    /// return a completely different record instance which will replace
+    /// the original log record in any future processing of the event.
+    #[pyfunction]
+    fn filter_error_log(_py: Python<'_>, record: &Bound<'_, PyAny>) -> bool {
+        // Filter out all error logs (they are propagated as rust errors)
+        let levelname: String = record
+            .getattr("levelname")
+            .expect("This should exist")
+            .extract()
+            .expect("This should be a String");
+
+        let return_value = levelname.as_str() != "ERROR";
+
+        if log_enabled!(Level::Debug) && !return_value {
+            let message: String = record
+                .call_method0("getMessage")
+                .expect("This method exists")
+                .extract()
+                .expect("The message is a string");
+
+            debug!("Swollowed error message: '{message}'");
+        }
+        return_value
+    }
+
     setup_logging(py, "yt_dlp")?;
 
     let logging = PyModule::import(py, "logging")?;
@@ -81,6 +110,11 @@ signal.signal(signal.SIGINT, signal.SIG_DFL)",
             .expect("This method exists");
     });
 
+    ytdl_logger.call_method1(
+        "addFilter",
+        (wrap_pyfunction!(filter_error_log, py).expect("This function can be wrapped"),),
+    )?;
+
     // This was taken from `ytcc`, I don't think it is still applicable
     // ytdl_logger.setattr("propagate", false)?;
     // let logging_null_handler = logging.call_method0("NullHandler")?;
@@ -111,10 +145,10 @@ pub fn progress_hook(py: Python<'_>, input: &Bound<'_, PyDict>) -> PyResult<()>
     // see: https://en.wikipedia.org/wiki/ANSI_escape_code#Control_Sequence_Introducer_commands
     const CSI: &str = "\x1b[";
     fn clear_whole_line() {
-        print!("{CSI}2K");
+        eprint!("{CSI}2K");
     }
     fn move_to_col(x: usize) {
-        print!("{CSI}{x}G");
+        eprint!("{CSI}{x}G");
     }
     // }}}
 
@@ -125,7 +159,7 @@ pub fn progress_hook(py: Python<'_>, input: &Bound<'_, PyDict>) -> PyResult<()>
             .expect("Will always work")
             .to_owned(),
     )?)
-    .expect("Python should always produce valid json");
+    .expect("python's json is valid");
 
     macro_rules! get {
         (@interrogate $item:ident, $type_fun:ident, $get_fun:ident, $name:expr) => {{
@@ -198,7 +232,7 @@ pub fn progress_hook(py: Python<'_>, input: &Bound<'_, PyDict>) -> PyResult<()>
         format!("{bytes}/s")
     }
 
-    let get_title = |add_extension: bool| -> String {
+    let get_title = || -> String {
         match get! {is_string, as_str, "info_dict", "ext"} {
             "vtt" => {
                 format!(
@@ -206,16 +240,8 @@ pub fn progress_hook(py: Python<'_>, input: &Bound<'_, PyDict>) -> PyResult<()>
                     default_get! {as_str, "<No Subtitle Language>", "info_dict", "name"}
                 )
             }
-            title_extension @ ("webm" | "mp4" | "m4a") => {
-                if add_extension {
-                    format!(
-                        "{} ({})",
-                        default_get! { as_str, "<No title>", "info_dict", "title"},
-                        title_extension
-                    )
-                } else {
-                    default_get! { as_str, "<No title>", "info_dict", "title"}.to_owned()
-                }
+            "webm" | "mp4" | "mp3" | "m4a" => {
+                default_get! { as_str, "<No title>", "info_dict", "title"}.to_owned()
             }
             other => panic!("The extension '{other}' is not yet implemented"),
         }
@@ -257,9 +283,9 @@ pub fn progress_hook(py: Python<'_>, input: &Bound<'_, PyDict>) -> PyResult<()>
             clear_whole_line();
             move_to_col(1);
 
-            print!(
+            eprint!(
                 "'{}' [{}/{} at {}] -> [{} of {}{} {}] ",
-                c!("34;1", get_title(true)),
+                c!("34;1", get_title()),
                 c!("33;1", Duration::from(Some(elapsed))),
                 c!("33;1", Duration::from(Some(eta))),
                 c!("32;1", format_speed(speed)),
@@ -268,13 +294,16 @@ pub fn progress_hook(py: Python<'_>, input: &Bound<'_, PyDict>) -> PyResult<()>
                 c!("31;1", format_bytes(total_bytes)),
                 c!("36;1", format!("{:.02}%", percent))
             );
-            stdout().flush()?;
+            stderr().flush()?;
         }
         "finished" => {
-            println!("-> Finished downloading.");
+            eprintln!("-> Finished downloading.");
         }
         "error" => {
-            panic!("-> Error while downloading: {}", get_title(true))
+            // TODO: This should probably return an Err. But I'm not so sure where the error would
+            // bubble up to (i.e., who would catch it) <2025-01-21>
+            eprintln!("-> Error while downloading: {}", get_title());
+            process::exit(1);
         }
         other => unreachable!("'{other}' should not be a valid state!"),
     };
@@ -298,6 +327,42 @@ pub fn add_hooks<'a>(opts: Bound<'a, PyDict>, py: Python<'_>) -> PyResult<Bound<
     Ok(opts)
 }
 
+/// Take the result of the ie (may be modified) and resolve all unresolved
+/// references (URLs, playlist items).
+///
+/// It will also download the videos if 'download'.
+/// Returns the resolved `ie_result`.
+#[allow(clippy::unused_async)]
+#[allow(clippy::missing_panics_doc)]
+pub async fn process_ie_result(
+    yt_dlp_opts: &Map<String, Value>,
+    ie_result: InfoJson,
+    download: bool,
+) -> Result<InfoJson, YtDlpError> {
+    Python::with_gil(|py| -> Result<InfoJson, YtDlpError> {
+        let opts = json_map_to_py_dict(yt_dlp_opts, py)?;
+
+        let instance = get_yt_dlp(py, opts)?;
+
+        let args = {
+            let ie_result = json_loads_str(py, ie_result)?;
+            (ie_result,)
+        };
+
+        let kwargs = PyDict::new(py);
+        kwargs.set_item("download", download)?;
+
+        let result = instance
+            .call_method("process_ie_result", args, Some(&kwargs))?
+            .downcast_into::<PyDict>()
+            .expect("This is a dict");
+
+        let result_str = json_dumps(py, result.into_any())?;
+
+        serde_json::from_str(&result_str).map_err(Into::into)
+    })
+}
+
 /// `extract_info(self, url, download=True, ie_key=None, extra_info=None, process=True, force_generic_extractor=False)`
 ///
 /// Extract and return the information dictionary of the URL
@@ -320,8 +385,8 @@ pub async fn extract_info(
     url: &Url,
     download: bool,
     process: bool,
-) -> PyResult<InfoJson> {
-    Python::with_gil(|py| {
+) -> Result<InfoJson, YtDlpError> {
+    Python::with_gil(|py| -> Result<InfoJson, YtDlpError> {
         let opts = json_map_to_py_dict(yt_dlp_opts, py)?;
 
         let instance = get_yt_dlp(py, opts)?;
@@ -331,14 +396,33 @@ pub async fn extract_info(
         kwargs.set_item("download", download)?;
         kwargs.set_item("process", process)?;
 
-        let result = instance.call_method("extract_info", args, Some(&kwargs))?;
+        let result = instance
+            .call_method("extract_info", args, Some(&kwargs))?
+            .downcast_into::<PyDict>()
+            .expect("This is a dict");
+
+        // Resolve the generator object
+        if let Some(generator) = result.get_item("entries")? {
+            if generator.is_instance_of::<PyList>() {
+                // already resolved. Do nothing
+            } else {
+                let max_backlog = yt_dlp_opts.get("playlistend").map_or(10000, |value| {
+                    usize::try_from(value.as_u64().expect("Works")).expect("Should work")
+                });
+
+                let mut out = vec![];
+                while let Ok(output) = generator.call_method0("__next__") {
+                    out.push(output);
 
-        // Remove the `<generator at 0xsome_hex>`, by setting it to null
-        if !process {
-            result.set_item("entries", ())?;
+                    if out.len() == max_backlog {
+                        break;
+                    }
+                }
+                result.set_item("entries", out)?;
+            }
         }
 
-        let result_str = json_dumps(py, result)?;
+        let result_str = json_dumps(py, result.into_any())?;
 
         if let Ok(confirm) = env::var("YT_STORE_INFO_JSON") {
             if confirm == "yes" {
@@ -347,8 +431,7 @@ pub async fn extract_info(
             }
         }
 
-        Ok(serde_json::from_str(&result_str)
-            .expect("Python should be able to produce correct json"))
+        serde_json::from_str(&result_str).map_err(Into::into)
     })
 }
 
@@ -380,7 +463,7 @@ pub fn unsmuggle_url(smug_url: &Url) -> PyResult<Url> {
 pub async fn download(
     urls: &[Url],
     download_options: &Map<String, Value>,
-) -> PyResult<Vec<PathBuf>> {
+) -> Result<Vec<PathBuf>, YtDlpError> {
     let mut out_paths = Vec::with_capacity(urls.len());
 
     for url in urls {