From 2aed548747a6c83f6d2db65ce7071b6a651aac6b Mon Sep 17 00:00:00 2001 From: Alex Orlenko Date: Sun, 21 Feb 2021 23:34:35 +0000 Subject: [PATCH] Fix scoped async destruction of partially polled futures --- src/lua.rs | 4 +++- src/scope.rs | 29 ++++++++++++++++++++++------- tests/async.rs | 19 ++++++++++++++++++- 3 files changed, 43 insertions(+), 9 deletions(-) diff --git a/src/lua.rs b/src/lua.rs index 8720c88..91e0edd 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -1787,9 +1787,11 @@ impl Lua { })?, )?; + // We set `poll` variable in the env table to be able to destroy upvalues self.load( r#" - local poll = get_poll(...) + poll = get_poll(...) + local poll = poll while true do ready, res = poll() if ready then diff --git a/src/scope.rs b/src/scope.rs index 4965a1f..fe2cb0f 100644 --- a/src/scope.rs +++ b/src/scope.rs @@ -23,9 +23,8 @@ use crate::value::{FromLuaMulti, MultiValue, ToLuaMulti, Value}; #[cfg(feature = "async")] use { crate::types::AsyncCallback, - futures_core::future::Future, + futures_core::future::{Future, LocalBoxFuture}, futures_util::future::{self, TryFutureExt}, - std::os::raw::c_char, }; /// Constructed by the [`Lua::scope`] method, allows temporarily creating Lua userdata and @@ -420,12 +419,11 @@ impl<'lua, 'scope> Scope<'lua, 'scope> { #[cfg(any(feature = "lua51", feature = "luajit"))] ffi::lua_getfenv(state, -1); - // Then, get the get_poll() closure using the corresponding key - let key = "get_poll"; - ffi::lua_pushlstring(state, key.as_ptr() as *const c_char, key.len()); + // Second, get the `get_poll()` closure using the corresponding key + ffi::lua_pushstring(state, cstr!("get_poll")); ffi::lua_rawget(state, -2); - // Finally, destroy all upvalues + // Destroy all upvalues ffi::lua_getupvalue(state, -1, 1); let ud1 = take_userdata::(state); ffi::lua_pushnil(state); @@ -437,8 +435,25 @@ impl<'lua, 'scope> Scope<'lua, 'scope> { ffi::lua_setupvalue(state, -2, 2); ffi::lua_pop(state, 1); + let mut data: Vec> = vec![Box::new(ud1), Box::new(ud2)]; - vec![Box::new(ud1), Box::new(ud2)] + // Finally, get polled future and destroy it + ffi::lua_pushstring(state, cstr!("poll")); + if ffi::lua_rawget(state, -2) == ffi::LUA_TFUNCTION { + ffi::lua_getupvalue(state, -1, 1); + let ud3 = take_userdata::>>(state); + ffi::lua_pushnil(state); + ffi::lua_setupvalue(state, -2, 1); + data.push(Box::new(ud3)); + + ffi::lua_getupvalue(state, -1, 2); + let ud4 = take_userdata::(state); + ffi::lua_pushnil(state); + ffi::lua_setupvalue(state, -2, 2); + data.push(Box::new(ud4)); + } + + data })); Ok(f) diff --git a/tests/async.rs b/tests/async.rs index caadb00..e2496f1 100644 --- a/tests/async.rs +++ b/tests/async.rs @@ -22,7 +22,9 @@ use std::time::Duration; use futures_timer::Delay; use futures_util::stream::TryStreamExt; -use mlua::{Error, Function, Lua, Result, Table, TableExt, UserData, UserDataMethods}; +use mlua::{ + Error, Function, Lua, Result, Table, TableExt, Thread, UserData, UserDataMethods, Value, +}; #[tokio::test] async fn test_async_function() -> Result<()> { @@ -332,11 +334,18 @@ async fn test_async_scope() -> Result<()> { let _ = f.call_async::(10).await?; assert_eq!(Rc::strong_count(rc), 1); + // Create future in partialy polled state (Poll::Pending) + let g = lua.create_thread(f)?; + g.resume::(10)?; + lua.globals().set("g", g)?; + assert_eq!(Rc::strong_count(rc), 2); + Ok(()) }); assert_eq!(Rc::strong_count(rc), 1); let _ = fut.await?; + assert_eq!(Rc::strong_count(rc), 1); match lua .globals() @@ -351,6 +360,14 @@ async fn test_async_scope() -> Result<()> { r => panic!("improper return for destructed function: {:?}", r), }; + match lua.globals().get::<_, Thread>("g")?.resume::<_, Value>(()) { + Err(Error::CallbackError { ref cause, .. }) => match cause.as_ref() { + Error::CallbackDestructed => {} + e => panic!("expected `CallbackDestructed` error cause, got {:?}", e), + }, + r => panic!("improper return for destructed function: {:?}", r), + }; + Ok(()) }