From 47e8a80c1cbc6b76a341029df5555aa098ce4816 Mon Sep 17 00:00:00 2001 From: Alex Orlenko Date: Fri, 17 Apr 2020 22:38:01 +0100 Subject: [PATCH] v0.3.0-alpha.1 with async support Squashed commit of the async branch. --- Cargo.toml | 31 +- README.md | 14 +- build/main.rs | 4 + examples/async_http_client.rs | 52 ++ examples/async_tcp_server.rs | 103 ++++ examples/guided_tour.rs | 28 - mlua_derive/Cargo.toml | 4 +- src/conversion.rs | 6 +- src/error.rs | 31 +- src/lib.rs | 5 +- src/lua.rs | 508 +++++++++++++----- src/multi.rs | 2 +- src/prelude.rs | 11 +- src/scope.rs | 484 ----------------- src/thread.rs | 208 +++++++ src/types.rs | 28 +- src/userdata.rs | 53 +- src/util.rs | 224 ++++---- tests/compile_fail/lua_norefunwindsafe.stderr | 27 +- tests/compile_fail/ref_nounwindsafe.stderr | 29 +- tests/compile_fail/scope_callback_capture.rs | 22 - .../scope_callback_capture.stderr | 45 -- tests/compile_fail/scope_callback_inner.rs | 19 - .../compile_fail/scope_callback_inner.stderr | 45 -- tests/compile_fail/scope_callback_outer.rs | 19 - .../compile_fail/scope_callback_outer.stderr | 11 - tests/compile_fail/scope_invariance.rs | 23 - tests/compile_fail/scope_invariance.stderr | 25 - tests/compile_fail/scope_mutable_aliasing.rs | 15 - .../scope_mutable_aliasing.stderr | 9 - tests/compile_fail/scope_userdata_borrow.rs | 20 - .../compile_fail/scope_userdata_borrow.stderr | 13 - tests/function.rs | 39 +- tests/memory.rs | 10 +- tests/scope.rs | 231 -------- tests/tests.rs | 10 +- tests/thread.rs | 66 +++ tests/userdata.rs | 11 +- 38 files changed, 1141 insertions(+), 1344 deletions(-) create mode 100644 examples/async_http_client.rs create mode 100644 examples/async_tcp_server.rs delete mode 100644 src/scope.rs delete mode 100644 tests/compile_fail/scope_callback_capture.rs delete mode 100644 tests/compile_fail/scope_callback_capture.stderr delete mode 100644 tests/compile_fail/scope_callback_inner.rs delete mode 100644 tests/compile_fail/scope_callback_inner.stderr delete mode 100644 tests/compile_fail/scope_callback_outer.rs delete mode 100644 tests/compile_fail/scope_callback_outer.stderr delete mode 100644 tests/compile_fail/scope_invariance.rs delete mode 100644 tests/compile_fail/scope_invariance.stderr delete mode 100644 tests/compile_fail/scope_mutable_aliasing.rs delete mode 100644 tests/compile_fail/scope_mutable_aliasing.stderr delete mode 100644 tests/compile_fail/scope_userdata_borrow.rs delete mode 100644 tests/compile_fail/scope_userdata_borrow.stderr delete mode 100644 tests/scope.rs diff --git a/Cargo.toml b/Cargo.toml index 96097a2..c436b72 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,13 +1,13 @@ [package] name = "mlua" -version = "0.2.2" +version = "0.3.0-alpha.1" authors = ["Aleksandr Orlenko ", "kyren "] edition = "2018" repository = "https://github.com/khvzak/mlua" documentation = "https://docs.rs/mlua" readme = "README.md" -keywords = ["lua", "luajit"] -categories = ["api-bindings"] +keywords = ["lua", "luajit", "async", "futures"] +categories = ["api-bindings", "asynchronous"] license = "MIT" links = "lua" build = "build/main.rs" @@ -31,22 +31,39 @@ lua52 = [] lua51 = [] luajit = [] vendored = ["lua-src", "luajit-src"] +async = ["futures-core", "futures-task", "futures-util"] [dependencies] -num-traits = { version = "0.2.6" } bstr = { version = "0.2", features = ["std"], default_features = false } +num-traits = { version = "0.2.11" } +futures-core = { version = "0.3.4", optional = true } +futures-task = { version = "0.3.4", optional = true } +futures-util = { version = "0.3.4", optional = true } [build-dependencies] cc = { version = "1.0" } -pkg-config = { version = "0.3.11" } +pkg-config = { version = "0.3.17" } lua-src = { version = "535.0.1", optional = true } luajit-src = { version = "210.0.0", optional = true } [dev-dependencies] -rustyline = "5.0" -criterion = "0.2.0" +rustyline = "6.0" +criterion = "0.3" trybuild = "1.0" +hyper = "0.13" +tokio = { version = "0.2.18", features = ["full"] } +futures-executor = "0.3.4" +futures-util = "0.3.4" +futures-timer = "3.0" [[bench]] name = "benchmark" harness = false + +[[example]] +name = "async_tcp_server" +required-features = ["async"] + +[[example]] +name = "async_http_client" +required-features = ["async"] diff --git a/README.md b/README.md index 921885f..7f5f946 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,14 @@ modules in Rust. ## Usage +### Async + +Starting from 0.3, mlua supports async/await for all Lua versions. This feature works using Lua [coroutines](https://www.lua.org/manual/5.3/manual.html#2.6) and require running [Thread](https://docs.rs/mlua/latest/mlua/struct.Thread.html). + +**Examples**: +- [HTTP Client](examples/async_http_client.rs) +- [TCP Server](examples/async_tcp_server.rs) + ### Choosing Lua version The following features could be used to choose Lua version: `lua53` (default), `lua52`, `lua51` and `luajit`. @@ -41,7 +49,7 @@ Add to `Cargo.toml` : ``` toml [dependencies] -mlua = "0.2" +mlua = "0.3" ``` `main.rs` @@ -73,8 +81,8 @@ Add to `Cargo.toml` : crate-type = ["cdylib"] [dependencies] -mlua = "0.2" -mlua_derive = "0.2" +mlua = "0.3" +mlua_derive = "0.3" ``` `lib.rs` : diff --git a/build/main.rs b/build/main.rs index b43d0ed..1f54100 100644 --- a/build/main.rs +++ b/build/main.rs @@ -81,6 +81,10 @@ fn main() { #[cfg(all(feature = "lua51", feature = "luajit"))] panic!("You can enable only one of the features: lua53, lua52, lua51, luajit"); + // Async + // #[cfg(all(feature = "async", not(any(feature = "lua53", feature = "lua52"))))] + // panic!("You can enable async only for: lua53, lua52"); + let include_dir = find::probe_lua(); build_glue(&include_dir); } diff --git a/examples/async_http_client.rs b/examples/async_http_client.rs new file mode 100644 index 0000000..3321515 --- /dev/null +++ b/examples/async_http_client.rs @@ -0,0 +1,52 @@ +use std::collections::HashMap; + +use hyper::Client as HyperClient; + +use mlua::{Lua, Result, Thread, Error}; + +#[tokio::main] +async fn main() -> Result<()> { + let lua = Lua::new(); + + let fetch_url = lua.create_async_function(|lua, uri: String| async move { + let client = HyperClient::new(); + let uri = uri.parse().map_err(Error::external)?; + let resp = client.get(uri).await.map_err(Error::external)?; + + let lua_resp = lua.create_table()?; + lua_resp.set("status", resp.status().as_u16())?; + + let mut headers = HashMap::new(); + for (key, value) in resp.headers().iter() { + headers.entry(key.as_str()).or_insert(Vec::new()).push(value.to_str().unwrap()); + } + lua_resp.set("headers", headers)?; + + let buf = hyper::body::to_bytes(resp).await.map_err(Error::external)?; + lua_resp.set("body", String::from_utf8_lossy(&buf).into_owned())?; + + Ok(lua_resp) + })?; + + let globals = lua.globals(); + globals.set("fetch_url", fetch_url)?; + + let thread = lua + .load( + r#" + coroutine.create(function () + local res = fetch_url("http://httpbin.org/ip"); + print(res.status) + for key, vals in pairs(res.headers) do + for _, val in ipairs(vals) do + print(key..": "..val) + end + end + print(res.body) + end) + "#, + ) + .eval::()?; + + thread.into_async(()).await +} diff --git a/examples/async_tcp_server.rs b/examples/async_tcp_server.rs new file mode 100644 index 0000000..c9e695b --- /dev/null +++ b/examples/async_tcp_server.rs @@ -0,0 +1,103 @@ +use std::net::Shutdown; +use std::rc::Rc; + +use bstr::BString; +use tokio::net::{TcpListener, TcpStream}; +use tokio::prelude::*; +use tokio::sync::Mutex; +use tokio::task; + +use mlua::{Function, Lua, Result, Thread, UserData, UserDataMethods}; + +#[derive(Clone)] +struct LuaTcpListener(Option>>); + +#[derive(Clone)] +struct LuaTcpStream(Rc>); + +impl UserData for LuaTcpListener { + fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { + methods.add_async_function("bind", |_, addr: String| async { + let listener = TcpListener::bind(addr).await?; + Ok(LuaTcpListener(Some(Rc::new(Mutex::new(listener))))) + }); + + methods.add_async_method("accept", |_, listener, ()| async { + let (stream, _) = listener.0.unwrap().lock().await.accept().await?; + Ok(LuaTcpStream(Rc::new(Mutex::new(stream)))) + }); + } +} + +impl UserData for LuaTcpStream { + fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { + methods.add_async_method("peer_addr", |_, stream, ()| async move { + Ok(stream.0.lock().await.peer_addr()?.to_string()) + }); + + methods.add_async_method("read", |_, stream, size: usize| async move { + let mut buf = vec![0; size]; + let mut stream = stream.0.lock().await; + let n = stream.read(&mut buf).await?; + buf.truncate(n); + Ok(BString::from(buf)) + }); + + methods.add_async_method("write", |_, stream, data: BString| async move { + let mut stream = stream.0.lock().await; + let n = stream.write(&data).await?; + Ok(n) + }); + + methods.add_async_method("close", |_, stream, ()| async move { + stream.0.lock().await.shutdown(Shutdown::Both)?; + Ok(()) + }); + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let lua = Lua::new(); + + let globals = lua.globals(); + globals.set("tcp", LuaTcpListener(None))?; + + globals.set( + "spawn", + lua.create_function(move |lua: &Lua, func: Function| { + let fut = lua.create_thread(func)?.into_async::<_, ()>(()); + task::spawn_local(async move { fut.await.unwrap() }); + Ok(()) + })?, + )?; + + let thread = lua + .load( + r#" + coroutine.create(function () + local listener = tcp.bind("0.0.0.0:1234") + print("listening on 0.0.0.0:1234") + while true do + local stream = listener:accept() + print("connected from " .. stream:peer_addr()) + spawn(function() + while true do + local data = stream:read(100) + data = data:match("^%s*(.-)%s*$") -- trim + print(data) + stream:write("got: "..data.."\n") + if data == "exit" then + stream:close() + break + end + end + end) + end + end) + "#, + ) + .eval::()?; + + thread.into_async(()).await +} diff --git a/examples/guided_tour.rs b/examples/guided_tour.rs index c42b858..dab9ac0 100644 --- a/examples/guided_tour.rs +++ b/examples/guided_tour.rs @@ -164,34 +164,6 @@ fn main() -> Result<()> { < f32::EPSILON ); - // Normally, Rust types passed to `Lua` must be `Send`, because `Lua` itself is `Send`, and - // must be `'static`, because there is no way to be sure of their lifetime inside the Lua - // state. There is, however, a limited way to lift both of these requirements. You can - // call `Lua::scope` to create userdata and callbacks types that only live for as long - // as the call to scope, but do not have to be `Send` OR `'static`. - - { - let mut rust_val = 0; - - lua.scope(|scope| { - // We create a 'sketchy' Lua callback that holds a mutable reference to the variable - // `rust_val`. Outside of a `Lua::scope` call, this would not be allowed - // because it could be unsafe. - - lua.globals().set( - "sketchy", - scope.create_function_mut(|_, ()| { - rust_val = 42; - Ok(()) - })?, - )?; - - lua.load("sketchy()").exec() - })?; - - assert_eq!(rust_val, 42); - } - // We were able to run our 'sketchy' function inside the scope just fine. However, if we // try to run our 'sketchy' function outside of the scope, the function we created will have // been invalidated and we will generate an error. If our function wasn't invalidated, we diff --git a/mlua_derive/Cargo.toml b/mlua_derive/Cargo.toml index c506884..9f37275 100644 --- a/mlua_derive/Cargo.toml +++ b/mlua_derive/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "mlua_derive" -version = "0.2.0" -authors = ["Aleksandr Orlenko "] +version = "0.3.0-alpha.1" +authors = ["Aleksandr Orlenko "] edition = "2018" description = "Procedural macros for the mlua crate." repository = "https://github.com/khvzak/mlua" diff --git a/src/conversion.rs b/src/conversion.rs index 3116c58..51c6efd 100644 --- a/src/conversion.rs +++ b/src/conversion.rs @@ -122,7 +122,7 @@ impl<'lua> FromLua<'lua> for AnyUserData<'lua> { } } -impl<'lua, T: 'static + Send + UserData> ToLua<'lua> for T { +impl<'lua, T: 'static + UserData> ToLua<'lua> for T { fn to_lua(self, lua: &'lua Lua) -> Result> { Ok(Value::UserData(lua.create_userdata(self)?)) } @@ -167,7 +167,7 @@ impl<'lua> ToLua<'lua> for bool { } impl<'lua> FromLua<'lua> for bool { - fn from_lua(v: Value, _: &'lua Lua) -> Result { + fn from_lua(v: Value<'lua>, _: &'lua Lua) -> Result { match v { Value::Nil => Ok(false), Value::Boolean(b) => Ok(b), @@ -183,7 +183,7 @@ impl<'lua> ToLua<'lua> for LightUserData { } impl<'lua> FromLua<'lua> for LightUserData { - fn from_lua(value: Value, _: &'lua Lua) -> Result { + fn from_lua(value: Value<'lua>, _: &'lua Lua) -> Result { match value { Value::LightUserData(ud) => Ok(ud), _ => Err(Error::FromLuaConversionError { diff --git a/src/error.rs b/src/error.rs index 4f7493d..24ff1c9 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,8 +1,11 @@ use std::error::Error as StdError; use std::fmt; +use std::io::Error as IoError; +use std::net::AddrParseError; +use std::rc::Rc; use std::result::Result as StdResult; +use std::str::Utf8Error; use std::string::String as StdString; -use std::sync::Arc; /// Error type returned by `mlua` methods. #[derive(Debug, Clone)] @@ -115,7 +118,7 @@ pub enum Error { /// Lua call stack backtrace. traceback: StdString, /// Original error returned by the Rust code. - cause: Arc, + cause: Rc, }, /// A custom error. /// @@ -124,7 +127,7 @@ pub enum Error { /// Returning `Err(ExternalError(...))` from a Rust callback will raise the error as a Lua /// error. The Rust code that originally invoked the Lua code then receives a `CallbackError`, /// from which the original error (and a stack traceback) can be recovered. - ExternalError(Arc), + ExternalError(Rc), } /// A specialized `Result` type used by `mlua`'s API. @@ -203,7 +206,7 @@ impl StdError for Error { } impl Error { - pub fn external>>(err: T) -> Error { + pub fn external>>(err: T) -> Error { Error::ExternalError(err.into().into()) } } @@ -214,7 +217,7 @@ pub trait ExternalError { impl ExternalError for E where - E: Into>, + E: Into>, { fn to_lua_err(self) -> Error { Error::external(self) @@ -233,3 +236,21 @@ where self.map_err(|e| e.to_lua_err()) } } + +impl std::convert::From for Error { + fn from(err: AddrParseError) -> Self { + Error::external(err) + } +} + +impl std::convert::From for Error { + fn from(err: IoError) -> Self { + Error::external(err) + } +} + +impl std::convert::From for Error { + fn from(err: Utf8Error) -> Self { + Error::external(err) + } +} diff --git a/src/lib.rs b/src/lib.rs index 951f4af..e6f10e0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -49,7 +49,6 @@ mod ffi; mod function; mod lua; mod multi; -mod scope; mod stdlib; mod string; mod table; @@ -65,7 +64,6 @@ pub use crate::error::{Error, ExternalError, ExternalResult, Result}; pub use crate::function::Function; pub use crate::lua::{Chunk, Lua}; pub use crate::multi::Variadic; -pub use crate::scope::Scope; pub use crate::stdlib::StdLib; pub use crate::string::String; pub use crate::table::{Table, TablePairs, TableSequence}; @@ -74,4 +72,7 @@ pub use crate::types::{Integer, LightUserData, Number, RegistryKey}; pub use crate::userdata::{AnyUserData, MetaMethod, UserData, UserDataMethods}; pub use crate::value::{FromLua, FromLuaMulti, MultiValue, Nil, ToLua, ToLuaMulti, Value}; +#[cfg(feature = "async")] +pub use crate::thread::AsyncThread; + pub mod prelude; diff --git a/src/lua.rs b/src/lua.rs index a987019..a0ca44f 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -3,34 +3,44 @@ use std::cell::{RefCell, UnsafeCell}; use std::collections::HashMap; use std::ffi::CString; use std::marker::PhantomData; -use std::os::raw::{c_char, c_int, c_void}; -use std::sync::{Arc, Mutex}; +use std::os::raw::{c_char, c_int}; +use std::rc::Rc; use std::{mem, ptr, str}; use crate::error::{Error, Result}; use crate::ffi; use crate::function::Function; -use crate::scope::Scope; use crate::stdlib::StdLib; use crate::string::String; use crate::table::Table; use crate::thread::Thread; use crate::types::{Callback, Integer, LightUserData, LuaRef, Number, RegistryKey}; use crate::userdata::{AnyUserData, MetaMethod, UserData, UserDataMethods}; -#[cfg(any(feature = "lua51", feature = "luajit"))] -use crate::util::set_main_state; use crate::util::{ - assert_stack, callback_error, check_stack, get_main_state, get_userdata, get_wrapped_error, - init_error_registry, init_userdata_metatable, pop_error, protect_lua, protect_lua_closure, - push_string, push_userdata, push_wrapped_error, userdata_destructor, StackGuard, + assert_stack, callback_error, check_stack, get_gc_userdata, get_main_state, + get_meta_gc_userdata, get_wrapped_error, init_error_registry, init_gc_metatable_for, + init_userdata_metatable, pop_error, protect_lua, protect_lua_closure, push_gc_userdata, + push_meta_gc_userdata, push_string, push_userdata, push_wrapped_error, StackGuard, }; use crate::value::{FromLua, FromLuaMulti, MultiValue, Nil, ToLua, ToLuaMulti, Value}; +#[cfg(feature = "async")] +use { + crate::types::AsyncCallback, + futures_core::future::LocalBoxFuture, + futures_task::noop_waker, + futures_util::future::{self, FutureExt, TryFutureExt}, + std::{ + future::Future, + task::{Context, Poll, Waker}, + }, +}; + /// Top level Lua struct which holds the Lua state itself. pub struct Lua { pub(crate) state: *mut ffi::lua_State, main_state: *mut ffi::lua_State, - extra: Arc>, + extra: Rc>, ephemeral: bool, // Lua has lots of interior mutability, should not be RefUnwindSafe _no_ref_unwind_safe: PhantomData>, @@ -39,7 +49,7 @@ pub struct Lua { // Data associated with the lua_State. struct ExtraData { registered_userdata: HashMap, - registry_unref_list: Arc>>>, + registry_unref_list: Rc>>>, ref_thread: *mut ffi::lua_State, ref_stack_size: c_int, @@ -47,7 +57,10 @@ struct ExtraData { ref_free: Vec, } -unsafe impl Send for Lua {} +#[cfg(feature = "async")] +pub(crate) struct AsyncPollPending; +#[cfg(feature = "async")] +pub(crate) static WAKER_REGISTRY_KEY: u8 = 0; impl Drop for Lua { fn drop(&mut self) { @@ -59,7 +72,10 @@ impl Drop for Lua { && extra.ref_stack_max as usize == extra.ref_free.len(), "reference leak detected" ); - *mlua_expect!(extra.registry_unref_list.lock(), "unref list poisoned") = None; + *mlua_expect!( + extra.registry_unref_list.try_borrow_mut(), + "unref list borrowed" + ) = None; ffi::lua_close(self.state); } } @@ -113,55 +129,25 @@ impl Lua { /// Constructs a new Lua instance from the existing state. pub unsafe fn init_from_ptr(state: *mut ffi::lua_State) -> Lua { - #[cfg(any(feature = "lua53", feature = "lua52"))] let main_state = get_main_state(state); - #[cfg(any(feature = "lua51", feature = "luajit"))] - let main_state = { - set_main_state(state); - state - }; let main_state_top = ffi::lua_gettop(state); let ref_thread = mlua_expect!( protect_lua_closure(main_state, 0, 0, |state| { init_error_registry(state); - // Create the function metatables and place them in the registry + // Create the internal metatables and place them in the registry // to prevent them from being garbage collected. - ffi::lua_pushlightuserdata( - state, - &FUNCTION_CALLBACK_METATABLE_REGISTRY_KEY as *const u8 as *mut c_void, - ); - - ffi::lua_newtable(state); - - ffi::lua_pushstring(state, cstr!("__gc")); - ffi::lua_pushcfunction(state, userdata_destructor::); - ffi::lua_rawset(state, -3); - - ffi::lua_pushstring(state, cstr!("__metatable")); - ffi::lua_pushboolean(state, 0); - ffi::lua_rawset(state, -3); - - ffi::lua_rawset(state, ffi::LUA_REGISTRYINDEX); - - ffi::lua_pushlightuserdata( - state, - &FUNCTION_EXTRA_METATABLE_REGISTRY_KEY as *const u8 as *mut c_void, - ); - - ffi::lua_newtable(state); - - ffi::lua_pushstring(state, cstr!("__gc")); - ffi::lua_pushcfunction(state, userdata_destructor::>>); - ffi::lua_rawset(state, -3); - - ffi::lua_pushstring(state, cstr!("__metatable")); - ffi::lua_pushboolean(state, 0); - ffi::lua_rawset(state, -3); - - ffi::lua_rawset(state, ffi::LUA_REGISTRYINDEX); + init_gc_metatable_for::(state, None); + init_gc_metatable_for::(state, None); + #[cfg(feature = "async")] + { + init_gc_metatable_for::(state, None); + init_gc_metatable_for::>>(state, None); + init_gc_metatable_for::(state, None); + init_gc_metatable_for::(state, None); + } // Create ref stack thread and place it in the registry to prevent it from being garbage // collected. @@ -175,9 +161,9 @@ impl Lua { // Create ExtraData - let extra = Arc::new(RefCell::new(ExtraData { + let extra = Rc::new(RefCell::new(ExtraData { registered_userdata: HashMap::new(), - registry_unref_list: Arc::new(Mutex::new(Some(Vec::new()))), + registry_unref_list: Rc::new(RefCell::new(Some(Vec::new()))), ref_thread, // We need 1 extra stack space to move values in and out of the ref stack. ref_stack_size: ffi::LUA_MINSTACK - 1, @@ -206,7 +192,7 @@ impl Lua { pub fn entrypoint1<'lua, 'callback, R, F>(&'lua self, func: F) -> Result where R: ToLua<'callback>, - F: 'static + Send + Fn(&'callback Lua) -> Result, + F: 'static + Fn(&'callback Lua) -> Result, { let cb = self.create_callback(Box::new(move |lua, _| func(lua)?.to_lua_multi(lua)))?; unsafe { self.push_value(cb.call(())?).map(|_| 1) } @@ -463,7 +449,7 @@ impl Lua { where A: FromLuaMulti<'callback>, R: ToLuaMulti<'callback>, - F: 'static + Send + Fn(&'callback Lua, A) -> Result, + F: 'static + Fn(&'callback Lua, A) -> Result, { self.create_callback(Box::new(move |lua, args| { func(lua, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) @@ -483,7 +469,7 @@ impl Lua { where A: FromLuaMulti<'callback>, R: ToLuaMulti<'callback>, - F: 'static + Send + FnMut(&'callback Lua, A) -> Result, + F: 'static + FnMut(&'callback Lua, A) -> Result, { let func = RefCell::new(func); self.create_function(move |lua, args| { @@ -493,6 +479,68 @@ impl Lua { }) } + /// Wraps a Rust async function or closure, creating a callable Lua function handle to it. + /// + /// While executing the function Rust will poll Future and if the result is not ready, call + /// `lua_yield()` returning internal representation of a `Poll::Pending` value. + /// + /// The function must be called inside [`Thread`] coroutine to be able to suspend its execution. + /// An executor could be used together with [`ThreadStream`] and mlua will use a provided Waker + /// in that case. Otherwise noop waker will be used if try to call the function outside of Rust + /// executors. + /// + /// # Examples + /// + /// Non blocking sleep: + /// + /// ``` + /// use std::time::Duration; + /// use futures_executor::block_on; + /// use futures_timer::Delay; + /// # use mlua::{Lua, Result, Thread}; + /// + /// async fn sleep(_lua: &Lua, n: u64) -> Result<&'static str> { + /// Delay::new(Duration::from_secs(n)).await; + /// Ok("done") + /// } + /// + /// # fn main() -> Result<()> { + /// # let lua = Lua::new(); + /// lua.globals().set("async_sleep", lua.create_async_function(sleep)?)?; + /// let thr = lua.load("coroutine.create(function(n) return async_sleep(n) end)").eval::()?; + /// let res: String = block_on(async { + /// thr.into_async(1).await // Sleep 1 second + /// })?; + /// + /// assert_eq!(res, "done"); + /// # Ok(()) + /// # } + /// ``` + /// + /// [`Thread`]: struct.Thread.html + /// [`ThreadStream`]: struct.ThreadStream.html + #[cfg(feature = "async")] + pub fn create_async_function<'lua, 'callback, A, R, F, FR>( + &'lua self, + func: F, + ) -> Result> + where + A: FromLuaMulti<'callback>, + R: ToLuaMulti<'callback>, + F: 'static + Fn(&'callback Lua, A) -> FR, + FR: 'static + Future>, + { + self.create_async_callback(Box::new(move |lua, args| { + let args = match A::from_lua_multi(args, lua) { + Ok(x) => x, + Err(e) => return future::err(e).boxed_local(), + }; + func(lua, args) + .and_then(move |x| future::ready(x.to_lua_multi(lua))) + .boxed_local() + })) + } + /// Wraps a Lua function into a new thread (or coroutine). /// /// Equivalent to `coroutine.create`. @@ -513,7 +561,7 @@ impl Lua { /// Create a Lua userdata object from a custom userdata type. pub fn create_userdata(&self, data: T) -> Result where - T: 'static + Send + UserData, + T: 'static + UserData, { unsafe { self.make_userdata(data) } } @@ -540,33 +588,6 @@ impl Lua { } } - /// Calls the given function with a `Scope` parameter, giving the function the ability to create - /// userdata and callbacks from rust types that are !Send or non-'static. - /// - /// The lifetime of any function or userdata created through `Scope` lasts only until the - /// completion of this method call, on completion all such created values are automatically - /// dropped and Lua references to them are invalidated. If a script accesses a value created - /// through `Scope` outside of this method, a Lua error will result. Since we can ensure the - /// lifetime of values created through `Scope`, and we know that `Lua` cannot be sent to another - /// thread while `Scope` is live, it is safe to allow !Send datatypes and whose lifetimes only - /// outlive the scope lifetime. - /// - /// Inside the scope callback, all handles created through Scope will share the same unique 'lua - /// lifetime of the parent `Lua`. This allows scoped and non-scoped values to be mixed in - /// API calls, which is very useful (e.g. passing a scoped userdata to a non-scoped function). - /// However, this also enables handles to scoped values to be trivially leaked from the given - /// callback. This is not dangerous, though! After the callback returns, all scoped values are - /// invalidated, which means that though references may exist, the Rust types backing them have - /// dropped. `Function` types will error when called, and `AnyUserData` will be typeless. It - /// would be impossible to prevent handles to scoped values from escaping anyway, since you - /// would always be able to smuggle them through Lua state. - pub fn scope<'scope, 'lua: 'scope, F, R>(&'lua self, f: F) -> R - where - F: FnOnce(&Scope<'lua, 'scope>) -> R, - { - f(&Scope::new(self)) - } - /// Attempts to coerce a Lua value into a String in a manner consistent with Lua's internal /// behavior. /// @@ -806,7 +827,7 @@ impl Lua { /// `Error::MismatchedRegistryKey` if passed a `RegistryKey` that was not created with a /// matching `Lua` state. pub fn owns_registry_value(&self, key: &RegistryKey) -> bool { - Arc::ptr_eq(&key.unref_list, &self.extra.borrow().registry_unref_list) + Rc::ptr_eq(&key.unref_list, &self.extra.borrow().registry_unref_list) } /// Remove any registry values whose `RegistryKey`s have all been dropped. @@ -818,8 +839,8 @@ impl Lua { unsafe { let unref_list = mem::replace( &mut *mlua_expect!( - self.extra.borrow().registry_unref_list.lock(), - "unref list poisoned" + self.extra.borrow().registry_unref_list.try_borrow_mut(), + "unref list borrowed" ), Some(Vec::new()), ); @@ -1009,7 +1030,12 @@ impl Lua { })?; } - if methods.methods.is_empty() { + #[cfg(feature = "async")] + let no_methods = methods.methods.is_empty() && methods.async_methods.is_empty(); + #[cfg(not(feature = "async"))] + let no_methods = methods.methods.is_empty(); + + if no_methods { init_userdata_metatable::>(self.state, -1, None)?; } else { protect_lua_closure(self.state, 0, 1, |state| { @@ -1022,6 +1048,14 @@ impl Lua { ffi::lua_rawset(state, -3); })?; } + #[cfg(feature = "async")] + for (k, m) in methods.async_methods { + push_string(self.state, &k)?; + self.push_value(Value::Function(self.create_async_callback(m)?))?; + protect_lua_closure(self.state, 3, 1, |state| { + ffi::lua_rawset(state, -3); + })?; + } init_userdata_metatable::>(self.state, -2, Some(-1))?; ffi::lua_pop(self.state, 1); @@ -1053,10 +1087,10 @@ impl Lua { ) -> Result> { unsafe extern "C" fn call_callback(state: *mut ffi::lua_State) -> c_int { callback_error(state, |nargs| { - if ffi::lua_type(state, ffi::lua_upvalueindex(1)) == ffi::LUA_TNIL { - return Err(Error::CallbackDestructed); - } - if ffi::lua_type(state, ffi::lua_upvalueindex(2)) == ffi::LUA_TNIL { + let func = + get_meta_gc_userdata::(state, ffi::lua_upvalueindex(1)); + let lua = get_gc_userdata::(state, ffi::lua_upvalueindex(2)); + if func.is_null() || lua.is_null() { return Err(Error::CallbackDestructed); } @@ -1064,16 +1098,8 @@ impl Lua { check_stack(state, ffi::LUA_MINSTACK - nargs)?; } - let extra = - get_userdata::>>(state, ffi::lua_upvalueindex(2)); - - let lua = Lua { - state: state, - main_state: get_main_state(state), - extra: (*extra).clone(), - ephemeral: true, - _no_ref_unwind_safe: PhantomData, - }; + let lua = &mut *lua; + lua.state = state; let mut args = MultiValue::new(); args.reserve(nargs as usize); @@ -1081,9 +1107,7 @@ impl Lua { args.push_front(lua.pop_value()); } - let func = get_userdata::(state, ffi::lua_upvalueindex(1)); - - let results = (*func)(&lua, args)?; + let results = (*func)(lua, args)?; let nresults = results.len() as c_int; check_stack(state, nresults)?; @@ -1099,21 +1123,8 @@ impl Lua { let _sg = StackGuard::new(self.state); assert_stack(self.state, 6); - push_userdata::(self.state, func)?; - ffi::lua_pushlightuserdata( - self.state, - &FUNCTION_CALLBACK_METATABLE_REGISTRY_KEY as *const u8 as *mut c_void, - ); - ffi::lua_rawget(self.state, ffi::LUA_REGISTRYINDEX); - ffi::lua_setmetatable(self.state, -2); - - push_userdata::>>(self.state, self.extra.clone())?; - ffi::lua_pushlightuserdata( - self.state, - &FUNCTION_EXTRA_METATABLE_REGISTRY_KEY as *const u8 as *mut c_void, - ); - ffi::lua_rawget(self.state, ffi::LUA_REGISTRYINDEX); - ffi::lua_setmetatable(self.state, -2); + push_meta_gc_userdata::(self.state, func)?; + push_gc_userdata(self.state, self.clone())?; protect_lua_closure(self.state, 2, 1, |state| { ffi::lua_pushcclosure(state, call_callback, 2); @@ -1123,7 +1134,140 @@ impl Lua { } } - // Does not require Send bounds, which can lead to unsafety. + #[cfg(feature = "async")] + pub(crate) fn create_async_callback<'lua, 'callback>( + &'lua self, + func: AsyncCallback<'callback, 'static>, + ) -> Result> { + #[cfg(any(feature = "lua53", feature = "lua52"))] + self.load_from_std_lib(StdLib::COROUTINE)?; + + unsafe extern "C" fn call_callback(state: *mut ffi::lua_State) -> c_int { + callback_error(state, |nargs| { + let func = get_meta_gc_userdata::( + state, + ffi::lua_upvalueindex(1), + ); + let lua = get_gc_userdata::(state, ffi::lua_upvalueindex(2)); + if func.is_null() || lua.is_null() { + return Err(Error::CallbackDestructed); + } + + if nargs < ffi::LUA_MINSTACK { + check_stack(state, ffi::LUA_MINSTACK - nargs)?; + } + + let lua = &mut *lua; + lua.state = state; + + let mut args = MultiValue::new(); + args.reserve(nargs as usize); + for _ in 0..nargs { + args.push_front(lua.pop_value()); + } + + let fut = (*func)(lua, args); + push_gc_userdata(state, fut)?; + push_gc_userdata(state, lua.clone())?; + + ffi::lua_pushcclosure(state, poll_future, 2); + + Ok(1) + }) + } + + unsafe extern "C" fn poll_future(state: *mut ffi::lua_State) -> c_int { + callback_error(state, |nargs| { + let fut = get_gc_userdata::>>( + state, + ffi::lua_upvalueindex(1), + ); + let lua = get_gc_userdata::(state, ffi::lua_upvalueindex(2)); + if fut.is_null() || lua.is_null() { + return Err(Error::CallbackDestructed); + } + + if nargs < ffi::LUA_MINSTACK { + check_stack(state, ffi::LUA_MINSTACK - nargs)?; + } + + let lua = &mut *lua; + let mut waker = noop_waker(); + + // Try to get an outer poll waker + ffi::lua_pushlightuserdata( + state, + &WAKER_REGISTRY_KEY as *const u8 as *mut ::std::os::raw::c_void, + ); + ffi::lua_rawget(state, ffi::LUA_REGISTRYINDEX); + if let Some(w) = get_gc_userdata::(state, -1).as_ref() { + waker = (*w).clone(); + } + ffi::lua_pop(state, 1); + + let mut ctx = Context::from_waker(&waker); + + match (*fut).as_mut().poll(&mut ctx) { + Poll::Pending => { + check_stack(state, 6)?; + ffi::lua_pushboolean(state, 0); + push_gc_userdata(state, AsyncPollPending)?; + Ok(2) + } + Poll::Ready(results) => { + let results = lua.create_sequence_from(results?)?; + check_stack(state, 2)?; + ffi::lua_pushboolean(state, 1); + lua.push_value(Value::Table(results))?; + Ok(2) + } + } + }) + } + + let get_poll = unsafe { + let _sg = StackGuard::new(self.state); + assert_stack(self.state, 6); + + push_meta_gc_userdata::(self.state, func)?; + push_gc_userdata(self.state, self.clone())?; + + protect_lua_closure(self.state, 2, 1, |state| { + ffi::lua_pushcclosure(state, call_callback, 2); + })?; + + Function(self.pop_ref()) + }; + + let env = self.create_table()?; + env.set("get_poll", get_poll)?; + env.set("coroutine", self.globals().get::<_, Value>("coroutine")?)?; + env.set( + "unpack", + self.create_function(|_, tbl: Table| { + Ok(MultiValue::from_vec( + tbl.sequence_values().collect::>>()?, + )) + })?, + )?; + + self.load( + r#" + local poll = get_poll(...) + while true do + ready, res = poll() + if ready then + return unpack(res) + end + coroutine.yield(res) + end + "#, + ) + .set_name("_mlua_async_poll")? + .set_environment(env)? + .into_function() + } + pub(crate) unsafe fn make_userdata(&self, data: T) -> Result where T: 'static + UserData, @@ -1143,6 +1287,16 @@ impl Lua { Ok(AnyUserData(self.pop_ref())) } + + pub(crate) fn clone(&self) -> Self { + Lua { + state: self.state, + main_state: self.main_state, + extra: self.extra.clone(), + ephemeral: true, + _no_ref_unwind_safe: PhantomData, + } + } } /// Returned from [`Lua::load`] and is used to finalize loading and executing Lua main chunks. @@ -1343,11 +1497,10 @@ unsafe fn ref_stack_pop(extra: &mut ExtraData) -> c_int { } } -static FUNCTION_CALLBACK_METATABLE_REGISTRY_KEY: u8 = 0; -static FUNCTION_EXTRA_METATABLE_REGISTRY_KEY: u8 = 0; - struct StaticUserDataMethods<'lua, T: 'static + UserData> { methods: Vec<(Vec, Callback<'lua, 'static>)>, + #[cfg(feature = "async")] + async_methods: Vec<(Vec, AsyncCallback<'lua, 'static>)>, meta_methods: Vec<(MetaMethod, Callback<'lua, 'static>)>, _type: PhantomData, } @@ -1356,6 +1509,8 @@ impl<'lua, T: 'static + UserData> Default for StaticUserDataMethods<'lua, T> { fn default() -> StaticUserDataMethods<'lua, T> { StaticUserDataMethods { methods: Vec::new(), + #[cfg(feature = "async")] + async_methods: Vec::new(), meta_methods: Vec::new(), _type: PhantomData, } @@ -1368,7 +1523,7 @@ impl<'lua, T: 'static + UserData> UserDataMethods<'lua, T> for StaticUserDataMet S: ?Sized + AsRef<[u8]>, A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, - M: 'static + Send + Fn(&'lua Lua, &T, A) -> Result, + M: 'static + Fn(&'lua Lua, &T, A) -> Result, { self.methods .push((name.as_ref().to_vec(), Self::box_method(method))); @@ -1379,18 +1534,32 @@ impl<'lua, T: 'static + UserData> UserDataMethods<'lua, T> for StaticUserDataMet S: ?Sized + AsRef<[u8]>, A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, - M: 'static + Send + FnMut(&'lua Lua, &mut T, A) -> Result, + M: 'static + FnMut(&'lua Lua, &mut T, A) -> Result, { self.methods .push((name.as_ref().to_vec(), Self::box_method_mut(method))); } + #[cfg(feature = "async")] + fn add_async_method(&mut self, name: &S, method: M) + where + T: Clone, + S: ?Sized + AsRef<[u8]>, + A: FromLuaMulti<'lua>, + R: ToLuaMulti<'lua>, + M: 'static + Fn(&'lua Lua, T, A) -> MR, + MR: 'static + Future>, + { + self.async_methods + .push((name.as_ref().to_vec(), Self::box_async_method(method))); + } + fn add_function(&mut self, name: &S, function: F) where S: ?Sized + AsRef<[u8]>, A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, - F: 'static + Send + Fn(&'lua Lua, A) -> Result, + F: 'static + Fn(&'lua Lua, A) -> Result, { self.methods .push((name.as_ref().to_vec(), Self::box_function(function))); @@ -1401,17 +1570,31 @@ impl<'lua, T: 'static + UserData> UserDataMethods<'lua, T> for StaticUserDataMet S: ?Sized + AsRef<[u8]>, A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, - F: 'static + Send + FnMut(&'lua Lua, A) -> Result, + F: 'static + FnMut(&'lua Lua, A) -> Result, { self.methods .push((name.as_ref().to_vec(), Self::box_function_mut(function))); } + #[cfg(feature = "async")] + fn add_async_function(&mut self, name: &S, function: F) + where + T: Clone, + S: ?Sized + AsRef<[u8]>, + A: FromLuaMulti<'lua>, + R: ToLuaMulti<'lua>, + F: 'static + Fn(&'lua Lua, A) -> FR, + FR: 'static + Future>, + { + self.async_methods + .push((name.as_ref().to_vec(), Self::box_async_function(function))); + } + fn add_meta_method(&mut self, meta: MetaMethod, method: M) where A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, - M: 'static + Send + Fn(&'lua Lua, &T, A) -> Result, + M: 'static + Fn(&'lua Lua, &T, A) -> Result, { self.meta_methods.push((meta, Self::box_method(method))); } @@ -1420,7 +1603,7 @@ impl<'lua, T: 'static + UserData> UserDataMethods<'lua, T> for StaticUserDataMet where A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, - M: 'static + Send + FnMut(&'lua Lua, &mut T, A) -> Result, + M: 'static + FnMut(&'lua Lua, &mut T, A) -> Result, { self.meta_methods.push((meta, Self::box_method_mut(method))); } @@ -1429,7 +1612,7 @@ impl<'lua, T: 'static + UserData> UserDataMethods<'lua, T> for StaticUserDataMet where A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, - F: 'static + Send + Fn(&'lua Lua, A) -> Result, + F: 'static + Fn(&'lua Lua, A) -> Result, { self.meta_methods.push((meta, Self::box_function(function))); } @@ -1438,7 +1621,7 @@ impl<'lua, T: 'static + UserData> UserDataMethods<'lua, T> for StaticUserDataMet where A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, - F: 'static + Send + FnMut(&'lua Lua, A) -> Result, + F: 'static + FnMut(&'lua Lua, A) -> Result, { self.meta_methods .push((meta, Self::box_function_mut(function))); @@ -1450,7 +1633,7 @@ impl<'lua, T: 'static + UserData> StaticUserDataMethods<'lua, T> { where A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, - M: 'static + Send + Fn(&'lua Lua, &T, A) -> Result, + M: 'static + Fn(&'lua Lua, &T, A) -> Result, { Box::new(move |lua, mut args| { if let Some(front) = args.pop_front() { @@ -1471,7 +1654,7 @@ impl<'lua, T: 'static + UserData> StaticUserDataMethods<'lua, T> { where A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, - M: 'static + Send + FnMut(&'lua Lua, &mut T, A) -> Result, + M: 'static + FnMut(&'lua Lua, &mut T, A) -> Result, { let method = RefCell::new(method); Box::new(move |lua, mut args| { @@ -1492,11 +1675,43 @@ impl<'lua, T: 'static + UserData> StaticUserDataMethods<'lua, T> { }) } + #[cfg(feature = "async")] + fn box_async_method(method: M) -> AsyncCallback<'lua, 'static> + where + T: Clone, + A: FromLuaMulti<'lua>, + R: ToLuaMulti<'lua>, + M: 'static + Fn(&'lua Lua, T, A) -> MR, + MR: 'static + Future>, + { + Box::new(move |lua, mut args| { + let fut = || { + if let Some(front) = args.pop_front() { + let userdata = AnyUserData::from_lua(front, lua)?; + let userdata = userdata.borrow::()?.clone(); + Ok(method(lua, userdata, A::from_lua_multi(args, lua)?)) + } else { + Err(Error::FromLuaConversionError { + from: "missing argument", + to: "userdata", + message: None, + }) + } + }; + match fut() { + Ok(f) => f + .and_then(move |fr| future::ready(fr.to_lua_multi(lua))) + .boxed_local(), + Err(e) => future::err(e).boxed_local(), + } + }) + } + fn box_function(function: F) -> Callback<'lua, 'static> where A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, - F: 'static + Send + Fn(&'lua Lua, A) -> Result, + F: 'static + Fn(&'lua Lua, A) -> Result, { Box::new(move |lua, args| function(lua, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua)) } @@ -1505,7 +1720,7 @@ impl<'lua, T: 'static + UserData> StaticUserDataMethods<'lua, T> { where A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, - F: 'static + Send + FnMut(&'lua Lua, A) -> Result, + F: 'static + FnMut(&'lua Lua, A) -> Result, { let function = RefCell::new(function); Box::new(move |lua, args| { @@ -1515,4 +1730,23 @@ impl<'lua, T: 'static + UserData> StaticUserDataMethods<'lua, T> { function(lua, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) }) } + + #[cfg(feature = "async")] + fn box_async_function(function: F) -> AsyncCallback<'lua, 'static> + where + A: FromLuaMulti<'lua>, + R: ToLuaMulti<'lua>, + F: 'static + Fn(&'lua Lua, A) -> FR, + FR: 'static + Future>, + { + Box::new(move |lua, args| { + let args = match A::from_lua_multi(args, lua) { + Ok(x) => x, + Err(e) => return future::err(e).boxed_local(), + }; + function(lua, args) + .and_then(move |x| future::ready(x.to_lua_multi(lua))) + .boxed_local() + }) + } } diff --git a/src/multi.rs b/src/multi.rs index 46a673f..b7540d0 100644 --- a/src/multi.rs +++ b/src/multi.rs @@ -146,7 +146,7 @@ macro_rules! impl_tuple { } impl<'lua> FromLuaMulti<'lua> for () { - fn from_lua_multi(_: MultiValue, _: &'lua Lua) -> Result { + fn from_lua_multi(_: MultiValue<'lua>, _: &'lua Lua) -> Result { Ok(()) } } diff --git a/src/prelude.rs b/src/prelude.rs index 66eb15f..20045a8 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -5,8 +5,11 @@ pub use crate::{ ExternalError as LuaExternalError, ExternalResult as LuaExternalResult, FromLua, FromLuaMulti, Function as LuaFunction, Integer as LuaInteger, LightUserData as LuaLightUserData, Lua, MetaMethod as LuaMetaMethod, MultiValue as LuaMultiValue, Nil as LuaNil, Number as LuaNumber, - RegistryKey as LuaRegistryKey, Result as LuaResult, Scope as LuaScope, String as LuaString, - Table as LuaTable, TablePairs as LuaTablePairs, TableSequence as LuaTableSequence, - Thread as LuaThread, ThreadStatus as LuaThreadStatus, ToLua, ToLuaMulti, - UserData as LuaUserData, UserDataMethods as LuaUserDataMethods, Value as LuaValue, + RegistryKey as LuaRegistryKey, Result as LuaResult, String as LuaString, Table as LuaTable, + TablePairs as LuaTablePairs, TableSequence as LuaTableSequence, Thread as LuaThread, + ThreadStatus as LuaThreadStatus, ToLua, ToLuaMulti, UserData as LuaUserData, + UserDataMethods as LuaUserDataMethods, Value as LuaValue, }; + +#[cfg(feature = "async")] +pub use crate::AsyncThread as LuaAsyncThread; diff --git a/src/scope.rs b/src/scope.rs deleted file mode 100644 index ec354af..0000000 --- a/src/scope.rs +++ /dev/null @@ -1,484 +0,0 @@ -use std::any::Any; -use std::cell::Cell; -use std::cell::RefCell; -use std::marker::PhantomData; -use std::mem; -use std::os::raw::c_void; -use std::rc::Rc; - -use crate::error::{Error, Result}; -use crate::ffi; -use crate::function::Function; -use crate::lua::Lua; -use crate::types::{Callback, LuaRef}; -use crate::userdata::{AnyUserData, MetaMethod, UserData, UserDataMethods}; -use crate::util::{ - assert_stack, init_userdata_metatable, protect_lua_closure, push_string, push_userdata, - take_userdata, StackGuard, -}; -use crate::value::{FromLuaMulti, MultiValue, ToLuaMulti, Value}; - -/// Constructed by the [`Lua::scope`] method, allows temporarily creating Lua userdata and -/// callbacks that are not required to be Send or 'static. -/// -/// See [`Lua::scope`] for more details. -/// -/// [`Lua::scope`]: struct.Lua.html#method.scope -pub struct Scope<'lua, 'scope> { - lua: &'lua Lua, - destructors: RefCell, fn(LuaRef<'lua>) -> Box)>>, - _scope_invariant: PhantomData>, -} - -impl<'lua, 'scope> Scope<'lua, 'scope> { - pub(crate) fn new(lua: &'lua Lua) -> Scope<'lua, 'scope> { - Scope { - lua, - destructors: RefCell::new(Vec::new()), - _scope_invariant: PhantomData, - } - } - - /// Wraps a Rust function or closure, creating a callable Lua function handle to it. - /// - /// This is a version of [`Lua::create_function`] that creates a callback which expires on - /// scope drop. See [`Lua::scope`] for more details. - /// - /// [`Lua::create_function`]: struct.Lua.html#method.create_function - /// [`Lua::scope`]: struct.Lua.html#method.scope - pub fn create_function<'callback, A, R, F>(&'callback self, func: F) -> Result> - where - A: FromLuaMulti<'callback>, - R: ToLuaMulti<'callback>, - F: 'scope + Fn(&'callback Lua, A) -> Result, - { - // Safe, because 'scope must outlive 'callback (due to Self containing 'scope), however the - // callback itself must be 'scope lifetime, so the function should not be able to capture - // anything of 'callback lifetime. 'scope can't be shortened due to being invariant, and - // the 'callback lifetime here can't be enlarged due to coming from a universal - // quantification in Lua::scope. - // - // I hope I got this explanation right, but in any case this is tested with compiletest_rs - // to make sure callbacks can't capture handles with lifetime outside the scope, inside the - // scope, and owned inside the callback itself. - unsafe { - self.create_callback(Box::new(move |lua, args| { - func(lua, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - })) - } - } - - /// Wraps a Rust mutable closure, creating a callable Lua function handle to it. - /// - /// This is a version of [`Lua::create_function_mut`] that creates a callback which expires - /// on scope drop. See [`Lua::scope`] and [`Scope::create_function`] for more details. - /// - /// [`Lua::create_function_mut`]: struct.Lua.html#method.create_function_mut - /// [`Lua::scope`]: struct.Lua.html#method.scope - /// [`Scope::create_function`]: #method.create_function - pub fn create_function_mut<'callback, A, R, F>( - &'callback self, - func: F, - ) -> Result> - where - A: FromLuaMulti<'callback>, - R: ToLuaMulti<'callback>, - F: 'scope + FnMut(&'callback Lua, A) -> Result, - { - let func = RefCell::new(func); - self.create_function(move |lua, args| { - (&mut *func - .try_borrow_mut() - .map_err(|_| Error::RecursiveMutCallback)?)(lua, args) - }) - } - - /// Create a Lua userdata object from a custom userdata type. - /// - /// This is a version of [`Lua::create_userdata`] that creates a userdata which expires on - /// scope drop, and does not require that the userdata type be Send (but still requires that the - /// UserData be 'static). See [`Lua::scope`] for more details. - /// - /// [`Lua::create_userdata`]: struct.Lua.html#method.create_userdata - /// [`Lua::scope`]: struct.Lua.html#method.scope - pub fn create_static_userdata(&self, data: T) -> Result> - where - T: 'static + UserData, - { - // Safe even though T may not be Send, because the parent Lua cannot be sent to another - // thread while the Scope is alive (or the returned AnyUserData handle even). - unsafe { - let u = self.lua.make_userdata(data)?; - self.destructors.borrow_mut().push((u.0.clone(), |u| { - let state = u.lua.state; - assert_stack(state, 2); - u.lua.push_ref(&u); - // We know the destructor has not run yet because we hold a reference to the - // userdata. - Box::new(take_userdata::>(state)) - })); - Ok(u) - } - } - - /// Create a Lua userdata object from a custom userdata type. - /// - /// This is a version of [`Lua::create_userdata`] that creates a userdata which expires on - /// scope drop, and does not require that the userdata type be Send or 'static. See - /// [`Lua::scope`] for more details. - /// - /// Lifting the requirement that the UserData type be 'static comes with some important - /// limitations, so if you only need to eliminate the Send requirement, it is probably better to - /// use [`Scope::create_static_userdata`] instead. - /// - /// The main limitation that comes from using non-'static userdata is that the produced userdata - /// will no longer have a `TypeId` associated with it, becuase `TypeId` can only work for - /// 'static types. This means that it is impossible, once the userdata is created, to get a - /// reference to it back *out* of an `AnyUserData` handle. This also implies that the - /// "function" type methods that can be added via [`UserDataMethods`] (the ones that accept - /// `AnyUserData` as a first parameter) are vastly less useful. Also, there is no way to re-use - /// a single metatable for multiple non-'static types, so there is a higher cost associated with - /// creating the userdata metatable each time a new userdata is created. - /// - /// [`create_static_userdata`]: #method.create_static_userdata - /// [`Lua::create_userdata`]: struct.Lua.html#method.create_userdata - /// [`Lua::scope`]: struct.Lua.html#method.scope - /// [`UserDataMethods`]: trait.UserDataMethods.html - pub fn create_nonstatic_userdata(&self, data: T) -> Result> - where - T: 'scope + UserData, - { - let data = Rc::new(RefCell::new(data)); - - // 'callback outliving 'scope is a lie to make the types work out, required due to the - // inability to work with the more correct callback type that is universally quantified over - // 'lua. This is safe though, because `UserData::add_methods` does not get to pick the 'lua - // lifetime, so none of the static methods UserData types can add can possibly capture - // parameters. - fn wrap_method<'scope, 'lua, 'callback: 'scope, T: 'scope>( - scope: &Scope<'lua, 'scope>, - data: Rc>, - method: NonStaticMethod<'callback, T>, - ) -> Result> { - // On methods that actually receive the userdata, we fake a type check on the passed in - // userdata, where we pretend there is a unique type per call to - // `Scope::create_nonstatic_userdata`. You can grab a method from a userdata and call - // it on a mismatched userdata type, which when using normal 'static userdata will fail - // with a type mismatch, but here without this check would proceed as though you had - // called the method on the original value (since we otherwise completely ignore the - // first argument). - let check_data = data.clone(); - let check_ud_type = move |lua: &'callback Lua, value| { - if let Some(value) = value { - if let Value::UserData(u) = value { - unsafe { - assert_stack(lua.state, 1); - lua.push_ref(&u.0); - ffi::lua_getuservalue(lua.state, -1); - #[cfg(any(feature = "lua52", feature = "lua51", feature = "luajit"))] - { - ffi::lua_pushinteger(lua.state, 1); - ffi::lua_gettable(lua.state, -2); - ffi::lua_remove(lua.state, -2); - } - return ffi::lua_touserdata(lua.state, -1) - == check_data.as_ptr() as *mut c_void; - } - } - } - - false - }; - - match method { - NonStaticMethod::Method(method) => { - let method_data = data.clone(); - let f = Box::new(move |lua, mut args: MultiValue<'callback>| { - if !check_ud_type(lua, args.pop_front()) { - return Err(Error::UserDataTypeMismatch); - } - let data = method_data - .try_borrow() - .map_err(|_| Error::UserDataBorrowError)?; - method(lua, &*data, args) - }); - unsafe { scope.create_callback(f) } - } - NonStaticMethod::MethodMut(method) => { - let method = RefCell::new(method); - let method_data = data.clone(); - let f = Box::new(move |lua, mut args: MultiValue<'callback>| { - if !check_ud_type(lua, args.pop_front()) { - return Err(Error::UserDataTypeMismatch); - } - let mut method = method - .try_borrow_mut() - .map_err(|_| Error::RecursiveMutCallback)?; - let mut data = method_data - .try_borrow_mut() - .map_err(|_| Error::UserDataBorrowMutError)?; - (&mut *method)(lua, &mut *data, args) - }); - unsafe { scope.create_callback(f) } - } - NonStaticMethod::Function(function) => unsafe { scope.create_callback(function) }, - NonStaticMethod::FunctionMut(function) => { - let function = RefCell::new(function); - let f = Box::new(move |lua, args| { - (&mut *function - .try_borrow_mut() - .map_err(|_| Error::RecursiveMutCallback)?)( - lua, args - ) - }); - unsafe { scope.create_callback(f) } - } - } - } - - let mut ud_methods = NonStaticUserDataMethods::default(); - T::add_methods(&mut ud_methods); - - unsafe { - let lua = self.lua; - let _sg = StackGuard::new(lua.state); - assert_stack(lua.state, 6); - - push_userdata(lua.state, ())?; - #[cfg(feature = "lua53")] - ffi::lua_pushlightuserdata(lua.state, data.as_ptr() as *mut c_void); - #[cfg(any(feature = "lua52", feature = "lua51", feature = "luajit"))] - protect_lua_closure(lua.state, 0, 1, |state| { - // Lua 5.2/5.1 allows to store only table. Then we will wrap the value. - ffi::lua_createtable(state, 1, 0); - ffi::lua_pushinteger(state, 1); - ffi::lua_pushlightuserdata(state, data.as_ptr() as *mut c_void); - ffi::lua_settable(state, -3); - })?; - ffi::lua_setuservalue(lua.state, -2); - - protect_lua_closure(lua.state, 0, 1, move |state| { - ffi::lua_newtable(state); - })?; - - for (k, m) in ud_methods.meta_methods { - push_string(lua.state, k.name())?; - lua.push_value(Value::Function(wrap_method(self, data.clone(), m)?))?; - - protect_lua_closure(lua.state, 3, 1, |state| { - ffi::lua_rawset(state, -3); - })?; - } - - if ud_methods.methods.is_empty() { - init_userdata_metatable::<()>(lua.state, -1, None)?; - } else { - protect_lua_closure(lua.state, 0, 1, |state| { - ffi::lua_newtable(state); - })?; - for (k, m) in ud_methods.methods { - push_string(lua.state, &k)?; - lua.push_value(Value::Function(wrap_method(self, data.clone(), m)?))?; - protect_lua_closure(lua.state, 3, 1, |state| { - ffi::lua_rawset(state, -3); - })?; - } - - init_userdata_metatable::<()>(lua.state, -2, Some(-1))?; - ffi::lua_pop(lua.state, 1); - } - - ffi::lua_setmetatable(lua.state, -2); - - Ok(AnyUserData(lua.pop_ref())) - } - } - - // Unsafe, because the callback can improperly capture any value with 'callback scope, such as - // improperly capturing an argument. Since the 'callback lifetime is chosen by the user and the - // lifetime of the callback itself is 'scope (non-'static), the borrow checker will happily pick - // a 'callback that outlives 'scope to allow this. In order for this to be safe, the callback - // must NOT capture any parameters. - unsafe fn create_callback<'callback>( - &self, - f: Callback<'callback, 'scope>, - ) -> Result> { - let f = mem::transmute::, Callback<'lua, 'static>>(f); - let f = self.lua.create_callback(f)?; - - let mut destructors = self.destructors.borrow_mut(); - destructors.push((f.0.clone(), |f| { - let state = f.lua.state; - assert_stack(state, 3); - f.lua.push_ref(&f); - - ffi::lua_getupvalue(state, -1, 1); - // We know the destructor has not run yet because we hold a reference to the callback. - let ud = take_userdata::(state); - - ffi::lua_pushnil(state); - ffi::lua_setupvalue(state, -2, 1); - - ffi::lua_pop(state, 1); - Box::new(ud) - })); - Ok(f) - } -} - -impl<'lua, 'scope> Drop for Scope<'lua, 'scope> { - fn drop(&mut self) { - // We separate the action of invalidating the userdata in Lua and actually dropping the - // userdata type into two phases. This is so that, in the event a userdata drop panics, we - // can be sure that all of the userdata in Lua is actually invalidated. - - // All destructors are non-panicking, so this is fine - let to_drop = self - .destructors - .get_mut() - .drain(..) - .map(|(r, dest)| dest(r)) - .collect::>(); - - drop(to_drop); - } -} - -enum NonStaticMethod<'lua, T> { - Method(Box) -> Result>>), - MethodMut(Box) -> Result>>), - Function(Box) -> Result>>), - FunctionMut(Box) -> Result>>), -} - -struct NonStaticUserDataMethods<'lua, T: UserData> { - methods: Vec<(Vec, NonStaticMethod<'lua, T>)>, - meta_methods: Vec<(MetaMethod, NonStaticMethod<'lua, T>)>, -} - -impl<'lua, T: UserData> Default for NonStaticUserDataMethods<'lua, T> { - fn default() -> NonStaticUserDataMethods<'lua, T> { - NonStaticUserDataMethods { - methods: Vec::new(), - meta_methods: Vec::new(), - } - } -} - -impl<'lua, T: UserData> UserDataMethods<'lua, T> for NonStaticUserDataMethods<'lua, T> { - fn add_method(&mut self, name: &S, method: M) - where - S: ?Sized + AsRef<[u8]>, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - M: 'static + Send + Fn(&'lua Lua, &T, A) -> Result, - { - self.methods.push(( - name.as_ref().to_vec(), - NonStaticMethod::Method(Box::new(move |lua, ud, args| { - method(lua, ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - })), - )); - } - - fn add_method_mut(&mut self, name: &S, mut method: M) - where - S: ?Sized + AsRef<[u8]>, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - M: 'static + Send + FnMut(&'lua Lua, &mut T, A) -> Result, - { - self.methods.push(( - name.as_ref().to_vec(), - NonStaticMethod::MethodMut(Box::new(move |lua, ud, args| { - method(lua, ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - })), - )); - } - - fn add_function(&mut self, name: &S, function: F) - where - S: ?Sized + AsRef<[u8]>, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - F: 'static + Send + Fn(&'lua Lua, A) -> Result, - { - self.methods.push(( - name.as_ref().to_vec(), - NonStaticMethod::Function(Box::new(move |lua, args| { - function(lua, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - })), - )); - } - - fn add_function_mut(&mut self, name: &S, mut function: F) - where - S: ?Sized + AsRef<[u8]>, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - F: 'static + Send + FnMut(&'lua Lua, A) -> Result, - { - self.methods.push(( - name.as_ref().to_vec(), - NonStaticMethod::FunctionMut(Box::new(move |lua, args| { - function(lua, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - })), - )); - } - - fn add_meta_method(&mut self, meta: MetaMethod, method: M) - where - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - M: 'static + Send + Fn(&'lua Lua, &T, A) -> Result, - { - self.meta_methods.push(( - meta, - NonStaticMethod::Method(Box::new(move |lua, ud, args| { - method(lua, ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - })), - )); - } - - fn add_meta_method_mut(&mut self, meta: MetaMethod, mut method: M) - where - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - M: 'static + Send + FnMut(&'lua Lua, &mut T, A) -> Result, - { - self.meta_methods.push(( - meta, - NonStaticMethod::MethodMut(Box::new(move |lua, ud, args| { - method(lua, ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - })), - )); - } - - fn add_meta_function(&mut self, meta: MetaMethod, function: F) - where - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - F: 'static + Send + Fn(&'lua Lua, A) -> Result, - { - self.meta_methods.push(( - meta, - NonStaticMethod::Function(Box::new(move |lua, args| { - function(lua, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - })), - )); - } - - fn add_meta_function_mut(&mut self, meta: MetaMethod, mut function: F) - where - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - F: 'static + Send + FnMut(&'lua Lua, A) -> Result, - { - self.meta_methods.push(( - meta, - NonStaticMethod::FunctionMut(Box::new(move |lua, args| { - function(lua, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - })), - )); - } -} diff --git a/src/thread.rs b/src/thread.rs index 2452fec..95bc96d 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -8,6 +8,24 @@ use crate::util::{ }; use crate::value::{FromLuaMulti, MultiValue, ToLuaMulti}; +#[cfg(feature = "async")] +use { + crate::{ + error::ExternalError, + lua::{AsyncPollPending, Lua, WAKER_REGISTRY_KEY}, + util::{get_gc_userdata, push_gc_userdata}, + value::Value, + }, + futures_core::{future::Future, stream::Stream}, + std::{ + cell::RefCell, + marker::PhantomData, + os::raw::c_void, + pin::Pin, + task::{Context, Poll, Waker}, + }, +}; + /// Status of a Lua thread (or coroutine). #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum ThreadStatus { @@ -27,6 +45,15 @@ pub enum ThreadStatus { #[derive(Clone, Debug)] pub struct Thread<'lua>(pub(crate) LuaRef<'lua>); +/// Thread (coroutine) representation as an async Future or Stream. +#[cfg(feature = "async")] +#[derive(Debug)] +pub struct AsyncThread<'lua, R> { + thread: Thread<'lua>, + args0: RefCell>>>, + ret: PhantomData, +} + impl<'lua> Thread<'lua> { /// Resumes execution of this thread. /// @@ -142,6 +169,62 @@ impl<'lua> Thread<'lua> { } } } + + /// Converts Thread to an AsyncThread which implements Future and Stream traits. + /// + /// `args` are passed as arguments to the thread function for first call. + /// The object call `resume()` while polling and also allows to run rust futures + /// to completion using an executor. + /// + /// Using AsyncThread as a Stream allows to iterate through `coroutine.yield()` + /// values whereas Future version discards that values and poll until the final + /// one (returned from the thread function). + /// + /// # Examples + /// + /// ``` + /// # use mlua::{Error, Lua, Result, Thread}; + /// use futures_executor::block_on; + /// use futures_util::stream::TryStreamExt; + /// # fn main() -> Result<()> { + /// # let lua = Lua::new(); + /// let thread: Thread = lua.load(r#" + /// coroutine.create(function(sum) + /// for i = 1,10 do + /// sum = sum + i + /// coroutine.yield(sum) + /// end + /// return sum + /// end) + /// "#).eval()?; + /// + /// let result = block_on(async { + /// let mut s = thread.into_async::<_, i64>(1); + /// let mut sum = 0; + /// while let Some(n) = s.try_next().await? { + /// sum += n; + /// } + /// Ok::<_, Error>(sum) + /// })?; + /// + /// assert_eq!(result, 286); + /// + /// # Ok(()) + /// # } + /// ``` + #[cfg(feature = "async")] + pub fn into_async(self, args: A) -> AsyncThread<'lua, R> + where + A: ToLuaMulti<'lua>, + R: FromLuaMulti<'lua>, + { + let args = args.to_lua_multi(&self.0.lua); + AsyncThread { + thread: self, + args0: RefCell::new(Some(args)), + ret: PhantomData, + } + } } impl<'lua> PartialEq for Thread<'lua> { @@ -149,3 +232,128 @@ impl<'lua> PartialEq for Thread<'lua> { self.0 == other.0 } } + +#[cfg(feature = "async")] +impl<'lua, R> Stream for AsyncThread<'lua, R> +where + R: FromLuaMulti<'lua>, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let lua = self.thread.0.lua; + + match self.thread.status() { + ThreadStatus::Resumable => {} + _ => return Poll::Ready(None), + }; + + let _wg = WakerGuard::new(lua.state, cx.waker().clone()); + let ret: MultiValue = if let Some(args) = self.args0.borrow_mut().take() { + self.thread.resume(args?)? + } else { + self.thread.resume(())? + }; + + if is_poll_pending(lua, &ret) { + return Poll::Pending; + } + + cx.waker().wake_by_ref(); + Poll::Ready(Some(R::from_lua_multi(ret, lua))) + } +} + +#[cfg(feature = "async")] +impl<'lua, R> Future for AsyncThread<'lua, R> +where + R: FromLuaMulti<'lua>, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let lua = self.thread.0.lua; + + match self.thread.status() { + ThreadStatus::Resumable => {} + _ => return Poll::Ready(Err("Thread already finished".to_lua_err())), + }; + + let _wg = WakerGuard::new(lua.state, cx.waker().clone()); + let ret: MultiValue = if let Some(args) = self.args0.borrow_mut().take() { + self.thread.resume(args?)? + } else { + self.thread.resume(())? + }; + + if is_poll_pending(lua, &ret) { + return Poll::Pending; + } + + if let ThreadStatus::Resumable = self.thread.status() { + // Ignore value returned via yield() + cx.waker().wake_by_ref(); + return Poll::Pending; + } + + Poll::Ready(R::from_lua_multi(ret, lua)) + } +} + +#[cfg(feature = "async")] +fn is_poll_pending(lua: &Lua, val: &MultiValue) -> bool { + if val.len() != 1 { + return false; + } + + if let Some(Value::UserData(ud)) = val.iter().next() { + unsafe { + let _sg = StackGuard::new(lua.state); + assert_stack(lua.state, 3); + + lua.push_ref(&ud.0); + let is_pending = get_gc_userdata::(lua.state, -1) + .as_ref() + .is_some(); + ffi::lua_pop(lua.state, 1); + + return is_pending; + } + } + + false +} + +#[cfg(feature = "async")] +struct WakerGuard(*mut ffi::lua_State); + +#[cfg(feature = "async")] +impl WakerGuard { + pub fn new(state: *mut ffi::lua_State, waker: Waker) -> Result { + unsafe { + let _sg = StackGuard::new(state); + assert_stack(state, 6); + + ffi::lua_pushlightuserdata(state, &WAKER_REGISTRY_KEY as *const u8 as *mut c_void); + push_gc_userdata(state, waker)?; + ffi::lua_rawset(state, ffi::LUA_REGISTRYINDEX); + + Ok(WakerGuard(state)) + } + } +} + +#[cfg(feature = "async")] +impl Drop for WakerGuard { + fn drop(&mut self) { + unsafe { + let state = self.0; + let _sg = StackGuard::new(state); + assert_stack(state, 2); + + ffi::lua_pushlightuserdata(state, &WAKER_REGISTRY_KEY as *const u8 as *mut c_void); + ffi::lua_pushnil(state); + ffi::lua_rawset(state, ffi::LUA_REGISTRYINDEX); + } + } +} diff --git a/src/types.rs b/src/types.rs index e88d0c1..38065bc 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,7 +1,11 @@ +use std::cell::RefCell; use std::os::raw::{c_int, c_void}; -use std::sync::{Arc, Mutex}; +use std::rc::Rc; use std::{fmt, mem, ptr}; +#[cfg(feature = "async")] +use futures_core::future::LocalBoxFuture; + use crate::error::Result; use crate::ffi; use crate::lua::Lua; @@ -20,29 +24,28 @@ pub struct LightUserData(pub *mut c_void); pub(crate) type Callback<'lua, 'a> = Box) -> Result> + 'a>; +#[cfg(feature = "async")] +pub(crate) type AsyncCallback<'lua, 'a> = + Box) -> LocalBoxFuture<'lua, Result>> + 'a>; + /// An auto generated key into the Lua registry. /// -/// This is a handle to a value stored inside the Lua registry. It is not directly usable like the -/// `Table` or `Function` handle types, but since it doesn't hold a reference to a parent Lua and is -/// Send + Sync + 'static, it is much more flexible and can be used in many situations where it is -/// impossible to directly store a normal handle type. It is not automatically garbage collected on -/// Drop, but it can be removed with [`Lua::remove_registry_value`], and instances not manually -/// removed can be garbage collected with [`Lua::expire_registry_values`]. +/// This is a handle to a value stored inside the Lua registry. It is not automatically +/// garbage collected on Drop, but it can be removed with [`Lua::remove_registry_value`], +/// and instances not manually removed can be garbage collected with [`Lua::expire_registry_values`]. /// /// Be warned, If you place this into Lua via a `UserData` type or a rust callback, it is *very /// easy* to accidentally cause reference cycles that the Lua garbage collector cannot resolve. /// Instead of placing a `RegistryKey` into a `UserData` type, prefer instead to use -/// [`UserData::set_user_value`] / [`UserData::get_user_value`], and instead of moving a RegistryKey -/// into a callback, prefer [`Lua::scope`]. +/// [`UserData::set_user_value`] / [`UserData::get_user_value`]. /// /// [`Lua::remove_registry_value`]: struct.Lua.html#method.remove_registry_value /// [`Lua::expire_registry_values`]: struct.Lua.html#method.expire_registry_values -/// [`Lua::scope`]: struct.Lua.html#method.scope /// [`UserData::set_user_value`]: struct.UserData.html#method.set_user_value /// [`UserData::get_user_value`]: struct.UserData.html#method.get_user_value pub struct RegistryKey { pub(crate) registry_id: c_int, - pub(crate) unref_list: Arc>>>, + pub(crate) unref_list: Rc>>>, } impl fmt::Debug for RegistryKey { @@ -53,7 +56,8 @@ impl fmt::Debug for RegistryKey { impl Drop for RegistryKey { fn drop(&mut self) { - if let Some(list) = mlua_expect!(self.unref_list.lock(), "unref_list poisoned").as_mut() { + let mut unref_list = mlua_expect!(self.unref_list.try_borrow_mut(), "unref list borrowed"); + if let Some(list) = unref_list.as_mut() { list.push(self.registry_id); } } diff --git a/src/userdata.rs b/src/userdata.rs index ed5db3d..fc6f86f 100644 --- a/src/userdata.rs +++ b/src/userdata.rs @@ -1,5 +1,8 @@ use std::cell::{Ref, RefCell, RefMut}; +#[cfg(feature = "async")] +use std::future::Future; + use crate::error::{Error, Result}; use crate::ffi; use crate::function::Function; @@ -134,7 +137,7 @@ pub trait UserDataMethods<'lua, T: UserData> { S: ?Sized + AsRef<[u8]>, A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, - M: 'static + Send + Fn(&'lua Lua, &T, A) -> Result; + M: 'static + Fn(&'lua Lua, &T, A) -> Result; /// Add a regular method which accepts a `&mut T` as the first parameter. /// @@ -146,7 +149,23 @@ pub trait UserDataMethods<'lua, T: UserData> { S: ?Sized + AsRef<[u8]>, A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, - M: 'static + Send + FnMut(&'lua Lua, &mut T, A) -> Result; + M: 'static + FnMut(&'lua Lua, &mut T, A) -> Result; + + /// Add an async method which accepts a `T` as the first parameter and returns Future. + /// The passed `T` is cloned from the original value. + /// + /// Refer to [`add_method`] for more information about the implementation. + /// + /// [`add_method`]: #method.add_method + #[cfg(feature = "async")] + fn add_async_method(&mut self, name: &S, method: M) + where + T: Clone, + S: ?Sized + AsRef<[u8]>, + A: FromLuaMulti<'lua>, + R: ToLuaMulti<'lua>, + M: 'static + Fn(&'lua Lua, T, A) -> MR, + MR: 'static + Future>; /// Add a regular method as a function which accepts generic arguments, the first argument will /// be a `UserData` of type T if the method is called with Lua method syntax: @@ -162,7 +181,7 @@ pub trait UserDataMethods<'lua, T: UserData> { S: ?Sized + AsRef<[u8]>, A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, - F: 'static + Send + Fn(&'lua Lua, A) -> Result; + F: 'static + Fn(&'lua Lua, A) -> Result; /// Add a regular method as a mutable function which accepts generic arguments. /// @@ -174,7 +193,23 @@ pub trait UserDataMethods<'lua, T: UserData> { S: ?Sized + AsRef<[u8]>, A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, - F: 'static + Send + FnMut(&'lua Lua, A) -> Result; + F: 'static + FnMut(&'lua Lua, A) -> Result; + + /// Add a regular method as an async function which accepts generic arguments + /// and returns Future. + /// + /// This is an async version of [`add_function`]. + /// + /// [`add_function`]: #method.add_function + #[cfg(feature = "async")] + fn add_async_function(&mut self, name: &S, function: F) + where + T: Clone, + S: ?Sized + AsRef<[u8]>, + A: FromLuaMulti<'lua>, + R: ToLuaMulti<'lua>, + F: 'static + Fn(&'lua Lua, A) -> FR, + FR: 'static + Future>; /// Add a metamethod which accepts a `&T` as the first parameter. /// @@ -188,7 +223,7 @@ pub trait UserDataMethods<'lua, T: UserData> { where A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, - M: 'static + Send + Fn(&'lua Lua, &T, A) -> Result; + M: 'static + Fn(&'lua Lua, &T, A) -> Result; /// Add a metamethod as a function which accepts a `&mut T` as the first parameter. /// @@ -202,7 +237,7 @@ pub trait UserDataMethods<'lua, T: UserData> { where A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, - M: 'static + Send + FnMut(&'lua Lua, &mut T, A) -> Result; + M: 'static + FnMut(&'lua Lua, &mut T, A) -> Result; /// Add a metamethod which accepts generic arguments. /// @@ -213,7 +248,7 @@ pub trait UserDataMethods<'lua, T: UserData> { where A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, - F: 'static + Send + Fn(&'lua Lua, A) -> Result; + F: 'static + Fn(&'lua Lua, A) -> Result; /// Add a metamethod as a mutable function which accepts generic arguments. /// @@ -224,7 +259,7 @@ pub trait UserDataMethods<'lua, T: UserData> { where A: FromLuaMulti<'lua>, R: ToLuaMulti<'lua>, - F: 'static + Send + FnMut(&'lua Lua, A) -> Result; + F: 'static + FnMut(&'lua Lua, A) -> Result; } /// Trait for custom userdata types. @@ -293,7 +328,7 @@ pub trait UserDataMethods<'lua, T: UserData> { /// [`UserDataMethods`]: trait.UserDataMethods.html pub trait UserData: Sized { /// Adds custom methods and operators specific to this userdata. - fn add_methods<'lua, T: UserDataMethods<'lua, Self>>(_methods: &mut T) {} + fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(_methods: &mut M) {} } /// Handle to an internal Lua userdata for any type that implements [`UserData`]. diff --git a/src/util.rs b/src/util.rs index b3cf022..88e3fe0 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,14 +1,20 @@ -use std::any::Any; +use std::any::{Any, TypeId}; use std::borrow::Cow; +use std::cell::RefCell; +use std::collections::HashMap; use std::fmt::Write; use std::os::raw::{c_char, c_int, c_void}; use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe}; -use std::sync::Arc; +use std::rc::Rc; use std::{mem, ptr, slice}; use crate::error::{Error, Result}; use crate::ffi; +thread_local! { + static METATABLE_CACHE: RefCell> = RefCell::new(HashMap::new()); +} + // Checks that Lua has enough free stack space for future stack operations. On failure, this will // panic with an internal error message. pub unsafe fn assert_stack(state: *mut ffi::lua_State, amount: c_int) { @@ -175,8 +181,7 @@ pub unsafe fn pop_error(state: *mut ffi::lua_State, err_code: c_int) -> Error { if let Some(err) = get_wrapped_error(state, -1).as_ref() { ffi::lua_pop(state, 1); err.clone() - } else if is_wrapped_panic(state, -1) { - let panic = get_userdata::(state, -1); + } else if let Some(panic) = get_gc_userdata::(state, -1).as_mut() { if let Some(p) = (*panic).0.take() { resume_unwind(p); } else { @@ -255,6 +260,41 @@ pub unsafe fn take_userdata(state: *mut ffi::lua_State) -> T { ptr::read(ud) } +// Pushes the userdata and attaches a metatable with __gc method +// Internally uses 5 stack spaces, does not call checkstack +pub unsafe fn push_gc_userdata(state: *mut ffi::lua_State, t: T) -> Result<()> { + push_meta_gc_userdata::(state, t) +} + +pub unsafe fn push_meta_gc_userdata(state: *mut ffi::lua_State, t: T) -> Result<()> { + let ud = protect_lua_closure(state, 0, 1, move |state| { + ffi::lua_newuserdata(state, mem::size_of::()) as *mut T + })?; + ptr::write(ud, t); + get_gc_metatable_for::(state); + ffi::lua_setmetatable(state, -2); + Ok(()) +} + +// Uses 2 stack spaces, does not call checkstack +pub unsafe fn get_gc_userdata(state: *mut ffi::lua_State, index: c_int) -> *mut T { + get_meta_gc_userdata::(state, index) +} + +pub unsafe fn get_meta_gc_userdata(state: *mut ffi::lua_State, index: c_int) -> *mut T { + let ud = ffi::lua_touserdata(state, index) as *mut T; + if ud.is_null() || ffi::lua_getmetatable(state, index) == 0 { + return ptr::null_mut(); + } + get_gc_metatable_for::(state); + let res = ffi::lua_rawequal(state, -1, -2) != 0; + ffi::lua_pop(state, 2); + if !res { + return ptr::null_mut(); + } + ud +} + // Populates the given table with the appropriate members to be a userdata metatable for the given // type. This function takes the given table at the `metatable` index, and adds an appropriate __gc // member to it for the given type and a __metatable entry to protect the table from script access. @@ -380,14 +420,14 @@ where Ok(Err(err)) => { ffi::lua_settop(state, 1); ptr::write(ud as *mut WrappedError, WrappedError(err)); - get_error_metatable(state); + get_gc_metatable_for::(state); ffi::lua_setmetatable(state, -2); ffi::lua_error(state) } Err(p) => { ffi::lua_settop(state, 1); ptr::write(ud as *mut WrappedPanic, WrappedPanic(Some(p))); - get_panic_metatable(state); + get_gc_metatable_for::(state); ffi::lua_setmetatable(state, -2); ffi::lua_error(state) } @@ -428,12 +468,12 @@ pub unsafe extern "C" fn error_traceback(state: *mut ffi::lua_State) -> c_int { ud, WrappedError(Error::CallbackError { traceback, - cause: Arc::new(error), + cause: Rc::new(error), }), ); - get_error_metatable(state); + get_gc_metatable_for::(state); ffi::lua_setmetatable(state, -2); - } else if !is_wrapped_panic(state, -1) { + } else if let None = get_gc_userdata::(state, -1).as_ref() { if ffi::lua_checkstack(state, LUA_TRACEBACK_STACK) != 0 { let s = ffi::luaL_tolstring(state, -1, ptr::null_mut()); ffi::luaL_traceback(state, state, s, 0); @@ -443,68 +483,72 @@ pub unsafe extern "C" fn error_traceback(state: *mut ffi::lua_State) -> c_int { 1 } -// Does not call lua_checkstack, uses 2 stack spaces. -#[cfg(any(feature = "lua51", feature = "luajit"))] -pub unsafe fn set_main_state(state: *mut ffi::lua_State) { - ffi::lua_pushlightuserdata(state, &MAIN_THREAD_REGISTRY_KEY as *const u8 as *mut c_void); - ffi::lua_pushthread(state); - ffi::lua_rawset(state, ffi::LUA_REGISTRYINDEX); -} - // Does not call lua_checkstack, uses 1 stack space. pub unsafe fn get_main_state(state: *mut ffi::lua_State) -> *mut ffi::lua_State { #[cfg(any(feature = "lua53", feature = "lua52"))] - ffi::lua_rawgeti(state, ffi::LUA_REGISTRYINDEX, ffi::LUA_RIDX_MAINTHREAD); - #[cfg(any(feature = "lua51", feature = "luajit"))] { - ffi::lua_pushlightuserdata(state, &MAIN_THREAD_REGISTRY_KEY as *const u8 as *mut c_void); - ffi::lua_rawget(state, ffi::LUA_REGISTRYINDEX); + ffi::lua_rawgeti(state, ffi::LUA_REGISTRYINDEX, ffi::LUA_RIDX_MAINTHREAD); + let main_state = ffi::lua_tothread(state, -1); + ffi::lua_pop(state, 1); + main_state } - let main_state = ffi::lua_tothread(state, -1); - ffi::lua_pop(state, 1); - main_state + #[cfg(any(feature = "lua51", feature = "luajit"))] + state } // Pushes a WrappedError to the top of the stack. Uses two stack spaces and does not call // lua_checkstack. pub unsafe fn push_wrapped_error(state: *mut ffi::lua_State, err: Error) -> Result<()> { - let ud = protect_lua_closure(state, 0, 1, move |state| { - ffi::lua_newuserdata(state, mem::size_of::()) as *mut WrappedError - })?; - ptr::write(ud, WrappedError(err)); - get_error_metatable(state); - ffi::lua_setmetatable(state, -2); - Ok(()) + push_gc_userdata::(state, WrappedError(err)) } // Checks if the value at the given index is a WrappedError, and if it is returns a pointer to it, // otherwise returns null. Uses 2 stack spaces and does not call lua_checkstack. pub unsafe fn get_wrapped_error(state: *mut ffi::lua_State, index: c_int) -> *const Error { - let userdata = ffi::lua_touserdata(state, index); - if userdata.is_null() { + let ud = get_gc_userdata::(state, index); + if ud.is_null() { return ptr::null(); } + &(*ud).0 +} - if ffi::lua_getmetatable(state, index) == 0 { - return ptr::null(); +// Initialize the internal (with __gc) metatable for a type T +pub unsafe fn init_gc_metatable_for( + state: *mut ffi::lua_State, + customize_fn: Option, +) { + let type_id = TypeId::of::(); + + ffi::lua_newtable(state); + + ffi::lua_pushstring(state, cstr!("__gc")); + ffi::lua_pushcfunction(state, userdata_destructor::); + ffi::lua_rawset(state, -3); + + ffi::lua_pushstring(state, cstr!("__metatable")); + ffi::lua_pushboolean(state, 0); + ffi::lua_rawset(state, -3); + + if let Some(f) = customize_fn { + f(state) } - get_error_metatable(state); - let res = ffi::lua_rawequal(state, -1, -2) != 0; - ffi::lua_pop(state, 2); + let ref_addr = ffi::luaL_ref(state, ffi::LUA_REGISTRYINDEX); + METATABLE_CACHE.with(|mc| mc.borrow_mut().insert(type_id, ref_addr)); +} - if res { - &(*get_userdata::(state, -1)).0 - } else { - ptr::null() - } +pub unsafe fn get_gc_metatable_for(state: *mut ffi::lua_State) { + let type_id = TypeId::of::(); + let ref_addr = METATABLE_CACHE + .with(|mc| *mlua_expect!(mc.borrow().get(&type_id), "gc metatable does not exist")); + ffi::lua_rawgeti(state, ffi::LUA_REGISTRYINDEX, ref_addr as ffi::lua_Integer); } // Initialize the error, panic, and destructed userdata metatables. pub unsafe fn init_error_registry(state: *mut ffi::lua_State) { assert_stack(state, 8); - // Create error metatable + // Create error and panic metatables unsafe extern "C" fn error_tostring(state: *mut ffi::lua_State) -> c_int { let err_buf = callback_error(state, |_| { @@ -524,8 +568,7 @@ pub unsafe fn init_error_registry(state: *mut ffi::lua_State) { // kind of recursive error structure?) let _ = write!(&mut (*err_buf), "{}", error); Ok(err_buf) - } else if is_wrapped_panic(state, -1) { - let panic = get_userdata::(state, -1); + } else if let Some(panic) = get_gc_userdata::(state, -1).as_ref() { if let Some(ref p) = (*panic).0 { ffi::lua_pushlightuserdata( state, @@ -564,56 +607,31 @@ pub unsafe fn init_error_registry(state: *mut ffi::lua_State) { 1 } - ffi::lua_pushlightuserdata( + init_gc_metatable_for::( state, - &ERROR_METATABLE_REGISTRY_KEY as *const u8 as *mut c_void, + Some(|state| { + ffi::lua_pushstring(state, cstr!("__tostring")); + ffi::lua_pushcfunction(state, error_tostring); + ffi::lua_rawset(state, -3); + }), ); - ffi::lua_newtable(state); - ffi::lua_pushstring(state, cstr!("__gc")); - ffi::lua_pushcfunction(state, userdata_destructor::); - ffi::lua_rawset(state, -3); - - ffi::lua_pushstring(state, cstr!("__tostring")); - ffi::lua_pushcfunction(state, error_tostring); - ffi::lua_rawset(state, -3); - - ffi::lua_pushstring(state, cstr!("__metatable")); - ffi::lua_pushboolean(state, 0); - ffi::lua_rawset(state, -3); - - ffi::lua_rawset(state, ffi::LUA_REGISTRYINDEX); - - // Create panic metatable - - ffi::lua_pushlightuserdata( + init_gc_metatable_for::( state, - &PANIC_METATABLE_REGISTRY_KEY as *const u8 as *mut c_void, + Some(|state| { + ffi::lua_pushstring(state, cstr!("__tostring")); + ffi::lua_pushcfunction(state, error_tostring); + ffi::lua_rawset(state, -3); + }), ); - ffi::lua_newtable(state); - - ffi::lua_pushstring(state, cstr!("__gc")); - ffi::lua_pushcfunction(state, userdata_destructor::); - ffi::lua_rawset(state, -3); - - ffi::lua_pushstring(state, cstr!("__tostring")); - ffi::lua_pushcfunction(state, error_tostring); - ffi::lua_rawset(state, -3); - - ffi::lua_pushstring(state, cstr!("__metatable")); - ffi::lua_pushboolean(state, 0); - ffi::lua_rawset(state, -3); - - ffi::lua_rawset(state, ffi::LUA_REGISTRYINDEX); // Create destructed userdata metatable unsafe extern "C" fn destructed_error(state: *mut ffi::lua_State) -> c_int { ffi::luaL_checkstack(state, 2, ptr::null()); let ud = ffi::lua_newuserdata(state, mem::size_of::()) as *mut WrappedError; - ptr::write(ud, WrappedError(Error::CallbackDestructed)); - get_error_metatable(state); + get_gc_metatable_for::(state); ffi::lua_setmetatable(state, -2); ffi::lua_error(state) } @@ -709,40 +727,6 @@ unsafe fn to_string<'a>(state: *mut ffi::lua_State, index: c_int) -> Cow<'a, str } } -// Checks if the value at the given index is a WrappedPanic. Uses 2 stack spaces and does not call -// lua_checkstack. -unsafe fn is_wrapped_panic(state: *mut ffi::lua_State, index: c_int) -> bool { - let userdata = ffi::lua_touserdata(state, index); - if userdata.is_null() { - return false; - } - - if ffi::lua_getmetatable(state, index) == 0 { - return false; - } - - get_panic_metatable(state); - let res = ffi::lua_rawequal(state, -1, -2) != 0; - ffi::lua_pop(state, 2); - res -} - -unsafe fn get_error_metatable(state: *mut ffi::lua_State) { - ffi::lua_pushlightuserdata( - state, - &ERROR_METATABLE_REGISTRY_KEY as *const u8 as *mut c_void, - ); - ffi::lua_rawget(state, ffi::LUA_REGISTRYINDEX); -} - -unsafe fn get_panic_metatable(state: *mut ffi::lua_State) { - ffi::lua_pushlightuserdata( - state, - &PANIC_METATABLE_REGISTRY_KEY as *const u8 as *mut c_void, - ); - ffi::lua_rawget(state, ffi::LUA_REGISTRYINDEX); -} - unsafe fn get_destructed_userdata_metatable(state: *mut ffi::lua_State) { ffi::lua_pushlightuserdata( state, @@ -751,9 +735,5 @@ unsafe fn get_destructed_userdata_metatable(state: *mut ffi::lua_State) { ffi::lua_rawget(state, ffi::LUA_REGISTRYINDEX); } -#[cfg(any(feature = "lua51", feature = "luajit"))] -static MAIN_THREAD_REGISTRY_KEY: u8 = 0; -static ERROR_METATABLE_REGISTRY_KEY: u8 = 0; -static PANIC_METATABLE_REGISTRY_KEY: u8 = 0; static DESTRUCTED_USERDATA_METATABLE: u8 = 0; static ERROR_PRINT_BUFFER_KEY: u8 = 0; diff --git a/tests/compile_fail/lua_norefunwindsafe.stderr b/tests/compile_fail/lua_norefunwindsafe.stderr index 1cf48ad..edd2803 100644 --- a/tests/compile_fail/lua_norefunwindsafe.stderr +++ b/tests/compile_fail/lua_norefunwindsafe.stderr @@ -10,6 +10,21 @@ error[E0277]: the type `std::cell::UnsafeCell<()>` may contain interior mutabili = note: required because of the requirements on the impl of `std::panic::UnwindSafe` for `&mlua::lua::Lua` = note: required because it appears within the type `[closure@$DIR/tests/compile_fail/lua_norefunwindsafe.rs:7:18: 7:48 lua:&mlua::lua::Lua]` +error[E0277]: the type `std::cell::UnsafeCell` may contain interior mutability and a reference may not be safely transferrable across a catch_unwind boundary + --> $DIR/lua_norefunwindsafe.rs:7:5 + | +7 | catch_unwind(|| lua.create_table().unwrap()); + | ^^^^^^^^^^^^ `std::cell::UnsafeCell` may contain interior mutability and a reference may not be safely transferrable across a catch_unwind boundary + | + = help: within `mlua::lua::Lua`, the trait `std::panic::RefUnwindSafe` is not implemented for `std::cell::UnsafeCell` + = note: required because it appears within the type `std::cell::Cell` + = note: required because it appears within the type `std::rc::RcBox>` + = note: required because it appears within the type `std::marker::PhantomData>>` + = note: required because it appears within the type `std::rc::Rc>` + = note: required because it appears within the type `mlua::lua::Lua` + = note: required because of the requirements on the impl of `std::panic::UnwindSafe` for `&mlua::lua::Lua` + = note: required because it appears within the type `[closure@$DIR/tests/compile_fail/lua_norefunwindsafe.rs:7:18: 7:48 lua:&mlua::lua::Lua]` + error[E0277]: the type `std::cell::UnsafeCell` may contain interior mutability and a reference may not be safely transferrable across a catch_unwind boundary --> $DIR/lua_norefunwindsafe.rs:7:5 | @@ -18,9 +33,9 @@ error[E0277]: the type `std::cell::UnsafeCell` may contain | = help: within `mlua::lua::Lua`, the trait `std::panic::RefUnwindSafe` is not implemented for `std::cell::UnsafeCell` = note: required because it appears within the type `std::cell::RefCell` - = note: required because it appears within the type `alloc::sync::ArcInner>` - = note: required because it appears within the type `std::marker::PhantomData>>` - = note: required because it appears within the type `std::sync::Arc>` + = note: required because it appears within the type `std::rc::RcBox>` + = note: required because it appears within the type `std::marker::PhantomData>>` + = note: required because it appears within the type `std::rc::Rc>` = note: required because it appears within the type `mlua::lua::Lua` = note: required because of the requirements on the impl of `std::panic::UnwindSafe` for `&mlua::lua::Lua` = note: required because it appears within the type `[closure@$DIR/tests/compile_fail/lua_norefunwindsafe.rs:7:18: 7:48 lua:&mlua::lua::Lua]` @@ -34,9 +49,9 @@ error[E0277]: the type `std::cell::UnsafeCell` may contain interior mutab = help: within `mlua::lua::Lua`, the trait `std::panic::RefUnwindSafe` is not implemented for `std::cell::UnsafeCell` = note: required because it appears within the type `std::cell::Cell` = note: required because it appears within the type `std::cell::RefCell` - = note: required because it appears within the type `alloc::sync::ArcInner>` - = note: required because it appears within the type `std::marker::PhantomData>>` - = note: required because it appears within the type `std::sync::Arc>` + = note: required because it appears within the type `std::rc::RcBox>` + = note: required because it appears within the type `std::marker::PhantomData>>` + = note: required because it appears within the type `std::rc::Rc>` = note: required because it appears within the type `mlua::lua::Lua` = note: required because of the requirements on the impl of `std::panic::UnwindSafe` for `&mlua::lua::Lua` = note: required because it appears within the type `[closure@$DIR/tests/compile_fail/lua_norefunwindsafe.rs:7:18: 7:48 lua:&mlua::lua::Lua]` diff --git a/tests/compile_fail/ref_nounwindsafe.stderr b/tests/compile_fail/ref_nounwindsafe.stderr index 46f8352..ba6e696 100644 --- a/tests/compile_fail/ref_nounwindsafe.stderr +++ b/tests/compile_fail/ref_nounwindsafe.stderr @@ -12,6 +12,23 @@ error[E0277]: the type `std::cell::UnsafeCell<()>` may contain interior mutabili = note: required because it appears within the type `mlua::table::Table<'_>` = note: required because it appears within the type `[closure@$DIR/tests/compile_fail/ref_nounwindsafe.rs:8:18: 8:54 table:mlua::table::Table<'_>]` +error[E0277]: the type `std::cell::UnsafeCell` may contain interior mutability and a reference may not be safely transferrable across a catch_unwind boundary + --> $DIR/ref_nounwindsafe.rs:8:5 + | +8 | catch_unwind(move || table.set("a", "b").unwrap()); + | ^^^^^^^^^^^^ `std::cell::UnsafeCell` may contain interior mutability and a reference may not be safely transferrable across a catch_unwind boundary + | + = help: within `mlua::lua::Lua`, the trait `std::panic::RefUnwindSafe` is not implemented for `std::cell::UnsafeCell` + = note: required because it appears within the type `std::cell::Cell` + = note: required because it appears within the type `std::rc::RcBox>` + = note: required because it appears within the type `std::marker::PhantomData>>` + = note: required because it appears within the type `std::rc::Rc>` + = note: required because it appears within the type `mlua::lua::Lua` + = note: required because of the requirements on the impl of `std::panic::UnwindSafe` for `&mlua::lua::Lua` + = note: required because it appears within the type `mlua::types::LuaRef<'_>` + = note: required because it appears within the type `mlua::table::Table<'_>` + = note: required because it appears within the type `[closure@$DIR/tests/compile_fail/ref_nounwindsafe.rs:8:18: 8:54 table:mlua::table::Table<'_>]` + error[E0277]: the type `std::cell::UnsafeCell` may contain interior mutability and a reference may not be safely transferrable across a catch_unwind boundary --> $DIR/ref_nounwindsafe.rs:8:5 | @@ -20,9 +37,9 @@ error[E0277]: the type `std::cell::UnsafeCell` may contain | = help: within `mlua::lua::Lua`, the trait `std::panic::RefUnwindSafe` is not implemented for `std::cell::UnsafeCell` = note: required because it appears within the type `std::cell::RefCell` - = note: required because it appears within the type `alloc::sync::ArcInner>` - = note: required because it appears within the type `std::marker::PhantomData>>` - = note: required because it appears within the type `std::sync::Arc>` + = note: required because it appears within the type `std::rc::RcBox>` + = note: required because it appears within the type `std::marker::PhantomData>>` + = note: required because it appears within the type `std::rc::Rc>` = note: required because it appears within the type `mlua::lua::Lua` = note: required because of the requirements on the impl of `std::panic::UnwindSafe` for `&mlua::lua::Lua` = note: required because it appears within the type `mlua::types::LuaRef<'_>` @@ -38,9 +55,9 @@ error[E0277]: the type `std::cell::UnsafeCell` may contain interior mutab = help: within `mlua::lua::Lua`, the trait `std::panic::RefUnwindSafe` is not implemented for `std::cell::UnsafeCell` = note: required because it appears within the type `std::cell::Cell` = note: required because it appears within the type `std::cell::RefCell` - = note: required because it appears within the type `alloc::sync::ArcInner>` - = note: required because it appears within the type `std::marker::PhantomData>>` - = note: required because it appears within the type `std::sync::Arc>` + = note: required because it appears within the type `std::rc::RcBox>` + = note: required because it appears within the type `std::marker::PhantomData>>` + = note: required because it appears within the type `std::rc::Rc>` = note: required because it appears within the type `mlua::lua::Lua` = note: required because of the requirements on the impl of `std::panic::UnwindSafe` for `&mlua::lua::Lua` = note: required because it appears within the type `mlua::types::LuaRef<'_>` diff --git a/tests/compile_fail/scope_callback_capture.rs b/tests/compile_fail/scope_callback_capture.rs deleted file mode 100644 index 93667f0..0000000 --- a/tests/compile_fail/scope_callback_capture.rs +++ /dev/null @@ -1,22 +0,0 @@ -use mlua::{Lua, Table, Result}; - -struct Test { - field: i32, -} - -fn main() { - let lua = Lua::new(); - lua.scope(|scope| -> Result<()> { - let mut inner: Option = None; - let f = scope - .create_function_mut(move |lua, t: Table| { - if let Some(old) = inner.take() { - // Access old callback `Lua`. - } - inner = Some(t); - Ok(()) - })?; - f.call::<_, ()>(lua.create_table()?)?; - Ok(()) - }); -} diff --git a/tests/compile_fail/scope_callback_capture.stderr b/tests/compile_fail/scope_callback_capture.stderr deleted file mode 100644 index b23ea95..0000000 --- a/tests/compile_fail/scope_callback_capture.stderr +++ /dev/null @@ -1,45 +0,0 @@ -error[E0495]: cannot infer an appropriate lifetime for autoref due to conflicting requirements - --> $DIR/scope_callback_capture.rs:12:14 - | -12 | .create_function_mut(move |lua, t: Table| { - | ^^^^^^^^^^^^^^^^^^^ - | -note: first, the lifetime cannot outlive the anonymous lifetime #2 defined on the body at 9:15... - --> $DIR/scope_callback_capture.rs:9:15 - | -9 | lua.scope(|scope| -> Result<()> { - | _______________^ -10 | | let mut inner: Option
= None; -11 | | let f = scope -12 | | .create_function_mut(move |lua, t: Table| { -... | -20 | | Ok(()) -21 | | }); - | |_____^ -note: ...so that reference does not outlive borrowed content - --> $DIR/scope_callback_capture.rs:11:17 - | -11 | let f = scope - | ^^^^^ -note: but, the lifetime must be valid for the method call at 9:5... - --> $DIR/scope_callback_capture.rs:9:5 - | -9 | / lua.scope(|scope| -> Result<()> { -10 | | let mut inner: Option
= None; -11 | | let f = scope -12 | | .create_function_mut(move |lua, t: Table| { -... | -20 | | Ok(()) -21 | | }); - | |______^ -note: ...so that a type/lifetime parameter is in scope here - --> $DIR/scope_callback_capture.rs:9:5 - | -9 | / lua.scope(|scope| -> Result<()> { -10 | | let mut inner: Option
= None; -11 | | let f = scope -12 | | .create_function_mut(move |lua, t: Table| { -... | -20 | | Ok(()) -21 | | }); - | |______^ diff --git a/tests/compile_fail/scope_callback_inner.rs b/tests/compile_fail/scope_callback_inner.rs deleted file mode 100644 index 56b56ba..0000000 --- a/tests/compile_fail/scope_callback_inner.rs +++ /dev/null @@ -1,19 +0,0 @@ -use mlua::{Lua, Table, Result}; - -struct Test { - field: i32, -} - -fn main() { - let lua = Lua::new(); - lua.scope(|scope| -> Result<()> { - let mut inner: Option
= None; - let f = scope - .create_function_mut(|_, t: Table| { - inner = Some(t); - Ok(()) - })?; - f.call::<_, ()>(lua.create_table()?)?; - Ok(()) - }); -} diff --git a/tests/compile_fail/scope_callback_inner.stderr b/tests/compile_fail/scope_callback_inner.stderr deleted file mode 100644 index 077fba4..0000000 --- a/tests/compile_fail/scope_callback_inner.stderr +++ /dev/null @@ -1,45 +0,0 @@ -error[E0495]: cannot infer an appropriate lifetime for autoref due to conflicting requirements - --> $DIR/scope_callback_inner.rs:12:14 - | -12 | .create_function_mut(|_, t: Table| { - | ^^^^^^^^^^^^^^^^^^^ - | -note: first, the lifetime cannot outlive the anonymous lifetime #2 defined on the body at 9:15... - --> $DIR/scope_callback_inner.rs:9:15 - | -9 | lua.scope(|scope| -> Result<()> { - | _______________^ -10 | | let mut inner: Option
= None; -11 | | let f = scope -12 | | .create_function_mut(|_, t: Table| { -... | -17 | | Ok(()) -18 | | }); - | |_____^ -note: ...so that reference does not outlive borrowed content - --> $DIR/scope_callback_inner.rs:11:17 - | -11 | let f = scope - | ^^^^^ -note: but, the lifetime must be valid for the method call at 9:5... - --> $DIR/scope_callback_inner.rs:9:5 - | -9 | / lua.scope(|scope| -> Result<()> { -10 | | let mut inner: Option
= None; -11 | | let f = scope -12 | | .create_function_mut(|_, t: Table| { -... | -17 | | Ok(()) -18 | | }); - | |______^ -note: ...so that a type/lifetime parameter is in scope here - --> $DIR/scope_callback_inner.rs:9:5 - | -9 | / lua.scope(|scope| -> Result<()> { -10 | | let mut inner: Option
= None; -11 | | let f = scope -12 | | .create_function_mut(|_, t: Table| { -... | -17 | | Ok(()) -18 | | }); - | |______^ diff --git a/tests/compile_fail/scope_callback_outer.rs b/tests/compile_fail/scope_callback_outer.rs deleted file mode 100644 index 57437c9..0000000 --- a/tests/compile_fail/scope_callback_outer.rs +++ /dev/null @@ -1,19 +0,0 @@ -use mlua::{Lua, Table, Result}; - -struct Test { - field: i32, -} - -fn main() { - let lua = Lua::new(); - let mut outer: Option
= None; - lua.scope(|scope| -> Result<()> { - let f = scope - .create_function_mut(|_, t: Table| { - outer = Some(t); - Ok(()) - })?; - f.call::<_, ()>(lua.create_table()?)?; - Ok(()) - }); -} diff --git a/tests/compile_fail/scope_callback_outer.stderr b/tests/compile_fail/scope_callback_outer.stderr deleted file mode 100644 index c68bb4f..0000000 --- a/tests/compile_fail/scope_callback_outer.stderr +++ /dev/null @@ -1,11 +0,0 @@ -error: borrowed data cannot be stored outside of its closure - --> $DIR/scope_callback_outer.rs:11:17 - | -9 | let mut outer: Option
= None; - | --------- ...so that variable is valid at time of its declaration -10 | lua.scope(|scope| -> Result<()> { - | --------------------- borrowed data cannot outlive this closure -11 | let f = scope - | ^^^^^ cannot be stored outside of its closure -12 | .create_function_mut(|_, t: Table| { - | ------------------- cannot infer an appropriate lifetime... diff --git a/tests/compile_fail/scope_invariance.rs b/tests/compile_fail/scope_invariance.rs deleted file mode 100644 index 569ada1..0000000 --- a/tests/compile_fail/scope_invariance.rs +++ /dev/null @@ -1,23 +0,0 @@ -use mlua::{Lua, Result}; - -struct Test { - field: i32, -} - -fn main() { - let lua = Lua::new(); - lua.scope(|scope| -> Result<()> { - let f = { - let mut test = Test { field: 0 }; - - scope - .create_function_mut(|_, ()| { - test.field = 42; - //~^ error: `test` does not live long enough - Ok(()) - })? - }; - - f.call::<_, ()>(()) - }); -} diff --git a/tests/compile_fail/scope_invariance.stderr b/tests/compile_fail/scope_invariance.stderr deleted file mode 100644 index 434eb78..0000000 --- a/tests/compile_fail/scope_invariance.stderr +++ /dev/null @@ -1,25 +0,0 @@ -error[E0373]: closure may outlive the current function, but it borrows `test`, which is owned by the current function - --> $DIR/scope_invariance.rs:14:38 - | -9 | lua.scope(|scope| -> Result<()> { - | ----- has type `&mlua::scope::Scope<'_, '1>` -... -14 | .create_function_mut(|_, ()| { - | ^^^^^^^ may outlive borrowed value `test` -15 | test.field = 42; - | ---- `test` is borrowed here - | -note: function requires argument type to outlive `'1` - --> $DIR/scope_invariance.rs:13:13 - | -13 | / scope -14 | | .create_function_mut(|_, ()| { -15 | | test.field = 42; -16 | | //~^ error: `test` does not live long enough -17 | | Ok(()) -18 | | })? - | |__________________^ -help: to force the closure to take ownership of `test` (and any other referenced variables), use the `move` keyword - | -14 | .create_function_mut(move |_, ()| { - | ^^^^^^^^^^^^ diff --git a/tests/compile_fail/scope_mutable_aliasing.rs b/tests/compile_fail/scope_mutable_aliasing.rs deleted file mode 100644 index 4d8d2f0..0000000 --- a/tests/compile_fail/scope_mutable_aliasing.rs +++ /dev/null @@ -1,15 +0,0 @@ -use mlua::{Lua, UserData, Result}; - -struct MyUserData<'a>(&'a mut i32); -impl<'a> UserData for MyUserData<'a> {} - -fn main() { - let mut i = 1; - - let lua = Lua::new(); - lua.scope(|scope| -> Result<()> { - let _a = scope.create_nonstatic_userdata(MyUserData(&mut i))?; - let _b = scope.create_nonstatic_userdata(MyUserData(&mut i))?; - Ok(()) - }); -} diff --git a/tests/compile_fail/scope_mutable_aliasing.stderr b/tests/compile_fail/scope_mutable_aliasing.stderr deleted file mode 100644 index d78f983..0000000 --- a/tests/compile_fail/scope_mutable_aliasing.stderr +++ /dev/null @@ -1,9 +0,0 @@ -error[E0499]: cannot borrow `i` as mutable more than once at a time - --> $DIR/scope_mutable_aliasing.rs:12:61 - | -11 | let _a = scope.create_nonstatic_userdata(MyUserData(&mut i))?; - | ------ first mutable borrow occurs here -12 | let _b = scope.create_nonstatic_userdata(MyUserData(&mut i))?; - | ------------------------- ^^^^^^ second mutable borrow occurs here - | | - | first borrow later used by call diff --git a/tests/compile_fail/scope_userdata_borrow.rs b/tests/compile_fail/scope_userdata_borrow.rs deleted file mode 100644 index b88d1d6..0000000 --- a/tests/compile_fail/scope_userdata_borrow.rs +++ /dev/null @@ -1,20 +0,0 @@ -use mlua::{Lua, UserData, Result}; - -struct MyUserData<'a>(&'a i32); -impl<'a> UserData for MyUserData<'a> {} - -fn main() { - // Should not allow userdata borrow to outlive lifetime of AnyUserData handle - - let igood = 1; - - let lua = Lua::new(); - lua.scope(|scope| -> Result<()> { - let _ugood = scope.create_nonstatic_userdata(MyUserData(&igood))?; - let _ubad = { - let ibad = 42; - scope.create_nonstatic_userdata(MyUserData(&ibad))?; - }; - Ok(()) - }); -} diff --git a/tests/compile_fail/scope_userdata_borrow.stderr b/tests/compile_fail/scope_userdata_borrow.stderr deleted file mode 100644 index c143c9a..0000000 --- a/tests/compile_fail/scope_userdata_borrow.stderr +++ /dev/null @@ -1,13 +0,0 @@ -error[E0597]: `ibad` does not live long enough - --> $DIR/scope_userdata_borrow.rs:16:56 - | -12 | lua.scope(|scope| -> Result<()> { - | ----- has type `&mlua::scope::Scope<'_, '1>` -... -16 | scope.create_nonstatic_userdata(MyUserData(&ibad))?; - | -------------------------------------------^^^^^-- - | | | - | | borrowed value does not live long enough - | argument requires that `ibad` is borrowed for `'1` -17 | }; - | - `ibad` dropped here while still borrowed diff --git a/tests/function.rs b/tests/function.rs index cc1c9ab..ef2f9fb 100644 --- a/tests/function.rs +++ b/tests/function.rs @@ -1,4 +1,10 @@ -use mlua::{Function, Lua, Result, String}; +#![allow(unused_imports)] + +use std::{string::String as StdString, time::Duration}; + +use futures_executor::block_on; + +use mlua::{Error, Function, Lua, Result, String, Thread}; #[test] fn test_function() -> Result<()> { @@ -75,3 +81,34 @@ fn test_rust_function() -> Result<()> { Ok(()) } + +#[cfg(feature = "async")] +#[tokio::test] +async fn test_async_function() -> Result<()> { + let lua = Lua::new(); + + let f = lua.create_async_function(move |_lua, n: u64| async move { + futures_timer::Delay::new(Duration::from_secs(n)).await; + Ok("hello") + })?; + lua.globals().set("rust_async_sleep", f)?; + + let thread = lua + .load( + r#" + coroutine.create(function () + ret = rust_async_sleep(1) + assert(ret == "hello") + coroutine.yield() + return "world" + end) + "#, + ) + .eval::()?; + + let fut = thread.into_async(()); + let ret: StdString = fut.await?; + assert_eq!(ret, "world"); + + Ok(()) +} diff --git a/tests/memory.rs b/tests/memory.rs index cf67c18..376e5b1 100644 --- a/tests/memory.rs +++ b/tests/memory.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::rc::Rc; use mlua::{Lua, Result, UserData}; @@ -16,17 +16,17 @@ fn test_gc_control() -> Result<()> { assert!(lua.gc_is_running()); } - struct MyUserdata(Arc<()>); + struct MyUserdata(Rc<()>); impl UserData for MyUserdata {} - let rc = Arc::new(()); + let rc = Rc::new(()); globals.set("userdata", lua.create_userdata(MyUserdata(rc.clone()))?)?; globals.raw_remove("userdata")?; - assert_eq!(Arc::strong_count(&rc), 2); + assert_eq!(Rc::strong_count(&rc), 2); lua.gc_collect()?; lua.gc_collect()?; - assert_eq!(Arc::strong_count(&rc), 1); + assert_eq!(Rc::strong_count(&rc), 1); Ok(()) } diff --git a/tests/scope.rs b/tests/scope.rs deleted file mode 100644 index 54840d6..0000000 --- a/tests/scope.rs +++ /dev/null @@ -1,231 +0,0 @@ -use std::cell::Cell; -use std::rc::Rc; - -use mlua::{Error, Function, Lua, MetaMethod, Result, String, UserData, UserDataMethods}; - -#[test] -fn scope_func() -> Result<()> { - let lua = Lua::new(); - - let rc = Rc::new(Cell::new(0)); - lua.scope(|scope| { - let r = rc.clone(); - let f = scope.create_function(move |_, ()| { - r.set(42); - Ok(()) - })?; - lua.globals().set("bad", f.clone())?; - f.call::<_, ()>(())?; - assert_eq!(Rc::strong_count(&rc), 2); - Ok(()) - })?; - assert_eq!(rc.get(), 42); - assert_eq!(Rc::strong_count(&rc), 1); - - match lua.globals().get::<_, Function>("bad")?.call::<_, ()>(()) { - Err(Error::CallbackError { .. }) => {} - r => panic!("improper return for destructed function: {:?}", r), - }; - - Ok(()) -} - -#[test] -fn scope_drop() -> Result<()> { - let lua = Lua::new(); - - struct MyUserdata(Rc<()>); - impl UserData for MyUserdata { - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { - methods.add_method("method", |_, _, ()| Ok(())); - } - } - - let rc = Rc::new(()); - - lua.scope(|scope| { - lua.globals().set( - "test", - scope.create_static_userdata(MyUserdata(rc.clone()))?, - )?; - assert_eq!(Rc::strong_count(&rc), 2); - Ok(()) - })?; - assert_eq!(Rc::strong_count(&rc), 1); - - match lua.load("test:method()").exec() { - Err(Error::CallbackError { .. }) => {} - r => panic!("improper return for destructed userdata: {:?}", r), - }; - - Ok(()) -} - -#[test] -fn scope_capture() -> Result<()> { - let lua = Lua::new(); - - let mut i = 0; - lua.scope(|scope| { - scope - .create_function_mut(|_, ()| { - i = 42; - Ok(()) - })? - .call::<_, ()>(()) - })?; - assert_eq!(i, 42); - - Ok(()) -} - -#[test] -fn outer_lua_access() -> Result<()> { - let lua = Lua::new(); - - let table = lua.create_table()?; - lua.scope(|scope| { - scope - .create_function_mut(|_, ()| table.set("a", "b"))? - .call::<_, ()>(()) - })?; - assert_eq!(table.get::<_, String>("a")?, "b"); - - Ok(()) -} - -#[test] -fn scope_userdata_methods() -> Result<()> { - struct MyUserData<'a>(&'a Cell); - - impl<'a> UserData for MyUserData<'a> { - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { - methods.add_method("inc", |_, data, ()| { - data.0.set(data.0.get() + 1); - Ok(()) - }); - - methods.add_method("dec", |_, data, ()| { - data.0.set(data.0.get() - 1); - Ok(()) - }); - } - } - - let lua = Lua::new(); - - let i = Cell::new(42); - let f: Function = lua - .load( - r#" - function(u) - u:inc() - u:inc() - u:inc() - u:dec() - end - "#, - ) - .eval()?; - - lua.scope(|scope| f.call::<_, ()>(scope.create_nonstatic_userdata(MyUserData(&i))?))?; - - assert_eq!(i.get(), 44); - - Ok(()) -} - -#[test] -fn scope_userdata_functions() -> Result<()> { - struct MyUserData<'a>(&'a i64); - - impl<'a> UserData for MyUserData<'a> { - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { - methods.add_meta_function(MetaMethod::Add, |lua, ()| { - let globals = lua.globals(); - globals.set("i", globals.get::<_, i64>("i")? + 1)?; - Ok(()) - }); - methods.add_meta_function(MetaMethod::Sub, |lua, ()| { - let globals = lua.globals(); - globals.set("i", globals.get::<_, i64>("i")? + 1)?; - Ok(()) - }); - } - } - - let lua = Lua::new(); - - let dummy = 0; - let f = lua - .load( - r#" - i = 0 - return function(u) - _ = u + u - _ = u - 1 - _ = 1 + u - end - "#, - ) - .eval::()?; - - lua.scope(|scope| f.call::<_, ()>(scope.create_nonstatic_userdata(MyUserData(&dummy))?))?; - - assert_eq!(lua.globals().get::<_, i64>("i")?, 3); - - Ok(()) -} - -#[test] -fn scope_userdata_mismatch() -> Result<()> { - struct MyUserData<'a>(&'a Cell); - - impl<'a> UserData for MyUserData<'a> { - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { - methods.add_method("inc", |_, data, ()| { - data.0.set(data.0.get() + 1); - Ok(()) - }); - } - } - - let lua = Lua::new(); - - lua.load( - r#" - function okay(a, b) - a.inc(a) - b.inc(b) - end - - function bad(a, b) - a.inc(b) - end - "#, - ) - .exec()?; - - let a = Cell::new(1); - let b = Cell::new(1); - - let okay: Function = lua.globals().get("okay")?; - let bad: Function = lua.globals().get("bad")?; - - lua.scope(|scope| { - let au = scope.create_nonstatic_userdata(MyUserData(&a))?; - let bu = scope.create_nonstatic_userdata(MyUserData(&b))?; - assert!(okay.call::<_, ()>((au.clone(), bu.clone())).is_ok()); - match bad.call::<_, ()>((au, bu)) { - Err(Error::CallbackError { ref cause, .. }) => match *cause.as_ref() { - Error::UserDataTypeMismatch => {} - ref other => panic!("wrong error type {:?}", other), - }, - Err(other) => panic!("wrong error type {:?}", other), - Ok(_) => panic!("incorrectly returned Ok"), - } - Ok(()) - })?; - - Ok(()) -} diff --git a/tests/tests.rs b/tests/tests.rs index 2784da9..a41e217 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -1,6 +1,6 @@ use std::iter::FromIterator; use std::panic::catch_unwind; -use std::sync::Arc; +use std::rc::Rc; use std::{error, f32, f64, fmt}; use mlua::{ @@ -584,22 +584,22 @@ fn test_registry_value() -> Result<()> { #[test] fn test_drop_registry_value() -> Result<()> { - struct MyUserdata(Arc<()>); + struct MyUserdata(Rc<()>); impl UserData for MyUserdata {} let lua = Lua::new(); - let rc = Arc::new(()); + let rc = Rc::new(()); let r = lua.create_registry_value(MyUserdata(rc.clone()))?; - assert_eq!(Arc::strong_count(&rc), 2); + assert_eq!(Rc::strong_count(&rc), 2); drop(r); lua.expire_registry_values(); lua.load(r#"collectgarbage("collect")"#).exec()?; - assert_eq!(Arc::strong_count(&rc), 1); + assert_eq!(Rc::strong_count(&rc), 1); Ok(()) } diff --git a/tests/thread.rs b/tests/thread.rs index 666a7e6..985fd88 100644 --- a/tests/thread.rs +++ b/tests/thread.rs @@ -1,4 +1,11 @@ +#![allow(unused_imports)] + use std::panic::catch_unwind; +use std::rc::Rc; +use std::time::Duration; + +use futures_executor::block_on; +use futures_util::stream::TryStreamExt; use mlua::{Error, Function, Lua, Result, Thread, ThreadStatus}; @@ -93,6 +100,38 @@ fn test_thread() -> Result<()> { Ok(()) } +#[cfg(feature = "async")] +#[tokio::test] +async fn test_thread_stream() -> Result<()> { + let lua = Lua::new(); + + let thread = lua.create_thread( + lua.load( + r#" + function (s) + local sum = s + for i = 1,10 do + sum = sum + i + coroutine.yield(sum) + end + return sum + end + "#, + ) + .eval()?, + )?; + + let mut s = thread.into_async::<_, i64>(0); + let mut sum = 0; + while let Some(n) = s.try_next().await? { + sum += n; + } + + assert_eq!(sum, 275); + + Ok(()) +} + #[test] fn coroutine_from_closure() -> Result<()> { let lua = Lua::new(); @@ -128,3 +167,30 @@ fn coroutine_panic() { Err(p) => assert!(*p.downcast::<&str>().unwrap() == "test_panic"), } } + +#[cfg(feature = "async")] +#[tokio::test] +async fn test_thread_async() -> Result<()> { + let lua = Lua::new(); + + let cnt = Rc::new(1); // sleep 1 second + let cnt2 = cnt.clone(); + let f = lua.create_async_function(move |_lua, ()| { + let cnt3 = cnt2.clone(); + async move { + futures_timer::Delay::new(Duration::from_secs(*cnt3.as_ref())).await; + Ok("hello") + } + })?; + + let mut thread_s = lua.create_thread(f)?.into_async(()); + let val: String = thread_s.try_next().await?.unwrap_or_default(); + + // thread_s is non-resumable and subject to garbage collection + + lua.gc_collect()?; + assert_eq!(Rc::strong_count(&cnt), 1); + assert_eq!(val, "hello"); + + Ok(()) +} diff --git a/tests/userdata.rs b/tests/userdata.rs index 64919cf..f7ffdd2 100644 --- a/tests/userdata.rs +++ b/tests/userdata.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::rc::Rc; use mlua::{ AnyUserData, ExternalError, Function, Lua, MetaMethod, Result, String, UserData, @@ -196,22 +196,22 @@ fn test_gc_userdata() -> Result<()> { #[test] fn detroys_userdata() -> Result<()> { - struct MyUserdata(Arc<()>); + struct MyUserdata(Rc<()>); impl UserData for MyUserdata {} - let rc = Arc::new(()); + let rc = Rc::new(()); let lua = Lua::new(); lua.globals().set("userdata", MyUserdata(rc.clone()))?; - assert_eq!(Arc::strong_count(&rc), 2); + assert_eq!(Rc::strong_count(&rc), 2); // should destroy all objects let _ = lua.globals().raw_remove("userdata")?; lua.gc_collect()?; - assert_eq!(Arc::strong_count(&rc), 1); + assert_eq!(Rc::strong_count(&rc), 1); Ok(()) } @@ -219,6 +219,7 @@ fn detroys_userdata() -> Result<()> { #[test] fn user_value() -> Result<()> { struct MyUserData; + impl UserData for MyUserData {} let lua = Lua::new();