about summary refs log tree commit diff stats
path: root/crates/yt_dlp/src/progress_hook.rs
diff options
context:
space:
mode:
Diffstat (limited to 'crates/yt_dlp/src/progress_hook.rs')
-rw-r--r--crates/yt_dlp/src/progress_hook.rs83
1 files changed, 48 insertions, 35 deletions
diff --git a/crates/yt_dlp/src/progress_hook.rs b/crates/yt_dlp/src/progress_hook.rs
index b42ae21..7e5f8a5 100644
--- a/crates/yt_dlp/src/progress_hook.rs
+++ b/crates/yt_dlp/src/progress_hook.rs
@@ -9,46 +9,59 @@
 // If not, see <https://www.gnu.org/licenses/gpl-3.0.txt>.
 
 #[macro_export]
-macro_rules! mk_python_function {
+macro_rules! wrap_progress_hook {
     ($name:ident, $new_name:ident) => {
-        pub fn $new_name(
-            mut args: $crate::progress_hook::__priv::vm::function::FuncArgs,
-            vm: &$crate::progress_hook::__priv::vm::VirtualMachine,
-        ) {
-            use $crate::progress_hook::__priv::vm;
-
-            let input = {
-                let dict: vm::PyRef<vm::builtins::PyDict> = args
-                    .args
-                    .remove(0)
-                    .downcast()
-                    .expect("The progress hook is always called with these args");
-                let new_dict = vm::builtins::PyDict::new_ref(&vm.ctx);
-                dict.into_iter()
-                    .filter_map(|(name, value)| {
-                        let real_name: vm::PyRefExact<vm::builtins::PyStr> =
-                            name.downcast_exact(vm).expect("Is a string");
-                        let name_str = real_name.to_str().expect("Is a string");
-                        if name_str.starts_with('_') {
-                            None
-                        } else {
-                            Some((name_str.to_owned(), value))
-                        }
-                    })
-                    .for_each(|(key, value)| {
-                        new_dict
-                            .set_item(&key, value, vm)
-                            .expect("This is a transpositions, should always be valid");
-                    });
-
-                $crate::progress_hook::__priv::json_dumps(new_dict, vm)
-            };
-            $name(input).expect("Shall not fail!");
+        pub(crate) fn $new_name(
+            py: yt_dlp::progress_hook::__priv::pyo3::Python<'_>,
+        ) -> yt_dlp::progress_hook::__priv::pyo3::PyResult<
+            yt_dlp::progress_hook::__priv::pyo3::Bound<
+                '_,
+                yt_dlp::progress_hook::__priv::pyo3::types::PyCFunction,
+            >,
+        > {
+            #[yt_dlp::progress_hook::__priv::pyo3::pyfunction]
+            #[pyo3(crate = "yt_dlp::progress_hook::__priv::pyo3")]
+            fn inner(
+                input: yt_dlp::progress_hook::__priv::pyo3::Bound<
+                    '_,
+                    yt_dlp::progress_hook::__priv::pyo3::types::PyDict,
+                >,
+            ) -> yt_dlp::progress_hook::__priv::pyo3::PyResult<()> {
+                let processed_input = {
+                    let new_dict = yt_dlp::progress_hook::__priv::pyo3::types::PyDict::new(input.py());
+
+                    input
+                        .into_iter()
+                        .filter_map(|(name, value)| {
+                            let real_name = yt_dlp::progress_hook::__priv::pyo3::types::PyAnyMethods::extract::<String>(&name).expect("Should always be a string");
+
+                            if real_name.starts_with('_') {
+                                None
+                            } else {
+                                Some((real_name, value))
+                            }
+                        })
+                        .for_each(|(key, value)| {
+                            yt_dlp::progress_hook::__priv::pyo3::types::PyDictMethods::set_item(&new_dict, &key, value)
+                                .expect("This is a transpositions, should always be valid");
+                        });
+                    yt_dlp::progress_hook::__priv::json_dumps(&new_dict)
+                };
+
+                $name(processed_input)?;
+
+                Ok(())
+            }
+
+            let module = yt_dlp::progress_hook::__priv::pyo3::types::PyModule::new(py, "progress_hook")?;
+            let fun = yt_dlp::progress_hook::__priv::pyo3::wrap_pyfunction!(inner, module)?;
+
+            Ok(fun)
         }
     };
 }
 
 pub mod __priv {
     pub use crate::info_json::{json_dumps, json_loads};
-    pub use rustpython::vm;
+    pub use pyo3;
 }