about summary refs log tree commit diff stats
path: root/crates/yt_dlp/src/post_processors
diff options
context:
space:
mode:
Diffstat (limited to 'crates/yt_dlp/src/post_processors')
-rw-r--r--crates/yt_dlp/src/post_processors/dearrow.rs145
-rw-r--r--crates/yt_dlp/src/post_processors/mod.rs91
2 files changed, 112 insertions, 124 deletions
diff --git a/crates/yt_dlp/src/post_processors/dearrow.rs b/crates/yt_dlp/src/post_processors/dearrow.rs
index 3cac745..f35f301 100644
--- a/crates/yt_dlp/src/post_processors/dearrow.rs
+++ b/crates/yt_dlp/src/post_processors/dearrow.rs
@@ -9,50 +9,106 @@
 // If not, see <https://www.gnu.org/licenses/gpl-3.0.txt>.
 
 use curl::easy::Easy;
-use log::{error, info, warn};
-use rustpython::vm::{
-    PyRef, VirtualMachine,
-    builtins::{PyDict, PyStr},
+use log::{error, info, trace, warn};
+use pyo3::{
+    Bound, PyAny, PyErr, PyResult, Python, exceptions, intern, pyfunction,
+    types::{PyAnyMethods, PyDict, PyModule},
+    wrap_pyfunction,
 };
 use serde::{Deserialize, Serialize};
 
-use crate::{pydict_cast, pydict_get, wrap_post_processor};
+use crate::{
+    pydict_cast, pydict_get,
+    python_error::{IntoPythonError, PythonError},
+};
+
+/// # Errors
+/// - If the underlying function returns an error.
+/// - If python operations fail.
+pub fn process(py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
+    #[pyfunction]
+    fn actual_processor(info_json: Bound<'_, PyDict>) -> PyResult<Bound<'_, PyDict>> {
+        let output = match unwrapped_process(info_json) {
+            Ok(ok) => ok,
+            Err(err) => {
+                return Err(PyErr::new::<exceptions::PyRuntimeError, _>(err.to_string()));
+            }
+        };
+        Ok(output)
+    }
 
-wrap_post_processor!("DeArrow", unwrapped_process, process);
+    let module = PyModule::new(py, "rust_post_processors")?;
+    let scope = PyDict::new(py);
+    scope.set_item(
+        intern!(py, "actual_processor"),
+        wrap_pyfunction!(actual_processor, module)?,
+    )?;
+    py.run(
+        c"
+import yt_dlp
+
+class DeArrow(yt_dlp.postprocessor.PostProcessor):
+    def run(self, info):
+        info = actual_processor(info)
+        return [], info
+
+inst = DeArrow()
+",
+        Some(&scope),
+        None,
+    )?;
+
+    Ok(scope.get_item(intern!(py, "inst"))?.downcast_into()?)
+}
 
 /// # Errors
 /// If the API access fails.
-pub fn unwrapped_process(info: PyRef<PyDict>, vm: &VirtualMachine) -> Result<PyRef<PyDict>, Error> {
-    if pydict_get!(@vm, info, "extractor_key", PyStr).as_str() != "Youtube" {
-        warn!("DeArrow: Extractor did not match, exiting.");
+pub fn unwrapped_process(info: Bound<'_, PyDict>) -> Result<Bound<'_, PyDict>, Error> {
+    if pydict_get!(info, "extractor_key", String).as_str() != "Youtube" {
         return Ok(info);
     }
 
+    let mut retry_num = 3;
     let mut output: DeArrowApi = {
-        let output_bytes = {
-            let mut dst = Vec::new();
-
-            let mut easy = Easy::new();
-            easy.url(
-                format!(
-                    "https://sponsor.ajay.app/api/branding?videoID={}",
-                    pydict_get!(@vm, info, "id", PyStr).as_str()
-                )
-                .as_str(),
-            )?;
-
-            let mut transfer = easy.transfer();
-            transfer.write_function(|data| {
-                dst.extend_from_slice(data);
-                Ok(data.len())
-            })?;
-            transfer.perform()?;
-            drop(transfer);
-
-            dst
-        };
-
-        serde_json::from_slice(&output_bytes)?
+        loop {
+            let output_bytes = {
+                let mut dst = Vec::new();
+
+                let mut easy = Easy::new();
+                easy.url(
+                    format!(
+                        "https://sponsor.ajay.app/api/branding?videoID={}",
+                        pydict_get!(info, "id", String)
+                    )
+                    .as_str(),
+                )?;
+
+                let mut transfer = easy.transfer();
+                transfer.write_function(|data| {
+                    dst.extend_from_slice(data);
+                    Ok(data.len())
+                })?;
+                transfer.perform()?;
+                drop(transfer);
+
+                dst
+            };
+
+            match serde_json::from_slice(&output_bytes) {
+                Ok(ok) => break ok,
+                Err(err) => {
+                    if retry_num > 0 {
+                        trace!(
+                            "DeArrow: Api access failed, trying again ({retry_num} retries left)"
+                        );
+                        retry_num -= 1;
+                    } else {
+                        let err: serde_json::Error = err;
+                        return Err(err.into());
+                    }
+                }
+            }
+        }
     };
 
     // We pop the titles, so we need this vector reversed.
@@ -74,7 +130,7 @@ pub fn unwrapped_process(info: PyRef<PyDict>, vm: &VirtualMachine) -> Result<PyR
             continue;
         }
 
-        update_title(&info, &title.value, vm);
+        update_title(&info, &title.value).wrap_exc(info.py())?;
 
         break true;
     };
@@ -82,7 +138,7 @@ pub fn unwrapped_process(info: PyRef<PyDict>, vm: &VirtualMachine) -> Result<PyR
     if !selected && title_len != 0 {
         // No title was selected, even though we had some titles.
         // Just pick the first one in this case.
-        update_title(&info, &output.titles[0].value, vm);
+        update_title(&info, &output.titles[0].value).wrap_exc(info.py())?;
     }
 
     Ok(info)
@@ -90,6 +146,9 @@ pub fn unwrapped_process(info: PyRef<PyDict>, vm: &VirtualMachine) -> Result<PyR
 
 #[derive(thiserror::Error, Debug)]
 pub enum Error {
+    #[error(transparent)]
+    Python(#[from] PythonError),
+
     #[error("Failed to access the DeArrow api: {0}")]
     Get(#[from] curl::Error),
 
@@ -97,17 +156,19 @@ pub enum Error {
     Deserialize(#[from] serde_json::Error),
 }
 
-fn update_title(info: &PyRef<PyDict>, new_title: &str, vm: &VirtualMachine) {
-    assert!(!info.contains_key("original_title", vm));
+fn update_title(info: &Bound<'_, PyDict>, new_title: &str) -> PyResult<()> {
+    let py = info.py();
+
+    assert!(!info.contains(intern!(py, "original_title"))?);
 
-    if let Ok(old_title) = info.get_item("title", vm) {
+    if let Ok(old_title) = info.get_item(intern!(py, "title")) {
         warn!(
             "DeArrow: Updating title from {:#?} to {:#?}",
-            pydict_cast!(@ref old_title, PyStr).as_str(),
+            pydict_cast!(old_title, &str),
             new_title
         );
 
-        info.set_item("original_title", old_title, vm)
+        info.set_item(intern!(py, "original_title"), old_title)
             .expect("We checked, it is a new key");
     } else {
         warn!("DeArrow: Setting title to {new_title:#?}");
@@ -119,8 +180,10 @@ fn update_title(info: &PyRef<PyDict>, new_title: &str, vm: &VirtualMachine) {
         new_title.replace('>', "")
     };
 
-    info.set_item("title", vm.new_pyobj(cleaned_title), vm)
+    info.set_item(intern!(py, "title"), cleaned_title)
         .expect("This should work?");
+
+    Ok(())
 }
 
 #[derive(Serialize, Deserialize)]
diff --git a/crates/yt_dlp/src/post_processors/mod.rs b/crates/yt_dlp/src/post_processors/mod.rs
index 00b0ad5..d9be3f5 100644
--- a/crates/yt_dlp/src/post_processors/mod.rs
+++ b/crates/yt_dlp/src/post_processors/mod.rs
@@ -12,8 +12,9 @@ pub mod dearrow;
 
 #[macro_export]
 macro_rules! pydict_get {
-    (@$vm:expr, $value:expr, $name:literal, $into:ident) => {{
-        match $value.get_item($name, $vm) {
+    ($value:expr, $name:literal, $into:ty) => {{
+        let item = $value.get_item(pyo3::intern!($value.py(), $name));
+        match &item {
             Ok(val) => $crate::pydict_cast!(val, $into),
             Err(_) => panic!(
                 concat!(
@@ -31,93 +32,17 @@ macro_rules! pydict_get {
 
 #[macro_export]
 macro_rules! pydict_cast {
-    ($value:expr, $into:ident) => {{
-        match $value.downcast::<$into>() {
+    ($value:expr, $into:ty) => {{
+        match $value.extract::<$into>() {
             Ok(result) => result,
             Err(val) => panic!(
                 concat!(
-                    "Expected to be able to downcast value ({:#?}) as ",
-                    stringify!($into)
+                    "Expected to be able to extract ",
+                    stringify!($into),
+                    " from value ({:#?})."
                 ),
                 val
             ),
         }
     }};
-    (@ref $value:expr, $into:ident) => {{
-        match $value.downcast_ref::<$into>() {
-            Some(result) => result,
-            None => panic!(
-                concat!(
-                    "Expected to be able to downcast value ({:#?}) as ",
-                    stringify!($into)
-                ),
-                $value
-            ),
-        }
-    }};
-}
-
-#[macro_export]
-macro_rules! wrap_post_processor {
-    ($name:literal, $unwrap:ident, $wrapped:ident) => {
-        use $crate::progress_hook::__priv::vm;
-
-        /// # Errors
-        /// - If the underlying function returns an error.
-        /// - If python operations fail.
-        pub fn $wrapped(vm: &vm::VirtualMachine) -> vm::PyResult<vm::PyObjectRef> {
-            fn actual_processor(
-                mut input: vm::function::FuncArgs,
-                vm: &vm::VirtualMachine,
-            ) -> vm::PyResult<vm::PyRef<vm::builtins::PyDict>> {
-                let input = input
-                    .args
-                    .remove(0)
-                    .downcast::<vm::builtins::PyDict>()
-                    .expect("Should be a py dict");
-
-                let output = match unwrapped_process(input, vm) {
-                    Ok(ok) => ok,
-                    Err(err) => {
-                        return Err(vm.new_runtime_error(err.to_string()));
-                    }
-                };
-
-                Ok(output)
-            }
-
-            let scope = vm.new_scope_with_builtins();
-
-            scope.globals.set_item(
-                "actual_processor",
-                vm.new_function("actual_processor", actual_processor).into(),
-                vm,
-            )?;
-
-            let local_scope = scope.clone();
-            vm.run_code_string(
-                local_scope,
-                format!(
-                    "
-import yt_dlp
-
-class {}(yt_dlp.postprocessor.PostProcessor):
-    def run(self, info):
-        info = actual_processor(info)
-        return [], info
-
-inst = {}()
-",
-                    $name, $name
-                )
-                .as_str(),
-                "<embedded post processor initializing code>".to_owned(),
-            )?;
-
-            Ok(scope
-                .globals
-                .get_item("inst", vm)
-                .expect("We just declared it"))
-        }
-    };
 }